Skip to content

Conversation

curtischong
Copy link
Collaborator

@curtischong curtischong commented Aug 30, 2025

Summary

  • Adds pyright with a bunch of rules checked off. The goal is to solve the actually important type issues (e.g. assuming that a variable is always torch.Tensor (and it'll never be None). I'll make a few PRs but this is just the first. I tried to get state.py a bit typed just to start a conversation.
  • I added overloads for get_attrs_for_scope so when ppl are using them, they will know the types that are returned. This is NOT a design decision because already; all atom and system attributes must be torch.Tensor whereas all global attributes can be Any. This must be enforced because when we split sim_states, we use torch.split on those per-atom and per-system attributes.

Checklist

Before a pull request can be merged, the following items must be checked:

  • Doc strings have been added in the Google docstring format.
  • Run ruff on your code.
  • Tests have been added for any new functionality or bug fixes.

We highly recommended installing the pre-commit hooks running in CI locally to speedup the development process. Simply run pip install pre-commit && pre-commit install to install the hooks which will check your code before each commit.

Summary by CodeRabbit

  • Refactor
    • Improved and modernized typing across simulation state and utilities, including clearer overloads for attribute retrieval by scope.
    • Streamlined device/dtype handling and more consistent system indexing when splitting/merging states.
  • Bug Fixes
    • More explicit error behavior for invalid attribute-scope requests.
  • Chores
    • Updated type-checker configuration to reduce noise and target relevant project paths.
    • Simplified public type aliases for modern Python typing compatibility.

@cla-bot cla-bot bot added the cla-signed Contributor license agreement signed label Aug 30, 2025
Copy link

coderabbitai bot commented Aug 30, 2025

Walkthrough

Adds a BasedPyright config to pyproject.toml and refactors typing and control flow in torch_sim: adds overloads for get_attrs_for_scope, tightens annotations, changes scope dispatch to dict (KeyError on unknown), simplifies dtype handling, and updates StateLike to PEP 604 union syntax.

Changes

Cohort / File(s) Summary of changes
Type-checker config
pyproject.toml
Adds a new [tool.basedpyright] block disabling many report* checks, restricts include to ["torch_sim","tests","examples","docs"], and excludes [".venv"].
SimState typing and state ops
torch_sim/state.py
Adds Sequence and overload imports; annotates __init_subclass__ kwargs as Any; adds overloads for get_attrs_for_scope; switches scope handling to dict dispatch (raises KeyError for unknown); simplifies n_atoms_per_system and unconditionally applies dtype casting in state_to_device; refines _split_state/concatenate logic with explicit types and zero-initialized per-system system_idx; adds pyright ignores where needed.
Public typing aliases
torch_sim/typing.py
Replaces Union[...]-based StateLike with PEP 604 (`

Sequence Diagram(s)

sequenceDiagram
  autonumber
  actor Caller
  participant State as SimState
  participant ScopeMap as Scope Dispatch

  Caller->>State: get_attrs_for_scope(scope)
  State->>ScopeMap: dispatch[scope]
  alt scope in {"per-atom","per-system"}
    ScopeMap-->>Caller: Generator[(str, Tensor)]
  else scope == "global"
    ScopeMap-->>Caller: Generator[(str, Any)]
  else unknown scope
    ScopeMap-->>Caller: KeyError
  end
Loading
sequenceDiagram
  autonumber
  participant Orchestrator
  participant Split as _split_state
  participant Concat as concatenate_states

  Orchestrator->>Split: split multi-system state
  Split->>Split: compute system_sizes
  Split->>Split: slice per-atom/system tensors
  Split->>Split: init system_idx = zeros(len(system))
  Split-->>Orchestrator: list[SimState]

  Orchestrator->>Concat: concatenate list[SimState]
  Concat->>Concat: stack per-atom/system tensors
  Concat-->>Orchestrator: merged SimState
Loading

Estimated code review effort

🎯 3 (Moderate) | ⏱️ ~25 minutes

Possibly related PRs

Suggested reviewers

  • orionarcher
  • CompRhys

Poem

I thump my paw on typing’s tune,
Dispatch maps hum a cleaner rune.
Systems split, their indices new—
Zeros march in tidy queue.
Pyright naps, reports at bay — I hop and nibble docs away. 🥕🐇

Tip

🔌 Remote MCP (Model Context Protocol) integration is now available!

Pro plan users can now connect to remote MCP servers from the Integrations page. Connect with popular remote MCPs such as Notion and Linear to add more context to your reviews and chats.

✨ Finishing Touches
  • 📝 Generate Docstrings
🧪 Generate unit tests
  • Create PR with unit tests
  • Post copyable unit tests in a comment

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share
🪧 Tips

Chat

There are 3 ways to chat with CodeRabbit:

  • Review comments: Directly reply to a review comment made by CodeRabbit. Example:
    • I pushed a fix in commit <commit_id>, please review it.
    • Open a follow-up GitHub issue for this discussion.
  • Files and specific lines of code (under the "Files changed" tab): Tag @coderabbitai in a new review comment at the desired location with your query.
  • PR comments: Tag @coderabbitai in a new PR comment to ask questions about the PR branch. For the best results, please provide a very specific query, as very limited context is provided in this mode. Examples:
    • @coderabbitai gather interesting stats about this repository and render them as a table. Additionally, render a pie chart showing the language distribution in the codebase.
    • @coderabbitai read the files in the src/scheduler package and generate a class diagram using mermaid and a README in the markdown format.

Support

Need help? Create a ticket on our support page for assistance with any issues or questions.

CodeRabbit Commands (Invoked using PR/Issue comments)

Type @coderabbitai help to get the list of available commands.

Other keywords and placeholders

  • Add @coderabbitai ignore or @coderabbit ignore anywhere in the PR description to prevent this PR from being reviewed.
  • Add @coderabbitai summary to generate the high-level summary at a specific location in the PR description.
  • Add @coderabbitai anywhere in the PR title to generate the title automatically.

CodeRabbit Configuration File (.coderabbit.yaml)

  • You can programmatically configure CodeRabbit by adding a .coderabbit.yaml file to the root of your repository.
  • Please see the configuration documentation for more information.
  • If your editor has YAML language server enabled, you can add the path at the top of this file to enable auto-completion and validation: # yaml-language-server: $schema=https://coderabbit.ai/integrations/schema.v2.json

Status, Documentation and Community

  • Visit our Status Page to check the current availability of CodeRabbit.
  • Visit our Documentation for detailed information on how to use CodeRabbit.
  • Join our Discord Community to get help, request features, and share feedback.
  • Follow us on X/Twitter for updates and announcements.

@@ -606,14 +602,25 @@ def state_to_device(
if isinstance(attr_value, torch.Tensor):
attrs[attr_name] = attr_value.to(device=device)

if dtype is not None:
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we assign dtype, so this can never be None

@@ -194,11 +194,7 @@ def n_atoms(self) -> int:
@property
def n_atoms_per_system(self) -> torch.Tensor:
"""Number of atoms per system."""
return (
self.system_idx.bincount()
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

system_idx will never be none

Copy link

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

Caution

Some comments are outside the diff and can’t be posted inline due to platform limitations.

⚠️ Outside diff range comments (1)
torch_sim/typing.py (1)

10-14: Move imports out of TYPE_CHECKING block as they're used at runtime.

The linter correctly identifies that these imports are being used in the StateLike type alias at runtime (lines 43-52), not just for type hints. They should be moved outside the TYPE_CHECKING block to avoid runtime NameError exceptions.

-if TYPE_CHECKING:
-    from ase import Atoms
-    from phonopy.structure.atoms import PhonopyAtoms
-    from pymatgen.core import Structure
-
-    from torch_sim.state import SimState
+from typing import TYPE_CHECKING
+
+# These imports are used at runtime in the StateLike type alias
+from ase import Atoms
+from phonopy.structure.atoms import PhonopyAtoms
+from pymatgen.core import Structure
+from torch_sim.state import SimState

However, if these libraries are optional dependencies, consider using string literals in the type alias instead:

StateLike = (
    "Atoms"
    | "Structure"
    | "PhonopyAtoms"
    | list["Atoms"]
    | list["Structure"]
    | list["PhonopyAtoms"]
    | "SimState"
    | list["SimState"]
)
🧹 Nitpick comments (4)
pyproject.toml (1)

158-158: Consider keeping reportPrivateUsage enabled for BasedPyright.

While the comment states "since ruff will catch this", ruff's private usage checks (N801-N807) primarily focus on naming conventions, not actual access of private members from outside their defining class. BasedPyright's reportPrivateUsage provides more comprehensive checking for accessing private attributes/methods, which is an important aspect of encapsulation that ruff doesn't fully cover.

-reportPrivateUsage = "none" # since ruff will catch this
+# reportPrivateUsage is kept at default to catch private member access violations
torch_sim/state.py (3)

637-641: Dictionary dispatch is cleaner but raises KeyError instead of ValueError.

The change from if/elif to dictionary dispatch improves readability. However, this changes the exception type from ValueError to KeyError for unknown scopes. While KeyError is technically correct for dictionary lookups, ValueError might be more appropriate for invalid parameter values.

Consider wrapping in a try/except to maintain the ValueError:

-    attr_names = {
-        "per-atom": state._atom_attributes,  # noqa: SLF001
-        "per-system": state._system_attributes,  # noqa: SLF001
-        "global": state._global_attributes,  # noqa: SLF001
-    }[scope]
+    scope_map = {
+        "per-atom": state._atom_attributes,  # noqa: SLF001
+        "per-system": state._system_attributes,  # noqa: SLF001
+        "global": state._global_attributes,  # noqa: SLF001
+    }
+    try:
+        attr_names = scope_map[scope]
+    except KeyError:
+        raise ValueError(f"Unknown scope: {scope}") from None

571-574: Pyright ignore comments could be avoided with better type guards.

The reportUnnecessaryIsInstance and reportUnreachable warnings suggest pyright doesn't understand the type narrowing here. Consider using a type guard or reorganizing the logic.

-    if isinstance(system_indices, torch.Tensor):  # pyright: ignore[reportUnnecessaryIsInstance]
-        # Handle negative indices in tensors
-        return torch.where(system_indices < 0, n_systems + system_indices, system_indices)
-    raise TypeError(f"Unsupported index type: {type(system_indices)}")  # pyright: ignore[reportUnreachable]
+    # At this point, system_indices must be a tensor if we've handled all other cases
+    assert isinstance(system_indices, torch.Tensor), f"Unsupported index type: {type(system_indices)}"
+    # Handle negative indices in tensors
+    return torch.where(system_indices < 0, n_systems + system_indices, system_indices)

950-950: Consider addressing the pyright ignore more elegantly.

The reportAttributeAccessIssue ignore suggests pyright doesn't understand that all items in the list are SimState objects. This was already checked on line 949.

Consider using a type guard or assertion to help the type checker:

-    if isinstance(system, list) and all(isinstance(s, SimState) for s in system):
-        if not all(state.n_systems == 1 for state in system):  # pyright: ignore[reportAttributeAccessIssue]
+    if isinstance(system, list) and all(isinstance(s, SimState) for s in system):
+        # Type narrowing for pyright
+        sim_states = cast(list[SimState], system)
+        if not all(state.n_systems == 1 for state in sim_states):
📜 Review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

💡 Knowledge Base configuration:

  • MCP integration is disabled by default for public repositories
  • Jira integration is disabled by default for public repositories
  • Linear integration is disabled by default for public repositories

You can enable these sources in your CodeRabbit configuration.

📥 Commits

Reviewing files that changed from the base of the PR and between b251e4c and 90b4788.

📒 Files selected for processing (3)
  • pyproject.toml (1 hunks)
  • torch_sim/state.py (13 hunks)
  • torch_sim/typing.py (2 hunks)
🧰 Additional context used
🧬 Code graph analysis (2)
torch_sim/typing.py (2)
torch_sim/state.py (1)
  • SimState (29-489)
torch_sim/io.py (3)
  • phonopy_to_state (316-391)
  • atoms_to_state (180-245)
  • state_to_phonopy (131-177)
torch_sim/state.py (2)
torch_sim/io.py (1)
  • atoms_to_state (180-245)
torch_sim/optimizers.py (1)
  • FireState (484-536)
🪛 GitHub Actions: Linting
pyproject.toml

[error] 1-1: End-of-file-fixer: Fixed missing trailing newline in pyproject.toml; re-run pre-commit to verify.

torch_sim/typing.py

[error] 10-10: TC004 Move import ase.Atoms out of type-checking block. Import is used for more than type hinting.


[error] 11-11: TC004 Move import phonopy.structure.atoms.PhonopyAtoms out of type-checking block. Import is used for more than type hinting.


[error] 12-12: TC004 Move import pymatgen.core.Structure out of type-checking block. Import is used for more than type hinting.


[error] 14-14: TC004 Move import torch_sim.state.SimState out of type-checking block. Import is used for more than type hinting.

⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (17)
  • GitHub Check: test-examples (examples/tutorials/using_graphpes_tutorial.py)
  • GitHub Check: test-examples (examples/tutorials/state_tutorial.py)
  • GitHub Check: test-examples (examples/tutorials/hybrid_swap_tutorial.py)
  • GitHub Check: test-examples (examples/tutorials/reporting_tutorial.py)
  • GitHub Check: test-examples (examples/tutorials/autobatching_tutorial.py)
  • GitHub Check: test-examples (examples/scripts/5_Workflow/5.3_Elastic.py)
  • GitHub Check: test-examples (examples/scripts/5_Workflow/5.1_a2c_silicon_batched.py)
  • GitHub Check: test-examples (examples/scripts/3_Dynamics/3.10_Hybrid_swap_mc.py)
  • GitHub Check: test-examples (examples/scripts/1_Introduction/1.2_MACE.py)
  • GitHub Check: test-examples (examples/scripts/3_Dynamics/3.2_MACE_NVE.py)
  • GitHub Check: test-examples (examples/scripts/6_Phonons/6.2_QuasiHarmonic_MACE.py)
  • GitHub Check: test-model (macos-14, 3.11, highest, orb, tests/models/test_orb.py)
  • GitHub Check: test-model (macos-14, 3.12, lowest-direct, orb, tests/models/test_orb.py)
  • GitHub Check: test-model (macos-14, 3.12, lowest-direct, mace, tests/test_optimizers_vs_ase.py)
  • GitHub Check: test-model (macos-14, 3.12, lowest-direct, fairchem, tests/models/test_fairchem.py)
  • GitHub Check: test-model (macos-14, 3.12, lowest-direct, mattersim, tests/models/test_mattersim.py)
  • GitHub Check: build-docs
🔇 Additional comments (8)
torch_sim/typing.py (1)

43-52: Good modernization of the type alias syntax.

The change from Union[...] to the modern | syntax (PEP 604) improves readability. The switch from SimStateVar to concrete SimState types also makes the type more explicit.

torch_sim/state.py (7)

411-411: Good addition of type annotation for kwargs.

Adding Any type annotation to **kwargs improves type clarity for the __init_subclass__ method.


612-622: Excellent use of overloads for type safety.

The overloaded signatures for get_attrs_for_scope correctly express that:

  • "per-atom" and "per-system" scopes always return torch.Tensor values
  • "global" scope can return Any type values

This provides better type information to callers and helps catch type errors at static analysis time.


197-197: Simplified implementation but changes behavior for edge cases.

The removal of the None check simplifies the code. However, this assumes system_idx is never None at this point, which should be guaranteed by the constructor's initialization logic.


605-608: Simplified dtype handling looks correct.

The unconditional dtype casting is cleaner since dtype is guaranteed to be non-None at this point (line 598 ensures a default).


714-746: Good type annotations and proper system_idx initialization.

The explicit type annotations for system_sizes, split_per_atom, split_per_system, and states improve type clarity. The change to initialize system_idx with zeros for each split system is correct - each resulting single-system state should have all atoms belonging to system 0.


878-880: Proper use of defaultdict with type parameters.

The type annotations for defaultdict[str, list[torch.Tensor]] correctly specify the expected types for the dictionary values.


430-430: Good handling of Python 3.10+ union syntax.

The code correctly handles both typing.Union and Python 3.10+ | syntax (types.UnionType) for union types, ensuring compatibility across Python versions.

Comment on lines +156 to +166
[tool.basedpyright]
reportImplicitStringConcatenation = "none"
reportPrivateUsage = "none" # since ruff will catch this
reportAny= "none"
reportExplicitAny= "none"
include = ["torch_sim", "tests", "examples", "docs"]
exclude = [".venv"]
reportUnknownMemberType = "none"
reportUnknownVariableType = "none"
reportUnknownArgumentType = "none"
reportMissingTypeStubs = "none"
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue

Fix the trailing newline issue flagged by the linter.

The pipeline failure indicates a missing trailing newline at the end of the file.

 reportUnknownVariableType = "none"
 reportUnknownArgumentType = "none"
 reportMissingTypeStubs = "none"
+
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
[tool.basedpyright]
reportImplicitStringConcatenation = "none"
reportPrivateUsage = "none" # since ruff will catch this
reportAny= "none"
reportExplicitAny= "none"
include = ["torch_sim", "tests", "examples", "docs"]
exclude = [".venv"]
reportUnknownMemberType = "none"
reportUnknownVariableType = "none"
reportUnknownArgumentType = "none"
reportMissingTypeStubs = "none"
[tool.basedpyright]
reportImplicitStringConcatenation = "none"
reportPrivateUsage = "none" # since ruff will catch this
reportAny= "none"
reportExplicitAny= "none"
include = ["torch_sim", "tests", "examples", "docs"]
exclude = [".venv"]
reportUnknownMemberType = "none"
reportUnknownVariableType = "none"
reportUnknownArgumentType = "none"
reportMissingTypeStubs = "none"
🤖 Prompt for AI Agents
In pyproject.toml around lines 156 to 166, the file is missing a trailing
newline at EOF causing the linter/pipeline failure; open the file and add a
single POSIX newline character at the end of the file (ensure the final line
ends with '\n'), save and commit so the file ends with a newline.

Copy link

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 0

Caution

Some comments are outside the diff and can’t be posted inline due to platform limitations.

⚠️ Outside diff range comments (2)
torch_sim/state.py (2)

600-609: Bug: state_to_device mutates the input state in-place.
attrs = vars(state) returns the live dict; assigning back to attrs[...] mutates state before constructing the new instance. to() should not have side effects on the original object.

Apply this fix to operate on a new dict:

-    attrs = vars(state)
-    for attr_name, attr_value in attrs.items():
-        if isinstance(attr_value, torch.Tensor):
-            attrs[attr_name] = attr_value.to(device=device)
-
-    attrs["positions"] = attrs["positions"].to(dtype=dtype)
-    attrs["masses"] = attrs["masses"].to(dtype=dtype)
-    attrs["cell"] = attrs["cell"].to(dtype=dtype)
-    attrs["atomic_numbers"] = attrs["atomic_numbers"].to(dtype=torch.int)
-    return type(state)(**attrs)
+    new_attrs: dict[str, Any] = {}
+    for attr_name, attr_value in vars(state).items():
+        if isinstance(attr_value, torch.Tensor):
+            new_attrs[attr_name] = attr_value.to(device=device)
+        else:
+            new_attrs[attr_name] = attr_value
+
+    new_attrs["positions"] = new_attrs["positions"].to(dtype=dtype)
+    new_attrs["masses"] = new_attrs["masses"].to(dtype=dtype)
+    new_attrs["cell"] = new_attrs["cell"].to(dtype=dtype)
+    new_attrs["atomic_numbers"] = new_attrs["atomic_numbers"].to(dtype=torch.int)
+    return type(state)(**new_attrs)

Optional follow-up: also cast other float per-atom/system tensors to dtype to avoid mixed dtypes (e.g., forces/energy) by checking tensor.is_floating_point() and excluding atomic_numbers/system_idx.


949-957: initialize_state ignores device/dtype when given list[SimState].
Currently returns concatenate_states(system) without moving/casting, so the device/dtype arguments are silently dropped; concatenate_states also defaults device to first state's device.

Apply this to normalize each state and honor device:

-    if isinstance(system, list) and all(isinstance(s, SimState) for s in system):
-        if not all(state.n_systems == 1 for state in system):  # pyright: ignore[reportAttributeAccessIssue]
+    if isinstance(system, list) and all(isinstance(s, SimState) for s in system):
+        if not all(s.n_systems == 1 for s in system):
             raise ValueError(
                 "When providing a list of states, to the initialize_state function, "
                 "all states must have n_systems == 1. To fix this, you can split the "
                 "states into individual states with the split_state function."
             )
-        return concatenate_states(system)
+        normed: list[SimState] = [state_to_device(s, device, dtype) for s in system]
+        return concatenate_states(normed, device=device)

Additionally, consider validating that all global attributes (e.g., pbc) agree before concatenation to avoid silent mismatches.

♻️ Duplicate comments (1)
torch_sim/state.py (1)

195-198: n_atoms_per_system simplification is correct.
Direct bincount on system_idx is fine given constructor guarantees contiguity. Matches the earlier note that system_idx is never None.

🧹 Nitpick comments (1)
torch_sim/state.py (1)

430-435: Optional: simplify Union detection to drop deprecated-API ignore.
You can avoid pyright: reportDeprecated by checking against types.UnionType in one shot.

Apply within this block:

-            is_union = origin is typing.Union  # pyright: ignore[reportDeprecated]
-            if not is_union and origin is not None:
-                # For Python 3.10+ `|` syntax, origin is types.UnionType
-                # We check by name to be robust against module reloading/patching issues
-                is_union = origin.__module__ == "types" and origin.__name__ == "UnionType"
+            import types as _types
+            is_union = origin in (getattr(typing, "Union", None), _types.UnionType)
📜 Review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

💡 Knowledge Base configuration:

  • MCP integration is disabled by default for public repositories
  • Jira integration is disabled by default for public repositories
  • Linear integration is disabled by default for public repositories

You can enable these sources in your CodeRabbit configuration.

📥 Commits

Reviewing files that changed from the base of the PR and between 90b4788 and 2c48c3f.

📒 Files selected for processing (1)
  • torch_sim/state.py (12 hunks)
🧰 Additional context used
🧬 Code graph analysis (1)
torch_sim/state.py (3)
torch_sim/io.py (2)
  • atoms_to_state (180-245)
  • phonopy_to_state (316-391)
tests/test_state.py (1)
  • ChildState (47-53)
torch_sim/runners.py (1)
  • StaticState (539-549)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (35)
  • GitHub Check: test-examples (examples/tutorials/using_graphpes_tutorial.py)
  • GitHub Check: test-examples (examples/tutorials/hybrid_swap_tutorial.py)
  • GitHub Check: test-examples (examples/tutorials/high_level_tutorial.py)
  • GitHub Check: test-examples (examples/tutorials/autobatching_tutorial.py)
  • GitHub Check: test-examples (examples/scripts/5_Workflow/5.1_a2c_silicon_batched.py)
  • GitHub Check: test-examples (examples/tutorials/low_level_tutorial.py)
  • GitHub Check: test-examples (examples/scripts/5_Workflow/5.3_Elastic.py)
  • GitHub Check: test-examples (examples/scripts/5_Workflow/5.2_In_Flight_WBM.py)
  • GitHub Check: test-examples (examples/scripts/3_Dynamics/3.12_MACE_NPT_Langevin.py)
  • GitHub Check: test-examples (examples/scripts/7_Others/7.6_Compare_ASE_to_VV_FIRE.py)
  • GitHub Check: test-examples (examples/scripts/3_Dynamics/3.4_MACE_NVT_Langevin.py)
  • GitHub Check: test-examples (examples/scripts/3_Dynamics/3.10_Hybrid_swap_mc.py)
  • GitHub Check: test-examples (examples/scripts/3_Dynamics/3.6_MACE_NVT_Nose_Hoover_temp_profile.py)
  • GitHub Check: test-examples (examples/scripts/3_Dynamics/3.9_MACE_NVT_staggered_stress.py)
  • GitHub Check: test-examples (examples/scripts/3_Dynamics/3.8_MACE_NPT_Nose_Hoover.py)
  • GitHub Check: test-examples (examples/scripts/3_Dynamics/3.13_MACE_NVE_non_pbc.py)
  • GitHub Check: test-examples (examples/scripts/1_Introduction/1.2_MACE.py)
  • GitHub Check: test-examples (examples/scripts/3_Dynamics/3.3_MACE_NVE_cueq.py)
  • GitHub Check: test-examples (examples/scripts/3_Dynamics/3.5_MACE_NVT_Nose_Hoover.py)
  • GitHub Check: test-examples (examples/scripts/3_Dynamics/3.2_MACE_NVE.py)
  • GitHub Check: test-examples (examples/scripts/2_Structural_optimization/2.5_MACE_UnitCellFilter_Gradient_Descen...
  • GitHub Check: test-examples (examples/scripts/2_Structural_optimization/2.4_MACE_FIRE.py)
  • GitHub Check: test-examples (examples/scripts/2_Structural_optimization/2.6_MACE_UnitCellFilter_FIRE.py)
  • GitHub Check: test-examples (examples/scripts/2_Structural_optimization/2.3_MACE_Gradient_Descent.py)
  • GitHub Check: test-examples (examples/scripts/2_Structural_optimization/2.7_MACE_FrechetCellFilter_FIRE.py)
  • GitHub Check: test-examples (examples/scripts/4_High_level_api/4.1_high_level_api.py)
  • GitHub Check: test-examples (examples/scripts/6_Phonons/6.2_QuasiHarmonic_MACE.py)
  • GitHub Check: test-examples (examples/scripts/6_Phonons/6.1_Phonons_MACE.py)
  • GitHub Check: test-examples (examples/scripts/6_Phonons/6.3_Conductivity_MACE.py)
  • GitHub Check: test-examples (examples/scripts/4_High_level_api/4.2_auto_batching_api.py)
  • GitHub Check: test-model (macos-14, 3.12, lowest-direct, sevenn, tests/models/test_sevennet.py)
  • GitHub Check: test-model (macos-14, 3.12, lowest-direct, fairchem, tests/models/test_fairchem.py)
  • GitHub Check: test-model (macos-14, 3.11, highest, mace, tests/test_optimizers_vs_ase.py)
  • GitHub Check: test-model (macos-14, 3.11, highest, graphpes, tests/models/test_graphpes.py)
  • GitHub Check: build-docs
🔇 Additional comments (8)
torch_sim/state.py (8)

12-15: PEP 585 imports and overload usage look good.
Moving Sequence to collections.abc and adding overload is appropriate.


411-422: init_subclass kwarg typing: LGTM.
Explicit Any on kwargs improves clarity with pyright.


571-575: Index normalization branch is fine.
The explicit torch.Tensor branch plus final TypeError keeps behavior clear; the pyright ignores are acceptable.


612-622: Overloads for get_attrs_for_scope: nice typing improvement.
This gives callers precise yields per scope.


637-641: KeyError vs ValueError behavior change—confirm API expectations.
Indexing the map with [scope] raises KeyError on bad input; prior behavior raised ValueError. If external callers rely on ValueError, consider preserving it.

Option to keep ValueError:

-    attr_names = {
+    mapping = {
         "per-atom": state._atom_attributes,  # noqa: SLF001
         "per-system": state._system_attributes,  # noqa: SLF001
         "global": state._global_attributes,  # noqa: SLF001
-    }[scope]
+    }
+    try:
+        attr_names = mapping[scope]
+    except KeyError as e:
+        raise ValueError(f"Unknown scope: {scope!r}") from e

650-650: Return type on _filter_attrs_by_mask clarified: good.
dict[str, Any] matches the mixed global/per-atom/system outputs.


714-715: Split: behavior change to reset system_idx to zeros—verify invariants.
The new logic bins by system_idx, splits tensors, and assigns system_idx=torch.zeros(...) per substate. This is sensible; please confirm downstream code never expects preserved original system numbering within single-system states.

Also applies to: 716-724, 728-729, 746-746


878-881: Typed defaultdicts and explicit new_system_indices list: LGTM.
Clearer types; concatenation logic below remains efficient.

@curtischong
Copy link
Collaborator Author

closing because I feel like mypy has much better pytorch support. tinygrad AND pytorch use mypy. I feel like pyright is better for more generic python software but not for this.

@curtischong curtischong deleted the more-types2 branch September 2, 2025 21:50
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
cla-signed Contributor license agreement signed
Projects
None yet
Development

Successfully merging this pull request may close these issues.

1 participant