-
Notifications
You must be signed in to change notification settings - Fork 39
Define attribute scopes in SimStates #228
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
WalkthroughAdds explicit per-atom, per-system, and global attribute registries and enforces them at subclass creation; replaces dynamic scope inference with get_attrs_for_scope; many state dataclasses, integrators, optimizers, runners, examples, and tests updated to declare and use these scope sets and new atom-level attributes (e.g., last_permutation). Changes
Sequence Diagram(s)sequenceDiagram
participant Dev
participant StateClass
participant SimState
participant Consumer
Dev->>StateClass: Define dataclass fields + declare scope sets
StateClass->>SimState: class creation triggers __init_subclass__
SimState->>SimState: validate scopes & non-None tensor attributes
Consumer->>SimState: call get_attrs_for_scope(state, scope)
SimState-->>Consumer: yield (name, value) pairs for requested scope
Estimated code review effort🎯 4 (Complex) | ⏱️ ~40 minutes Possibly related PRs
Suggested labels
Suggested reviewers
Poem
✨ Finishing Touches
🧪 Generate unit tests
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. 🪧 TipsChatThere are 3 ways to chat with CodeRabbit:
SupportNeed help? Create a ticket on our support page for assistance with any issues or questions. CodeRabbit Commands (Invoked using PR comments)
Other keywords and placeholders
CodeRabbit Configuration File (
|
torch_sim/state.py
Outdated
def __init_subclass__(cls, **kwargs) -> None: | ||
"""Enforce all of child classes's attributes are specified in _atom_attributes, | ||
_system_attributes, or _global_attributes. | ||
""" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We should document the new requirements somewhere
the tests are failing because of the issue pointed out in this PR: #219. I'll revisit this tomorrow |
7395098
to
0157533
Compare
:'( I just got a |
4c49f21
to
84d6750
Compare
18182d4
to
23eea14
Compare
6755e8b
to
3d23f7d
Compare
add test for the new pre-defined attributes wip make fire simstate attributes predefined rename features to attributes define attribute scope for fire state consolidate attribute definitions and add it to npt rm infer_property_scope import __init_subclass__ to enforce that all attributes are specified find callable attributes do not check some user attributes filter for @properties __init_subclass__ doesn't catch for static state since it's a grandchild state added documentation for get_attrs_for_scope more examples for the other scopes cleaner documentation rename batch to system more docs cleanup test for running the init subclass more tests define more scope for the integrators split state handles none properties a bit better
074f100
to
c4f8ee0
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 1
🔭 Outside diff range comments (1)
torch_sim/state.py (1)
867-905
: concatenate_states: skip collecting system_idx to avoid redundant workSystem indices are recomputed via new_system_indices; also collecting/catting system_idx above is redundant and briefly incorrect before being overwritten.
- for prop, val in get_attrs_for_scope(state, "per-atom"): - # if hasattr(state, prop): - per_atom_tensors[prop].append(val) + for prop, val in get_attrs_for_scope(state, "per-atom"): + if prop == "system_idx": + continue + per_atom_tensors[prop].append(val)
🧹 Nitpick comments (4)
torch_sim/optimizers.py (2)
333-378
: UnitCell GD init: ensure dtype consistency for pressure and identitytorch.eye defaults to float32; if model dtype is float64 this will trigger promotions. Make eye explicitly match dtype to avoid subtle perf/dtype issues.
- pressure = scalar_pressure * torch.eye(3, device=device) + pressure = scalar_pressure * torch.eye(3, device=device, dtype=dtype) - virial = diag_mean.unsqueeze(-1) * torch.eye(3, device=device).unsqueeze( + virial = diag_mean.unsqueeze(-1) * torch.eye(3, device=device, dtype=dtype).unsqueeze( 0 ).expand(state.n_systems, -1, -1) - virial = virial - diag_mean.unsqueeze(-1) * torch.eye( - 3, device=device + virial = virial - diag_mean.unsqueeze(-1) * torch.eye( + 3, device=device, dtype=dtype ).unsqueeze(0).expand(state.n_systems, -1, -1)
483-536
: FireState: docstring mentions momenta but no property implementedEither implement the property or remove from docstring. Implementing it is tiny and useful.
class FireState(SimState): @@ Properties: momenta (torch.Tensor): Atomwise momenta of the system with shape [n_atoms, 3], calculated as velocities * masses @@ n_pos: torch.Tensor + + @property + def momenta(self) -> torch.Tensor: + """Atomwise momenta: velocities * masses.""" + return self.velocities * self.masses.unsqueeze(-1)torch_sim/state.py (2)
449-494
: Scope enforcement: considers ClassVar/properties/methods; one small improvementImplementation correctly avoids false positives (ClassVar, properties, methods). Consider tightening the error message by including the class name and guidance to add the attribute to exactly one scope.
- raise TypeError( - f"Attribute '{attr_name}' is not defined in {cls.__name__} in any " - "of _atom_attributes, _system_attributes, or _global_attributes" - ) + raise TypeError( + f"{cls.__name__}: attribute '{attr_name}' is not declared in any " + "scope. Add it to exactly one of: _atom_attributes, " + "_system_attributes, or _global_attributes." + )
612-636
: get_attrs_for_scope: LGTM; optional determinismFunction is correct and defensive. For deterministic iteration (helps debugging, reproducible dumps), consider yielding in sorted order.
- for attr_name in attr_names: + for attr_name in sorted(attr_names): yield attr_name, getattr(state, attr_name)
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (10)
examples/scripts/3_Dynamics/3.10_Hybrid_swap_mc.py
(1 hunks)examples/tutorials/hybrid_swap_tutorial.py
(2 hunks)tests/test_state.py
(3 hunks)torch_sim/integrators/md.py
(1 hunks)torch_sim/integrators/npt.py
(2 hunks)torch_sim/integrators/nvt.py
(1 hunks)torch_sim/monte_carlo.py
(1 hunks)torch_sim/optimizers.py
(10 hunks)torch_sim/runners.py
(2 hunks)torch_sim/state.py
(15 hunks)
🚧 Files skipped from review as they are similar to previous changes (8)
- torch_sim/integrators/nvt.py
- examples/scripts/3_Dynamics/3.10_Hybrid_swap_mc.py
- tests/test_state.py
- torch_sim/monte_carlo.py
- torch_sim/runners.py
- torch_sim/integrators/md.py
- examples/tutorials/hybrid_swap_tutorial.py
- torch_sim/integrators/npt.py
🧰 Additional context used
🧬 Code Graph Analysis (2)
torch_sim/optimizers.py (1)
torch_sim/state.py (2)
SimState
(29-493)DeformGradMixin
(497-533)
torch_sim/state.py (2)
torch_sim/properties/correlations.py (2)
update
(178-196)append
(57-71)tests/test_correlations.py (1)
split
(39-42)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (45)
- GitHub Check: test-examples (examples/scripts/2_Structural_optimization/2.3_MACE_Gradient_Descent.py)
- GitHub Check: test-examples (examples/tutorials/low_level_tutorial.py)
- GitHub Check: test-examples (examples/scripts/3_Dynamics/3.6_MACE_NVT_Nose_Hoover_temp_profile.py)
- GitHub Check: test-examples (examples/scripts/3_Dynamics/3.5_MACE_NVT_Nose_Hoover.py)
- GitHub Check: test-examples (examples/scripts/3_Dynamics/3.12_MACE_NPT_Langevin.py)
- GitHub Check: test-examples (examples/scripts/4_High_level_api/4.2_auto_batching_api.py)
- GitHub Check: test-examples (examples/scripts/5_Workflow/5.3_Elastic.py)
- GitHub Check: test-examples (examples/scripts/7_Others/7.1_Soft_sphere_autograd.py)
- GitHub Check: test-examples (examples/scripts/5_Workflow/5.1_a2c_silicon_batched.py)
- GitHub Check: test-examples (examples/scripts/3_Dynamics/3.3_MACE_NVE_cueq.py)
- GitHub Check: test-examples (examples/scripts/3_Dynamics/3.2_MACE_NVE.py)
- GitHub Check: test-examples (examples/tutorials/reporting_tutorial.py)
- GitHub Check: test-examples (examples/scripts/3_Dynamics/3.8_MACE_NPT_Nose_Hoover.py)
- GitHub Check: test-examples (examples/scripts/3_Dynamics/3.9_MACE_NVT_staggered_stress.py)
- GitHub Check: test-examples (examples/scripts/3_Dynamics/3.4_MACE_NVT_Langevin.py)
- GitHub Check: test-examples (examples/scripts/6_Phonons/6.1_Phonons_MACE.py)
- GitHub Check: test-examples (examples/scripts/7_Others/7.6_Compare_ASE_to_VV_FIRE.py)
- GitHub Check: test-examples (examples/scripts/3_Dynamics/3.10_Hybrid_swap_mc.py)
- GitHub Check: test-examples (examples/scripts/7_Others/7.4_Velocity_AutoCorrelation.py)
- GitHub Check: test-examples (examples/scripts/2_Structural_optimization/2.6_MACE_UnitCellFilter_FIRE.py)
- GitHub Check: test-examples (examples/tutorials/autobatching_tutorial.py)
- GitHub Check: test-examples (examples/scripts/2_Structural_optimization/2.7_MACE_FrechetCellFilter_FIRE.py)
- GitHub Check: test-examples (examples/tutorials/high_level_tutorial.py)
- GitHub Check: test-examples (examples/scripts/6_Phonons/6.2_QuasiHarmonic_MACE.py)
- GitHub Check: test-examples (examples/scripts/2_Structural_optimization/2.4_MACE_FIRE.py)
- GitHub Check: test-examples (examples/scripts/4_High_level_api/4.1_high_level_api.py)
- GitHub Check: test-examples (examples/tutorials/state_tutorial.py)
- GitHub Check: test-examples (examples/tutorials/hybrid_swap_tutorial.py)
- GitHub Check: test-model (macos-14, 3.12, lowest-direct, graphpes, tests/models/test_graphpes.py)
- GitHub Check: test-model (ubuntu-latest, 3.11, highest, orb, tests/models/test_orb.py)
- GitHub Check: test-model (ubuntu-latest, 3.12, lowest-direct, sevenn, tests/models/test_sevennet.py)
- GitHub Check: test-model (macos-14, 3.11, highest, mattersim, tests/models/test_mattersim.py)
- GitHub Check: test-model (ubuntu-latest, 3.12, lowest-direct, mace, tests/test_elastic.py)
- GitHub Check: test-model (ubuntu-latest, 3.12, lowest-direct, orb, tests/models/test_orb.py)
- GitHub Check: test-model (ubuntu-latest, 3.11, highest, sevenn, tests/models/test_sevennet.py)
- GitHub Check: test-model (ubuntu-latest, 3.12, lowest-direct, mace, tests/test_optimizers_vs_ase.py)
- GitHub Check: test-model (macos-14, 3.11, highest, mace, tests/test_elastic.py)
- GitHub Check: test-model (ubuntu-latest, 3.12, lowest-direct, fairchem, tests/models/test_fairchem.py)
- GitHub Check: test-model (ubuntu-latest, 3.11, highest, mace, tests/test_elastic.py)
- GitHub Check: test-core (macos-14, 3.11, highest)
- GitHub Check: test-model (ubuntu-latest, 3.11, highest, fairchem, tests/models/test_fairchem.py)
- GitHub Check: test-core (macos-14, 3.12, lowest-direct)
- GitHub Check: test-core (ubuntu-latest, 3.11, highest)
- GitHub Check: test-core (ubuntu-latest, 3.12, lowest-direct)
- GitHub Check: build-docs
🔇 Additional comments (9)
torch_sim/optimizers.py (4)
82-84
: GDState scopes: LGTMPer-atom and per-system scope declarations for forces and energy are correct and minimal.
178-239
: UnitCellGDState scopes: LGTMScopes cover reference_cell (via mixin), stress/pressure/cell_* tensors, and global flags. Matches SimState’s enforcement.
748-751
: UnitCellFireState: scope sets look correctAtom/system/global registries match the class’ declared fields (including DeformGradMixin.reference_cell).
1040-1043
: FrechetCellFIREState: scope sets look correctAtom/system/global registries align with fields and mixin. No duplicates across scopes.
torch_sim/state.py (5)
88-96
: Default scope sets on SimState: LGTMSane defaults and clear separation: per-atom (positions/masses/atomic_numbers/system_idx), per-system (cell), global (pbc).
415-426
: Subclass enforcement: good placement and coverageBoth “no Tensor | None” and “all attributes must be scoped” checks are in init_subclass. Solid guardrails.
638-687
: Filtering by masks: remapping system_idx is correctHandling of system_idx with remapping to consecutive indices is correct and robust when systems are removed.
708-722
: Split state: sensible split across scopes; global attrs copiedPer-atom/system splits and zeroed system_idx per single-system outputs are correct.
778-787
: Pop states: correct masking and split; good non-mutating semanticsThe keep/pop flow and subsequent split are sound.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 2
♻️ Duplicate comments (1)
torch_sim/optimizers.py (1)
36-58
: Rename to private _md_atom_attributes is applied; consistent with prior feedbackThe top-level constant is now private and referenced consistently in state classes. Good follow-through on the previous suggestion.
🧹 Nitpick comments (2)
torch_sim/optimizers.py (2)
36-58
: Make module-level attribute registries immutable (optional)To avoid accidental mutation at runtime, consider using frozenset for these registries.
Apply this diff:
-_md_atom_attributes = SimState._atom_attributes | {"forces", "velocities"} # noqa: SLF001 +_md_atom_attributes = frozenset(SimState._atom_attributes | {"forces", "velocities"}) # noqa: SLF001 -_fire_system_attributes = ( +_fire_system_attributes = frozenset( SimState._system_attributes # noqa: SLF001 | DeformGradMixin._system_attributes # noqa: SLF001 | { "energy", "stress", "cell_positions", "cell_velocities", "cell_forces", "cell_masses", "cell_factor", "pressure", "dt", "alpha", "n_pos", } ) -_fire_global_attributes = SimState._global_attributes | { # noqa: SLF001 +_fire_global_attributes = frozenset(SimState._global_attributes | { # noqa: SLF001 "hydrostatic_strain", "constant_volume", -} +})
224-239
: Nit: pass dtype to torch.eye in related functions to avoid implicit castsIn unit_cell_gradient_descent.gd_init, torch.eye(...) is created without dtype, which can cause dtype promotion when mixed with dtype tensors.
Suggested adjustments outside this hunk:
- pressure = scalar_pressure * torch.eye(3, device=device) + pressure = scalar_pressure * torch.eye(3, device=device, dtype=dtype) - virial = diag_mean.unsqueeze(-1) * torch.eye(3, device=device).unsqueeze( + virial = diag_mean.unsqueeze(-1) * torch.eye(3, device=device, dtype=dtype).unsqueeze( - virial = virial - diag_mean.unsqueeze(-1) * torch.eye( - 3, device=device + virial = virial - diag_mean.unsqueeze(-1) * torch.eye( + 3, device=device, dtype=dtype ).unsqueeze(0).expand(state.n_systems, -1, -1)
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (1)
torch_sim/optimizers.py
(10 hunks)
🧰 Additional context used
🧠 Learnings (1)
📓 Common learnings
Learnt from: curtischong
PR: Radical-AI/torch-sim#215
File: torch_sim/runners.py:48-53
Timestamp: 2025-08-08T02:04:31.517Z
Learning: In the torch-sim project, when reviewing PRs, scope expansion issues (like runtime safety improvements) that are outside the current PR's main objective should be acknowledged but deferred to separate PRs rather than insisting on immediate fixes.
🧬 Code Graph Analysis (1)
torch_sim/optimizers.py (3)
torch_sim/state.py (4)
SimState
(29-493)DeformGradMixin
(497-533)SimState
(26-401)_get_property_attrs
(599-632)torch_sim/integrators/md.py (1)
MDState
(14-49)tests/test_state.py (1)
DeformState
(497-509)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (45)
- GitHub Check: test-examples (examples/scripts/3_Dynamics/3.4_MACE_NVT_Langevin.py)
- GitHub Check: test-examples (examples/scripts/3_Dynamics/3.1_Lennard_Jones_NVE.py)
- GitHub Check: test-examples (examples/scripts/3_Dynamics/3.6_MACE_NVT_Nose_Hoover_temp_profile.py)
- GitHub Check: test-examples (examples/scripts/3_Dynamics/3.9_MACE_NVT_staggered_stress.py)
- GitHub Check: test-examples (examples/scripts/3_Dynamics/3.13_MACE_NVE_non_pbc.py)
- GitHub Check: test-examples (examples/scripts/3_Dynamics/3.3_MACE_NVE_cueq.py)
- GitHub Check: test-examples (examples/scripts/3_Dynamics/3.2_MACE_NVE.py)
- GitHub Check: test-examples (examples/scripts/3_Dynamics/3.8_MACE_NPT_Nose_Hoover.py)
- GitHub Check: test-examples (examples/scripts/3_Dynamics/3.12_MACE_NPT_Langevin.py)
- GitHub Check: test-examples (examples/scripts/3_Dynamics/3.10_Hybrid_swap_mc.py)
- GitHub Check: test-examples (examples/scripts/5_Workflow/5.3_Elastic.py)
- GitHub Check: test-examples (examples/scripts/7_Others/7.4_Velocity_AutoCorrelation.py)
- GitHub Check: test-examples (examples/scripts/5_Workflow/5.2_In_Flight_WBM.py)
- GitHub Check: test-examples (examples/scripts/7_Others/7.6_Compare_ASE_to_VV_FIRE.py)
- GitHub Check: test-examples (examples/scripts/5_Workflow/5.1_a2c_silicon_batched.py)
- GitHub Check: test-examples (examples/scripts/2_Structural_optimization/2.6_MACE_UnitCellFilter_FIRE.py)
- GitHub Check: test-examples (examples/scripts/6_Phonons/6.2_QuasiHarmonic_MACE.py)
- GitHub Check: test-examples (examples/scripts/2_Structural_optimization/2.4_MACE_FIRE.py)
- GitHub Check: test-examples (examples/scripts/2_Structural_optimization/2.3_MACE_Gradient_Descent.py)
- GitHub Check: test-examples (examples/scripts/4_High_level_api/4.1_high_level_api.py)
- GitHub Check: test-examples (examples/tutorials/low_level_tutorial.py)
- GitHub Check: test-examples (examples/scripts/6_Phonons/6.3_Conductivity_MACE.py)
- GitHub Check: test-examples (examples/scripts/6_Phonons/6.1_Phonons_MACE.py)
- GitHub Check: test-examples (examples/tutorials/reporting_tutorial.py)
- GitHub Check: test-examples (examples/tutorials/high_level_tutorial.py)
- GitHub Check: test-examples (examples/tutorials/autobatching_tutorial.py)
- GitHub Check: test-examples (examples/tutorials/hybrid_swap_tutorial.py)
- GitHub Check: test-model (macos-14, 3.12, lowest-direct, mace, tests/test_optimizers_vs_ase.py)
- GitHub Check: test-model (macos-14, 3.11, highest, sevenn, tests/models/test_sevennet.py)
- GitHub Check: test-model (ubuntu-latest, 3.12, lowest-direct, orb, tests/models/test_orb.py)
- GitHub Check: test-model (macos-14, 3.11, highest, orb, tests/models/test_orb.py)
- GitHub Check: test-model (macos-14, 3.11, highest, metatomic, tests/models/test_metatomic.py)
- GitHub Check: test-model (ubuntu-latest, 3.12, lowest-direct, mace, tests/test_elastic.py)
- GitHub Check: test-model (ubuntu-latest, 3.12, lowest-direct, fairchem, tests/models/test_fairchem.py)
- GitHub Check: test-model (ubuntu-latest, 3.12, lowest-direct, mace, tests/test_optimizers_vs_ase.py)
- GitHub Check: test-model (ubuntu-latest, 3.12, lowest-direct, sevenn, tests/models/test_sevennet.py)
- GitHub Check: test-model (ubuntu-latest, 3.11, highest, mace, tests/test_optimizers_vs_ase.py)
- GitHub Check: test-model (ubuntu-latest, 3.11, highest, mace, tests/test_elastic.py)
- GitHub Check: test-model (ubuntu-latest, 3.11, highest, orb, tests/models/test_orb.py)
- GitHub Check: test-model (ubuntu-latest, 3.11, highest, sevenn, tests/models/test_sevennet.py)
- GitHub Check: test-model (ubuntu-latest, 3.11, highest, fairchem, tests/models/test_fairchem.py)
- GitHub Check: test-core (ubuntu-latest, 3.11, highest)
- GitHub Check: test-core (ubuntu-latest, 3.12, lowest-direct)
- GitHub Check: test-core (macos-14, 3.12, lowest-direct)
- GitHub Check: build-docs
🔇 Additional comments (8)
torch_sim/optimizers.py (8)
82-84
: GDState attribute scopes are correctforces as per-atom and energy as per-system matches usage. Good.
224-239
: UnitCellGDState scopes look consistent with DeformGradMixin and added fieldsSystem attributes include reference_cell (via mixin), stress, pressure, and cell_*; globals include hydrostatic_strain and constant_volume. Looks good.
483-483
: kw_only dataclass for FireState is appropriateThis is consistent with the PR’s scope-enforcement and constructor changes.
748-751
: UnitCellFireState attribute scopes look correctAtom attributes use md_atom_attributes; system/global sets match FIRE requirements incl. cell* and dt/alpha/n_pos.
966-966
: kw_only dataclass for FrechetCellFIREState is appropriateMatches the rest of the module and PR goals.
1040-1043
: FrechetCellFIREState scopes look consistent with FIRE and deform-grad needsAtom/system/global sets align with fire* registries. LGTM.
178-178
: No positional instantiations found for UnitCellGDState
All existing call sites use keyword arguments, so the switch tokw_only=True
does not break construction.• torch_sim/optimizers.py:379–381:
return UnitCellGDState(positions=state.positions, forces=forces, …)
1040-1043
: Minimum PyTorch requirement already covers torch.segment_reduceThe project’s pyproject.toml pins
• torch>=2.0 (line 33)
And torch.segment_reduce was introduced in v1.9, so this usage is safe. No guard or version bump needed.
@@ -1261,7 +1321,7 @@ def _vv_fire_step( # noqa: C901, PLR0915 | |||
state.positions[nan_velocities] | |||
) | |||
if is_cell_optimization: | |||
if not isinstance(state, AnyFireCellState): | |||
if not isinstance(state, get_args(AnyFireCellState)): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
without get_args this isinstance check is wrong I believe (since AnyFireCellState is a union type, we need to convert AnyFireCellState to a tuple of its individual types). coderabbit caught this which is why I put it in this PR
3441ae2
to
acabbad
Compare
acabbad
to
6f7439b
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 1
🔭 Outside diff range comments (1)
torch_sim/state.py (1)
643-692
: Vectorize system_idx remapping to avoid Python loopsThe current list-comprehension remapping for system_idx works but scales poorly. Vectorize to reduce Python overhead and allocations.
Apply within this block:
- if attr_name == "system_idx": - # Get the old system indices for the selected atoms - old_system_idxs = attr_value[atom_mask] - - # Get the system indices that are kept - kept_indices = torch.arange(attr_value.max() + 1, device=attr_value.device)[ - system_mask - ] - - # Create a mapping from old system indices to new consecutive indices - system_idx_map = {idx.item(): i for i, idx in enumerate(kept_indices)} - - # Create new system tensor with remapped indices - new_system_idxs = torch.tensor( - [system_idx_map[b.item()] for b in old_system_idxs], - device=attr_value.device, - dtype=attr_value.dtype, - ) - filtered_attrs[attr_name] = new_system_idxs + if attr_name == "system_idx": + old_system_idxs = attr_value[atom_mask] + n_systems = state.n_systems + kept_indices = torch.arange(n_systems, device=attr_value.device)[system_mask] + mapping = torch.full((n_systems,), -1, device=attr_value.device, dtype=attr_value.dtype) + mapping[kept_indices] = torch.arange( + kept_indices.numel(), device=attr_value.device, dtype=attr_value.dtype + ) + filtered_attrs[attr_name] = mapping[old_system_idxs]
🧹 Nitpick comments (1)
tests/test_state.py (1)
27-39
: Good coverage of per-atom/system/global scopesAssertions match SimState’s declared scopes. Consider adding a negative test for an invalid scope to lock in the error behavior.
Example:
+def test_get_attrs_for_scope_invalid_scope(si_sim_state: ts.SimState) -> None: + with pytest.raises(ValueError, match="Unknown scope"): + list(get_attrs_for_scope(si_sim_state, "invalid")) # type: ignore[arg-type]
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (4)
examples/tutorials/state_tutorial.py
(2 hunks)tests/test_state.py
(3 hunks)torch_sim/optimizers.py
(12 hunks)torch_sim/state.py
(15 hunks)
🚧 Files skipped from review as they are similar to previous changes (1)
- examples/tutorials/state_tutorial.py
🧰 Additional context used
🧠 Learnings (3)
📓 Common learnings
Learnt from: curtischong
PR: Radical-AI/torch-sim#215
File: torch_sim/runners.py:48-53
Timestamp: 2025-08-08T02:04:31.517Z
Learning: In the torch-sim project, when reviewing PRs, scope expansion issues (like runtime safety improvements) that are outside the current PR's main objective should be acknowledged but deferred to separate PRs rather than insisting on immediate fixes.
📚 Learning: 2025-08-11T03:47:27.656Z
Learnt from: curtischong
PR: Radical-AI/torch-sim#228
File: torch_sim/optimizers.py:675-675
Timestamp: 2025-08-11T03:47:27.656Z
Learning: In the torch-sim project, `DeformGradMixin` is a `dataclass(kw_only=True)` that declares `reference_cell: torch.Tensor` as a field. Classes that inherit from `DeformGradMixin` (like `UnitCellFireState`) automatically inherit the `reference_cell` field through Python's dataclass inheritance mechanism, so they don't need to redeclare it.
Applied to files:
torch_sim/optimizers.py
torch_sim/state.py
📚 Learning: 2025-08-08T02:04:31.517Z
Learnt from: curtischong
PR: Radical-AI/torch-sim#215
File: torch_sim/runners.py:48-53
Timestamp: 2025-08-08T02:04:31.517Z
Learning: In the torch-sim project, when reviewing PRs, scope expansion issues (like runtime safety improvements) that are outside the current PR's main objective should be acknowledged but deferred to separate PRs rather than insisting on immediate fixes.
Applied to files:
torch_sim/optimizers.py
🧬 Code Graph Analysis (2)
torch_sim/optimizers.py (3)
torch_sim/state.py (2)
SimState
(29-493)DeformGradMixin
(497-538)torch_sim/integrators/md.py (1)
MDState
(14-49)tests/test_optimizers.py (1)
test_fire_fixed_cell_unit_cell_consistency
(786-880)
tests/test_state.py (3)
torch_sim/state.py (3)
get_attrs_for_scope
(617-640)SimState
(29-493)DeformGradMixin
(497-538)tests/conftest.py (1)
si_sim_state
(142-144)torch_sim/integrators/md.py (1)
MDState
(14-49)
🔇 Additional comments (25)
tests/test_state.py (4)
16-18
: API migration to get_attrs_for_scope looks correctImport aligns with the new public API; removes reliance on shape inference. No issues.
42-57
: Scope-coverage enforcement test is solidThe test accurately validates that missing attributes trigger a TypeError and that attributes present in scope registries are not flagged. LGTM.
59-73
: Duplicate-attribute enforcement test is correctCatches attributes declared in multiple scopes and asserts the error message. Looks good.
514-517
: DeformState: proper composition of system attributesUnioning SimState and DeformGradMixin system attributes makes reference_cell available for scope logic without polluting atom/global scopes. Good.
torch_sim/state.py (9)
88-96
: Explicit scope registries are well-chosenPer-atom/system/global sets are minimal and unambiguous. This enables deterministic behavior in downstream utilities.
218-229
: batch property deprecation path is consistentReturn type is non-optional and mirrors system_idx with deprecation warnings. Matches tests.
415-426
: Subclass enforcement placement is correctRunning validations before super().init_subclass ensures early failure for misdeclared subclasses.
427-448
: None-forbidden tensor enforcement is robustCovers both typing.Union and PEP 604 unions. Clear error message to authors.
449-494
: Scope enforcement: duplicates and coverage are handled
- Detects duplicates across scope sets.
- Aggregates annotations across MRO.
- Skips properties, methods, and ClassVar-annotated names to avoid false positives.
This is the right balance between safety and flexibility.
496-503
: DeformGradMixin: kw_only dataclass with scoped reference_cellDeclaring reference_cell and registering it in _system_attributes fits the new model; TYPE_CHECKING guard for row_vector_cell avoids runtime field injection.
Also applies to: 504-508
617-641
: get_attrs_for_scope: clear API and defensive defaultLiteral-typed scope, explicit default raising ValueError, and generator semantics are clean. Matches tests.
713-717
: Split state: correct handling of scopes
- Ignores system_idx in per-atom splits and reconstructs it to zeros.
- Splits per-system on dim=0 as expected.
- Copies global attributes verbatim.
All good.Also applies to: 719-721, 722-723
783-792
: Pop states: correct reuse of mask-based filteringThe keep/pop paths share the same filtering logic; then the popped aggregate is split into per-system states. Solid approach.
torch_sim/optimizers.py (12)
36-58
: Well-scoped FIRE/MD attribute registries
- _md_atom_attributes extends per-atom with forces/velocities.
- _fire_system_attributes unions SimState, DeformGradMixin, and FIRE-specific fields.
- _fire_global_attributes captures booleans controlling cell optimization.
This centralization reduces drift across classes.
82-84
: GDState: scope sets correctly reflect forces/energyMatches the docstring and downstream usage. Good.
178-178
: UnitCellGDState kw_only dataclasskw_only aligns with SimState subclassing changes and reduces param-ordering errors.
224-239
: UnitCellGDState: merged scope declarations are coherentIncludes DeformGradMixin.system attributes (reference_cell) and adds cell/stress-related per-system and global items. Looks consistent.
483-483
: FireState kw_only dataclassConsistent with the new constructor pattern.
527-537
: FireState: scopes include energy and FIRE params per-systemAtom attributes reuse the MD set; system-level adds energy, dt, alpha, n_pos. Matches algorithm needs.
675-675
: UnitCellFireState kw_only dataclassWorks with DeformGradMixin being kw_only; no need to redeclare reference_cell.
748-751
: UnitCellFireState: attribute scopes align with FIRE cell optimizationReuses shared FIRE sets; keeps global booleans separate. Good.
966-966
: FrechetCellFIREState kw_only dataclassConsistent with neighboring states and mixin expectations.
1039-1042
: FrechetCellFIREState: scopes mirror UnitCellFireStateKeeps atom/system/global symmetry between the two FIRE variants. Good.
1323-1326
: Fix isinstance checks using get_args(AnyFireCellState)Using a tuple of concrete classes avoids runtime TypeError. Correct fix.
1536-1536
: Initialize cur_deform_grad and guard is_cell_optimizationPrevents UnboundLocalError and properly validates state type for cell optimization. LGTM.
Also applies to: 1546-1554
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 0
🧹 Nitpick comments (1)
torch_sim/state.py (1)
449-494
: Consider adding defensive checks for scope attribute existence.While the validation logic is well-implemented, child classes might not define all three scope attributes. Consider checking attribute existence before accessing them:
@classmethod def _assert_all_attributes_have_defined_scope(cls) -> None: + # Ensure scope attributes exist with defaults + atom_attrs = getattr(cls, '_atom_attributes', set()) + system_attrs = getattr(cls, '_system_attributes', set()) + global_attrs = getattr(cls, '_global_attributes', set()) all_defined_attributes = ( - cls._atom_attributes | cls._system_attributes | cls._global_attributes + atom_attrs | system_attrs | global_attrs ) # 1) assert that no attribute is defined twice in all_defined_attributes duplicates = ( - (cls._atom_attributes & cls._system_attributes) - | (cls._atom_attributes & cls._global_attributes) - | (cls._system_attributes & cls._global_attributes) + (atom_attrs & system_attrs) + | (atom_attrs & global_attrs) + | (system_attrs & global_attrs) )
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (1)
torch_sim/state.py
(15 hunks)
🧰 Additional context used
🧠 Learnings (2)
📓 Common learnings
Learnt from: curtischong
PR: Radical-AI/torch-sim#215
File: torch_sim/runners.py:48-53
Timestamp: 2025-08-08T02:04:31.517Z
Learning: In the torch-sim project, when reviewing PRs, scope expansion issues (like runtime safety improvements) that are outside the current PR's main objective should be acknowledged but deferred to separate PRs rather than insisting on immediate fixes.
📚 Learning: 2025-08-11T03:47:27.656Z
Learnt from: curtischong
PR: Radical-AI/torch-sim#228
File: torch_sim/optimizers.py:675-675
Timestamp: 2025-08-11T03:47:27.656Z
Learning: In the torch-sim project, `DeformGradMixin` is a `dataclass(kw_only=True)` that declares `reference_cell: torch.Tensor` as a field. Classes that inherit from `DeformGradMixin` (like `UnitCellFireState`) automatically inherit the `reference_cell` field through Python's dataclass inheritance mechanism, so they don't need to redeclare it.
Applied to files:
torch_sim/state.py
🧬 Code Graph Analysis (1)
torch_sim/state.py (1)
tests/test_state.py (2)
test_infer_sim_state_property_scope
(27-37)DeformState
(497-509)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (8)
- GitHub Check: test-model (macos-14, 3.12, lowest-direct, fairchem, tests/models/test_fairchem.py)
- GitHub Check: test-model (macos-14, 3.12, lowest-direct, sevenn, tests/models/test_sevennet.py)
- GitHub Check: test-model (ubuntu-latest, 3.12, lowest-direct, sevenn, tests/models/test_sevennet.py)
- GitHub Check: test-model (macos-14, 3.12, lowest-direct, mace, tests/models/test_mace.py)
- GitHub Check: test-model (macos-14, 3.11, highest, orb, tests/models/test_orb.py)
- GitHub Check: test-model (macos-14, 3.11, highest, sevenn, tests/models/test_sevennet.py)
- GitHub Check: test-model (ubuntu-latest, 3.11, highest, sevenn, tests/models/test_sevennet.py)
- GitHub Check: build-docs
🔇 Additional comments (8)
torch_sim/state.py (8)
88-95
: LGTM! Clear attribute scope definitions.The use of sets for scope attributes is appropriate and provides clear documentation of each attribute's scope. This addresses the ambiguous shape-inference errors mentioned in the PR objectives.
218-229
: Type annotation correctly updated.The return type change from
torch.Tensor | None
totorch.Tensor
aligns with the PR's enforcement that tensor attributes cannot be None, assystem_idx
is always initialized.
496-507
: Correct implementation of DeformGradMixin scope.The
kw_only=True
addresses parameter ordering issues, and_system_attributes
correctly declares the scope ofreference_cell
. Movingrow_vector_cell
to TYPE_CHECKING block is appropriate since it's a SimState property.
617-641
: Well-implemented scope accessor with proper error handling.The function correctly replaces dynamic scope inference with explicit scope access. The default case properly handles invalid scopes as suggested in previous reviews.
643-693
: Clean refactoring using explicit scopes.The removal of the ambiguous_handling parameter and use of
get_attrs_for_scope
simplifies the function while maintaining correct system_idx remapping logic.
696-745
: Correct state splitting with explicit scopes.The function properly handles splitting by scope, correctly excludes system_idx from per-atom splitting, and regenerates it as zeros for each single-system state.
872-915
: Efficient concatenation avoiding system_idx double-handling.The implementation correctly addresses the previous review comment by skipping
system_idx
during per-atom collection and handling it separately with proper offset calculation.
783-831
: Consistent refactoring of state manipulation functions.The updates to
_pop_states
and_slice_state
correctly use the refactored_filter_attrs_by_mask
function, maintaining consistency with the new explicit scope system.
Since this is breaking let's plan to merge this when it's done and then release a v0.3.0 along with the recent bug fix? |
oh I am finished with it. I just wanted to make the tests pass. |
just so you're aware if you submit a PR review and "request for changes" no other person can approve it and have it merged. So we need your approval to move forward with this! |
Oop my bad I thought it could be overridden 😅 Approved now, LGTM and thank you! |
Summary
This PR makes handling SimStates much simpler. Rather than guessing if an attribute is per-node/system/global (via
infer_property_scope
, SimState (and its child classes) explicitly define which the scope of each attribute.#227 was a precursor to this PR
By explicitly defining a property scope to each state we:
n_atoms (1) and n_systems (1) are equal, which means shapes cannot be inferred unambiguously
errors which can happen when systems have a single atom.In general I think we should prevent these edge cases from occurring because they are uncommon but happen often enough.
I added validation logic inside the
SimState
class to ensure that each child class explicitly defines the property scope for all its attributes.Note: this PR introduces breaking changes. In particular, it makes all of the classes that derive from SimState be kw_only. This is because the
_system_attributes
of the parent class are defined for all of the attributes:e.g.
_system_attributes = {"reference_cell"}
so all child classes that inherit from this, put their arguments AFTER the parent's attributes. But since we've assigned a value to_system_attributes
all future attributes (i.e. the ones in the child class) must be kwonly.I think this breaking change is fine because by using keywords when defining complex objects like
FireState
, it eliminates the chance of a user from mixing up differenttorch.Tensor
constructor params.Checklist
Before a pull request can be merged, the following items must be checked:
Run ruff on your code.
We highly recommended installing the pre-commit hooks running in CI locally to speedup the development process. Simply run
pip install pre-commit && pre-commit install
to install the hooks which will check your code before each commit.Summary by CodeRabbit
New Features
Bug Fixes
Documentation
Tests