Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 13 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -151,3 +151,16 @@ conflicts = [
{ extra = "sevenn" },
],
]


[tool.basedpyright]
reportImplicitStringConcatenation = "none"
reportPrivateUsage = "none" # since ruff will catch this
reportAny= "none"
reportExplicitAny= "none"
include = ["torch_sim", "tests", "examples", "docs"]
exclude = [".venv"]
reportUnknownMemberType = "none"
reportUnknownVariableType = "none"
reportUnknownArgumentType = "none"
reportMissingTypeStubs = "none"
Comment on lines +156 to +166
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue

Fix the trailing newline issue flagged by the linter.

The pipeline failure indicates a missing trailing newline at the end of the file.

 reportUnknownVariableType = "none"
 reportUnknownArgumentType = "none"
 reportMissingTypeStubs = "none"
+
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
[tool.basedpyright]
reportImplicitStringConcatenation = "none"
reportPrivateUsage = "none" # since ruff will catch this
reportAny= "none"
reportExplicitAny= "none"
include = ["torch_sim", "tests", "examples", "docs"]
exclude = [".venv"]
reportUnknownMemberType = "none"
reportUnknownVariableType = "none"
reportUnknownArgumentType = "none"
reportMissingTypeStubs = "none"
[tool.basedpyright]
reportImplicitStringConcatenation = "none"
reportPrivateUsage = "none" # since ruff will catch this
reportAny= "none"
reportExplicitAny= "none"
include = ["torch_sim", "tests", "examples", "docs"]
exclude = [".venv"]
reportUnknownMemberType = "none"
reportUnknownVariableType = "none"
reportUnknownArgumentType = "none"
reportMissingTypeStubs = "none"
🤖 Prompt for AI Agents
In pyproject.toml around lines 156 to 166, the file is missing a trailing
newline at EOF causing the linter/pipeline failure; open the file and add a
single POSIX newline character at the end of the file (ensure the final line
ends with '\n'), save and commit so the file ends with a newline.

73 changes: 38 additions & 35 deletions torch_sim/state.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,9 @@
import typing
import warnings
from collections import defaultdict
from collections.abc import Generator
from collections.abc import Generator, Sequence
from dataclasses import dataclass
from typing import TYPE_CHECKING, Any, ClassVar, Literal, Self, cast
from typing import TYPE_CHECKING, Any, ClassVar, Literal, Self, cast, overload

import torch

Expand Down Expand Up @@ -194,11 +194,7 @@ def n_atoms(self) -> int:
@property
def n_atoms_per_system(self) -> torch.Tensor:
"""Number of atoms per system."""
return (
self.system_idx.bincount()
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

system_idx will never be none

if self.system_idx is not None
else torch.tensor([self.n_atoms], device=self.device)
)
return self.system_idx.bincount()

@property
def n_atoms_per_batch(self) -> torch.Tensor:
Expand Down Expand Up @@ -412,7 +408,7 @@ def __getitem__(self, system_indices: int | list[int] | slice | torch.Tensor) ->

return _slice_state(self, system_indices)

def __init_subclass__(cls, **kwargs) -> None:
def __init_subclass__(cls, **kwargs: Any) -> None:
"""Enforce that all derived states cannot have tensor attributes that can also be
None. This is because torch.concatenate cannot concat between a tensor and a None.
See https://github.com/Radical-AI/torch-sim/pull/219 for more details.
Expand All @@ -431,7 +427,7 @@ def _assert_no_tensor_attributes_can_be_none(cls) -> None:
for attr_name, attr_typehint in type_hints.items():
origin = typing.get_origin(attr_typehint)

is_union = origin is typing.Union
is_union = origin is typing.Union # pyright: ignore[reportDeprecated]
if not is_union and origin is not None:
# For Python 3.10+ `|` syntax, origin is types.UnionType
# We check by name to be robust against module reloading/patching issues
Expand Down Expand Up @@ -572,10 +568,10 @@ def _normalize_system_indices(
if isinstance(system_indices, slice):
# Let PyTorch handle the slice conversion with negative indices
return torch.arange(n_systems, device=device)[system_indices]
if isinstance(system_indices, torch.Tensor):
if isinstance(system_indices, torch.Tensor): # pyright: ignore[reportUnnecessaryIsInstance]
# Handle negative indices in tensors
return torch.where(system_indices < 0, n_systems + system_indices, system_indices)
raise TypeError(f"Unsupported index type: {type(system_indices)}")
raise TypeError(f"Unsupported index type: {type(system_indices)}") # pyright: ignore[reportUnreachable]


def state_to_device(
Expand Down Expand Up @@ -606,14 +602,25 @@ def state_to_device(
if isinstance(attr_value, torch.Tensor):
attrs[attr_name] = attr_value.to(device=device)

if dtype is not None:
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we assign dtype, so this can never be None

attrs["positions"] = attrs["positions"].to(dtype=dtype)
attrs["masses"] = attrs["masses"].to(dtype=dtype)
attrs["cell"] = attrs["cell"].to(dtype=dtype)
attrs["atomic_numbers"] = attrs["atomic_numbers"].to(dtype=torch.int)
attrs["positions"] = attrs["positions"].to(dtype=dtype)
attrs["masses"] = attrs["masses"].to(dtype=dtype)
attrs["cell"] = attrs["cell"].to(dtype=dtype)
attrs["atomic_numbers"] = attrs["atomic_numbers"].to(dtype=torch.int)
return type(state)(**attrs)


@overload
def get_attrs_for_scope(
state: SimState, scope: Literal["per-atom", "per-system"]
) -> Generator[tuple[str, torch.Tensor], None, None]: ...


@overload
def get_attrs_for_scope(
state: SimState, scope: Literal["global"]
) -> Generator[tuple[str, Any], None, None]: ...


def get_attrs_for_scope(
state: SimState, scope: Literal["per-atom", "per-system", "global"]
) -> Generator[tuple[str, Any], None, None]:
Expand All @@ -627,15 +634,11 @@ def get_attrs_for_scope(
Returns:
Generator[tuple[str, Any], None, None]: A generator of attribute names and values
"""
match scope:
case "per-atom":
attr_names = state._atom_attributes # noqa: SLF001
case "per-system":
attr_names = state._system_attributes # noqa: SLF001
case "global":
attr_names = state._global_attributes # noqa: SLF001
case _:
raise ValueError(f"Unknown scope: {scope!r}")
attr_names = {
"per-atom": state._atom_attributes, # noqa: SLF001
"per-system": state._system_attributes, # noqa: SLF001
"global": state._global_attributes, # noqa: SLF001
}[scope]
for attr_name in attr_names:
yield attr_name, getattr(state, attr_name)

Expand All @@ -644,7 +647,7 @@ def _filter_attrs_by_mask(
state: SimState,
atom_mask: torch.Tensor,
system_mask: torch.Tensor,
) -> dict:
) -> dict[str, Any]:
"""Filter attributes by atom and system masks.

Selects subsets of attributes based on boolean masks for atoms and systems.
Expand Down Expand Up @@ -708,21 +711,21 @@ def _split_state(
list[SimState]: A list of SimState objects, each containing a single
system
"""
system_sizes = torch.bincount(state.system_idx).tolist()
system_sizes: list[int] = torch.bincount(state.system_idx).tolist()

split_per_atom = {}
split_per_atom: dict[str, Sequence[torch.Tensor]] = {}
for attr_name, attr_value in get_attrs_for_scope(state, "per-atom"):
if attr_name != "system_idx":
split_per_atom[attr_name] = torch.split(attr_value, system_sizes, dim=0)

split_per_system = {}
split_per_system: dict[str, Sequence[torch.Tensor]] = {}
for attr_name, attr_value in get_attrs_for_scope(state, "per-system"):
split_per_system[attr_name] = torch.split(attr_value, 1, dim=0)

global_attrs = dict(get_attrs_for_scope(state, "global"))

# Create a state for each system
states = []
states: list[SimStateVar] = []
n_systems = len(system_sizes)
for i in range(n_systems):
system_attrs = {
Expand All @@ -740,7 +743,7 @@ def _split_state(
# Add the global attributes
**global_attrs,
}
states.append(type(state)(**system_attrs))
states.append(type(state)(**system_attrs)) # pyright: ignore[reportArgumentType]

return states

Expand Down Expand Up @@ -872,9 +875,9 @@ def concatenate_states(
concatenated = dict(get_attrs_for_scope(first_state, "global"))

# Pre-allocate lists for tensors to concatenate
per_atom_tensors = defaultdict(list)
per_system_tensors = defaultdict(list)
new_system_indices = []
per_atom_tensors = defaultdict[str, list[torch.Tensor]](list)
per_system_tensors = defaultdict[str, list[torch.Tensor]](list)
new_system_indices: list[torch.Tensor] = []
system_offset = 0

# Process all states in a single pass
Expand Down Expand Up @@ -944,7 +947,7 @@ def initialize_state(
return state_to_device(system, device, dtype)

if isinstance(system, list) and all(isinstance(s, SimState) for s in system):
if not all(cast("SimState", state).n_systems == 1 for state in system):
if not all(state.n_systems == 1 for state in system): # pyright: ignore[reportAttributeAccessIssue]
raise ValueError(
"When providing a list of states, to the initialize_state function, "
"all states must have n_systems == 1. To fix this, you can split the "
Expand Down
22 changes: 11 additions & 11 deletions torch_sim/typing.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
"""Types used across torch-sim."""

from enum import Enum
from typing import TYPE_CHECKING, Literal, TypeVar, Union
from typing import TYPE_CHECKING, Literal, TypeVar

import torch

Expand Down Expand Up @@ -40,13 +40,13 @@ class BravaisType(Enum):
TRICLINIC = "triclinic"


StateLike = Union[
"Atoms",
"Structure",
"PhonopyAtoms",
list["Atoms"],
list["Structure"],
list["PhonopyAtoms"],
SimStateVar,
list[SimStateVar],
]
StateLike = (
Atoms
| Structure
| PhonopyAtoms
| list[Atoms]
| list[Structure]
| list[PhonopyAtoms]
| SimState
| list[SimState]
)
Loading