@@ -152,8 +152,9 @@ def test_fire_optimization(
152
152
energies .append (state .energy .item ())
153
153
steps_taken += 1
154
154
155
- if steps_taken == max_steps :
156
- print (f"FIRE optimization for { md_flavor = } did not converge in { max_steps } steps" )
155
+ assert steps_taken < max_steps , (
156
+ f"FIRE optimization for { md_flavor = } did not converge in { max_steps = } "
157
+ )
157
158
158
159
energies = energies [1 :]
159
160
@@ -327,7 +328,6 @@ def test_unit_cell_fire_optimization(
327
328
ar_supercell_sim_state : ts .SimState , lj_model : torch .nn .Module , md_flavor : MdFlavor
328
329
) -> None :
329
330
"""Test that the Unit Cell FIRE optimizer actually minimizes energy."""
330
- print (f"\n --- Starting test_unit_cell_fire_optimization for { md_flavor = } ---" )
331
331
332
332
# Add random displacement to positions and cell
333
333
current_positions = (
@@ -347,48 +347,33 @@ def test_unit_cell_fire_optimization(
347
347
atomic_numbers = ar_supercell_sim_state .atomic_numbers .clone (),
348
348
batch = ar_supercell_sim_state .batch .clone (),
349
349
)
350
- print (f"[{ md_flavor } ] Initial SimState created." )
351
350
352
351
initial_state_positions = current_sim_state .positions .clone ()
353
352
initial_state_cell = current_sim_state .cell .clone ()
354
353
355
354
# Initialize FIRE optimizer
356
- print (f"Initializing { md_flavor } optimizer..." )
357
355
init_fn , update_fn = unit_cell_fire (
358
356
model = lj_model ,
359
357
dt_max = 0.3 ,
360
358
dt_start = 0.1 ,
361
359
md_flavor = md_flavor ,
362
360
)
363
- print (f"[{ md_flavor } ] Optimizer functions obtained." )
364
361
365
362
state = init_fn (current_sim_state )
366
- energy = float (getattr (state , "energy" , "nan" ))
367
- print (f"[{ md_flavor } ] Initial state created by init_fn. { energy = :.4f} " )
368
363
369
364
# Run optimization for a few steps
370
365
energies = [1000.0 , state .energy .item ()]
371
366
max_steps = 1000
372
367
steps_taken = 0
373
- print (f"[{ md_flavor } ] Entering optimization loop (max_steps: { max_steps } )..." )
374
368
375
369
while abs (energies [- 2 ] - energies [- 1 ]) > 1e-6 and steps_taken < max_steps :
376
370
state = update_fn (state )
377
371
energies .append (state .energy .item ())
378
372
steps_taken += 1
379
373
380
- print (f"[{ md_flavor } ] Loop finished after { steps_taken } steps." )
381
-
382
- if steps_taken == max_steps and abs (energies [- 2 ] - energies [- 1 ]) > 1e-6 :
383
- print (
384
- f"WARNING: Unit Cell FIRE { md_flavor = } optimization did not converge "
385
- f"in { max_steps } steps. Final energy: { energies [- 1 ]:.4f} "
386
- )
387
- else :
388
- print (
389
- f"Unit Cell FIRE { md_flavor = } optimization converged in { steps_taken } "
390
- f"steps. Final energy: { energies [- 1 ]:.4f} "
391
- )
374
+ assert steps_taken < max_steps , (
375
+ f"Unit Cell FIRE { md_flavor = } optimization did not converge in { max_steps = } "
376
+ )
392
377
393
378
energies = energies [1 :]
394
379
@@ -522,7 +507,6 @@ def test_frechet_cell_fire_optimization(
522
507
) -> None :
523
508
"""Test that the Frechet Cell FIRE optimizer actually minimizes energy for different
524
509
md_flavors."""
525
- print (f"\n --- Starting test_frechet_cell_fire_optimization for { md_flavor = } ---" )
526
510
527
511
# Add random displacement to positions and cell
528
512
# Create a fresh copy for each test run to avoid interference
@@ -543,48 +527,33 @@ def test_frechet_cell_fire_optimization(
543
527
atomic_numbers = ar_supercell_sim_state .atomic_numbers .clone (),
544
528
batch = ar_supercell_sim_state .batch .clone (),
545
529
)
546
- print (f"[{ md_flavor } ] Initial SimState created for Frechet test." )
547
530
548
531
initial_state_positions = current_sim_state .positions .clone ()
549
532
initial_state_cell = current_sim_state .cell .clone ()
550
533
551
534
# Initialize FIRE optimizer
552
- print (f"Initializing Frechet { md_flavor } optimizer..." )
553
535
init_fn , update_fn = frechet_cell_fire (
554
536
model = lj_model ,
555
537
dt_max = 0.3 ,
556
538
dt_start = 0.1 ,
557
539
md_flavor = md_flavor ,
558
540
)
559
- print (f"[{ md_flavor } ] Frechet optimizer functions obtained." )
560
541
561
542
state = init_fn (current_sim_state )
562
- energy = float (getattr (state , "energy" , "nan" ))
563
- print (f"[{ md_flavor } ] Initial state created by Frechet init_fn. { energy = :.4f} " )
564
543
565
544
# Run optimization for a few steps
566
545
energies = [1000.0 , state .energy .item ()] # Ensure float for comparison
567
546
max_steps = 1000
568
547
steps_taken = 0
569
- print (f"[{ md_flavor } ] Entering Frechet optimization loop (max_steps: { max_steps } )..." )
570
548
571
549
while abs (energies [- 2 ] - energies [- 1 ]) > 1e-6 and steps_taken < max_steps :
572
550
state = update_fn (state )
573
551
energies .append (state .energy .item ())
574
552
steps_taken += 1
575
553
576
- print (f"[{ md_flavor } ] Frechet loop finished after { steps_taken } steps." )
577
-
578
- if steps_taken == max_steps and abs (energies [- 2 ] - energies [- 1 ]) > 1e-6 :
579
- print (
580
- f"WARNING: Frechet Cell FIRE { md_flavor = } optimization did not converge "
581
- f"in { max_steps } steps. Final energy: { energies [- 1 ]:.4f} "
582
- )
583
- else :
584
- print (
585
- f"Frechet Cell FIRE { md_flavor = } optimization converged in { steps_taken } "
586
- f"steps. Final energy: { energies [- 1 ]:.4f} "
587
- )
554
+ assert steps_taken < max_steps , (
555
+ f"Frechet FIRE { md_flavor = } optimization did not converge in { max_steps = } "
556
+ )
588
557
589
558
energies = energies [1 :]
590
559
@@ -600,8 +569,7 @@ def test_frechet_cell_fire_optimization(
600
569
pressure = torch .trace (state .stress .squeeze (0 )) / 3.0
601
570
602
571
# Adjust tolerances if needed, Frechet might behave slightly differently
603
- pressure_tol = 0.01
604
- force_tol = 0.2
572
+ pressure_tol , force_tol = 0.01 , 0.2
605
573
606
574
assert torch .abs (pressure ) < pressure_tol , (
607
575
f"{ md_flavor = } pressure should be below { pressure_tol = } after Frechet "
0 commit comments