@@ -448,10 +448,20 @@ def test_in_flight_auto_batcher_restore_order(
448
448
# batcher.restore_original_order([si_sim_state])
449
449
450
450
451
+ @pytest .mark .parametrize (
452
+ "num_steps_per_batch" ,
453
+ [
454
+ 5 , # At 5 steps, not every state will converge before the next batch.
455
+ # This tests the merging of partially converged states with new states
456
+ # which has been a bug in the past. See https://github.com/Radical-AI/torch-sim/pull/219
457
+ 10 , # At 10 steps, all states will converge before the next batch
458
+ ],
459
+ )
451
460
def test_in_flight_with_fire (
452
461
si_sim_state : ts .SimState ,
453
462
fe_supercell_sim_state : ts .SimState ,
454
463
lj_model : LennardJonesModel ,
464
+ num_steps_per_batch : int ,
455
465
) -> None :
456
466
fire_init , fire_update = unit_cell_fire (lj_model )
457
467
@@ -489,62 +499,7 @@ def convergence_fn(state: ts.SimState) -> bool:
489
499
if state is None :
490
500
break
491
501
492
- # run 10 steps (so all states converge before the next batch)
493
- for _ in range (10 ):
494
- state = fire_update (state )
495
- convergence_tensor = convergence_fn (state )
496
-
497
- assert len (all_completed_states ) == len (fire_states )
498
-
499
-
500
- def test_in_flight_with_fire_only_converge_some_states (
501
- si_sim_state : ts .SimState ,
502
- fe_supercell_sim_state : ts .SimState ,
503
- lj_model : LennardJonesModel ,
504
- ) -> None :
505
- """This test is the same as the test_in_flight_with_fire above
506
- but we only converge a few states before we trigger the auto batcher.
507
- This can catch bugs related to merging partially converged and fully converged
508
- states. See https://github.com/Radical-AI/torch-sim/pull/219
509
- """
510
- fire_init , fire_update = unit_cell_fire (lj_model )
511
-
512
- si_fire_state = fire_init (si_sim_state )
513
- fe_fire_state = fire_init (fe_supercell_sim_state )
514
-
515
- fire_states = [si_fire_state , fe_fire_state ] * 5
516
- fire_states = [state .clone () for state in fire_states ]
517
- for state in fire_states :
518
- state .positions += torch .randn_like (state .positions ) * 0.01
519
-
520
- batcher = InFlightAutoBatcher (
521
- model = lj_model ,
522
- memory_scales_with = "n_atoms" ,
523
- # max_metric=400_000,
524
- max_memory_scaler = 600 ,
525
- )
526
- batcher .load_states (fire_states )
527
-
528
- def convergence_fn (state : ts .SimState ) -> bool :
529
- system_wise_max_force = torch .zeros (
530
- state .n_systems , device = state .device , dtype = torch .float64
531
- )
532
- max_forces = state .forces .norm (dim = 1 )
533
- system_wise_max_force = system_wise_max_force .scatter_reduce (
534
- dim = 0 , index = state .system_idx , src = max_forces , reduce = "amax"
535
- )
536
- return system_wise_max_force < 5e-1
537
-
538
- all_completed_states , convergence_tensor = [], None
539
- while True :
540
- state , completed_states = batcher .next_batch (state , convergence_tensor )
541
-
542
- all_completed_states .extend (completed_states )
543
- if state is None :
544
- break
545
-
546
- # run 5 steps (so not every state can converge before the next batch)
547
- for _ in range (5 ):
502
+ for _ in range (num_steps_per_batch ):
548
503
state = fire_update (state )
549
504
convergence_tensor = convergence_fn (state )
550
505
0 commit comments