-
Notifications
You must be signed in to change notification settings - Fork 42
Improve Typing of ModelInterface #215
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
WalkthroughThe Changes
Estimated code review effort🎯 4 (Complex) | ⏱️ ~45 minutes Possibly related PRs
Suggested labels
Poem
Note 🔌 MCP (Model Context Protocol) integration is now available in Early Access!Pro users can now connect to remote MCP servers under the Integrations page to get reviews and chat conversations that understand additional development context. ✨ Finishing Touches
🧪 Generate unit tests
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. 🪧 TipsChatThere are 3 ways to chat with CodeRabbit:
SupportNeed help? Create a ticket on our support page for assistance with any issues or questions. CodeRabbit Commands (Invoked using PR comments)
Other keywords and placeholders
CodeRabbit Configuration File (
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 1
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (1)
torch_sim/workflows/a2c.py
(2 hunks)
🧰 Additional context used
🧬 Code Graph Analysis (1)
torch_sim/workflows/a2c.py (1)
torch_sim/optimizers.py (1)
FireState
(441-482)
🪛 GitHub Check: codecov/patch
torch_sim/workflows/a2c.py
[warning] 825-825: torch_sim/workflows/a2c.py#L825
Added line #L825 was not covered by tests
[warning] 827-827: torch_sim/workflows/a2c.py#L827
Added line #L827 was not covered by tests
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (37)
- GitHub Check: test-examples (examples/scripts/5_Workflow/5.1_a2c_silicon_batched.py)
- GitHub Check: test-examples (examples/scripts/5_Workflow/5.3_Elastic.py)
- GitHub Check: test-examples (examples/scripts/5_Workflow/5.2_In_Flight_WBM.py)
- GitHub Check: test-examples (examples/scripts/4_High_level_api/4.2_auto_batching_api.py)
- GitHub Check: test-examples (examples/scripts/3_Dynamics/3.3_MACE_NVE_cueq.py)
- GitHub Check: test-examples (examples/scripts/3_Dynamics/3.8_MACE_NPT_Nose_Hoover.py)
- GitHub Check: test-examples (examples/scripts/2_Structural_optimization/2.6_MACE_UnitCellFilter_FIRE.py)
- GitHub Check: test-examples (examples/scripts/2_Structural_optimization/2.1_Lennard_Jones_FIRE.py)
- GitHub Check: test-examples (examples/scripts/3_Dynamics/3.11_Lennard_Jones_NPT_Langevin.py)
- GitHub Check: test-examples (examples/scripts/1_Introduction/1.2_MACE.py)
- GitHub Check: test-examples (examples/scripts/4_High_level_api/4.1_high_level_api.py)
- GitHub Check: test-examples (examples/scripts/6_Phonons/6.2_QuasiHarmonic_MACE.py)
- GitHub Check: test-examples (examples/scripts/3_Dynamics/3.6_MACE_NVT_Nose_Hoover_temp_profile.py)
- GitHub Check: test-examples (examples/scripts/7_Others/7.4_Velocity_AutoCorrelation.py)
- GitHub Check: test-examples (examples/scripts/7_Others/7.6_Compare_ASE_to_VV_FIRE.py)
- GitHub Check: test-examples (examples/scripts/3_Dynamics/3.2_MACE_NVE.py)
- GitHub Check: test-examples (examples/tutorials/hybrid_swap_tutorial.py)
- GitHub Check: test-examples (examples/tutorials/autobatching_tutorial.py)
- GitHub Check: test-examples (examples/tutorials/high_level_tutorial.py)
- GitHub Check: test-examples (examples/tutorials/low_level_tutorial.py)
- GitHub Check: test-model (ubuntu-latest, 3.12, lowest-direct, sevenn, tests/models/test_sevennet.py)
- GitHub Check: test-model (macos-14, 3.12, lowest-direct, mace, tests/models/test_mace.py)
- GitHub Check: test-model (macos-14, 3.12, lowest-direct, mace, tests/test_optimizers_vs_ase.py)
- GitHub Check: test-model (macos-14, 3.11, highest, fairchem, tests/models/test_fairchem.py)
- GitHub Check: test-model (macos-14, 3.11, highest, mace, tests/test_elastic.py)
- GitHub Check: test-model (ubuntu-latest, 3.12, lowest-direct, orb, tests/models/test_orb.py)
- GitHub Check: test-model (ubuntu-latest, 3.12, lowest-direct, mace, tests/test_elastic.py)
- GitHub Check: test-model (ubuntu-latest, 3.11, highest, fairchem, tests/models/test_fairchem.py)
- GitHub Check: test-model (ubuntu-latest, 3.12, lowest-direct, mace, tests/test_optimizers_vs_ase.py)
- GitHub Check: test-model (ubuntu-latest, 3.11, highest, sevenn, tests/models/test_sevennet.py)
- GitHub Check: test-model (ubuntu-latest, 3.12, lowest-direct, fairchem, tests/models/test_fairchem.py)
- GitHub Check: test-model (ubuntu-latest, 3.11, highest, mace, tests/test_elastic.py)
- GitHub Check: test-model (ubuntu-latest, 3.11, highest, orb, tests/models/test_orb.py)
- GitHub Check: test-model (ubuntu-latest, 3.11, highest, mace, tests/test_optimizers_vs_ase.py)
- GitHub Check: test-core (ubuntu-latest, 3.12, lowest-direct)
- GitHub Check: test-core (macos-14, 3.12, lowest-direct)
- GitHub Check: build-docs
🔇 Additional comments (1)
torch_sim/workflows/a2c.py (1)
825-827
: LGTM! Explicit typing improves type safety.The explicit
float
typing forfinal_energy
andfinal_pressure
helps the type checker understand the variable types and improves code clarity.
uvx ty check
uvx ty check
or uvx mypy torch_sim --strict-equality
uvx ty check
or uvx mypy torch_sim --strict-equality
uvx basedpyright
For visibility we are probably going to type the entire repo in parts. We'll merge this PR soon or break it up into more understandable parts |
|
||
|
||
class ModelInterface(ABC): | ||
class ModelInterface(torch.nn.Module, ABC): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nice change
@coderabbitai review |
✅ Actions performedReview triggered.
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 7
🔭 Outside diff range comments (14)
torch_sim/models/particle_life.py (2)
193-193
: Return type bug: energy stored as Python float instead of torch.Tensorresults["energy"] is set to 0.0 (float), but the return type is dict[str, torch.Tensor] and forward later stacks energies across systems. This will break stacking and typing.
Apply this diff:
- results = {"energy": 0.0} + results = { + "energy": torch.tensor(0.0, dtype=self.dtype, device=self.device) + }
196-199
: Incorrect call signature: unexpected keyword 'epsilon' and missing required 'A'asymmetric_particle_pair_force_jit(dr, A, beta, sigma) does not accept epsilon. Passing epsilon=... will fail type-checks and at runtime.
Use A=self.epsilon and keep sigma keyword:
- pair_forces = asymmetric_particle_pair_force_jit( - distances, sigma=self.sigma, epsilon=self.epsilon - ) + pair_forces = asymmetric_particle_pair_force_jit( + distances, + A=self.epsilon, + sigma=self.sigma, + )Optionally pass beta if you intend a non-default inner radius:
+ beta=DEFAULT_BETA,
torch_sim/models/fairchem.py (1)
324-341
: Alignforward
return type withModelInterface
The
forward
method intorch_sim/models/fairchem.py
must use the same return annotation as the base interface to satisfy static checks.• File:
torch_sim/models/fairchem.py
,forward
definition (around lines 324–341)
• Change return type fromdict
todict[str, torch.Tensor]
- def forward(self, state: ts.SimState | StateDict) -> dict: + def forward(self, state: ts.SimState | StateDict) -> dict[str, torch.Tensor]:Also ensure every value in the returned dictionary is a
torch.Tensor
.torch_sim/models/orb.py (1)
333-347
: Fix model-path handling and isinstance union for portabilityTwo issues in init:
self.system_config = system_config or model.system_config
runs before path loading; ifmodel
is a str/Path, this fails.isinstance(model, str | Path)
may not be portable; use a tuple for compatibility.- # Set up system configuration - self.system_config = system_config or model.system_config - self._max_num_neighbors = max_num_neighbors - self._edge_method = edge_method - self._half_supercell = half_supercell - self.conservative = conservative - - # Load model if path is provided - if isinstance(model, str | Path): - model = torch.load(model, map_location=self._device) + # Load model if path is provided + if isinstance(model, (str, Path)): + model = torch.load(model, map_location=self._device) + + # Set up system configuration (after model is resolved) + self.system_config = system_config or model.system_config + self._max_num_neighbors = max_num_neighbors + self._edge_method = edge_method + self._half_supercell = half_supercell + self.conservative = conservativetorch_sim/models/graphpes.py (1)
102-142
: Override signature should match ModelInterface (add**kwargs
)ModelInterface.forward accepts
**kwargs
. Overriding without**kwargs
can trigger basedpyright override-compatibility errors. Accept and ignore extra kwargs in the wrapper.Apply this diff to the method signature:
- def forward(self, state: ts.SimState | StateDict) -> dict[str, torch.Tensor]: + def forward( + self, state: ts.SimState | StateDict, **kwargs: Any + ) -> dict[str, torch.Tensor]:Optional: instead of
# type: ignore[return-value]
, cast to the expected mapping to keep the file clean for basedpyright:- return self._gp_model.predict(atomic_graph, self._properties) # type: ignore[return-value] + return typing.cast( + dict[str, torch.Tensor], + self._gp_model.predict(atomic_graph, self._properties), + )torch_sim/models/metatomic.py (2)
51-62
: Override signature should match ModelInterface (add**kwargs
)ModelInterface.forward includes
**kwargs
. This override omits it, which can fail type checking with basedpyright. Accept and ignore extra kwargs.- def forward( # noqa: C901, PLR0915 - self, - state: ts.SimState | StateDict, - ) -> dict[str, torch.Tensor]: + def forward( # noqa: C901, PLR0915 + self, + state: ts.SimState | StateDict, + **kwargs: Any, + ) -> dict[str, torch.Tensor]:
99-116
:Path
handling bug:.endswith
used onPath
model
is typedstr | Path | None
but the implementation callsmodel.endswith(...)
. This will raise onPath
. Normalize viaPath(model)
when not using the literal"pet-mad"
.- if model == "pet-mad": + if model == "pet-mad": path = "https://huggingface.co/lab-cosmo/pet-mad/resolve/v1.1.0/models/pet-mad-v1.1.0.ckpt" self._model = load_model(path).export() - elif model.endswith(".ckpt"): - path = model - self._model = load_model(path).export() - elif model.endswith(".pt"): - path = model - self._model = load_atomistic_model(path, extensions_path) + else: + path_obj = Path(model) # type: ignore[arg-type] + if path_obj.suffix == ".ckpt": + self._model = load_model(str(path_obj)).export() + elif path_obj.suffix == ".pt": + self._model = load_atomistic_model(str(path_obj), extensions_path) + else: + raise ValueError('Model must be a path to a .ckpt/.pt file, or "pet-mad".') - else: - raise ValueError('Model must be a path to a .ckpt/.pt file, or "pet-mad".')torch_sim/models/interface.py (1)
135-170
: Ensure allforward
overrides include**kwargs
The base
ModelInterface.forward
declares**kwargs
to support model-specific options and satisfy LSP/type checking. The following subclasses currently defineforward
without**kwargs
and should be updated:
- torch_sim/models/fairchem.py:324
- torch_sim/models/graphpes.py:177
- torch_sim/models/lennard_jones.py:356
- torch_sim/models/mace.py:240
- torch_sim/models/mattersim.py:113
- torch_sim/models/metatomic.py:151
- torch_sim/models/morse.py:346
- torch_sim/models/orb.py:361
- torch_sim/models/particle_life.py:213
- torch_sim/models/sevennet.py:152
- torch_sim/models/soft_sphere.py:380, 807
Please revise each signature, for example:
- def forward(self, state: SimState | StateDict) -> dict[str, torch.Tensor]: + def forward(self, state: SimState | StateDict, **kwargs) -> dict[str, torch.Tensor]:This ensures all overrides remain signature-compatible and maintain proper extensibility.
torch_sim/models/soft_sphere.py (5)
283-285
: Fix wrong masses shape when constructing SimState from dict
torch.ones_like(state["positions"])
produces shape [n_atoms, 3], but masses must be [n_atoms]. This will break both runtime and types.Apply:
- if isinstance(state, dict): - state = ts.SimState(**state, masses=torch.ones_like(state["positions"])) + if isinstance(state, dict): + masses = torch.ones( + state["positions"].shape[0], dtype=self.dtype, device=self.device + ) + state = ts.SimState(**state, masses=masses)
351-353
: Guard against division by zero in force projectionIf any pair has zero distance (overlap),
(pair_forces / distances)
will produce inf/NaN. Clamp/guard the denominator.- force_vectors = (pair_forces / distances)[:, None] * dr_vec + denom = torch.where(distances > 0, distances, torch.ones_like(distances)) + force_vectors = (pair_forces / denom)[:, None] * dr_vec
606-610
: Breaks when species=None: cannot infer n_species
n_species = len(torch.unique(species))
fails ifspecies is None
. You either needspecies
or a way to infer n_species from provided matrices. Minimal fix: infer from any provided *_matrix; otherwise requirespecies
.- self.species = species - n_species = len(torch.unique(species)) + self.species = species + if species is None: + if sigma_matrix is not None: + n_species = sigma_matrix.shape[0] + elif epsilon_matrix is not None: + n_species = epsilon_matrix.shape[0] + elif alpha_matrix is not None: + n_species = alpha_matrix.shape[0] + else: + raise ValueError( + "Provide `species` or at least one of {sigma_matrix, epsilon_matrix, alpha_matrix} to infer n_species." + ) + else: + n_species = int(torch.unique(species).numel())
711-718
: Typo in vesin_nl_ts argument:sorti
→sort_id
This will raise a TypeError at runtime.
- sorti=False, + sort_id=False,
870-876
: Incorrect stacking/cat of results: double-processing forces/energies/stressesYou first stack 'forces', 'energies', 'stresses' and then immediately overwrite by concatenating them. Follow the same pattern as SoftSphereModel: stack only 'stress' and 'energy'; cat the rest.
- for key in ("stress", "energy", "forces", "energies", "stresses"): + for key in ("stress", "energy"): if key in properties: results[key] = torch.stack([out[key] for out in outputs])(The following cat loop stays as-is.)
torch_sim/integrators/npt.py (1)
1460-1499
: Avoid constructing NPTNoseHooverState with momenta=None (type mismatch); initialize momenta and thermostat first
NPTNoseHooverState
declaresmomenta: torch.Tensor
. PassingNone
will fail strict type checking (and can break runtime assumptions). Computemomenta
earlier and pass it into the constructor. Also initialize the thermostat using the computed KE.- # Create initial state - npt_state = NPTNoseHooverState( + # Initialize particle momenta before building the state + momenta = kwargs.get( + "momenta", + calculate_momenta( + state.positions, state.masses, state.system_idx, kT, seed + ), + ) + + # Initialize thermostat using particle KE + KE = calc_kinetic_energy( + masses=state.masses, momenta=momenta, system_idx=state.system_idx + ) + thermostat = thermostat_fns.initialize(state.positions.numel(), KE, kT) + + # Create initial state with concrete tensors (no Nones) + npt_state = NPTNoseHooverState( positions=state.positions, - momenta=None, + momenta=momenta, energy=energy, forces=forces, masses=state.masses, atomic_numbers=atomic_numbers, cell=state.cell, pbc=state.pbc, system_idx=state.system_idx, reference_cell=reference_cell, cell_position=cell_position, cell_momentum=cell_momentum, cell_mass=cell_mass, barostat=barostat_fns.initialize(1, KE_cell, kT), - thermostat=None, + thermostat=thermostat, barostat_fns=barostat_fns, thermostat_fns=thermostat_fns, ) - - # Initialize momenta - momenta = kwargs.get( - "momenta", - calculate_momenta( - npt_state.positions, npt_state.masses, npt_state.system_idx, kT, seed - ), - ) - - # Initialize thermostat - npt_state.momenta = momenta - KE = calc_kinetic_energy( - momenta=npt_state.momenta, - masses=npt_state.masses, - system_idx=npt_state.system_idx, - ) - npt_state.thermostat = thermostat_fns.initialize( - npt_state.positions.numel(), KE, kT - )
♻️ Duplicate comments (1)
torch_sim/models/interface.py (1)
38-38
: Base class now inherits nn.Module — agreedConsolidating nn.Module at the interface simplifies model inheritance and typing.
🧹 Nitpick comments (4)
torch_sim/models/particle_life.py (2)
51-69
: Docstring says “Energies” but function returns forcesasymmetric_particle_pair_force_jit computes force magnitudes, not energies. Update the Returns description to prevent confusion and improve type-check clarity.
184-185
: I’d like to inspect the full mapping block in torch_sim/models/particle_life.py to verify why it’s reversed. Could you run:rg -n "mapping = torch.stack" -A 10 torch_sim/models/particle_life.py
torch_sim/models/mattersim.py (1)
24-34
: LGTM: Placeholder class now correctly derives from ModelInterfaceThis keeps typing consistent when imports fail and still raises eagerly in init. One nit: the docstring says “sevenn is not installed” — should be “mattersim”.
- It raises an ImportError if sevenn is not installed. + It raises an ImportError if MatterSim is not installed.torch_sim/models/metatomic.py (1)
79-97
: Docstring vs implementation mismatch formodel=None
behaviorDocstring says “If None, defaults to 'pet-mad'” but the code raises ValueError when
model is None
. Please align behavior with docs, or update the docstring and type to removeNone
.Two options:
- Update implementation to default to "pet-mad":
- if model is None: - raise ValueError( - "A model path, or the name of a pre-defined model, must be provided. " - 'Currently only "pet-mad" is available as a pre-defined model.' - ) + if model is None: + model = "pet-mad"
- Or update the docstring and annotation (remove
| None
) to reflect that a value is required.
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (45)
.github/PULL_REQUEST_TEMPLATE.md
(1 hunks)examples/scripts/2_Structural_optimization/2.1_Lennard_Jones_FIRE.py
(1 hunks)examples/scripts/2_Structural_optimization/2.2_Soft_Sphere_FIRE.py
(1 hunks)examples/scripts/3_Dynamics/3.11_Lennard_Jones_NPT_Langevin.py
(2 hunks)examples/scripts/3_Dynamics/3.1_Lennard_Jones_NVE.py
(1 hunks)examples/scripts/3_Dynamics/3.2_MACE_NVE.py
(1 hunks)examples/scripts/3_Dynamics/3.4_MACE_NVT_Langevin.py
(1 hunks)examples/scripts/3_Dynamics/3.7_Lennard_Jones_NPT_Nose_Hoover.py
(1 hunks)examples/scripts/4_High_level_api/4.1_high_level_api.py
(1 hunks)examples/scripts/6_Phonons/6.2_QuasiHarmonic_MACE.py
(3 hunks)examples/tutorials/high_level_tutorial.py
(1 hunks)examples/tutorials/reporting_tutorial.py
(2 hunks)tests/models/test_mattersim.py
(3 hunks)tests/models/test_sevennet.py
(3 hunks)tests/test_integrators.py
(4 hunks)tests/test_monte_carlo.py
(2 hunks)tests/test_optimizers.py
(15 hunks)tests/test_runners.py
(5 hunks)tests/test_state.py
(1 hunks)tests/test_trajectory.py
(2 hunks)tests/workflows/test_a2c.py
(2 hunks)torch_sim/elastic.py
(2 hunks)torch_sim/integrators/md.py
(2 hunks)torch_sim/integrators/npt.py
(8 hunks)torch_sim/integrators/nve.py
(1 hunks)torch_sim/integrators/nvt.py
(5 hunks)torch_sim/models/fairchem.py
(2 hunks)torch_sim/models/graphpes.py
(2 hunks)torch_sim/models/interface.py
(2 hunks)torch_sim/models/lennard_jones.py
(1 hunks)torch_sim/models/mace.py
(2 hunks)torch_sim/models/mattersim.py
(2 hunks)torch_sim/models/metatomic.py
(2 hunks)torch_sim/models/morse.py
(1 hunks)torch_sim/models/orb.py
(2 hunks)torch_sim/models/particle_life.py
(1 hunks)torch_sim/models/sevennet.py
(2 hunks)torch_sim/models/soft_sphere.py
(3 hunks)torch_sim/monte_carlo.py
(2 hunks)torch_sim/optimizers.py
(10 hunks)torch_sim/quantities.py
(6 hunks)torch_sim/runners.py
(1 hunks)torch_sim/state.py
(10 hunks)torch_sim/trajectory.py
(4 hunks)torch_sim/workflows/a2c.py
(8 hunks)
✅ Files skipped from review due to trivial changes (16)
- examples/scripts/3_Dynamics/3.4_MACE_NVT_Langevin.py
- examples/scripts/4_High_level_api/4.1_high_level_api.py
- examples/scripts/3_Dynamics/3.2_MACE_NVE.py
- .github/PULL_REQUEST_TEMPLATE.md
- examples/scripts/2_Structural_optimization/2.1_Lennard_Jones_FIRE.py
- examples/scripts/2_Structural_optimization/2.2_Soft_Sphere_FIRE.py
- torch_sim/monte_carlo.py
- torch_sim/integrators/md.py
- torch_sim/models/mace.py
- torch_sim/integrators/nve.py
- torch_sim/trajectory.py
- torch_sim/elastic.py
- tests/test_integrators.py
- examples/scripts/6_Phonons/6.2_QuasiHarmonic_MACE.py
- torch_sim/integrators/nvt.py
- torch_sim/optimizers.py
🧰 Additional context used
🧬 Code Graph Analysis (17)
tests/test_monte_carlo.py (2)
torch_sim/models/interface.py (1)
ModelInterface
(38-170)tests/conftest.py (1)
lj_model
(34-45)
examples/scripts/3_Dynamics/3.11_Lennard_Jones_NPT_Langevin.py (1)
torch_sim/quantities.py (1)
calc_kinetic_energy
(101-134)
torch_sim/models/particle_life.py (1)
torch_sim/models/interface.py (1)
ModelInterface
(38-170)
tests/test_state.py (3)
torch_sim/integrators/nvt.py (1)
velocities
(270-274)torch_sim/integrators/npt.py (1)
velocities
(871-877)torch_sim/integrators/md.py (1)
velocities
(45-49)
tests/workflows/test_a2c.py (1)
torch_sim/optimizers.py (2)
FireState
(442-483)UnitCellFireState
(621-692)
torch_sim/models/sevennet.py (1)
torch_sim/models/interface.py (1)
ModelInterface
(38-170)
torch_sim/models/fairchem.py (1)
torch_sim/models/interface.py (1)
ModelInterface
(38-170)
torch_sim/models/lennard_jones.py (1)
torch_sim/models/interface.py (1)
ModelInterface
(38-170)
torch_sim/models/mattersim.py (1)
torch_sim/models/interface.py (1)
ModelInterface
(38-170)
torch_sim/models/morse.py (1)
torch_sim/models/interface.py (1)
ModelInterface
(38-170)
torch_sim/models/graphpes.py (2)
torch_sim/models/interface.py (1)
ModelInterface
(38-170)tests/models/test_graphpes.py (2)
ts_nequip_model
(128-134)ts_mace_model
(155-161)
tests/test_trajectory.py (1)
torch_sim/models/interface.py (1)
ModelInterface
(38-170)
torch_sim/models/metatomic.py (1)
torch_sim/models/interface.py (1)
ModelInterface
(38-170)
torch_sim/models/orb.py (1)
torch_sim/models/interface.py (1)
ModelInterface
(38-170)
examples/tutorials/reporting_tutorial.py (2)
torch_sim/models/interface.py (1)
ModelInterface
(38-170)torch_sim/state.py (1)
SimState
(26-402)
torch_sim/quantities.py (4)
torch_sim/integrators/npt.py (2)
momenta
(71-73)velocities
(871-877)torch_sim/integrators/nvt.py (1)
velocities
(270-274)torch_sim/integrators/md.py (1)
velocities
(45-49)torch_sim/units.py (1)
MetalUnits
(83-103)
torch_sim/workflows/a2c.py (2)
torch_sim/models/interface.py (1)
ModelInterface
(38-170)torch_sim/optimizers.py (2)
FireState
(442-483)UnitCellFireState
(621-692)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (11)
- GitHub Check: test-examples (examples/scripts/7_Others/7.4_Velocity_AutoCorrelation.py)
- GitHub Check: test-examples (examples/tutorials/high_level_tutorial.py)
- GitHub Check: test-model (macos-14, 3.12, lowest-direct, sevenn, tests/models/test_sevennet.py)
- GitHub Check: test-model (macos-14, 3.11, highest, fairchem, tests/models/test_fairchem.py)
- GitHub Check: test-model (macos-14, 3.12, lowest-direct, fairchem, tests/models/test_fairchem.py)
- GitHub Check: test-model (macos-14, 3.11, highest, mattersim, tests/models/test_mattersim.py)
- GitHub Check: test-model (macos-14, 3.11, highest, metatomic, tests/models/test_metatomic.py)
- GitHub Check: test-model (ubuntu-latest, 3.12, lowest-direct, orb, tests/models/test_orb.py)
- GitHub Check: test-model (ubuntu-latest, 3.12, lowest-direct, sevenn, tests/models/test_sevennet.py)
- GitHub Check: test-model (ubuntu-latest, 3.11, highest, sevenn, tests/models/test_sevennet.py)
- GitHub Check: build-docs
🔇 Additional comments (58)
examples/scripts/3_Dynamics/3.1_Lennard_Jones_NVE.py (1)
81-81
: LGTM: purely a formatting consolidation of SimState kwargsNo functional changes; kwargs are correct and consistent with surrounding examples.
tests/workflows/test_a2c.py (3)
1-2
: LGTM! Appropriate import addition.The
cast
import from typing is correctly added to support the type casting needed later in the test.
9-9
: LGTM! Correct import addition.Adding
FireState
import is necessary for the type cast used later in the test function.
160-160
: LGTM! Proper handling of ambiguous return type.The cast to
FireState
correctly handles the fact thatrandom_packed_structure
can return either aFireState
or a tuple. This ensures type safety while maintaining test functionality.torch_sim/quantities.py (5)
3-4
: LGTM! Necessary import addition.The
cast
import is correctly added to support type casting in the function implementations.
25-31
: LGTM! Excellent improvement to argument handling.Making arguments keyword-only with the leading
*
and makingmomenta
optional improves the API design. This enforces clearer calling conventions and prevents positional argument mistakes.
44-50
: LGTM! Improved validation logic.The XOR validation using
((momenta is not None) ^ (velocities is not None))
is more concise and clearer than separate checks. The explicit cast totorch.Tensor
when velocities is guaranteed to be non-None ensures type safety.
73-98
: LGTM! Consistent parameter improvements.The same improvements applied to
calc_temperature
: keyword-only arguments, optional momenta parameter, and explicit keyword arguments in thecalc_kT
call. TheMetalUnits
type annotation is more precise than the previous generic object type.
101-122
: LGTM! Consistent implementation pattern.The
calc_kinetic_energy
function follows the same improved pattern with keyword-only arguments, optional momenta parameter, and consolidated XOR validation logic.torch_sim/state.py (7)
11-11
: LGTM! Appropriate imports for enhanced type safety.The additional imports support the generic type improvements and explicit casting used throughout the file.
112-147
: LGTM! Fixed constructor validation logic.The system_idx validation and assignment logic is now more robust. The validation properly handles the case where system_idx is provided, and the assignment is correctly performed after validation.
254-259
: LGTM! Proper error handling for volume property.Raising a
ValueError
for non-PBC systems is the correct approach since volume is undefined without periodic boundaries. This is better than returningNone
which could lead to silent errors downstream.
368-368
: LGTM! Explicit type casting for method return.The cast ensures the return type matches the method signature while maintaining type safety.
408-420
: LGTM! Simplified mixin interface.Removing the momenta property and adding the required attributes simplifies the mixin interface. The reference_row_vector_cell property provides a clean way to access the reference cell in row vector format.
483-490
: LGTM! Enhanced type safety with generics.The SimStateT generic type variable and its usage in function signatures preserve subclass types through operations like device conversion, which is excellent for type safety.
972-972
: LGTM! Appropriate type casting for validation.The cast ensures type safety when accessing the n_systems property on list elements during validation.
examples/scripts/3_Dynamics/3.7_Lennard_Jones_NPT_Nose_Hoover.py (1)
100-100
: LGTM! Consistent argument ordering.Moving
pbc=True
afteratomic_numbers
maintains consistency with other example scripts while preserving functionality.torch_sim/models/morse.py (1)
115-115
: LGTM! Consistent interface inheritance.Removing
torch.nn.Module
inheritance aligns with the architectural change whereModelInterface
now provides the PyTorch module functionality. This centralizes the inheritance pattern and reduces duplication across model implementations.torch_sim/models/particle_life.py (2)
86-86
: Inheritance change aligns with new interface hierarchyGood move inheriting only from ModelInterface now that it subclasses nn.Module. super().init() is present.
155-161
: Ignore CUDA scalar conversion suggestionThe
vesin_nl_ts
signature clearly declaresdef vesin_nl_ts( …, cutoff: torch.Tensor, … )so it expects a tensor (not a Python float). Converting
self.cutoff
to a Python float—even with.item()
—would break the intended API. You can leave the call as-is.Likely an incorrect or invalid review comment.
torch_sim/models/lennard_jones.py (2)
122-122
: Inheritance change matches new designSwitching to inherit only from ModelInterface (which is an nn.Module) is consistent with the refactor.
265-271
: Ignore float conversion suggestion for cutoff
Thevesin_nl_ts
function is scripted to accept a 0-dimtorch.Tensor
and internally invokescutoff.item()
, so passingself.cutoff
(even on CUDA) is correct. Converting it to a Python float would mismatch the declaredtorch.Tensor
parameter and break the JIT call.Likely an incorrect or invalid review comment.
examples/tutorials/high_level_tutorial.py (1)
135-136
: Keyword-only call is correctSwitch to keyword args for calc_kinetic_energy matches the updated signature and improves clarity.
tests/test_monte_carlo.py (2)
6-6
: Type import aligns test with new model interfaceImporting ModelInterface keeps tests consistent with the refactor.
116-117
: Fixture type updated to ModelInterface is appropriateTyping lj_model as ModelInterface matches how models are consumed across the codebase.
torch_sim/models/fairchem.py (2)
48-56
: Placeholder class now inherits ModelInterfaceGood change; keeps type checkers happy even when FairChem isn’t installed. Raising in init is fine here.
73-73
: Main class inheritance updated appropriatelyInheriting only from ModelInterface centralizes nn.Module via the interface. super().init() is correctly called in init.
tests/test_trajectory.py (2)
12-12
: ModelInterface import aligns the test with the new model abstractionGood change. This keeps the test consistent with ModelInterface-based typing across the codebase.
752-756
: Type-narrowing to ModelInterface for energy_calculator — LGTMAnnotating the model parameter as ModelInterface is correct here since the calculator calls model(state) and accesses "energy".
examples/tutorials/reporting_tutorial.py (2)
209-209
: Importing ModelInterface for tutorial examples — good consistencyThis ensures the tutorial code type-checks under basedpyright with the new interface.
218-221
: Update calculate_energy to accept ModelInterface — correct and minimalAccurately reflects the contract that models implement ModelInterface and return an output dict with "energy".
examples/scripts/3_Dynamics/3.11_Lennard_Jones_NPT_Langevin.py (2)
101-102
: Keyword argument order tweak for SimState — OKReordering pbc after atomic_numbers using keyword args is harmless and reads cleanly.
151-156
: Keyword-only kinetic energy call and variable rename are correct
- calc_kinetic_energy now requires keyword args; usage is correct.
- Passing kinetic_energy to get_pressure matches the expected signature (stress, kinetic_energy, volume).
tests/test_state.py (1)
503-505
: Require velocities and reference_cell in DeformState — good alignmentMaking these parameters required reflects the updated mixin expectations and clarifies test intent.
torch_sim/models/mattersim.py (1)
42-55
: LGTM: Inherit only from ModelInterfaceGiven
ModelInterface
already subclassestorch.nn.Module
, this simplifies the hierarchy without functional changes.super().__init__()
still initializesnn.Module
state.torch_sim/models/orb.py (2)
42-52
: LGTM: Placeholder class now uses ModelInterfaceConsistent with the new interface-first design and keeps the eager ImportError behavior.
250-274
: LGTM: Inherit only from ModelInterfaceMatches the project-wide refactor and maintains proper
nn.Module
initialization viasuper().__init__()
.torch_sim/models/sevennet.py (2)
35-45
: LGTM: Placeholder class inherits ModelInterfaceKeeps error surfacing while maintaining consistent typing.
47-59
: LGTM: Inherit only from ModelInterfaceConsistent with project-wide refactor;
nn.Module
is still initialized viasuper().__init__()
.torch_sim/models/graphpes.py (1)
39-49
: ImportError fallback: inheritance change is correctUsing ModelInterface alone (dropping direct nn.Module) in the stub keeps the hierarchy consistent now that ModelInterface subclasses nn.Module.
torch_sim/models/metatomic.py (1)
39-49
: ImportError fallback: inheritance change is correctUsing ModelInterface alone (dropping direct nn.Module) in the stub keeps the hierarchy consistent.
tests/test_runners.py (1)
26-30
: Keyword-arg calls tocalc_kinetic_energy
are correctSwitching to
momenta=..., masses=...
improves clarity and typing; matches the current API and helps basedpyright.Also applies to: 63-67, 119-123, 170-174
torch_sim/models/interface.py (1)
72-75
: Declaring protected attributes at class level is the right moveThese declarations unblock basedpyright by guaranteeing attribute presence on subclasses.
torch_sim/models/soft_sphere.py (3)
438-438
: Multi-model now inherits ModelInterface — good alignment with the refactorNo issues with the base class swap.
597-602
: Private attribute renaming is fine; confirm public properties back theseSetting
_device/_dtype/_compute_*
is consistent with SoftSphereModel. As above, ensure ModelInterface exposes the publicdevice/dtype/compute_*
properties used elsewhere.Same script as in Line 133 comment can be used to confirm the properties exist.
133-133
: ModelInterface defines all required properties
Verified thatModelInterface
declares the following@property
methods intorch_sim/models/interface.py
:
device
(l. 77)dtype
(l. 89)compute_stress
(l. 101)compute_forces
(l. 113)No further changes needed.
torch_sim/integrators/npt.py (2)
17-17
: Importing ModelInterface for type annotations — goodThis aligns the integrators with the new interface abstraction.
1436-1436
: Consistent keyword-only calls to calc_kinetic_energy — LGTMAll updated calls now use
masses=
/momenta=
and passsystem_idx
where per-system results are needed. This matches the API change and improves clarity.Also applies to: 1491-1494, 1550-1552, 1555-1555, 1607-1608
torch_sim/workflows/a2c.py (8)
22-22
: LGTM! Import added for ModelInterface.The import is correctly added to support the type annotation changes throughout the file.
232-232
: LGTM! Return type annotation correctly reflects conditional behavior.The union type
FireState | tuple[FireState, list[np.ndarray]]
accurately represents that the function returns either aFireState
or a tuple based on whetherlog
is provided.
580-597
: LGTM! Performance optimization for composition handling.The refactoring to convert
restrict_to_compositions
to a set upfront and use.add()
instead of list operations is more efficient. The explicit type annotations also improve type safety.
614-614
: LGTM! Explicit type annotation improves type safety.The explicit typing of
candidates
aslist[tuple[torch.Tensor, torch.Tensor, torch.Tensor]]
makes the return type clear and helps with type checking.
709-709
: LGTM! Parameter type updated to ModelInterface.The change from
torch.nn.Module
toModelInterface
is consistent with the codebase refactoring and provides more specific typing for models in the simulation framework.
711-711
: LGTM! Return type expanded to include energy and pressure lists.The return type correctly reflects that the function now returns final energies and pressures as lists of floats in addition to the state and logger. This is consistent with the function implementation that prints and returns these values.
755-756
: LGTM! Type annotations updated for consistency.The parameter and return type annotations are properly updated to use the new type system with explicit
UnitCellFireState
anddict[str, torch.Tensor]
types.
330-331
: Missing return statement in else branch.The function has a conditional return at line 329-330 but is missing a return statement for the case when
log is None
. This will cause a runtime error.Apply this diff to fix the missing return:
if log is not None: return state, log - + return stateLikely an incorrect or invalid review comment.
tests/test_optimizers.py (2)
9-9
: LGTM! Import added for ModelInterface.The import is correctly added to support the updated type annotations throughout the test file.
27-27
: LGTM! Model parameter types updated to ModelInterface.All test functions now correctly use
ModelInterface
instead oftorch.nn.Module
for thelj_model
parameter. This is consistent with the codebase refactoring where models inherit fromModelInterface
rather than directly fromtorch.nn.Module
.Also applies to: 66-66, 115-115, 189-189, 205-205, 213-213, 276-276, 329-329, 418-418, 452-452, 471-471, 507-507, 596-596, 711-711, 787-787
1: { | ||
"ke": lambda state: calc_kinetic_energy( | ||
velocities=state.velocities, masses=state.masses | ||
) | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
💡 Verification agent
🧩 Analysis chain
Verify velocities
usage in FIRE reporter
Using velocities=state.velocities
is fine if SimState guarantees velocities
during FIRE optimization and calc_kinetic_energy
supports it. Please sanity-check both assumptions.
Run this script to verify the function signature and SimState fields:
🏁 Script executed:
#!/bin/bash
set -euo pipefail
echo "Check calc_kinetic_energy signature:"
rg -n "def calc_kinetic_energy\\(" -A 5 torch_sim | sed -n '1,12p'
echo
echo "Confirm it accepts 'velocities' and/or 'momenta':"
rg -n "calc_kinetic_energy\\(" -A 2 torch_sim | rg -n "velocities|momenta|masses" -n
echo
echo "Check that SimState defines 'velocities' in contexts used by FIRE:"
rg -n "class SimState" -n torch_sim
rg -n "velocities" torch_sim | head -n 20
Length of output: 4095
Ensure SimState provides velocities for the FIRE reporter
While calc_kinetic_energy
does accept a velocities
tensor (see torch_sim/quantities.py:def calc_kinetic_energy signature), SimState does not currently define or populate a velocities
attribute:
• torch_sim/quantities.py (lines 101–107): supports velocities: Tensor | None
• torch_sim/state.py (line 966): TODO to pass velocities from external engines, no field defined
• torch_sim/runners.py (lines 49, 52): reporter calls velocities=state.velocities
Please address this by either:
- Adding/populating a
velocities
field on SimState during FIRE optimizations, or - Switching the reporter to use
momenta=state.momenta
(which SimState does define)
🤖 Prompt for AI Agents
In tests/test_runners.py around lines 365 to 369, the reporter calls
calc_kinetic_energy with state.velocities, but SimState does not define or
populate a velocities attribute. To fix this, either add and populate a
velocities field on SimState during FIRE optimizations so it can be passed
correctly, or modify the reporter to use momenta=state.momenta instead, since
momenta is already defined on SimState.
"kinetic_energy": lambda state: calc_kinetic_energy( | ||
velocities=state.velocities, masses=state.masses | ||
), | ||
"temperature": lambda state: calc_kT( | ||
velocities=state.velocities, masses=state.masses | ||
), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
💡 Verification agent
🧩 Analysis chain
Kinetic/temperature calculators now depend on velocities — verify availability or guard
Switch to keyword args looks good and aligns with API. However, if any integrator/state doesn’t populate state.velocities
, these lambdas will raise at runtime. Either ensure all integrate
paths guarantee velocities, or guard the calculators (e.g., skip or return NaN when velocities are absent).
You can quickly sanity-check for residual usage of momenta
vs. assured velocities
support with:
🏁 Script executed:
#!/bin/bash
# Inspect SimState and integrators for velocities/momenta usage
rg -n "velocities" torch_sim | head -n 50
rg -n "momenta" torch_sim | head -n 50
# Check if SimState defines 'velocities' as required attr
rg -n "class SimState" -A 40 torch_sim/state.py
Length of output: 11438
🏁 Script executed:
#!/bin/bash
# Inspect runner reporters setup for state.velocities usage
rg -n "reporters" -A30 torch_sim/runners.py
# Inspect how Trajectory uses save_velocities and when velocities are attached
rg -n "class Trajectory" -A200 torch_sim/trajectory.py
# Check save_velocities flag defaults and how velocities propagate into state
rg -n "save_velocities" -A20 torch_sim/trajectory.py
Length of output: 15077
🏁 Script executed:
#!/bin/bash
# Check for StaticState definition in the codebase
rg -n "class StaticState" -A20 torch_sim/state.py
Length of output: 51
🏁 Script executed:
#!/bin/bash
# Locate StaticState definition across the repo
rg -n "class StaticState" -A20 .
# Locate MDState definition across the repo
rg -n "class MDState" -A20 .
# Show the top of runners.py (including default prop_calculators)
sed -n '1,100p' torch_sim/runners.py
Length of output: 6753
Guard kinetic/temperature reporters against missing velocities
In _configure_reporter
(torch_sim/runners.py), the default kinetic_energy
and temperature
lambdas assume state.velocities
always exists. While MDState provides a velocities
property, base SimState and static‐optimization states do not, leading to an AttributeError if users request these props outside an MD run.
• Location:
- torch_sim/runners.py, around lines 46–53 in
_configure_reporter
• Suggested change: wrap each calculator in a guard that checks forstate.velocities
and returns NaN (or skips) when unavailable
Example diff:
--- a/torch_sim/runners.py
+++ b/torch_sim/runners.py
@@ def _configure_reporter(...):
- possible_properties = {
- "kinetic_energy": lambda state: calc_kinetic_energy(
- velocities=state.velocities, masses=state.masses
- ),
- "temperature": lambda state: calc_kT(
- velocities=state.velocities, masses=state.masses
- ),
+ possible_properties = {
+ "kinetic_energy": lambda state: (
+ calc_kinetic_energy(velocities=state.velocities, masses=state.masses)
+ if getattr(state, "velocities", None) is not None
+ else torch.tensor(float("nan"), device=state.device, dtype=state.dtype)
+ ),
+ "temperature": lambda state: (
+ calc_kT(velocities=state.velocities, masses=state.masses)
+ if getattr(state, "velocities", None) is not None
+ else torch.tensor(float("nan"), device=state.device, dtype=state.dtype)
+ ),
This ensures reporters won’t crash when running on states without velocities.
📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
"kinetic_energy": lambda state: calc_kinetic_energy( | |
velocities=state.velocities, masses=state.masses | |
), | |
"temperature": lambda state: calc_kT( | |
velocities=state.velocities, masses=state.masses | |
), | |
possible_properties = { | |
"kinetic_energy": lambda state: ( | |
calc_kinetic_energy(velocities=state.velocities, masses=state.masses) | |
if getattr(state, "velocities", None) is not None | |
else torch.tensor(float("nan"), device=state.device, dtype=state.dtype) | |
), | |
"temperature": lambda state: ( | |
calc_kT(velocities=state.velocities, masses=state.masses) | |
if getattr(state, "velocities", None) is not None | |
else torch.tensor(float("nan"), device=state.device, dtype=state.dtype) | |
), | |
# … other properties … | |
} |
🤖 Prompt for AI Agents
In torch_sim/runners.py around lines 48 to 53, the lambdas for "kinetic_energy"
and "temperature" assume state.velocities always exists, which can cause
AttributeError for states without velocities. Modify these lambdas to first
check if state.velocities is present; if not, return NaN to prevent crashes.
This guard ensures the reporters handle states lacking velocities gracefully.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I suspect this is a larger issue. we'll prob fix this in another PR
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@curtischong, that makes perfect sense! Keeping the scope focused on type checking compliance for this PR is a good approach. The runtime safety issue with missing velocities can definitely be addressed separately.
Would you like me to create a GitHub issue to track this for the future PR? It would help ensure the velocities availability issue doesn't get forgotten.
✏️ Learnings added
Learnt from: curtischong
PR: Radical-AI/torch-sim#215
File: torch_sim/runners.py:48-53
Timestamp: 2025-08-08T02:04:31.517Z
Learning: In the torch-sim project, when reviewing PRs, scope expansion issues (like runtime safety improvements) that are outside the current PR's main objective should be acknowledged but deferred to separate PRs rather than insisting on immediate fixes.
Note: Learnings are effective only in the context of similar code segments. To apply general coding guidelines, please configure review instructions. You can manage existing learnings in the app.
[skip ci] Signed-off-by: Rhys Goodall <[email protected]>
torch.Tensor: Scalar temperature value | ||
""" | ||
if momenta is not None and velocities is not None: | ||
if not ((momenta is not None) ^ (velocities is not None)): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think the value error is slightly misleading. if both of them are none then the xor will evaluate to false. but the not will negate it and the entire statement will be true.
so if they pass in no velocity and no mementia, the value error will raise. but the "not both" is misleading
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this is a minor nit though. Oh I get it. you merged the if statements. How about: "Must pass in one of either momenta or velocities"
If system_idx is provided: Tensor of kinetic energies per system | ||
""" | ||
if momenta is not None and velocities is not None: | ||
if not ((momenta is not None) ^ (velocities is not None)): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
maybe update this value error as well?
uvx basedpyright
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 0
🧹 Nitpick comments (2)
tests/test_quantities.py (2)
3-3
: Use public API for tensor type annotation.Importing from
torch._tensor
accesses a private module. Usetorch.Tensor
directly for better compatibility with type checkers.-from torch._tensor import Tensor +Then update the type annotations to use
torch.Tensor
instead ofTensor
:-def single_system_data() -> dict[str, Tensor]: +def single_system_data() -> dict[str, torch.Tensor]:-def batched_system_data() -> dict[str, Tensor]: +def batched_system_data() -> dict[str, torch.Tensor]:-def test_calc_kinetic_energy_single_system(single_system_data: dict[str, Tensor]) -> None: +def test_calc_kinetic_energy_single_system(single_system_data: dict[str, torch.Tensor]) -> None:And similarly for other function signatures.
127-137
: Consider adding batched system test coverage.The temperature calculation test correctly verifies the unit conversion relationship with kT. Consider adding a test case for batched systems to match the coverage of other test functions.
def test_calc_temperature_batched_system(batched_system_data: dict[str, torch.Tensor]) -> None: temp = quantities.calc_temperature( masses=batched_system_data["masses"], velocities=batched_system_data["velocities"], system_idx=batched_system_data["system_idx"], ) kt = quantities.calc_kT( masses=batched_system_data["masses"], velocities=batched_system_data["velocities"], system_idx=batched_system_data["system_idx"], ) assert torch.allclose(temp, kt / MetalUnits.temperature)
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (2)
tests/test_quantities.py
(1 hunks)torch_sim/quantities.py
(6 hunks)
🚧 Files skipped from review as they are similar to previous changes (1)
- torch_sim/quantities.py
🧰 Additional context used
🧠 Learnings (1)
📓 Common learnings
Learnt from: curtischong
PR: Radical-AI/torch-sim#215
File: torch_sim/runners.py:48-53
Timestamp: 2025-08-08T02:04:31.517Z
Learning: In the torch-sim project, when reviewing PRs, scope expansion issues (like runtime safety improvements) that are outside the current PR's main objective should be acknowledged but deferred to separate PRs rather than insisting on immediate fixes.
🧬 Code Graph Analysis (1)
tests/test_quantities.py (2)
torch_sim/units.py (1)
MetalUnits
(83-103)torch_sim/quantities.py (3)
calc_kinetic_energy
(101-134)calc_kT
(25-70)calc_temperature
(73-97)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (58)
- GitHub Check: test-examples (examples/scripts/7_Others/7.3_Batched_neighbor_list.py)
- GitHub Check: test-examples (examples/scripts/7_Others/7.4_Velocity_AutoCorrelation.py)
- GitHub Check: test-examples (examples/scripts/7_Others/7.6_Compare_ASE_to_VV_FIRE.py)
- GitHub Check: test-examples (examples/scripts/7_Others/7.1_Soft_sphere_autograd.py)
- GitHub Check: test-examples (examples/scripts/1_Introduction/1.3_Fairchem.py)
- GitHub Check: test-examples (examples/scripts/1_Introduction/1.2_MACE.py)
- GitHub Check: test-examples (examples/scripts/6_Phonons/6.2_QuasiHarmonic_MACE.py)
- GitHub Check: test-examples (examples/scripts/6_Phonons/6.1_Phonons_MACE.py)
- GitHub Check: test-examples (examples/scripts/2_Structural_optimization/2.1_Lennard_Jones_FIRE.py)
- GitHub Check: test-examples (examples/scripts/2_Structural_optimization/2.5_MACE_UnitCellFilter_Gradient_Descen...
- GitHub Check: test-examples (examples/scripts/6_Phonons/6.3_Conductivity_MACE.py)
- GitHub Check: test-examples (examples/scripts/3_Dynamics/3.12_MACE_NPT_Langevin.py)
- GitHub Check: test-examples (examples/scripts/2_Structural_optimization/2.3_MACE_Gradient_Descent.py)
- GitHub Check: test-examples (examples/scripts/2_Structural_optimization/2.7_MACE_FrechetCellFilter_FIRE.py)
- GitHub Check: test-examples (examples/scripts/3_Dynamics/3.7_Lennard_Jones_NPT_Nose_Hoover.py)
- GitHub Check: test-examples (examples/scripts/2_Structural_optimization/2.4_MACE_FIRE.py)
- GitHub Check: test-examples (examples/scripts/3_Dynamics/3.9_MACE_NVT_staggered_stress.py)
- GitHub Check: test-examples (examples/scripts/5_Workflow/5.2_In_Flight_WBM.py)
- GitHub Check: test-examples (examples/scripts/2_Structural_optimization/2.6_MACE_UnitCellFilter_FIRE.py)
- GitHub Check: test-examples (examples/scripts/4_High_level_api/4.1_high_level_api.py)
- GitHub Check: test-examples (examples/scripts/4_High_level_api/4.2_auto_batching_api.py)
- GitHub Check: test-examples (examples/scripts/5_Workflow/5.3_Elastic.py)
- GitHub Check: test-examples (examples/scripts/5_Workflow/5.1_a2c_silicon_batched.py)
- GitHub Check: test-examples (examples/scripts/3_Dynamics/3.8_MACE_NPT_Nose_Hoover.py)
- GitHub Check: test-examples (examples/scripts/3_Dynamics/3.10_Hybrid_swap_mc.py)
- GitHub Check: test-examples (examples/scripts/3_Dynamics/3.4_MACE_NVT_Langevin.py)
- GitHub Check: test-examples (examples/scripts/3_Dynamics/3.6_MACE_NVT_Nose_Hoover_temp_profile.py)
- GitHub Check: test-examples (examples/scripts/3_Dynamics/3.3_MACE_NVE_cueq.py)
- GitHub Check: test-examples (examples/scripts/3_Dynamics/3.2_MACE_NVE.py)
- GitHub Check: test-examples (examples/scripts/3_Dynamics/3.13_MACE_NVE_non_pbc.py)
- GitHub Check: test-examples (examples/scripts/3_Dynamics/3.5_MACE_NVT_Nose_Hoover.py)
- GitHub Check: test-examples (examples/tutorials/hybrid_swap_tutorial.py)
- GitHub Check: test-examples (examples/tutorials/low_level_tutorial.py)
- GitHub Check: test-examples (examples/tutorials/state_tutorial.py)
- GitHub Check: test-examples (examples/tutorials/reporting_tutorial.py)
- GitHub Check: test-examples (examples/tutorials/autobatching_tutorial.py)
- GitHub Check: test-examples (examples/tutorials/high_level_tutorial.py)
- GitHub Check: test-model (macos-14, 3.12, lowest-direct, mace, tests/test_elastic.py)
- GitHub Check: test-model (macos-14, 3.12, lowest-direct, orb, tests/models/test_orb.py)
- GitHub Check: test-model (macos-14, 3.12, lowest-direct, fairchem, tests/models/test_fairchem.py)
- GitHub Check: test-model (ubuntu-latest, 3.12, lowest-direct, orb, tests/models/test_orb.py)
- GitHub Check: test-model (ubuntu-latest, 3.12, lowest-direct, sevenn, tests/models/test_sevennet.py)
- GitHub Check: test-model (ubuntu-latest, 3.12, lowest-direct, mace, tests/test_elastic.py)
- GitHub Check: test-model (ubuntu-latest, 3.12, lowest-direct, graphpes, tests/models/test_graphpes.py)
- GitHub Check: test-model (ubuntu-latest, 3.12, lowest-direct, mace, tests/test_optimizers_vs_ase.py)
- GitHub Check: test-model (ubuntu-latest, 3.12, lowest-direct, fairchem, tests/models/test_fairchem.py)
- GitHub Check: test-model (ubuntu-latest, 3.11, highest, sevenn, tests/models/test_sevennet.py)
- GitHub Check: test-model (ubuntu-latest, 3.11, highest, orb, tests/models/test_orb.py)
- GitHub Check: test-model (ubuntu-latest, 3.11, highest, mace, tests/test_optimizers_vs_ase.py)
- GitHub Check: test-model (ubuntu-latest, 3.11, highest, mace, tests/models/test_mace.py)
- GitHub Check: test-model (ubuntu-latest, 3.11, highest, mace, tests/test_elastic.py)
- GitHub Check: test-model (ubuntu-latest, 3.11, highest, graphpes, tests/models/test_graphpes.py)
- GitHub Check: test-model (ubuntu-latest, 3.11, highest, fairchem, tests/models/test_fairchem.py)
- GitHub Check: test-core (macos-14, 3.12, lowest-direct)
- GitHub Check: test-core (ubuntu-latest, 3.11, highest)
- GitHub Check: test-core (ubuntu-latest, 3.12, lowest-direct)
- GitHub Check: test-core (macos-14, 3.11, highest)
- GitHub Check: build-docs
🔇 Additional comments (4)
tests/test_quantities.py (4)
13-26
: LGTM!The fixture provides accurate test data. The expected kinetic energy (13.5) and kT (4.5) values are mathematically correct for the given masses and velocities.
29-44
: LGTM!The batched system fixture provides accurate test data. The expected kinetic energies ([3.0, 24.0]) and kT values ([1.0, 8.0]) per system are mathematically correct.
47-92
: Excellent test coverage for kinetic energy calculations.The tests comprehensively cover:
- Single and batched system scenarios
- Both velocity and momentum input modes
- Proper error handling for invalid argument combinations
- Correct usage of keyword-only arguments
94-125
: LGTM!The kT calculation tests provide comprehensive coverage matching the kinetic energy test structure, correctly testing both single and batched systems with velocity and momentum inputs.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
overall I approve - 2 minor nits you might want to change before we merge
Summary
By ensuring that basedpyright runs, we reduce the chance of bugs in our code
Checklist
Before a pull request can be merged, the following items must be checked:
Run ruff on your code.
We highly recommended installing the pre-commit hooks running in CI locally to speedup the development process. Simply run
pip install pre-commit && pre-commit install
to install the hooks which will check your code before each commit.Summary by CodeRabbit
Summary by CodeRabbit