From 9cf3f985e294d2852b9076b78428218f7d5e1af7 Mon Sep 17 00:00:00 2001 From: Curtis Chong Date: Fri, 29 Aug 2025 20:16:58 -0700 Subject: [PATCH 1/8] minor typing updates --- pyproject.toml | 4 ++++ torch_sim/state.py | 6 +----- torch_sim/typing.py | 20 +++++++++----------- 3 files changed, 14 insertions(+), 16 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 723edafb..8b98523e 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -151,3 +151,7 @@ conflicts = [ { extra = "sevenn" }, ], ] + + +[tool.basedpyright] +reportImplicitStringConcatenation = "none" diff --git a/torch_sim/state.py b/torch_sim/state.py index fa898e1c..259c1aed 100644 --- a/torch_sim/state.py +++ b/torch_sim/state.py @@ -194,11 +194,7 @@ def n_atoms(self) -> int: @property def n_atoms_per_system(self) -> torch.Tensor: """Number of atoms per system.""" - return ( - self.system_idx.bincount() - if self.system_idx is not None - else torch.tensor([self.n_atoms], device=self.device) - ) + return self.system_idx.bincount() @property def n_atoms_per_batch(self) -> torch.Tensor: diff --git a/torch_sim/typing.py b/torch_sim/typing.py index 94ec44ca..774e0b6f 100644 --- a/torch_sim/typing.py +++ b/torch_sim/typing.py @@ -1,7 +1,7 @@ """Types used across torch-sim.""" from enum import Enum -from typing import TYPE_CHECKING, Literal, TypeVar, Union +from typing import TYPE_CHECKING, Literal, TypeVar import torch @@ -40,13 +40,11 @@ class BravaisType(Enum): TRICLINIC = "triclinic" -StateLike = Union[ - "Atoms", - "Structure", - "PhonopyAtoms", - list["Atoms"], - list["Structure"], - list["PhonopyAtoms"], - SimStateVar, - list[SimStateVar], -] +StateLike = ( + Literal["Atoms", "Structure", "PhonopyAtoms"] + | list["Atoms"] + | list["Structure"] + | list["PhonopyAtoms"] + | SimStateVar + | list[SimStateVar] +) From 41151b68968c9afbbf7f1bb5b102828c1c443bdc Mon Sep 17 00:00:00 2001 From: Curtis Chong Date: Fri, 29 Aug 2025 20:44:51 -0700 Subject: [PATCH 2/8] ignore rules to find the real problems --- pyproject.toml | 9 +++++++++ torch_sim/state.py | 39 +++++++++++++++++---------------------- 2 files changed, 26 insertions(+), 22 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 8b98523e..61b6553d 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -155,3 +155,12 @@ conflicts = [ [tool.basedpyright] reportImplicitStringConcatenation = "none" +reportPrivateUsage = "none" # since ruff will catch this +reportAny= "none" +reportExplicitAny= "none" +reportUnknownMemberType = "none" +reportUnknownVariableType = "none" +reportUnknownArgumentType = "none" +reportMissingTypeStubs = "none" +reportPrivateLocalImportUsage = "none" +reportUnknownParameterType = "none" \ No newline at end of file diff --git a/torch_sim/state.py b/torch_sim/state.py index 259c1aed..b9dbceca 100644 --- a/torch_sim/state.py +++ b/torch_sim/state.py @@ -408,7 +408,7 @@ def __getitem__(self, system_indices: int | list[int] | slice | torch.Tensor) -> return _slice_state(self, system_indices) - def __init_subclass__(cls, **kwargs) -> None: + def __init_subclass__(cls, **kwargs: Any) -> None: """Enforce that all derived states cannot have tensor attributes that can also be None. This is because torch.concatenate cannot concat between a tensor and a None. See https://github.com/Radical-AI/torch-sim/pull/219 for more details. @@ -427,7 +427,7 @@ def _assert_no_tensor_attributes_can_be_none(cls) -> None: for attr_name, attr_typehint in type_hints.items(): origin = typing.get_origin(attr_typehint) - is_union = origin is typing.Union + is_union = origin is typing.Union # pyright: ignore[reportDeprecated] if not is_union and origin is not None: # For Python 3.10+ `|` syntax, origin is types.UnionType # We check by name to be robust against module reloading/patching issues @@ -568,10 +568,10 @@ def _normalize_system_indices( if isinstance(system_indices, slice): # Let PyTorch handle the slice conversion with negative indices return torch.arange(n_systems, device=device)[system_indices] - if isinstance(system_indices, torch.Tensor): + if isinstance(system_indices, torch.Tensor): # pyright: ignore[reportUnnecessaryIsInstance] # Handle negative indices in tensors return torch.where(system_indices < 0, n_systems + system_indices, system_indices) - raise TypeError(f"Unsupported index type: {type(system_indices)}") + raise TypeError(f"Unsupported index type: {type(system_indices)}") # pyright: ignore[reportUnreachable] def state_to_device( @@ -602,11 +602,10 @@ def state_to_device( if isinstance(attr_value, torch.Tensor): attrs[attr_name] = attr_value.to(device=device) - if dtype is not None: - attrs["positions"] = attrs["positions"].to(dtype=dtype) - attrs["masses"] = attrs["masses"].to(dtype=dtype) - attrs["cell"] = attrs["cell"].to(dtype=dtype) - attrs["atomic_numbers"] = attrs["atomic_numbers"].to(dtype=torch.int) + attrs["positions"] = attrs["positions"].to(dtype=dtype) + attrs["masses"] = attrs["masses"].to(dtype=dtype) + attrs["cell"] = attrs["cell"].to(dtype=dtype) + attrs["atomic_numbers"] = attrs["atomic_numbers"].to(dtype=torch.int) return type(state)(**attrs) @@ -623,15 +622,11 @@ def get_attrs_for_scope( Returns: Generator[tuple[str, Any], None, None]: A generator of attribute names and values """ - match scope: - case "per-atom": - attr_names = state._atom_attributes # noqa: SLF001 - case "per-system": - attr_names = state._system_attributes # noqa: SLF001 - case "global": - attr_names = state._global_attributes # noqa: SLF001 - case _: - raise ValueError(f"Unknown scope: {scope!r}") + attr_names = { + "per-atom": state._atom_attributes, # noqa: SLF001 + "per-system": state._system_attributes, # noqa: SLF001 + "global": state._global_attributes, # noqa: SLF001 + }[scope] for attr_name in attr_names: yield attr_name, getattr(state, attr_name) @@ -640,7 +635,7 @@ def _filter_attrs_by_mask( state: SimState, atom_mask: torch.Tensor, system_mask: torch.Tensor, -) -> dict: +) -> dict[str, Any]: """Filter attributes by atom and system masks. Selects subsets of attributes based on boolean masks for atoms and systems. @@ -868,9 +863,9 @@ def concatenate_states( concatenated = dict(get_attrs_for_scope(first_state, "global")) # Pre-allocate lists for tensors to concatenate - per_atom_tensors = defaultdict(list) - per_system_tensors = defaultdict(list) - new_system_indices = [] + per_atom_tensors = defaultdict[str, list[Any]](list) + per_system_tensors = defaultdict[str, list[Any]](list) + new_system_indices: list[torch.Tensor] = [] system_offset = 0 # Process all states in a single pass From d85e7f9e10d72d32f9fffed859b6fd466d843591 Mon Sep 17 00:00:00 2001 From: Curtis Chong Date: Sat, 30 Aug 2025 10:41:39 -0700 Subject: [PATCH 3/8] fix statelike and wip fix all of state --- pyproject.toml | 8 +++-- torch_sim/state.py | 84 +++++++++++++++++++++++---------------------- torch_sim/typing.py | 12 ++++--- 3 files changed, 56 insertions(+), 48 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 61b6553d..08764afa 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -158,9 +158,11 @@ reportImplicitStringConcatenation = "none" reportPrivateUsage = "none" # since ruff will catch this reportAny= "none" reportExplicitAny= "none" +include = ["torch_sim", "tests", "examples", "docs"] +exclude = [".venv"] reportUnknownMemberType = "none" reportUnknownVariableType = "none" reportUnknownArgumentType = "none" -reportMissingTypeStubs = "none" -reportPrivateLocalImportUsage = "none" -reportUnknownParameterType = "none" \ No newline at end of file +# reportMissingTypeStubs = "none" +# reportPrivateLocalImportUsage = "none" +# reportUnknownParameterType = "none" \ No newline at end of file diff --git a/torch_sim/state.py b/torch_sim/state.py index b9dbceca..d86fad2f 100644 --- a/torch_sim/state.py +++ b/torch_sim/state.py @@ -9,9 +9,9 @@ import typing import warnings from collections import defaultdict -from collections.abc import Generator +from collections.abc import Generator, Sequence from dataclasses import dataclass -from typing import TYPE_CHECKING, Any, ClassVar, Literal, Self, cast +from typing import TYPE_CHECKING, Any, ClassVar, Self, cast import torch @@ -609,25 +609,27 @@ def state_to_device( return type(state)(**attrs) -def get_attrs_for_scope( - state: SimState, scope: Literal["per-atom", "per-system", "global"] -) -> Generator[tuple[str, Any], None, None]: - """Get attributes for a given scope. +def get_attrs_for_per_atom_scope( + state: SimState, +) -> Generator[tuple[str, torch.Tensor], None, None]: + """Get attributes for the per-atom scope.""" + for attr_name in state._atom_attributes: # noqa: SLF001 + yield attr_name, getattr(state, attr_name) - Args: - state (SimState): The state to get attributes for - scope (Literal["per-atom", "per-system", "global"]): The scope to get - attributes for - Returns: - Generator[tuple[str, Any], None, None]: A generator of attribute names and values - """ - attr_names = { - "per-atom": state._atom_attributes, # noqa: SLF001 - "per-system": state._system_attributes, # noqa: SLF001 - "global": state._global_attributes, # noqa: SLF001 - }[scope] - for attr_name in attr_names: +def get_attrs_for_per_system_scope( + state: SimState, +) -> Generator[tuple[str, torch.Tensor], None, None]: + """Get attributes for the per-system scope.""" + for attr_name in state._system_attributes: # noqa: SLF001 + yield attr_name, getattr(state, attr_name) + + +def get_attrs_for_global_scope( + state: SimState, +) -> Generator[tuple[str, Any], None, None]: + """Get attributes for the global scope.""" + for attr_name in state._global_attributes: # noqa: SLF001 yield attr_name, getattr(state, attr_name) @@ -651,18 +653,18 @@ def _filter_attrs_by_mask( dict: Filtered attributes with appropriate handling for each scope """ # Copy global attributes directly - filtered_attrs = dict(get_attrs_for_scope(state, "global")) + filtered_attrs = dict(get_attrs_for_global_scope(state)) # Filter per-atom attributes - for attr_name, attr_value in get_attrs_for_scope(state, "per-atom"): + for attr_name, attr_value in get_attrs_for_per_atom_scope(state): if attr_name == "system_idx": # Get the old system indices for the selected atoms old_system_idxs = attr_value[atom_mask] # Get the system indices that are kept - kept_indices = torch.arange(attr_value.max() + 1, device=attr_value.device)[ - system_mask - ] + kept_indices = torch.arange( + (attr_value.max() + 1).item(), device=attr_value.device + )[system_mask] # Create a mapping from old system indices to new consecutive indices system_idx_map = {idx.item(): i for i, idx in enumerate(kept_indices)} @@ -678,7 +680,7 @@ def _filter_attrs_by_mask( filtered_attrs[attr_name] = attr_value[atom_mask] # Filter per-system attributes - for attr_name, attr_value in get_attrs_for_scope(state, "per-system"): + for attr_name, attr_value in get_attrs_for_per_system_scope(state): filtered_attrs[attr_name] = attr_value[system_mask] return filtered_attrs @@ -699,21 +701,21 @@ def _split_state( list[SimState]: A list of SimState objects, each containing a single system """ - system_sizes = torch.bincount(state.system_idx).tolist() + system_sizes: list[int] = torch.bincount(state.system_idx).tolist() - split_per_atom = {} - for attr_name, attr_value in get_attrs_for_scope(state, "per-atom"): + split_per_atom: dict[str, Sequence[torch.Tensor]] = {} + for attr_name, attr_value in get_attrs_for_per_atom_scope(state): if attr_name != "system_idx": split_per_atom[attr_name] = torch.split(attr_value, system_sizes, dim=0) - split_per_system = {} - for attr_name, attr_value in get_attrs_for_scope(state, "per-system"): + split_per_system: dict[str, Sequence[torch.Tensor]] = {} + for attr_name, attr_value in get_attrs_for_per_system_scope(state): split_per_system[attr_name] = torch.split(attr_value, 1, dim=0) - global_attrs = dict(get_attrs_for_scope(state, "global")) + global_attrs = dict(get_attrs_for_global_scope(state)) # Create a state for each system - states = [] + states: list[SimStateVar] = [] n_systems = len(system_sizes) for i in range(n_systems): system_attrs = { @@ -731,7 +733,7 @@ def _split_state( # Add the global attributes **global_attrs, } - states.append(type(state)(**system_attrs)) + states.append(type(state)(**system_attrs)) # pyright: ignore[reportArgumentType] return states @@ -860,11 +862,11 @@ def concatenate_states( target_device = device or first_state.device # Initialize result with global properties from first state - concatenated = dict(get_attrs_for_scope(first_state, "global")) + concatenated = dict(get_attrs_for_global_scope(first_state)) # Pre-allocate lists for tensors to concatenate - per_atom_tensors = defaultdict[str, list[Any]](list) - per_system_tensors = defaultdict[str, list[Any]](list) + per_atom_tensors = defaultdict[str, list[torch.Tensor]](list) + per_system_tensors = defaultdict[str, list[torch.Tensor]](list) new_system_indices: list[torch.Tensor] = [] system_offset = 0 @@ -875,14 +877,14 @@ def concatenate_states( state = state_to_device(state, target_device) # Collect per-atom properties - for prop, val in get_attrs_for_scope(state, "per-atom"): + for prop, val in get_attrs_for_per_atom_scope(state): if prop == "system_idx": # skip system_idx, it will be handled below continue per_atom_tensors[prop].append(val) # Collect per-system properties - for prop, val in get_attrs_for_scope(state, "per-system"): + for prop, val in get_attrs_for_per_system_scope(state): per_system_tensors[prop].append(val) # Update system indices @@ -935,13 +937,13 @@ def initialize_state( return state_to_device(system, device, dtype) if isinstance(system, list) and all(isinstance(s, SimState) for s in system): - if not all(cast("SimState", state).n_systems == 1 for state in system): + if not all(state.n_systems == 1 for state in system): # pyright: ignore[reportAttributeAccessIssue] raise ValueError( "When providing a list of states, to the initialize_state function, " "all states must have n_systems == 1. To fix this, you can split the " "states into individual states with the split_state function." ) - return concatenate_states(system) + return concatenate_states(system) # pyright: ignore[reportArgumentType] converters = [ ("pymatgen.core", "Structure", ts.io.structures_to_state), @@ -958,7 +960,7 @@ def initialize_state( if isinstance(system, cls) or ( isinstance(system, list) and all(isinstance(s, cls) for s in system) ): - return converter_func(system, device, dtype) + return converter_func(system, device, dtype) # pyright: ignore[reportArgumentType] except ImportError: continue diff --git a/torch_sim/typing.py b/torch_sim/typing.py index 774e0b6f..26cbf38f 100644 --- a/torch_sim/typing.py +++ b/torch_sim/typing.py @@ -1,7 +1,7 @@ """Types used across torch-sim.""" from enum import Enum -from typing import TYPE_CHECKING, Literal, TypeVar +from typing import TYPE_CHECKING, Literal, TypeAlias, TypeVar import torch @@ -41,10 +41,14 @@ class BravaisType(Enum): StateLike = ( - Literal["Atoms", "Structure", "PhonopyAtoms"] + Atoms + | Structure + | PhonopyAtoms | list["Atoms"] | list["Structure"] | list["PhonopyAtoms"] - | SimStateVar - | list[SimStateVar] + | SimState + | list[SimState] + # | SimStateVar + # | list[SimStateVar] ) From f706b75736c4952ff0af7b2bbaf87e8e34b3f22a Mon Sep 17 00:00:00 2001 From: Curtis Chong Date: Sat, 30 Aug 2025 10:48:20 -0700 Subject: [PATCH 4/8] fixed most basedpyright issues in state.py --- torch_sim/state.py | 4 ++-- torch_sim/typing.py | 8 ++++---- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/torch_sim/state.py b/torch_sim/state.py index d86fad2f..2404694b 100644 --- a/torch_sim/state.py +++ b/torch_sim/state.py @@ -943,7 +943,7 @@ def initialize_state( "all states must have n_systems == 1. To fix this, you can split the " "states into individual states with the split_state function." ) - return concatenate_states(system) # pyright: ignore[reportArgumentType] + return concatenate_states(system) converters = [ ("pymatgen.core", "Structure", ts.io.structures_to_state), @@ -960,7 +960,7 @@ def initialize_state( if isinstance(system, cls) or ( isinstance(system, list) and all(isinstance(s, cls) for s in system) ): - return converter_func(system, device, dtype) # pyright: ignore[reportArgumentType] + return converter_func(system, device, dtype) except ImportError: continue diff --git a/torch_sim/typing.py b/torch_sim/typing.py index 26cbf38f..de6b0566 100644 --- a/torch_sim/typing.py +++ b/torch_sim/typing.py @@ -1,7 +1,7 @@ """Types used across torch-sim.""" from enum import Enum -from typing import TYPE_CHECKING, Literal, TypeAlias, TypeVar +from typing import TYPE_CHECKING, Literal, TypeVar import torch @@ -44,9 +44,9 @@ class BravaisType(Enum): Atoms | Structure | PhonopyAtoms - | list["Atoms"] - | list["Structure"] - | list["PhonopyAtoms"] + | list[Atoms] + | list[Structure] + | list[PhonopyAtoms] | SimState | list[SimState] # | SimStateVar From 55a733330ec492dd5f9cb382bf5212b62d9578c5 Mon Sep 17 00:00:00 2001 From: Curtis Chong Date: Sat, 30 Aug 2025 11:05:18 -0700 Subject: [PATCH 5/8] added overrides for get_attrs_for_scope --- torch_sim/state.py | 62 +++++++++++++++++++++++++++------------------- 1 file changed, 36 insertions(+), 26 deletions(-) diff --git a/torch_sim/state.py b/torch_sim/state.py index 2404694b..a2db20ea 100644 --- a/torch_sim/state.py +++ b/torch_sim/state.py @@ -11,7 +11,7 @@ from collections import defaultdict from collections.abc import Generator, Sequence from dataclasses import dataclass -from typing import TYPE_CHECKING, Any, ClassVar, Self, cast +from typing import TYPE_CHECKING, Any, ClassVar, Literal, Self, cast, overload import torch @@ -609,27 +609,37 @@ def state_to_device( return type(state)(**attrs) -def get_attrs_for_per_atom_scope( - state: SimState, -) -> Generator[tuple[str, torch.Tensor], None, None]: - """Get attributes for the per-atom scope.""" - for attr_name in state._atom_attributes: # noqa: SLF001 - yield attr_name, getattr(state, attr_name) +@overload +def get_attrs_for_scope( + state: SimState, scope: Literal["per-atom", "per-system"] +) -> Generator[tuple[str, torch.Tensor], None, None]: ... -def get_attrs_for_per_system_scope( - state: SimState, -) -> Generator[tuple[str, torch.Tensor], None, None]: - """Get attributes for the per-system scope.""" - for attr_name in state._system_attributes: # noqa: SLF001 - yield attr_name, getattr(state, attr_name) +@overload +def get_attrs_for_scope( + state: SimState, scope: Literal["global"] +) -> Generator[tuple[str, Any], None, None]: ... -def get_attrs_for_global_scope( - state: SimState, +def get_attrs_for_scope( + state: SimState, scope: Literal["per-atom", "per-system", "global"] ) -> Generator[tuple[str, Any], None, None]: - """Get attributes for the global scope.""" - for attr_name in state._global_attributes: # noqa: SLF001 + """Get attributes for a given scope. + + Args: + state (SimState): The state to get attributes for + scope (Literal["per-atom", "per-system", "global"]): The scope to get + attributes for + + Returns: + Generator[tuple[str, Any], None, None]: A generator of attribute names and values + """ + attr_names = { + "per-atom": state._atom_attributes, # noqa: SLF001 + "per-system": state._system_attributes, # noqa: SLF001 + "global": state._global_attributes, # noqa: SLF001 + }[scope] + for attr_name in attr_names: yield attr_name, getattr(state, attr_name) @@ -653,10 +663,10 @@ def _filter_attrs_by_mask( dict: Filtered attributes with appropriate handling for each scope """ # Copy global attributes directly - filtered_attrs = dict(get_attrs_for_global_scope(state)) + filtered_attrs = dict(get_attrs_for_scope(state, "global")) # Filter per-atom attributes - for attr_name, attr_value in get_attrs_for_per_atom_scope(state): + for attr_name, attr_value in get_attrs_for_scope(state, "per-atom"): if attr_name == "system_idx": # Get the old system indices for the selected atoms old_system_idxs = attr_value[atom_mask] @@ -680,7 +690,7 @@ def _filter_attrs_by_mask( filtered_attrs[attr_name] = attr_value[atom_mask] # Filter per-system attributes - for attr_name, attr_value in get_attrs_for_per_system_scope(state): + for attr_name, attr_value in get_attrs_for_scope(state, "per-system"): filtered_attrs[attr_name] = attr_value[system_mask] return filtered_attrs @@ -704,15 +714,15 @@ def _split_state( system_sizes: list[int] = torch.bincount(state.system_idx).tolist() split_per_atom: dict[str, Sequence[torch.Tensor]] = {} - for attr_name, attr_value in get_attrs_for_per_atom_scope(state): + for attr_name, attr_value in get_attrs_for_scope(state, "per-atom"): if attr_name != "system_idx": split_per_atom[attr_name] = torch.split(attr_value, system_sizes, dim=0) split_per_system: dict[str, Sequence[torch.Tensor]] = {} - for attr_name, attr_value in get_attrs_for_per_system_scope(state): + for attr_name, attr_value in get_attrs_for_scope(state, "per-system"): split_per_system[attr_name] = torch.split(attr_value, 1, dim=0) - global_attrs = dict(get_attrs_for_global_scope(state)) + global_attrs = dict(get_attrs_for_scope(state, "global")) # Create a state for each system states: list[SimStateVar] = [] @@ -862,7 +872,7 @@ def concatenate_states( target_device = device or first_state.device # Initialize result with global properties from first state - concatenated = dict(get_attrs_for_global_scope(first_state)) + concatenated = dict(get_attrs_for_scope(first_state, "global")) # Pre-allocate lists for tensors to concatenate per_atom_tensors = defaultdict[str, list[torch.Tensor]](list) @@ -877,14 +887,14 @@ def concatenate_states( state = state_to_device(state, target_device) # Collect per-atom properties - for prop, val in get_attrs_for_per_atom_scope(state): + for prop, val in get_attrs_for_scope(state, "per-atom"): if prop == "system_idx": # skip system_idx, it will be handled below continue per_atom_tensors[prop].append(val) # Collect per-system properties - for prop, val in get_attrs_for_per_system_scope(state): + for prop, val in get_attrs_for_scope(state, "per-system"): per_system_tensors[prop].append(val) # Update system indices From 835ad759d7b821c8ff59d6778e43be19f91283a3 Mon Sep 17 00:00:00 2001 From: Curtis Chong Date: Sat, 30 Aug 2025 11:08:55 -0700 Subject: [PATCH 6/8] basedpyright passes most type checks --- pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index 08764afa..697236ec 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -163,6 +163,6 @@ exclude = [".venv"] reportUnknownMemberType = "none" reportUnknownVariableType = "none" reportUnknownArgumentType = "none" -# reportMissingTypeStubs = "none" +reportMissingTypeStubs = "none" # reportPrivateLocalImportUsage = "none" # reportUnknownParameterType = "none" \ No newline at end of file From 90b478867a2300895a1833cbcebd09699bf7764e Mon Sep 17 00:00:00 2001 From: Curtis Chong Date: Sat, 30 Aug 2025 11:14:07 -0700 Subject: [PATCH 7/8] ignore pyright errors --- pyproject.toml | 4 +--- torch_sim/typing.py | 2 -- 2 files changed, 1 insertion(+), 5 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 697236ec..1f6a69cd 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -163,6 +163,4 @@ exclude = [".venv"] reportUnknownMemberType = "none" reportUnknownVariableType = "none" reportUnknownArgumentType = "none" -reportMissingTypeStubs = "none" -# reportPrivateLocalImportUsage = "none" -# reportUnknownParameterType = "none" \ No newline at end of file +reportMissingTypeStubs = "none" \ No newline at end of file diff --git a/torch_sim/typing.py b/torch_sim/typing.py index de6b0566..ea603dea 100644 --- a/torch_sim/typing.py +++ b/torch_sim/typing.py @@ -49,6 +49,4 @@ class BravaisType(Enum): | list[PhonopyAtoms] | SimState | list[SimState] - # | SimStateVar - # | list[SimStateVar] ) From 2c48c3fe4f6ed7b1b3888367da2d616de05ed805 Mon Sep 17 00:00:00 2001 From: Curtis Chong Date: Sat, 30 Aug 2025 11:44:22 -0700 Subject: [PATCH 8/8] revert .item call --- torch_sim/state.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/torch_sim/state.py b/torch_sim/state.py index a2db20ea..8fd59aff 100644 --- a/torch_sim/state.py +++ b/torch_sim/state.py @@ -672,9 +672,9 @@ def _filter_attrs_by_mask( old_system_idxs = attr_value[atom_mask] # Get the system indices that are kept - kept_indices = torch.arange( - (attr_value.max() + 1).item(), device=attr_value.device - )[system_mask] + kept_indices = torch.arange(attr_value.max() + 1, device=attr_value.device)[ + system_mask + ] # Create a mapping from old system indices to new consecutive indices system_idx_map = {idx.item(): i for i, idx in enumerate(kept_indices)}