Skip to content

Commit bdcfd1a

Browse files
committed
remove debug print statements from tests and replace them with assertions, adjust ruff config to flag prints in package code
1 parent db4782d commit bdcfd1a

File tree

11 files changed

+24
-60
lines changed

11 files changed

+24
-60
lines changed

.pre-commit-config.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@ default_install_hook_types: [pre-commit, commit-msg]
77

88
repos:
99
- repo: https://github.com/astral-sh/ruff-pre-commit
10-
rev: v0.11.12
10+
rev: v0.11.13
1111
hooks:
1212
- id: ruff
1313
args: [--fix]

pyproject.toml

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -114,10 +114,8 @@ ignore = [
114114
"S301", # pickle and modules that wrap it can be unsafe, possible security issue
115115
"S311", # Standard pseudo-random generators are not suitable for cryptographic purposes
116116
"SIM105", # Use contextlib.suppress instead of try-except-pass
117-
"T201", # print found
118117
"TD", # flake8-todos
119118
"TRY003", # Avoid specifying long messages outside the exception class
120-
"TRY301", # Abstract raise to an inner function
121119
]
122120
pydocstyle.convention = "google"
123121
isort.split-on-trailing-comma = false
@@ -126,7 +124,7 @@ pep8-naming.ignore-names = ["get_kT", "kT"]
126124

127125
[tool.ruff.lint.per-file-ignores]
128126
"**/tests/*" = ["ANN201", "D", "S101"]
129-
"examples/**/*" = ["B018"]
127+
"examples/**/*" = ["B018", "T201"]
130128
"examples/tutorials/**/*" = ["ALL"]
131129

132130
[tool.ruff.format]

tests/models/test_morse.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,6 @@ def test_morse_pair_asymptotic() -> None:
2828
dr = torch.tensor([[1.0]]) # Large distance
2929
epsilon = 5.0
3030
energy = morse_pair(dr, epsilon=epsilon)
31-
print(energy, -epsilon * torch.ones_like(energy))
3231
torch.testing.assert_close(
3332
energy, -epsilon * torch.ones_like(energy), rtol=1e-2, atol=1e-5
3433
)
@@ -55,8 +54,6 @@ def test_morse_force_energy_consistency() -> None:
5554
force_from_grad = -torch.autograd.grad(energy.sum(), dr, create_graph=True)[0]
5655

5756
# Compare forces
58-
print(force_direct)
59-
print(force_from_grad)
6057
assert torch.allclose(force_direct, force_from_grad, rtol=1e-4, atol=1e-4)
6158

6259

tests/test_autobatching.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -457,10 +457,7 @@ def convergence_fn(state: ts.SimState) -> bool:
457457

458458
all_completed_states, convergence_tensor = [], None
459459
while True:
460-
print(f"Starting new batch of {state.n_batches} states.")
461-
462460
state, completed_states = batcher.next_batch(state, convergence_tensor)
463-
print("Number of completed states", len(completed_states))
464461

465462
all_completed_states.extend(completed_states)
466463
if state is None:

tests/test_optimizers.py

Lines changed: 10 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -152,8 +152,9 @@ def test_fire_optimization(
152152
energies.append(state.energy.item())
153153
steps_taken += 1
154154

155-
if steps_taken == max_steps:
156-
print(f"FIRE optimization for {md_flavor=} did not converge in {max_steps} steps")
155+
assert steps_taken < max_steps, (
156+
f"FIRE optimization for {md_flavor=} did not converge in {max_steps=}"
157+
)
157158

158159
energies = energies[1:]
159160

@@ -327,7 +328,6 @@ def test_unit_cell_fire_optimization(
327328
ar_supercell_sim_state: ts.SimState, lj_model: torch.nn.Module, md_flavor: MdFlavor
328329
) -> None:
329330
"""Test that the Unit Cell FIRE optimizer actually minimizes energy."""
330-
print(f"\n--- Starting test_unit_cell_fire_optimization for {md_flavor=} ---")
331331

332332
# Add random displacement to positions and cell
333333
current_positions = (
@@ -347,48 +347,33 @@ def test_unit_cell_fire_optimization(
347347
atomic_numbers=ar_supercell_sim_state.atomic_numbers.clone(),
348348
batch=ar_supercell_sim_state.batch.clone(),
349349
)
350-
print(f"[{md_flavor}] Initial SimState created.")
351350

352351
initial_state_positions = current_sim_state.positions.clone()
353352
initial_state_cell = current_sim_state.cell.clone()
354353

355354
# Initialize FIRE optimizer
356-
print(f"Initializing {md_flavor} optimizer...")
357355
init_fn, update_fn = unit_cell_fire(
358356
model=lj_model,
359357
dt_max=0.3,
360358
dt_start=0.1,
361359
md_flavor=md_flavor,
362360
)
363-
print(f"[{md_flavor}] Optimizer functions obtained.")
364361

365362
state = init_fn(current_sim_state)
366-
energy = float(getattr(state, "energy", "nan"))
367-
print(f"[{md_flavor}] Initial state created by init_fn. {energy=:.4f}")
368363

369364
# Run optimization for a few steps
370365
energies = [1000.0, state.energy.item()]
371366
max_steps = 1000
372367
steps_taken = 0
373-
print(f"[{md_flavor}] Entering optimization loop (max_steps: {max_steps})...")
374368

375369
while abs(energies[-2] - energies[-1]) > 1e-6 and steps_taken < max_steps:
376370
state = update_fn(state)
377371
energies.append(state.energy.item())
378372
steps_taken += 1
379373

380-
print(f"[{md_flavor}] Loop finished after {steps_taken} steps.")
381-
382-
if steps_taken == max_steps and abs(energies[-2] - energies[-1]) > 1e-6:
383-
print(
384-
f"WARNING: Unit Cell FIRE {md_flavor=} optimization did not converge "
385-
f"in {max_steps} steps. Final energy: {energies[-1]:.4f}"
386-
)
387-
else:
388-
print(
389-
f"Unit Cell FIRE {md_flavor=} optimization converged in {steps_taken} "
390-
f"steps. Final energy: {energies[-1]:.4f}"
391-
)
374+
assert steps_taken < max_steps, (
375+
f"Unit Cell FIRE {md_flavor=} optimization did not converge in {max_steps=}"
376+
)
392377

393378
energies = energies[1:]
394379

@@ -522,7 +507,6 @@ def test_frechet_cell_fire_optimization(
522507
) -> None:
523508
"""Test that the Frechet Cell FIRE optimizer actually minimizes energy for different
524509
md_flavors."""
525-
print(f"\n--- Starting test_frechet_cell_fire_optimization for {md_flavor=} ---")
526510

527511
# Add random displacement to positions and cell
528512
# Create a fresh copy for each test run to avoid interference
@@ -543,48 +527,33 @@ def test_frechet_cell_fire_optimization(
543527
atomic_numbers=ar_supercell_sim_state.atomic_numbers.clone(),
544528
batch=ar_supercell_sim_state.batch.clone(),
545529
)
546-
print(f"[{md_flavor}] Initial SimState created for Frechet test.")
547530

548531
initial_state_positions = current_sim_state.positions.clone()
549532
initial_state_cell = current_sim_state.cell.clone()
550533

551534
# Initialize FIRE optimizer
552-
print(f"Initializing Frechet {md_flavor} optimizer...")
553535
init_fn, update_fn = frechet_cell_fire(
554536
model=lj_model,
555537
dt_max=0.3,
556538
dt_start=0.1,
557539
md_flavor=md_flavor,
558540
)
559-
print(f"[{md_flavor}] Frechet optimizer functions obtained.")
560541

561542
state = init_fn(current_sim_state)
562-
energy = float(getattr(state, "energy", "nan"))
563-
print(f"[{md_flavor}] Initial state created by Frechet init_fn. {energy=:.4f}")
564543

565544
# Run optimization for a few steps
566545
energies = [1000.0, state.energy.item()] # Ensure float for comparison
567546
max_steps = 1000
568547
steps_taken = 0
569-
print(f"[{md_flavor}] Entering Frechet optimization loop (max_steps: {max_steps})...")
570548

571549
while abs(energies[-2] - energies[-1]) > 1e-6 and steps_taken < max_steps:
572550
state = update_fn(state)
573551
energies.append(state.energy.item())
574552
steps_taken += 1
575553

576-
print(f"[{md_flavor}] Frechet loop finished after {steps_taken} steps.")
577-
578-
if steps_taken == max_steps and abs(energies[-2] - energies[-1]) > 1e-6:
579-
print(
580-
f"WARNING: Frechet Cell FIRE {md_flavor=} optimization did not converge "
581-
f"in {max_steps} steps. Final energy: {energies[-1]:.4f}"
582-
)
583-
else:
584-
print(
585-
f"Frechet Cell FIRE {md_flavor=} optimization converged in {steps_taken} "
586-
f"steps. Final energy: {energies[-1]:.4f}"
587-
)
554+
assert steps_taken < max_steps, (
555+
f"Frechet FIRE {md_flavor=} optimization did not converge in {max_steps=}"
556+
)
588557

589558
energies = energies[1:]
590559

@@ -600,8 +569,7 @@ def test_frechet_cell_fire_optimization(
600569
pressure = torch.trace(state.stress.squeeze(0)) / 3.0
601570

602571
# Adjust tolerances if needed, Frechet might behave slightly differently
603-
pressure_tol = 0.01
604-
force_tol = 0.2
572+
pressure_tol, force_tol = 0.01, 0.2
605573

606574
assert torch.abs(pressure) < pressure_tol, (
607575
f"{md_flavor=} pressure should be below {pressure_tol=} after Frechet "

tests/test_runners.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -768,12 +768,12 @@ def test_readme_example(lj_model: LennardJonesModel, tmp_path: Path) -> None:
768768

769769
cu_atoms = bulk("Cu", "fcc", a=3.58, cubic=True).repeat((2, 2, 2))
770770
many_cu_atoms = [cu_atoms] * 5
771-
trajectory_files = [tmp_path / f"Cu_traj_{i}" for i in range(len(many_cu_atoms))]
771+
trajectory_files = [tmp_path / f"Cu_traj_{i}.h5md" for i in range(len(many_cu_atoms))]
772772

773773
# run them all simultaneously with batching
774774
final_state = ts.integrate(
775775
system=many_cu_atoms,
776-
model=lj_model,
776+
model=lj_model, # using LJ instead of MACE for testing
777777
n_steps=50,
778778
timestep=0.002,
779779
temperature=1000,
@@ -788,17 +788,17 @@ def test_readme_example(lj_model: LennardJonesModel, tmp_path: Path) -> None:
788788
with ts.TorchSimTrajectory(filename) as traj:
789789
final_energies.append(traj.get_array("potential_energy")[-1])
790790

791-
print(final_energies)
791+
assert len(final_energies) == len(trajectory_files)
792792

793793
# relax all of the high temperature states
794794
relaxed_state = ts.optimize(
795795
system=final_state,
796796
model=lj_model,
797797
optimizer=ts.frechet_cell_fire,
798-
# autobatcher=True,
798+
# autobatcher=True, # disabled for CPU-based LJ model in test
799799
)
800800

801-
print(relaxed_state.energy)
801+
assert relaxed_state.energy.shape == (final_state.n_batches,)
802802

803803

804804
@pytest.fixture

torch_sim/math.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -984,7 +984,7 @@ def matrix_log_33(
984984
"Falling back to scipy"
985985
)
986986
if fallback_warning:
987-
print(msg)
987+
print(msg) # noqa: T201
988988
# Fall back to scipy implementation
989989
return matrix_log_scipy(matrix).to(sim_dtype)
990990

torch_sim/models/fairchem.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,8 @@
1414
pretrained model checkpoints.
1515
"""
1616

17+
# ruff: noqa: T201
18+
1719
from __future__ import annotations
1820

1921
import copy

torch_sim/models/mace.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -169,7 +169,7 @@ def __init__(
169169
self.model = self.model.to(dtype=self.dtype)
170170

171171
if enable_cueq:
172-
print("Converting models to CuEq for acceleration")
172+
print("Converting models to CuEq for acceleration") # noqa: T201
173173
self.model = run_e3nn_to_cueq(self.model)
174174

175175
# Set model properties

torch_sim/optimizers.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1599,7 +1599,7 @@ def _ase_fire_step( # noqa: C901, PLR0915
15991599
volumes = torch.linalg.det(state.cell).view(n_batches, 1, 1)
16001600
if torch.any(volumes <= 0):
16011601
bad_indices = torch.where(volumes <= 0)[0].tolist()
1602-
print(
1602+
print( # noqa: T201
16031603
f"WARNING: Non-positive volume(s) detected during _ase_fire_step: "
16041604
f"{volumes[bad_indices].tolist()} at {bad_indices=} ({is_frechet=})"
16051605
)

0 commit comments

Comments
 (0)