-
Notifications
You must be signed in to change notification settings - Fork 39
Remove unbatched code #206
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
…egrators into folder
… as the MaceModel unit tests didn't fail removing them.
…m` (#203) * fea: use batched vdot --------- Co-authored-by: Janosh Riebesell <[email protected]> * clean: remove ai slop * clean: further attempts to clean but still not matching PR * fix: dr is vdt rather than fdt * typing: fix typing issue * wip: still not sure where the difference is now * update forces per comment * Fix ASE pos only implementation * Initialize velocities to None in the pos-only case, see previous changes to the optimizers. (ensures that the correct `dt` is used) * Change the order of the increase of `n_pos`. Again, this ensures the usage of the correct `dt` compared to ASE * Fix torch-sim ASE-FIRE (Frechet Cell) * Remove rescaling of positions when updating cell, it's not relevant * Correctly rescale the positions with respect to the deformation gradient * Consider the `cell_forces` in the convergence when doing cell optimizations * linting * test: still differ significantly after step 1 for distorted structures * Fix test comparing ASE and torch-sim optimization * Include the `cell_forces` in the convergence check * Fix the number of iterations that are performed. `steps_between_swaps` is set to 1, so the number of iterations is equal to the number of swaps. In the previous version, less iterations would have been performed when reaching the maximum number of swaps. For example, when trying to run 32 steps with `steps_between_swaps=5`, the optimization would have stopped after 30 iterations, i.e., 6 swaps. * Fix `autobatching.py`. The if statement would have been triggered for `max_attempts=0`, which was the case when running one iteration and `steps_between_swaps=5` * Fix `optimizers` when using `UnitCellFilter` * Fix the `None` initialization * Fix the cell update when using `UnitCellFilter` * fix test_optimize_fire * allow FireState.velocities = None since it's being set to None in multiple places * safer `batched_vdot`: check dimensionality of input tensors `y` and `batch_indices` - fix stale docstring mentioning is_sum_sq kwarg * generate_force_convergence_fn raise informative error on needed but missing cell_forces * pascal case VALID_FIRE_CELL_STATES->AnyFireCellState and fix non-f-string error messages * fix FireState TypeError: non-default argument 'dt' follows default argument * allow None but don't set default for state.velocities * fix bad merge conflict resolution * tweaks --------- Co-authored-by: Rhys Goodall <[email protected]> Co-authored-by: Janosh Riebesell <[email protected]>
@CompRhys @abhijeetgangan looks like this (or an earlier PR) broke the temperature profile in |
all those scripts previously used unbatched code and so they will never have worked with the batched code. I can only assume the unbatched versions did work. |
a2ced4f
to
e2330c0
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 14
🔭 Outside diff range comments (5)
torch_sim/models/particle_life.py (2)
196-198
:⚠️ Potential issueIncorrect call signature – unexpected keyword
epsilon
asymmetric_particle_pair_force_jit()
expects parameters(dr, A, beta=..., sigma=...)
.
The current call passessigma
and an unsupportedepsilon
, so it will raiseTypeError
.-pair_forces = asymmetric_particle_pair_force_jit( - distances, sigma=self.sigma, epsilon=self.epsilon -) +pair_forces = asymmetric_particle_pair_force_jit( + distances, + A=self.epsilon, # interaction strength + sigma=self.sigma, +)Also consider threading
beta
through if you plan to expose it as a model hyper-parameter.🧰 Tools
🪛 Pylint (3.3.7)
[error] 196-198: Unexpected keyword argument 'epsilon' in function call
(E1123)
[error] 196-198: No value for argument 'A' in function call
(E1120)
205-210
: 🛠️ Refactor suggestion
forces
are always computed regardless ofcompute_forces
flag
self._compute_forces
(set in__init__
) is ignored here – the function always builds and returns the forces tensor.
Either respect the flag or remove it from the constructor to avoid surprising users.-if self._compute_forces: - results["forces"] = forces +if self._compute_forces: + results["forces"] = forces… and wrap the expensive calculations above with the same guard.
torch_sim/workflows/a2c.py (2)
290-301
:⚠️ Potential issueWrong shape for
atomic_numbers
tensor
torch.ones_like(positions_cart, …)
makesatomic_numbers
shape[N_atoms, 3]
, yet downstream code (e.g.SimState
) expects a 1-D[N_atoms]
vector.
This silently broadcasts and later corrupts per-atom metadata.-atomic_numbers = torch.ones_like(positions_cart, device=device, dtype=torch.int) +atomic_numbers = torch.ones( + positions_cart.shape[0], device=device, dtype=torch.int +)
418-430
:⚠️ Potential issueMulti-species: wrong
atomic_numbers
dimensionalityIdentical shape issue as the single-species routine – please fix to a 1-D tensor.
torch_sim/optimizers.py (1)
1207-1221
: 🛠️ Refactor suggestionMonolithic
_vv_fire_step
violates single-responsibilityThe function carries >100 lines, ~20 arguments, many branches, and handles three algorithm flavours.
Recommend extracting helpers (e.g. cell-update, parameter-update, mixing) to improve testability and readability.🧰 Tools
🪛 Pylint (3.3.7)
[refactor] 1207-1207: Too many arguments (11/5)
(R0913)
[refactor] 1207-1207: Too many local variables (57/15)
(R0914)
[refactor] 1207-1207: Too many branches (18/12)
(R0912)
[refactor] 1207-1207: Too many statements (98/50)
(R0915)
♻️ Duplicate comments (1)
examples/scripts/3_Dynamics/3.6_MACE_NVT_Nose_Hoover_temp_profile.py (1)
125-133
:⚠️ Potential issueModel and state initialization updated for batched API.
Same cell tensor issue affects this script as well. The model initialization is correct, but the pipeline failure suggests a systematic problem with cell tensor formatting.
This script has the same cell tensor shape issue as
3.12_MACE_NPT_Langevin.py
. The neighbor list function expects a 3x3 tensor but receives a [32, 3] tensor.
🧹 Nitpick comments (16)
examples/scripts/2_Structural_optimization/2.5_MACE_UnitCellFilter_Gradient_Descent.py (1)
39-40
: ApplySMOKE_TEST
for environment-based step control.
Refactoring the inline CI check intoSMOKE_TEST
improves clarity; consider centralizing this pattern across all example scripts.examples/scripts/1_Introduction/1.2_MACE.py (1)
82-88
: Minor style improvement: Use dict literal.Consider using dictionary literal syntax for better readability.
-results = batched_model( - dict( - positions=positions, - cell=cell, - atomic_numbers=atomic_numbers, - batch=batch, - pbc=True, - ) -) +results = batched_model({ + "positions": positions, + "cell": cell, + "atomic_numbers": atomic_numbers, + "batch": batch, + "pbc": True, +})🧰 Tools
🪛 Pylint (3.3.7)
[refactor] 82-88: Consider using '{"positions": positions, "cell": cell, "atomic_numbers": atomic_numbers, ... }' instead of a call to 'dict'.
(R1735)
torch_sim/autobatching.py (1)
1033-1035
: Good defensive programming improvement.The explicit
is not None
check is safer than relying on truthiness, as it correctly handles the case wheremax_attempts
could be0
(a valid value that would be falsy).Note: Static analysis indicates this line is not covered by tests. Consider adding a test case for the
max_attempts
edge cases.🧰 Tools
🪛 GitHub Check: codecov/patch
[warning] 1033-1033: torch_sim/autobatching.py#L1033
Added line #L1033 was not covered by testsexamples/scripts/3_Dynamics/3.8_MACE_NPT_Nose_Hoover.py (1)
64-77
: Variable naming inconsistency with NPT integrator.The
npt_nose_hoover
function returns initialization and update functions that are assigned tonvt_init
andnvt_update
. This naming is misleading since this is an NPT simulation.Consider renaming for clarity:
-nvt_init, nvt_update = npt_nose_hoover(model=model, kT=kT, dt=dt) +npt_init, npt_update = npt_nose_hoover(model=model, kT=kT, dt=dt) -state = nvt_init(state=state, seed=1) +state = npt_init(state=state, seed=1) - state = nvt_update(state, kT=kT) + state = npt_update(state, kT=kT)examples/scripts/7_Others/7.6_Compare_ASE_to_VV_FIRE.py (1)
200-200
: Consider if single optimization step is sufficient.Reducing
steps_this_round
from 10 to 1 significantly decreases the optimization granularity. While this may be acceptable for testing purposes, it could impact the effectiveness of the optimization comparison.Consider whether a single step provides meaningful optimization progress for the comparison, or if a moderate value like 3-5 steps would be more appropriate.
torch_sim/models/metatomic.py (1)
231-240
: CPU round-trip for neighbour building is a hotspot
vesin.metatomic.compute_requested_neighbors()
is forced onto CPU, then data is copied back toself.device
.
For GPU runs this introduces two full copies of potentially largeSystem
objects per step.If
vesin
supports CUDA in the near future, gate this logic or at least emit a warning once:if self._device.type == "cuda": warnings.warn( "Neighbour list temporarily built on CPU; this incurs a host<->device copy.", stacklevel=2, )Long-term: contribute GPU support or cache neighbour lists to amortise the copy cost.
torch_sim/models/particle_life.py (1)
251-259
: Aggregation logic drops unhandled global keysOnly
"energy"
and"stress"
are stacked. Ifunbatched_forward
later returns additional global properties (e.g.,"temperature"
), they will be silently ignored.Consider iterating dynamically:
for key in properties: if key in ("forces", "energies", "stresses"): results[key] = torch.cat([out[key] for out in outputs]) else: # global tensor results[key] = torch.stack([out[key] for out in outputs])Keeps the wrapper future-proof with minimal overhead.
torch_sim/integrators/md.py (2)
90-96
: Hard-coding three spatial dimensions limits generalityThe COM-removal logic replicates the
batch
vector three times:batch_3 = batch.view(-1, 1).repeat(1, 3)If 2-D or 4-D simulations are ever used, this silently breaks.
Derive the dimension frompositions
instead:- # create 3 copies of batch - batch_3 = batch.view(-1, 1).repeat(1, 3) + dim = positions.shape[1] + batch_dim = batch.view(-1, 1).repeat(1, dim)(Subsequent uses of
batch_3
need to referencebatch_dim
.)
Keeps the helper truly dimension-agnostic.
112-130
: In-place state mutation may hinder differentiability & functional reuse
momentum_step()
updatesstate.momenta
in-place.
While fine for plain MD, in-place ops:
- Prevent gradient tracking through the old
state
,- Make the function non-pure, complicating JIT or parallel use.
Consider returning a new
MDState
(orstate.clone()
first) to keep the integrator side-effect-free.torch_sim/workflows/a2c.py (3)
229-257
: Docstring & return-type drift – update toFireState
random_packed_structure()
now returns aFireState
, but the docstring (lines 254-257) still advertisesFIREState
.
Please adjust the docstring (and the “Notes” section a few lines below) so that tooling and users don’t rely on a removed class name.
341-372
: Same docstring mismatch for multi-species variant
random_packed_structure_multi()
also returnsFireState
, yet the docstring still referencesFIREState
. Synch the wording here too.
800-806
: Inconsistent variable naming (Initial_energy
)Variable names should stay snake-case. Rename
Initial_energy
➜initial_energy
to avoid confusion and potential lint failures.-Initial_energy = results["energy"] -print(f"Initial energy: {Initial_energy.item():.4f} eV") +initial_energy = results["energy"] +print(f"Initial energy: {initial_energy.item():.4f} eV")torch_sim/integrators/nvt.py (1)
441-461
: Docstring arguments out of sync
nvt_nose_hoover_invariant()
docstring still mentions anenergy_fn
parameter that no longer exists.
Clean this up to avoid stale documentation.torch_sim/optimizers.py (3)
273-289
:cell_factor.view()
may fail on non-contiguous tensors
cell_factor.view(n_batches,1,1)
assumes contiguous memory; usereshape
to be safe after possible broadcasts:-cell_factor = cell_factor.view(n_batches, 1, 1) +cell_factor = cell_factor.reshape(n_batches, 1, 1)
321-323
: Volume tensor already 3-D – avoid unnecessaryview
torch.linalg.det(state.cell)
already yields[n_batches]
.
view(n_batches,1,1)
is an unnecessary reshape that can break if tensor isn’t contiguous.
Considervolumes[:, None, None]
instead for clarity.
1368-1395
: Repeatedsqrt
/division work – micro-optimisation
v_scaling_batch
andf_scaling_batch
are reused in both atomic and cell branches.
Cache their square-roots once rather than recomputing for atoms and cells.
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (84)
.github/workflows/docs.yml
(1 hunks).github/workflows/link-check.yml
(0 hunks).github/workflows/test.yml
(2 hunks).pre-commit-config.yaml
(2 hunks)README.md
(2 hunks)docs/conf.py
(1 hunks)docs/tutorials/index.rst
(1 hunks)examples/readme.md
(1 hunks)examples/scripts/1_Introduction/1.1_Lennard_Jones.py
(2 hunks)examples/scripts/1_Introduction/1.2_MACE.py
(2 hunks)examples/scripts/1_Introduction/1.3_Batched_MACE.py
(0 hunks)examples/scripts/1_Introduction/1.3_Fairchem.py
(1 hunks)examples/scripts/2_Structural_optimization/2.1_Lennard_Jones_FIRE.py
(4 hunks)examples/scripts/2_Structural_optimization/2.2_Soft_Sphere_FIRE.py
(4 hunks)examples/scripts/2_Structural_optimization/2.3_MACE_FIRE.py
(0 hunks)examples/scripts/2_Structural_optimization/2.3_MACE_Gradient_Descent.py
(1 hunks)examples/scripts/2_Structural_optimization/2.4_MACE_FIRE.py
(1 hunks)examples/scripts/2_Structural_optimization/2.4_MACE_UnitCellFilter_FIRE.py
(0 hunks)examples/scripts/2_Structural_optimization/2.5_MACE_FrechetCellFilter_FIRE.py
(0 hunks)examples/scripts/2_Structural_optimization/2.5_MACE_UnitCellFilter_Gradient_Descent.py
(1 hunks)examples/scripts/2_Structural_optimization/2.6_MACE_UnitCellFilter_FIRE.py
(1 hunks)examples/scripts/2_Structural_optimization/2.7_MACE_FrechetCellFilter_FIRE.py
(1 hunks)examples/scripts/3_Dynamics/3.10_Hybrid_swap_mc.py
(2 hunks)examples/scripts/3_Dynamics/3.11_Lennard_Jones_NPT_Langevin.py
(5 hunks)examples/scripts/3_Dynamics/3.12_MACE_NPT_Langevin.py
(2 hunks)examples/scripts/3_Dynamics/3.13_MACE_NVE_non_pbc.py
(4 hunks)examples/scripts/3_Dynamics/3.1_Lennard_Jones_NVE.py
(3 hunks)examples/scripts/3_Dynamics/3.2_MACE_NVE.py
(5 hunks)examples/scripts/3_Dynamics/3.3_MACE_NVE_cueq.py
(3 hunks)examples/scripts/3_Dynamics/3.4_MACE_NVT_Langevin.py
(5 hunks)examples/scripts/3_Dynamics/3.5_MACE_NVT_Nose_Hoover.py
(2 hunks)examples/scripts/3_Dynamics/3.6_MACE_NVT_Nose_Hoover_temp_profile.py
(4 hunks)examples/scripts/3_Dynamics/3.7_Lennard_Jones_NPT_Nose_Hoover.py
(6 hunks)examples/scripts/3_Dynamics/3.8_MACE_NPT_Nose_Hoover.py
(2 hunks)examples/scripts/3_Dynamics/3.9_MACE_NVT_staggered_stress.py
(3 hunks)examples/scripts/4_High_level_api/4.1_high_level_api.py
(9 hunks)examples/scripts/4_High_level_api/4.2_auto_batching_api.py
(5 hunks)examples/scripts/5_Workflow/5.1_a2c_silicon.py
(0 hunks)examples/scripts/5_Workflow/5.1_a2c_silicon_batched.py
(6 hunks)examples/scripts/5_Workflow/5.2_In_Flight_WBM.py
(3 hunks)examples/scripts/7_Others/7.1_Soft_sphere_autograd.py
(1 hunks)examples/scripts/7_Others/7.2_Stress_autograd.py
(1 hunks)examples/scripts/7_Others/7.6_Compare_ASE_to_VV_FIRE.py
(4 hunks)examples/tutorials/metatomic_tutorial.py
(3 hunks)pyproject.toml
(2 hunks)tests/conftest.py
(0 hunks)tests/models/test_lennard_jones.py
(3 hunks)tests/models/test_metatensor.py
(0 hunks)tests/models/test_metatomic.py
(1 hunks)tests/models/test_morse.py
(1 hunks)tests/test_integrators.py
(1 hunks)tests/test_optimizers_vs_ase.py
(5 hunks)tests/test_runners.py
(2 hunks)tests/unbatched/conftest.py
(0 hunks)tests/unbatched/test_unbatched_integrators.py
(0 hunks)tests/unbatched/test_unbatched_mace.py
(0 hunks)tests/unbatched/test_unbatched_optimizers.py
(0 hunks)tests/workflows/test_a2c.py
(3 hunks)torch_sim/__init__.py
(0 hunks)torch_sim/autobatching.py
(1 hunks)torch_sim/integrators/__init__.py
(1 hunks)torch_sim/integrators/md.py
(1 hunks)torch_sim/integrators/npt.py
(3 hunks)torch_sim/integrators/nve.py
(1 hunks)torch_sim/integrators/nvt.py
(1 hunks)torch_sim/math.py
(1 hunks)torch_sim/models/lennard_jones.py
(1 hunks)torch_sim/models/mace.py
(1 hunks)torch_sim/models/metatomic.py
(7 hunks)torch_sim/models/morse.py
(1 hunks)torch_sim/models/particle_life.py
(4 hunks)torch_sim/models/soft_sphere.py
(1 hunks)torch_sim/optimizers.py
(23 hunks)torch_sim/quantities.py
(2 hunks)torch_sim/runners.py
(5 hunks)torch_sim/state.py
(2 hunks)torch_sim/unbatched/models/lennard_jones.py
(0 hunks)torch_sim/unbatched/models/mace.py
(0 hunks)torch_sim/unbatched/models/morse.py
(0 hunks)torch_sim/unbatched/models/soft_sphere.py
(0 hunks)torch_sim/unbatched/readme.md
(0 hunks)torch_sim/unbatched/unbatched_integrators.py
(0 hunks)torch_sim/unbatched/unbatched_optimizers.py
(0 hunks)torch_sim/workflows/a2c.py
(9 hunks)
💤 Files with no reviewable changes (20)
- torch_sim/init.py
- torch_sim/unbatched/readme.md
- .github/workflows/link-check.yml
- tests/conftest.py
- examples/scripts/1_Introduction/1.3_Batched_MACE.py
- examples/scripts/2_Structural_optimization/2.4_MACE_UnitCellFilter_FIRE.py
- tests/unbatched/conftest.py
- examples/scripts/2_Structural_optimization/2.3_MACE_FIRE.py
- tests/models/test_metatensor.py
- examples/scripts/2_Structural_optimization/2.5_MACE_FrechetCellFilter_FIRE.py
- torch_sim/unbatched/models/mace.py
- tests/unbatched/test_unbatched_mace.py
- tests/unbatched/test_unbatched_optimizers.py
- torch_sim/unbatched/models/morse.py
- tests/unbatched/test_unbatched_integrators.py
- examples/scripts/5_Workflow/5.1_a2c_silicon.py
- torch_sim/unbatched/models/lennard_jones.py
- torch_sim/unbatched/models/soft_sphere.py
- torch_sim/unbatched/unbatched_optimizers.py
- torch_sim/unbatched/unbatched_integrators.py
🧰 Additional context used
🧬 Code Graph Analysis (18)
examples/scripts/7_Others/7.1_Soft_sphere_autograd.py (1)
torch_sim/models/soft_sphere.py (2)
soft_sphere_pair
(59-93)soft_sphere_pair_force
(96-130)
tests/test_integrators.py (1)
torch_sim/integrators/npt.py (1)
_compute_cell_force
(76-129)
tests/models/test_morse.py (1)
torch_sim/models/morse.py (3)
MorseModel
(115-399)morse_pair
(42-76)morse_pair_force
(79-112)
examples/scripts/7_Others/7.2_Stress_autograd.py (1)
torch_sim/models/lennard_jones.py (2)
lennard_jones_pair
(40-76)lennard_jones_pair_force
(79-119)
examples/scripts/2_Structural_optimization/2.2_Soft_Sphere_FIRE.py (3)
torch_sim/models/soft_sphere.py (1)
SoftSphereModel
(133-433)torch_sim/optimizers.py (4)
fire
(487-618)fire_init
(554-601)fire_init
(776-885)fire_init
(1060-1184)torch_sim/state.py (1)
SimState
(26-312)
examples/scripts/3_Dynamics/3.4_MACE_NVT_Langevin.py (3)
torch_sim/integrators/nvt.py (2)
nvt_langevin
(24-226)langevin_update
(180-224)torch_sim/models/mace.py (3)
MaceModel
(41-50)MaceModel
(80-367)MaceUrls
(370-375)torch_sim/quantities.py (1)
calc_kT
(23-69)
tests/models/test_lennard_jones.py (2)
torch_sim/models/lennard_jones.py (1)
LennardJonesModel
(122-414)tests/conftest.py (1)
lj_model
(34-45)
examples/scripts/3_Dynamics/3.12_MACE_NPT_Langevin.py (4)
torch_sim/integrators/npt.py (4)
npt_langevin
(132-793)momenta
(70-72)npt_init
(608-680)npt_update
(682-791)torch_sim/models/mace.py (2)
MaceModel
(41-50)MaceModel
(80-367)torch_sim/quantities.py (3)
calc_kinetic_energy
(96-130)calc_kT
(23-69)get_pressure
(133-141)torch_sim/io.py (1)
atoms_to_state
(180-245)
torch_sim/quantities.py (4)
torch_sim/state.py (2)
momenta
(319-325)volume
(166-168)torch_sim/integrators/npt.py (2)
momenta
(70-72)velocities
(858-864)torch_sim/integrators/md.py (1)
velocities
(44-48)torch_sim/integrators/nvt.py (1)
velocities
(263-267)
torch_sim/integrators/nve.py (2)
torch_sim/integrators/md.py (4)
MDState
(13-48)calculate_momenta
(51-109)momentum_step
(112-129)position_step
(132-156)torch_sim/state.py (1)
SimState
(26-312)
examples/scripts/3_Dynamics/3.8_MACE_NPT_Nose_Hoover.py (5)
torch_sim/integrators/npt.py (3)
npt_nose_hoover
(885-1449)npt_nose_hoover_invariant
(1452-1520)momenta
(70-72)torch_sim/models/mace.py (2)
MaceModel
(41-50)MaceModel
(80-367)torch_sim/quantities.py (3)
calc_kinetic_energy
(96-130)calc_kT
(23-69)get_pressure
(133-141)torch_sim/state.py (4)
device
(137-139)dtype
(142-144)momenta
(319-325)volume
(166-168)torch_sim/io.py (1)
atoms_to_state
(180-245)
examples/scripts/3_Dynamics/3.3_MACE_NVE_cueq.py (3)
torch_sim/integrators/nve.py (1)
nve
(17-141)torch_sim/models/mace.py (3)
MaceModel
(41-50)MaceModel
(80-367)MaceUrls
(370-375)torch_sim/io.py (1)
atoms_to_state
(180-245)
examples/scripts/1_Introduction/1.1_Lennard_Jones.py (1)
torch_sim/models/lennard_jones.py (1)
LennardJonesModel
(122-414)
torch_sim/models/metatomic.py (1)
torch_sim/models/interface.py (1)
ModelInterface
(40-199)
examples/scripts/7_Others/7.6_Compare_ASE_to_VV_FIRE.py (2)
torch_sim/runners.py (3)
convergence_fn
(303-317)convergence_fn
(334-339)generate_force_convergence_fn
(287-319)tests/test_autobatching.py (1)
convergence_fn
(448-456)
torch_sim/models/soft_sphere.py (1)
torch_sim/transforms.py (1)
safe_mask
(1117-1146)
torch_sim/workflows/a2c.py (2)
torch_sim/models/soft_sphere.py (2)
SoftSphereModel
(133-433)SoftSphereMultiModel
(436-873)torch_sim/optimizers.py (3)
FireState
(443-484)UnitCellFireState
(622-693)fire
(487-618)
torch_sim/optimizers.py (2)
torch_sim/state.py (8)
n_batches
(161-163)device
(137-139)dtype
(142-144)deform_grad
(350-356)reference_row_vector_cell
(328-330)reference_row_vector_cell
(333-335)row_vector_cell
(185-187)row_vector_cell
(190-196)torch_sim/math.py (1)
batched_vdot
(992-1022)
🪛 GitHub Actions: Tests
examples/scripts/5_Workflow/5.1_a2c_silicon_batched.py
[error] 123-123: RuntimeError in TorchScript interpreter: box
must be 3 x 3 tensor, but the shape is [64, 3]. This caused the process to fail with exit code 1.
examples/scripts/3_Dynamics/3.5_MACE_NVT_Nose_Hoover.py
[error] 66-66: RuntimeError in TorchScript interpreter: box
must be 3 x 3 tensor, but the shape is [64, 3]. This error occurs during neighbor_list_fn.compute call in torch_sim/neighbors.py and vesin/torch/_neighbors.py.
examples/scripts/3_Dynamics/3.12_MACE_NPT_Langevin.py
[error] 66-66: RuntimeError in TorchScript interpreter: box
must be 3 x 3 tensor, but the shape is [64, 3]. This error occurs during the call to neighbor_list_fn.compute in torch_sim/neighbors.py.
examples/scripts/3_Dynamics/3.8_MACE_NPT_Nose_Hoover.py
[error] 16-16: ImportError: cannot import name 'npt_nose_hoover' from 'torch_sim.integrators'. This caused the script to fail with exit code 1.
examples/scripts/3_Dynamics/3.3_MACE_NVE_cueq.py
[error] 47-47: RuntimeError in loading state_dict for ScaleShiftMACE: Unexpected keys and size mismatches in model parameters during state_dict loading. Keys include 'products.0.symmetric_contractions.weight', 'products.1.symmetric_contractions.weight'. Multiple parameters have size mismatches between checkpoint and current model.
examples/scripts/3_Dynamics/3.6_MACE_NVT_Nose_Hoover_temp_profile.py
[error] 143-143: RuntimeError in TorchScript interpreter: box
must be 3 x 3 tensor, but the shape is [32, 3]. This error occurs during the call to neighbor_list_fn.compute in torch_sim/neighbors.py.
examples/scripts/3_Dynamics/3.9_MACE_NVT_staggered_stress.py
[error] 71-71: TypeError: nvt_langevin() missing 1 required keyword-only argument: 'dt'. The function call at line 71 is missing the required 'dt' argument.
examples/scripts/3_Dynamics/3.7_Lennard_Jones_NPT_Nose_Hoover.py
[error] 120-120: ValueError: Cannot categorize property 'cell' with shape torch.Size([3, 3]). Expected first dimension to be either 256 (per-atom) or 1 (per-batch), or a scalar (global). Command failed with exit code 1.
🪛 Pylint (3.3.7)
examples/scripts/1_Introduction/1.2_MACE.py
[refactor] 82-88: Consider using '{"positions": positions, "cell": cell, "atomic_numbers": atomic_numbers, ... }' instead of a call to 'dict'.
(R1735)
torch_sim/models/metatomic.py
[error] 51-51: class already defined line 39
(E0102)
[refactor] 51-51: Too many instance attributes (9/7)
(R0902)
torch_sim/models/particle_life.py
[refactor] 86-86: Too many instance attributes (10/7)
(R0902)
[refactor] 128-128: Too many local variables (16/15)
(R0914)
tests/test_optimizers_vs_ase.py
[refactor] 19-19: Too many local variables (17/15)
(R0914)
torch_sim/integrators/nvt.py
[refactor] 270-270: Too many arguments (6/5)
(R0913)
[error] 365-365: Too many positional arguments for function call
(E1121)
torch_sim/integrators/md.py
[refactor] 324-324: Too many local variables (17/15)
(R0914)
torch_sim/integrators/npt.py
[refactor] 797-797: Too many instance attributes (9/7)
(R0902)
[refactor] 885-885: Too many arguments (7/5)
(R0913)
[refactor] 885-885: Too many local variables (18/15)
(R0914)
[refactor] 1112-1112: Too many arguments (7/5)
(R0913)
[refactor] 1112-1112: Too many positional arguments (7/5)
(R0917)
[refactor] 1164-1164: Too many local variables (19/15)
(R0914)
[refactor] 1261-1261: Too many local variables (20/15)
(R0914)
[error] 1375-1375: Too many positional arguments for function call
(E1121)
[refactor] 885-885: Too many statements (94/50)
(R0915)
torch_sim/optimizers.py
[refactor] 1207-1207: Too many arguments (11/5)
(R0913)
[refactor] 1207-1207: Too many local variables (57/15)
(R0914)
[refactor] 1207-1207: Too many branches (18/12)
(R0912)
[refactor] 1207-1207: Too many statements (98/50)
(R0915)
[refactor] 1417-1417: Too many arguments (12/5)
(R0913)
[refactor] 1417-1417: Too many local variables (59/15)
(R0914)
[refactor] 1417-1417: Too many branches (21/12)
(R0912)
[refactor] 1417-1417: Too many statements (110/50)
(R0915)
🪛 GitHub Check: codecov/patch
torch_sim/autobatching.py
[warning] 1033-1033: torch_sim/autobatching.py#L1033
Added line #L1033 was not covered by tests
torch_sim/integrators/md.py
[warning] 183-185: torch_sim/integrators/md.py#L183-L185
Added lines #L183 - L185 were not covered by tests
[warning] 187-187: torch_sim/integrators/md.py#L187
Added line #L187 was not covered by tests
[warning] 189-191: torch_sim/integrators/md.py#L189-L191
Added lines #L189 - L191 were not covered by tests
[warning] 300-300: torch_sim/integrators/md.py#L300
Added line #L300 was not covered by tests
[warning] 313-314: torch_sim/integrators/md.py#L313-L314
Added lines #L313 - L314 were not covered by tests
[warning] 316-317: torch_sim/integrators/md.py#L316-L317
Added lines #L316 - L317 were not covered by tests
[warning] 319-320: torch_sim/integrators/md.py#L319-L320
Added lines #L319 - L320 were not covered by tests
[warning] 322-322: torch_sim/integrators/md.py#L322
Added line #L322 was not covered by tests
[warning] 324-324: torch_sim/integrators/md.py#L324
Added line #L324 was not covered by tests
[warning] 338-338: torch_sim/integrators/md.py#L338
Added line #L338 was not covered by tests
[warning] 347-349: torch_sim/integrators/md.py#L347-L349
Added lines #L347 - L349 were not covered by tests
[warning] 351-351: torch_sim/integrators/md.py#L351
Added line #L351 was not covered by tests
[warning] 354-355: torch_sim/integrators/md.py#L354-L355
Added lines #L354 - L355 were not covered by tests
[warning] 357-360: torch_sim/integrators/md.py#L357-L360
Added lines #L357 - L360 were not covered by tests
[warning] 363-365: torch_sim/integrators/md.py#L363-L365
Added lines #L363 - L365 were not covered by tests
[warning] 368-370: torch_sim/integrators/md.py#L368-L370
Added lines #L368 - L370 were not covered by tests
[warning] 373-373: torch_sim/integrators/md.py#L373
Added line #L373 was not covered by tests
[warning] 376-381: torch_sim/integrators/md.py#L376-L381
Added lines #L376 - L381 were not covered by tests
[warning] 383-383: torch_sim/integrators/md.py#L383
Added line #L383 was not covered by tests
[warning] 385-385: torch_sim/integrators/md.py#L385
Added line #L385 was not covered by tests
[warning] 398-400: torch_sim/integrators/md.py#L398-L400
Added lines #L398 - L400 were not covered by tests
[warning] 402-403: torch_sim/integrators/md.py#L402-L403
Added lines #L402 - L403 were not covered by tests
[warning] 405-407: torch_sim/integrators/md.py#L405-L407
Added lines #L405 - L407 were not covered by tests
[warning] 409-409: torch_sim/integrators/md.py#L409
Added line #L409 was not covered by tests
[warning] 411-411: torch_sim/integrators/md.py#L411
Added line #L411 was not covered by tests
[warning] 421-422: torch_sim/integrators/md.py#L421-L422
Added lines #L421 - L422 were not covered by tests
[warning] 424-425: torch_sim/integrators/md.py#L424-L425
Added lines #L424 - L425 were not covered by tests
[warning] 427-427: torch_sim/integrators/md.py#L427
Added line #L427 was not covered by tests
[warning] 436-436: torch_sim/integrators/md.py#L436
Added line #L436 was not covered by tests
torch_sim/integrators/npt.py
[warning] 864-864: torch_sim/integrators/npt.py#L864
Added line #L864 was not covered by tests
⏰ Context from checks skipped due to timeout of 90000ms (225)
- GitHub Check: test-examples (examples/tutorials/reporting_tutorial.py)
- GitHub Check: test-examples (examples/tutorials/using_graphpes_tutorial.py)
- GitHub Check: test-examples (examples/tutorials/low_level_tutorial.py)
- GitHub Check: test-examples (examples/tutorials/high_level_tutorial.py)
- GitHub Check: test-examples (examples/tutorials/autobatching_tutorial.py)
- GitHub Check: test-examples (examples/tutorials/metatomic_tutorial.py)
- GitHub Check: test-examples (examples/tutorials/hybrid_swap_tutorial.py)
- GitHub Check: test-examples (examples/scripts/2_Structural_optimization/2.2_Soft_Sphere_FIRE.py)
- GitHub Check: test-examples (examples/scripts/4_High_level_api/4.1_high_level_api.py)
- GitHub Check: test-examples (examples/scripts/7_Others/7.6_Compare_ASE_to_VV_FIRE.py)
- GitHub Check: test-examples (examples/scripts/7_Others/7.4_Velocity_AutoCorrelation.py)
- GitHub Check: test-model (macos-14, 3.11, highest, mattersim, tests/models/test_mattersim.py)
- GitHub Check: test-model (macos-14, 3.11, highest, mace, tests/test_optimizers_vs_ase.py)
- GitHub Check: test-model (macos-14, 3.11, highest, mace, tests/test_elastic.py)
- GitHub Check: test-model (macos-14, 3.11, highest, mace, tests/models/test_mace.py)
- GitHub Check: test-model (macos-14, 3.11, highest, fairchem, tests/models/test_fairchem.py)
- GitHub Check: test-model (ubuntu-latest, 3.12, lowest-direct, sevenn, tests/models/test_sevennet.py)
- GitHub Check: test-model (ubuntu-latest, 3.12, lowest-direct, orb, tests/models/test_orb.py)
- GitHub Check: test-model (ubuntu-latest, 3.12, lowest-direct, mattersim, tests/models/test_mattersim.py)
- GitHub Check: test-model (ubuntu-latest, 3.12, lowest-direct, mace, tests/test_optimizers_vs_ase.py)
- GitHub Check: test-model (ubuntu-latest, 3.12, lowest-direct, fairchem, tests/models/test_fairchem.py)
- GitHub Check: test-model (ubuntu-latest, 3.11, highest, sevenn, tests/models/test_sevennet.py)
- GitHub Check: test-model (ubuntu-latest, 3.11, highest, orb, tests/models/test_orb.py)
- GitHub Check: test-model (ubuntu-latest, 3.11, highest, mace, tests/test_elastic.py)
- GitHub Check: test-model (ubuntu-latest, 3.11, highest, fairchem, tests/models/test_fairchem.py)
- GitHub Check: build-docs
- GitHub Check: test-examples (examples/tutorials/using_graphpes_tutorial.py)
- GitHub Check: test-examples (examples/tutorials/low_level_tutorial.py)
- GitHub Check: test-examples (examples/tutorials/high_level_tutorial.py)
- GitHub Check: test-examples (examples/tutorials/autobatching_tutorial.py)
- GitHub Check: test-examples (examples/tutorials/metatomic_tutorial.py)
- GitHub Check: test-examples (examples/tutorials/hybrid_swap_tutorial.py)
- GitHub Check: test-examples (examples/scripts/2_Structural_optimization/2.2_Soft_Sphere_FIRE.py)
- GitHub Check: test-examples (examples/scripts/4_High_level_api/4.1_high_level_api.py)
- GitHub Check: test-examples (examples/scripts/7_Others/7.6_Compare_ASE_to_VV_FIRE.py)
- GitHub Check: test-examples (examples/scripts/7_Others/7.4_Velocity_AutoCorrelation.py)
- GitHub Check: test-model (macos-14, 3.11, highest, mattersim, tests/models/test_mattersim.py)
- GitHub Check: test-model (macos-14, 3.11, highest, mace, tests/test_optimizers_vs_ase.py)
- GitHub Check: test-model (macos-14, 3.11, highest, mace, tests/test_elastic.py)
- GitHub Check: test-model (macos-14, 3.11, highest, mace, tests/models/test_mace.py)
- GitHub Check: test-model (macos-14, 3.11, highest, fairchem, tests/models/test_fairchem.py)
- GitHub Check: test-model (ubuntu-latest, 3.12, lowest-direct, sevenn, tests/models/test_sevennet.py)
- GitHub Check: test-model (ubuntu-latest, 3.12, lowest-direct, orb, tests/models/test_orb.py)
- GitHub Check: test-model (ubuntu-latest, 3.12, lowest-direct, mattersim, tests/models/test_mattersim.py)
- GitHub Check: test-model (ubuntu-latest, 3.12, lowest-direct, mace, tests/test_optimizers_vs_ase.py)
- GitHub Check: test-model (ubuntu-latest, 3.12, lowest-direct, fairchem, tests/models/test_fairchem.py)
- GitHub Check: test-model (ubuntu-latest, 3.11, highest, sevenn, tests/models/test_sevennet.py)
- GitHub Check: test-model (ubuntu-latest, 3.11, highest, orb, tests/models/test_orb.py)
- GitHub Check: test-model (ubuntu-latest, 3.11, highest, fairchem, tests/models/test_fairchem.py)
- GitHub Check: build-docs
- GitHub Check: test-examples (examples/tutorials/using_graphpes_tutorial.py)
- GitHub Check: test-examples (examples/tutorials/low_level_tutorial.py)
- GitHub Check: test-examples (examples/tutorials/high_level_tutorial.py)
- GitHub Check: test-examples (examples/tutorials/autobatching_tutorial.py)
- GitHub Check: test-examples (examples/tutorials/metatomic_tutorial.py)
- GitHub Check: test-examples (examples/tutorials/hybrid_swap_tutorial.py)
- GitHub Check: test-examples (examples/scripts/2_Structural_optimization/2.2_Soft_Sphere_FIRE.py)
- GitHub Check: test-examples (examples/scripts/4_High_level_api/4.1_high_level_api.py)
- GitHub Check: test-examples (examples/scripts/7_Others/7.6_Compare_ASE_to_VV_FIRE.py)
- GitHub Check: test-examples (examples/scripts/7_Others/7.4_Velocity_AutoCorrelation.py)
- GitHub Check: test-model (macos-14, 3.11, highest, mattersim, tests/models/test_mattersim.py)
- GitHub Check: test-model (macos-14, 3.11, highest, mace, tests/test_optimizers_vs_ase.py)
- GitHub Check: test-model (macos-14, 3.11, highest, mace, tests/test_elastic.py)
- GitHub Check: test-model (macos-14, 3.11, highest, mace, tests/models/test_mace.py)
- GitHub Check: test-model (macos-14, 3.11, highest, fairchem, tests/models/test_fairchem.py)
- GitHub Check: test-model (ubuntu-latest, 3.12, lowest-direct, sevenn, tests/models/test_sevennet.py)
- GitHub Check: test-model (ubuntu-latest, 3.12, lowest-direct, orb, tests/models/test_orb.py)
- GitHub Check: test-model (ubuntu-latest, 3.12, lowest-direct, mace, tests/test_optimizers_vs_ase.py)
- GitHub Check: test-model (ubuntu-latest, 3.12, lowest-direct, fairchem, tests/models/test_fairchem.py)
- GitHub Check: test-model (ubuntu-latest, 3.11, highest, sevenn, tests/models/test_sevennet.py)
- GitHub Check: test-model (ubuntu-latest, 3.11, highest, orb, tests/models/test_orb.py)
- GitHub Check: test-model (ubuntu-latest, 3.11, highest, fairchem, tests/models/test_fairchem.py)
- GitHub Check: build-docs
- GitHub Check: test-examples (examples/tutorials/using_graphpes_tutorial.py)
- GitHub Check: test-examples (examples/tutorials/low_level_tutorial.py)
- GitHub Check: test-examples (examples/tutorials/high_level_tutorial.py)
- GitHub Check: test-examples (examples/tutorials/autobatching_tutorial.py)
- GitHub Check: test-examples (examples/tutorials/metatomic_tutorial.py)
- GitHub Check: test-examples (examples/tutorials/hybrid_swap_tutorial.py)
- GitHub Check: test-examples (examples/scripts/2_Structural_optimization/2.2_Soft_Sphere_FIRE.py)
- GitHub Check: test-examples (examples/scripts/4_High_level_api/4.1_high_level_api.py)
- GitHub Check: test-examples (examples/scripts/7_Others/7.6_Compare_ASE_to_VV_FIRE.py)
- GitHub Check: test-examples (examples/scripts/7_Others/7.4_Velocity_AutoCorrelation.py)
- GitHub Check: test-model (macos-14, 3.11, highest, mattersim, tests/models/test_mattersim.py)
- GitHub Check: test-model (macos-14, 3.11, highest, mace, tests/test_optimizers_vs_ase.py)
- GitHub Check: test-model (macos-14, 3.11, highest, mace, tests/test_elastic.py)
- GitHub Check: test-model (macos-14, 3.11, highest, mace, tests/models/test_mace.py)
- GitHub Check: test-model (macos-14, 3.11, highest, fairchem, tests/models/test_fairchem.py)
- GitHub Check: test-model (ubuntu-latest, 3.12, lowest-direct, sevenn, tests/models/test_sevennet.py)
- GitHub Check: test-model (ubuntu-latest, 3.12, lowest-direct, orb, tests/models/test_orb.py)
- GitHub Check: test-model (ubuntu-latest, 3.12, lowest-direct, mace, tests/test_optimizers_vs_ase.py)
- GitHub Check: test-model (ubuntu-latest, 3.12, lowest-direct, fairchem, tests/models/test_fairchem.py)
- GitHub Check: test-model (ubuntu-latest, 3.11, highest, sevenn, tests/models/test_sevennet.py)
- GitHub Check: test-model (ubuntu-latest, 3.11, highest, orb, tests/models/test_orb.py)
- GitHub Check: test-model (ubuntu-latest, 3.11, highest, fairchem, tests/models/test_fairchem.py)
- GitHub Check: build-docs
- GitHub Check: test-examples (examples/tutorials/using_graphpes_tutorial.py)
- GitHub Check: test-examples (examples/tutorials/low_level_tutorial.py)
- GitHub Check: test-examples (examples/tutorials/high_level_tutorial.py)
- GitHub Check: test-examples (examples/tutorials/autobatching_tutorial.py)
- GitHub Check: test-examples (examples/tutorials/metatomic_tutorial.py)
- GitHub Check: test-examples (examples/tutorials/hybrid_swap_tutorial.py)
- GitHub Check: test-examples (examples/scripts/2_Structural_optimization/2.2_Soft_Sphere_FIRE.py)
- GitHub Check: test-examples (examples/scripts/4_High_level_api/4.1_high_level_api.py)
- GitHub Check: test-examples (examples/scripts/7_Others/7.6_Compare_ASE_to_VV_FIRE.py)
- GitHub Check: test-examples (examples/scripts/7_Others/7.4_Velocity_AutoCorrelation.py)
- GitHub Check: test-model (macos-14, 3.11, highest, mattersim, tests/models/test_mattersim.py)
- GitHub Check: test-model (macos-14, 3.11, highest, mace, tests/test_optimizers_vs_ase.py)
- GitHub Check: test-model (macos-14, 3.11, highest, mace, tests/test_elastic.py)
- GitHub Check: test-model (macos-14, 3.11, highest, mace, tests/models/test_mace.py)
- GitHub Check: test-model (macos-14, 3.11, highest, fairchem, tests/models/test_fairchem.py)
- GitHub Check: test-model (ubuntu-latest, 3.12, lowest-direct, sevenn, tests/models/test_sevennet.py)
- GitHub Check: test-model (ubuntu-latest, 3.12, lowest-direct, orb, tests/models/test_orb.py)
- GitHub Check: test-model (ubuntu-latest, 3.12, lowest-direct, mace, tests/test_optimizers_vs_ase.py)
- GitHub Check: test-model (ubuntu-latest, 3.12, lowest-direct, fairchem, tests/models/test_fairchem.py)
- GitHub Check: test-model (ubuntu-latest, 3.11, highest, sevenn, tests/models/test_sevennet.py)
- GitHub Check: test-model (ubuntu-latest, 3.11, highest, orb, tests/models/test_orb.py)
- GitHub Check: test-model (ubuntu-latest, 3.11, highest, fairchem, tests/models/test_fairchem.py)
- GitHub Check: build-docs
- GitHub Check: test-examples (examples/tutorials/using_graphpes_tutorial.py)
- GitHub Check: test-examples (examples/tutorials/low_level_tutorial.py)
- GitHub Check: test-examples (examples/tutorials/high_level_tutorial.py)
- GitHub Check: test-examples (examples/tutorials/autobatching_tutorial.py)
- GitHub Check: test-examples (examples/tutorials/metatomic_tutorial.py)
- GitHub Check: test-examples (examples/tutorials/hybrid_swap_tutorial.py)
- GitHub Check: test-examples (examples/scripts/2_Structural_optimization/2.2_Soft_Sphere_FIRE.py)
- GitHub Check: test-examples (examples/scripts/4_High_level_api/4.1_high_level_api.py)
- GitHub Check: test-examples (examples/scripts/7_Others/7.6_Compare_ASE_to_VV_FIRE.py)
- GitHub Check: test-examples (examples/scripts/7_Others/7.4_Velocity_AutoCorrelation.py)
- GitHub Check: test-model (macos-14, 3.11, highest, mattersim, tests/models/test_mattersim.py)
- GitHub Check: test-model (macos-14, 3.11, highest, mace, tests/test_optimizers_vs_ase.py)
- GitHub Check: test-model (macos-14, 3.11, highest, mace, tests/test_elastic.py)
- GitHub Check: test-model (macos-14, 3.11, highest, mace, tests/models/test_mace.py)
- GitHub Check: test-model (macos-14, 3.11, highest, fairchem, tests/models/test_fairchem.py)
- GitHub Check: test-model (ubuntu-latest, 3.12, lowest-direct, sevenn, tests/models/test_sevennet.py)
- GitHub Check: test-model (ubuntu-latest, 3.12, lowest-direct, orb, tests/models/test_orb.py)
- GitHub Check: test-model (ubuntu-latest, 3.12, lowest-direct, mace, tests/test_optimizers_vs_ase.py)
- GitHub Check: test-model (ubuntu-latest, 3.12, lowest-direct, fairchem, tests/models/test_fairchem.py)
- GitHub Check: test-model (ubuntu-latest, 3.11, highest, sevenn, tests/models/test_sevennet.py)
- GitHub Check: test-model (ubuntu-latest, 3.11, highest, orb, tests/models/test_orb.py)
- GitHub Check: test-model (ubuntu-latest, 3.11, highest, fairchem, tests/models/test_fairchem.py)
- GitHub Check: build-docs
- GitHub Check: test-examples (examples/tutorials/using_graphpes_tutorial.py)
- GitHub Check: test-examples (examples/tutorials/low_level_tutorial.py)
- GitHub Check: test-examples (examples/tutorials/high_level_tutorial.py)
- GitHub Check: test-examples (examples/tutorials/autobatching_tutorial.py)
- GitHub Check: test-examples (examples/tutorials/metatomic_tutorial.py)
- GitHub Check: test-examples (examples/tutorials/hybrid_swap_tutorial.py)
- GitHub Check: test-examples (examples/scripts/2_Structural_optimization/2.2_Soft_Sphere_FIRE.py)
- GitHub Check: test-examples (examples/scripts/4_High_level_api/4.1_high_level_api.py)
- GitHub Check: test-examples (examples/scripts/7_Others/7.6_Compare_ASE_to_VV_FIRE.py)
- GitHub Check: test-examples (examples/scripts/7_Others/7.4_Velocity_AutoCorrelation.py)
- GitHub Check: test-model (macos-14, 3.11, highest, mattersim, tests/models/test_mattersim.py)
- GitHub Check: test-model (macos-14, 3.11, highest, mace, tests/test_optimizers_vs_ase.py)
- GitHub Check: test-model (macos-14, 3.11, highest, mace, tests/test_elastic.py)
- GitHub Check: test-model (macos-14, 3.11, highest, mace, tests/models/test_mace.py)
- GitHub Check: test-model (macos-14, 3.11, highest, fairchem, tests/models/test_fairchem.py)
- GitHub Check: test-model (ubuntu-latest, 3.12, lowest-direct, sevenn, tests/models/test_sevennet.py)
- GitHub Check: test-model (ubuntu-latest, 3.12, lowest-direct, orb, tests/models/test_orb.py)
- GitHub Check: test-model (ubuntu-latest, 3.12, lowest-direct, mace, tests/test_optimizers_vs_ase.py)
- GitHub Check: test-model (ubuntu-latest, 3.12, lowest-direct, fairchem, tests/models/test_fairchem.py)
- GitHub Check: test-model (ubuntu-latest, 3.11, highest, sevenn, tests/models/test_sevennet.py)
- GitHub Check: test-model (ubuntu-latest, 3.11, highest, orb, tests/models/test_orb.py)
- GitHub Check: test-model (ubuntu-latest, 3.11, highest, fairchem, tests/models/test_fairchem.py)
- GitHub Check: build-docs
- GitHub Check: test-examples (examples/tutorials/using_graphpes_tutorial.py)
- GitHub Check: test-examples (examples/tutorials/low_level_tutorial.py)
- GitHub Check: test-examples (examples/tutorials/high_level_tutorial.py)
- GitHub Check: test-examples (examples/tutorials/autobatching_tutorial.py)
- GitHub Check: test-examples (examples/tutorials/hybrid_swap_tutorial.py)
- GitHub Check: test-examples (examples/scripts/2_Structural_optimization/2.2_Soft_Sphere_FIRE.py)
- GitHub Check: test-examples (examples/scripts/4_High_level_api/4.1_high_level_api.py)
- GitHub Check: test-examples (examples/scripts/7_Others/7.6_Compare_ASE_to_VV_FIRE.py)
- GitHub Check: test-examples (examples/scripts/7_Others/7.4_Velocity_AutoCorrelation.py)
- GitHub Check: test-model (macos-14, 3.11, highest, mattersim, tests/models/test_mattersim.py)
- GitHub Check: test-model (macos-14, 3.11, highest, mace, tests/test_optimizers_vs_ase.py)
- GitHub Check: test-model (macos-14, 3.11, highest, mace, tests/test_elastic.py)
- GitHub Check: test-model (macos-14, 3.11, highest, mace, tests/models/test_mace.py)
- GitHub Check: test-model (macos-14, 3.11, highest, fairchem, tests/models/test_fairchem.py)
- GitHub Check: test-model (ubuntu-latest, 3.12, lowest-direct, sevenn, tests/models/test_sevennet.py)
- GitHub Check: test-model (ubuntu-latest, 3.12, lowest-direct, orb, tests/models/test_orb.py)
- GitHub Check: test-model (ubuntu-latest, 3.12, lowest-direct, fairchem, tests/models/test_fairchem.py)
- GitHub Check: test-model (ubuntu-latest, 3.11, highest, sevenn, tests/models/test_sevennet.py)
- GitHub Check: test-model (ubuntu-latest, 3.11, highest, orb, tests/models/test_orb.py)
- GitHub Check: test-model (ubuntu-latest, 3.11, highest, fairchem, tests/models/test_fairchem.py)
- GitHub Check: build-docs
- GitHub Check: test-examples (examples/tutorials/using_graphpes_tutorial.py)
- GitHub Check: test-examples (examples/tutorials/low_level_tutorial.py)
- GitHub Check: test-examples (examples/tutorials/high_level_tutorial.py)
- GitHub Check: test-examples (examples/tutorials/autobatching_tutorial.py)
- GitHub Check: test-examples (examples/scripts/2_Structural_optimization/2.2_Soft_Sphere_FIRE.py)
- GitHub Check: test-examples (examples/scripts/4_High_level_api/4.1_high_level_api.py)
- GitHub Check: test-examples (examples/scripts/7_Others/7.6_Compare_ASE_to_VV_FIRE.py)
- GitHub Check: test-examples (examples/scripts/7_Others/7.4_Velocity_AutoCorrelation.py)
- GitHub Check: test-model (macos-14, 3.11, highest, mattersim, tests/models/test_mattersim.py)
- GitHub Check: test-model (macos-14, 3.11, highest, mace, tests/test_optimizers_vs_ase.py)
- GitHub Check: test-model (macos-14, 3.11, highest, mace, tests/test_elastic.py)
- GitHub Check: test-model (macos-14, 3.11, highest, mace, tests/models/test_mace.py)
- GitHub Check: test-model (macos-14, 3.11, highest, fairchem, tests/models/test_fairchem.py)
- GitHub Check: test-model (ubuntu-latest, 3.12, lowest-direct, sevenn, tests/models/test_sevennet.py)
- GitHub Check: test-model (ubuntu-latest, 3.12, lowest-direct, orb, tests/models/test_orb.py)
- GitHub Check: test-model (ubuntu-latest, 3.12, lowest-direct, fairchem, tests/models/test_fairchem.py)
- GitHub Check: test-model (ubuntu-latest, 3.11, highest, sevenn, tests/models/test_sevennet.py)
- GitHub Check: test-model (ubuntu-latest, 3.11, highest, orb, tests/models/test_orb.py)
- GitHub Check: test-model (ubuntu-latest, 3.11, highest, fairchem, tests/models/test_fairchem.py)
- GitHub Check: build-docs
- GitHub Check: test-examples (examples/tutorials/low_level_tutorial.py)
- GitHub Check: test-examples (examples/tutorials/high_level_tutorial.py)
- GitHub Check: test-examples (examples/tutorials/autobatching_tutorial.py)
- GitHub Check: test-examples (examples/scripts/2_Structural_optimization/2.2_Soft_Sphere_FIRE.py)
- GitHub Check: test-examples (examples/scripts/4_High_level_api/4.1_high_level_api.py)
- GitHub Check: test-examples (examples/scripts/7_Others/7.6_Compare_ASE_to_VV_FIRE.py)
- GitHub Check: test-examples (examples/scripts/7_Others/7.4_Velocity_AutoCorrelation.py)
- GitHub Check: test-model (macos-14, 3.11, highest, mattersim, tests/models/test_mattersim.py)
- GitHub Check: test-model (macos-14, 3.11, highest, mace, tests/test_optimizers_vs_ase.py)
- GitHub Check: test-model (macos-14, 3.11, highest, mace, tests/test_elastic.py)
- GitHub Check: test-model (macos-14, 3.11, highest, mace, tests/models/test_mace.py)
- GitHub Check: test-model (macos-14, 3.11, highest, fairchem, tests/models/test_fairchem.py)
- GitHub Check: test-model (ubuntu-latest, 3.12, lowest-direct, sevenn, tests/models/test_sevennet.py)
- GitHub Check: test-model (ubuntu-latest, 3.12, lowest-direct, orb, tests/models/test_orb.py)
- GitHub Check: test-model (ubuntu-latest, 3.12, lowest-direct, fairchem, tests/models/test_fairchem.py)
- GitHub Check: test-model (ubuntu-latest, 3.11, highest, sevenn, tests/models/test_sevennet.py)
- GitHub Check: test-model (ubuntu-latest, 3.11, highest, orb, tests/models/test_orb.py)
- GitHub Check: test-model (ubuntu-latest, 3.11, highest, fairchem, tests/models/test_fairchem.py)
- GitHub Check: build-docs
🔇 Additional comments (133)
examples/readme.md (1)
6-6
: LGTM: Improved link text clarity.This change makes the issue reference more descriptive and informative for readers.
.pre-commit-config.yaml (1)
10-10
: LGTM: Routine dependency updates.The version bumps for ruff-pre-commit and markdownlint-cli are appropriate minor/patch updates that should include bug fixes and improvements.
Also applies to: 35-35
torch_sim/models/morse.py (3)
37-40
: Well-structured default parameters.The default parameter constants are clearly defined and appropriately typed.
42-77
: Excellent implementation of the Morse potential.The mathematical implementation is correct:
- Formula matches the standard Morse potential: V(r) = ε(1 - e^(-α(r-σ)))² - ε
- Proper edge case handling for dr ≤ 0
- Comprehensive docstring with clear parameter descriptions
The
torch.where
approach for handling numerical instabilities is more explicit than the commentedtorch.nan_to_num
alternative.
79-113
: Correct implementation of Morse force calculation.The force derivation is mathematically sound:
- Properly implements F(r) = -dV/dr = 2αε·e^(-α(r-σ))·(1 - e^(-α(r-σ)))
- Consistent edge case handling with the potential function
- Clear documentation explaining the relationship to the potential energy
.github/workflows/docs.yml (1)
37-37
: LGTM: Consistent with metatensor → metatomic rename.This change aligns with the systematic renaming from "metatensor" to "metatomic" across the repository and ensures the docs build process uses the correct extra dependencies.
examples/scripts/7_Others/7.2_Stress_autograd.py (1)
20-20
: LGTM: Clean import path update.The import correctly moves from the removed unbatched module to the consolidated
torch_sim.models.lennard_jones
. The functionslennard_jones_pair
andlennard_jones_pair_force
maintain the same signatures and behavior.docs/tutorials/index.rst (1)
22-22
: LGTM: Documentation update reflects the metatensor→metatomic transition.The toctree reference correctly updates to match the new tutorial naming convention.
tests/test_integrators.py (1)
379-379
: LGTM: Import path correctly updated for reorganized integrators.The import correctly moves
_compute_cell_force
from the main integrators module to the specificnpt
submodule, aligning with the integrator reorganization.examples/scripts/7_Others/7.1_Soft_sphere_autograd.py (1)
15-15
: LGTM: Clean import path update for soft sphere functions.The import correctly moves from the removed unbatched module to the consolidated
torch_sim.models.soft_sphere
. The functions maintain the same signatures and behavior.tests/models/test_morse.py (1)
6-6
: LGTM: Import consolidation correctly updates path and simplifies imports.The consolidated import correctly moves all Morse model components from the removed unbatched module to
torch_sim.models.morse
, maintaining consistent functionality.README.md (1)
17-17
: Corrected dependency name to metatomic.
This aligns the supported MLIP models with the updated package naming.examples/scripts/2_Structural_optimization/2.3_MACE_Gradient_Descent.py (1)
38-39
: Centralize CI check withSMOKE_TEST
.
Extracting the CI flag intoSMOKE_TEST
improves readability and maintainability.docs/conf.py (1)
69-69
: Updatedautodoc_mock_imports
to includemetatomic
.
This ensures Sphinx mocks the renamed dependency correctly.examples/scripts/2_Structural_optimization/2.6_MACE_UnitCellFilter_FIRE.py (1)
39-40
: UseSMOKE_TEST
for CI-based step count.
Consolidating the CI-environment check intoSMOKE_TEST
makes the script more readable and consistent.examples/scripts/3_Dynamics/3.10_Hybrid_swap_mc.py (3)
20-20
: LGTM: Clean import alias for better readability.The alias from
MetalUnits
toUnits
improves code readability without changing functionality.
26-26
: Consistent with import alias change.The update to use
Units.temperature
correctly follows the import alias introduced on line 20.
43-44
: Good addition of explicit computation control.The new
compute_forces=True
andcompute_stress=False
arguments make the model's computational behavior explicit, which is appropriate for a dynamics simulation where forces are required..github/workflows/test.yml (2)
76-76
: Consistent with the metatensor to metatomic renaming.The matrix entry change correctly reflects the updated model name and aligns with the test file rename.
52-52
:✅ Verification successful
Verify the renamed test file exists.
The change from
test_metatensor.py
totest_metatomic.py
aligns with the broader renaming effort. Ensure the target file exists to avoid workflow issues.
🏁 Script executed:
#!/bin/bash # Verify the metatomic test file exists fd "test_metatomic.py" tests/Length of output: 60
Ignore pattern matches actual test file
.github/workflows/test.yml: line 52
now ignorestests/models/test_metatomic.py
- Verified that
tests/models/test_metatomic.py
exists in the repositoryApproved.
examples/scripts/2_Structural_optimization/2.7_MACE_FrechetCellFilter_FIRE.py (1)
38-39
: Excellent refactoring for improved readability.The introduction of the
SMOKE_TEST
boolean variable makes the CI detection explicit and improves code readability. The logic correctly reduces optimization steps for CI environments.examples/scripts/2_Structural_optimization/2.4_MACE_FIRE.py (1)
38-39
: Consistent refactoring pattern across scripts.The
SMOKE_TEST
boolean introduction follows the same clean pattern as other scripts in this PR, demonstrating good consistency in the refactoring effort. The logic correctly handles CI environment detection.examples/scripts/5_Workflow/5.2_In_Flight_WBM.py (2)
23-24
: LGTM - Clean refactoring to centralize environment check.The introduction of the
SMOKE_TEST
variable improves code readability and maintainability by centralizing the CI environment check.
40-41
: Consistent usage of SMOKE_TEST variable.The refactoring correctly replaces all inline
os.getenv("CI")
checks with the centralizedSMOKE_TEST
variable, maintaining the same functional behavior while improving consistency.Also applies to: 75-75
pyproject.toml (2)
35-36
:✅ Verification successful
Verify the updated vesin dependency versions.
The versions for
vesin-torch
andvesin
have been updated from ">=0.3.1" to ">=0.3.7". Please verify that these versions are available and don't introduce breaking changes or security vulnerabilities.
🏁 Script executed:
#!/bin/bash # Description: Check PyPI for the latest versions of vesin packages and any security advisories echo "Checking vesin-torch versions..." curl -s https://pypi.org/pypi/vesin-torch/json | jq '.releases | keys | .[-5:]' echo "Checking vesin versions..." curl -s https://pypi.org/pypi/vesin/json | jq '.releases | keys | .[-5:]' # Check for security advisories echo "Checking for security advisories..." gh api graphql -f query=' { securityVulnerabilities(first: 5, ecosystem: PIP, package: "vesin-torch") { nodes { advisory { summary severity publishedAt } vulnerableVersionRange } } }'Length of output: 782
Verified vesin package updates
Both
vesin-torch>=0.3.7
andvesin>=0.3.7
exist on PyPI (latest releases include 0.3.7), and no recent security advisories were found for these packages. The dependency bumps are safe to merge.
51-51
:✅ Verification successful
Verify the metatomic dependency versions.
The optional dependency group has been renamed from "metatensor" to "metatomic" with updated package versions. Please verify these packages exist and are compatible.
🏁 Script executed:
#!/bin/bash # Description: Check PyPI for metatomic-torch availability and metatrain version echo "Checking metatomic-torch versions..." curl -s https://pypi.org/pypi/metatomic-torch/json | jq '.info.version, .releases | keys | .[-3:]' echo "Checking metatrain versions..." curl -s https://pypi.org/pypi/metatrain/json | jq '.releases | keys | .[-3:]'Length of output: 428
Verified metatomic dependencies on PyPI
- pyproject.toml (line 51):
metatomic-torch>=0.1.1,<0.2
– latest release is 0.1.2 (within the specified range)- pyproject.toml (line 51):
metatrain[pet]==2025.7
– version 2025.7 is published with thepet
extraNo changes required.
examples/scripts/4_High_level_api/4.1_high_level_api.py (2)
30-30
: LGTM - Consistent refactoring pattern.The introduction of
SMOKE_TEST
follows the same clean pattern as other files in this PR, centralizing the CI environment check.
45-45
: Consistent usage across all integration and optimization calls.All instances of inline
os.getenv("CI")
checks have been properly replaced with theSMOKE_TEST
variable, maintaining consistent behavior while improving readability.Also applies to: 72-72, 112-112, 130-130, 152-152, 171-171, 187-187, 208-208
tests/test_runners.py (2)
300-300
: LGTM - Boundary condition adjustment for batched optimizer compatibility.The change from
> 10
to>= 10
makes the test more flexible, which is appropriate when transitioning from unbatched to batched optimizers that may have slightly different convergence behavior.
330-331
: LGTM - Improved test clarity with descriptive comment.The comment clarification helps explain that this assertion verifies overall energy decrease during optimization, which is the expected behavior.
examples/scripts/4_High_level_api/4.2_auto_batching_api.py (1)
36-36
: Excellent refactoring - centralized environment detection improves maintainability.The introduction of the
SMOKE_TEST
variable centralizes CI environment detection and eliminates repeated inlineos.getenv("CI")
checks throughout the script. This improves code readability and makes it easier to modify test behavior.Also applies to: 59-59, 73-73, 103-103, 113-113
examples/scripts/3_Dynamics/3.1_Lennard_Jones_NVE.py (3)
15-16
: LGTM - Proper migration to batched model and integrator.The import changes from unbatched to batched versions (
nve
integrator andLennardJonesModel
) correctly align with the PR objective of removing unbatched code.
26-27
: LGTM - Consistent SMOKE_TEST pattern for CI environments.The introduction of the
SMOKE_TEST
variable follows the same pattern used across other example scripts, providing consistent CI-aware step control.
92-101
: LGTM - Successful transition to batched LennardJonesModel.The model instantiation has been correctly updated to use the batched
LennardJonesModel
with the same parameters, maintaining simulation consistency while supporting the new batched architecture.torch_sim/state.py (3)
109-110
: LGTM - Automatic batch dimension handling for cell tensor.The automatic addition of a batch dimension when
cell.ndim != 3
andbatch is None
provides backward compatibility while ensuring consistent batched state handling.
112-113
: LGTM - Essential validation for batched cell tensor format.The shape validation ensures the cell tensor conforms to the expected
(n_batches, 3, 3)
format, which is critical for the batched model architecture.
123-126
: LGTM - Comprehensive batch dimension consistency check.The validation ensures the cell's first dimension matches
n_batches
, preventing shape mismatches that could cause runtime errors in batched operations. The descriptive error message aids debugging.examples/scripts/3_Dynamics/3.2_MACE_NVE.py (5)
17-18
: LGTM! Import updates align with batched API transition.The imports have been correctly updated to use the new modular integrator structure and batched MACE model.
40-41
: Good refactor to improve code clarity.The environment variable check has been refactored into a clear boolean
SMOKE_TEST
flag, making the code more readable and maintainable.
52-60
: LGTM! Model initialization correctly uses batched MaceModel.The model initialization has been properly updated to use the batched
MaceModel
class with appropriate parameters.
74-74
: Good practice: Explicit tensor creation with device and dtype.Creating the temperature tensor with explicit device and dtype ensures consistency and prevents potential device/dtype mismatches.
90-92
: Correct batch parameter addition for kinetic energy calculation.The
batch
parameter has been correctly added to the kinetic energy calculation to handle batched simulation states properly.examples/tutorials/metatomic_tutorial.py (3)
29-31
: LGTM! Correctly updated to use MetatomicModel.The import and model instantiation have been properly updated to use the new metatomic package interface.
15-18
: Tutorial text consistently updated for metatomic.The tutorial text has been appropriately updated to reference metatomic instead of metatensor throughout.
Also applies to: 71-71
6-7
:❓ Verification inconclusive
Verify package version compatibility.
The dependencies have been updated to newer versions. Please ensure these versions are compatible and available.
🌐 Web query:
Are metatrain version 2025.7 and metatomic-torch version 0.1.1 available and compatible?
💡 Result:
Based on the search results provided, there is no specific information about metatrain version 2025.7 or metatomic-torch version 0.1.1 being available or compatible with each other.
The search results do mention metatrain as a command line interface for training and evaluating atomistic models[3], but they don't reference the specific version 2025.7. The most recent information about metatrain shows it can be installed via pip or conda, but no version numbers are specified[3].
Similarly, there is no mention of a package called "metatomic-torch" in the search results. There is information about a package called "metatensor-torch" with version 0.6.0[6], but nothing about metatomic-torch 0.1.1.
Without specific information about these versions and their compatibility, I cannot confirm whether metatrain version 2025.7 and metatomic-torch version 0.1.1 are available or compatible. If you're looking to use these specific packages together, you might want to check their official documentation, GitHub repositories, or package indices like PyPI for the most up-to-date information on availability and compatibility.
Citations:
- 1: https://www.meta.com/blog/connect-2025-llamacon-save-the-date/
- 2: https://www.polaranalytics.com/post/2025-metas-tracking-restrictions-for-health-wellness-are-here----heres-how-to-fix-it
- 3: https://github.com/lab-cosmo/metatensor-models
- 4: https://investor.atmeta.com/investor-news/press-release-details/2025/Meta-Reports-First-Quarter-2025-Results/default.aspx
- 5: https://www.pymnts.com/artificial-intelligence-2/2025/meta-adds-multimodal-models-to-its-llama-ai-stable/
- 6: https://pypi.org/project/metatensor-torch/
- 7: https://pytorch.org/docs/stable/meta.html
- 8: https://www.felixsanz.dev/articles/compatibility-between-pytorch-cuda-and-xformers-versions
- 9: https://www.meta.com/connect/
- 10: https://repository.si.edu/bitstream/handle/10088/23971/1931%20SMC%20V88%20%20Smiths%20Physical%20Tables%201-682.pdf?sequence=1&isAllowed=y
Verify availability of metatrain 2025.7 and metatomic-torch 0.1.1
It isn’t clear from public sources whether these exact versions are published or compatible. Please confirm:
- That
metatrain==2025.7
is available on PyPI (or your chosen registry)- That
metatomic-torch>=0.1.1,<0.2
exists and supports the required dependenciestests/models/test_lennard_jones.py (3)
9-11
: LGTM! Import correctly updated to batched LennardJonesModel.The import has been properly updated to use the batched
LennardJonesModel
from the main models module instead of the unbatched version.
162-165
: LGTM! Model instantiation correctly uses batched LennardJonesModel.Both neighbor list and direct calculation models have been properly updated to use the batched
LennardJonesModel
class while maintaining all the test parameters.
234-234
: LGTM! Type annotation correctly updated.The function parameter type annotation has been properly updated to reflect the batched
LennardJonesModel
class.examples/scripts/2_Structural_optimization/2.2_Soft_Sphere_FIRE.py (5)
15-16
: LGTM! Imports correctly updated to use batched models and optimizers.The imports have been properly updated to use the batched
SoftSphereModel
andfire
optimizer from the main modules instead of the unbatched versions.
28-29
: Good refactor for improved code clarity.The environment variable check has been refactored into a clear boolean
SMOKE_TEST
flag, consistent with other example scripts.
79-85
: Correct handling of batch dimension in SimState creation.The cell tensor is properly unsqueezed to add a batch dimension before passing to
SimState
, which is required for the batched API.
88-95
: LGTM! Model initialization correctly uses batched SoftSphereModel.The model initialization has been properly updated to use the batched
SoftSphereModel
class with appropriate parameters.
112-112
: Correct indexing for batched energy and force access.The energy and force values are correctly accessed using
[0]
indexing to get the first batch element, which is appropriate for the batched state structure.Also applies to: 116-119
examples/scripts/3_Dynamics/3.13_MACE_NVE_non_pbc.py (4)
17-18
: LGTM! Clean migration to batched integrators and models.The import changes correctly replace unbatched implementations with their batched counterparts from the new
torch_sim.integrators
module.
40-41
: Good refactoring of environment variable handling.The
SMOKE_TEST
boolean provides cleaner, more readable control flow compared to inline environment variable checks.
61-63
: Excellent explicit tensor creation.Creating the temperature tensor with explicit device and dtype parameters ensures proper tensor placement and prevents potential device/dtype mismatches in batched computations.
80-80
: Correct addition of batch argument.The
batch=state.batch
argument is properly added to handle batched kinetic energy calculations, aligning with the new batched API requirements.examples/scripts/3_Dynamics/3.4_MACE_NVT_Langevin.py (3)
16-17
: Consistent migration to batched APIs.The imports correctly transition from unbatched to batched implementations, following the same pattern across dynamics scripts.
70-70
: Proper tensor creation with device/dtype specification.Explicit device and dtype parameters ensure the temperature tensor is correctly placed and typed for batched computations.
85-89
: Correct batch argument handling in temperature calculations.The addition of
batch=state.batch
tocalc_kT
calls properly handles batched states, and the.item()
calls correctly extract scalar values for printing.Also applies to: 92-96
torch_sim/math.py (1)
992-1022
: Well-implemented batched utility function.The
batched_vdot
function provides efficient batched vector dot product computation with:
- Comprehensive input validation covering shape mismatches and negative indices
- Efficient implementation using
scatter_add_
for batch accumulation- Clear documentation with explicit shape specifications
- Proper error handling with informative messages
This utility function nicely supports the broader transition to batched operations throughout the codebase.
examples/scripts/3_Dynamics/3.5_MACE_NVT_Nose_Hoover.py (4)
16-17
: Consistent import migration to batched APIs.The imports follow the established pattern of transitioning to batched integrators and models.
54-54
: Good use of utility function for state initialization.Using
ts.io.atoms_to_state
provides a standardized way to convert ASE atoms to simulation state, replacing manual tensor construction.
60-62
: Proper temperature tensor initialization.Explicit device and dtype specification ensures correct tensor placement for batched operations.
54-54
:❌ Incorrect review comment
Investigate pipeline failure with cell tensor shapes.
The pipeline failure indicates a shape mismatch in neighbor list computation: "box must be 3 x 3 tensor, but the shape is [64, 3]". This suggests the cell tensor is incorrectly shaped for batched neighbor list operations.
🏁 Script executed:
#!/bin/bash # Description: Search for cell tensor handling in neighbor list functions and MaceModel # Expected: Find how cell tensors are processed in batched operations echo "=== Searching for cell tensor usage in neighbor list functions ===" rg -A 5 -B 5 "box.*3.*3" --type py echo "=== Searching for row_vector_cell usage in MaceModel ===" rg -A 10 -B 5 "row_vector_cell" --type py echo "=== Searching for neighbor_list_fn usage ===" ast-grep --pattern 'neighbor_list_fn($$$)'Length of output: 46097
I need to inspect the example script to confirm how
atoms_to_state
constructs the cell tensor and whether it’s passed correctly into the MACE model’s neighbor-list call. Let me pull the relevant lines around theatoms_to_state
invocation.
🏁 Script executed:
#!/bin/bash # Locate and dump the first 200 lines of the example script to review cell handling fd -t f "3.5_MACE_NVT_Nose_Hoover.py" --print0 | xargs -0 sed -n '1,200p'Length of output: 2375
To confirm how the ASE cell is converted into the SimState (and why it may become shape
[64, 3]
), let’s inspect theatoms_to_state
implementation.
🏁 Script executed:
#!/bin/bash # Find and display the definition of atoms_to_state rg -n "def atoms_to_state" -R torch_simLength of output: 446
We need to locate the actual implementation of
atoms_to_state
. Let’s search across thetorch_sim/io
module for its definition:
🏁 Script executed:
#!/bin/bash # Search for the atoms_to_state definition in torch_sim/io rg -n "def atoms_to_state" -g "torch_sim/io/*.py"Length of output: 200
Let’s locate the
atoms_to_state
definition and see howneighbor_list_fn
is bound inMaceModel
:
🏁 Script executed:
#!/bin/bash echo "=== Find atoms_to_state occurrences ===" rg -n "atoms_to_state" echo "=== Inspect neighbor_list_fn in MaceModel ===" rg -n "self.neighbor_list_fn" -n torch_sim/models/mace.pyLength of output: 7245
I’ll inspect
SimState
to confirm how the cell tensor is stored and sliced:
🏁 Script executed:
#!/bin/bash # Locate SimState class definition rg -n "class SimState" -n torch_sim/state.py # Dump its cell handling portion sed -n '800,880p' torch_sim/state.py sed -n '880,960p' torch_sim/state.pyLength of output: 4600
Inspecting the
atoms_to_state
implementation to see how it constructs the cell tensor:
🏁 Script executed:
#!/bin/bash # Dump the atoms_to_state function from torch_sim/io.py sed -n '180,300p' torch_sim/io.pyLength of output: 4059
Let’s inspect how
neighbor_list_fn
is selected and bound inMaceModel
:
🏁 Script executed:
#!/bin/bash # Dump lines 100–200 of MaceModel to see neighbor_list_fn setup sed -n '100,200p' torch_sim/models/mace.pyLength of output: 4536
Ignore the cell‐shape issue in this example script
The call
state = ts.io.atoms_to_state(si_dc, device=device, dtype=dtype)correctly produces a cell tensor of shape
[1,3,3]
(one batch), andMaceModel
usesstate.row_vector_cell[b] # → [3,3]for each neighbor‐list call. The reported error
box must be 3 x 3 tensor, but the shape is [64, 3]
actually stems from how
MaceModel.forward
passes the batched cell into the underlying MACE model—not from this example oratoms_to_state
. No changes are required in3.5_MACE_NVT_Nose_Hoover.py
.Likely an incorrect or invalid review comment.
examples/scripts/5_Workflow/5.1_a2c_silicon_batched.py (5)
27-34
: LGTM! Import updates align with batched model transition.The import changes correctly transition from unbatched to batched model interfaces, consistent with the broader codebase refactoring.
55-55
: Good practice: Centralized CI test configuration.Using a boolean
SMOKE_TEST
variable improves readability compared to inline environment variable checks.
75-82
: Model initialization correctly updated for batched interface.The transition from
UnbatchedMaceModel
toMaceModel
is appropriate and consistent with the codebase changes.
220-225
:✅ Verification successful
Verify function name change consistency.
The change from
get_unit_cell_relaxed_structure_batched
toget_unit_cell_relaxed_structure
appears inconsistent with using a batched model. Please ensure this function properly handles batched inputs.
🏁 Script executed:
#!/bin/bash # Description: Verify the function signature and usage to ensure it handles batched models correctly # Expected: Function should accept batched states and models # Search for the function definition and usage ast-grep --pattern 'def get_unit_cell_relaxed_structure($$$)' echo "--- Function usage patterns ---" rg -A 3 -B 3 'get_unit_cell_relaxed_structure'Length of output: 6793
Batched input support confirmed in get_unit_cell_relaxed_structure
I’ve verified that the renamed function correctly handles batched inputs:
- In torch_sim/workflows/a2c.py (lines 704–724), it initializes
logger
withstate.n_batches
and drives optimization viabatched_unit_cell_fire
, which returns batch-aware init/update functions.- examples/scripts/5_Workflow/5.1_a2c_silicon_batched.py invokes it on a multi‐structure
batch_state
.- tests/workflows/test_a2c.py cover the single‐batch case, but the implementation universally supports any
state.n_batches
.No further changes are needed here.
94-102
:❌ Incorrect review comment
Investigate pipeline failure related to tensor shapes.
The pipeline failure indicates a TorchScript error: "
box
must be 3 x 3 tensor, but the shape is [64, 3]". This suggests a batching issue where the cell tensor has incorrect dimensions. Please verify that batched cell tensors are properly handled.
🏁 Script executed:
#!/bin/bash # Description: Search for potential cell/box tensor shape issues in batched processing # Expected: Find where cell tensors might be incorrectly shaped for TorchScript # Look for cell tensor usage in A2C and batching code rg -A 5 -B 5 'cell.*shape\|box.*shape\|TorchScript.*cell\|TorchScript.*box' # Search for cell tensor operations in the A2C workflow ast-grep --pattern 'cell = $$$' # Look for batched cell handling rg -A 3 'cell.*batch\|batch.*cell'Length of output: 7199
🏁 Script executed:
#!/bin/bash # Display the first 200 lines of the A2C Silicon batched example to inspect cell batching sed -n '1,200p' examples/scripts/5_Workflow/5.1_a2c_silicon_batched.pyLength of output: 6125
🏁 Script executed:
#!/bin/bash # Inspect A2C workflow for subcell handling and cell tensor shaping rg -n "def get_subcells_to_crystallize" -A 200 -B 5 torch_sim/workflows/a2c.py rg -n "def subcells_to_structures" -A 200 -B 5 torch_sim/workflows/a2c.pyLength of output: 17387
🏁 Script executed:
#!/bin/bash # Inspect random_packed_structure implementation for batching logic rg -n "def random_packed_structure" -A200 -B5 torch_sim/workflows/a2c.pyLength of output: 15135
No batched cell tensor here
The example uses a fixed 3×3cell
(e.g.cell = torch.tensor([[11.1, 0, 0], [ 0, 11.1, 0], [ 0, 0, 11.1]], dtype=dtype, device=device)) and never applies a batch dimension to it. The
batch_size
parameter governs how many structures are relaxed in parallel, not the shape ofcell
. The TorchScript “box must be 3×3 tensor, but shape is [64,3]” error must originate elsewhere.Likely an incorrect or invalid review comment.
examples/scripts/1_Introduction/1.2_MACE.py (2)
1-1
: Excellent refactor to demonstrate batched MACE usage.The updated description accurately reflects the new batched functionality.
35-89
: Well-implemented batched tensor construction.The code correctly demonstrates:
- Concatenating positions from multiple structures
- Stacking cell tensors with proper batch dimension
- Creating batch indices using
repeat_interleave
- Proper tensor device and dtype handling
This is an excellent educational example of batched model usage.
🧰 Tools
🪛 Pylint (3.3.7)
[refactor] 82-88: Consider using '{"positions": positions, "cell": cell, "atomic_numbers": atomic_numbers, ... }' instead of a call to 'dict'.
(R1735)
examples/scripts/3_Dynamics/3.3_MACE_NVE_cueq.py (5)
17-18
: LGTM! Import updates align with batched model transition.The updated imports correctly use the general integrators module and batched MACE model.
40-41
: Good practice: Centralized CI test configuration.The
SMOKE_TEST
pattern improves code clarity and consistency across examples.
55-55
: Excellent simplification using utility function.Using
ts.io.atoms_to_state
is much cleaner than manual tensor construction and handles all the required conversions automatically.
74-74
: Correctly added batch argument for batched processing.The addition of the
batch
argument tocalc_kinetic_energy
is necessary for proper batched tensor handling.
53-53
:❓ Verification inconclusive
Address pipeline failure: Model loading incompatibility.
The pipeline shows a critical model loading error with state_dict mismatches. The
enable_cueq=True
setting may be causing compatibility issues with certain model checkpoints.Consider making CuEq conditional or handling the model loading error gracefully:
- enable_cueq=True, + enable_cueq=torch.cuda.is_available(),
🏁 Script executed:
#!/bin/bash # Description: Check for CuEq compatibility issues and model loading patterns # Expected: Find how other scripts handle CuEq and model loading errors # Search for CuEq usage patterns rg -A 3 -B 3 'enable_cueq.*True\|enable_cueq.*False' # Look for model loading error handling rg -A 5 'state_dict.*mismatch\|Unexpected.*keys'Length of output: 105
#!/bin/bash
set -e1. Find all occurrences of enable_cueq in the repo
rg --no-heading --line-number 'enable_cueq'
2. Show context around enable_cueq in the specific script
rg --no-heading --line-number -C3 'enable_cueq' examples/scripts/3_Dynamics/3.3_MACE_NVE_cueq.py
3. Check for model loading calls and state_dict handling in the script
rg --no-heading --line-number -C3 'load_state_dict' examples/scripts/3_Dynamics/3.3_MACE_NVE_cueq.py
rg --no-heading --line-number -C3 'torch.load' examples/scripts/3_Dynamics/3.3_MACE_NVE_cueq.py4. Look for any exception handling around model loading
rg --no-heading --line-number -C3 'try:' examples/scripts/3_Dynamics/3.3_MACE_NVE_cueq.py
#!/bin/bash
set -e1. Find all enable_cueq occurrences in the repo
rg --no-heading --line-number 'enable_cueq' -C3 .
2. Show context around enable_cueq in the specific script
rg --no-heading --line-number 'enable_cueq' -C3 examples/scripts/3_Dynamics/3.3_MACE_NVE_cueq.py
3. Compare with other scripts to see default handling
rg --no-heading --line-number 'enable_cueq' -C3 examples/scripts
4. Inspect model loading and error handling in the CuEq script
rg --no-heading --line-number -e 'load_state_dict' -e 'torch.load' -e 'except' -C3 examples/scripts/3_Dynamics/3.3_MACE_NVE_cueq.py
#!/bin/bash
set -e1. Locate all occurrences of enable_cueq across the repo
rg --no-heading --line-number 'enable_cueq' -C3 .
2. Show full context around the flag in the CuEq script
rg --no-heading --line-number -C5 'enable_cueq' examples/scripts/3_Dynamics/3.3_MACE_NVE_cueq.py
3. Inspect model‐loading logic and error handling in that script
rg --no-heading --line-number -C5 'torch.load|load_state_dict|checkpoint' examples/scripts/3_Dynamics/3.3_MACE_NVE_cueq.py
rg --no-heading --line-number -C5 'try:|except' examples/scripts/3_Dynamics/3.3_MACE_NVE_cueq.py4. Compare CuEq defaults in other example scripts
rg --no-heading --line-number -C3 'enable_cueq' examples/scripts
5. Check for documented CuEq compatibility notes in README or docs
rg --files-with-matches 'cueq' -e 'enable_cueq' -e 'CuEq' -t md -t rst .
Verify CuEq flag usage and model loading error handling
I wasn’t able to find any existing conditional logic around
enable_cueq
or try/except blocks for model loading in this script or others—please double-check before merging. To guard against state_dict mismatches on devices without CuEq support, consider:
- Making CuEq conditional on CUDA availability
- Wrapping
load_state_dict
in a try/except to fall back gracefullySuggested diff:
- enable_cueq=True, + enable_cueq=torch.cuda.is_available(),And around model loading:
try: model.load_state_dict(checkpoint['state_dict']) except RuntimeError as e: logger.warning(f"Model loading failed: {e}. Disabling CuEq and retrying.") enable_cueq = False # reload or alternate logic herePlease verify this approach against your pipeline’s model checkpoints and ensure it resolves the incompatibility.
torch_sim/quantities.py (1)
120-123
: Improved comment clarity.The change from inline to trailing comments makes the conditional logic clearer and more readable.
examples/scripts/2_Structural_optimization/2.1_Lennard_Jones_FIRE.py (5)
15-16
: Successfully migrated to batched models.The import changes from unbatched to batched implementations align perfectly with the PR objectives.
28-29
: Good centralization of CI environment detection.Using a
SMOKE_TEST
variable improves code readability and makes the CI-specific behavior more explicit.
91-91
: Correct batch dimension handling.Adding the batch dimension to the cell tensor with
unsqueeze(0)
properly adapts the script to the batched state representation.
111-118
: Proper indexing for batched outputs.The indexing with
[0]
correctly extracts values from the first batch element, handling the transition to batched tensor outputs.
76-85
:✅ Verification successful
Verify parameter compatibility between unbatched and batched models.
The model parameters appear identical, but it's worth confirming that the batched
LennardJonesModel
handles these parameters the same way as the unbatched version.
🏁 Script executed:
#!/bin/bash # Description: Verify LennardJonesModel parameter compatibility and usage patterns # Search for LennardJonesModel initialization patterns to compare parameters rg -A 10 "LennardJonesModel\(" --type pyLength of output: 14146
No parameter compatibility issues detected
Initialization of LennardJonesModel with and without neighbor lists uses the same keyword arguments, and the existing test tests/models/test_lennard_jones.py (
test_energy_match
) explicitly verifies that both direct and neighbor-list modes produce identical results. No changes are needed.
- tests/models/test_lennard_jones.py:
test_energy_match
confirms parameter handling is consistent.torch_sim/integrators/__init__.py (1)
1-28
: Well-structured integrator module with clear API.This module provides an excellent entry point for the integrator functionality with:
- Comprehensive documentation and usage examples
- Logical organization of imports from submodules
- Proper handling of linter warnings for re-exports
The modular design supports the transition from unbatched to batched integrators effectively.
tests/workflows/test_a2c.py (3)
6-7
: Consistent migration to batched model imports.The import changes from
UnbatchedSoftSphereModel
toSoftSphereModel
andUnitCellFIREState
toUnitCellFireState
align with the codebase-wide removal of unbatched implementations.
333-342
: Correct model type annotation and instantiation.The function now properly returns the batched
SoftSphereModel
with the same parameters as the unbatched version, maintaining test functionality.
377-380
: Proper adaptation to batched outputs.The type assertion and tensor indexing changes correctly handle the new batched state class and tensor outputs, ensuring tests continue to validate the expected behavior.
examples/scripts/3_Dynamics/3.12_MACE_NPT_Langevin.py (5)
16-19
: LGTM: Import updates align with batched API migration.The imports have been correctly updated to use the batched integrators and models from their respective modules.
43-50
: Model initialization updated correctly for batched API.The migration from
UnbatchedMaceModel
toMaceModel
is correct, and usingts.io.atoms_to_state
simplifies state creation by handling tensor conversions automatically.
71-76
: Temperature calculations correctly updated for batched states.Adding the
batch=state.batch
argument ensures proper temperature calculation for batched systems, and using.item()
for scalar extraction in print statements is correct.
86-89
: Consistent batched quantity calculations throughout simulation loop.All temperature, kinetic energy, and pressure calculations correctly include the batch argument and use
.item()
for scalar printing. The implementation is consistent with the batched API.Also applies to: 94-99, 103-107
109-114
: Final calculations properly handle batched states.The final temperature and pressure calculations maintain consistency with the batched approach used throughout the script.
Also applies to: 119-121
examples/scripts/3_Dynamics/3.6_MACE_NVT_Nose_Hoover_temp_profile.py (4)
22-24
: Import updates correctly migrate to batched API.The imports properly reference the batched
MaceModel
and NVT integrators.
99-104
: SMOKE_TEST boolean improves CI detection clarity.The boolean assignment is cleaner than inline environment variable checks and makes the conditional logic more readable.
140-140
: Temperature tensor creation improved with explicit device/dtype.Creating the temperature tensor with explicit device and dtype parameters ensures consistency with the simulation state.
165-174
: Temperature calculations correctly handle batched states.The addition of
batch=state.batch
tocalc_kT
and the use of.item()
for printing are correct for the batched API.examples/scripts/3_Dynamics/3.8_MACE_NPT_Nose_Hoover.py (5)
17-19
: Model and utility imports correctly updated.The
MaceModel
import and quantities imports are properly updated for the batched API.
42-50
: Model initialization follows batched API pattern.The model and state initialization are consistent with other updated scripts, though the same potential cell tensor issue may affect this script.
55-61
: CI detection and temperature handling improved.The
SMOKE_TEST
boolean and explicit temperature tensor creation are consistent improvements.
69-76
: Temperature and pressure calculations properly handle batched states.All quantity calculations correctly include the batch argument and use
.item()
for scalar extraction in print statements.Also applies to: 86-89, 97-99, 104-107
110-122
: Final calculations maintain batched API consistency.The final temperature and pressure calculations follow the same correct pattern used throughout the script.
torch_sim/models/lennard_jones.py (2)
40-77
: Lennard-Jones potential implementation is mathematically correct.The function correctly implements the 12-6 Lennard-Jones potential:
V(r) = 4*epsilon*[(sigma/r)^12 - (sigma/r)^6]
. The edge case handling withtorch.where(dr > 0, energy, torch.zeros_like(energy))
prevents numerical instabilities from zero distances.
79-120
: Lennard-Jones force calculation correctly implements the negative gradient.The force implementation
F(r) = 24*epsilon/r * [2*(sigma/r)^12 - (sigma/r)^6]
is the correct negative gradient of the potential. The same edge case protection prevents division by zero.examples/scripts/1_Introduction/1.1_Lennard_Jones.py (1)
63-85
: LGTM! Correct transition to batched Lennard-Jones model.The changes properly replace the unbatched model with the batched
LennardJonesModel
and correctly add the batch dimension to the cell tensor withcell.unsqueeze(0)
. This aligns perfectly with the PR objective of removing unbatched code while maintaining functionality.🧰 Tools
🪛 Pylint (3.3.7)
[refactor] 77-82: Consider using '{"positions": positions, "cell": cell.unsqueeze(0), "atomic_numbers": atomic_numbers, ... }' instead of a call to 'dict'.
(R1735)
torch_sim/runners.py (4)
19-24
: LGTM! Proper imports for new batched optimizer states.These imports support the transition from unbatched to batched optimizers, aligning with the PR objective.
287-317
: LGTM! Well-implemented cell forces convergence checking.The addition of
include_cell_forces
parameter enhances the convergence function to optionally check cell forces alongside regular forces. The implementation correctly:
- Retrieves cell forces from the state with proper error handling
- Computes the norm across appropriate dimensions
- Combines cell force and regular force convergence with logical AND
This is a valuable enhancement for cell optimization scenarios.
401-410
: LGTM! Appropriate conditional initialization for stateful optimizers.The logic correctly skips
_chunked_apply
initialization for the new batched optimizer state classes that handle their own initialization. This prevents double initialization and supports the transition to stateful optimizers.
540-562
: LGTM! Improved variable naming consistency.The renaming from
substate
tosub_state
improves code consistency and readability.examples/scripts/7_Others/7.6_Compare_ASE_to_VV_FIRE.py (2)
33-33
: LGTM! Cleaner CI environment detection.The
SMOKE_TEST
boolean consolidates repeated environment variable checks, improving code readability and maintainability.
175-177
: LGTM! Proper integration with enhanced convergence checking.The convergence function correctly passes
include_cell_forces=ts_use_frechet
to enable cell force convergence checking when Frechet cell optimization is used. This leverages the new functionality added togenerate_force_convergence_fn
.tests/models/test_metatomic.py (3)
11-17
: LGTM! Robust dependency handling.The try-except block with
pytest.skip
gracefully handles missing metatomic dependencies, ensuring tests don't fail in environments where the package isn't installed.
20-44
: LGTM! Well-structured test fixtures.The fixtures for dtype, calculator, and model are properly defined with appropriate parameters and return types. The model instantiation uses sensible defaults.
46-66
: LGTM! Comprehensive test coverage.The tests appropriately cover:
- Model initialization verification
- Consistency between model and calculator outputs
- Model output validation
This provides adequate coverage for the metatomic model integration and follows the established testing patterns.
examples/scripts/3_Dynamics/3.7_Lennard_Jones_NPT_Nose_Hoover.py (8)
15-17
: LGTM! Import updates align with batched API migration.The imports correctly reflect the transition from unbatched to batched modules as described in the PR objectives.
37-38
: LGTM! Good CI optimization practice.The SMOKE_TEST flag appropriately reduces simulation complexity in CI environments while maintaining test coverage.
85-94
: LGTM! Correct batched model initialization.The LennardJonesModel initialization properly uses the batched API with appropriate parameters.
107-109
: LGTM! Proper tensor creation with device/dtype consistency.The target_pressure tensor is correctly created with explicit device and dtype, ensuring consistency with other tensors.
125-128
: LGTM! Correct batched function calls.The addition of
batch=state.batch
arguments to calc_kT and calc_kinetic_energy is required for the batched API, and the.item()
calls properly extract scalar values.Also applies to: 134-136, 149-152
140-140
: LGTM! Proper batched tensor indexing.The ellipsis indexing
state.cell[..., i, i]
correctly handles batched cell tensors.
142-143
: LGTM! Necessary scalar extraction from tensors.The
.item()
calls correctly extract scalar values from tensors returned by the batched API.Also applies to: 153-153
98-98
:✅ Verification successful
Verify cell tensor handling with batched models.
While the
unsqueeze(0)
adds the required batch dimension, the pipeline failure suggests a cell shape issue. The error indicates the model expects either per-atom (256) or per-batch (1) dimensions but receives [3, 3].
🏁 Script executed:
#!/bin/bash # Check how cell tensors are handled in the batched LennardJonesModel ast-grep --pattern $'cell.unsqueeze($_)' echo "---" ast-grep --pattern $'class LennardJonesModel { $$$ def forward($$$) { $$$ } $$$ }'Length of output: 624
🏁 Script executed:
#!/bin/bash # Inspect imports and usage of LennardJonesModel in the script echo "=== Header (first 200 lines) ===" sed -n '1,200p' examples/scripts/3_Dynamics/3.7_Lennard_Jones_NPT_Nose_Hoover.py echo echo "=== Search for LennardJonesModel usage ===" rg "LennardJonesModel" -n examples/scripts/3_Dynamics/3.7_Lennard_Jones_NPT_Nose_Hoover.pyLength of output: 5319
No action required —
cell.unsqueeze(0)
is correctThe
unsqueeze(0)
call correctly adds a batch dimension (yielding[1, 3, 3]
) and matches its usage in all Lennard-Jones examples. You can ignore the earlier concern about a cell-shape mismatch.torch_sim/models/soft_sphere.py (2)
59-94
: LGTM! Well-implemented soft sphere potential with proper safeguards.The function correctly implements the soft sphere potential V = ε/α(1 - r/σ)^α with finite cutoff at σ. The use of
transforms.safe_mask
ensures numerical stability for masked values.
96-131
: LGTM! Correct force implementation as derivative of potential.The function properly implements the force as F = -(ε/σ)(1 - r/σ)^(α-1), which is the correct negative gradient of the potential. Consistent use of
safe_mask
ensures stability.examples/scripts/3_Dynamics/3.11_Lennard_Jones_NPT_Langevin.py (5)
15-17
: LGTM! Consistent batched API imports.The imports correctly use the batched modules, consistent with the codebase migration.
38-39
: LGTM! Consistent CI optimization.The SMOKE_TEST implementation matches other scripts and provides appropriate CI optimization.
86-102
: LGTM! Proper batched model and state initialization.The LennardJonesModel and SimState are correctly initialized for the batched API with proper batch dimensions.
108-110
: LGTM! Proper pressure tensor creation.The target_pressure tensor is correctly created with explicit device/dtype. Note the pressure value (10,000) differs from the Nose-Hoover script (10), which may be intentional for different simulation conditions.
121-124
: LGTM! Consistent batched function calls.The addition of
batch=state.batch
arguments and.item()
calls correctly handle the batched API requirements.Also applies to: 127-129, 141-144
torch_sim/integrators/nve.py (3)
17-57
: LGTM! Well-designed NVE integrator API.The function provides a clean interface for NVE integration with comprehensive documentation and proper typing. The approach of returning init/update function pairs is a good design pattern.
59-105
: LGTM! Robust state initialization with proper fallbacks.The initialization function correctly handles input validation, computes initial forces/energy, and provides sensible defaults for missing momenta using Maxwell-Boltzmann sampling.
107-141
: LGTM! Correct velocity Verlet implementation.The NVE update function properly implements the velocity Verlet algorithm with the correct sequence: half momentum step, position update, force recalculation, and final half momentum step. This preserves the symplectic nature and energy conservation.
torch_sim/integrators/nvt.py (1)
161-166
:calculate_momenta()
call signature: verify extrabatch
arg
calculate_momenta(state.positions, state.masses, state.batch, kT, seed)
presumes the helper takes abatch
argument.
Double-check – older helpers in this repo only accepted(positions, masses, kT, device, dtype, seed)
.torch_sim/optimizers.py (1)
586-595
: Optional fields default toNone
– but later assumed tensor
velocities=None
is stored inFireState
, yet many step functions immediately dostate.velocities += …
without a prioris None
guard (only_vv_fire_step
checks).
Ensure all step variants (_ase_fire_step
as well) initialise velocities when missing.
@@ -3,7 +3,7 @@ | |||
|
|||
# /// script | |||
# dependencies = [ | |||
# "fairchem-core", | |||
# "fairchem-core==1.10.0", |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
💡 Verification agent
🧩 Analysis chain
Use a version range for the fairchem-core
dependency
Pinning to 1.10.0
ensures reproducibility, but an exact version can prevent non-critical patch updates. Consider loosening to a minor range:
-# "fairchem-core==1.10.0",
+# "fairchem-core>=1.10.0,<1.11.0",
Also, verify that all example scripts use a consistent fairchem-core
version:
🏁 Script executed:
rg -n "fairchem-core" examples/
Length of output: 112
Loosen fairchem-core version pin
Using an exact version can block non-critical patch updates. Update to a minor-range spec that allows patch releases while staying within 1.x:
• File: examples/scripts/1_Introduction/1.3_Fairchem.py
(line 6)
- "fairchem-core==1.10.0",
+ "fairchem-core>=1.10.0,<1.11.0",
No other references to fairchem-core
were found in examples/
, so this change keeps all examples consistent.
Committable suggestion skipped: line range outside the PR's diff.
🤖 Prompt for AI Agents
In examples/scripts/1_Introduction/1.3_Fairchem.py at line 6, the fairchem-core
dependency is pinned to an exact version 1.10.0. Change this to a minor version
range, such as "fairchem-core>=1.10.0,<1.11.0", to allow patch updates while
maintaining compatibility. This ensures flexibility for non-critical updates
without breaking reproducibility.
def get_pressure( | ||
stress: torch.Tensor, kinetic_energy: torch.Tensor, volume: torch.Tensor, dim: int = 3 | ||
) -> torch.Tensor: | ||
"""Compute the pressure from the stress tensor. | ||
|
||
The stress tensor is defined as 1/volume * dU/de_ij | ||
So the pressure is -1/volume * trace(dU/de_ij) | ||
""" | ||
return 1 / dim * ((2 * kinetic_energy / volume) - torch.einsum("...ii", stress)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
🛠️ Refactor suggestion
Improve documentation and add error handling for the pressure calculation.
The function implementation appears correct for the virial equation of state, but the docstring is incomplete and doesn't explain the full formula. Additionally, the function lacks error handling for edge cases.
Apply this diff to improve the documentation and add error handling:
def get_pressure(
stress: torch.Tensor, kinetic_energy: torch.Tensor, volume: torch.Tensor, dim: int = 3
) -> torch.Tensor:
- """Compute the pressure from the stress tensor.
+ """Compute the pressure using the virial equation of state.
- The stress tensor is defined as 1/volume * dU/de_ij
- So the pressure is -1/volume * trace(dU/de_ij)
+ The pressure is calculated as:
+ P = (1/dim) * (2 * kinetic_energy / volume - trace(stress))
+
+ Where:
+ - kinetic_energy: kinetic energy of the system
+ - stress: stress tensor (dU/de_ij normalized by volume)
+ - volume: system volume
+ - dim: spatial dimensions (default 3)
+
+ Args:
+ stress: Stress tensor with shape (..., dim, dim)
+ kinetic_energy: Kinetic energy with shape (...)
+ volume: System volume with shape (...)
+ dim: Spatial dimensions
+
+ Returns:
+ Pressure tensor with shape (...)
+
+ Raises:
+ ValueError: If volume is zero or negative
"""
+ if torch.any(volume <= 0):
+ raise ValueError("Volume must be positive")
return 1 / dim * ((2 * kinetic_energy / volume) - torch.einsum("...ii", stress))
📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
def get_pressure( | |
stress: torch.Tensor, kinetic_energy: torch.Tensor, volume: torch.Tensor, dim: int = 3 | |
) -> torch.Tensor: | |
"""Compute the pressure from the stress tensor. | |
The stress tensor is defined as 1/volume * dU/de_ij | |
So the pressure is -1/volume * trace(dU/de_ij) | |
""" | |
return 1 / dim * ((2 * kinetic_energy / volume) - torch.einsum("...ii", stress)) | |
def get_pressure( | |
stress: torch.Tensor, kinetic_energy: torch.Tensor, volume: torch.Tensor, dim: int = 3 | |
) -> torch.Tensor: | |
"""Compute the pressure using the virial equation of state. | |
The pressure is calculated as: | |
P = (1/dim) * (2 * kinetic_energy / volume - trace(stress)) | |
Where: | |
- kinetic_energy: kinetic energy of the system | |
- stress: stress tensor (dU/de_ij normalized by volume) | |
- volume: system volume | |
- dim: spatial dimensions (default 3) | |
Args: | |
stress: Stress tensor with shape (..., dim, dim) | |
kinetic_energy: Kinetic energy with shape (...) | |
volume: System volume with shape (...) | |
dim: Spatial dimensions | |
Returns: | |
Pressure tensor with shape (...) | |
Raises: | |
ValueError: If volume is zero or negative | |
""" | |
if torch.any(volume <= 0): | |
raise ValueError("Volume must be positive") | |
return 1 / dim * ((2 * kinetic_energy / volume) - torch.einsum("...ii", stress)) |
🤖 Prompt for AI Agents
In torch_sim/quantities.py around lines 133 to 141, improve the docstring to
fully explain the pressure calculation formula, including the virial equation of
state context. Add error handling to check that the input tensors stress,
kinetic_energy, and volume have compatible shapes and valid values (e.g., volume
should not be zero or negative) to prevent runtime errors. Raise appropriate
exceptions if inputs are invalid.
nvt_init, nvt_update = nvt_langevin(model=model, kT=kT, dt=dt) | ||
state = nvt_init(state, kT=kT, seed=1) | ||
|
||
stress = torch.zeros(N_steps // 10, 3, 3, device=device, dtype=dtype) | ||
for step in range(N_steps): | ||
temp = calc_kT(masses=state.masses, momenta=state.momenta) / Units.temperature | ||
invariant = nvt_nose_hoover_invariant(state, kT=kT).item() | ||
print(f"{step=}: Temperature: {temp:.4f}: invariant: {invariant:.4f}") | ||
temp = ( | ||
calc_kT(masses=state.masses, momenta=state.momenta, batch=state.batch) | ||
/ Units.temperature | ||
) | ||
invariant = nvt_langevin(state, kT=kT).item() | ||
print(f"{step=}: Temperature: {temp.item():.4f}: invariant: {invariant:.4f}") | ||
state = nvt_update(state, kT=kT) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Broken invariant call – wrong function invoked inside the loop
nvt_langevin()
is invoked a second time inside the time-integration loop and (as the CI log shows) raises TypeError: missing keyword-only argument: 'dt'
.
The integrator factory already returned nvt_init
and nvt_update
; most likely it also returns an nvt_invariant
(or similarly named) callback that should be used here.
-# build integrator
-nvt_init, nvt_update = nvt_langevin(model=model, kT=kT, dt=dt)
+ # build integrator – the factory returns three callables
+nvt_init, nvt_update, nvt_invariant = nvt_langevin(model=model, kT=kT, dt=dt)
...
- invariant = nvt_langevin(state, kT=kT).item()
+ invariant = nvt_invariant(state, kT=kT).item()
This resolves the missing-argument error and prevents the factory from being rebuilt at every step.
📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
nvt_init, nvt_update = nvt_langevin(model=model, kT=kT, dt=dt) | |
state = nvt_init(state, kT=kT, seed=1) | |
stress = torch.zeros(N_steps // 10, 3, 3, device=device, dtype=dtype) | |
for step in range(N_steps): | |
temp = calc_kT(masses=state.masses, momenta=state.momenta) / Units.temperature | |
invariant = nvt_nose_hoover_invariant(state, kT=kT).item() | |
print(f"{step=}: Temperature: {temp:.4f}: invariant: {invariant:.4f}") | |
temp = ( | |
calc_kT(masses=state.masses, momenta=state.momenta, batch=state.batch) | |
/ Units.temperature | |
) | |
invariant = nvt_langevin(state, kT=kT).item() | |
print(f"{step=}: Temperature: {temp.item():.4f}: invariant: {invariant:.4f}") | |
state = nvt_update(state, kT=kT) | |
# build integrator – the factory returns three callables | |
nvt_init, nvt_update, nvt_invariant = nvt_langevin(model=model, kT=kT, dt=dt) | |
state = nvt_init(state, kT=kT, seed=1) | |
stress = torch.zeros(N_steps // 10, 3, 3, device=device, dtype=dtype) | |
for step in range(N_steps): | |
temp = ( | |
calc_kT(masses=state.masses, momenta=state.momenta, batch=state.batch) | |
/ Units.temperature | |
) | |
invariant = nvt_invariant(state, kT=kT).item() | |
print(f"{step=}: Temperature: {temp.item():.4f}: invariant: {invariant:.4f}") | |
state = nvt_update(state, kT=kT) |
🧰 Tools
🪛 GitHub Actions: Tests
[error] 71-71: TypeError: nvt_langevin() missing 1 required keyword-only argument: 'dt'. The function call at line 71 is missing the required 'dt' argument.
🤖 Prompt for AI Agents
In examples/scripts/3_Dynamics/3.9_MACE_NVT_staggered_stress.py around lines 62
to 73, the code incorrectly calls nvt_langevin inside the loop, causing a
TypeError due to missing the 'dt' argument and unnecessarily recreating the
integrator. Instead, use the invariant function returned by the nvt_langevin
factory (likely named nvt_invariant or similar) to compute the invariant inside
the loop. Replace the direct call to nvt_langevin(state, kT=kT) with this
invariant callback to fix the error and improve efficiency.
…g nvt_langevin in place of nvt_nose_hoover_invariant remove debug print in torch_sim/runners.py
Signed-off-by: Janosh Riebesell <[email protected]>
…m/integrators/nvt.py also fix broadcasting bug in Ornstein-Uhlenbeck (ou_step)
also fix 3.8_MACE_NPT_Nose_Hoover.py broken imports from torch_sim.integrators.npt
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 2
♻️ Duplicate comments (1)
torch_sim/integrators/npt.py (1)
1375-1375
: Past review comment properly addressed.The
calculate_momenta
function call now uses the correct signature, resolving the previous issue.
🧹 Nitpick comments (1)
torch_sim/integrators/npt.py (1)
1339-1339
: Consider direct computation for scalar kinetic energy.While
calc_kinetic_energy
works with scalar inputs due to broadcasting, it's more natural to compute cell kinetic energy directly for clarity.- KE_cell = calc_kinetic_energy(cell_momentum, cell_mass) + KE_cell = 0.5 * cell_momentum**2 / cell_mass
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (2)
examples/scripts/3_Dynamics/3.8_MACE_NPT_Nose_Hoover.py
(2 hunks)torch_sim/integrators/npt.py
(3 hunks)
🚧 Files skipped from review as they are similar to previous changes (1)
- examples/scripts/3_Dynamics/3.8_MACE_NPT_Nose_Hoover.py
🧰 Additional context used
🧬 Code Graph Analysis (1)
torch_sim/integrators/npt.py (4)
torch_sim/integrators/md.py (6)
MDState
(13-48)NoseHooverChain
(195-217)NoseHooverChainFns
(221-232)calculate_momenta
(51-109)construct_nose_hoover_chain
(262-436)velocities
(44-48)torch_sim/quantities.py (1)
calc_kinetic_energy
(96-130)torch_sim/transforms.py (2)
pbc_wrap_batched
(137-192)pbc_wrap_general
(93-134)torch_sim/state.py (2)
SimState
(26-312)volume
(166-168)
🪛 GitHub Check: codecov/patch
torch_sim/integrators/npt.py
[warning] 864-864: torch_sim/integrators/npt.py#L864
Added line #L864 was not covered by tests
🪛 Pylint (3.3.7)
torch_sim/integrators/npt.py
[refactor] 797-797: Too many instance attributes (9/7)
(R0902)
[refactor] 885-885: Too many arguments (7/5)
(R0913)
[refactor] 885-885: Too many local variables (18/15)
(R0914)
[refactor] 1112-1112: Too many arguments (7/5)
(R0913)
[refactor] 1112-1112: Too many positional arguments (7/5)
(R0917)
[refactor] 1164-1164: Too many local variables (19/15)
(R0914)
[refactor] 1261-1261: Too many local variables (20/15)
(R0914)
[refactor] 885-885: Too many statements (94/50)
(R0915)
⏰ Context from checks skipped due to timeout of 90000ms (4)
- GitHub Check: test-model (macos-14, 3.12, lowest-direct, mace, tests/test_optimizers_vs_ase.py)
- GitHub Check: test-model (macos-14, 3.12, lowest-direct, fairchem, tests/models/test_fairchem.py)
- GitHub Check: test-model (macos-14, 3.11, highest, sevenn, tests/models/test_sevennet.py)
- GitHub Check: test-model (macos-14, 3.11, highest, mace, tests/test_elastic.py)
🔇 Additional comments (3)
torch_sim/integrators/npt.py (3)
22-73
: LGTM! Well-structured NPT Langevin state class.The NPTLangevinState class properly extends SimState with the necessary fields for NPT dynamics. The momenta property correctly implements the physics relationship between velocities and momenta.
🧰 Tools
🪛 Pylint (3.3.7)
[refactor] 23-23: Too many instance attributes (8/7)
(R0902)
541-543
: LGTM! Correct usage of PBC wrapping function.The call to
pbc_wrap_batched
uses the proper arguments in the correct order matching the function signature.
796-883
: LGTM! Well-designed NPT Nose-Hoover state class.The NPTNoseHooverState class properly implements the logarithmic cell coordinate system for volume updates. The current_cell property correctly transforms from cell_position to the actual cell matrix using proper volume scaling.
🧰 Tools
🪛 GitHub Check: codecov/patch
[warning] 864-864: torch_sim/integrators/npt.py#L864
Added line #L864 was not covered by tests🪛 Pylint (3.3.7)
[refactor] 797-797: Too many instance attributes (9/7)
(R0902)
- Update NPTNoseHooverState cell properties to support batch dimensions - reference_cell: [n_batches, 3, 3] instead of [3, 3] - cell_position, cell_momentum, cell_mass: [n_batches] instead of scalar - Fix tensor broadcasting in exp_iL1, exp_iL2 with proper atom-to-batch mapping - Update compute_cell_force for per-batch kinetic energy and stress calculations - Fix npt_nose_hoover_init to properly initialize batched cell variables - Update npt_nose_hoover_invariant for per-batch energy conservation - Replace pbc_wrap_general with pbc_wrap_batched for proper PBC handling - Fix example script 3.8_MACE_NPT_Nose_Hoover.py output formatting Enables multiple independent NPT systems in a single simulation while maintaining backward compatibility for single-batch systems.
- Remove problematic cell.squeeze(0) that breaks batching support - Fix calculate_momenta function call to use correct signature with batch parameter - Resolves RuntimeError when using MACE with NVT Nose-Hoover thermostat Fixes example script 3.5_MACE_NVT_Nose_Hoover.py which was failing due to neighbor list function receiving wrong tensor shapes when cell batch dimension was incorrectly removed.
- Convert scalar kT to tensor before accessing .ndim attribute in npt_nose_hoover_init and update_cell_mass - Fix stress tensor trace computation in compute_cell_force to handle 3D batched tensors - Use torch.diagonal().sum() for batched stress tensors instead of torch.trace() Fixes Lennard-Jones NPT Nose-Hoover script that was failing with: - AttributeError: 'float' object has no attribute 'ndim' - RuntimeError: trace: expected a matrix, but got tensor with dim 3
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 3
🧹 Nitpick comments (2)
torch_sim/integrators/nvt.py (2)
268-273
: Remove redundantvelocities
override
NVTNoseHooverState
inherits the samevelocities
property fromMDState
. This override is identical and can be removed to reduce duplication.
444-468
: Align docstring parameters with signatureThe docstring for
nvt_nose_hoover_invariant
lists anenergy_fn
argument that no longer exists and omits the actual(state, kT)
parameters. Please update theArgs:
section to match the function signature.
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (1)
torch_sim/integrators/nvt.py
(1 hunks)
🧰 Additional context used
🧬 Code Graph Analysis (1)
torch_sim/integrators/nvt.py (4)
torch_sim/integrators/md.py (8)
MDState
(13-48)NoseHooverChain
(195-217)NoseHooverChainFns
(221-232)calculate_momenta
(51-109)momentum_step
(112-129)position_step
(132-156)velocity_verlet
(159-191)velocities
(44-48)torch_sim/quantities.py (2)
calc_kinetic_energy
(96-130)count_dof
(10-19)torch_sim/state.py (4)
SimState
(26-312)device
(137-139)dtype
(142-144)momenta
(319-325)tests/conftest.py (2)
device
(24-25)dtype
(29-30)
🪛 Pylint (3.3.7)
torch_sim/integrators/nvt.py
[refactor] 275-275: Too many arguments (6/5)
(R0913)
⏰ Context from checks skipped due to timeout of 90000ms (31)
- GitHub Check: test-examples (examples/tutorials/using_graphpes_tutorial.py)
- GitHub Check: test-examples (examples/scripts/2_Structural_optimization/2.7_MACE_FrechetCellFilter_FIRE.py)
- GitHub Check: test-examples (examples/tutorials/low_level_tutorial.py)
- GitHub Check: test-examples (examples/tutorials/high_level_tutorial.py)
- GitHub Check: test-examples (examples/scripts/2_Structural_optimization/2.4_MACE_FIRE.py)
- GitHub Check: test-examples (examples/tutorials/metatomic_tutorial.py)
- GitHub Check: test-examples (examples/scripts/2_Structural_optimization/2.6_MACE_UnitCellFilter_FIRE.py)
- GitHub Check: test-examples (examples/scripts/2_Structural_optimization/2.3_MACE_Gradient_Descent.py)
- GitHub Check: test-examples (examples/scripts/5_Workflow/5.3_Elastic.py)
- GitHub Check: test-examples (examples/scripts/6_Phonons/6.3_Conductivity_MACE.py)
- GitHub Check: test-examples (examples/scripts/4_High_level_api/4.2_auto_batching_api.py)
- GitHub Check: test-examples (examples/scripts/4_High_level_api/4.1_high_level_api.py)
- GitHub Check: test-examples (examples/scripts/3_Dynamics/3.6_MACE_NVT_Nose_Hoover_temp_profile.py)
- GitHub Check: test-examples (examples/scripts/7_Others/7.6_Compare_ASE_to_VV_FIRE.py)
- GitHub Check: test-examples (examples/scripts/7_Others/7.4_Velocity_AutoCorrelation.py)
- GitHub Check: test-model (macos-14, 3.12, lowest-direct, fairchem, tests/models/test_fairchem.py)
- GitHub Check: test-model (macos-14, 3.11, highest, mace, tests/test_elastic.py)
- GitHub Check: test-model (macos-14, 3.11, highest, graphpes, tests/models/test_graphpes.py)
- GitHub Check: test-model (ubuntu-latest, 3.12, lowest-direct, orb, tests/models/test_orb.py)
- GitHub Check: test-model (macos-14, 3.11, highest, fairchem, tests/models/test_fairchem.py)
- GitHub Check: test-model (ubuntu-latest, 3.12, lowest-direct, mace, tests/test_elastic.py)
- GitHub Check: test-model (ubuntu-latest, 3.12, lowest-direct, sevenn, tests/models/test_sevennet.py)
- GitHub Check: test-model (ubuntu-latest, 3.11, highest, sevenn, tests/models/test_sevennet.py)
- GitHub Check: test-model (ubuntu-latest, 3.12, lowest-direct, fairchem, tests/models/test_fairchem.py)
- GitHub Check: test-model (ubuntu-latest, 3.12, lowest-direct, mace, tests/test_optimizers_vs_ase.py)
- GitHub Check: test-model (ubuntu-latest, 3.11, highest, orb, tests/models/test_orb.py)
- GitHub Check: test-model (ubuntu-latest, 3.11, highest, mace, tests/test_elastic.py)
- GitHub Check: test-model (ubuntu-latest, 3.11, highest, fairchem, tests/models/test_fairchem.py)
- GitHub Check: test-core (macos-14, 3.12, lowest-direct)
- GitHub Check: test-core (ubuntu-latest, 3.12, lowest-direct)
- GitHub Check: build-docs
🔇 Additional comments (2)
torch_sim/integrators/nvt.py (2)
70-71
: Ensure model.device and model.dtype are availableAccessing
model.device
andmodel.dtype
assumes these attributes exist on the passedtorch.nn.Module
. Consider deriving the device and dtype from the input state (e.g.,state.positions.device
andstate.positions.dtype
) or explicitly documenting thatmodel
must expose these properties.
115-123
: Correct per-particle broadcasting in OU stepIndexing
kT
andc1
withstate.batch
ensures both arrays have shape[n_particles]
, soc2 = sqrt(kT*(1-c1²)).unsqueeze(-1)
and the subsequent momentum update broadcast correctly. Nice resolution of the previous broadcasting bug.
- Replace deprecated .T with .mT for matrix transpose on 3D tensors - Add support for batched cell tensors with shape [n_batches, 3, 3] - Extract first batch cell matrix when cell.ndim == 3 - Maintains backward compatibility with 2D cell matrices Fixes batched silicon workflow script that was failing with: - UserWarning about deprecated .T usage on >2D tensors - RuntimeError: linalg.solve: A must be batches of square matrices The get_fractional_coordinates function now properly handles both single [3,3] and batched [n_batches,3,3] cell tensors, enabling a2c_silicon_batched.py workflow to run
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 2
♻️ Duplicate comments (5)
torch_sim/integrators/nvt.py (4)
70-74
:⚠️ Potential issueAvoid truth-testing tensors for the default
gamma
gamma = gamma or 1 / (100 * dt)
still relies on Python truthiness; this blows up whengamma
is atorch.Tensor
(non-zero tensors are not truthy).
Identical feedback was given earlier but the line persists.- gamma = gamma or 1 / (100 * dt) + if gamma is None: + gamma = 1 / (100 * dt)
427-429
:⚠️ Potential issueSame batching bug during update
KE
inside the update step should mirror the fix above; otherwise the chain drifts after the first step.- KE = calc_kinetic_energy(state.momenta, state.masses) + KE = calc_kinetic_energy(state.momenta, state.masses, batch=state.batch)
368-374
: 🛠️ Refactor suggestion
⚠️ Potential issuePer-batch DOF & kinetic energy are required for a batched Nose–Hoover chain
KE
anddof
are computed globally – this silently mixes all batches and breaks thermostat physics for multi-system runs.- KE = calc_kinetic_energy(momenta, state.masses) - dof = count_dof(state.positions) + KE = calc_kinetic_energy(momenta, state.masses, batch=state.batch) + # degrees of freedom per batch (3N – constraints) + dof = torch.bincount(state.batch)
construct_nose_hoover_chain
must accept per-batch tensors afterwards; otherwise the chain length/mass logic will mis-behave.
470-489
:⚠️ Potential issueInvariant mixes scalars and per-batch tensors – broadcasting errors ahead
e_kin
anddof
are scalars whilestate.energy
andkT
may be[n_batches]
.
As soon asn_batches>1
this sums tensors of incompatible shapes.- e_kin = calc_kinetic_energy(state.momenta, state.masses) - dof = count_dof(state.positions) + e_kin = calc_kinetic_energy(state.momenta, state.masses, batch=state.batch) + dof = torch.bincount(state.batch)After that, loop over batches (or vectorise) so that each term has consistent shape before summation.
torch_sim/integrators/npt.py (1)
1109-1113
: Still missing thestate.pbc
guard around wrappingPrevious review already pointed this out; the unconditional call will wrap even for non-periodic systems.
- return ts.transforms.pbc_wrap_batched( - new_positions, state.current_cell, state.batch - ) + if state.pbc: + new_positions = ts.transforms.pbc_wrap_batched( + new_positions, state.current_cell, state.batch + ) + return new_positions
🧹 Nitpick comments (1)
examples/scripts/3_Dynamics/3.8_MACE_NPT_Nose_Hoover.py (1)
55-58
: Nit: variable names mis-match the simulation being run
N_steps_nvt
is used for an NPT loop. Renaming toN_steps_npt_pre
/_prod
(or similar) avoids future confusion.
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (5)
examples/scripts/3_Dynamics/3.12_MACE_NPT_Langevin.py
(2 hunks)examples/scripts/3_Dynamics/3.8_MACE_NPT_Nose_Hoover.py
(2 hunks)torch_sim/integrators/npt.py
(1 hunks)torch_sim/integrators/nvt.py
(1 hunks)torch_sim/transforms.py
(1 hunks)
🚧 Files skipped from review as they are similar to previous changes (1)
- examples/scripts/3_Dynamics/3.12_MACE_NPT_Langevin.py
🧰 Additional context used
🧬 Code Graph Analysis (2)
torch_sim/integrators/nvt.py (4)
torch_sim/integrators/md.py (8)
MDState
(13-48)NoseHooverChain
(195-217)NoseHooverChainFns
(221-232)calculate_momenta
(51-109)momentum_step
(112-129)position_step
(132-156)velocity_verlet
(159-191)velocities
(44-48)torch_sim/quantities.py (2)
calc_kinetic_energy
(96-130)count_dof
(10-19)torch_sim/state.py (4)
SimState
(26-312)device
(137-139)dtype
(142-144)momenta
(319-325)tests/conftest.py (2)
device
(24-25)dtype
(29-30)
torch_sim/integrators/npt.py (4)
torch_sim/integrators/md.py (6)
MDState
(13-48)NoseHooverChain
(195-217)NoseHooverChainFns
(221-232)calculate_momenta
(51-109)construct_nose_hoover_chain
(262-436)velocities
(44-48)torch_sim/quantities.py (1)
calc_kinetic_energy
(96-130)torch_sim/state.py (5)
SimState
(26-312)n_batches
(161-163)n_atoms_per_batch
(152-158)clone
(198-214)volume
(166-168)torch_sim/transforms.py (1)
pbc_wrap_batched
(143-198)
🪛 Pylint (3.3.7)
torch_sim/integrators/nvt.py
[refactor] 275-275: Too many arguments (6/5)
(R0913)
torch_sim/integrators/npt.py
[refactor] 23-23: Too many instance attributes (8/7)
(R0902)
[refactor] 142-142: Too many arguments (8/5)
(R0913)
[refactor] 142-142: Too many local variables (19/15)
(R0914)
[refactor] 388-388: Too many arguments (6/5)
(R0913)
[refactor] 388-388: Too many positional arguments (6/5)
(R0917)
[refactor] 388-388: Too many local variables (17/15)
(R0914)
[refactor] 464-464: Too many local variables (19/15)
(R0914)
[refactor] 557-557: Too many local variables (17/15)
(R0914)
[refactor] 692-692: Too many arguments (6/5)
(R0913)
[refactor] 692-692: Too many positional arguments (6/5)
(R0917)
[refactor] 692-692: Too many local variables (19/15)
(R0914)
[refactor] 142-142: Too many statements (157/50)
(R0915)
[refactor] 807-807: Too many instance attributes (9/7)
(R0902)
[refactor] 899-899: Too many arguments (7/5)
(R0913)
[refactor] 899-899: Too many local variables (18/15)
(R0914)
[refactor] 1115-1115: Too many arguments (6/5)
(R0913)
[refactor] 1115-1115: Too many positional arguments (6/5)
(R0917)
[refactor] 1171-1171: Too many arguments (8/5)
(R0913)
[refactor] 1171-1171: Too many positional arguments (8/5)
(R0917)
[refactor] 1171-1171: Too many local variables (17/15)
(R0914)
[refactor] 1242-1242: Too many local variables (20/15)
(R0914)
[refactor] 1346-1346: Too many local variables (26/15)
(R0914)
[refactor] 899-899: Too many statements (129/50)
(R0915)
[refactor] 1559-1559: Too many local variables (19/15)
(R0914)
[refactor] 1559-1559: Too many branches (13/12)
(R0912)
🪛 GitHub Check: codecov/patch
torch_sim/integrators/npt.py
[warning] 100-100: torch_sim/integrators/npt.py#L100
Added line #L100 was not covered by tests
⏰ Context from checks skipped due to timeout of 90000ms (31)
- GitHub Check: test-examples (examples/tutorials/state_tutorial.py)
- GitHub Check: test-examples (examples/scripts/2_Structural_optimization/2.4_MACE_FIRE.py)
- GitHub Check: test-examples (examples/tutorials/reporting_tutorial.py)
- GitHub Check: test-examples (examples/tutorials/hybrid_swap_tutorial.py)
- GitHub Check: test-examples (examples/tutorials/low_level_tutorial.py)
- GitHub Check: test-examples (examples/tutorials/high_level_tutorial.py)
- GitHub Check: test-examples (examples/scripts/3_Dynamics/3.8_MACE_NPT_Nose_Hoover.py)
- GitHub Check: test-examples (examples/scripts/2_Structural_optimization/2.5_MACE_UnitCellFilter_Gradient_Descen...
- GitHub Check: test-examples (examples/scripts/3_Dynamics/3.9_MACE_NVT_staggered_stress.py)
- GitHub Check: test-examples (examples/scripts/3_Dynamics/3.10_Hybrid_swap_mc.py)
- GitHub Check: test-examples (examples/scripts/7_Others/7.6_Compare_ASE_to_VV_FIRE.py)
- GitHub Check: test-examples (examples/scripts/4_High_level_api/4.1_high_level_api.py)
- GitHub Check: test-examples (examples/scripts/7_Others/7.4_Velocity_AutoCorrelation.py)
- GitHub Check: test-model (ubuntu-latest, 3.12, lowest-direct, sevenn, tests/models/test_sevennet.py)
- GitHub Check: test-model (macos-14, 3.12, lowest-direct, mace, tests/test_optimizers_vs_ase.py)
- GitHub Check: test-model (macos-14, 3.11, highest, orb, tests/models/test_orb.py)
- GitHub Check: test-model (ubuntu-latest, 3.12, lowest-direct, orb, tests/models/test_orb.py)
- GitHub Check: test-model (macos-14, 3.12, lowest-direct, metatomic, tests/models/test_metatomic.py)
- GitHub Check: test-model (ubuntu-latest, 3.12, lowest-direct, mace, tests/test_optimizers_vs_ase.py)
- GitHub Check: test-model (macos-14, 3.11, highest, fairchem, tests/models/test_fairchem.py)
- GitHub Check: test-model (ubuntu-latest, 3.11, highest, orb, tests/models/test_orb.py)
- GitHub Check: test-model (ubuntu-latest, 3.11, highest, mace, tests/test_optimizers_vs_ase.py)
- GitHub Check: test-model (ubuntu-latest, 3.12, lowest-direct, fairchem, tests/models/test_fairchem.py)
- GitHub Check: test-model (ubuntu-latest, 3.12, lowest-direct, mace, tests/test_elastic.py)
- GitHub Check: test-model (ubuntu-latest, 3.11, highest, graphpes, tests/models/test_graphpes.py)
- GitHub Check: test-model (ubuntu-latest, 3.11, highest, mace, tests/test_elastic.py)
- GitHub Check: test-model (ubuntu-latest, 3.11, highest, sevenn, tests/models/test_sevennet.py)
- GitHub Check: test-core (macos-14, 3.12, lowest-direct)
- GitHub Check: test-model (ubuntu-latest, 3.11, highest, fairchem, tests/models/test_fairchem.py)
- GitHub Check: test-core (ubuntu-latest, 3.12, lowest-direct)
- GitHub Check: build-docs
🔇 Additional comments (1)
torch_sim/integrators/npt.py (1)
1428-1432
:calc_kinetic_energy()
likely receives invalid 1-D tensors
cell_momentum
/cell_mass
are 1-D. The helper inquantities.py
expects 2-D momenta (shape[N, dim]
).
This can raise a shape error or silently give wrong KE. Either reshape or compute explicitly:- KE_cell = calc_kinetic_energy(cell_momentum[:1], cell_mass[:1]) + KE_cell = (cell_momentum.pow(2) / (2 * cell_mass)).clone()Please check the expected signature and adjust all other calls (
update
path at line 1544) accordingly.
# Generate batch-specific noise with correct shape | ||
noise = torch.randn_like(state.velocities) | ||
|
||
# Calculate the thermal noise amplitude by batch | ||
batch_kT = kT | ||
if kT.ndim == 0: | ||
batch_kT = kT.expand(state.n_batches) | ||
|
||
# Map batch kT to atoms | ||
atom_kT = batch_kT[state.batch] | ||
|
||
# Calculate the prefactor for each atom | ||
# The standard deviation should be sqrt(2*alpha*kB*T*dt) | ||
prefactor = torch.sqrt(2 * alpha * atom_kT * dt) | ||
|
||
return prefactor.unsqueeze(-1) * noise |
This comment was marked as resolved.
This comment was marked as resolved.
Sorry, something went wrong.
…tic energy calculations - Add missing `batch=state.batch` parameter to `calc_kinetic_energy` calls in NPT integrator - Enhance NVT/NPT invariant functions with explicit broadcasting for chain variables - Replace inefficient manual loops with batch-aware kinetic energy calculation in NPT invariant - Fix undefined variable reference in NPT invariant function (n_batches → state.n_batches) Resolves broadcasting issues where scalar kinetic energy was incorrectly combined with batched energy and temperature tensors, ensuring all energy terms have consistent batch dimensions before addition.
…oordinates Replace problematic batched cell handling that only processed the first batch (cell[0]) with explicit NotImplementedError. This prevents silent data corruption where multi-batch systems would incorrectly use only the first batch's cell parameters for all coordinate transformations. - Raise NotImplementedError for 3D batched cell tensors instead of silently ignoring batches 1, 2, ... N - Preserve existing functionality for 2D cell matrices - Add clear error message indicating limitation and suggesting workarounds Breaking change: Code that previously silently failed will now explicitly error, but this prevents incorrect results in multi-batch scenarios.
- Check state.pbc before applying ts.transforms.pbc_wrap_batched - Return unwrapped positions when state.pbc is False - Ensure consistent behavior with NPT Langevin implementation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 5
♻️ Duplicate comments (4)
torch_sim/integrators/npt.py (2)
240-255
:⚠️ Potential issueCritical: Broadcasting dimension mismatch remains unresolved.
The issue identified in previous reviews regarding incorrect broadcasting in the
beta()
function has not been properly addressed. The calculation on line 253 still uses batch-levelalpha
anddt
with atom-levelatom_kT
, which will cause dimension mismatches in multi-batch simulations.Apply the previously suggested fix:
- # Calculate the prefactor for each atom - # The standard deviation should be sqrt(2*alpha*kB*T*dt) - prefactor = torch.sqrt(2 * alpha * atom_kT * dt) + # Map batch parameters to atom level + alpha_atoms = alpha if alpha.ndim == 0 else alpha[state.batch] + dt_atoms = dt if dt.ndim == 0 else dt[state.batch] + + # Standard deviation √(2 α kT Δt) for every atom + prefactor = torch.sqrt(2 * alpha_atoms * atom_kT * dt_atoms)
1111-1113
:⚠️ Potential issueCritical: Missing PBC check in position wrapping remains unresolved.
The function still always applies periodic boundary wrapping without checking
state.pbc
, which is inconsistent with other implementations and could cause incorrect behavior when PBC is disabled.Apply the previously suggested fix:
- # Apply periodic boundary conditions - return ts.transforms.pbc_wrap_batched( - new_positions, state.current_cell, state.batch - ) + # Apply periodic boundary conditions if enabled + if state.pbc: + return ts.transforms.pbc_wrap_batched( + new_positions, state.current_cell, state.batch + ) + return new_positionstorch_sim/integrators/nvt.py (1)
72-76
: 🛠️ Refactor suggestionAvoid truth-testing Tensors for default gamma.
Using
gamma = gamma or 1 / (100 * dt)
triggers ambiguous boolean logic whengamma
is a Tensor. Switch to an explicitif gamma is None
guard.- gamma = gamma or 1 / (100 * dt) + if gamma is None: + gamma = 1 / (100 * dt)torch_sim/optimizers.py (1)
1456-1470
:cur_deform_grad
undefined when no cell optimization.This is the same issue identified in previous reviews where
cur_deform_grad
is only initialized inside theis_cell_optimization
conditional blocks but is later used unguarded, causingUnboundLocalError
during pure atomic optimization.
🧹 Nitpick comments (4)
torch_sim/integrators/npt.py (2)
142-803
: Consider refactoring for improved maintainability.The
npt_langevin
function is extremely large (157 statements) with many nested functions, making it difficult to test, debug, and maintain. Consider extracting more of the nested functions as standalone functions or organizing the code into classes.Benefits of refactoring:
- Easier unit testing of individual components
- Better code reusability across integrators
- Improved readability and debugging
- Reduced cognitive complexity
🧰 Tools
🪛 Pylint (3.3.7)
[refactor] 142-142: Too many arguments (8/5)
(R0913)
[refactor] 142-142: Too many local variables (19/15)
(R0914)
[refactor] 388-388: Too many arguments (6/5)
(R0913)
[refactor] 388-388: Too many positional arguments (6/5)
(R0917)
[refactor] 388-388: Too many local variables (17/15)
(R0914)
[refactor] 464-464: Too many local variables (19/15)
(R0914)
[refactor] 557-557: Too many local variables (17/15)
(R0914)
[refactor] 692-692: Too many arguments (6/5)
(R0913)
[refactor] 692-692: Too many positional arguments (6/5)
(R0917)
[refactor] 692-692: Too many local variables (19/15)
(R0914)
[refactor] 142-142: Too many statements (157/50)
(R0915)
899-1558
: Consider refactoring for improved maintainability.The
npt_nose_hoover
function is also very large (129 statements) with high complexity. Similar to the Langevin integrator, this would benefit from breaking down into smaller, more focused functions or a class-based approach.🧰 Tools
🪛 Pylint (3.3.7)
[refactor] 899-899: Too many arguments (7/5)
(R0913)
[refactor] 899-899: Too many local variables (18/15)
(R0914)
[refactor] 1115-1115: Too many arguments (6/5)
(R0913)
[refactor] 1115-1115: Too many positional arguments (6/5)
(R0917)
[refactor] 1171-1171: Too many arguments (8/5)
(R0913)
[refactor] 1171-1171: Too many positional arguments (8/5)
(R0917)
[refactor] 1171-1171: Too many local variables (17/15)
(R0914)
[refactor] 1242-1242: Too many local variables (20/15)
(R0914)
[refactor] 1346-1346: Too many local variables (26/15)
(R0914)
[refactor] 899-899: Too many statements (129/50)
(R0915)
torch_sim/integrators/nvt.py (2)
275-283
: High parameter count innvt_nose_hoover
.This function’s signature uses six keyword-only parameters, exceeding typical argument limits. Consider grouping thermostat settings (e.g.,
chain_length
,chain_steps
,sy_steps
,dt
,kT
) into a dedicated config/dataclass to simplify the API.🧰 Tools
🪛 Pylint (3.3.7)
[refactor] 275-275: Too many arguments (6/5)
(R0913)
457-461
: Docstring parameter mismatch innvt_nose_hoover_invariant
.The
Args
section lists anenergy_fn
parameter that doesn’t exist. Remove or update this entry to match the actual(state, kT)
signature.
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (8)
tests/test_state.py
(3 hunks)tests/test_voigt.py
(1 hunks)torch_sim/elastic.py
(3 hunks)torch_sim/integrators/npt.py
(1 hunks)torch_sim/integrators/nvt.py
(1 hunks)torch_sim/optimizers.py
(3 hunks)torch_sim/state.py
(5 hunks)torch_sim/transforms.py
(1 hunks)
✅ Files skipped from review due to trivial changes (2)
- tests/test_voigt.py
- tests/test_state.py
🚧 Files skipped from review as they are similar to previous changes (2)
- torch_sim/transforms.py
- torch_sim/state.py
🧰 Additional context used
🧬 Code Graph Analysis (3)
torch_sim/optimizers.py (1)
torch_sim/state.py (4)
reference_row_vector_cell
(328-330)reference_row_vector_cell
(333-335)row_vector_cell
(185-187)row_vector_cell
(190-196)
torch_sim/elastic.py (1)
torch_sim/state.py (2)
row_vector_cell
(185-187)row_vector_cell
(190-196)
torch_sim/integrators/nvt.py (4)
torch_sim/integrators/md.py (8)
MDState
(13-48)NoseHooverChain
(195-217)NoseHooverChainFns
(221-232)calculate_momenta
(51-109)momentum_step
(112-129)position_step
(132-156)velocity_verlet
(159-191)velocities
(44-48)torch_sim/quantities.py (2)
calc_kinetic_energy
(96-130)count_dof
(10-19)torch_sim/state.py (5)
SimState
(26-312)device
(137-139)dtype
(142-144)momenta
(319-325)n_atoms_per_batch
(152-158)tests/conftest.py (2)
device
(24-25)dtype
(29-30)
🪛 Pylint (3.3.7)
torch_sim/integrators/npt.py
[refactor] 23-23: Too many instance attributes (8/7)
(R0902)
[refactor] 142-142: Too many arguments (8/5)
(R0913)
[refactor] 142-142: Too many local variables (19/15)
(R0914)
[refactor] 388-388: Too many arguments (6/5)
(R0913)
[refactor] 388-388: Too many positional arguments (6/5)
(R0917)
[refactor] 388-388: Too many local variables (17/15)
(R0914)
[refactor] 464-464: Too many local variables (19/15)
(R0914)
[refactor] 557-557: Too many local variables (17/15)
(R0914)
[refactor] 692-692: Too many arguments (6/5)
(R0913)
[refactor] 692-692: Too many positional arguments (6/5)
(R0917)
[refactor] 692-692: Too many local variables (19/15)
(R0914)
[refactor] 142-142: Too many statements (157/50)
(R0915)
[refactor] 807-807: Too many instance attributes (9/7)
(R0902)
[refactor] 899-899: Too many arguments (7/5)
(R0913)
[refactor] 899-899: Too many local variables (18/15)
(R0914)
[refactor] 1115-1115: Too many arguments (6/5)
(R0913)
[refactor] 1115-1115: Too many positional arguments (6/5)
(R0917)
[refactor] 1171-1171: Too many arguments (8/5)
(R0913)
[refactor] 1171-1171: Too many positional arguments (8/5)
(R0917)
[refactor] 1171-1171: Too many local variables (17/15)
(R0914)
[refactor] 1242-1242: Too many local variables (20/15)
(R0914)
[refactor] 1346-1346: Too many local variables (26/15)
(R0914)
[refactor] 899-899: Too many statements (129/50)
(R0915)
torch_sim/integrators/nvt.py
[refactor] 275-275: Too many arguments (6/5)
(R0913)
🪛 GitHub Check: codecov/patch
torch_sim/elastic.py
[warning] 815-815: torch_sim/elastic.py#L815
Added line #L815 was not covered by tests
⏰ Context from checks skipped due to timeout of 90000ms (34)
- GitHub Check: test-examples (examples/tutorials/reporting_tutorial.py)
- GitHub Check: test-examples (examples/tutorials/state_tutorial.py)
- GitHub Check: test-examples (examples/tutorials/high_level_tutorial.py)
- GitHub Check: test-examples (examples/tutorials/low_level_tutorial.py)
- GitHub Check: test-examples (examples/scripts/2_Structural_optimization/2.4_MACE_FIRE.py)
- GitHub Check: test-examples (examples/scripts/5_Workflow/5.2_In_Flight_WBM.py)
- GitHub Check: test-examples (examples/scripts/2_Structural_optimization/2.3_MACE_Gradient_Descent.py)
- GitHub Check: test-examples (examples/scripts/6_Phonons/6.2_QuasiHarmonic_MACE.py)
- GitHub Check: test-examples (examples/scripts/3_Dynamics/3.3_MACE_NVE_cueq.py)
- GitHub Check: test-examples (examples/scripts/3_Dynamics/3.9_MACE_NVT_staggered_stress.py)
- GitHub Check: test-examples (examples/scripts/3_Dynamics/3.13_MACE_NVE_non_pbc.py)
- GitHub Check: test-examples (examples/scripts/3_Dynamics/3.4_MACE_NVT_Langevin.py)
- GitHub Check: test-examples (examples/scripts/3_Dynamics/3.8_MACE_NPT_Nose_Hoover.py)
- GitHub Check: test-examples (examples/scripts/4_High_level_api/4.1_high_level_api.py)
- GitHub Check: test-examples (examples/scripts/7_Others/7.1_Soft_sphere_autograd.py)
- GitHub Check: test-examples (examples/scripts/7_Others/7.4_Velocity_AutoCorrelation.py)
- GitHub Check: test-examples (examples/scripts/7_Others/7.6_Compare_ASE_to_VV_FIRE.py)
- GitHub Check: test-model (ubuntu-latest, 3.12, lowest-direct, orb, tests/models/test_orb.py)
- GitHub Check: test-model (macos-14, 3.12, lowest-direct, sevenn, tests/models/test_sevennet.py)
- GitHub Check: test-model (macos-14, 3.12, lowest-direct, graphpes, tests/models/test_graphpes.py)
- GitHub Check: test-model (macos-14, 3.11, highest, orb, tests/models/test_orb.py)
- GitHub Check: test-model (ubuntu-latest, 3.11, highest, sevenn, tests/models/test_sevennet.py)
- GitHub Check: test-model (ubuntu-latest, 3.12, lowest-direct, sevenn, tests/models/test_sevennet.py)
- GitHub Check: test-model (macos-14, 3.11, highest, mace, tests/test_elastic.py)
- GitHub Check: test-model (ubuntu-latest, 3.12, lowest-direct, mace, tests/test_optimizers_vs_ase.py)
- GitHub Check: test-model (ubuntu-latest, 3.12, lowest-direct, mace, tests/test_elastic.py)
- GitHub Check: test-model (ubuntu-latest, 3.11, highest, orb, tests/models/test_orb.py)
- GitHub Check: test-model (ubuntu-latest, 3.12, lowest-direct, fairchem, tests/models/test_fairchem.py)
- GitHub Check: test-model (ubuntu-latest, 3.11, highest, mace, tests/test_optimizers_vs_ase.py)
- GitHub Check: test-model (ubuntu-latest, 3.11, highest, mace, tests/test_elastic.py)
- GitHub Check: test-model (ubuntu-latest, 3.11, highest, fairchem, tests/models/test_fairchem.py)
- GitHub Check: test-core (macos-14, 3.12, lowest-direct)
- GitHub Check: test-core (ubuntu-latest, 3.12, lowest-direct)
- GitHub Check: build-docs
🔇 Additional comments (9)
torch_sim/integrators/npt.py (3)
22-73
: Well-structured dataclass with comprehensive documentation.The NPTLangevinState dataclass is properly designed with clear type hints, thorough docstrings, and a correct momenta property implementation.
🧰 Tools
🪛 Pylint (3.3.7)
[refactor] 23-23: Too many instance attributes (8/7)
(R0902)
76-140
: Correct implementation of cell force calculation for NPT dynamics.The physics and mathematics in this function are sound, properly handling tensor conversions, volume calculations, virial computations, and kinetic contributions. Good practice extracting this for separate testing.
1561-1669
: Correct implementation of NPT conserved quantity calculation.The
npt_nose_hoover_invariant
function properly computes the extended Hamiltonian with all necessary energy contributions (potential, kinetic, thermostat chains, barostat chains, PV work, and cell kinetic energy). The mathematics and broadcasting logic are correct for validation of NPT simulations.torch_sim/elastic.py (2)
674-674
: LGTM! Matrix transpose modernization improves readability.The replacement of
.transpose(-2, -1)
with.mT
is functionally equivalent and aligns with the codebase-wide modernization effort. This change improves code readability while maintaining the same matrix transposition behavior.
875-875
: LGTM! Symmetric stress tensor enforcement using modern syntax.The replacement with
.mT
correctly enforces symmetry in the stress tensor(stress + stress.mT) / 2
and aligns with the modernization pattern used throughout the codebase.torch_sim/optimizers.py (4)
404-404
: LGTM! Matrix transpose modernization in batch operation.The replacement of
.transpose(-2, -1)
with.mT
in the batch matrix multiplicationtorch.bmm(state.reference_row_vector_cell, cell_update.mT)
is functionally equivalent and improves code readability. This aligns with the codebase-wide modernization effort.
1570-1570
: LGTM! Consistent matrix transpose syntax in Frechet cell update.The use of
.mT
intorch.bmm(state.reference_row_vector_cell, F_new.mT)
is correct and maintains consistency with the modernization pattern used throughout the codebase.
1581-1581
: LGTM! Matrix transpose modernization in unit cell update.The replacement with
.mT
intorch.bmm(state.reference_row_vector_cell, F_new.mT)
is functionally equivalent to the previous.transpose(-2, -1)
and improves readability.
1585-1585
: LGTM! Modern matrix transpose syntax for position update.The use of
.mT
intorch.bmm(state.positions.unsqueeze(1), F_new[state.batch].mT)
correctly transposes the deformation gradient matrices and aligns with the codebase modernization effort.
device, dtype = model.device, model.dtype | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Invalid use of model.device
and model.dtype
.
torch.nn.Module
has no device
or dtype
attributes—this will raise an AttributeError
. Derive device
and dtype
from the state
(e.g., state.device
, state.dtype
) inside your init/update functions instead.
🤖 Prompt for AI Agents
In torch_sim/integrators/nvt.py around lines 70 to 71, the code incorrectly
accesses model.device and model.dtype, which do not exist on torch.nn.Module and
will cause an AttributeError. To fix this, replace model.device and model.dtype
with state.device and state.dtype, deriving device and dtype from the state
object instead.
) -> tuple[ | ||
Callable[[SimState | StateDict, torch.Tensor, int | None, Any], NVTNoseHooverState], | ||
Callable[[NVTNoseHooverState, torch.Tensor], NVTNoseHooverState], | ||
]: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
🛠️ Refactor suggestion
Type hints mismatch for Nose–Hoover callables.
The declared Callable
return types don’t align with the actual signatures:
- The init fn takes
(state, kT, tau, seed, **kwargs)
. - The update fn takes
(state, dt, kT)
.
Please correct these type annotations.
🤖 Prompt for AI Agents
In torch_sim/integrators/nvt.py around lines 283 to 286, the type hints for the
Nose–Hoover callables are incorrect. Update the first Callable to accept
parameters (state, kT, tau, seed, **kwargs) and the second Callable to accept
(state, dt, kT). Adjust their return types accordingly to match the actual
function signatures.
torch_sim/integrators/nvt.py
Outdated
# For now, use the total degrees of freedom as chain expects an int | ||
# This is a limitation that should be addressed in the chain implementation | ||
total_dof = count_dof(state.positions) |
This comment was marked as resolved.
This comment was marked as resolved.
Sorry, something went wrong.
) -> tuple[ | ||
Callable[[SimState | StateDict, torch.Tensor], MDState], | ||
Callable[[MDState, torch.Tensor], MDState], | ||
]: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
🛠️ Refactor suggestion
Incorrect type hints for returned callables.
The Callable
annotations for langevin_init
and langevin_update
don’t match their actual signatures (e.g., missing seed
, kT
, and gamma
parameters on the init, and missing kT
/gamma
on the update). Update the return type hints to reflect the true parameter lists.
🤖 Prompt for AI Agents
In torch_sim/integrators/nvt.py around lines 31 to 34, the return type hints for
the two callables are incorrect because they do not include all parameters such
as seed, kT, and gamma for the init function, and kT and gamma for the update
function. Update the Callable type annotations to include these missing
parameters in their signatures so that the type hints accurately reflect the
actual function parameters.
@@ -812,7 +812,7 @@ def get_strain( | |||
u = torch.matmul(reference_inverse, cell_difference) | |||
|
|||
# Compute symmetric strain tensor: ε = (u + u^T)/2 | |||
strain = (u + u.transpose(-2, -1)) / 2 | |||
strain = (u + u.mT) / 2 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
💡 Verification agent
🧩 Analysis chain
LGTM! Consistent symmetric strain tensor calculation.
The use of .mT
for the symmetric strain tensor calculation (u + u.mT) / 2
is correct and improves readability. However, note that this line lacks test coverage according to static analysis.
🏁 Script executed:
#!/bin/bash
# Check test coverage for the get_strain function
rg -A 10 -B 5 "def test.*strain" --type py
Length of output: 44
To ensure we’re not missing any existing tests for get_strain
or the symmetric-strain calculation, let’s broaden our search:
🏁 Script executed:
#!/bin/bash
# 1. Locate the get_strain definition
rg -n "def\s\+get_strain" torch_sim/elastic.py
# 2. Look for any direct invocation of get_strain in test files
rg -n "get_strain" -g "*test*.py"
# 3. Search for any mention of 'strain' in test files (first 50 matches)
rg -n "strain" -g "*test*.py" | head -50
Length of output: 696
Let’s locate the enclosing function around that symmetric‐strain calculation so we can verify its test coverage:
🏁 Script executed:
#!/bin/bash
# Show 5 lines before and after the symmetric-strain line to get the function signature
rg -n -C 5 "strain = (u + u.mT)" torch_sim/elastic.py
Length of output: 55
Let’s precisely locate the symmetric‐strain line and its surrounding function definition:
🏁 Script executed:
#!/bin/bash
# 1. Find any occurrences of “.mT” to locate the snippet
rg -n "\.mT" -C 10 torch_sim/elastic.py
# 2. Dump lines 800–830 to inspect the function signature around line 815
sed -n '800,830p' torch_sim/elastic.py
Length of output: 3318
Missing Test Coverage for Symmetric Strain Tensor
I didn’t find any existing tests that exercise the symmetric‐strain calculation in get_elementary_deformations
(torch_sim/elastic.py:815). Please add unit tests to validate that for known deformations (u + u.mT)/2
produces the expected symmetric strain tensor (e.g., zero strain, pure shear, hydrostatic strain).
• File: torch_sim/elastic.py
– Line 815: strain = (u + u.mT) / 2
🧰 Tools
🪛 GitHub Check: codecov/patch
[warning] 815-815: torch_sim/elastic.py#L815
Added line #L815 was not covered by tests
🤖 Prompt for AI Agents
In torch_sim/elastic.py at line 815, the symmetric strain tensor calculation
using (u + u.mT) / 2 lacks test coverage. Add unit tests for the function
get_elementary_deformations that include cases with known deformation tensors
such as zero strain, pure shear, and hydrostatic strain. Verify that the
computed strain matches the expected symmetric strain tensor in each case to
ensure this line is properly tested.
Replace count_dof() with proper batch-aware DOF calculation using torch.bincount(state.batch) to ensure consistency with invariant function. Previously ignored batch structure, causing incorrect DOF for batched systems.
…pport in transforms.py's get_fractional_coordinates - Handle batched cell tensors with shape [1, 3, 3] by auto-squeezing to [3, 3] - Improve error messages for multi-batch cases to be more informative - Add comprehensive tests for batched cell tensor scenarios - Fixes NotImplementedError in examples/scripts/5_Workflow/5.1_a2c_silicon_batched.py:159 - Maintains full backward compatibility with existing 2D cell matrix usage
Warning Rate limit exceeded@janosh has exceeded the limit for the number of commits or files that can be reviewed per hour. Please wait 3 minutes and 41 seconds before requesting another review. ⌛ How to resolve this issue?After the wait time has elapsed, a review can be triggered using the We recommend that you space out your commits to avoid hitting the rate limit. 🚦 How do rate limits work?CodeRabbit enforces hourly rate limits for each developer per organization. Our paid plans have higher rate limits than the trial, open-source and free plans. In all cases, we re-allow further reviews after a brief timeout. Please see our FAQ for further information. 📒 Files selected for processing (76)
WalkthroughThis update removes all unbatched model, integrator, and optimizer code, along with their corresponding tests, scripts, and documentation. The main codebase is refactored to use only batched models and batched integrators/optimizers. New modularized integrator implementations are introduced, and all scripts, tests, and workflow utilities are updated to use the new batched APIs and model classes. Several utility and validation improvements are also included. Changes
Sequence Diagram(s)Example: Batched Model Forward PasssequenceDiagram
participant Script
participant BatchedModel
participant PairPotential
Script->>BatchedModel: forward(state)
BatchedModel->>PairPotential: compute energies/forces (batched)
PairPotential-->>BatchedModel: returns tensors
BatchedModel-->>Script: returns result dict (batched tensors)
Example: Batched Integrator Initialization and UpdatesequenceDiagram
participant UserScript
participant Integrator
participant Model
UserScript->>Integrator: init_fn(state, kT)
Integrator->>Model: model(state)
Model-->>Integrator: energy, forces
Integrator-->>UserScript: MDState
loop MD steps
UserScript->>Integrator: update_fn(MDState, dt)
Integrator->>Model: model(new_state)
Model-->>Integrator: energy, forces
Integrator-->>UserScript: updated MDState
end
Poem
✨ Finishing Touches
🪧 TipsChatThere are 3 ways to chat with CodeRabbit:
SupportNeed help? Create a ticket on our support page for assistance with any issues or questions. Note: Be mindful of the bot's finite context window. It's strongly recommended to break down tasks such as reading entire modules into smaller chunks. For a focused discussion, use review comments to chat about specific files and their changes, instead of using the PR comments. CodeRabbit Commands (Invoked using PR comments)
Other keywords and placeholders
CodeRabbit Configuration File (
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
that was a BIG lift!
Summary
Unbatched code was useful during the development of this package. Initially it was included in the released package because of it's pedagogical value. However going forward it was decided that the educational benefits do not outweigh the increased maintenance burden and so it is being removed.
Checklist
Before a pull request can be merged, the following items must be checked:
Run ruff on your code.
We highly recommended installing the pre-commit hooks running in CI locally to speedup the development process. Simply run
pip install pre-commit && pre-commit install
to install the hooks which will check your code before each commit.Summary by CodeRabbit
New Features
Refactor
Bug Fixes
Documentation
Chores
Tests