Skip to content

Commit 84d6750

Browse files
committed
parameterize num_steps_per_batch for test_in_flight_with_fire
1 parent aad9be5 commit 84d6750

File tree

1 file changed

+11
-56
lines changed

1 file changed

+11
-56
lines changed

tests/test_autobatching.py

Lines changed: 11 additions & 56 deletions
Original file line numberDiff line numberDiff line change
@@ -448,10 +448,20 @@ def test_in_flight_auto_batcher_restore_order(
448448
# batcher.restore_original_order([si_sim_state])
449449

450450

451+
@pytest.mark.parametrize(
452+
"num_steps_per_batch",
453+
[
454+
5, # At 5 steps, not every state will converge before the next batch.
455+
# This tests the merging of partially converged states with new states
456+
# which has been a bug in the past. See https://github.com/Radical-AI/torch-sim/pull/219
457+
10, # At 10 steps, all states will converge before the next batch
458+
],
459+
)
451460
def test_in_flight_with_fire(
452461
si_sim_state: ts.SimState,
453462
fe_supercell_sim_state: ts.SimState,
454463
lj_model: LennardJonesModel,
464+
num_steps_per_batch: int,
455465
) -> None:
456466
fire_init, fire_update = unit_cell_fire(lj_model)
457467

@@ -489,62 +499,7 @@ def convergence_fn(state: ts.SimState) -> bool:
489499
if state is None:
490500
break
491501

492-
# run 10 steps (so all states converge before the next batch)
493-
for _ in range(10):
494-
state = fire_update(state)
495-
convergence_tensor = convergence_fn(state)
496-
497-
assert len(all_completed_states) == len(fire_states)
498-
499-
500-
def test_in_flight_with_fire_only_converge_some_states(
501-
si_sim_state: ts.SimState,
502-
fe_supercell_sim_state: ts.SimState,
503-
lj_model: LennardJonesModel,
504-
) -> None:
505-
"""This test is the same as the test_in_flight_with_fire above
506-
but we only converge a few states before we trigger the auto batcher.
507-
This can catch bugs related to merging partially converged and fully converged
508-
states. See https://github.com/Radical-AI/torch-sim/pull/219
509-
"""
510-
fire_init, fire_update = unit_cell_fire(lj_model)
511-
512-
si_fire_state = fire_init(si_sim_state)
513-
fe_fire_state = fire_init(fe_supercell_sim_state)
514-
515-
fire_states = [si_fire_state, fe_fire_state] * 5
516-
fire_states = [state.clone() for state in fire_states]
517-
for state in fire_states:
518-
state.positions += torch.randn_like(state.positions) * 0.01
519-
520-
batcher = InFlightAutoBatcher(
521-
model=lj_model,
522-
memory_scales_with="n_atoms",
523-
# max_metric=400_000,
524-
max_memory_scaler=600,
525-
)
526-
batcher.load_states(fire_states)
527-
528-
def convergence_fn(state: ts.SimState) -> bool:
529-
system_wise_max_force = torch.zeros(
530-
state.n_systems, device=state.device, dtype=torch.float64
531-
)
532-
max_forces = state.forces.norm(dim=1)
533-
system_wise_max_force = system_wise_max_force.scatter_reduce(
534-
dim=0, index=state.system_idx, src=max_forces, reduce="amax"
535-
)
536-
return system_wise_max_force < 5e-1
537-
538-
all_completed_states, convergence_tensor = [], None
539-
while True:
540-
state, completed_states = batcher.next_batch(state, convergence_tensor)
541-
542-
all_completed_states.extend(completed_states)
543-
if state is None:
544-
break
545-
546-
# run 5 steps (so not every state can converge before the next batch)
547-
for _ in range(5):
502+
for _ in range(num_steps_per_batch):
548503
state = fire_update(state)
549504
convergence_tensor = convergence_fn(state)
550505

0 commit comments

Comments
 (0)