Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
33 commits
Select commit Hold shift + click to select a range
5cfd552
use hardcoded attr names
curtischong Jul 27, 2025
4781bc4
fix bad rebase
curtischong Aug 3, 2025
c4f8ee0
add last_permutation to system attributes
curtischong Aug 6, 2025
5204370
cleanup init subclass
curtischong Aug 9, 2025
b403fa3
define scope for deformgradmixin attributes
curtischong Aug 9, 2025
a902700
remove duplicate definitions of reference_cell. also fix scope defini…
curtischong Aug 9, 2025
9e71a89
fix _fire_system_attributes declaration
curtischong Aug 9, 2025
8110f43
add back reference cell attr to be defined
curtischong Aug 9, 2025
f2be537
manually add row_vector_cell to each reference_cell
curtischong Aug 9, 2025
2012d74
use kwonly to hopefully fix default arg issues
curtischong Aug 9, 2025
ba31ea1
more kwargs
curtischong Aug 9, 2025
6768f4e
make more states kwonly
curtischong Aug 9, 2025
4b83294
try different params
curtischong Aug 9, 2025
dc860c3
make the parent class a dataclass so attributes just propagate down
curtischong Aug 9, 2025
099e38c
fix: row_vector_cell is just an alias to cell.mT not an attribute
CompRhys Aug 10, 2025
e9c9aae
make deform kwonly true
curtischong Aug 10, 2025
c23ae3a
splitting logic now just uses torch.split
curtischong Aug 10, 2025
aa29f81
revert to old method of split per atom and system (more readable)
curtischong Aug 10, 2025
af09df4
see if it works if I remove reference_cell: torch.Tensor
curtischong Aug 10, 2025
0e130d8
coderabbit comments
curtischong Aug 10, 2025
3d89fb3
rm xdist and remove dependency between integrators and optimizers
curtischong Aug 10, 2025
9cbcd87
swap from tuple to set
curtischong Aug 10, 2025
5e181ed
rm classvar annotation from all simstate
curtischong Aug 11, 2025
50fa99d
make md_atom_attributes private
curtischong Aug 11, 2025
aee26df
fix isinstance type check
curtischong Aug 11, 2025
971b7bd
fix docs
curtischong Aug 11, 2025
318e9cf
cleanup test type annotation
curtischong Aug 11, 2025
d8f67ff
add back row_vector_cell to deformgrad mixin
curtischong Aug 11, 2025
6f7439b
clone row_vector_cell
curtischong Aug 11, 2025
77bdc90
rm clone of row_vector_cell
curtischong Aug 11, 2025
81f830b
try to define row_vector_cell for the typechecker
curtischong Aug 11, 2025
7ba046b
skip processing system_idx as recommended by coderabbit
curtischong Aug 11, 2025
c908a72
Merge branch 'main' into classify-range-of-simstate-feats2
orionarcher Aug 12, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions examples/scripts/3_Dynamics/3.10_Hybrid_swap_mc.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,9 @@ class HybridSwapMCState(MDState):
"""

last_permutation: torch.Tensor
_atom_attributes = (
MDState._atom_attributes | {"last_permutation"} # noqa: SLF001
)


nvt_init, nvt_step = nvt_langevin(model=model, dt=0.002, kT=kT)
Expand Down
5 changes: 5 additions & 0 deletions examples/tutorials/hybrid_swap_tutorial.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,9 +34,11 @@
"""

# %%
from typing import ClassVar
import torch
import torch_sim as ts
from mace.calculators.foundations_models import mace_mp
from torch_sim.integrators.md import MDState
from torch_sim.models.mace import MaceModel

# Initialize the mace model
Expand Down Expand Up @@ -104,6 +106,9 @@ class HybridSwapMCState(ts.integrators.MDState):
"""

last_permutation: torch.Tensor
_atom_attributes = (
MDState._atom_attributes | {"last_permutation"} # noqa: SLF001
)


# %% [markdown]
Expand Down
27 changes: 19 additions & 8 deletions examples/tutorials/state_tutorial.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,16 +79,28 @@
the base SimState. Names are singular.
* Global attributes have any other shape or type, just `pbc` here. Names are singular.

You can use the `infer_property_scope` function to analyze a state's properties. This
For TorchSim to know which attributes are atomwise, systemwise, and global, each attribute's
name is explicitly defined in the `_atom_attributes`, `_system_attributes`, and `_global_attributes`:

_atom_attributes = {"positions", "masses", "atomic_numbers", "system_idx"}
_system_attributes = {"cell"}
_global_attributes = {"pbc"}

You can use the `get_attrs_for_scope` generator function to analyze a state's properties. This
is mostly used internally but can be useful for debugging.
"""

# %%
from torch_sim.state import infer_property_scope
from torch_sim.state import get_attrs_for_scope

scope = infer_property_scope(si_state)
print(scope)
# loop through each attribute:
for attr_name, attr_value in get_attrs_for_scope(si_state, "per-atom"):
print(f"per-atom attribute: {attr_name}")
print(f"value: {attr_value}")

# or access the attributes via a dict:
print("Per-system attributes:", dict(get_attrs_for_scope(si_state, "per-system"))) # noqa: E501
print("Global attributes:", dict(get_attrs_for_scope(si_state, "global")))

# %% [markdown]
"""
Expand Down Expand Up @@ -257,10 +269,9 @@
)

print("MDState properties:")
scope = infer_property_scope(md_state)
print("Global properties:", scope["global"])
print("Per-atom properties:", scope["per_atom"])
print("Per-system properties:", scope["per_system"])
print("Per-atom attributes:", dict(get_attrs_for_scope(si_state, "per-atom")))
print("Per-system attributes:", dict(get_attrs_for_scope(si_state, "per-system")))
print("Global attributes:", dict(get_attrs_for_scope(si_state, "global")))


# %% [markdown]
Expand Down
71 changes: 45 additions & 26 deletions tests/test_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
_pop_states,
_slice_state,
concatenate_states,
infer_property_scope,
get_attrs_for_scope,
initialize_state,
)

Expand All @@ -24,38 +24,52 @@
from pymatgen.core import Structure


def test_infer_sim_state_property_scope(si_sim_state: ts.SimState) -> None:
"""Test inference of property scope."""
scope = infer_property_scope(si_sim_state)
assert set(scope["global"]) == {"pbc"}
assert set(scope["per_atom"]) == {
def test_get_attrs_for_scope(si_sim_state: ts.SimState) -> None:
"""Test getting attributes for a scope."""
per_atom_attrs = dict(get_attrs_for_scope(si_sim_state, "per-atom"))
assert set(per_atom_attrs.keys()) == {
"positions",
"masses",
"atomic_numbers",
"system_idx",
}
assert set(scope["per_system"]) == {"cell"}
per_system_attrs = dict(get_attrs_for_scope(si_sim_state, "per-system"))
assert set(per_system_attrs.keys()) == {"cell"}
global_attrs = dict(get_attrs_for_scope(si_sim_state, "global"))
assert set(global_attrs.keys()) == {"pbc"}


def test_infer_md_state_property_scope(si_sim_state: ts.SimState) -> None:
"""Test inference of property scope."""
state = MDState(
**asdict(si_sim_state),
momenta=torch.randn_like(si_sim_state.positions),
forces=torch.randn_like(si_sim_state.positions),
energy=torch.zeros((1,)),
)
scope = infer_property_scope(state)
assert set(scope["global"]) == {"pbc"}
assert set(scope["per_atom"]) == {
"positions",
"masses",
"atomic_numbers",
"system_idx",
"forces",
"momenta",
}
assert set(scope["per_system"]) == {"cell", "energy"}
def test_all_attributes_must_be_specified_in_scopes() -> None:
"""Test that an error is raised when we forget to specify the scope
for an attribute in a child SimState class."""
with pytest.raises(TypeError) as excinfo:

class ChildState(SimState):
attribute_specified_in_scopes: bool
attribute_not_specified_in_scopes: bool

_atom_attributes = (
SimState._atom_attributes | {"attribute_specified_in_scopes"} # noqa: SLF001
)

assert "attribute_not_specified_in_scopes" in str(excinfo.value)
assert "attribute_specified_in_scopes" not in str(excinfo.value)


def test_no_duplicate_attributes_in_scopes() -> None:
"""Test that no attributes are specified in multiple scopes."""

# Capture the exception information using "as excinfo"
with pytest.raises(TypeError) as excinfo:

class ChildState(SimState):
duplicated_attribute: bool

_system_attributes = SimState._system_attributes | {"duplicated_attribute"} # noqa: SLF001
_global_attributes = SimState._global_attributes | {"duplicated_attribute"} # noqa: SLF001

assert "are declared multiple times" in str(excinfo.value)
assert "duplicated_attribute" in str(excinfo.value)


def test_slice_substate(
Expand Down Expand Up @@ -497,6 +511,11 @@ def test_column_vector_cell(si_sim_state: ts.SimState) -> None:
class DeformState(SimState, DeformGradMixin):
"""Test class that combines SimState with DeformGradMixin."""

_system_attributes = (
SimState._system_attributes # noqa: SLF001
| DeformGradMixin._system_attributes # noqa: SLF001
)

def __init__(
self,
*args,
Expand Down
7 changes: 7 additions & 0 deletions torch_sim/integrators/md.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,13 @@ class MDState(SimState):
energy: torch.Tensor
forces: torch.Tensor

_atom_attributes = (
SimState._atom_attributes | {"momenta", "forces"} # noqa: SLF001
)
_system_attributes = (
SimState._system_attributes | {"energy"} # noqa: SLF001
)

@property
def velocities(self) -> torch.Tensor:
"""Velocities calculated from momenta and masses with shape
Expand Down
31 changes: 31 additions & 0 deletions torch_sim/integrators/npt.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,18 @@ class NPTLangevinState(SimState):
cell_velocities: torch.Tensor
cell_masses: torch.Tensor

_atom_attributes = (
SimState._atom_attributes | {"forces", "velocities"} # noqa: SLF001
)
_system_attributes = SimState._system_attributes | { # noqa: SLF001
"stress",
"cell_positions",
"cell_velocities",
"cell_masses",
"reference_cell",
"energy",
}

@property
def momenta(self) -> torch.Tensor:
"""Calculate momenta from velocities and masses."""
Expand Down Expand Up @@ -867,6 +879,25 @@ class NPTNoseHooverState(MDState):
barostat: NoseHooverChain
barostat_fns: NoseHooverChainFns

_system_attributes = (
MDState._system_attributes # noqa: SLF001
| {
"reference_cell",
"cell_position",
"cell_momentum",
"cell_mass",
}
)
_global_attributes = (
MDState._global_attributes # noqa: SLF001
| {
"thermostat",
"barostat",
"thermostat_fns",
"barostat_fns",
}
)

@property
def velocities(self) -> torch.Tensor:
"""Calculate particle velocities from momenta and masses.
Expand Down
4 changes: 4 additions & 0 deletions torch_sim/integrators/nvt.py
Original file line number Diff line number Diff line change
Expand Up @@ -266,6 +266,10 @@ class NVTNoseHooverState(MDState):
chain: NoseHooverChain
_chain_fns: NoseHooverChainFns

_global_attributes = (
MDState._global_attributes | {"chain", "_chain_fns"} # noqa: SLF001
)

@property
def velocities(self) -> torch.Tensor:
"""Velocities calculated from momenta and masses with shape
Expand Down
3 changes: 3 additions & 0 deletions torch_sim/monte_carlo.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,9 @@ class SwapMCState(SimState):
energy: torch.Tensor
last_permutation: torch.Tensor

_atom_attributes = SimState._atom_attributes | {"last_permutation"} # noqa: SLF001
_system_attributes = SimState._system_attributes | {"energy"} # noqa: SLF001


def generate_swaps(
state: SimState, generator: torch.Generator | None = None
Expand Down
Loading