Skip to content

Commit 2549457

Browse files
committed
replace assert statements with ValueError in multiple files
1 parent c69b54e commit 2549457

File tree

9 files changed

+93
-75
lines changed

9 files changed

+93
-75
lines changed

examples/scripts/7_Others/7.6_Compare_ASE_to_VV_FIRE.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -484,7 +484,8 @@ def run_optimization_ase( # noqa: C901, PLR0915
484484
final_state_opt: SimState | GDState | None = None
485485

486486
if optimizer_type_val == "torch_sim":
487-
assert ts_md_flavor_val is not None, "ts_md_flavor must be provided for torch_sim"
487+
if ts_md_flavor_val is None:
488+
raise ValueError(f"{ts_md_flavor_val=} must be provided for torch_sim")
488489
steps, final_state_opt = run_optimization_ts(
489490
initial_state=state.clone(),
490491
ts_md_flavor=ts_md_flavor_val,

pyproject.toml

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -102,14 +102,11 @@ ignore = [
102102
"ISC001", # avoid conflicts with the formatter
103103
"N803", # Variable name should be lowercase
104104
"N806", # Uppercase letters in variable names
105-
"PD010", # .pivot_table is preferred to .pivot or .unstack; provides same functionality
106-
"PD015", # pandas-use-of-pd-merge
107105
"PLR0912", # too many branches
108106
"PLR0913", # too many function arguments
109107
"PLR2004", # Magic value used in comparison, consider replacing {value} with a constant variable
110108
"PLW2901", # Outer for loop variable overwritten by inner assignment target
111109
"PTH", # flake8-use-pathlib
112-
"S101", # Use of assertion statements
113110
"S301", # pickle and modules that wrap it can be unsafe, possible security issue
114111
"S311", # Standard pseudo-random generators are not suitable for cryptographic purposes
115112
"SIM105", # Use contextlib.suppress instead of try-except-pass

torch_sim/autobatching.py

Lines changed: 10 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -955,10 +955,8 @@ def _get_first_batch(self) -> SimState:
955955
self.max_memory_scaler = self.max_memory_scaler * self.max_memory_padding
956956
return concatenate_states([first_state, *states])
957957

958-
def next_batch(
959-
self,
960-
updated_state: SimState | None,
961-
convergence_tensor: torch.Tensor | None,
958+
def next_batch( # noqa: C901
959+
self, updated_state: SimState | None, convergence_tensor: torch.Tensor | None
962960
) -> (
963961
tuple[SimState | None, list[SimState]]
964962
| tuple[SimState | None, list[SimState], list[int]]
@@ -1022,10 +1020,14 @@ def next_batch(
10221020

10231021
# assert statements helpful for debugging, should be moved to validate fn
10241022
# the first two are most important
1025-
assert len(convergence_tensor) == updated_state.n_batches
1026-
assert len(self.current_idx) == len(self.current_scalers)
1027-
assert len(convergence_tensor.shape) == 1
1028-
assert updated_state.n_batches > 0
1023+
if len(convergence_tensor) != updated_state.n_batches:
1024+
raise ValueError(f"{len(convergence_tensor)=} != {updated_state.n_batches=}")
1025+
if len(self.current_idx) != len(self.current_scalers):
1026+
raise ValueError(f"{len(self.current_idx)=} != {len(self.current_scalers)=}")
1027+
if len(convergence_tensor.shape) != 1:
1028+
raise ValueError(f"{len(convergence_tensor.shape)=} != 1")
1029+
if updated_state.n_batches <= 0:
1030+
raise ValueError(f"{updated_state.n_batches=} <= 0")
10291031

10301032
# Increment attempt counters and check for max attempts in a single loop
10311033
for cur_idx, abs_idx in enumerate(self.current_idx):

torch_sim/models/fairchem.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -175,7 +175,8 @@ def __init__( # noqa: C901, PLR0915
175175
)
176176

177177
# Either the config path or the checkpoint path needs to be provided
178-
assert config_yml or model is not None
178+
if not config_yml and model is None:
179+
raise ValueError("Either config_yml or model must be provided")
179180

180181
checkpoint = None
181182
if config_yml is not None:

torch_sim/models/interface.py

Lines changed: 41 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -199,10 +199,8 @@ def forward(self, state: SimState | StateDict, **kwargs) -> dict[str, torch.Tens
199199
"""
200200

201201

202-
def validate_model_outputs(
203-
model: ModelInterface,
204-
device: torch.device,
205-
dtype: torch.dtype,
202+
def validate_model_outputs( # noqa: C901, PLR0915
203+
model: ModelInterface, device: torch.device, dtype: torch.dtype
206204
) -> None:
207205
"""Validate the outputs of a model implementation against the interface requirements.
208206
@@ -233,10 +231,9 @@ def validate_model_outputs(
233231
"""
234232
from ase.build import bulk
235233

236-
assert model.dtype is not None
237-
assert model.device is not None
238-
assert model.compute_stress is not None
239-
assert model.compute_forces is not None
234+
for attr in ("dtype", "device", "compute_stress", "compute_forces"):
235+
if not hasattr(model, attr):
236+
raise ValueError(f"model.{attr} is not set")
240237

241238
try:
242239
if not model.compute_stress:
@@ -265,52 +262,56 @@ def validate_model_outputs(
265262
model_output = model.forward(sim_state)
266263

267264
# assert model did not mutate the input
268-
assert torch.allclose(og_positions, sim_state.positions)
269-
assert torch.allclose(og_cell, sim_state.cell)
270-
assert torch.allclose(og_batch, sim_state.batch)
271-
assert torch.allclose(og_atomic_numbers, sim_state.atomic_numbers)
265+
if not torch.allclose(og_positions, sim_state.positions):
266+
raise ValueError(f"{og_positions=} != {sim_state.positions=}")
267+
if not torch.allclose(og_cell, sim_state.cell):
268+
raise ValueError(f"{og_cell=} != {sim_state.cell=}")
269+
if not torch.allclose(og_batch, sim_state.batch):
270+
raise ValueError(f"{og_batch=} != {sim_state.batch=}")
271+
if not torch.allclose(og_atomic_numbers, sim_state.atomic_numbers):
272+
raise ValueError(f"{og_atomic_numbers=} != {sim_state.atomic_numbers=}")
272273

273274
# assert model output has the correct keys
274-
assert "energy" in model_output
275-
assert "forces" in model_output if force_computed else True
276-
assert "stress" in model_output if stress_computed else True
275+
if "energy" not in model_output:
276+
raise ValueError("energy not in model output")
277+
if force_computed and "forces" not in model_output:
278+
raise ValueError("forces not in model output")
279+
if stress_computed and "stress" not in model_output:
280+
raise ValueError("stress not in model output")
277281

278282
# assert model output shapes are correct
279-
assert model_output["energy"].shape == (2,)
280-
assert model_output["forces"].shape == (20, 3) if force_computed else True
281-
assert model_output["stress"].shape == (2, 3, 3) if stress_computed else True
283+
if model_output["energy"].shape != (2,):
284+
raise ValueError(f"{model_output['energy'].shape=} != (2,)")
285+
if force_computed and model_output["forces"].shape != (20, 3):
286+
raise ValueError(f"{model_output['forces'].shape=} != (20, 3)")
287+
if stress_computed and model_output["stress"].shape != (2, 3, 3):
288+
raise ValueError(f"{model_output['stress'].shape=} != (2, 3, 3)")
282289

283290
si_state = ts.io.atoms_to_state([si_atoms], device, dtype)
284291
fe_state = ts.io.atoms_to_state([fe_atoms], device, dtype)
285292

286293
si_model_output = model.forward(si_state)
287-
assert torch.allclose(
294+
if not torch.allclose(
288295
si_model_output["energy"], model_output["energy"][0], atol=10e-3
289-
)
290-
assert torch.allclose(
291-
si_model_output["forces"],
292-
model_output["forces"][: si_state.n_atoms],
296+
):
297+
raise ValueError(f"{si_model_output['energy']=} != {model_output['energy'][0]=}")
298+
if not torch.allclose(
299+
forces := si_model_output["forces"],
300+
expected_forces := model_output["forces"][: si_state.n_atoms],
293301
atol=10e-3,
294-
)
295-
# assert torch.allclose(
296-
# si_model_output["stress"],
297-
# model_output["stress"][0],
298-
# atol=10e-3,
299-
# )
302+
):
303+
raise ValueError(f"{forces=} != {expected_forces=}")
300304

301305
fe_model_output = model.forward(fe_state)
302306
si_model_output = model.forward(si_state)
303307

304-
assert torch.allclose(
308+
if not torch.allclose(
305309
fe_model_output["energy"], model_output["energy"][1], atol=10e-2
306-
)
307-
assert torch.allclose(
308-
fe_model_output["forces"],
309-
model_output["forces"][si_state.n_atoms :],
310+
):
311+
raise ValueError(f"{fe_model_output['energy']=} != {model_output['energy'][1]=}")
312+
if not torch.allclose(
313+
forces := fe_model_output["forces"],
314+
expected_forces := model_output["forces"][si_state.n_atoms :],
310315
atol=10e-2,
311-
)
312-
# assert torch.allclose(
313-
# arr_model_output["stress"],
314-
# model_output["stress"][1],
315-
# atol=10e-3,
316-
# )
316+
):
317+
raise ValueError(f"{forces=} != {expected_forces=}")

torch_sim/models/soft_sphere.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -642,9 +642,10 @@ def __init__(
642642
)
643643

644644
# Ensure parameter matrices are symmetric (required for energy conservation)
645-
assert torch.allclose(self.sigma_matrix, self.sigma_matrix.T)
646-
assert torch.allclose(self.epsilon_matrix, self.epsilon_matrix.T)
647-
assert torch.allclose(self.alpha_matrix, self.alpha_matrix.T)
645+
for matrix_name in ("sigma_matrix", "epsilon_matrix", "alpha_matrix"):
646+
matrix = getattr(self, matrix_name)
647+
if not torch.allclose(matrix, matrix.T):
648+
raise ValueError(f"{matrix_name} is not symmetric")
648649

649650
# Set interaction cutoff distance
650651
self.cutoff = torch.tensor(

torch_sim/neighbors.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -109,7 +109,8 @@ def primitive_neighbor_list( # noqa: C901, PLR0915
109109
1 / l3 if l3 > 0 else pytorch_scalar_1,
110110
]
111111
)
112-
assert face_dist_c.shape == (3,)
112+
if face_dist_c.shape != (3,):
113+
raise ValueError(f"{face_dist_c.shape=} != (3,)")
113114

114115
# we don't handle other fancier cutoffs
115116
max_cutoff: torch.Tensor = cutoff
@@ -214,8 +215,10 @@ def primitive_neighbor_list( # noqa: C901, PLR0915
214215
bin_index_i = bin_index_i[mask]
215216

216217
# Make sure that all atoms have been sorted into bins.
217-
assert len(atom_i) == 0
218-
assert len(bin_index_i) == 0
218+
if len(atom_i) != 0:
219+
raise ValueError(f"{len(atom_i)=} != 0")
220+
if len(bin_index_i) != 0:
221+
raise ValueError(f"{len(bin_index_i)=} != 0")
219222

220223
# Now we construct neighbor pairs by pairing up all atoms within a bin or
221224
# between bin and neighboring bin. atom_pairs_pn is a helper buffer that

torch_sim/optimizers.py

Lines changed: 24 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1273,7 +1273,8 @@ def _vv_fire_step( # noqa: C901, PLR0915
12731273
if is_cell_optimization:
12741274
cell_factor_reshaped = state.cell_factor.view(n_batches, 1, 1)
12751275
if is_frechet:
1276-
assert isinstance(state, FrechetCellFIREState)
1276+
if not isinstance(state, expected_cls := FrechetCellFIREState):
1277+
raise ValueError(f"{type(state)=} must be a {expected_cls.__name__}")
12771278
cur_deform_grad = state.deform_grad()
12781279
deform_grad_log = torch.zeros_like(cur_deform_grad)
12791280
for b in range(n_batches):
@@ -1291,7 +1292,8 @@ def _vv_fire_step( # noqa: C901, PLR0915
12911292
state.row_vector_cell = new_row_vector_cell
12921293
state.cell_positions = cell_positions_log_scaled_new
12931294
else:
1294-
assert isinstance(state, UnitCellFireState)
1295+
if not isinstance(state, expected_cls := UnitCellFireState):
1296+
raise ValueError(f"{type(state)=} must be a {expected_cls.__name__}")
12951297
cur_deform_grad = state.deform_grad()
12961298
cell_factor_expanded = state.cell_factor.expand(n_batches, 3, 1)
12971299
current_cell_positions_scaled = (
@@ -1329,7 +1331,8 @@ def _vv_fire_step( # noqa: C901, PLR0915
13291331
).unsqueeze(0).expand(n_batches, -1, -1)
13301332

13311333
if is_frechet:
1332-
assert isinstance(state, FrechetCellFIREState)
1334+
if not isinstance(state, expected_cls := FrechetCellFIREState):
1335+
raise ValueError(f"{type(state)=} must be a {expected_cls.__name__}")
13331336
ucf_cell_grad = torch.bmm(
13341337
virial, torch.linalg.inv(torch.transpose(deform_grad_new, 1, 2))
13351338
)
@@ -1353,7 +1356,8 @@ def _vv_fire_step( # noqa: C901, PLR0915
13531356
new_cell_forces[b] = forces_flat.reshape(3, 3)
13541357
state.cell_forces = new_cell_forces / cell_factor_reshaped
13551358
else:
1356-
assert isinstance(state, UnitCellFireState)
1359+
if not isinstance(state, expected_cls := UnitCellFireState):
1360+
raise ValueError(f"{type(state)=} must be a {expected_cls.__name__}")
13571361
state.cell_forces = virial / cell_factor_reshaped
13581362

13591363
state.velocities += 0.5 * atom_wise_dt * state.forces / state.masses.unsqueeze(-1)
@@ -1564,15 +1568,17 @@ def _ase_fire_step( # noqa: C901, PLR0915
15641568
)
15651569

15661570
if is_frechet:
1567-
assert isinstance(state, FrechetCellFIREState)
1571+
if not isinstance(state, expected_cls := FrechetCellFIREState):
1572+
raise ValueError(f"{type(state)=} must be a {expected_cls.__name__}")
15681573
new_logm_F_scaled = state.cell_positions + dr_cell
15691574
state.cell_positions = new_logm_F_scaled
15701575
logm_F_new = new_logm_F_scaled / (state.cell_factor + eps)
15711576
F_new = torch.matrix_exp(logm_F_new)
15721577
new_row_vector_cell = torch.bmm(state.reference_row_vector_cell, F_new.mT)
15731578
state.row_vector_cell = new_row_vector_cell
15741579
else:
1575-
assert isinstance(state, UnitCellFireState)
1580+
if not isinstance(state, expected_cls := UnitCellFireState):
1581+
raise ValueError(f"{type(state)=} must be a {expected_cls.__name__}")
15761582
F_current = state.deform_grad()
15771583
cell_factor_exp_mult = state.cell_factor.expand(n_batches, 3, 1)
15781584
current_F_scaled = F_current * cell_factor_exp_mult
@@ -1619,13 +1625,16 @@ def _ase_fire_step( # noqa: C901, PLR0915
16191625
).unsqueeze(0).expand(n_batches, -1, -1)
16201626

16211627
if is_frechet:
1622-
assert isinstance(state, FrechetCellFIREState)
1623-
assert F_new is not None, (
1624-
"F_new should be defined for Frechet cell force calculation"
1625-
)
1626-
assert logm_F_new is not None, (
1627-
"logm_F_new should be defined for Frechet cell force calculation"
1628-
)
1628+
if not isinstance(state, expected_cls := FrechetCellFIREState):
1629+
raise ValueError(f"{type(state)=} must be a {expected_cls.__name__}")
1630+
if F_new is None:
1631+
raise ValueError(
1632+
"F_new should be defined for Frechet cell force calculation"
1633+
)
1634+
if logm_F_new is None:
1635+
raise ValueError(
1636+
"logm_F_new should be defined for Frechet cell force calculation"
1637+
)
16291638
ucf_cell_grad = torch.bmm(
16301639
virial, torch.linalg.inv(torch.transpose(F_new, 1, 2))
16311640
)
@@ -1649,7 +1658,8 @@ def _ase_fire_step( # noqa: C901, PLR0915
16491658
new_cell_forces_log_space[b_idx] = forces_flat.reshape(3, 3)
16501659
state.cell_forces = new_cell_forces_log_space / (state.cell_factor + eps)
16511660
else:
1652-
assert isinstance(state, UnitCellFireState)
1661+
if not isinstance(state, expected_cls := UnitCellFireState):
1662+
raise ValueError(f"{type(state)=} must be a {expected_cls.__name__}")
16531663
state.cell_forces = virial / state.cell_factor
16541664

16551665
return state

torch_sim/transforms.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -521,8 +521,10 @@ def compute_distances_with_cell_shifts(
521521
torch.Tensor: A tensor of shape (n_pairs,) containing the
522522
computed distances for each pair.
523523
"""
524-
assert mapping.dim() == 2
525-
assert mapping.shape[0] == 2
524+
if mapping.dim() != 2:
525+
raise ValueError(f"Mapping must be a 2D tensor, got {mapping.shape}")
526+
if mapping.shape[0] != 2:
527+
raise ValueError(f"Mapping must have 2 rows, got {mapping.shape[0]}")
526528

527529
if cell_shifts is None:
528530
dr = pos[mapping[1]] - pos[mapping[0]]

0 commit comments

Comments
 (0)