diff --git a/examples/tutorials/state_tutorial.py b/examples/tutorials/state_tutorial.py index 0bc9e341..065ca4a7 100644 --- a/examples/tutorials/state_tutorial.py +++ b/examples/tutorials/state_tutorial.py @@ -78,18 +78,8 @@ * Batchwise attributes are tensors with shape (n_systems, ...), this is just `cell` for the base SimState. Names are singular. * Global attributes have any other shape or type, just `pbc` here. Names are singular. - -You can use the `infer_property_scope` function to analyze a state's properties. This -is mostly used internally but can be useful for debugging. """ -# %% -from torch_sim.state import infer_property_scope - -scope = infer_property_scope(si_state) -print(scope) - - # %% [markdown] """ ### A Batched State @@ -256,12 +246,7 @@ energy=torch.zeros((si_state.n_systems,), device=si_state.device), # Initial 0 energy ) -print("MDState properties:") -scope = infer_property_scope(md_state) -print("Global properties:", scope["global"]) -print("Per-atom properties:", scope["per_atom"]) -print("Per-system properties:", scope["per_system"]) - +print(md_state) # %% [markdown] """ diff --git a/pyproject.toml b/pyproject.toml index 723edafb..399d2919 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -44,6 +44,7 @@ test = [ "pymatgen>=2024.11.3", "pytest-cov>=6", "pytest>=8", + "pytest-xdist>=3.8.0" ] io = ["ase>=3.24", "phonopy>=2.37.0", "pymatgen>=2024.11.3"] mace = ["mace-torch>=0.3.12"] diff --git a/tests/test_state.py b/tests/test_state.py index ea57dd3a..bb783dfb 100644 --- a/tests/test_state.py +++ b/tests/test_state.py @@ -13,7 +13,6 @@ _pop_states, _slice_state, concatenate_states, - infer_property_scope, initialize_state, ) @@ -24,40 +23,6 @@ from pymatgen.core import Structure -def test_infer_sim_state_property_scope(si_sim_state: ts.SimState) -> None: - """Test inference of property scope.""" - scope = infer_property_scope(si_sim_state) - assert set(scope["global"]) == {"pbc"} - assert set(scope["per_atom"]) == { - "positions", - "masses", - "atomic_numbers", - "system_idx", - } - assert set(scope["per_system"]) == {"cell"} - - -def test_infer_md_state_property_scope(si_sim_state: ts.SimState) -> None: - """Test inference of property scope.""" - state = MDState( - **asdict(si_sim_state), - momenta=torch.randn_like(si_sim_state.positions), - forces=torch.randn_like(si_sim_state.positions), - energy=torch.zeros((1,)), - ) - scope = infer_property_scope(state) - assert set(scope["global"]) == {"pbc"} - assert set(scope["per_atom"]) == { - "positions", - "masses", - "atomic_numbers", - "system_idx", - "forces", - "momenta", - } - assert set(scope["per_system"]) == {"cell", "energy"} - - def test_slice_substate( si_double_sim_state: ts.SimState, si_sim_state: ts.SimState ) -> None: diff --git a/torch_sim/integrators/md.py b/torch_sim/integrators/md.py index ce15877d..3d2f4570 100644 --- a/torch_sim/integrators/md.py +++ b/torch_sim/integrators/md.py @@ -9,7 +9,6 @@ from torch_sim.state import SimState -@dataclass class MDState(SimState): """State information for molecular dynamics simulations. @@ -36,9 +35,54 @@ class MDState(SimState): dtype (torch.dtype): Data type of tensors """ - momenta: torch.Tensor - energy: torch.Tensor - forces: torch.Tensor + def __init__( + self, + *, + momenta: torch.Tensor, + energy: torch.Tensor, + forces: torch.Tensor, + positions: torch.Tensor, + masses: torch.Tensor, + atomic_numbers: torch.Tensor, + cell: torch.Tensor, + pbc: bool, + system_idx: torch.Tensor | None = None, + ): + super().__init__( + positions=positions, + masses=masses, + atomic_numbers=atomic_numbers, + cell=cell, + pbc=pbc, + system_idx=system_idx, + ) + self.node_features["momenta"] = momenta + self.node_features["energy"] = energy + self.node_features["forces"] = forces + + @property + def momenta(self) -> torch.Tensor: + return self.node_features["momenta"] + + @momenta.setter + def momenta(self, momenta: torch.Tensor) -> None: + self.node_features["momenta"] = momenta + + @property + def energy(self) -> torch.Tensor: + return self.node_features["energy"] + + @energy.setter + def energy(self, energy: torch.Tensor) -> None: + self.node_features["energy"] = energy + + @property + def forces(self) -> torch.Tensor: + return self.node_features["forces"] + + @forces.setter + def forces(self, forces: torch.Tensor) -> None: + self.node_features["forces"] = forces @property def velocities(self) -> torch.Tensor: diff --git a/torch_sim/state.py b/torch_sim/state.py index ce21ef9b..07c9635f 100644 --- a/torch_sim/state.py +++ b/torch_sim/state.py @@ -7,8 +7,7 @@ import copy import importlib import warnings -from dataclasses import dataclass, field -from typing import TYPE_CHECKING, Literal, Self +from typing import TYPE_CHECKING, Any, Self, TypeVar, cast import torch @@ -21,8 +20,9 @@ from phonopy.structure.atoms import PhonopyAtoms from pymatgen.core import Structure +DefaultFeatures = ["positions", "masses", "atomic_numbers", "cell", "pbc", "system_idx"] + -@dataclass class SimState: """State representation for atomistic systems with batched operations support. @@ -76,16 +76,34 @@ class SimState: >>> cloned_state = state.clone() """ - positions: torch.Tensor - masses: torch.Tensor - cell: torch.Tensor - pbc: bool # TODO: do all calculators support mixed pbc? - atomic_numbers: torch.Tensor - system_idx: torch.Tensor | None = field(default=None, kw_only=True) + node_features: dict[str, Any] + system_features: dict[str, Any] + global_features: dict[str, Any] + + def __init__( + self, + positions: torch.Tensor, + masses: torch.Tensor, + atomic_numbers: torch.Tensor, + cell: torch.Tensor, + *, + pbc: bool, + system_idx: torch.Tensor | None = None, + ) -> None: + """Initialize the SimState.""" + self.node_features = { + "positions": positions, + "masses": masses, + "atomic_numbers": atomic_numbers, + } + self.system_features = {} + self.global_features = { + "cell": cell, + "pbc": pbc, + } - def __post_init__(self) -> None: - """Validate and process the state after initialization.""" - # data validation and fill system_idx + # Validate and process the state after initialization. + # data validation and fill batch # should make pbc a tensor here # if devices aren't all the same, raise an error, in a clean way devices = { @@ -107,13 +125,13 @@ def __post_init__(self) -> None: f"masses {shapes[1]}, atomic_numbers {shapes[2]}" ) - if self.cell.ndim != 3 and self.system_idx is None: + if self.cell.ndim != 3 and system_idx is None: self.cell = self.cell.unsqueeze(0) if self.cell.shape[-2:] != (3, 3): raise ValueError("Cell must have shape (n_systems, 3, 3)") - if self.system_idx is None: + if system_idx is None: self.system_idx = torch.zeros( self.n_atoms, device=self.device, dtype=torch.int64 ) @@ -121,15 +139,77 @@ def __post_init__(self) -> None: # assert that system indices are unique consecutive integers # TODO(curtis): I feel like this logic is not reliable. # I'll come up with something better later. - _, counts = torch.unique_consecutive(self.system_idx, return_counts=True) - if not torch.all(counts == torch.bincount(self.system_idx)): + _, counts = torch.unique_consecutive(system_idx, return_counts=True) + if not torch.all(counts == torch.bincount(system_idx)): raise ValueError("System indices must be unique consecutive integers") + self.system_idx = system_idx if self.cell.shape[0] != self.n_systems: raise ValueError( - f"Cell must have shape (n_systems, 3, 3), got {self.cell.shape}" + f"Cell must have shape (n_systems, 3, 3): ({self.n_systems}, 3, 3)" + f"got {self.cell.shape}" ) + @classmethod + def from_features(cls, node_features: dict[str, Any], system_features: dict[str, Any], global_features: dict[str, Any]) -> Self: + # TODO(curtis): investigate if there are system features? + return cls( + positions=node_features["positions"], + masses=node_features["masses"], + atomic_numbers=node_features["atomic_numbers"], + cell=global_features["cell"], + pbc=global_features["pbc"], + system_idx=node_features["system_idx"], + ) + + @property + def positions(self) -> torch.Tensor: + return self.node_features["positions"] + + @positions.setter + def positions(self, positions: torch.Tensor) -> None: + self.node_features["positions"] = positions + + @property + def masses(self) -> torch.Tensor: + return self.node_features["masses"] + + @masses.setter + def masses(self, masses: torch.Tensor) -> None: + self.node_features["masses"] = masses + + @property + def atomic_numbers(self) -> torch.Tensor: + return self.node_features["atomic_numbers"] + + @atomic_numbers.setter + def atomic_numbers(self, atomic_numbers: torch.Tensor) -> None: + self.node_features["atomic_numbers"] = atomic_numbers + + @property + def cell(self) -> torch.Tensor: + return self.global_features["cell"] + + @cell.setter + def cell(self, cell: torch.Tensor) -> None: + self.global_features["cell"] = cell + + @property + def pbc(self) -> bool: + return self.global_features["pbc"] + + @pbc.setter + def pbc(self, pbc: bool) -> None: + self.global_features["pbc"] = pbc + + @property + def system_idx(self) -> torch.Tensor: + return self.node_features["system_idx"] + + @system_idx.setter + def system_idx(self, system_idx: torch.Tensor) -> None: + self.node_features["system_idx"] = system_idx + @property def wrap_positions(self) -> torch.Tensor: """Atomic positions wrapped according to periodic boundary conditions if pbc=True, @@ -226,7 +306,9 @@ def n_systems(self) -> int: @property def volume(self) -> torch.Tensor: """Volume of the system.""" - return torch.det(self.cell) if self.pbc else None + if not self.pbc: + raise ValueError("Volume is only defined for periodic systems") + return torch.det(self.cell) @property def column_vector_cell(self) -> torch.Tensor: @@ -336,7 +418,7 @@ def pop(self, system_indices: int | list[int] | slice | torch.Tensor) -> list[Se for attr_name, attr_value in vars(modified_state).items(): setattr(self, attr_name, attr_value) - return popped_states + return cast("list[Self]", popped_states) def to( self, device: torch.device | None = None, dtype: torch.dtype | None = None @@ -376,14 +458,8 @@ def __getitem__(self, system_indices: int | list[int] | slice | torch.Tensor) -> class DeformGradMixin: """Mixin for states that support deformation gradients.""" - @property - def momenta(self) -> torch.Tensor: - """Calculate momenta from velocities and masses. - - Returns: - The momenta of the particles - """ - return self.velocities * self.masses.unsqueeze(-1) + reference_cell: torch.Tensor + row_vector_cell: torch.Tensor @property def reference_row_vector_cell(self) -> torch.Tensor: @@ -457,11 +533,14 @@ def _normalize_system_indices( raise TypeError(f"Unsupported index type: {type(system_indices)}") +SimStateT = TypeVar("SimStateT", bound=SimState) + + def state_to_device( - state: SimState, + state: SimStateT, device: torch.device | None = None, dtype: torch.dtype | None = None, -) -> Self: +) -> SimStateT: """Convert the SimState to a new device and dtype. Creates a new SimState with all tensors moved to the specified device and @@ -493,125 +572,8 @@ def state_to_device( return type(state)(**attrs) -def infer_property_scope( - state: SimState, - ambiguous_handling: Literal["error", "globalize", "globalize_warn"] = "error", -) -> dict[Literal["global", "per_atom", "per_system"], list[str]]: - """Infer whether a property is global, per-atom, or per-system. - - Analyzes the shapes of tensor attributes to determine their scope within - the atomistic system representation. - - Args: - state (SimState): The state to analyze - ambiguous_handling ("error" | "globalize" | "globalize_warn"): How to - handle properties with ambiguous scope. Options: - - "error": Raise an error for ambiguous properties - - "globalize": Treat ambiguous properties as global - - "globalize_warn": Treat ambiguous properties as global with a warning - - Returns: - dict[Literal["global", "per_atom", "per_system"], list[str]]: Map of scope - category to list of property names - - Raises: - ValueError: If n_atoms equals n_systems (making scope inference ambiguous) or - if ambiguous_handling="error" and an ambiguous property is encountered - """ - # TODO: this cannot effectively resolve global properties with - # length of n_atoms or n_systems, they will be classified incorrectly, - # no clear fix - - if state.n_atoms == state.n_systems: - raise ValueError( - f"n_atoms ({state.n_atoms}) and n_systems ({state.n_systems}) are equal, " - "which means shapes cannot be inferred unambiguously." - ) - - scope = { - "global": [], - "per_atom": [], - "per_system": [], - } - - # Iterate through all attributes - for attr_name, attr_value in vars(state).items(): - # Handle scalar values (global properties) - if not isinstance(attr_value, torch.Tensor): - scope["global"].append(attr_name) - continue - - # Handle tensor properties based on shape - shape = attr_value.shape - - # Empty tensor case - if len(shape) == 0: - scope["global"].append(attr_name) - # Vector/matrix with first dimension matching number of atoms - elif shape[0] == state.n_atoms: - scope["per_atom"].append(attr_name) - # Tensor with first dimension matching number of systems - elif shape[0] == state.n_systems: - scope["per_system"].append(attr_name) - # Any other shape is ambiguous - elif ambiguous_handling == "error": - raise ValueError( - f"Cannot categorize property '{attr_name}' with shape {shape}. " - f"Expected first dimension to be either {state.n_atoms} (per-atom) or " - f"{state.n_systems} (per-system), or a scalar (global)." - ) - elif ambiguous_handling in ("globalize", "globalize_warn"): - scope["global"].append(attr_name) - - if ambiguous_handling == "globalize_warn": - warnings.warn( - f"Property '{attr_name}' with shape {shape} is ambiguous, " - "treating as global. This may lead to unexpected behavior " - "and suggests the State is not being used as intended.", - stacklevel=1, - ) - - return scope - - -def _get_property_attrs( - state: SimState, ambiguous_handling: Literal["error", "globalize"] = "error" -) -> dict[str, dict]: - """Get global, per-atom, and per-system attributes from a state. - - Categorizes all attributes of the state based on their scope - (global, per-atom, or per-system). - - Args: - state (SimState): The state to extract attributes from - ambiguous_handling ("error" | "globalize"): How to handle ambiguous - properties - - Returns: - dict[str, dict]: Keys are 'global', 'per_atom', and 'per_system', each - containing a dictionary of attribute names to values - """ - scope = infer_property_scope(state, ambiguous_handling=ambiguous_handling) - - attrs = {"global": {}, "per_atom": {}, "per_system": {}} - - # Process global properties - for attr_name in scope["global"]: - attrs["global"][attr_name] = getattr(state, attr_name) - - # Process per-atom properties - for attr_name in scope["per_atom"]: - attrs["per_atom"][attr_name] = getattr(state, attr_name) - - # Process per-system properties - for attr_name in scope["per_system"]: - attrs["per_system"][attr_name] = getattr(state, attr_name) - - return attrs - - def _filter_attrs_by_mask( - attrs: dict[str, dict], + state: SimState, atom_mask: torch.Tensor, system_mask: torch.Tensor, ) -> dict: @@ -620,8 +582,7 @@ def _filter_attrs_by_mask( Selects subsets of attributes based on boolean masks for atoms and systems. Args: - attrs (dict[str, dict]): Keys are 'global', 'per_atom', and 'per_system', each - containing a dictionary of attribute names to values + state (SimState): The state to filter atom_mask (torch.Tensor): Boolean mask for atoms to include with shape (n_atoms,) system_mask (torch.Tensor): Boolean mask for systems to include with shape @@ -633,10 +594,10 @@ def _filter_attrs_by_mask( filtered_attrs = {} # Copy global attributes directly - filtered_attrs.update(attrs["global"]) + filtered_attrs.update(state.global_features) # Filter per-atom attributes - for attr_name, attr_value in attrs["per_atom"].items(): + for attr_name, attr_value in state.node_features.items(): if attr_name == "system_idx": # Get the old system indices for the selected atoms old_system_idxs = attr_value[atom_mask] @@ -660,16 +621,15 @@ def _filter_attrs_by_mask( filtered_attrs[attr_name] = attr_value[atom_mask] # Filter per-system attributes - for attr_name, attr_value in attrs["per_system"].items(): + for attr_name, attr_value in state.system_features.items(): filtered_attrs[attr_name] = attr_value[system_mask] return filtered_attrs def _split_state( - state: SimState, - ambiguous_handling: Literal["error", "globalize"] = "error", -) -> list[SimState]: + state: SimStateT, +) -> list[SimStateT]: """Split a SimState into a list of states, each containing a single system. Divides a multi-system state into individual single-system states, preserving @@ -677,28 +637,22 @@ def _split_state( Args: state (SimState): The SimState to split - ambiguous_handling ("error" | "globalize"): How to handle ambiguous - properties. If "error", an error is raised if a property has ambiguous - scope. If "globalize", properties with ambiguous scope are treated as - global. - Returns: list[SimState]: A list of SimState objects, each containing a single system """ - attrs = _get_property_attrs(state, ambiguous_handling) system_sizes = torch.bincount(state.system_idx).tolist() # Split per-atom attributes by system split_per_atom = {} - for attr_name, attr_value in attrs["per_atom"].items(): + for attr_name, attr_value in state.node_features.items(): if attr_name == "system_idx": continue split_per_atom[attr_name] = torch.split(attr_value, system_sizes, dim=0) # Split per-system attributes into individual elements split_per_system = {} - for attr_name, attr_value in attrs["per_system"].items(): + for attr_name, attr_value in state.system_features.items(): split_per_system[attr_name] = torch.split(attr_value, 1, dim=0) # Create a state for each system @@ -717,7 +671,7 @@ def _split_state( for attr_name in split_per_system }, # Add the global attributes - **attrs["global"], + **state.global_features, } states.append(type(state)(**system_attrs)) @@ -727,7 +681,6 @@ def _split_state( def _pop_states( state: SimState, pop_indices: list[int] | torch.Tensor, - ambiguous_handling: Literal["error", "globalize"] = "error", ) -> tuple[SimState, list[SimState]]: """Pop off the states with the specified indices. @@ -736,10 +689,6 @@ def _pop_states( Args: state (SimState): The SimState to modify pop_indices (list[int] | torch.Tensor): The system indices to extract and remove - ambiguous_handling ("error" | "globalize"): How to handle ambiguous - properties. If "error", an error is raised if a property has ambiguous - scope. If "globalize", properties with ambiguous scope are treated as - global. Returns: tuple[SimState, list[SimState]]: A tuple containing: @@ -755,8 +704,6 @@ def _pop_states( if isinstance(pop_indices, list): pop_indices = torch.tensor(pop_indices, device=state.device, dtype=torch.int64) - attrs = _get_property_attrs(state, ambiguous_handling) - # Create masks for the atoms and systems to keep and pop system_range = torch.arange(state.n_systems, device=state.device) pop_system_mask = torch.isin(system_range, pop_indices) @@ -766,24 +713,23 @@ def _pop_states( keep_atom_mask = ~pop_atom_mask # Filter attributes for keep and pop states - keep_attrs = _filter_attrs_by_mask(attrs, keep_atom_mask, keep_system_mask) - pop_attrs = _filter_attrs_by_mask(attrs, pop_atom_mask, pop_system_mask) + keep_attrs = _filter_attrs_by_mask(state, keep_atom_mask, keep_system_mask) + pop_attrs = _filter_attrs_by_mask(state, pop_atom_mask, pop_system_mask) # Create the keep state keep_state = type(state)(**keep_attrs) # Create and split the pop state pop_state = type(state)(**pop_attrs) - pop_states = _split_state(pop_state, ambiguous_handling) + pop_states = _split_state(pop_state) return keep_state, pop_states def _slice_state( - state: SimState, + state: SimStateT, system_indices: list[int] | torch.Tensor, - ambiguous_handling: Literal["error", "globalize"] = "error", -) -> SimState: +) -> SimStateT: """Slice a substate from the SimState containing only the specified system indices. Creates a new SimState containing only the specified systems, preserving @@ -793,10 +739,6 @@ def _slice_state( state (SimState): The state to slice system_indices (list[int] | torch.Tensor): System indices to include in the sliced state - ambiguous_handling ("error" | "globalize"): How to handle ambiguous - properties. If "error", an error is raised if a property has ambiguous - scope. If "globalize", properties with ambiguous scope are treated as - global. Returns: SimState: A new SimState object containing only the specified systems @@ -812,15 +754,13 @@ def _slice_state( if len(system_indices) == 0: raise ValueError("system_indices cannot be empty") - attrs = _get_property_attrs(state, ambiguous_handling) - # Create masks for the atoms and systems to include system_range = torch.arange(state.n_systems, device=state.device) system_mask = torch.isin(system_range, system_indices) atom_mask = torch.isin(state.system_idx, system_indices) # Filter attributes - filtered_attrs = _filter_attrs_by_mask(attrs, atom_mask, system_mask) + filtered_attrs = _filter_attrs_by_mask(state, atom_mask, system_mask) # Create the sliced state return type(state)(**filtered_attrs) @@ -861,19 +801,12 @@ def concatenate_states( # Use the target device or default to the first state's device target_device = device or first_state.device - # Get property scopes from the first state to identify - # global/per-atom/per-system properties - first_scope = infer_property_scope(first_state) - global_props = set(first_scope["global"]) - per_atom_props = set(first_scope["per_atom"]) - per_system_props = set(first_scope["per_system"]) - # Initialize result with global properties from first state - concatenated = {prop: getattr(first_state, prop) for prop in global_props} + concatenated = first_state.global_features.copy() # Pre-allocate lists for tensors to concatenate - per_atom_tensors = {prop: [] for prop in per_atom_props} - per_system_tensors = {prop: [] for prop in per_system_props} + per_atom_tensors = {prop: [] for prop in first_state.node_features.keys()} + per_system_tensors = {prop: [] for prop in first_state.system_features.keys()} new_system_indices = [] system_offset = 0 @@ -883,14 +816,12 @@ def concatenate_states( if state.device != target_device: state = state_to_device(state, target_device) - # Collect per-atom properties - for prop in per_atom_props: - # if hasattr(state, prop): + # Collect per-node properties + for prop in first_state.node_features.keys(): per_atom_tensors[prop].append(getattr(state, prop)) # Collect per-system properties - for prop in per_system_props: - # if hasattr(state, prop): + for prop in first_state.system_features.keys(): per_system_tensors[prop].append(getattr(state, prop)) # Update system indices @@ -901,11 +832,9 @@ def concatenate_states( # Concatenate collected tensors for prop, tensors in per_atom_tensors.items(): - # if tensors: concatenated[prop] = torch.cat(tensors, dim=0) for prop, tensors in per_system_tensors.items(): - # if tensors: concatenated[prop] = torch.cat(tensors, dim=0) # Concatenate system indices @@ -943,7 +872,7 @@ def initialize_state( return state_to_device(system, device, dtype) if isinstance(system, list) and all(isinstance(s, SimState) for s in system): - if not all(state.n_systems == 1 for state in system): + if not all(cast("SimState", state).n_systems == 1 for state in system): raise ValueError( "When providing a list of states, to the initialize_state function, " "all states must have n_systems == 1. To fix this, you can split the "