Skip to content

Commit cde91e9

Browse files
curtischongCompRhysorionarcher
authored
Define attribute scopes in SimStates (#228)
Co-authored-by: Rhys Goodall <[email protected]> Co-authored-by: Orion Cohen <[email protected]>
1 parent a563523 commit cde91e9

File tree

11 files changed

+319
-215
lines changed

11 files changed

+319
-215
lines changed

examples/scripts/3_Dynamics/3.10_Hybrid_swap_mc.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -76,6 +76,9 @@ class HybridSwapMCState(MDState):
7676
"""
7777

7878
last_permutation: torch.Tensor
79+
_atom_attributes = (
80+
MDState._atom_attributes | {"last_permutation"} # noqa: SLF001
81+
)
7982

8083

8184
nvt_init, nvt_step = nvt_langevin(model=model, dt=0.002, kT=kT)

examples/tutorials/hybrid_swap_tutorial.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,9 +34,11 @@
3434
"""
3535

3636
# %%
37+
from typing import ClassVar
3738
import torch
3839
import torch_sim as ts
3940
from mace.calculators.foundations_models import mace_mp
41+
from torch_sim.integrators.md import MDState
4042
from torch_sim.models.mace import MaceModel
4143

4244
# Initialize the mace model
@@ -104,6 +106,9 @@ class HybridSwapMCState(ts.integrators.MDState):
104106
"""
105107

106108
last_permutation: torch.Tensor
109+
_atom_attributes = (
110+
MDState._atom_attributes | {"last_permutation"} # noqa: SLF001
111+
)
107112

108113

109114
# %% [markdown]

examples/tutorials/state_tutorial.py

Lines changed: 19 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -79,16 +79,28 @@
7979
the base SimState. Names are singular.
8080
* Global attributes have any other shape or type, just `pbc` here. Names are singular.
8181
82-
You can use the `infer_property_scope` function to analyze a state's properties. This
82+
For TorchSim to know which attributes are atomwise, systemwise, and global, each attribute's
83+
name is explicitly defined in the `_atom_attributes`, `_system_attributes`, and `_global_attributes`:
84+
85+
_atom_attributes = {"positions", "masses", "atomic_numbers", "system_idx"}
86+
_system_attributes = {"cell"}
87+
_global_attributes = {"pbc"}
88+
89+
You can use the `get_attrs_for_scope` generator function to analyze a state's properties. This
8390
is mostly used internally but can be useful for debugging.
8491
"""
8592

8693
# %%
87-
from torch_sim.state import infer_property_scope
94+
from torch_sim.state import get_attrs_for_scope
8895

89-
scope = infer_property_scope(si_state)
90-
print(scope)
96+
# loop through each attribute:
97+
for attr_name, attr_value in get_attrs_for_scope(si_state, "per-atom"):
98+
print(f"per-atom attribute: {attr_name}")
99+
print(f"value: {attr_value}")
91100

101+
# or access the attributes via a dict:
102+
print("Per-system attributes:", dict(get_attrs_for_scope(si_state, "per-system"))) # noqa: E501
103+
print("Global attributes:", dict(get_attrs_for_scope(si_state, "global")))
92104

93105
# %% [markdown]
94106
"""
@@ -257,10 +269,9 @@
257269
)
258270

259271
print("MDState properties:")
260-
scope = infer_property_scope(md_state)
261-
print("Global properties:", scope["global"])
262-
print("Per-atom properties:", scope["per_atom"])
263-
print("Per-system properties:", scope["per_system"])
272+
print("Per-atom attributes:", dict(get_attrs_for_scope(si_state, "per-atom")))
273+
print("Per-system attributes:", dict(get_attrs_for_scope(si_state, "per-system")))
274+
print("Global attributes:", dict(get_attrs_for_scope(si_state, "global")))
264275

265276

266277
# %% [markdown]

tests/test_state.py

Lines changed: 45 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
_pop_states,
1414
_slice_state,
1515
concatenate_states,
16-
infer_property_scope,
16+
get_attrs_for_scope,
1717
initialize_state,
1818
)
1919

@@ -24,38 +24,52 @@
2424
from pymatgen.core import Structure
2525

2626

27-
def test_infer_sim_state_property_scope(si_sim_state: ts.SimState) -> None:
28-
"""Test inference of property scope."""
29-
scope = infer_property_scope(si_sim_state)
30-
assert set(scope["global"]) == {"pbc"}
31-
assert set(scope["per_atom"]) == {
27+
def test_get_attrs_for_scope(si_sim_state: ts.SimState) -> None:
28+
"""Test getting attributes for a scope."""
29+
per_atom_attrs = dict(get_attrs_for_scope(si_sim_state, "per-atom"))
30+
assert set(per_atom_attrs.keys()) == {
3231
"positions",
3332
"masses",
3433
"atomic_numbers",
3534
"system_idx",
3635
}
37-
assert set(scope["per_system"]) == {"cell"}
36+
per_system_attrs = dict(get_attrs_for_scope(si_sim_state, "per-system"))
37+
assert set(per_system_attrs.keys()) == {"cell"}
38+
global_attrs = dict(get_attrs_for_scope(si_sim_state, "global"))
39+
assert set(global_attrs.keys()) == {"pbc"}
3840

3941

40-
def test_infer_md_state_property_scope(si_sim_state: ts.SimState) -> None:
41-
"""Test inference of property scope."""
42-
state = MDState(
43-
**asdict(si_sim_state),
44-
momenta=torch.randn_like(si_sim_state.positions),
45-
forces=torch.randn_like(si_sim_state.positions),
46-
energy=torch.zeros((1,)),
47-
)
48-
scope = infer_property_scope(state)
49-
assert set(scope["global"]) == {"pbc"}
50-
assert set(scope["per_atom"]) == {
51-
"positions",
52-
"masses",
53-
"atomic_numbers",
54-
"system_idx",
55-
"forces",
56-
"momenta",
57-
}
58-
assert set(scope["per_system"]) == {"cell", "energy"}
42+
def test_all_attributes_must_be_specified_in_scopes() -> None:
43+
"""Test that an error is raised when we forget to specify the scope
44+
for an attribute in a child SimState class."""
45+
with pytest.raises(TypeError) as excinfo:
46+
47+
class ChildState(SimState):
48+
attribute_specified_in_scopes: bool
49+
attribute_not_specified_in_scopes: bool
50+
51+
_atom_attributes = (
52+
SimState._atom_attributes | {"attribute_specified_in_scopes"} # noqa: SLF001
53+
)
54+
55+
assert "attribute_not_specified_in_scopes" in str(excinfo.value)
56+
assert "attribute_specified_in_scopes" not in str(excinfo.value)
57+
58+
59+
def test_no_duplicate_attributes_in_scopes() -> None:
60+
"""Test that no attributes are specified in multiple scopes."""
61+
62+
# Capture the exception information using "as excinfo"
63+
with pytest.raises(TypeError) as excinfo:
64+
65+
class ChildState(SimState):
66+
duplicated_attribute: bool
67+
68+
_system_attributes = SimState._system_attributes | {"duplicated_attribute"} # noqa: SLF001
69+
_global_attributes = SimState._global_attributes | {"duplicated_attribute"} # noqa: SLF001
70+
71+
assert "are declared multiple times" in str(excinfo.value)
72+
assert "duplicated_attribute" in str(excinfo.value)
5973

6074

6175
def test_slice_substate(
@@ -497,6 +511,11 @@ def test_column_vector_cell(si_sim_state: ts.SimState) -> None:
497511
class DeformState(SimState, DeformGradMixin):
498512
"""Test class that combines SimState with DeformGradMixin."""
499513

514+
_system_attributes = (
515+
SimState._system_attributes # noqa: SLF001
516+
| DeformGradMixin._system_attributes # noqa: SLF001
517+
)
518+
500519
def __init__(
501520
self,
502521
*args,

torch_sim/integrators/md.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,13 @@ class MDState(SimState):
4141
energy: torch.Tensor
4242
forces: torch.Tensor
4343

44+
_atom_attributes = (
45+
SimState._atom_attributes | {"momenta", "forces"} # noqa: SLF001
46+
)
47+
_system_attributes = (
48+
SimState._system_attributes | {"energy"} # noqa: SLF001
49+
)
50+
4451
@property
4552
def velocities(self) -> torch.Tensor:
4653
"""Velocities calculated from momenta and masses with shape

torch_sim/integrators/npt.py

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -67,6 +67,18 @@ class NPTLangevinState(SimState):
6767
cell_velocities: torch.Tensor
6868
cell_masses: torch.Tensor
6969

70+
_atom_attributes = (
71+
SimState._atom_attributes | {"forces", "velocities"} # noqa: SLF001
72+
)
73+
_system_attributes = SimState._system_attributes | { # noqa: SLF001
74+
"stress",
75+
"cell_positions",
76+
"cell_velocities",
77+
"cell_masses",
78+
"reference_cell",
79+
"energy",
80+
}
81+
7082
@property
7183
def momenta(self) -> torch.Tensor:
7284
"""Calculate momenta from velocities and masses."""
@@ -867,6 +879,25 @@ class NPTNoseHooverState(MDState):
867879
barostat: NoseHooverChain
868880
barostat_fns: NoseHooverChainFns
869881

882+
_system_attributes = (
883+
MDState._system_attributes # noqa: SLF001
884+
| {
885+
"reference_cell",
886+
"cell_position",
887+
"cell_momentum",
888+
"cell_mass",
889+
}
890+
)
891+
_global_attributes = (
892+
MDState._global_attributes # noqa: SLF001
893+
| {
894+
"thermostat",
895+
"barostat",
896+
"thermostat_fns",
897+
"barostat_fns",
898+
}
899+
)
900+
870901
@property
871902
def velocities(self) -> torch.Tensor:
872903
"""Calculate particle velocities from momenta and masses.

torch_sim/integrators/nvt.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -266,6 +266,10 @@ class NVTNoseHooverState(MDState):
266266
chain: NoseHooverChain
267267
_chain_fns: NoseHooverChainFns
268268

269+
_global_attributes = (
270+
MDState._global_attributes | {"chain", "_chain_fns"} # noqa: SLF001
271+
)
272+
269273
@property
270274
def velocities(self) -> torch.Tensor:
271275
"""Velocities calculated from momenta and masses with shape

torch_sim/monte_carlo.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,9 @@ class SwapMCState(SimState):
3636
energy: torch.Tensor
3737
last_permutation: torch.Tensor
3838

39+
_atom_attributes = SimState._atom_attributes | {"last_permutation"} # noqa: SLF001
40+
_system_attributes = SimState._system_attributes | {"energy"} # noqa: SLF001
41+
3942

4043
def generate_swaps(
4144
state: SimState, generator: torch.Generator | None = None

0 commit comments

Comments
 (0)