Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 1 addition & 16 deletions examples/tutorials/state_tutorial.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,18 +78,8 @@
* Batchwise attributes are tensors with shape (n_systems, ...), this is just `cell` for
the base SimState. Names are singular.
* Global attributes have any other shape or type, just `pbc` here. Names are singular.

You can use the `infer_property_scope` function to analyze a state's properties. This
is mostly used internally but can be useful for debugging.
"""

# %%
from torch_sim.state import infer_property_scope

scope = infer_property_scope(si_state)
print(scope)


# %% [markdown]
"""
### A Batched State
Expand Down Expand Up @@ -256,12 +246,7 @@
energy=torch.zeros((si_state.n_systems,), device=si_state.device), # Initial 0 energy
)

print("MDState properties:")
scope = infer_property_scope(md_state)
print("Global properties:", scope["global"])
print("Per-atom properties:", scope["per_atom"])
print("Per-system properties:", scope["per_system"])

print(md_state)

# %% [markdown]
"""
Expand Down
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ test = [
"pymatgen>=2024.11.3",
"pytest-cov>=6",
"pytest>=8",
"pytest-xdist>=3.8.0"
]
io = ["ase>=3.24", "phonopy>=2.37.0", "pymatgen>=2024.11.3"]
mace = ["mace-torch>=0.3.12"]
Expand Down
35 changes: 0 additions & 35 deletions tests/test_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@
_pop_states,
_slice_state,
concatenate_states,
infer_property_scope,
initialize_state,
)

Expand All @@ -24,40 +23,6 @@
from pymatgen.core import Structure


def test_infer_sim_state_property_scope(si_sim_state: ts.SimState) -> None:
"""Test inference of property scope."""
scope = infer_property_scope(si_sim_state)
assert set(scope["global"]) == {"pbc"}
assert set(scope["per_atom"]) == {
"positions",
"masses",
"atomic_numbers",
"system_idx",
}
assert set(scope["per_system"]) == {"cell"}


def test_infer_md_state_property_scope(si_sim_state: ts.SimState) -> None:
"""Test inference of property scope."""
state = MDState(
**asdict(si_sim_state),
momenta=torch.randn_like(si_sim_state.positions),
forces=torch.randn_like(si_sim_state.positions),
energy=torch.zeros((1,)),
)
scope = infer_property_scope(state)
assert set(scope["global"]) == {"pbc"}
assert set(scope["per_atom"]) == {
"positions",
"masses",
"atomic_numbers",
"system_idx",
"forces",
"momenta",
}
assert set(scope["per_system"]) == {"cell", "energy"}


def test_slice_substate(
si_double_sim_state: ts.SimState, si_sim_state: ts.SimState
) -> None:
Expand Down
52 changes: 48 additions & 4 deletions torch_sim/integrators/md.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@
from torch_sim.state import SimState


@dataclass
class MDState(SimState):
"""State information for molecular dynamics simulations.

Expand All @@ -36,9 +35,54 @@ class MDState(SimState):
dtype (torch.dtype): Data type of tensors
"""

momenta: torch.Tensor
energy: torch.Tensor
forces: torch.Tensor
def __init__(
self,
*,
momenta: torch.Tensor,
energy: torch.Tensor,
forces: torch.Tensor,
positions: torch.Tensor,
masses: torch.Tensor,
atomic_numbers: torch.Tensor,
cell: torch.Tensor,
pbc: bool,
system_idx: torch.Tensor | None = None,
):
super().__init__(
positions=positions,
masses=masses,
atomic_numbers=atomic_numbers,
cell=cell,
pbc=pbc,
system_idx=system_idx,
)
self.node_features["momenta"] = momenta
self.node_features["energy"] = energy
self.node_features["forces"] = forces

@property
def momenta(self) -> torch.Tensor:
return self.node_features["momenta"]

@momenta.setter
def momenta(self, momenta: torch.Tensor) -> None:
self.node_features["momenta"] = momenta

@property
def energy(self) -> torch.Tensor:
return self.node_features["energy"]

@energy.setter
def energy(self, energy: torch.Tensor) -> None:
self.node_features["energy"] = energy

@property
def forces(self) -> torch.Tensor:
return self.node_features["forces"]

@forces.setter
def forces(self, forces: torch.Tensor) -> None:
self.node_features["forces"] = forces

@property
def velocities(self) -> torch.Tensor:
Expand Down
Loading
Loading