Skip to content

Conversation

curtischong
Copy link
Collaborator

@curtischong curtischong commented Jul 27, 2025

Summary

This PR makes handling SimStates much simpler. Rather than guessing if an attribute is per-node/system/global (via infer_property_scope, SimState (and its child classes) explicitly define which the scope of each attribute.

#227 was a precursor to this PR

By explicitly defining a property scope to each state we:

  1. Add documentation to each attribute - teaching readers which attributes is per-atom/system/global
  2. Prevent n_atoms (1) and n_systems (1) are equal, which means shapes cannot be inferred unambiguously errors which can happen when systems have a single atom.

In general I think we should prevent these edge cases from occurring because they are uncommon but happen often enough.

I added validation logic inside the SimState class to ensure that each child class explicitly defines the property scope for all its attributes.

Note: this PR introduces breaking changes. In particular, it makes all of the classes that derive from SimState be kw_only. This is because the _system_attributes of the parent class are defined for all of the attributes:
e.g. _system_attributes = {"reference_cell"} so all child classes that inherit from this, put their arguments AFTER the parent's attributes. But since we've assigned a value to _system_attributes all future attributes (i.e. the ones in the child class) must be kwonly.

I think this breaking change is fine because by using keywords when defining complex objects like FireState, it eliminates the chance of a user from mixing up different torch.Tensor constructor params.

Checklist

Before a pull request can be merged, the following items must be checked:

  • Doc strings have been added in the Google docstring format.
    Run ruff on your code.
  • Tests have been added for any new functionality or bug fixes.

We highly recommended installing the pre-commit hooks running in CI locally to speedup the development process. Simply run pip install pre-commit && pre-commit install to install the hooks which will check your code before each commit.

Summary by CodeRabbit

  • New Features

    • Introduced explicit per-atom/per-system/global attribute scopes and added a per-atom tracked attribute "last_permutation".
    • Constructors now return NaN-filled tensors instead of None for omitted model outputs (e.g., forces/stress).
  • Bug Fixes

    • Enforced that every state attribute is declared in exactly one scope and that tensor attributes cannot be None.
  • Documentation

    • Tutorials updated to use the new scope-based attribute inspection API.
  • Tests

    • Added/updated tests for scope inspection and enforcement errors; API rename reflected in tests.

@cla-bot cla-bot bot added the cla-signed Contributor license agreement signed label Jul 27, 2025
@curtischong curtischong marked this pull request as draft July 27, 2025 03:25
Copy link

coderabbitai bot commented Jul 27, 2025

Walkthrough

Adds explicit per-atom, per-system, and global attribute registries and enforces them at subclass creation; replaces dynamic scope inference with get_attrs_for_scope; many state dataclasses, integrators, optimizers, runners, examples, and tests updated to declare and use these scope sets and new atom-level attributes (e.g., last_permutation).

Changes

Cohort / File(s) Change Summary
Core state refactor
torch_sim/state.py
Introduces _atom_attributes, _system_attributes, _global_attributes; adds __init_subclass__ checks to enforce scopes and non-None tensor attributes; removes infer_property_scope/_get_property_attrs; adds get_attrs_for_scope and updates state split/concat/filter utilities to use explicit scopes.
MD integrators
torch_sim/integrators/md.py
Adds `_atom_attributes = SimState._atom_attributes
NPT / NVT integrators
torch_sim/integrators/npt.py, torch_sim/integrators/nvt.py
Declares/extends _atom_attributes, _system_attributes, and _global_attributes on NPT/NVT states to include forces, velocities, energy, stress, cell-*, thermostat/barostat fields.
Monte Carlo / Hybrid swap
torch_sim/monte_carlo.py, examples/scripts/3_Dynamics/3.10_Hybrid_swap_mc.py, examples/tutorials/hybrid_swap_tutorial.py
Adds last_permutation: torch.Tensor to HybridSwap/Swap state dataclasses and registers "last_permutation" in _atom_attributes; adds _system_attributes including "energy" to SwapMCState.
Optimizers / FIRE / GD states
torch_sim/optimizers.py
Introduces module-level attribute sets (_md_atom_attributes, _fire_system_attributes, _fire_global_attributes); adds _atom_attributes/_system_attributes/_global_attributes to GDState, FireState, UnitCellGDState, UnitCellFireState, FrechetCellFIREState; enforces @dataclass(kw_only=True) for several states and adjusts related declarations.
Runners / StaticState
torch_sim/runners.py
Adds _atom_attributes and _system_attributes to StaticState; replaces None defaults with NaN-filled tensors for forces/stress when model outputs are disabled.
Tutorials / Examples
examples/tutorials/state_tutorial.py
Replaces infer_property_scope usage with get_attrs_for_scope and updates examples to iterate and display attributes per explicit scope.
Tests
tests/test_state.py
Replaces tests for infer_property_scope with get_attrs_for_scope; renames tests; adds tests enforcing that all attributes have scopes and no duplicates; updates DeformState to combine parent _system_attributes.

Sequence Diagram(s)

sequenceDiagram
    participant Dev
    participant StateClass
    participant SimState
    participant Consumer

    Dev->>StateClass: Define dataclass fields + declare scope sets
    StateClass->>SimState: class creation triggers __init_subclass__
    SimState->>SimState: validate scopes & non-None tensor attributes
    Consumer->>SimState: call get_attrs_for_scope(state, scope)
    SimState-->>Consumer: yield (name, value) pairs for requested scope
Loading

Estimated code review effort

🎯 4 (Complex) | ⏱️ ~40 minutes

Possibly related PRs

  • Remove unbatched code #206 — Adds/exposes per-atom last_permutation in MC state metadata (closely related to the Hybrid/Swap attribute changes).
  • Fix simstate concatenation [2/2] #232 — Introduces SimState subclass-time validations and non-None tensor enforcement that interact with newly added attributes and scope declarations.

Suggested labels

enhancement

Suggested reviewers

  • CompRhys
  • orionarcher

Poem

Hopping through schemas with tidy delight,
I tag each atom and name it just right.
Scopes are enforced at class creation time,
No missing tensors, each field in its rhyme.
A rabbit nods — the state is neat and bright 🐇

✨ Finishing Touches
  • 📝 Generate Docstrings
🧪 Generate unit tests
  • Create PR with unit tests
  • Post copyable unit tests in a comment
  • Commit unit tests in branch classify-range-of-simstate-feats2

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share
🪧 Tips

Chat

There are 3 ways to chat with CodeRabbit:

  • Review comments: Directly reply to a review comment made by CodeRabbit. Example:
    • I pushed a fix in commit <commit_id>, please review it.
    • Explain this complex logic.
    • Open a follow-up GitHub issue for this discussion.
  • Files and specific lines of code (under the "Files changed" tab): Tag @coderabbitai in a new review comment at the desired location with your query. Examples:
    • @coderabbitai explain this code block.
  • PR comments: Tag @coderabbitai in a new PR comment to ask questions about the PR branch. For the best results, please provide a very specific query, as very limited context is provided in this mode. Examples:
    • @coderabbitai gather interesting stats about this repository and render them as a table. Additionally, render a pie chart showing the language distribution in the codebase.
    • @coderabbitai read src/utils.ts and explain its main purpose.
    • @coderabbitai read the files in the src/scheduler package and generate a class diagram using mermaid and a README in the markdown format.

Support

Need help? Create a ticket on our support page for assistance with any issues or questions.

CodeRabbit Commands (Invoked using PR comments)

  • @coderabbitai pause to pause the reviews on a PR.
  • @coderabbitai resume to resume the paused reviews.
  • @coderabbitai review to trigger an incremental review. This is useful when automatic reviews are disabled for the repository.
  • @coderabbitai full review to do a full review from scratch and review all the files again.
  • @coderabbitai summary to regenerate the summary of the PR.
  • @coderabbitai generate docstrings to generate docstrings for this PR.
  • @coderabbitai generate sequence diagram to generate a sequence diagram of the changes in this PR.
  • @coderabbitai generate unit tests to generate unit tests for this PR.
  • @coderabbitai resolve resolve all the CodeRabbit review comments.
  • @coderabbitai configuration to show the current CodeRabbit configuration for the repository.
  • @coderabbitai help to get help.

Other keywords and placeholders

  • Add @coderabbitai ignore anywhere in the PR description to prevent this PR from being reviewed.
  • Add @coderabbitai summary to generate the high-level summary at a specific location in the PR description.
  • Add @coderabbitai anywhere in the PR title to generate the title automatically.

CodeRabbit Configuration File (.coderabbit.yaml)

  • You can programmatically configure CodeRabbit by adding a .coderabbit.yaml file to the root of your repository.
  • Please see the configuration documentation for more information.
  • If your editor has YAML language server enabled, you can add the path at the top of this file to enable auto-completion and validation: # yaml-language-server: $schema=https://coderabbit.ai/integrations/schema.v2.json

Documentation and Community

  • Visit our Documentation for detailed information on how to use CodeRabbit.
  • Join our Discord Community to get help, request features, and share feedback.
  • Follow us on X/Twitter for updates and announcements.

Comment on lines 381 to 417
def __init_subclass__(cls, **kwargs) -> None:
"""Enforce all of child classes's attributes are specified in _atom_attributes,
_system_attributes, or _global_attributes.
"""
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should document the new requirements somewhere

@curtischong curtischong changed the title use hardcoded attr names in SimState Define attribute scopes in SimStates Jul 27, 2025
@curtischong
Copy link
Collaborator Author

curtischong commented Jul 27, 2025

the tests are failing because of the issue pointed out in this PR: #219. I'll revisit this tomorrow

@curtischong curtischong force-pushed the classify-range-of-simstate-feats2 branch 2 times, most recently from 7395098 to 0157533 Compare August 3, 2025 21:43
@curtischong curtischong changed the base branch from main to fix-simstate-concatenation August 3, 2025 21:43
@curtischong
Copy link
Collaborator Author

:'( I just got a ValueError: n_atoms (1) and n_systems (1) are equal, which means shapes cannot be inferred unambiguously. error. we should probably merge this in

@curtischong curtischong force-pushed the fix-simstate-concatenation branch from 4c49f21 to 84d6750 Compare August 8, 2025 03:14
@curtischong curtischong force-pushed the classify-range-of-simstate-feats2 branch from 18182d4 to 23eea14 Compare August 8, 2025 03:18
@curtischong curtischong force-pushed the fix-simstate-concatenation branch from 6755e8b to 3d23f7d Compare August 8, 2025 14:03
Base automatically changed from fix-simstate-concatenation to main August 8, 2025 14:25
An error occurred while trying to automatically change base from fix-simstate-concatenation to main August 8, 2025 14:25
add test for the new pre-defined attributes

wip make fire simstate attributes predefined

rename features to attributes

define attribute scope for fire state

consolidate attribute definitions and add it to npt

rm infer_property_scope import

__init_subclass__ to enforce that all attributes are specified

find callable attributes

do not check some user attributes

filter for @properties

__init_subclass__ doesn't catch for static state since it's a grandchild state

added documentation for get_attrs_for_scope

more examples for the other scopes

cleaner documentation

rename batch to system

more docs cleanup

test for running the init subclass

more tests

define more scope for the integrators

split state handles none properties a bit better
@curtischong curtischong force-pushed the classify-range-of-simstate-feats2 branch from 074f100 to c4f8ee0 Compare August 9, 2025 02:02
Copy link

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

🔭 Outside diff range comments (1)
torch_sim/state.py (1)

867-905: concatenate_states: skip collecting system_idx to avoid redundant work

System indices are recomputed via new_system_indices; also collecting/catting system_idx above is redundant and briefly incorrect before being overwritten.

-    for prop, val in get_attrs_for_scope(state, "per-atom"):
-        # if hasattr(state, prop):
-        per_atom_tensors[prop].append(val)
+    for prop, val in get_attrs_for_scope(state, "per-atom"):
+        if prop == "system_idx":
+            continue
+        per_atom_tensors[prop].append(val)
🧹 Nitpick comments (4)
torch_sim/optimizers.py (2)

333-378: UnitCell GD init: ensure dtype consistency for pressure and identity

torch.eye defaults to float32; if model dtype is float64 this will trigger promotions. Make eye explicitly match dtype to avoid subtle perf/dtype issues.

-        pressure = scalar_pressure * torch.eye(3, device=device)
+        pressure = scalar_pressure * torch.eye(3, device=device, dtype=dtype)

-            virial = diag_mean.unsqueeze(-1) * torch.eye(3, device=device).unsqueeze(
+            virial = diag_mean.unsqueeze(-1) * torch.eye(3, device=device, dtype=dtype).unsqueeze(
                 0
             ).expand(state.n_systems, -1, -1)

-            virial = virial - diag_mean.unsqueeze(-1) * torch.eye(
-                3, device=device
+            virial = virial - diag_mean.unsqueeze(-1) * torch.eye(
+                3, device=device, dtype=dtype
             ).unsqueeze(0).expand(state.n_systems, -1, -1)

483-536: FireState: docstring mentions momenta but no property implemented

Either implement the property or remove from docstring. Implementing it is tiny and useful.

 class FireState(SimState):
@@
     Properties:
         momenta (torch.Tensor): Atomwise momenta of the system with shape [n_atoms, 3],
             calculated as velocities * masses
@@
     n_pos: torch.Tensor
+
+    @property
+    def momenta(self) -> torch.Tensor:
+        """Atomwise momenta: velocities * masses."""
+        return self.velocities * self.masses.unsqueeze(-1)
torch_sim/state.py (2)

449-494: Scope enforcement: considers ClassVar/properties/methods; one small improvement

Implementation correctly avoids false positives (ClassVar, properties, methods). Consider tightening the error message by including the class name and guidance to add the attribute to exactly one scope.

-                raise TypeError(
-                    f"Attribute '{attr_name}' is not defined in {cls.__name__} in any "
-                    "of _atom_attributes, _system_attributes, or _global_attributes"
-                )
+                raise TypeError(
+                    f"{cls.__name__}: attribute '{attr_name}' is not declared in any "
+                    "scope. Add it to exactly one of: _atom_attributes, "
+                    "_system_attributes, or _global_attributes."
+                )

612-636: get_attrs_for_scope: LGTM; optional determinism

Function is correct and defensive. For deterministic iteration (helps debugging, reproducible dumps), consider yielding in sorted order.

-    for attr_name in attr_names:
+    for attr_name in sorted(attr_names):
         yield attr_name, getattr(state, attr_name)
📜 Review details

Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between efa7e00 and 5e181ed.

📒 Files selected for processing (10)
  • examples/scripts/3_Dynamics/3.10_Hybrid_swap_mc.py (1 hunks)
  • examples/tutorials/hybrid_swap_tutorial.py (2 hunks)
  • tests/test_state.py (3 hunks)
  • torch_sim/integrators/md.py (1 hunks)
  • torch_sim/integrators/npt.py (2 hunks)
  • torch_sim/integrators/nvt.py (1 hunks)
  • torch_sim/monte_carlo.py (1 hunks)
  • torch_sim/optimizers.py (10 hunks)
  • torch_sim/runners.py (2 hunks)
  • torch_sim/state.py (15 hunks)
🚧 Files skipped from review as they are similar to previous changes (8)
  • torch_sim/integrators/nvt.py
  • examples/scripts/3_Dynamics/3.10_Hybrid_swap_mc.py
  • tests/test_state.py
  • torch_sim/monte_carlo.py
  • torch_sim/runners.py
  • torch_sim/integrators/md.py
  • examples/tutorials/hybrid_swap_tutorial.py
  • torch_sim/integrators/npt.py
🧰 Additional context used
🧬 Code Graph Analysis (2)
torch_sim/optimizers.py (1)
torch_sim/state.py (2)
  • SimState (29-493)
  • DeformGradMixin (497-533)
torch_sim/state.py (2)
torch_sim/properties/correlations.py (2)
  • update (178-196)
  • append (57-71)
tests/test_correlations.py (1)
  • split (39-42)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (45)
  • GitHub Check: test-examples (examples/scripts/2_Structural_optimization/2.3_MACE_Gradient_Descent.py)
  • GitHub Check: test-examples (examples/tutorials/low_level_tutorial.py)
  • GitHub Check: test-examples (examples/scripts/3_Dynamics/3.6_MACE_NVT_Nose_Hoover_temp_profile.py)
  • GitHub Check: test-examples (examples/scripts/3_Dynamics/3.5_MACE_NVT_Nose_Hoover.py)
  • GitHub Check: test-examples (examples/scripts/3_Dynamics/3.12_MACE_NPT_Langevin.py)
  • GitHub Check: test-examples (examples/scripts/4_High_level_api/4.2_auto_batching_api.py)
  • GitHub Check: test-examples (examples/scripts/5_Workflow/5.3_Elastic.py)
  • GitHub Check: test-examples (examples/scripts/7_Others/7.1_Soft_sphere_autograd.py)
  • GitHub Check: test-examples (examples/scripts/5_Workflow/5.1_a2c_silicon_batched.py)
  • GitHub Check: test-examples (examples/scripts/3_Dynamics/3.3_MACE_NVE_cueq.py)
  • GitHub Check: test-examples (examples/scripts/3_Dynamics/3.2_MACE_NVE.py)
  • GitHub Check: test-examples (examples/tutorials/reporting_tutorial.py)
  • GitHub Check: test-examples (examples/scripts/3_Dynamics/3.8_MACE_NPT_Nose_Hoover.py)
  • GitHub Check: test-examples (examples/scripts/3_Dynamics/3.9_MACE_NVT_staggered_stress.py)
  • GitHub Check: test-examples (examples/scripts/3_Dynamics/3.4_MACE_NVT_Langevin.py)
  • GitHub Check: test-examples (examples/scripts/6_Phonons/6.1_Phonons_MACE.py)
  • GitHub Check: test-examples (examples/scripts/7_Others/7.6_Compare_ASE_to_VV_FIRE.py)
  • GitHub Check: test-examples (examples/scripts/3_Dynamics/3.10_Hybrid_swap_mc.py)
  • GitHub Check: test-examples (examples/scripts/7_Others/7.4_Velocity_AutoCorrelation.py)
  • GitHub Check: test-examples (examples/scripts/2_Structural_optimization/2.6_MACE_UnitCellFilter_FIRE.py)
  • GitHub Check: test-examples (examples/tutorials/autobatching_tutorial.py)
  • GitHub Check: test-examples (examples/scripts/2_Structural_optimization/2.7_MACE_FrechetCellFilter_FIRE.py)
  • GitHub Check: test-examples (examples/tutorials/high_level_tutorial.py)
  • GitHub Check: test-examples (examples/scripts/6_Phonons/6.2_QuasiHarmonic_MACE.py)
  • GitHub Check: test-examples (examples/scripts/2_Structural_optimization/2.4_MACE_FIRE.py)
  • GitHub Check: test-examples (examples/scripts/4_High_level_api/4.1_high_level_api.py)
  • GitHub Check: test-examples (examples/tutorials/state_tutorial.py)
  • GitHub Check: test-examples (examples/tutorials/hybrid_swap_tutorial.py)
  • GitHub Check: test-model (macos-14, 3.12, lowest-direct, graphpes, tests/models/test_graphpes.py)
  • GitHub Check: test-model (ubuntu-latest, 3.11, highest, orb, tests/models/test_orb.py)
  • GitHub Check: test-model (ubuntu-latest, 3.12, lowest-direct, sevenn, tests/models/test_sevennet.py)
  • GitHub Check: test-model (macos-14, 3.11, highest, mattersim, tests/models/test_mattersim.py)
  • GitHub Check: test-model (ubuntu-latest, 3.12, lowest-direct, mace, tests/test_elastic.py)
  • GitHub Check: test-model (ubuntu-latest, 3.12, lowest-direct, orb, tests/models/test_orb.py)
  • GitHub Check: test-model (ubuntu-latest, 3.11, highest, sevenn, tests/models/test_sevennet.py)
  • GitHub Check: test-model (ubuntu-latest, 3.12, lowest-direct, mace, tests/test_optimizers_vs_ase.py)
  • GitHub Check: test-model (macos-14, 3.11, highest, mace, tests/test_elastic.py)
  • GitHub Check: test-model (ubuntu-latest, 3.12, lowest-direct, fairchem, tests/models/test_fairchem.py)
  • GitHub Check: test-model (ubuntu-latest, 3.11, highest, mace, tests/test_elastic.py)
  • GitHub Check: test-core (macos-14, 3.11, highest)
  • GitHub Check: test-model (ubuntu-latest, 3.11, highest, fairchem, tests/models/test_fairchem.py)
  • GitHub Check: test-core (macos-14, 3.12, lowest-direct)
  • GitHub Check: test-core (ubuntu-latest, 3.11, highest)
  • GitHub Check: test-core (ubuntu-latest, 3.12, lowest-direct)
  • GitHub Check: build-docs
🔇 Additional comments (9)
torch_sim/optimizers.py (4)

82-84: GDState scopes: LGTM

Per-atom and per-system scope declarations for forces and energy are correct and minimal.


178-239: UnitCellGDState scopes: LGTM

Scopes cover reference_cell (via mixin), stress/pressure/cell_* tensors, and global flags. Matches SimState’s enforcement.


748-751: UnitCellFireState: scope sets look correct

Atom/system/global registries match the class’ declared fields (including DeformGradMixin.reference_cell).


1040-1043: FrechetCellFIREState: scope sets look correct

Atom/system/global registries align with fields and mixin. No duplicates across scopes.

torch_sim/state.py (5)

88-96: Default scope sets on SimState: LGTM

Sane defaults and clear separation: per-atom (positions/masses/atomic_numbers/system_idx), per-system (cell), global (pbc).


415-426: Subclass enforcement: good placement and coverage

Both “no Tensor | None” and “all attributes must be scoped” checks are in init_subclass. Solid guardrails.


638-687: Filtering by masks: remapping system_idx is correct

Handling of system_idx with remapping to consecutive indices is correct and robust when systems are removed.


708-722: Split state: sensible split across scopes; global attrs copied

Per-atom/system splits and zeroed system_idx per single-system outputs are correct.


778-787: Pop states: correct masking and split; good non-mutating semantics

The keep/pop flow and subsequent split are sound.

Copy link

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 2

♻️ Duplicate comments (1)
torch_sim/optimizers.py (1)

36-58: Rename to private _md_atom_attributes is applied; consistent with prior feedback

The top-level constant is now private and referenced consistently in state classes. Good follow-through on the previous suggestion.

🧹 Nitpick comments (2)
torch_sim/optimizers.py (2)

36-58: Make module-level attribute registries immutable (optional)

To avoid accidental mutation at runtime, consider using frozenset for these registries.

Apply this diff:

-_md_atom_attributes = SimState._atom_attributes | {"forces", "velocities"}  # noqa: SLF001
+_md_atom_attributes = frozenset(SimState._atom_attributes | {"forces", "velocities"})  # noqa: SLF001

-_fire_system_attributes = (
+_fire_system_attributes = frozenset(
     SimState._system_attributes  # noqa: SLF001
     | DeformGradMixin._system_attributes  # noqa: SLF001
     | {
         "energy",
         "stress",
         "cell_positions",
         "cell_velocities",
         "cell_forces",
         "cell_masses",
         "cell_factor",
         "pressure",
         "dt",
         "alpha",
         "n_pos",
     }
 )
-_fire_global_attributes = SimState._global_attributes | {  # noqa: SLF001
+_fire_global_attributes = frozenset(SimState._global_attributes | {  # noqa: SLF001
     "hydrostatic_strain",
     "constant_volume",
-}
+})

224-239: Nit: pass dtype to torch.eye in related functions to avoid implicit casts

In unit_cell_gradient_descent.gd_init, torch.eye(...) is created without dtype, which can cause dtype promotion when mixed with dtype tensors.

Suggested adjustments outside this hunk:

-        pressure = scalar_pressure * torch.eye(3, device=device)
+        pressure = scalar_pressure * torch.eye(3, device=device, dtype=dtype)

-            virial = diag_mean.unsqueeze(-1) * torch.eye(3, device=device).unsqueeze(
+            virial = diag_mean.unsqueeze(-1) * torch.eye(3, device=device, dtype=dtype).unsqueeze(

-            virial = virial - diag_mean.unsqueeze(-1) * torch.eye(
-                3, device=device
+            virial = virial - diag_mean.unsqueeze(-1) * torch.eye(
+                3, device=device, dtype=dtype
             ).unsqueeze(0).expand(state.n_systems, -1, -1)
📜 Review details

Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 5e181ed and 50fa99d.

📒 Files selected for processing (1)
  • torch_sim/optimizers.py (10 hunks)
🧰 Additional context used
🧠 Learnings (1)
📓 Common learnings
Learnt from: curtischong
PR: Radical-AI/torch-sim#215
File: torch_sim/runners.py:48-53
Timestamp: 2025-08-08T02:04:31.517Z
Learning: In the torch-sim project, when reviewing PRs, scope expansion issues (like runtime safety improvements) that are outside the current PR's main objective should be acknowledged but deferred to separate PRs rather than insisting on immediate fixes.
🧬 Code Graph Analysis (1)
torch_sim/optimizers.py (3)
torch_sim/state.py (4)
  • SimState (29-493)
  • DeformGradMixin (497-533)
  • SimState (26-401)
  • _get_property_attrs (599-632)
torch_sim/integrators/md.py (1)
  • MDState (14-49)
tests/test_state.py (1)
  • DeformState (497-509)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (45)
  • GitHub Check: test-examples (examples/scripts/3_Dynamics/3.4_MACE_NVT_Langevin.py)
  • GitHub Check: test-examples (examples/scripts/3_Dynamics/3.1_Lennard_Jones_NVE.py)
  • GitHub Check: test-examples (examples/scripts/3_Dynamics/3.6_MACE_NVT_Nose_Hoover_temp_profile.py)
  • GitHub Check: test-examples (examples/scripts/3_Dynamics/3.9_MACE_NVT_staggered_stress.py)
  • GitHub Check: test-examples (examples/scripts/3_Dynamics/3.13_MACE_NVE_non_pbc.py)
  • GitHub Check: test-examples (examples/scripts/3_Dynamics/3.3_MACE_NVE_cueq.py)
  • GitHub Check: test-examples (examples/scripts/3_Dynamics/3.2_MACE_NVE.py)
  • GitHub Check: test-examples (examples/scripts/3_Dynamics/3.8_MACE_NPT_Nose_Hoover.py)
  • GitHub Check: test-examples (examples/scripts/3_Dynamics/3.12_MACE_NPT_Langevin.py)
  • GitHub Check: test-examples (examples/scripts/3_Dynamics/3.10_Hybrid_swap_mc.py)
  • GitHub Check: test-examples (examples/scripts/5_Workflow/5.3_Elastic.py)
  • GitHub Check: test-examples (examples/scripts/7_Others/7.4_Velocity_AutoCorrelation.py)
  • GitHub Check: test-examples (examples/scripts/5_Workflow/5.2_In_Flight_WBM.py)
  • GitHub Check: test-examples (examples/scripts/7_Others/7.6_Compare_ASE_to_VV_FIRE.py)
  • GitHub Check: test-examples (examples/scripts/5_Workflow/5.1_a2c_silicon_batched.py)
  • GitHub Check: test-examples (examples/scripts/2_Structural_optimization/2.6_MACE_UnitCellFilter_FIRE.py)
  • GitHub Check: test-examples (examples/scripts/6_Phonons/6.2_QuasiHarmonic_MACE.py)
  • GitHub Check: test-examples (examples/scripts/2_Structural_optimization/2.4_MACE_FIRE.py)
  • GitHub Check: test-examples (examples/scripts/2_Structural_optimization/2.3_MACE_Gradient_Descent.py)
  • GitHub Check: test-examples (examples/scripts/4_High_level_api/4.1_high_level_api.py)
  • GitHub Check: test-examples (examples/tutorials/low_level_tutorial.py)
  • GitHub Check: test-examples (examples/scripts/6_Phonons/6.3_Conductivity_MACE.py)
  • GitHub Check: test-examples (examples/scripts/6_Phonons/6.1_Phonons_MACE.py)
  • GitHub Check: test-examples (examples/tutorials/reporting_tutorial.py)
  • GitHub Check: test-examples (examples/tutorials/high_level_tutorial.py)
  • GitHub Check: test-examples (examples/tutorials/autobatching_tutorial.py)
  • GitHub Check: test-examples (examples/tutorials/hybrid_swap_tutorial.py)
  • GitHub Check: test-model (macos-14, 3.12, lowest-direct, mace, tests/test_optimizers_vs_ase.py)
  • GitHub Check: test-model (macos-14, 3.11, highest, sevenn, tests/models/test_sevennet.py)
  • GitHub Check: test-model (ubuntu-latest, 3.12, lowest-direct, orb, tests/models/test_orb.py)
  • GitHub Check: test-model (macos-14, 3.11, highest, orb, tests/models/test_orb.py)
  • GitHub Check: test-model (macos-14, 3.11, highest, metatomic, tests/models/test_metatomic.py)
  • GitHub Check: test-model (ubuntu-latest, 3.12, lowest-direct, mace, tests/test_elastic.py)
  • GitHub Check: test-model (ubuntu-latest, 3.12, lowest-direct, fairchem, tests/models/test_fairchem.py)
  • GitHub Check: test-model (ubuntu-latest, 3.12, lowest-direct, mace, tests/test_optimizers_vs_ase.py)
  • GitHub Check: test-model (ubuntu-latest, 3.12, lowest-direct, sevenn, tests/models/test_sevennet.py)
  • GitHub Check: test-model (ubuntu-latest, 3.11, highest, mace, tests/test_optimizers_vs_ase.py)
  • GitHub Check: test-model (ubuntu-latest, 3.11, highest, mace, tests/test_elastic.py)
  • GitHub Check: test-model (ubuntu-latest, 3.11, highest, orb, tests/models/test_orb.py)
  • GitHub Check: test-model (ubuntu-latest, 3.11, highest, sevenn, tests/models/test_sevennet.py)
  • GitHub Check: test-model (ubuntu-latest, 3.11, highest, fairchem, tests/models/test_fairchem.py)
  • GitHub Check: test-core (ubuntu-latest, 3.11, highest)
  • GitHub Check: test-core (ubuntu-latest, 3.12, lowest-direct)
  • GitHub Check: test-core (macos-14, 3.12, lowest-direct)
  • GitHub Check: build-docs
🔇 Additional comments (8)
torch_sim/optimizers.py (8)

82-84: GDState attribute scopes are correct

forces as per-atom and energy as per-system matches usage. Good.


224-239: UnitCellGDState scopes look consistent with DeformGradMixin and added fields

System attributes include reference_cell (via mixin), stress, pressure, and cell_*; globals include hydrostatic_strain and constant_volume. Looks good.


483-483: kw_only dataclass for FireState is appropriate

This is consistent with the PR’s scope-enforcement and constructor changes.


748-751: UnitCellFireState attribute scopes look correct

Atom attributes use md_atom_attributes; system/global sets match FIRE requirements incl. cell* and dt/alpha/n_pos.


966-966: kw_only dataclass for FrechetCellFIREState is appropriate

Matches the rest of the module and PR goals.


1040-1043: FrechetCellFIREState scopes look consistent with FIRE and deform-grad needs

Atom/system/global sets align with fire* registries. LGTM.


178-178: No positional instantiations found for UnitCellGDState
All existing call sites use keyword arguments, so the switch to kw_only=True does not break construction.

• torch_sim/optimizers.py:379–381:
return UnitCellGDState(positions=state.positions, forces=forces, …)


1040-1043: Minimum PyTorch requirement already covers torch.segment_reduce

The project’s pyproject.toml pins
• torch>=2.0 (line 33)
And torch.segment_reduce was introduced in v1.9, so this usage is safe. No guard or version bump needed.

@@ -1261,7 +1321,7 @@ def _vv_fire_step( # noqa: C901, PLR0915
state.positions[nan_velocities]
)
if is_cell_optimization:
if not isinstance(state, AnyFireCellState):
if not isinstance(state, get_args(AnyFireCellState)):
Copy link
Collaborator Author

@curtischong curtischong Aug 11, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

without get_args this isinstance check is wrong I believe (since AnyFireCellState is a union type, we need to convert AnyFireCellState to a tuple of its individual types). coderabbit caught this which is why I put it in this PR

@curtischong curtischong marked this pull request as draft August 11, 2025 03:54
@curtischong curtischong force-pushed the classify-range-of-simstate-feats2 branch from 3441ae2 to acabbad Compare August 11, 2025 04:03
@curtischong curtischong force-pushed the classify-range-of-simstate-feats2 branch from acabbad to 6f7439b Compare August 11, 2025 04:06
@curtischong curtischong marked this pull request as ready for review August 11, 2025 14:31
Copy link

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

🔭 Outside diff range comments (1)
torch_sim/state.py (1)

643-692: Vectorize system_idx remapping to avoid Python loops

The current list-comprehension remapping for system_idx works but scales poorly. Vectorize to reduce Python overhead and allocations.

Apply within this block:

-        if attr_name == "system_idx":
-            # Get the old system indices for the selected atoms
-            old_system_idxs = attr_value[atom_mask]
-
-            # Get the system indices that are kept
-            kept_indices = torch.arange(attr_value.max() + 1, device=attr_value.device)[
-                system_mask
-            ]
-
-            # Create a mapping from old system indices to new consecutive indices
-            system_idx_map = {idx.item(): i for i, idx in enumerate(kept_indices)}
-
-            # Create new system tensor with remapped indices
-            new_system_idxs = torch.tensor(
-                [system_idx_map[b.item()] for b in old_system_idxs],
-                device=attr_value.device,
-                dtype=attr_value.dtype,
-            )
-            filtered_attrs[attr_name] = new_system_idxs
+        if attr_name == "system_idx":
+            old_system_idxs = attr_value[atom_mask]
+            n_systems = state.n_systems
+            kept_indices = torch.arange(n_systems, device=attr_value.device)[system_mask]
+            mapping = torch.full((n_systems,), -1, device=attr_value.device, dtype=attr_value.dtype)
+            mapping[kept_indices] = torch.arange(
+                kept_indices.numel(), device=attr_value.device, dtype=attr_value.dtype
+            )
+            filtered_attrs[attr_name] = mapping[old_system_idxs]
🧹 Nitpick comments (1)
tests/test_state.py (1)

27-39: Good coverage of per-atom/system/global scopes

Assertions match SimState’s declared scopes. Consider adding a negative test for an invalid scope to lock in the error behavior.

Example:

+def test_get_attrs_for_scope_invalid_scope(si_sim_state: ts.SimState) -> None:
+    with pytest.raises(ValueError, match="Unknown scope"):
+        list(get_attrs_for_scope(si_sim_state, "invalid"))  # type: ignore[arg-type]
📜 Review details

Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 50fa99d and 81f830b.

📒 Files selected for processing (4)
  • examples/tutorials/state_tutorial.py (2 hunks)
  • tests/test_state.py (3 hunks)
  • torch_sim/optimizers.py (12 hunks)
  • torch_sim/state.py (15 hunks)
🚧 Files skipped from review as they are similar to previous changes (1)
  • examples/tutorials/state_tutorial.py
🧰 Additional context used
🧠 Learnings (3)
📓 Common learnings
Learnt from: curtischong
PR: Radical-AI/torch-sim#215
File: torch_sim/runners.py:48-53
Timestamp: 2025-08-08T02:04:31.517Z
Learning: In the torch-sim project, when reviewing PRs, scope expansion issues (like runtime safety improvements) that are outside the current PR's main objective should be acknowledged but deferred to separate PRs rather than insisting on immediate fixes.
📚 Learning: 2025-08-11T03:47:27.656Z
Learnt from: curtischong
PR: Radical-AI/torch-sim#228
File: torch_sim/optimizers.py:675-675
Timestamp: 2025-08-11T03:47:27.656Z
Learning: In the torch-sim project, `DeformGradMixin` is a `dataclass(kw_only=True)` that declares `reference_cell: torch.Tensor` as a field. Classes that inherit from `DeformGradMixin` (like `UnitCellFireState`) automatically inherit the `reference_cell` field through Python's dataclass inheritance mechanism, so they don't need to redeclare it.

Applied to files:

  • torch_sim/optimizers.py
  • torch_sim/state.py
📚 Learning: 2025-08-08T02:04:31.517Z
Learnt from: curtischong
PR: Radical-AI/torch-sim#215
File: torch_sim/runners.py:48-53
Timestamp: 2025-08-08T02:04:31.517Z
Learning: In the torch-sim project, when reviewing PRs, scope expansion issues (like runtime safety improvements) that are outside the current PR's main objective should be acknowledged but deferred to separate PRs rather than insisting on immediate fixes.

Applied to files:

  • torch_sim/optimizers.py
🧬 Code Graph Analysis (2)
torch_sim/optimizers.py (3)
torch_sim/state.py (2)
  • SimState (29-493)
  • DeformGradMixin (497-538)
torch_sim/integrators/md.py (1)
  • MDState (14-49)
tests/test_optimizers.py (1)
  • test_fire_fixed_cell_unit_cell_consistency (786-880)
tests/test_state.py (3)
torch_sim/state.py (3)
  • get_attrs_for_scope (617-640)
  • SimState (29-493)
  • DeformGradMixin (497-538)
tests/conftest.py (1)
  • si_sim_state (142-144)
torch_sim/integrators/md.py (1)
  • MDState (14-49)
🔇 Additional comments (25)
tests/test_state.py (4)

16-18: API migration to get_attrs_for_scope looks correct

Import aligns with the new public API; removes reliance on shape inference. No issues.


42-57: Scope-coverage enforcement test is solid

The test accurately validates that missing attributes trigger a TypeError and that attributes present in scope registries are not flagged. LGTM.


59-73: Duplicate-attribute enforcement test is correct

Catches attributes declared in multiple scopes and asserts the error message. Looks good.


514-517: DeformState: proper composition of system attributes

Unioning SimState and DeformGradMixin system attributes makes reference_cell available for scope logic without polluting atom/global scopes. Good.

torch_sim/state.py (9)

88-96: Explicit scope registries are well-chosen

Per-atom/system/global sets are minimal and unambiguous. This enables deterministic behavior in downstream utilities.


218-229: batch property deprecation path is consistent

Return type is non-optional and mirrors system_idx with deprecation warnings. Matches tests.


415-426: Subclass enforcement placement is correct

Running validations before super().init_subclass ensures early failure for misdeclared subclasses.


427-448: None-forbidden tensor enforcement is robust

Covers both typing.Union and PEP 604 unions. Clear error message to authors.


449-494: Scope enforcement: duplicates and coverage are handled

  • Detects duplicates across scope sets.
  • Aggregates annotations across MRO.
  • Skips properties, methods, and ClassVar-annotated names to avoid false positives.
    This is the right balance between safety and flexibility.

496-503: DeformGradMixin: kw_only dataclass with scoped reference_cell

Declaring reference_cell and registering it in _system_attributes fits the new model; TYPE_CHECKING guard for row_vector_cell avoids runtime field injection.

Also applies to: 504-508


617-641: get_attrs_for_scope: clear API and defensive default

Literal-typed scope, explicit default raising ValueError, and generator semantics are clean. Matches tests.


713-717: Split state: correct handling of scopes

  • Ignores system_idx in per-atom splits and reconstructs it to zeros.
  • Splits per-system on dim=0 as expected.
  • Copies global attributes verbatim.
    All good.

Also applies to: 719-721, 722-723


783-792: Pop states: correct reuse of mask-based filtering

The keep/pop paths share the same filtering logic; then the popped aggregate is split into per-system states. Solid approach.

torch_sim/optimizers.py (12)

36-58: Well-scoped FIRE/MD attribute registries

  • _md_atom_attributes extends per-atom with forces/velocities.
  • _fire_system_attributes unions SimState, DeformGradMixin, and FIRE-specific fields.
  • _fire_global_attributes captures booleans controlling cell optimization.
    This centralization reduces drift across classes.

82-84: GDState: scope sets correctly reflect forces/energy

Matches the docstring and downstream usage. Good.


178-178: UnitCellGDState kw_only dataclass

kw_only aligns with SimState subclassing changes and reduces param-ordering errors.


224-239: UnitCellGDState: merged scope declarations are coherent

Includes DeformGradMixin.system attributes (reference_cell) and adds cell/stress-related per-system and global items. Looks consistent.


483-483: FireState kw_only dataclass

Consistent with the new constructor pattern.


527-537: FireState: scopes include energy and FIRE params per-system

Atom attributes reuse the MD set; system-level adds energy, dt, alpha, n_pos. Matches algorithm needs.


675-675: UnitCellFireState kw_only dataclass

Works with DeformGradMixin being kw_only; no need to redeclare reference_cell.


748-751: UnitCellFireState: attribute scopes align with FIRE cell optimization

Reuses shared FIRE sets; keeps global booleans separate. Good.


966-966: FrechetCellFIREState kw_only dataclass

Consistent with neighboring states and mixin expectations.


1039-1042: FrechetCellFIREState: scopes mirror UnitCellFireState

Keeps atom/system/global symmetry between the two FIRE variants. Good.


1323-1326: Fix isinstance checks using get_args(AnyFireCellState)

Using a tuple of concrete classes avoids runtime TypeError. Correct fix.


1536-1536: Initialize cur_deform_grad and guard is_cell_optimization

Prevents UnboundLocalError and properly validates state type for cell optimization. LGTM.

Also applies to: 1546-1554

Copy link

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 0

🧹 Nitpick comments (1)
torch_sim/state.py (1)

449-494: Consider adding defensive checks for scope attribute existence.

While the validation logic is well-implemented, child classes might not define all three scope attributes. Consider checking attribute existence before accessing them:

 @classmethod
 def _assert_all_attributes_have_defined_scope(cls) -> None:
+    # Ensure scope attributes exist with defaults
+    atom_attrs = getattr(cls, '_atom_attributes', set())
+    system_attrs = getattr(cls, '_system_attributes', set())
+    global_attrs = getattr(cls, '_global_attributes', set())
     all_defined_attributes = (
-        cls._atom_attributes | cls._system_attributes | cls._global_attributes
+        atom_attrs | system_attrs | global_attrs
     )
     # 1) assert that no attribute is defined twice in all_defined_attributes
     duplicates = (
-        (cls._atom_attributes & cls._system_attributes)
-        | (cls._atom_attributes & cls._global_attributes)
-        | (cls._system_attributes & cls._global_attributes)
+        (atom_attrs & system_attrs)
+        | (atom_attrs & global_attrs)
+        | (system_attrs & global_attrs)
     )
📜 Review details

Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 81f830b and 7ba046b.

📒 Files selected for processing (1)
  • torch_sim/state.py (15 hunks)
🧰 Additional context used
🧠 Learnings (2)
📓 Common learnings
Learnt from: curtischong
PR: Radical-AI/torch-sim#215
File: torch_sim/runners.py:48-53
Timestamp: 2025-08-08T02:04:31.517Z
Learning: In the torch-sim project, when reviewing PRs, scope expansion issues (like runtime safety improvements) that are outside the current PR's main objective should be acknowledged but deferred to separate PRs rather than insisting on immediate fixes.
📚 Learning: 2025-08-11T03:47:27.656Z
Learnt from: curtischong
PR: Radical-AI/torch-sim#228
File: torch_sim/optimizers.py:675-675
Timestamp: 2025-08-11T03:47:27.656Z
Learning: In the torch-sim project, `DeformGradMixin` is a `dataclass(kw_only=True)` that declares `reference_cell: torch.Tensor` as a field. Classes that inherit from `DeformGradMixin` (like `UnitCellFireState`) automatically inherit the `reference_cell` field through Python's dataclass inheritance mechanism, so they don't need to redeclare it.

Applied to files:

  • torch_sim/state.py
🧬 Code Graph Analysis (1)
torch_sim/state.py (1)
tests/test_state.py (2)
  • test_infer_sim_state_property_scope (27-37)
  • DeformState (497-509)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (8)
  • GitHub Check: test-model (macos-14, 3.12, lowest-direct, fairchem, tests/models/test_fairchem.py)
  • GitHub Check: test-model (macos-14, 3.12, lowest-direct, sevenn, tests/models/test_sevennet.py)
  • GitHub Check: test-model (ubuntu-latest, 3.12, lowest-direct, sevenn, tests/models/test_sevennet.py)
  • GitHub Check: test-model (macos-14, 3.12, lowest-direct, mace, tests/models/test_mace.py)
  • GitHub Check: test-model (macos-14, 3.11, highest, orb, tests/models/test_orb.py)
  • GitHub Check: test-model (macos-14, 3.11, highest, sevenn, tests/models/test_sevennet.py)
  • GitHub Check: test-model (ubuntu-latest, 3.11, highest, sevenn, tests/models/test_sevennet.py)
  • GitHub Check: build-docs
🔇 Additional comments (8)
torch_sim/state.py (8)

88-95: LGTM! Clear attribute scope definitions.

The use of sets for scope attributes is appropriate and provides clear documentation of each attribute's scope. This addresses the ambiguous shape-inference errors mentioned in the PR objectives.


218-229: Type annotation correctly updated.

The return type change from torch.Tensor | None to torch.Tensor aligns with the PR's enforcement that tensor attributes cannot be None, as system_idx is always initialized.


496-507: Correct implementation of DeformGradMixin scope.

The kw_only=True addresses parameter ordering issues, and _system_attributes correctly declares the scope of reference_cell. Moving row_vector_cell to TYPE_CHECKING block is appropriate since it's a SimState property.


617-641: Well-implemented scope accessor with proper error handling.

The function correctly replaces dynamic scope inference with explicit scope access. The default case properly handles invalid scopes as suggested in previous reviews.


643-693: Clean refactoring using explicit scopes.

The removal of the ambiguous_handling parameter and use of get_attrs_for_scope simplifies the function while maintaining correct system_idx remapping logic.


696-745: Correct state splitting with explicit scopes.

The function properly handles splitting by scope, correctly excludes system_idx from per-atom splitting, and regenerates it as zeros for each single-system state.


872-915: Efficient concatenation avoiding system_idx double-handling.

The implementation correctly addresses the previous review comment by skipping system_idx during per-atom collection and handling it separately with proper offset calculation.


783-831: Consistent refactoring of state manipulation functions.

The updates to _pop_states and _slice_state correctly use the refactored _filter_attrs_by_mask function, maintaining consistency with the new explicit scope system.

@orionarcher
Copy link
Collaborator

orionarcher commented Aug 12, 2025

Since this is breaking let's plan to merge this when it's done and then release a v0.3.0 along with the recent bug fix?

@curtischong
Copy link
Collaborator Author

curtischong commented Aug 12, 2025

oh I am finished with it. I just wanted to make the tests pass. I'll bump the version in this PR though nvm I've already bumped to 0.3.0 in when we renamed batch to system

@CompRhys CompRhys requested a review from orionarcher August 12, 2025 23:53
@curtischong
Copy link
Collaborator Author

just so you're aware if you submit a PR review and "request for changes" no other person can approve it and have it merged. So we need your approval to move forward with this!

@orionarcher
Copy link
Collaborator

no other person can approve it and have it merged

Oop my bad I thought it could be overridden 😅

Approved now, LGTM and thank you!

@CompRhys CompRhys merged commit cde91e9 into main Aug 13, 2025
93 checks passed
@CompRhys CompRhys deleted the classify-range-of-simstate-feats2 branch August 13, 2025 20:16
@coderabbitai coderabbitai bot mentioned this pull request Aug 30, 2025
3 tasks
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
breaking Breaking changes cla-signed Contributor license agreement signed
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants