-
Notifications
You must be signed in to change notification settings - Fork 39
Basic Pyright checks for state.py #241
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
WalkthroughAdds a BasedPyright config to pyproject.toml and refactors typing and control flow in Changes
Sequence Diagram(s)sequenceDiagram
autonumber
actor Caller
participant State as SimState
participant ScopeMap as Scope Dispatch
Caller->>State: get_attrs_for_scope(scope)
State->>ScopeMap: dispatch[scope]
alt scope in {"per-atom","per-system"}
ScopeMap-->>Caller: Generator[(str, Tensor)]
else scope == "global"
ScopeMap-->>Caller: Generator[(str, Any)]
else unknown scope
ScopeMap-->>Caller: KeyError
end
sequenceDiagram
autonumber
participant Orchestrator
participant Split as _split_state
participant Concat as concatenate_states
Orchestrator->>Split: split multi-system state
Split->>Split: compute system_sizes
Split->>Split: slice per-atom/system tensors
Split->>Split: init system_idx = zeros(len(system))
Split-->>Orchestrator: list[SimState]
Orchestrator->>Concat: concatenate list[SimState]
Concat->>Concat: stack per-atom/system tensors
Concat-->>Orchestrator: merged SimState
Estimated code review effort🎯 3 (Moderate) | ⏱️ ~25 minutes Possibly related PRs
Suggested reviewers
Poem
Tip 🔌 Remote MCP (Model Context Protocol) integration is now available!Pro plan users can now connect to remote MCP servers from the Integrations page. Connect with popular remote MCPs such as Notion and Linear to add more context to your reviews and chats. ✨ Finishing Touches
🧪 Generate unit tests
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. 🪧 TipsChatThere are 3 ways to chat with CodeRabbit:
SupportNeed help? Create a ticket on our support page for assistance with any issues or questions. CodeRabbit Commands (Invoked using PR/Issue comments)Type Other keywords and placeholders
CodeRabbit Configuration File (
|
@@ -606,14 +602,25 @@ def state_to_device( | |||
if isinstance(attr_value, torch.Tensor): | |||
attrs[attr_name] = attr_value.to(device=device) | |||
|
|||
if dtype is not None: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
we assign dtype, so this can never be None
@@ -194,11 +194,7 @@ def n_atoms(self) -> int: | |||
@property | |||
def n_atoms_per_system(self) -> torch.Tensor: | |||
"""Number of atoms per system.""" | |||
return ( | |||
self.system_idx.bincount() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
system_idx will never be none
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 1
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (1)
torch_sim/typing.py (1)
10-14
: Move imports out of TYPE_CHECKING block as they're used at runtime.The linter correctly identifies that these imports are being used in the
StateLike
type alias at runtime (lines 43-52), not just for type hints. They should be moved outside the TYPE_CHECKING block to avoid runtimeNameError
exceptions.-if TYPE_CHECKING: - from ase import Atoms - from phonopy.structure.atoms import PhonopyAtoms - from pymatgen.core import Structure - - from torch_sim.state import SimState +from typing import TYPE_CHECKING + +# These imports are used at runtime in the StateLike type alias +from ase import Atoms +from phonopy.structure.atoms import PhonopyAtoms +from pymatgen.core import Structure +from torch_sim.state import SimStateHowever, if these libraries are optional dependencies, consider using string literals in the type alias instead:
StateLike = ( "Atoms" | "Structure" | "PhonopyAtoms" | list["Atoms"] | list["Structure"] | list["PhonopyAtoms"] | "SimState" | list["SimState"] )
🧹 Nitpick comments (4)
pyproject.toml (1)
158-158
: Consider keepingreportPrivateUsage
enabled for BasedPyright.While the comment states "since ruff will catch this", ruff's private usage checks (N801-N807) primarily focus on naming conventions, not actual access of private members from outside their defining class. BasedPyright's
reportPrivateUsage
provides more comprehensive checking for accessing private attributes/methods, which is an important aspect of encapsulation that ruff doesn't fully cover.-reportPrivateUsage = "none" # since ruff will catch this +# reportPrivateUsage is kept at default to catch private member access violationstorch_sim/state.py (3)
637-641
: Dictionary dispatch is cleaner but raises KeyError instead of ValueError.The change from if/elif to dictionary dispatch improves readability. However, this changes the exception type from
ValueError
toKeyError
for unknown scopes. WhileKeyError
is technically correct for dictionary lookups,ValueError
might be more appropriate for invalid parameter values.Consider wrapping in a try/except to maintain the
ValueError
:- attr_names = { - "per-atom": state._atom_attributes, # noqa: SLF001 - "per-system": state._system_attributes, # noqa: SLF001 - "global": state._global_attributes, # noqa: SLF001 - }[scope] + scope_map = { + "per-atom": state._atom_attributes, # noqa: SLF001 + "per-system": state._system_attributes, # noqa: SLF001 + "global": state._global_attributes, # noqa: SLF001 + } + try: + attr_names = scope_map[scope] + except KeyError: + raise ValueError(f"Unknown scope: {scope}") from None
571-574
: Pyright ignore comments could be avoided with better type guards.The
reportUnnecessaryIsInstance
andreportUnreachable
warnings suggest pyright doesn't understand the type narrowing here. Consider using a type guard or reorganizing the logic.- if isinstance(system_indices, torch.Tensor): # pyright: ignore[reportUnnecessaryIsInstance] - # Handle negative indices in tensors - return torch.where(system_indices < 0, n_systems + system_indices, system_indices) - raise TypeError(f"Unsupported index type: {type(system_indices)}") # pyright: ignore[reportUnreachable] + # At this point, system_indices must be a tensor if we've handled all other cases + assert isinstance(system_indices, torch.Tensor), f"Unsupported index type: {type(system_indices)}" + # Handle negative indices in tensors + return torch.where(system_indices < 0, n_systems + system_indices, system_indices)
950-950
: Consider addressing the pyright ignore more elegantly.The
reportAttributeAccessIssue
ignore suggests pyright doesn't understand that all items in the list are SimState objects. This was already checked on line 949.Consider using a type guard or assertion to help the type checker:
- if isinstance(system, list) and all(isinstance(s, SimState) for s in system): - if not all(state.n_systems == 1 for state in system): # pyright: ignore[reportAttributeAccessIssue] + if isinstance(system, list) and all(isinstance(s, SimState) for s in system): + # Type narrowing for pyright + sim_states = cast(list[SimState], system) + if not all(state.n_systems == 1 for state in sim_states):
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
💡 Knowledge Base configuration:
- MCP integration is disabled by default for public repositories
- Jira integration is disabled by default for public repositories
- Linear integration is disabled by default for public repositories
You can enable these sources in your CodeRabbit configuration.
📒 Files selected for processing (3)
pyproject.toml
(1 hunks)torch_sim/state.py
(13 hunks)torch_sim/typing.py
(2 hunks)
🧰 Additional context used
🧬 Code graph analysis (2)
torch_sim/typing.py (2)
torch_sim/state.py (1)
SimState
(29-489)torch_sim/io.py (3)
phonopy_to_state
(316-391)atoms_to_state
(180-245)state_to_phonopy
(131-177)
torch_sim/state.py (2)
torch_sim/io.py (1)
atoms_to_state
(180-245)torch_sim/optimizers.py (1)
FireState
(484-536)
🪛 GitHub Actions: Linting
pyproject.toml
[error] 1-1: End-of-file-fixer: Fixed missing trailing newline in pyproject.toml; re-run pre-commit to verify.
torch_sim/typing.py
[error] 10-10: TC004 Move import ase.Atoms
out of type-checking block. Import is used for more than type hinting.
[error] 11-11: TC004 Move import phonopy.structure.atoms.PhonopyAtoms
out of type-checking block. Import is used for more than type hinting.
[error] 12-12: TC004 Move import pymatgen.core.Structure
out of type-checking block. Import is used for more than type hinting.
[error] 14-14: TC004 Move import torch_sim.state.SimState
out of type-checking block. Import is used for more than type hinting.
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (17)
- GitHub Check: test-examples (examples/tutorials/using_graphpes_tutorial.py)
- GitHub Check: test-examples (examples/tutorials/state_tutorial.py)
- GitHub Check: test-examples (examples/tutorials/hybrid_swap_tutorial.py)
- GitHub Check: test-examples (examples/tutorials/reporting_tutorial.py)
- GitHub Check: test-examples (examples/tutorials/autobatching_tutorial.py)
- GitHub Check: test-examples (examples/scripts/5_Workflow/5.3_Elastic.py)
- GitHub Check: test-examples (examples/scripts/5_Workflow/5.1_a2c_silicon_batched.py)
- GitHub Check: test-examples (examples/scripts/3_Dynamics/3.10_Hybrid_swap_mc.py)
- GitHub Check: test-examples (examples/scripts/1_Introduction/1.2_MACE.py)
- GitHub Check: test-examples (examples/scripts/3_Dynamics/3.2_MACE_NVE.py)
- GitHub Check: test-examples (examples/scripts/6_Phonons/6.2_QuasiHarmonic_MACE.py)
- GitHub Check: test-model (macos-14, 3.11, highest, orb, tests/models/test_orb.py)
- GitHub Check: test-model (macos-14, 3.12, lowest-direct, orb, tests/models/test_orb.py)
- GitHub Check: test-model (macos-14, 3.12, lowest-direct, mace, tests/test_optimizers_vs_ase.py)
- GitHub Check: test-model (macos-14, 3.12, lowest-direct, fairchem, tests/models/test_fairchem.py)
- GitHub Check: test-model (macos-14, 3.12, lowest-direct, mattersim, tests/models/test_mattersim.py)
- GitHub Check: build-docs
🔇 Additional comments (8)
torch_sim/typing.py (1)
43-52
: Good modernization of the type alias syntax.The change from
Union[...]
to the modern|
syntax (PEP 604) improves readability. The switch fromSimStateVar
to concreteSimState
types also makes the type more explicit.torch_sim/state.py (7)
411-411
: Good addition of type annotation for kwargs.Adding
Any
type annotation to**kwargs
improves type clarity for the__init_subclass__
method.
612-622
: Excellent use of overloads for type safety.The overloaded signatures for
get_attrs_for_scope
correctly express that:
- "per-atom" and "per-system" scopes always return
torch.Tensor
values- "global" scope can return
Any
type valuesThis provides better type information to callers and helps catch type errors at static analysis time.
197-197
: Simplified implementation but changes behavior for edge cases.The removal of the None check simplifies the code. However, this assumes
system_idx
is never None at this point, which should be guaranteed by the constructor's initialization logic.
605-608
: Simplified dtype handling looks correct.The unconditional dtype casting is cleaner since dtype is guaranteed to be non-None at this point (line 598 ensures a default).
714-746
: Good type annotations and proper system_idx initialization.The explicit type annotations for
system_sizes
,split_per_atom
,split_per_system
, andstates
improve type clarity. The change to initializesystem_idx
with zeros for each split system is correct - each resulting single-system state should have all atoms belonging to system 0.
878-880
: Proper use of defaultdict with type parameters.The type annotations for
defaultdict[str, list[torch.Tensor]]
correctly specify the expected types for the dictionary values.
430-430
: Good handling of Python 3.10+ union syntax.The code correctly handles both
typing.Union
and Python 3.10+|
syntax (types.UnionType) for union types, ensuring compatibility across Python versions.
[tool.basedpyright] | ||
reportImplicitStringConcatenation = "none" | ||
reportPrivateUsage = "none" # since ruff will catch this | ||
reportAny= "none" | ||
reportExplicitAny= "none" | ||
include = ["torch_sim", "tests", "examples", "docs"] | ||
exclude = [".venv"] | ||
reportUnknownMemberType = "none" | ||
reportUnknownVariableType = "none" | ||
reportUnknownArgumentType = "none" | ||
reportMissingTypeStubs = "none" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Fix the trailing newline issue flagged by the linter.
The pipeline failure indicates a missing trailing newline at the end of the file.
reportUnknownVariableType = "none"
reportUnknownArgumentType = "none"
reportMissingTypeStubs = "none"
+
📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
[tool.basedpyright] | |
reportImplicitStringConcatenation = "none" | |
reportPrivateUsage = "none" # since ruff will catch this | |
reportAny= "none" | |
reportExplicitAny= "none" | |
include = ["torch_sim", "tests", "examples", "docs"] | |
exclude = [".venv"] | |
reportUnknownMemberType = "none" | |
reportUnknownVariableType = "none" | |
reportUnknownArgumentType = "none" | |
reportMissingTypeStubs = "none" | |
[tool.basedpyright] | |
reportImplicitStringConcatenation = "none" | |
reportPrivateUsage = "none" # since ruff will catch this | |
reportAny= "none" | |
reportExplicitAny= "none" | |
include = ["torch_sim", "tests", "examples", "docs"] | |
exclude = [".venv"] | |
reportUnknownMemberType = "none" | |
reportUnknownVariableType = "none" | |
reportUnknownArgumentType = "none" | |
reportMissingTypeStubs = "none" | |
🤖 Prompt for AI Agents
In pyproject.toml around lines 156 to 166, the file is missing a trailing
newline at EOF causing the linter/pipeline failure; open the file and add a
single POSIX newline character at the end of the file (ensure the final line
ends with '\n'), save and commit so the file ends with a newline.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 0
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (2)
torch_sim/state.py (2)
600-609
: Bug: state_to_device mutates the input state in-place.
attrs = vars(state) returns the live dict; assigning back to attrs[...] mutates state before constructing the new instance. to() should not have side effects on the original object.Apply this fix to operate on a new dict:
- attrs = vars(state) - for attr_name, attr_value in attrs.items(): - if isinstance(attr_value, torch.Tensor): - attrs[attr_name] = attr_value.to(device=device) - - attrs["positions"] = attrs["positions"].to(dtype=dtype) - attrs["masses"] = attrs["masses"].to(dtype=dtype) - attrs["cell"] = attrs["cell"].to(dtype=dtype) - attrs["atomic_numbers"] = attrs["atomic_numbers"].to(dtype=torch.int) - return type(state)(**attrs) + new_attrs: dict[str, Any] = {} + for attr_name, attr_value in vars(state).items(): + if isinstance(attr_value, torch.Tensor): + new_attrs[attr_name] = attr_value.to(device=device) + else: + new_attrs[attr_name] = attr_value + + new_attrs["positions"] = new_attrs["positions"].to(dtype=dtype) + new_attrs["masses"] = new_attrs["masses"].to(dtype=dtype) + new_attrs["cell"] = new_attrs["cell"].to(dtype=dtype) + new_attrs["atomic_numbers"] = new_attrs["atomic_numbers"].to(dtype=torch.int) + return type(state)(**new_attrs)Optional follow-up: also cast other float per-atom/system tensors to dtype to avoid mixed dtypes (e.g., forces/energy) by checking tensor.is_floating_point() and excluding atomic_numbers/system_idx.
949-957
: initialize_state ignores device/dtype when given list[SimState].
Currently returns concatenate_states(system) without moving/casting, so the device/dtype arguments are silently dropped; concatenate_states also defaults device to first state's device.Apply this to normalize each state and honor device:
- if isinstance(system, list) and all(isinstance(s, SimState) for s in system): - if not all(state.n_systems == 1 for state in system): # pyright: ignore[reportAttributeAccessIssue] + if isinstance(system, list) and all(isinstance(s, SimState) for s in system): + if not all(s.n_systems == 1 for s in system): raise ValueError( "When providing a list of states, to the initialize_state function, " "all states must have n_systems == 1. To fix this, you can split the " "states into individual states with the split_state function." ) - return concatenate_states(system) + normed: list[SimState] = [state_to_device(s, device, dtype) for s in system] + return concatenate_states(normed, device=device)Additionally, consider validating that all global attributes (e.g., pbc) agree before concatenation to avoid silent mismatches.
♻️ Duplicate comments (1)
torch_sim/state.py (1)
195-198
: n_atoms_per_system simplification is correct.
Direct bincount on system_idx is fine given constructor guarantees contiguity. Matches the earlier note that system_idx is never None.
🧹 Nitpick comments (1)
torch_sim/state.py (1)
430-435
: Optional: simplify Union detection to drop deprecated-API ignore.
You can avoid pyright: reportDeprecated by checking against types.UnionType in one shot.Apply within this block:
- is_union = origin is typing.Union # pyright: ignore[reportDeprecated] - if not is_union and origin is not None: - # For Python 3.10+ `|` syntax, origin is types.UnionType - # We check by name to be robust against module reloading/patching issues - is_union = origin.__module__ == "types" and origin.__name__ == "UnionType" + import types as _types + is_union = origin in (getattr(typing, "Union", None), _types.UnionType)
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
💡 Knowledge Base configuration:
- MCP integration is disabled by default for public repositories
- Jira integration is disabled by default for public repositories
- Linear integration is disabled by default for public repositories
You can enable these sources in your CodeRabbit configuration.
📒 Files selected for processing (1)
torch_sim/state.py
(12 hunks)
🧰 Additional context used
🧬 Code graph analysis (1)
torch_sim/state.py (3)
torch_sim/io.py (2)
atoms_to_state
(180-245)phonopy_to_state
(316-391)tests/test_state.py (1)
ChildState
(47-53)torch_sim/runners.py (1)
StaticState
(539-549)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (35)
- GitHub Check: test-examples (examples/tutorials/using_graphpes_tutorial.py)
- GitHub Check: test-examples (examples/tutorials/hybrid_swap_tutorial.py)
- GitHub Check: test-examples (examples/tutorials/high_level_tutorial.py)
- GitHub Check: test-examples (examples/tutorials/autobatching_tutorial.py)
- GitHub Check: test-examples (examples/scripts/5_Workflow/5.1_a2c_silicon_batched.py)
- GitHub Check: test-examples (examples/tutorials/low_level_tutorial.py)
- GitHub Check: test-examples (examples/scripts/5_Workflow/5.3_Elastic.py)
- GitHub Check: test-examples (examples/scripts/5_Workflow/5.2_In_Flight_WBM.py)
- GitHub Check: test-examples (examples/scripts/3_Dynamics/3.12_MACE_NPT_Langevin.py)
- GitHub Check: test-examples (examples/scripts/7_Others/7.6_Compare_ASE_to_VV_FIRE.py)
- GitHub Check: test-examples (examples/scripts/3_Dynamics/3.4_MACE_NVT_Langevin.py)
- GitHub Check: test-examples (examples/scripts/3_Dynamics/3.10_Hybrid_swap_mc.py)
- GitHub Check: test-examples (examples/scripts/3_Dynamics/3.6_MACE_NVT_Nose_Hoover_temp_profile.py)
- GitHub Check: test-examples (examples/scripts/3_Dynamics/3.9_MACE_NVT_staggered_stress.py)
- GitHub Check: test-examples (examples/scripts/3_Dynamics/3.8_MACE_NPT_Nose_Hoover.py)
- GitHub Check: test-examples (examples/scripts/3_Dynamics/3.13_MACE_NVE_non_pbc.py)
- GitHub Check: test-examples (examples/scripts/1_Introduction/1.2_MACE.py)
- GitHub Check: test-examples (examples/scripts/3_Dynamics/3.3_MACE_NVE_cueq.py)
- GitHub Check: test-examples (examples/scripts/3_Dynamics/3.5_MACE_NVT_Nose_Hoover.py)
- GitHub Check: test-examples (examples/scripts/3_Dynamics/3.2_MACE_NVE.py)
- GitHub Check: test-examples (examples/scripts/2_Structural_optimization/2.5_MACE_UnitCellFilter_Gradient_Descen...
- GitHub Check: test-examples (examples/scripts/2_Structural_optimization/2.4_MACE_FIRE.py)
- GitHub Check: test-examples (examples/scripts/2_Structural_optimization/2.6_MACE_UnitCellFilter_FIRE.py)
- GitHub Check: test-examples (examples/scripts/2_Structural_optimization/2.3_MACE_Gradient_Descent.py)
- GitHub Check: test-examples (examples/scripts/2_Structural_optimization/2.7_MACE_FrechetCellFilter_FIRE.py)
- GitHub Check: test-examples (examples/scripts/4_High_level_api/4.1_high_level_api.py)
- GitHub Check: test-examples (examples/scripts/6_Phonons/6.2_QuasiHarmonic_MACE.py)
- GitHub Check: test-examples (examples/scripts/6_Phonons/6.1_Phonons_MACE.py)
- GitHub Check: test-examples (examples/scripts/6_Phonons/6.3_Conductivity_MACE.py)
- GitHub Check: test-examples (examples/scripts/4_High_level_api/4.2_auto_batching_api.py)
- GitHub Check: test-model (macos-14, 3.12, lowest-direct, sevenn, tests/models/test_sevennet.py)
- GitHub Check: test-model (macos-14, 3.12, lowest-direct, fairchem, tests/models/test_fairchem.py)
- GitHub Check: test-model (macos-14, 3.11, highest, mace, tests/test_optimizers_vs_ase.py)
- GitHub Check: test-model (macos-14, 3.11, highest, graphpes, tests/models/test_graphpes.py)
- GitHub Check: build-docs
🔇 Additional comments (8)
torch_sim/state.py (8)
12-15
: PEP 585 imports and overload usage look good.
Moving Sequence to collections.abc and adding overload is appropriate.
411-422
: init_subclass kwarg typing: LGTM.
Explicit Any on kwargs improves clarity with pyright.
571-575
: Index normalization branch is fine.
The explicit torch.Tensor branch plus final TypeError keeps behavior clear; the pyright ignores are acceptable.
612-622
: Overloads for get_attrs_for_scope: nice typing improvement.
This gives callers precise yields per scope.
637-641
: KeyError vs ValueError behavior change—confirm API expectations.
Indexing the map with [scope] raises KeyError on bad input; prior behavior raised ValueError. If external callers rely on ValueError, consider preserving it.Option to keep ValueError:
- attr_names = { + mapping = { "per-atom": state._atom_attributes, # noqa: SLF001 "per-system": state._system_attributes, # noqa: SLF001 "global": state._global_attributes, # noqa: SLF001 - }[scope] + } + try: + attr_names = mapping[scope] + except KeyError as e: + raise ValueError(f"Unknown scope: {scope!r}") from e
650-650
: Return type on _filter_attrs_by_mask clarified: good.
dict[str, Any] matches the mixed global/per-atom/system outputs.
714-715
: Split: behavior change to reset system_idx to zeros—verify invariants.
The new logic bins by system_idx, splits tensors, and assigns system_idx=torch.zeros(...) per substate. This is sensible; please confirm downstream code never expects preserved original system numbering within single-system states.Also applies to: 716-724, 728-729, 746-746
878-881
: Typed defaultdicts and explicit new_system_indices list: LGTM.
Clearer types; concatenation logic below remains efficient.
closing because I feel like mypy has much better pytorch support. tinygrad AND pytorch use mypy. I feel like pyright is better for more generic python software but not for this. |
Summary
torch.Tensor
(and it'll never beNone
). I'll make a few PRs but this is just the first. I tried to get state.py a bit typed just to start a conversation.overload
s forget_attrs_for_scope
so when ppl are using them, they will know the types that are returned. This is NOT a design decision because already; all atom and system attributes must betorch.Tensor
whereas all global attributes can be Any. This must be enforced because when we split sim_states, we usetorch.split
on those per-atom and per-system attributes.Checklist
Before a pull request can be merged, the following items must be checked:
We highly recommended installing the pre-commit hooks running in CI locally to speedup the development process. Simply run
pip install pre-commit && pre-commit install
to install the hooks which will check your code before each commit.Summary by CodeRabbit