Skip to content

Conversation

curtischong
Copy link
Collaborator

@curtischong curtischong commented Jul 7, 2025

Summary

By ensuring that basedpyright runs, we reduce the chance of bugs in our code

Checklist

Before a pull request can be merged, the following items must be checked:

  • Doc strings have been added in the Google docstring format.
    Run ruff on your code.
  • Tests have been added for any new functionality or bug fixes.

We highly recommended installing the pre-commit hooks running in CI locally to speedup the development process. Simply run pip install pre-commit && pre-commit install to install the hooks which will check your code before each commit.

Summary by CodeRabbit

Summary by CodeRabbit

  • New Features
    • The relaxed structure workflow now provides final energy and pressure values alongside existing outputs.
  • Improvements
    • Updated argument handling for kinetic energy and temperature calculations to use explicit keyword arguments and mutually exclusive inputs.
    • Enhanced type annotations across models, optimizers, integrators, workflows, and tests for better interface consistency and clarity.
    • Streamlined model class hierarchies by standardizing inheritance solely from a common interface.
    • Refined type hints and generics in state management for improved type safety and validation.
    • Updated internal reporting and trajectory properties to use velocities instead of momenta for kinetic energy and temperature calculations.
  • Bug Fixes
    • Corrected variable naming and argument order in example scripts and tests for clarity and consistency.
  • Refactor
    • Removed deprecated relaxation function from workflows.
    • Improved internal logging output formatting for energies and pressures during relaxation steps.
    • Adjusted attribute naming conventions in models for internal consistency.

@cla-bot cla-bot bot added the cla-signed Contributor license agreement signed label Jul 7, 2025
@curtischong curtischong marked this pull request as draft July 7, 2025 00:01
Copy link

coderabbitai bot commented Jul 7, 2025

Walkthrough

The torch_sim/workflows/a2c.py module was updated to improve type annotations by introducing ModelInterface and expanding return types for key functions. The get_relaxed_structure function was removed. Logger dictionaries and internal function signatures were refined for clarity. Multiple model classes across the codebase were changed to inherit only from ModelInterface instead of torch.nn.Module. Several integrator, optimizer, and test functions updated type hints accordingly. Argument passing in kinetic energy calculations was standardized to use keyword arguments. Minor formatting and variable renaming changes were applied in example scripts. The ModelInterface class was modified to inherit from torch.nn.Module and dropped its abstract constructor.

Changes

Cohort / File(s) Change Summary
A2C Workflow
torch_sim/workflows/a2c.py
Added import of ModelInterface; expanded return types for random_packed_structure and get_unit_cell_relaxed_structure; removed get_relaxed_structure; refined logger typing; updated internal function signatures and prints.
Model Interface and Model Classes
torch_sim/models/interface.py, torch_sim/models/*.py (fairchem, graphpes, mace, mattersim, metatomic, lennard_jones, morse, orb, particle_life, sevennet, soft_sphere)
Changed ModelInterface to inherit from torch.nn.Module; removed abstract constructor; all model classes dropped direct torch.nn.Module inheritance, inheriting only from ModelInterface. Renamed some instance attributes in SoftSphereModel classes.
Integrators and Optimizers
torch_sim/integrators/*.py, torch_sim/optimizers.py
Updated model parameter type hints from torch.nn.Module to ModelInterface; standardized calls to calc_kinetic_energy and related functions to use explicit keyword arguments. Changed return types of some optimizer init functions to return callables.
Monte Carlo and Elastic
torch_sim/monte_carlo.py, torch_sim/elastic.py
Updated model parameter type hints from torch.nn.Module to ModelInterface.
State and Trajectory
torch_sim/state.py, torch_sim/trajectory.py
Added generics and improved type hints; corrected SimState constructor logic; refined type hints in TrajectoryReporter; removed momenta property from DeformGradMixin.
Runners and Quantities
torch_sim/runners.py, torch_sim/quantities.py
Changed kinetic energy and temperature calculations to use velocities with keyword arguments instead of positional momenta; made calc_kT, calc_temperature, and calc_kinetic_energy use keyword-only arguments and improved input validation.
Tests
tests/models/*.py, tests/test_*.py, tests/workflows/test_a2c.py
Updated type annotations for model parameters from torch.nn.Module to ModelInterface; added explicit casts; updated kinetic energy lambda argument passing to keyword arguments; removed one test function; minor variable renaming.
Examples and Tutorials
examples/scripts/*, examples/tutorials/*
Reordered keyword arguments in ts.state.SimState constructor calls; changed kinetic energy lambda functions to use keyword arguments; updated type annotations for model parameters to ModelInterface.
GitHub Template
.github/PULL_REQUEST_TEMPLATE.md
Added checklist item to run uvx ty check; changed linter checklist to checkbox format.

Estimated code review effort

🎯 4 (Complex) | ⏱️ ~45 minutes

Possibly related PRs

Suggested labels

enhancement

Poem

🐇 In code’s vast fields where models grow,
Interfaces changed their flow.
Modules shed their old attire,
Embracing interfaces to inspire.
Arguments now clearly named,
Kinetic energies rightly framed.
A hop, a skip, a joyful cheer—
Cleaner code is finally here! 🎉

Note

🔌 MCP (Model Context Protocol) integration is now available in Early Access!

Pro users can now connect to remote MCP servers under the Integrations page to get reviews and chat conversations that understand additional development context.

✨ Finishing Touches
  • 📝 Generate Docstrings
🧪 Generate unit tests
  • Create PR with unit tests
  • Post copyable unit tests in a comment
  • Commit unit tests in branch ty

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share
🪧 Tips

Chat

There are 3 ways to chat with CodeRabbit:

  • Review comments: Directly reply to a review comment made by CodeRabbit. Example:
    • I pushed a fix in commit <commit_id>, please review it.
    • Explain this complex logic.
    • Open a follow-up GitHub issue for this discussion.
  • Files and specific lines of code (under the "Files changed" tab): Tag @coderabbitai in a new review comment at the desired location with your query. Examples:
    • @coderabbitai explain this code block.
  • PR comments: Tag @coderabbitai in a new PR comment to ask questions about the PR branch. For the best results, please provide a very specific query, as very limited context is provided in this mode. Examples:
    • @coderabbitai gather interesting stats about this repository and render them as a table. Additionally, render a pie chart showing the language distribution in the codebase.
    • @coderabbitai read src/utils.ts and explain its main purpose.
    • @coderabbitai read the files in the src/scheduler package and generate a class diagram using mermaid and a README in the markdown format.

Support

Need help? Create a ticket on our support page for assistance with any issues or questions.

CodeRabbit Commands (Invoked using PR comments)

  • @coderabbitai pause to pause the reviews on a PR.
  • @coderabbitai resume to resume the paused reviews.
  • @coderabbitai review to trigger an incremental review. This is useful when automatic reviews are disabled for the repository.
  • @coderabbitai full review to do a full review from scratch and review all the files again.
  • @coderabbitai summary to regenerate the summary of the PR.
  • @coderabbitai generate docstrings to generate docstrings for this PR.
  • @coderabbitai generate sequence diagram to generate a sequence diagram of the changes in this PR.
  • @coderabbitai generate unit tests to generate unit tests for this PR.
  • @coderabbitai resolve resolve all the CodeRabbit review comments.
  • @coderabbitai configuration to show the current CodeRabbit configuration for the repository.
  • @coderabbitai help to get help.

Other keywords and placeholders

  • Add @coderabbitai ignore anywhere in the PR description to prevent this PR from being reviewed.
  • Add @coderabbitai summary to generate the high-level summary at a specific location in the PR description.
  • Add @coderabbitai anywhere in the PR title to generate the title automatically.

CodeRabbit Configuration File (.coderabbit.yaml)

  • You can programmatically configure CodeRabbit by adding a .coderabbit.yaml file to the root of your repository.
  • Please see the configuration documentation for more information.
  • If your editor has YAML language server enabled, you can add the path at the top of this file to enable auto-completion and validation: # yaml-language-server: $schema=https://coderabbit.ai/integrations/schema.v2.json

Documentation and Community

  • Visit our Documentation for detailed information on how to use CodeRabbit.
  • Join our Discord Community to get help, request features, and share feedback.
  • Follow us on X/Twitter for updates and announcements.

Copy link

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

📜 Review details

Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 2bb3bd5 and 58601d1.

📒 Files selected for processing (1)
  • torch_sim/workflows/a2c.py (2 hunks)
🧰 Additional context used
🧬 Code Graph Analysis (1)
torch_sim/workflows/a2c.py (1)
torch_sim/optimizers.py (1)
  • FireState (441-482)
🪛 GitHub Check: codecov/patch
torch_sim/workflows/a2c.py

[warning] 825-825: torch_sim/workflows/a2c.py#L825
Added line #L825 was not covered by tests


[warning] 827-827: torch_sim/workflows/a2c.py#L827
Added line #L827 was not covered by tests

⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (37)
  • GitHub Check: test-examples (examples/scripts/5_Workflow/5.1_a2c_silicon_batched.py)
  • GitHub Check: test-examples (examples/scripts/5_Workflow/5.3_Elastic.py)
  • GitHub Check: test-examples (examples/scripts/5_Workflow/5.2_In_Flight_WBM.py)
  • GitHub Check: test-examples (examples/scripts/4_High_level_api/4.2_auto_batching_api.py)
  • GitHub Check: test-examples (examples/scripts/3_Dynamics/3.3_MACE_NVE_cueq.py)
  • GitHub Check: test-examples (examples/scripts/3_Dynamics/3.8_MACE_NPT_Nose_Hoover.py)
  • GitHub Check: test-examples (examples/scripts/2_Structural_optimization/2.6_MACE_UnitCellFilter_FIRE.py)
  • GitHub Check: test-examples (examples/scripts/2_Structural_optimization/2.1_Lennard_Jones_FIRE.py)
  • GitHub Check: test-examples (examples/scripts/3_Dynamics/3.11_Lennard_Jones_NPT_Langevin.py)
  • GitHub Check: test-examples (examples/scripts/1_Introduction/1.2_MACE.py)
  • GitHub Check: test-examples (examples/scripts/4_High_level_api/4.1_high_level_api.py)
  • GitHub Check: test-examples (examples/scripts/6_Phonons/6.2_QuasiHarmonic_MACE.py)
  • GitHub Check: test-examples (examples/scripts/3_Dynamics/3.6_MACE_NVT_Nose_Hoover_temp_profile.py)
  • GitHub Check: test-examples (examples/scripts/7_Others/7.4_Velocity_AutoCorrelation.py)
  • GitHub Check: test-examples (examples/scripts/7_Others/7.6_Compare_ASE_to_VV_FIRE.py)
  • GitHub Check: test-examples (examples/scripts/3_Dynamics/3.2_MACE_NVE.py)
  • GitHub Check: test-examples (examples/tutorials/hybrid_swap_tutorial.py)
  • GitHub Check: test-examples (examples/tutorials/autobatching_tutorial.py)
  • GitHub Check: test-examples (examples/tutorials/high_level_tutorial.py)
  • GitHub Check: test-examples (examples/tutorials/low_level_tutorial.py)
  • GitHub Check: test-model (ubuntu-latest, 3.12, lowest-direct, sevenn, tests/models/test_sevennet.py)
  • GitHub Check: test-model (macos-14, 3.12, lowest-direct, mace, tests/models/test_mace.py)
  • GitHub Check: test-model (macos-14, 3.12, lowest-direct, mace, tests/test_optimizers_vs_ase.py)
  • GitHub Check: test-model (macos-14, 3.11, highest, fairchem, tests/models/test_fairchem.py)
  • GitHub Check: test-model (macos-14, 3.11, highest, mace, tests/test_elastic.py)
  • GitHub Check: test-model (ubuntu-latest, 3.12, lowest-direct, orb, tests/models/test_orb.py)
  • GitHub Check: test-model (ubuntu-latest, 3.12, lowest-direct, mace, tests/test_elastic.py)
  • GitHub Check: test-model (ubuntu-latest, 3.11, highest, fairchem, tests/models/test_fairchem.py)
  • GitHub Check: test-model (ubuntu-latest, 3.12, lowest-direct, mace, tests/test_optimizers_vs_ase.py)
  • GitHub Check: test-model (ubuntu-latest, 3.11, highest, sevenn, tests/models/test_sevennet.py)
  • GitHub Check: test-model (ubuntu-latest, 3.12, lowest-direct, fairchem, tests/models/test_fairchem.py)
  • GitHub Check: test-model (ubuntu-latest, 3.11, highest, mace, tests/test_elastic.py)
  • GitHub Check: test-model (ubuntu-latest, 3.11, highest, orb, tests/models/test_orb.py)
  • GitHub Check: test-model (ubuntu-latest, 3.11, highest, mace, tests/test_optimizers_vs_ase.py)
  • GitHub Check: test-core (ubuntu-latest, 3.12, lowest-direct)
  • GitHub Check: test-core (macos-14, 3.12, lowest-direct)
  • GitHub Check: build-docs
🔇 Additional comments (1)
torch_sim/workflows/a2c.py (1)

825-827: LGTM! Explicit typing improves type safety.

The explicit float typing for final_energy and final_pressure helps the type checker understand the variable types and improves code clarity.

@curtischong curtischong changed the title Make types pass uvx ty check Make types pass uvx ty check or uvx mypy torch_sim --strict-equality Jul 16, 2025
@curtischong curtischong changed the title Make types pass uvx ty check or uvx mypy torch_sim --strict-equality Make types pass basedpyright Jul 22, 2025
@curtischong curtischong changed the title Make types pass basedpyright Make types pass uvx basedpyright Jul 23, 2025
@curtischong
Copy link
Collaborator Author

For visibility we are probably going to type the entire repo in parts. We'll merge this PR soon or break it up into more understandable parts



class ModelInterface(ABC):
class ModelInterface(torch.nn.Module, ABC):
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nice change

@CompRhys CompRhys marked this pull request as ready for review August 8, 2025 01:27
@CompRhys
Copy link
Collaborator

CompRhys commented Aug 8, 2025

@coderabbitai review

Copy link

coderabbitai bot commented Aug 8, 2025

✅ Actions performed

Review triggered.

Note: CodeRabbit is an incremental review system and does not re-review already reviewed commits. This command is applicable only when automatic reviews are paused.

Copy link

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 7

🔭 Outside diff range comments (14)
torch_sim/models/particle_life.py (2)

193-193: Return type bug: energy stored as Python float instead of torch.Tensor

results["energy"] is set to 0.0 (float), but the return type is dict[str, torch.Tensor] and forward later stacks energies across systems. This will break stacking and typing.

Apply this diff:

-        results = {"energy": 0.0}
+        results = {
+            "energy": torch.tensor(0.0, dtype=self.dtype, device=self.device)
+        }

196-199: Incorrect call signature: unexpected keyword 'epsilon' and missing required 'A'

asymmetric_particle_pair_force_jit(dr, A, beta, sigma) does not accept epsilon. Passing epsilon=... will fail type-checks and at runtime.

Use A=self.epsilon and keep sigma keyword:

-        pair_forces = asymmetric_particle_pair_force_jit(
-            distances, sigma=self.sigma, epsilon=self.epsilon
-        )
+        pair_forces = asymmetric_particle_pair_force_jit(
+            distances,
+            A=self.epsilon,
+            sigma=self.sigma,
+        )

Optionally pass beta if you intend a non-default inner radius:

+            beta=DEFAULT_BETA,
torch_sim/models/fairchem.py (1)

324-341: Align forward return type with ModelInterface

The forward method in torch_sim/models/fairchem.py must use the same return annotation as the base interface to satisfy static checks.

• File: torch_sim/models/fairchem.py, forward definition (around lines 324–341)
• Change return type from dict to dict[str, torch.Tensor]

-    def forward(self, state: ts.SimState | StateDict) -> dict:
+    def forward(self, state: ts.SimState | StateDict) -> dict[str, torch.Tensor]:

Also ensure every value in the returned dictionary is a torch.Tensor.

torch_sim/models/orb.py (1)

333-347: Fix model-path handling and isinstance union for portability

Two issues in init:

  • self.system_config = system_config or model.system_config runs before path loading; if model is a str/Path, this fails.
  • isinstance(model, str | Path) may not be portable; use a tuple for compatibility.
-        # Set up system configuration
-        self.system_config = system_config or model.system_config
-        self._max_num_neighbors = max_num_neighbors
-        self._edge_method = edge_method
-        self._half_supercell = half_supercell
-        self.conservative = conservative
-
-        # Load model if path is provided
-        if isinstance(model, str | Path):
-            model = torch.load(model, map_location=self._device)
+        # Load model if path is provided
+        if isinstance(model, (str, Path)):
+            model = torch.load(model, map_location=self._device)
+
+        # Set up system configuration (after model is resolved)
+        self.system_config = system_config or model.system_config
+        self._max_num_neighbors = max_num_neighbors
+        self._edge_method = edge_method
+        self._half_supercell = half_supercell
+        self.conservative = conservative
torch_sim/models/graphpes.py (1)

102-142: Override signature should match ModelInterface (add **kwargs)

ModelInterface.forward accepts **kwargs. Overriding without **kwargs can trigger basedpyright override-compatibility errors. Accept and ignore extra kwargs in the wrapper.

Apply this diff to the method signature:

-    def forward(self, state: ts.SimState | StateDict) -> dict[str, torch.Tensor]:
+    def forward(
+        self, state: ts.SimState | StateDict, **kwargs: Any
+    ) -> dict[str, torch.Tensor]:

Optional: instead of # type: ignore[return-value], cast to the expected mapping to keep the file clean for basedpyright:

-        return self._gp_model.predict(atomic_graph, self._properties)  # type: ignore[return-value]
+        return typing.cast(
+            dict[str, torch.Tensor],
+            self._gp_model.predict(atomic_graph, self._properties),
+        )
torch_sim/models/metatomic.py (2)

51-62: Override signature should match ModelInterface (add **kwargs)

ModelInterface.forward includes **kwargs. This override omits it, which can fail type checking with basedpyright. Accept and ignore extra kwargs.

-    def forward(  # noqa: C901, PLR0915
-        self,
-        state: ts.SimState | StateDict,
-    ) -> dict[str, torch.Tensor]:
+    def forward(  # noqa: C901, PLR0915
+        self,
+        state: ts.SimState | StateDict,
+        **kwargs: Any,
+    ) -> dict[str, torch.Tensor]:

99-116: Path handling bug: .endswith used on Path

model is typed str | Path | None but the implementation calls model.endswith(...). This will raise on Path. Normalize via Path(model) when not using the literal "pet-mad".

-        if model == "pet-mad":
+        if model == "pet-mad":
             path = "https://huggingface.co/lab-cosmo/pet-mad/resolve/v1.1.0/models/pet-mad-v1.1.0.ckpt"
             self._model = load_model(path).export()
-        elif model.endswith(".ckpt"):
-            path = model
-            self._model = load_model(path).export()
-        elif model.endswith(".pt"):
-            path = model
-            self._model = load_atomistic_model(path, extensions_path)
+        else:
+            path_obj = Path(model)  # type: ignore[arg-type]
+            if path_obj.suffix == ".ckpt":
+                self._model = load_model(str(path_obj)).export()
+            elif path_obj.suffix == ".pt":
+                self._model = load_atomistic_model(str(path_obj), extensions_path)
+            else:
+                raise ValueError('Model must be a path to a .ckpt/.pt file, or "pet-mad".')
-        else:
-            raise ValueError('Model must be a path to a .ckpt/.pt file, or "pet-mad".')
torch_sim/models/interface.py (1)

135-170: Ensure all forward overrides include **kwargs

The base ModelInterface.forward declares **kwargs to support model-specific options and satisfy LSP/type checking. The following subclasses currently define forward without **kwargs and should be updated:

  • torch_sim/models/fairchem.py:324
  • torch_sim/models/graphpes.py:177
  • torch_sim/models/lennard_jones.py:356
  • torch_sim/models/mace.py:240
  • torch_sim/models/mattersim.py:113
  • torch_sim/models/metatomic.py:151
  • torch_sim/models/morse.py:346
  • torch_sim/models/orb.py:361
  • torch_sim/models/particle_life.py:213
  • torch_sim/models/sevennet.py:152
  • torch_sim/models/soft_sphere.py:380, 807

Please revise each signature, for example:

-    def forward(self, state: SimState | StateDict) -> dict[str, torch.Tensor]:
+    def forward(self, state: SimState | StateDict, **kwargs) -> dict[str, torch.Tensor]:

This ensures all overrides remain signature-compatible and maintain proper extensibility.

torch_sim/models/soft_sphere.py (5)

283-285: Fix wrong masses shape when constructing SimState from dict

torch.ones_like(state["positions"]) produces shape [n_atoms, 3], but masses must be [n_atoms]. This will break both runtime and types.

Apply:

-        if isinstance(state, dict):
-            state = ts.SimState(**state, masses=torch.ones_like(state["positions"]))
+        if isinstance(state, dict):
+            masses = torch.ones(
+                state["positions"].shape[0], dtype=self.dtype, device=self.device
+            )
+            state = ts.SimState(**state, masses=masses)

351-353: Guard against division by zero in force projection

If any pair has zero distance (overlap), (pair_forces / distances) will produce inf/NaN. Clamp/guard the denominator.

-            force_vectors = (pair_forces / distances)[:, None] * dr_vec
+            denom = torch.where(distances > 0, distances, torch.ones_like(distances))
+            force_vectors = (pair_forces / denom)[:, None] * dr_vec

606-610: Breaks when species=None: cannot infer n_species

n_species = len(torch.unique(species)) fails if species is None. You either need species or a way to infer n_species from provided matrices. Minimal fix: infer from any provided *_matrix; otherwise require species.

-        self.species = species
-        n_species = len(torch.unique(species))
+        self.species = species
+        if species is None:
+            if sigma_matrix is not None:
+                n_species = sigma_matrix.shape[0]
+            elif epsilon_matrix is not None:
+                n_species = epsilon_matrix.shape[0]
+            elif alpha_matrix is not None:
+                n_species = alpha_matrix.shape[0]
+            else:
+                raise ValueError(
+                    "Provide `species` or at least one of {sigma_matrix, epsilon_matrix, alpha_matrix} to infer n_species."
+                )
+        else:
+            n_species = int(torch.unique(species).numel())

711-718: Typo in vesin_nl_ts argument: sortisort_id

This will raise a TypeError at runtime.

-                sorti=False,
+                sort_id=False,

870-876: Incorrect stacking/cat of results: double-processing forces/energies/stresses

You first stack 'forces', 'energies', 'stresses' and then immediately overwrite by concatenating them. Follow the same pattern as SoftSphereModel: stack only 'stress' and 'energy'; cat the rest.

-        for key in ("stress", "energy", "forces", "energies", "stresses"):
+        for key in ("stress", "energy"):
             if key in properties:
                 results[key] = torch.stack([out[key] for out in outputs])

(The following cat loop stays as-is.)

torch_sim/integrators/npt.py (1)

1460-1499: Avoid constructing NPTNoseHooverState with momenta=None (type mismatch); initialize momenta and thermostat first

NPTNoseHooverState declares momenta: torch.Tensor. Passing None will fail strict type checking (and can break runtime assumptions). Compute momenta earlier and pass it into the constructor. Also initialize the thermostat using the computed KE.

-        # Create initial state
-        npt_state = NPTNoseHooverState(
+        # Initialize particle momenta before building the state
+        momenta = kwargs.get(
+            "momenta",
+            calculate_momenta(
+                state.positions, state.masses, state.system_idx, kT, seed
+            ),
+        )
+
+        # Initialize thermostat using particle KE
+        KE = calc_kinetic_energy(
+            masses=state.masses, momenta=momenta, system_idx=state.system_idx
+        )
+        thermostat = thermostat_fns.initialize(state.positions.numel(), KE, kT)
+
+        # Create initial state with concrete tensors (no Nones)
+        npt_state = NPTNoseHooverState(
             positions=state.positions,
-            momenta=None,
+            momenta=momenta,
             energy=energy,
             forces=forces,
             masses=state.masses,
             atomic_numbers=atomic_numbers,
             cell=state.cell,
             pbc=state.pbc,
             system_idx=state.system_idx,
             reference_cell=reference_cell,
             cell_position=cell_position,
             cell_momentum=cell_momentum,
             cell_mass=cell_mass,
             barostat=barostat_fns.initialize(1, KE_cell, kT),
-            thermostat=None,
+            thermostat=thermostat,
             barostat_fns=barostat_fns,
             thermostat_fns=thermostat_fns,
         )
-
-        # Initialize momenta
-        momenta = kwargs.get(
-            "momenta",
-            calculate_momenta(
-                npt_state.positions, npt_state.masses, npt_state.system_idx, kT, seed
-            ),
-        )
-
-        # Initialize thermostat
-        npt_state.momenta = momenta
-        KE = calc_kinetic_energy(
-            momenta=npt_state.momenta,
-            masses=npt_state.masses,
-            system_idx=npt_state.system_idx,
-        )
-        npt_state.thermostat = thermostat_fns.initialize(
-            npt_state.positions.numel(), KE, kT
-        )
♻️ Duplicate comments (1)
torch_sim/models/interface.py (1)

38-38: Base class now inherits nn.Module — agreed

Consolidating nn.Module at the interface simplifies model inheritance and typing.

🧹 Nitpick comments (4)
torch_sim/models/particle_life.py (2)

51-69: Docstring says “Energies” but function returns forces

asymmetric_particle_pair_force_jit computes force magnitudes, not energies. Update the Returns description to prevent confusion and improve type-check clarity.


184-185: I’d like to inspect the full mapping block in torch_sim/models/particle_life.py to verify why it’s reversed. Could you run:

rg -n "mapping = torch.stack" -A 10 torch_sim/models/particle_life.py
torch_sim/models/mattersim.py (1)

24-34: LGTM: Placeholder class now correctly derives from ModelInterface

This keeps typing consistent when imports fail and still raises eagerly in init. One nit: the docstring says “sevenn is not installed” — should be “mattersim”.

-        It raises an ImportError if sevenn is not installed.
+        It raises an ImportError if MatterSim is not installed.
torch_sim/models/metatomic.py (1)

79-97: Docstring vs implementation mismatch for model=None behavior

Docstring says “If None, defaults to 'pet-mad'” but the code raises ValueError when model is None. Please align behavior with docs, or update the docstring and type to remove None.

Two options:

  • Update implementation to default to "pet-mad":
-        if model is None:
-            raise ValueError(
-                "A model path, or the name of a pre-defined model, must be provided. "
-                'Currently only "pet-mad" is available as a pre-defined model.'
-            )
+        if model is None:
+            model = "pet-mad"
  • Or update the docstring and annotation (remove | None) to reflect that a value is required.
📜 Review details

Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 58601d1 and 30a4350.

📒 Files selected for processing (45)
  • .github/PULL_REQUEST_TEMPLATE.md (1 hunks)
  • examples/scripts/2_Structural_optimization/2.1_Lennard_Jones_FIRE.py (1 hunks)
  • examples/scripts/2_Structural_optimization/2.2_Soft_Sphere_FIRE.py (1 hunks)
  • examples/scripts/3_Dynamics/3.11_Lennard_Jones_NPT_Langevin.py (2 hunks)
  • examples/scripts/3_Dynamics/3.1_Lennard_Jones_NVE.py (1 hunks)
  • examples/scripts/3_Dynamics/3.2_MACE_NVE.py (1 hunks)
  • examples/scripts/3_Dynamics/3.4_MACE_NVT_Langevin.py (1 hunks)
  • examples/scripts/3_Dynamics/3.7_Lennard_Jones_NPT_Nose_Hoover.py (1 hunks)
  • examples/scripts/4_High_level_api/4.1_high_level_api.py (1 hunks)
  • examples/scripts/6_Phonons/6.2_QuasiHarmonic_MACE.py (3 hunks)
  • examples/tutorials/high_level_tutorial.py (1 hunks)
  • examples/tutorials/reporting_tutorial.py (2 hunks)
  • tests/models/test_mattersim.py (3 hunks)
  • tests/models/test_sevennet.py (3 hunks)
  • tests/test_integrators.py (4 hunks)
  • tests/test_monte_carlo.py (2 hunks)
  • tests/test_optimizers.py (15 hunks)
  • tests/test_runners.py (5 hunks)
  • tests/test_state.py (1 hunks)
  • tests/test_trajectory.py (2 hunks)
  • tests/workflows/test_a2c.py (2 hunks)
  • torch_sim/elastic.py (2 hunks)
  • torch_sim/integrators/md.py (2 hunks)
  • torch_sim/integrators/npt.py (8 hunks)
  • torch_sim/integrators/nve.py (1 hunks)
  • torch_sim/integrators/nvt.py (5 hunks)
  • torch_sim/models/fairchem.py (2 hunks)
  • torch_sim/models/graphpes.py (2 hunks)
  • torch_sim/models/interface.py (2 hunks)
  • torch_sim/models/lennard_jones.py (1 hunks)
  • torch_sim/models/mace.py (2 hunks)
  • torch_sim/models/mattersim.py (2 hunks)
  • torch_sim/models/metatomic.py (2 hunks)
  • torch_sim/models/morse.py (1 hunks)
  • torch_sim/models/orb.py (2 hunks)
  • torch_sim/models/particle_life.py (1 hunks)
  • torch_sim/models/sevennet.py (2 hunks)
  • torch_sim/models/soft_sphere.py (3 hunks)
  • torch_sim/monte_carlo.py (2 hunks)
  • torch_sim/optimizers.py (10 hunks)
  • torch_sim/quantities.py (6 hunks)
  • torch_sim/runners.py (1 hunks)
  • torch_sim/state.py (10 hunks)
  • torch_sim/trajectory.py (4 hunks)
  • torch_sim/workflows/a2c.py (8 hunks)
✅ Files skipped from review due to trivial changes (16)
  • examples/scripts/3_Dynamics/3.4_MACE_NVT_Langevin.py
  • examples/scripts/4_High_level_api/4.1_high_level_api.py
  • examples/scripts/3_Dynamics/3.2_MACE_NVE.py
  • .github/PULL_REQUEST_TEMPLATE.md
  • examples/scripts/2_Structural_optimization/2.1_Lennard_Jones_FIRE.py
  • examples/scripts/2_Structural_optimization/2.2_Soft_Sphere_FIRE.py
  • torch_sim/monte_carlo.py
  • torch_sim/integrators/md.py
  • torch_sim/models/mace.py
  • torch_sim/integrators/nve.py
  • torch_sim/trajectory.py
  • torch_sim/elastic.py
  • tests/test_integrators.py
  • examples/scripts/6_Phonons/6.2_QuasiHarmonic_MACE.py
  • torch_sim/integrators/nvt.py
  • torch_sim/optimizers.py
🧰 Additional context used
🧬 Code Graph Analysis (17)
tests/test_monte_carlo.py (2)
torch_sim/models/interface.py (1)
  • ModelInterface (38-170)
tests/conftest.py (1)
  • lj_model (34-45)
examples/scripts/3_Dynamics/3.11_Lennard_Jones_NPT_Langevin.py (1)
torch_sim/quantities.py (1)
  • calc_kinetic_energy (101-134)
torch_sim/models/particle_life.py (1)
torch_sim/models/interface.py (1)
  • ModelInterface (38-170)
tests/test_state.py (3)
torch_sim/integrators/nvt.py (1)
  • velocities (270-274)
torch_sim/integrators/npt.py (1)
  • velocities (871-877)
torch_sim/integrators/md.py (1)
  • velocities (45-49)
tests/workflows/test_a2c.py (1)
torch_sim/optimizers.py (2)
  • FireState (442-483)
  • UnitCellFireState (621-692)
torch_sim/models/sevennet.py (1)
torch_sim/models/interface.py (1)
  • ModelInterface (38-170)
torch_sim/models/fairchem.py (1)
torch_sim/models/interface.py (1)
  • ModelInterface (38-170)
torch_sim/models/lennard_jones.py (1)
torch_sim/models/interface.py (1)
  • ModelInterface (38-170)
torch_sim/models/mattersim.py (1)
torch_sim/models/interface.py (1)
  • ModelInterface (38-170)
torch_sim/models/morse.py (1)
torch_sim/models/interface.py (1)
  • ModelInterface (38-170)
torch_sim/models/graphpes.py (2)
torch_sim/models/interface.py (1)
  • ModelInterface (38-170)
tests/models/test_graphpes.py (2)
  • ts_nequip_model (128-134)
  • ts_mace_model (155-161)
tests/test_trajectory.py (1)
torch_sim/models/interface.py (1)
  • ModelInterface (38-170)
torch_sim/models/metatomic.py (1)
torch_sim/models/interface.py (1)
  • ModelInterface (38-170)
torch_sim/models/orb.py (1)
torch_sim/models/interface.py (1)
  • ModelInterface (38-170)
examples/tutorials/reporting_tutorial.py (2)
torch_sim/models/interface.py (1)
  • ModelInterface (38-170)
torch_sim/state.py (1)
  • SimState (26-402)
torch_sim/quantities.py (4)
torch_sim/integrators/npt.py (2)
  • momenta (71-73)
  • velocities (871-877)
torch_sim/integrators/nvt.py (1)
  • velocities (270-274)
torch_sim/integrators/md.py (1)
  • velocities (45-49)
torch_sim/units.py (1)
  • MetalUnits (83-103)
torch_sim/workflows/a2c.py (2)
torch_sim/models/interface.py (1)
  • ModelInterface (38-170)
torch_sim/optimizers.py (2)
  • FireState (442-483)
  • UnitCellFireState (621-692)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (11)
  • GitHub Check: test-examples (examples/scripts/7_Others/7.4_Velocity_AutoCorrelation.py)
  • GitHub Check: test-examples (examples/tutorials/high_level_tutorial.py)
  • GitHub Check: test-model (macos-14, 3.12, lowest-direct, sevenn, tests/models/test_sevennet.py)
  • GitHub Check: test-model (macos-14, 3.11, highest, fairchem, tests/models/test_fairchem.py)
  • GitHub Check: test-model (macos-14, 3.12, lowest-direct, fairchem, tests/models/test_fairchem.py)
  • GitHub Check: test-model (macos-14, 3.11, highest, mattersim, tests/models/test_mattersim.py)
  • GitHub Check: test-model (macos-14, 3.11, highest, metatomic, tests/models/test_metatomic.py)
  • GitHub Check: test-model (ubuntu-latest, 3.12, lowest-direct, orb, tests/models/test_orb.py)
  • GitHub Check: test-model (ubuntu-latest, 3.12, lowest-direct, sevenn, tests/models/test_sevennet.py)
  • GitHub Check: test-model (ubuntu-latest, 3.11, highest, sevenn, tests/models/test_sevennet.py)
  • GitHub Check: build-docs
🔇 Additional comments (58)
examples/scripts/3_Dynamics/3.1_Lennard_Jones_NVE.py (1)

81-81: LGTM: purely a formatting consolidation of SimState kwargs

No functional changes; kwargs are correct and consistent with surrounding examples.

tests/workflows/test_a2c.py (3)

1-2: LGTM! Appropriate import addition.

The cast import from typing is correctly added to support the type casting needed later in the test.


9-9: LGTM! Correct import addition.

Adding FireState import is necessary for the type cast used later in the test function.


160-160: LGTM! Proper handling of ambiguous return type.

The cast to FireState correctly handles the fact that random_packed_structure can return either a FireState or a tuple. This ensures type safety while maintaining test functionality.

torch_sim/quantities.py (5)

3-4: LGTM! Necessary import addition.

The cast import is correctly added to support type casting in the function implementations.


25-31: LGTM! Excellent improvement to argument handling.

Making arguments keyword-only with the leading * and making momenta optional improves the API design. This enforces clearer calling conventions and prevents positional argument mistakes.


44-50: LGTM! Improved validation logic.

The XOR validation using ((momenta is not None) ^ (velocities is not None)) is more concise and clearer than separate checks. The explicit cast to torch.Tensor when velocities is guaranteed to be non-None ensures type safety.


73-98: LGTM! Consistent parameter improvements.

The same improvements applied to calc_temperature: keyword-only arguments, optional momenta parameter, and explicit keyword arguments in the calc_kT call. The MetalUnits type annotation is more precise than the previous generic object type.


101-122: LGTM! Consistent implementation pattern.

The calc_kinetic_energy function follows the same improved pattern with keyword-only arguments, optional momenta parameter, and consolidated XOR validation logic.

torch_sim/state.py (7)

11-11: LGTM! Appropriate imports for enhanced type safety.

The additional imports support the generic type improvements and explicit casting used throughout the file.


112-147: LGTM! Fixed constructor validation logic.

The system_idx validation and assignment logic is now more robust. The validation properly handles the case where system_idx is provided, and the assignment is correctly performed after validation.


254-259: LGTM! Proper error handling for volume property.

Raising a ValueError for non-PBC systems is the correct approach since volume is undefined without periodic boundaries. This is better than returning None which could lead to silent errors downstream.


368-368: LGTM! Explicit type casting for method return.

The cast ensures the return type matches the method signature while maintaining type safety.


408-420: LGTM! Simplified mixin interface.

Removing the momenta property and adding the required attributes simplifies the mixin interface. The reference_row_vector_cell property provides a clean way to access the reference cell in row vector format.


483-490: LGTM! Enhanced type safety with generics.

The SimStateT generic type variable and its usage in function signatures preserve subclass types through operations like device conversion, which is excellent for type safety.


972-972: LGTM! Appropriate type casting for validation.

The cast ensures type safety when accessing the n_systems property on list elements during validation.

examples/scripts/3_Dynamics/3.7_Lennard_Jones_NPT_Nose_Hoover.py (1)

100-100: LGTM! Consistent argument ordering.

Moving pbc=True after atomic_numbers maintains consistency with other example scripts while preserving functionality.

torch_sim/models/morse.py (1)

115-115: LGTM! Consistent interface inheritance.

Removing torch.nn.Module inheritance aligns with the architectural change where ModelInterface now provides the PyTorch module functionality. This centralizes the inheritance pattern and reduces duplication across model implementations.

torch_sim/models/particle_life.py (2)

86-86: Inheritance change aligns with new interface hierarchy

Good move inheriting only from ModelInterface now that it subclasses nn.Module. super().init() is present.


155-161: Ignore CUDA scalar conversion suggestion

The vesin_nl_ts signature clearly declares

def vesin_nl_ts(
    …,
    cutoff: torch.Tensor,
    …
)

so it expects a tensor (not a Python float). Converting self.cutoff to a Python float—even with .item()—would break the intended API. You can leave the call as-is.

Likely an incorrect or invalid review comment.

torch_sim/models/lennard_jones.py (2)

122-122: Inheritance change matches new design

Switching to inherit only from ModelInterface (which is an nn.Module) is consistent with the refactor.


265-271: Ignore float conversion suggestion for cutoff
The vesin_nl_ts function is scripted to accept a 0-dim torch.Tensor and internally invokes cutoff.item(), so passing self.cutoff (even on CUDA) is correct. Converting it to a Python float would mismatch the declared torch.Tensor parameter and break the JIT call.

Likely an incorrect or invalid review comment.

examples/tutorials/high_level_tutorial.py (1)

135-136: Keyword-only call is correct

Switch to keyword args for calc_kinetic_energy matches the updated signature and improves clarity.

tests/test_monte_carlo.py (2)

6-6: Type import aligns test with new model interface

Importing ModelInterface keeps tests consistent with the refactor.


116-117: Fixture type updated to ModelInterface is appropriate

Typing lj_model as ModelInterface matches how models are consumed across the codebase.

torch_sim/models/fairchem.py (2)

48-56: Placeholder class now inherits ModelInterface

Good change; keeps type checkers happy even when FairChem isn’t installed. Raising in init is fine here.


73-73: Main class inheritance updated appropriately

Inheriting only from ModelInterface centralizes nn.Module via the interface. super().init() is correctly called in init.

tests/test_trajectory.py (2)

12-12: ModelInterface import aligns the test with the new model abstraction

Good change. This keeps the test consistent with ModelInterface-based typing across the codebase.


752-756: Type-narrowing to ModelInterface for energy_calculator — LGTM

Annotating the model parameter as ModelInterface is correct here since the calculator calls model(state) and accesses "energy".

examples/tutorials/reporting_tutorial.py (2)

209-209: Importing ModelInterface for tutorial examples — good consistency

This ensures the tutorial code type-checks under basedpyright with the new interface.


218-221: Update calculate_energy to accept ModelInterface — correct and minimal

Accurately reflects the contract that models implement ModelInterface and return an output dict with "energy".

examples/scripts/3_Dynamics/3.11_Lennard_Jones_NPT_Langevin.py (2)

101-102: Keyword argument order tweak for SimState — OK

Reordering pbc after atomic_numbers using keyword args is harmless and reads cleanly.


151-156: Keyword-only kinetic energy call and variable rename are correct

  • calc_kinetic_energy now requires keyword args; usage is correct.
  • Passing kinetic_energy to get_pressure matches the expected signature (stress, kinetic_energy, volume).
tests/test_state.py (1)

503-505: Require velocities and reference_cell in DeformState — good alignment

Making these parameters required reflects the updated mixin expectations and clarifies test intent.

torch_sim/models/mattersim.py (1)

42-55: LGTM: Inherit only from ModelInterface

Given ModelInterface already subclasses torch.nn.Module, this simplifies the hierarchy without functional changes. super().__init__() still initializes nn.Module state.

torch_sim/models/orb.py (2)

42-52: LGTM: Placeholder class now uses ModelInterface

Consistent with the new interface-first design and keeps the eager ImportError behavior.


250-274: LGTM: Inherit only from ModelInterface

Matches the project-wide refactor and maintains proper nn.Module initialization via super().__init__().

torch_sim/models/sevennet.py (2)

35-45: LGTM: Placeholder class inherits ModelInterface

Keeps error surfacing while maintaining consistent typing.


47-59: LGTM: Inherit only from ModelInterface

Consistent with project-wide refactor; nn.Module is still initialized via super().__init__().

torch_sim/models/graphpes.py (1)

39-49: ImportError fallback: inheritance change is correct

Using ModelInterface alone (dropping direct nn.Module) in the stub keeps the hierarchy consistent now that ModelInterface subclasses nn.Module.

torch_sim/models/metatomic.py (1)

39-49: ImportError fallback: inheritance change is correct

Using ModelInterface alone (dropping direct nn.Module) in the stub keeps the hierarchy consistent.

tests/test_runners.py (1)

26-30: Keyword-arg calls to calc_kinetic_energy are correct

Switching to momenta=..., masses=... improves clarity and typing; matches the current API and helps basedpyright.

Also applies to: 63-67, 119-123, 170-174

torch_sim/models/interface.py (1)

72-75: Declaring protected attributes at class level is the right move

These declarations unblock basedpyright by guaranteeing attribute presence on subclasses.

torch_sim/models/soft_sphere.py (3)

438-438: Multi-model now inherits ModelInterface — good alignment with the refactor

No issues with the base class swap.


597-602: Private attribute renaming is fine; confirm public properties back these

Setting _device/_dtype/_compute_* is consistent with SoftSphereModel. As above, ensure ModelInterface exposes the public device/dtype/compute_* properties used elsewhere.

Same script as in Line 133 comment can be used to confirm the properties exist.


133-133: ModelInterface defines all required properties
Verified that ModelInterface declares the following @property methods in torch_sim/models/interface.py:

  • device (l. 77)
  • dtype (l. 89)
  • compute_stress (l. 101)
  • compute_forces (l. 113)

No further changes needed.

torch_sim/integrators/npt.py (2)

17-17: Importing ModelInterface for type annotations — good

This aligns the integrators with the new interface abstraction.


1436-1436: Consistent keyword-only calls to calc_kinetic_energy — LGTM

All updated calls now use masses=/momenta= and pass system_idx where per-system results are needed. This matches the API change and improves clarity.

Also applies to: 1491-1494, 1550-1552, 1555-1555, 1607-1608

torch_sim/workflows/a2c.py (8)

22-22: LGTM! Import added for ModelInterface.

The import is correctly added to support the type annotation changes throughout the file.


232-232: LGTM! Return type annotation correctly reflects conditional behavior.

The union type FireState | tuple[FireState, list[np.ndarray]] accurately represents that the function returns either a FireState or a tuple based on whether log is provided.


580-597: LGTM! Performance optimization for composition handling.

The refactoring to convert restrict_to_compositions to a set upfront and use .add() instead of list operations is more efficient. The explicit type annotations also improve type safety.


614-614: LGTM! Explicit type annotation improves type safety.

The explicit typing of candidates as list[tuple[torch.Tensor, torch.Tensor, torch.Tensor]] makes the return type clear and helps with type checking.


709-709: LGTM! Parameter type updated to ModelInterface.

The change from torch.nn.Module to ModelInterface is consistent with the codebase refactoring and provides more specific typing for models in the simulation framework.


711-711: LGTM! Return type expanded to include energy and pressure lists.

The return type correctly reflects that the function now returns final energies and pressures as lists of floats in addition to the state and logger. This is consistent with the function implementation that prints and returns these values.


755-756: LGTM! Type annotations updated for consistency.

The parameter and return type annotations are properly updated to use the new type system with explicit UnitCellFireState and dict[str, torch.Tensor] types.


330-331: Missing return statement in else branch.

The function has a conditional return at line 329-330 but is missing a return statement for the case when log is None. This will cause a runtime error.

Apply this diff to fix the missing return:

     if log is not None:
         return state, log
-
+    
     return state

Likely an incorrect or invalid review comment.

tests/test_optimizers.py (2)

9-9: LGTM! Import added for ModelInterface.

The import is correctly added to support the updated type annotations throughout the test file.


27-27: LGTM! Model parameter types updated to ModelInterface.

All test functions now correctly use ModelInterface instead of torch.nn.Module for the lj_model parameter. This is consistent with the codebase refactoring where models inherit from ModelInterface rather than directly from torch.nn.Module.

Also applies to: 66-66, 115-115, 189-189, 205-205, 213-213, 276-276, 329-329, 418-418, 452-452, 471-471, 507-507, 596-596, 711-711, 787-787

Comment on lines +365 to +369
1: {
"ke": lambda state: calc_kinetic_energy(
velocities=state.velocities, masses=state.masses
)
}
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

💡 Verification agent

🧩 Analysis chain

Verify velocities usage in FIRE reporter

Using velocities=state.velocities is fine if SimState guarantees velocities during FIRE optimization and calc_kinetic_energy supports it. Please sanity-check both assumptions.

Run this script to verify the function signature and SimState fields:


🏁 Script executed:

#!/bin/bash
set -euo pipefail

echo "Check calc_kinetic_energy signature:"
rg -n "def calc_kinetic_energy\\(" -A 5 torch_sim | sed -n '1,12p'

echo
echo "Confirm it accepts 'velocities' and/or 'momenta':"
rg -n "calc_kinetic_energy\\(" -A 2 torch_sim | rg -n "velocities|momenta|masses" -n

echo
echo "Check that SimState defines 'velocities' in contexts used by FIRE:"
rg -n "class SimState" -n torch_sim
rg -n "velocities" torch_sim | head -n 20

Length of output: 4095


Ensure SimState provides velocities for the FIRE reporter

While calc_kinetic_energy does accept a velocities tensor (see torch_sim/quantities.py:def calc_kinetic_energy signature), SimState does not currently define or populate a velocities attribute:

• torch_sim/quantities.py (lines 101–107): supports velocities: Tensor | None
• torch_sim/state.py (line 966): TODO to pass velocities from external engines, no field defined
• torch_sim/runners.py (lines 49, 52): reporter calls velocities=state.velocities

Please address this by either:

  • Adding/populating a velocities field on SimState during FIRE optimizations, or
  • Switching the reporter to use momenta=state.momenta (which SimState does define)
🤖 Prompt for AI Agents
In tests/test_runners.py around lines 365 to 369, the reporter calls
calc_kinetic_energy with state.velocities, but SimState does not define or
populate a velocities attribute. To fix this, either add and populate a
velocities field on SimState during FIRE optimizations so it can be passed
correctly, or modify the reporter to use momenta=state.momenta instead, since
momenta is already defined on SimState.

Comment on lines +48 to +53
"kinetic_energy": lambda state: calc_kinetic_energy(
velocities=state.velocities, masses=state.masses
),
"temperature": lambda state: calc_kT(
velocities=state.velocities, masses=state.masses
),
Copy link

@coderabbitai coderabbitai bot Aug 8, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

💡 Verification agent

🧩 Analysis chain

Kinetic/temperature calculators now depend on velocities — verify availability or guard

Switch to keyword args looks good and aligns with API. However, if any integrator/state doesn’t populate state.velocities, these lambdas will raise at runtime. Either ensure all integrate paths guarantee velocities, or guard the calculators (e.g., skip or return NaN when velocities are absent).

You can quickly sanity-check for residual usage of momenta vs. assured velocities support with:


🏁 Script executed:

#!/bin/bash
# Inspect SimState and integrators for velocities/momenta usage
rg -n "velocities" torch_sim | head -n 50
rg -n "momenta" torch_sim | head -n 50
# Check if SimState defines 'velocities' as required attr
rg -n "class SimState" -A 40 torch_sim/state.py

Length of output: 11438


🏁 Script executed:

#!/bin/bash
# Inspect runner reporters setup for state.velocities usage
rg -n "reporters" -A30 torch_sim/runners.py

# Inspect how Trajectory uses save_velocities and when velocities are attached
rg -n "class Trajectory" -A200 torch_sim/trajectory.py

# Check save_velocities flag defaults and how velocities propagate into state
rg -n "save_velocities" -A20 torch_sim/trajectory.py

Length of output: 15077


🏁 Script executed:

#!/bin/bash
# Check for StaticState definition in the codebase
rg -n "class StaticState" -A20 torch_sim/state.py

Length of output: 51


🏁 Script executed:

#!/bin/bash
# Locate StaticState definition across the repo
rg -n "class StaticState" -A20 .

# Locate MDState definition across the repo
rg -n "class MDState" -A20 .

# Show the top of runners.py (including default prop_calculators)
sed -n '1,100p' torch_sim/runners.py

Length of output: 6753


Guard kinetic/temperature reporters against missing velocities

In _configure_reporter (torch_sim/runners.py), the default kinetic_energy and temperature lambdas assume state.velocities always exists. While MDState provides a velocities property, base SimState and static‐optimization states do not, leading to an AttributeError if users request these props outside an MD run.

• Location:

  • torch_sim/runners.py, around lines 46–53 in _configure_reporter
    • Suggested change: wrap each calculator in a guard that checks for state.velocities and returns NaN (or skips) when unavailable

Example diff:

--- a/torch_sim/runners.py
+++ b/torch_sim/runners.py
@@ def _configure_reporter(...):
-    possible_properties = {
-        "kinetic_energy": lambda state: calc_kinetic_energy(
-            velocities=state.velocities, masses=state.masses
-        ),
-        "temperature": lambda state: calc_kT(
-            velocities=state.velocities, masses=state.masses
-        ),
+    possible_properties = {
+        "kinetic_energy": lambda state: (
+            calc_kinetic_energy(velocities=state.velocities, masses=state.masses)
+            if getattr(state, "velocities", None) is not None
+            else torch.tensor(float("nan"), device=state.device, dtype=state.dtype)
+        ),
+        "temperature": lambda state: (
+            calc_kT(velocities=state.velocities, masses=state.masses)
+            if getattr(state, "velocities", None) is not None
+            else torch.tensor(float("nan"), device=state.device, dtype=state.dtype)
+        ),

This ensures reporters won’t crash when running on states without velocities.

📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
"kinetic_energy": lambda state: calc_kinetic_energy(
velocities=state.velocities, masses=state.masses
),
"temperature": lambda state: calc_kT(
velocities=state.velocities, masses=state.masses
),
possible_properties = {
"kinetic_energy": lambda state: (
calc_kinetic_energy(velocities=state.velocities, masses=state.masses)
if getattr(state, "velocities", None) is not None
else torch.tensor(float("nan"), device=state.device, dtype=state.dtype)
),
"temperature": lambda state: (
calc_kT(velocities=state.velocities, masses=state.masses)
if getattr(state, "velocities", None) is not None
else torch.tensor(float("nan"), device=state.device, dtype=state.dtype)
),
# … other properties …
}
🤖 Prompt for AI Agents
In torch_sim/runners.py around lines 48 to 53, the lambdas for "kinetic_energy"
and "temperature" assume state.velocities always exists, which can cause
AttributeError for states without velocities. Modify these lambdas to first
check if state.velocities is present; if not, return NaN to prevent crashes.
This guard ensures the reporters handle states lacking velocities gracefully.

Copy link
Collaborator Author

@curtischong curtischong Aug 8, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I suspect this is a larger issue. we'll prob fix this in another PR

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@curtischong, that makes perfect sense! Keeping the scope focused on type checking compliance for this PR is a good approach. The runtime safety issue with missing velocities can definitely be addressed separately.

Would you like me to create a GitHub issue to track this for the future PR? It would help ensure the velocities availability issue doesn't get forgotten.


✏️ Learnings added
Learnt from: curtischong
PR: Radical-AI/torch-sim#215
File: torch_sim/runners.py:48-53
Timestamp: 2025-08-08T02:04:31.517Z
Learning: In the torch-sim project, when reviewing PRs, scope expansion issues (like runtime safety improvements) that are outside the current PR's main objective should be acknowledged but deferred to separate PRs rather than insisting on immediate fixes.

Note: Learnings are effective only in the context of similar code segments. To apply general coding guidelines, please configure review instructions. You can manage existing learnings in the app.

[skip ci]

Signed-off-by: Rhys Goodall <[email protected]>
torch.Tensor: Scalar temperature value
"""
if momenta is not None and velocities is not None:
if not ((momenta is not None) ^ (velocities is not None)):
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think the value error is slightly misleading. if both of them are none then the xor will evaluate to false. but the not will negate it and the entire statement will be true.

so if they pass in no velocity and no mementia, the value error will raise. but the "not both" is misleading

Copy link
Collaborator Author

@curtischong curtischong Aug 8, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is a minor nit though. Oh I get it. you merged the if statements. How about: "Must pass in one of either momenta or velocities"

If system_idx is provided: Tensor of kinetic energies per system
"""
if momenta is not None and velocities is not None:
if not ((momenta is not None) ^ (velocities is not None)):
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

maybe update this value error as well?

@CompRhys CompRhys changed the title Make types pass uvx basedpyright Improve Typing of ModelInterface Aug 8, 2025
Copy link

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 0

🧹 Nitpick comments (2)
tests/test_quantities.py (2)

3-3: Use public API for tensor type annotation.

Importing from torch._tensor accesses a private module. Use torch.Tensor directly for better compatibility with type checkers.

-from torch._tensor import Tensor
+

Then update the type annotations to use torch.Tensor instead of Tensor:

-def single_system_data() -> dict[str, Tensor]:
+def single_system_data() -> dict[str, torch.Tensor]:
-def batched_system_data() -> dict[str, Tensor]:
+def batched_system_data() -> dict[str, torch.Tensor]:
-def test_calc_kinetic_energy_single_system(single_system_data: dict[str, Tensor]) -> None:
+def test_calc_kinetic_energy_single_system(single_system_data: dict[str, torch.Tensor]) -> None:

And similarly for other function signatures.


127-137: Consider adding batched system test coverage.

The temperature calculation test correctly verifies the unit conversion relationship with kT. Consider adding a test case for batched systems to match the coverage of other test functions.

def test_calc_temperature_batched_system(batched_system_data: dict[str, torch.Tensor]) -> None:
    temp = quantities.calc_temperature(
        masses=batched_system_data["masses"],
        velocities=batched_system_data["velocities"],
        system_idx=batched_system_data["system_idx"],
    )
    kt = quantities.calc_kT(
        masses=batched_system_data["masses"],
        velocities=batched_system_data["velocities"],
        system_idx=batched_system_data["system_idx"],
    )
    assert torch.allclose(temp, kt / MetalUnits.temperature)
📜 Review details

Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 212f13e and 923d213.

📒 Files selected for processing (2)
  • tests/test_quantities.py (1 hunks)
  • torch_sim/quantities.py (6 hunks)
🚧 Files skipped from review as they are similar to previous changes (1)
  • torch_sim/quantities.py
🧰 Additional context used
🧠 Learnings (1)
📓 Common learnings
Learnt from: curtischong
PR: Radical-AI/torch-sim#215
File: torch_sim/runners.py:48-53
Timestamp: 2025-08-08T02:04:31.517Z
Learning: In the torch-sim project, when reviewing PRs, scope expansion issues (like runtime safety improvements) that are outside the current PR's main objective should be acknowledged but deferred to separate PRs rather than insisting on immediate fixes.
🧬 Code Graph Analysis (1)
tests/test_quantities.py (2)
torch_sim/units.py (1)
  • MetalUnits (83-103)
torch_sim/quantities.py (3)
  • calc_kinetic_energy (101-134)
  • calc_kT (25-70)
  • calc_temperature (73-97)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (58)
  • GitHub Check: test-examples (examples/scripts/7_Others/7.3_Batched_neighbor_list.py)
  • GitHub Check: test-examples (examples/scripts/7_Others/7.4_Velocity_AutoCorrelation.py)
  • GitHub Check: test-examples (examples/scripts/7_Others/7.6_Compare_ASE_to_VV_FIRE.py)
  • GitHub Check: test-examples (examples/scripts/7_Others/7.1_Soft_sphere_autograd.py)
  • GitHub Check: test-examples (examples/scripts/1_Introduction/1.3_Fairchem.py)
  • GitHub Check: test-examples (examples/scripts/1_Introduction/1.2_MACE.py)
  • GitHub Check: test-examples (examples/scripts/6_Phonons/6.2_QuasiHarmonic_MACE.py)
  • GitHub Check: test-examples (examples/scripts/6_Phonons/6.1_Phonons_MACE.py)
  • GitHub Check: test-examples (examples/scripts/2_Structural_optimization/2.1_Lennard_Jones_FIRE.py)
  • GitHub Check: test-examples (examples/scripts/2_Structural_optimization/2.5_MACE_UnitCellFilter_Gradient_Descen...
  • GitHub Check: test-examples (examples/scripts/6_Phonons/6.3_Conductivity_MACE.py)
  • GitHub Check: test-examples (examples/scripts/3_Dynamics/3.12_MACE_NPT_Langevin.py)
  • GitHub Check: test-examples (examples/scripts/2_Structural_optimization/2.3_MACE_Gradient_Descent.py)
  • GitHub Check: test-examples (examples/scripts/2_Structural_optimization/2.7_MACE_FrechetCellFilter_FIRE.py)
  • GitHub Check: test-examples (examples/scripts/3_Dynamics/3.7_Lennard_Jones_NPT_Nose_Hoover.py)
  • GitHub Check: test-examples (examples/scripts/2_Structural_optimization/2.4_MACE_FIRE.py)
  • GitHub Check: test-examples (examples/scripts/3_Dynamics/3.9_MACE_NVT_staggered_stress.py)
  • GitHub Check: test-examples (examples/scripts/5_Workflow/5.2_In_Flight_WBM.py)
  • GitHub Check: test-examples (examples/scripts/2_Structural_optimization/2.6_MACE_UnitCellFilter_FIRE.py)
  • GitHub Check: test-examples (examples/scripts/4_High_level_api/4.1_high_level_api.py)
  • GitHub Check: test-examples (examples/scripts/4_High_level_api/4.2_auto_batching_api.py)
  • GitHub Check: test-examples (examples/scripts/5_Workflow/5.3_Elastic.py)
  • GitHub Check: test-examples (examples/scripts/5_Workflow/5.1_a2c_silicon_batched.py)
  • GitHub Check: test-examples (examples/scripts/3_Dynamics/3.8_MACE_NPT_Nose_Hoover.py)
  • GitHub Check: test-examples (examples/scripts/3_Dynamics/3.10_Hybrid_swap_mc.py)
  • GitHub Check: test-examples (examples/scripts/3_Dynamics/3.4_MACE_NVT_Langevin.py)
  • GitHub Check: test-examples (examples/scripts/3_Dynamics/3.6_MACE_NVT_Nose_Hoover_temp_profile.py)
  • GitHub Check: test-examples (examples/scripts/3_Dynamics/3.3_MACE_NVE_cueq.py)
  • GitHub Check: test-examples (examples/scripts/3_Dynamics/3.2_MACE_NVE.py)
  • GitHub Check: test-examples (examples/scripts/3_Dynamics/3.13_MACE_NVE_non_pbc.py)
  • GitHub Check: test-examples (examples/scripts/3_Dynamics/3.5_MACE_NVT_Nose_Hoover.py)
  • GitHub Check: test-examples (examples/tutorials/hybrid_swap_tutorial.py)
  • GitHub Check: test-examples (examples/tutorials/low_level_tutorial.py)
  • GitHub Check: test-examples (examples/tutorials/state_tutorial.py)
  • GitHub Check: test-examples (examples/tutorials/reporting_tutorial.py)
  • GitHub Check: test-examples (examples/tutorials/autobatching_tutorial.py)
  • GitHub Check: test-examples (examples/tutorials/high_level_tutorial.py)
  • GitHub Check: test-model (macos-14, 3.12, lowest-direct, mace, tests/test_elastic.py)
  • GitHub Check: test-model (macos-14, 3.12, lowest-direct, orb, tests/models/test_orb.py)
  • GitHub Check: test-model (macos-14, 3.12, lowest-direct, fairchem, tests/models/test_fairchem.py)
  • GitHub Check: test-model (ubuntu-latest, 3.12, lowest-direct, orb, tests/models/test_orb.py)
  • GitHub Check: test-model (ubuntu-latest, 3.12, lowest-direct, sevenn, tests/models/test_sevennet.py)
  • GitHub Check: test-model (ubuntu-latest, 3.12, lowest-direct, mace, tests/test_elastic.py)
  • GitHub Check: test-model (ubuntu-latest, 3.12, lowest-direct, graphpes, tests/models/test_graphpes.py)
  • GitHub Check: test-model (ubuntu-latest, 3.12, lowest-direct, mace, tests/test_optimizers_vs_ase.py)
  • GitHub Check: test-model (ubuntu-latest, 3.12, lowest-direct, fairchem, tests/models/test_fairchem.py)
  • GitHub Check: test-model (ubuntu-latest, 3.11, highest, sevenn, tests/models/test_sevennet.py)
  • GitHub Check: test-model (ubuntu-latest, 3.11, highest, orb, tests/models/test_orb.py)
  • GitHub Check: test-model (ubuntu-latest, 3.11, highest, mace, tests/test_optimizers_vs_ase.py)
  • GitHub Check: test-model (ubuntu-latest, 3.11, highest, mace, tests/models/test_mace.py)
  • GitHub Check: test-model (ubuntu-latest, 3.11, highest, mace, tests/test_elastic.py)
  • GitHub Check: test-model (ubuntu-latest, 3.11, highest, graphpes, tests/models/test_graphpes.py)
  • GitHub Check: test-model (ubuntu-latest, 3.11, highest, fairchem, tests/models/test_fairchem.py)
  • GitHub Check: test-core (macos-14, 3.12, lowest-direct)
  • GitHub Check: test-core (ubuntu-latest, 3.11, highest)
  • GitHub Check: test-core (ubuntu-latest, 3.12, lowest-direct)
  • GitHub Check: test-core (macos-14, 3.11, highest)
  • GitHub Check: build-docs
🔇 Additional comments (4)
tests/test_quantities.py (4)

13-26: LGTM!

The fixture provides accurate test data. The expected kinetic energy (13.5) and kT (4.5) values are mathematically correct for the given masses and velocities.


29-44: LGTM!

The batched system fixture provides accurate test data. The expected kinetic energies ([3.0, 24.0]) and kT values ([1.0, 8.0]) per system are mathematically correct.


47-92: Excellent test coverage for kinetic energy calculations.

The tests comprehensively cover:

  • Single and batched system scenarios
  • Both velocity and momentum input modes
  • Proper error handling for invalid argument combinations
  • Correct usage of keyword-only arguments

94-125: LGTM!

The kT calculation tests provide comprehensive coverage matching the kinetic energy test structure, correctly testing both single and batched systems with velocity and momentum inputs.

Copy link
Collaborator Author

@curtischong curtischong left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

overall I approve - 2 minor nits you might want to change before we merge

@CompRhys CompRhys merged commit f6cd006 into main Aug 8, 2025
92 of 93 checks passed
@CompRhys CompRhys deleted the ty branch August 8, 2025 02:27
@coderabbitai coderabbitai bot mentioned this pull request Aug 14, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
cla-signed Contributor license agreement signed
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants