Skip to content

Commit 0157533

Browse files
committed
use hardcoded attr names
add test for the new pre-defined attributes wip make fire simstate attributes predefined rename features to attributes define attribute scope for fire state consolidate attribute definitions and add it to npt rm infer_property_scope import __init_subclass__ to enforce that all attributes are specified find callable attributes do not check some user attributes filter for @properties __init_subclass__ doesn't catch for static state since it's a grandchild state added documentation for get_attrs_for_scope more examples for the other scopes cleaner documentation rename batch to system more docs cleanup test for running the init subclass more tests define more scope for the integrators split state handles none properties a bit better
1 parent 4c49f21 commit 0157533

File tree

11 files changed

+285
-215
lines changed

11 files changed

+285
-215
lines changed

examples/scripts/3_Dynamics/3.10_Hybrid_swap_mc.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -76,6 +76,7 @@ class HybridSwapMCState(MDState):
7676
"""
7777

7878
last_permutation: torch.Tensor
79+
_system_attributes = (*MDState._system_attributes, "last_permutation") # noqa: SLF001
7980

8081

8182
nvt_init, nvt_step = nvt_langevin(model=model, dt=0.002, kT=kT)

examples/tutorials/state_tutorial.py

Lines changed: 30 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -71,24 +71,36 @@
7171

7272
# %% [markdown]
7373
"""
74-
SimState attributes fall into three categories: atomwise, batchwise, and global.
74+
SimState attributes fall into three categories: atomwise, systemwise, and global.
7575
7676
* Atomwise attributes are tensors with shape (n_atoms, ...), these are `positions`,
77-
`masses`, `atomic_numbers`, and `batch`. Names are plural.
78-
* Batchwise attributes are tensors with shape (n_systems, ...), this is just `cell` for
77+
`masses`, `atomic_numbers`, and `system_idx`. Names are plural.
78+
* Systemwise attributes are tensors with shape (n_systems, ...), this is just `cell` for
7979
the base SimState. Names are singular.
8080
* Global attributes have any other shape or type, just `pbc` here. Names are singular.
8181
82-
You can use the `infer_property_scope` function to analyze a state's properties. This
82+
For TorchSim to know which attributes are atomwise, systemwise, and global, each attribute's
83+
name is explicitly defined in the `_atom_attributes`, `_system_attributes`, and `_global_attributes`:
84+
85+
_atom_attributes = ("positions", "masses", "atomic_numbers", "system_idx")
86+
_system_attributes = ("cell",)
87+
_global_attributes = ("pbc",)
88+
89+
You can use the `get_attrs_for_scope` generator function to analyze a state's properties. This
8390
is mostly used internally but can be useful for debugging.
8491
"""
8592

8693
# %%
87-
from torch_sim.state import infer_property_scope
94+
from torch_sim.state import get_attrs_for_scope
8895

89-
scope = infer_property_scope(si_state)
90-
print(scope)
96+
# loop through each attribute:
97+
for attr_name, attr_value in get_attrs_for_scope(si_state, "per-atom"):
98+
print(f"per-atom attribute: {attr_name}")
99+
print(f"value: {attr_value}")
91100

101+
# or access the attributes via a dict:
102+
print("Per-system attributes:", dict(get_attrs_for_scope(si_state, "per-system"))) # noqa: E501
103+
print("Global attributes:", dict(get_attrs_for_scope(si_state, "global")))
92104

93105
# %% [markdown]
94106
"""
@@ -112,7 +124,7 @@
112124
f"Multi-state has {multi_state.n_atoms} total atoms across {multi_state.n_systems} systems"
113125
)
114126

115-
# we can see how the shapes of batchwise, atomwise, and global properties change
127+
# we can see how the shapes of atomwise, systemwise, and global properties change
116128
print(f"Positions shape: {multi_state.positions.shape}")
117129
print(f"Cell shape: {multi_state.cell.shape}")
118130
print(f"PBC: {multi_state.pbc}")
@@ -142,7 +154,7 @@
142154
143155
SimState supports many convenience operations for manipulating batched states. Slicing
144156
is supported through fancy indexing, e.g. `state[[0, 1, 2]]` will return a new state
145-
containing only the first three batches. The other operations are available through the
157+
containing only the first three systems. The other operations are available through the
146158
`pop`, `split`, `clone`, and `to` methods.
147159
"""
148160

@@ -182,19 +194,19 @@
182194
# %% [markdown]
183195
"""
184196
185-
You can extract specific batches from a batched state using Python's slicing syntax.
197+
You can extract specific systems from a batched state using Python's slicing syntax.
186198
This is extremely useful for analyzing specific systems or for implementing complex
187199
workflows where different systems need separate processing:
188200
189201
The slicing interface follows Python's standard indexing conventions, making it
190202
intuitive to use. Behind the scenes, TorchSim is creating a new SimState with only the
191-
selected batches, maintaining all the necessary properties and relationships.
203+
selected systems, maintaining all the necessary properties and relationships.
192204
193205
Note the difference between these operations:
194-
- `split()` returns all batches as separate states but doesn't modify the original
195-
- `pop()` removes specified batches from the original state and returns them as
206+
- `split()` returns all systems as separate states but doesn't modify the original
207+
- `pop()` removes specified systems from the original state and returns them as
196208
separate states
197-
- `__getitem__` (slicing) creates a new state with specified batches without modifying
209+
- `__getitem__` (slicing) creates a new state with specified systems without modifying
198210
the original
199211
200212
This flexibility allows you to structure your simulation workflows in the most
@@ -203,7 +215,7 @@
203215
### Splitting and Popping Batches
204216
205217
SimState provides methods to split a batched state into separate states or to remove
206-
specific batches:
218+
specific systems:
207219
"""
208220

209221
# %% [markdown]
@@ -257,10 +269,9 @@
257269
)
258270

259271
print("MDState properties:")
260-
scope = infer_property_scope(md_state)
261-
print("Global properties:", scope["global"])
262-
print("Per-atom properties:", scope["per_atom"])
263-
print("Per-system properties:", scope["per_system"])
272+
print("Per-atom attributes:", dict(get_attrs_for_scope(si_state, "per-atom")))
273+
print("Per-system attributes:", dict(get_attrs_for_scope(si_state, "per-system")))
274+
print("Global attributes:", dict(get_attrs_for_scope(si_state, "global")))
264275

265276

266277
# %% [markdown]

pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,7 @@ test = [
4444
"pymatgen>=2024.11.3",
4545
"pytest-cov>=6",
4646
"pytest>=8",
47+
"pytest-xdist>=3.8.0",
4748
]
4849
io = ["ase>=3.24", "phonopy>=2.37.0", "pymatgen>=2024.11.3"]
4950
mace = ["mace-torch>=0.3.12"]

tests/test_state.py

Lines changed: 47 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
_pop_states,
1414
_slice_state,
1515
concatenate_states,
16-
infer_property_scope,
16+
get_attrs_for_scope,
1717
initialize_state,
1818
)
1919

@@ -24,38 +24,59 @@
2424
from pymatgen.core import Structure
2525

2626

27-
def test_infer_sim_state_property_scope(si_sim_state: ts.SimState) -> None:
28-
"""Test inference of property scope."""
29-
scope = infer_property_scope(si_sim_state)
30-
assert set(scope["global"]) == {"pbc"}
31-
assert set(scope["per_atom"]) == {
27+
def test_get_attrs_for_scope(si_sim_state: ts.SimState) -> None:
28+
"""Test getting attributes for a scope."""
29+
per_atom_attrs = dict(get_attrs_for_scope(si_sim_state, "per-atom"))
30+
assert set(per_atom_attrs.keys()) == {
3231
"positions",
3332
"masses",
3433
"atomic_numbers",
3534
"system_idx",
3635
}
37-
assert set(scope["per_system"]) == {"cell"}
36+
per_system_attrs = dict(get_attrs_for_scope(si_sim_state, "per-system"))
37+
assert set(per_system_attrs.keys()) == {"cell"}
38+
global_attrs = dict(get_attrs_for_scope(si_sim_state, "global"))
39+
assert set(global_attrs.keys()) == {"pbc"}
3840

3941

40-
def test_infer_md_state_property_scope(si_sim_state: ts.SimState) -> None:
41-
"""Test inference of property scope."""
42-
state = MDState(
43-
**asdict(si_sim_state),
44-
momenta=torch.randn_like(si_sim_state.positions),
45-
forces=torch.randn_like(si_sim_state.positions),
46-
energy=torch.zeros((1,)),
47-
)
48-
scope = infer_property_scope(state)
49-
assert set(scope["global"]) == {"pbc"}
50-
assert set(scope["per_atom"]) == {
51-
"positions",
52-
"masses",
53-
"atomic_numbers",
54-
"system_idx",
55-
"forces",
56-
"momenta",
57-
}
58-
assert set(scope["per_system"]) == {"cell", "energy"}
42+
def test_all_attributes_must_be_specified_in_scopes() -> None:
43+
"""Test that an error is raised when we forget to specify the scope
44+
for an attribute in a child SimState class."""
45+
with pytest.raises(TypeError) as excinfo:
46+
47+
class ChildState(SimState):
48+
attribute_specified_in_scopes: bool
49+
attribute_not_specified_in_scopes: bool
50+
51+
_atom_attributes = (
52+
*SimState._atom_attributes, # noqa: SLF001
53+
"attribute_specified_in_scopes",
54+
)
55+
56+
assert "attribute_not_specified_in_scopes" in str(excinfo.value)
57+
assert "attribute_specified_in_scopes" not in str(excinfo.value)
58+
59+
60+
def test_no_duplicate_attributes_in_scopes() -> None:
61+
"""Test that no attributes are specified in multiple scopes."""
62+
63+
# Capture the exception information using "as excinfo"
64+
with pytest.raises(TypeError) as excinfo:
65+
66+
class ChildState(SimState):
67+
duplicated_attribute: bool
68+
69+
_system_attributes = (
70+
*SimState._atom_attributes, # noqa: SLF001
71+
"duplicated_attribute",
72+
)
73+
_global_attributes = (
74+
*SimState._global_attributes, # noqa: SLF001
75+
"duplicated_attribute",
76+
)
77+
78+
assert "are declared multiple times" in str(excinfo.value)
79+
assert "duplicated_attribute" in str(excinfo.value)
5980

6081

6182
def test_slice_substate(

torch_sim/integrators/md.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,9 @@ class MDState(SimState):
4040
energy: torch.Tensor
4141
forces: torch.Tensor
4242

43+
_atom_attributes = (*SimState._atom_attributes, "momenta", "forces") # noqa: SLF001
44+
_system_attributes = (*SimState._system_attributes, "energy") # noqa: SLF001
45+
4346
@property
4447
def velocities(self) -> torch.Tensor:
4548
"""Velocities calculated from momenta and masses with shape

torch_sim/integrators/npt.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
calculate_momenta,
1515
construct_nose_hoover_chain,
1616
)
17+
from torch_sim.optimizers import md_atom_attributes
1718
from torch_sim.quantities import calc_kinetic_energy
1819
from torch_sim.state import SimState
1920
from torch_sim.typing import StateDict
@@ -66,6 +67,17 @@ class NPTLangevinState(SimState):
6667
cell_velocities: torch.Tensor
6768
cell_masses: torch.Tensor
6869

70+
_atom_attributes = md_atom_attributes
71+
_system_attributes = (
72+
*SimState._system_attributes, # noqa: SLF001
73+
"stress",
74+
"cell_positions",
75+
"cell_velocities",
76+
"cell_masses",
77+
"reference_cell",
78+
"energy",
79+
)
80+
6981
@property
7082
def momenta(self) -> torch.Tensor:
7183
"""Calculate momenta from velocities and masses."""
@@ -866,6 +878,21 @@ class NPTNoseHooverState(MDState):
866878
barostat: NoseHooverChain
867879
barostat_fns: NoseHooverChainFns
868880

881+
_system_attributes = (
882+
*MDState._system_attributes, # noqa: SLF001
883+
"reference_cell",
884+
"cell_position",
885+
"cell_momentum",
886+
"cell_mass",
887+
)
888+
_global_attributes = (
889+
*MDState._global_attributes, # noqa: SLF001
890+
"thermostat",
891+
"barostat",
892+
"thermostat_fns",
893+
"barostat_fns",
894+
)
895+
869896
@property
870897
def velocities(self) -> torch.Tensor:
871898
"""Calculate particle velocities from momenta and masses.

torch_sim/integrators/nvt.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -265,6 +265,12 @@ class NVTNoseHooverState(MDState):
265265
chain: NoseHooverChain
266266
_chain_fns: NoseHooverChainFns
267267

268+
_global_attributes = (
269+
*MDState._global_attributes, # noqa: SLF001
270+
"chain",
271+
"_chain_fns",
272+
)
273+
268274
@property
269275
def velocities(self) -> torch.Tensor:
270276
"""Velocities calculated from momenta and masses with shape

torch_sim/monte_carlo.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,9 @@ class SwapMCState(SimState):
3535
energy: torch.Tensor
3636
last_permutation: torch.Tensor
3737

38+
_atom_attributes = (*SimState._atom_attributes, "last_permutation") # noqa: SLF001
39+
_system_attributes = (*SimState._system_attributes, "energy") # noqa: SLF001
40+
3841

3942
def generate_swaps(
4043
state: SimState, generator: torch.Generator | None = None

torch_sim/optimizers.py

Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,28 @@
3232
MdFlavor = Literal["vv_fire", "ase_fire"]
3333
vv_fire_key, ase_fire_key = get_args(MdFlavor)
3434

35+
md_atom_attributes = (*SimState._atom_attributes, "forces", "velocities") # noqa: SLF001
36+
_fire_system_attributes = (
37+
*SimState._system_attributes, # noqa: SLF001
38+
"energy",
39+
"stress",
40+
"cell_positions",
41+
"cell_velocities",
42+
"cell_forces",
43+
"cell_masses",
44+
"reference_cell",
45+
"cell_factor",
46+
"pressure",
47+
"dt",
48+
"alpha",
49+
"n_pos",
50+
)
51+
_fire_global_attributes = (
52+
*SimState._global_attributes, # noqa: SLF001
53+
"hydrostatic_strain",
54+
"constant_volume",
55+
)
56+
3557

3658
@dataclass
3759
class GDState(SimState):
@@ -55,6 +77,9 @@ class GDState(SimState):
5577
forces: torch.Tensor
5678
energy: torch.Tensor
5779

80+
_atom_attributes = (*SimState._atom_attributes, "forces") # noqa: SLF001
81+
_system_attributes = (*SimState._system_attributes, "energy") # noqa: SLF001
82+
5883

5984
def gradient_descent(
6085
model: torch.nn.Module, *, lr: torch.Tensor | float = 0.01
@@ -194,6 +219,22 @@ class UnitCellGDState(GDState, DeformGradMixin):
194219
cell_forces: torch.Tensor
195220
cell_masses: torch.Tensor
196221

222+
_system_attributes = (
223+
*GDState._system_attributes, # noqa: SLF001
224+
"cell_forces",
225+
"pressure",
226+
"stress",
227+
"cell_positions",
228+
"cell_factor",
229+
"cell_masses",
230+
)
231+
_global_attributes = (
232+
*GDState._global_attributes, # noqa: SLF001
233+
"reference_cell",
234+
"hydrostatic_strain",
235+
"constant_volume",
236+
)
237+
197238

198239
def unit_cell_gradient_descent( # noqa: PLR0915, C901
199240
model: torch.nn.Module,
@@ -481,6 +522,9 @@ class FireState(SimState):
481522
alpha: torch.Tensor
482523
n_pos: torch.Tensor
483524

525+
_atom_attributes = md_atom_attributes
526+
_system_attributes = (*SimState._system_attributes, "energy", "dt", "alpha", "n_pos") # noqa: SLF001
527+
484528

485529
def fire(
486530
model: torch.nn.Module,
@@ -692,6 +736,10 @@ class UnitCellFireState(SimState, DeformGradMixin):
692736
alpha: torch.Tensor
693737
n_pos: torch.Tensor
694738

739+
_atom_attributes = md_atom_attributes
740+
_system_attributes = _fire_system_attributes
741+
_global_attributes = _fire_global_attributes
742+
695743

696744
def unit_cell_fire(
697745
model: torch.nn.Module,
@@ -980,6 +1028,10 @@ class FrechetCellFIREState(SimState, DeformGradMixin):
9801028
alpha: torch.Tensor
9811029
n_pos: torch.Tensor
9821030

1031+
_atom_attributes = md_atom_attributes
1032+
_system_attributes = _fire_system_attributes
1033+
_global_attributes = _fire_global_attributes
1034+
9831035

9841036
def frechet_cell_fire(
9851037
model: torch.nn.Module,

0 commit comments

Comments
 (0)