@@ -199,10 +199,8 @@ def forward(self, state: SimState | StateDict, **kwargs) -> dict[str, torch.Tens
199
199
"""
200
200
201
201
202
- def validate_model_outputs (
203
- model : ModelInterface ,
204
- device : torch .device ,
205
- dtype : torch .dtype ,
202
+ def validate_model_outputs ( # noqa: C901, PLR0915
203
+ model : ModelInterface , device : torch .device , dtype : torch .dtype
206
204
) -> None :
207
205
"""Validate the outputs of a model implementation against the interface requirements.
208
206
@@ -233,10 +231,9 @@ def validate_model_outputs(
233
231
"""
234
232
from ase .build import bulk
235
233
236
- assert model .dtype is not None
237
- assert model .device is not None
238
- assert model .compute_stress is not None
239
- assert model .compute_forces is not None
234
+ for attr in ("dtype" , "device" , "compute_stress" , "compute_forces" ):
235
+ if not hasattr (model , attr ):
236
+ raise ValueError (f"model.{ attr } is not set" )
240
237
241
238
try :
242
239
if not model .compute_stress :
@@ -265,52 +262,56 @@ def validate_model_outputs(
265
262
model_output = model .forward (sim_state )
266
263
267
264
# assert model did not mutate the input
268
- assert torch .allclose (og_positions , sim_state .positions )
269
- assert torch .allclose (og_cell , sim_state .cell )
270
- assert torch .allclose (og_batch , sim_state .batch )
271
- assert torch .allclose (og_atomic_numbers , sim_state .atomic_numbers )
265
+ if not torch .allclose (og_positions , sim_state .positions ):
266
+ raise ValueError (f"{ og_positions = } != { sim_state .positions = } " )
267
+ if not torch .allclose (og_cell , sim_state .cell ):
268
+ raise ValueError (f"{ og_cell = } != { sim_state .cell = } " )
269
+ if not torch .allclose (og_batch , sim_state .batch ):
270
+ raise ValueError (f"{ og_batch = } != { sim_state .batch = } " )
271
+ if not torch .allclose (og_atomic_numbers , sim_state .atomic_numbers ):
272
+ raise ValueError (f"{ og_atomic_numbers = } != { sim_state .atomic_numbers = } " )
272
273
273
274
# assert model output has the correct keys
274
- assert "energy" in model_output
275
- assert "forces" in model_output if force_computed else True
276
- assert "stress" in model_output if stress_computed else True
275
+ if "energy" not in model_output :
276
+ raise ValueError ("energy not in model output" )
277
+ if force_computed and "forces" not in model_output :
278
+ raise ValueError ("forces not in model output" )
279
+ if stress_computed and "stress" not in model_output :
280
+ raise ValueError ("stress not in model output" )
277
281
278
282
# assert model output shapes are correct
279
- assert model_output ["energy" ].shape == (2 ,)
280
- assert model_output ["forces" ].shape == (20 , 3 ) if force_computed else True
281
- assert model_output ["stress" ].shape == (2 , 3 , 3 ) if stress_computed else True
283
+ if model_output ["energy" ].shape != (2 ,):
284
+ raise ValueError (f"{ model_output ['energy' ].shape = } != (2,)" )
285
+ if force_computed and model_output ["forces" ].shape != (20 , 3 ):
286
+ raise ValueError (f"{ model_output ['forces' ].shape = } != (20, 3)" )
287
+ if stress_computed and model_output ["stress" ].shape != (2 , 3 , 3 ):
288
+ raise ValueError (f"{ model_output ['stress' ].shape = } != (2, 3, 3)" )
282
289
283
290
si_state = ts .io .atoms_to_state ([si_atoms ], device , dtype )
284
291
fe_state = ts .io .atoms_to_state ([fe_atoms ], device , dtype )
285
292
286
293
si_model_output = model .forward (si_state )
287
- assert torch .allclose (
294
+ if not torch .allclose (
288
295
si_model_output ["energy" ], model_output ["energy" ][0 ], atol = 10e-3
289
- )
290
- assert torch .allclose (
291
- si_model_output ["forces" ],
292
- model_output ["forces" ][: si_state .n_atoms ],
296
+ ):
297
+ raise ValueError (f"{ si_model_output ['energy' ]= } != { model_output ['energy' ][0 ]= } " )
298
+ if not torch .allclose (
299
+ forces := si_model_output ["forces" ],
300
+ expected_forces := model_output ["forces" ][: si_state .n_atoms ],
293
301
atol = 10e-3 ,
294
- )
295
- # assert torch.allclose(
296
- # si_model_output["stress"],
297
- # model_output["stress"][0],
298
- # atol=10e-3,
299
- # )
302
+ ):
303
+ raise ValueError (f"{ forces = } != { expected_forces = } " )
300
304
301
305
fe_model_output = model .forward (fe_state )
302
306
si_model_output = model .forward (si_state )
303
307
304
- assert torch .allclose (
308
+ if not torch .allclose (
305
309
fe_model_output ["energy" ], model_output ["energy" ][1 ], atol = 10e-2
306
- )
307
- assert torch .allclose (
308
- fe_model_output ["forces" ],
309
- model_output ["forces" ][si_state .n_atoms :],
310
+ ):
311
+ raise ValueError (f"{ fe_model_output ['energy' ]= } != { model_output ['energy' ][1 ]= } " )
312
+ if not torch .allclose (
313
+ forces := fe_model_output ["forces" ],
314
+ expected_forces := model_output ["forces" ][si_state .n_atoms :],
310
315
atol = 10e-2 ,
311
- )
312
- # assert torch.allclose(
313
- # arr_model_output["stress"],
314
- # model_output["stress"][1],
315
- # atol=10e-3,
316
- # )
316
+ ):
317
+ raise ValueError (f"{ forces = } != { expected_forces = } " )
0 commit comments