Skip to content

Commit 5e181ed

Browse files
committed
rm classvar annotation from all simstate
1 parent 9cbcd87 commit 5e181ed

File tree

10 files changed

+34
-41
lines changed

10 files changed

+34
-41
lines changed

examples/scripts/3_Dynamics/3.10_Hybrid_swap_mc.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,6 @@
88
# ///
99

1010
from dataclasses import dataclass
11-
from typing import ClassVar
1211

1312
import torch
1413
from mace.calculators.foundations_models import mace_mp
@@ -77,7 +76,7 @@ class HybridSwapMCState(MDState):
7776
"""
7877

7978
last_permutation: torch.Tensor
80-
_atom_attributes: ClassVar[set[str]] = (
79+
_atom_attributes = (
8180
MDState._atom_attributes | {"last_permutation"} # noqa: SLF001
8281
)
8382

examples/tutorials/hybrid_swap_tutorial.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -106,7 +106,7 @@ class HybridSwapMCState(ts.integrators.MDState):
106106
"""
107107

108108
last_permutation: torch.Tensor
109-
_atom_attributes: ClassVar[set[str]] = (
109+
_atom_attributes = (
110110
MDState._atom_attributes | {"last_permutation"} # noqa: SLF001
111111
)
112112

tests/test_state.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -66,7 +66,7 @@ class ChildState(SimState):
6666
duplicated_attribute: bool
6767

6868
_system_attributes: typing.ClassVar[set[str]] = (
69-
SimState._atom_attributes | {"duplicated_attribute"} # noqa: SLF001
69+
SimState._system_attributes | {"duplicated_attribute"} # noqa: SLF001
7070
)
7171
_global_attributes: typing.ClassVar[set[str]] = (
7272
SimState._global_attributes | {"duplicated_attribute"} # noqa: SLF001

torch_sim/integrators/md.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,6 @@
22

33
from collections.abc import Callable
44
from dataclasses import dataclass
5-
from typing import ClassVar
65

76
import torch
87

@@ -42,10 +41,10 @@ class MDState(SimState):
4241
energy: torch.Tensor
4342
forces: torch.Tensor
4443

45-
_atom_attributes: ClassVar[set[str]] = (
44+
_atom_attributes = (
4645
SimState._atom_attributes | {"momenta", "forces"} # noqa: SLF001
4746
)
48-
_system_attributes: ClassVar[set[str]] = (
47+
_system_attributes = (
4948
SimState._system_attributes | {"energy"} # noqa: SLF001
5049
)
5150

torch_sim/integrators/npt.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
from collections.abc import Callable
44
from dataclasses import dataclass
5-
from typing import Any, ClassVar
5+
from typing import Any
66

77
import torch
88

@@ -67,10 +67,10 @@ class NPTLangevinState(SimState):
6767
cell_velocities: torch.Tensor
6868
cell_masses: torch.Tensor
6969

70-
_atom_attributes: ClassVar[set[str]] = (
70+
_atom_attributes = (
7171
SimState._atom_attributes | {"forces", "velocities"} # noqa: SLF001
7272
)
73-
_system_attributes: ClassVar[set[str]] = SimState._system_attributes | { # noqa: SLF001
73+
_system_attributes = SimState._system_attributes | { # noqa: SLF001
7474
"stress",
7575
"cell_positions",
7676
"cell_velocities",
@@ -879,7 +879,7 @@ class NPTNoseHooverState(MDState):
879879
barostat: NoseHooverChain
880880
barostat_fns: NoseHooverChainFns
881881

882-
_system_attributes: ClassVar[set[str]] = (
882+
_system_attributes = (
883883
MDState._system_attributes # noqa: SLF001
884884
| {
885885
"reference_cell",
@@ -888,7 +888,7 @@ class NPTNoseHooverState(MDState):
888888
"cell_mass",
889889
}
890890
)
891-
_global_attributes: ClassVar[set[str]] = (
891+
_global_attributes = (
892892
MDState._global_attributes # noqa: SLF001
893893
| {
894894
"thermostat",

torch_sim/integrators/nvt.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
from collections.abc import Callable
44
from dataclasses import dataclass
5-
from typing import Any, ClassVar
5+
from typing import Any
66

77
import torch
88

@@ -266,7 +266,7 @@ class NVTNoseHooverState(MDState):
266266
chain: NoseHooverChain
267267
_chain_fns: NoseHooverChainFns
268268

269-
_global_attributes: ClassVar[set[str]] = (
269+
_global_attributes = (
270270
MDState._global_attributes | {"chain", "_chain_fns"} # noqa: SLF001
271271
)
272272

torch_sim/monte_carlo.py

Lines changed: 2 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,6 @@
1212

1313
from collections.abc import Callable
1414
from dataclasses import dataclass
15-
from typing import ClassVar
1615

1716
import torch
1817

@@ -37,12 +36,8 @@ class SwapMCState(SimState):
3736
energy: torch.Tensor
3837
last_permutation: torch.Tensor
3938

40-
_atom_attributes: ClassVar[set[str]] = (
41-
SimState._atom_attributes | {"last_permutation"} # noqa: SLF001
42-
)
43-
_system_attributes: ClassVar[set[str]] = (
44-
SimState._system_attributes | {"energy"} # noqa: SLF001
45-
)
39+
_atom_attributes = SimState._atom_attributes | {"last_permutation"} # noqa: SLF001
40+
_system_attributes = SimState._system_attributes | {"energy"} # noqa: SLF001
4641

4742

4843
def generate_swaps(

torch_sim/optimizers.py

Lines changed: 11 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020
import functools
2121
from collections.abc import Callable
2222
from dataclasses import dataclass
23-
from typing import Any, ClassVar, Literal, get_args
23+
from typing import Any, Literal, get_args
2424

2525
import torch
2626

@@ -221,7 +221,7 @@ class UnitCellGDState(GDState, DeformGradMixin):
221221
cell_forces: torch.Tensor
222222
cell_masses: torch.Tensor
223223

224-
_system_attributes: ClassVar[set[str]] = (
224+
_system_attributes = (
225225
GDState._system_attributes # noqa: SLF001
226226
| DeformGradMixin._system_attributes # noqa: SLF001
227227
| {
@@ -233,7 +233,7 @@ class UnitCellGDState(GDState, DeformGradMixin):
233233
"cell_masses",
234234
}
235235
)
236-
_global_attributes: ClassVar[set[str]] = (
236+
_global_attributes = (
237237
GDState._global_attributes | {"hydrostatic_strain", "constant_volume"} # noqa: SLF001
238238
)
239239

@@ -524,8 +524,8 @@ class FireState(SimState):
524524
alpha: torch.Tensor
525525
n_pos: torch.Tensor
526526

527-
_atom_attributes: ClassVar[set[str]] = md_atom_attributes
528-
_system_attributes: ClassVar[set[str]] = (
527+
_atom_attributes = md_atom_attributes
528+
_system_attributes = (
529529
SimState._system_attributes # noqa: SLF001
530530
| {
531531
"energy",
@@ -745,9 +745,9 @@ class UnitCellFireState(SimState, DeformGradMixin):
745745
alpha: torch.Tensor
746746
n_pos: torch.Tensor
747747

748-
_atom_attributes: ClassVar[set[str]] = md_atom_attributes
749-
_system_attributes: ClassVar[set[str]] = _fire_system_attributes
750-
_global_attributes: ClassVar[set[str]] = _fire_global_attributes
748+
_atom_attributes = md_atom_attributes
749+
_system_attributes = _fire_system_attributes
750+
_global_attributes = _fire_global_attributes
751751

752752

753753
def unit_cell_fire(
@@ -1037,9 +1037,9 @@ class FrechetCellFIREState(SimState, DeformGradMixin):
10371037
alpha: torch.Tensor
10381038
n_pos: torch.Tensor
10391039

1040-
_atom_attributes: ClassVar[set[str]] = md_atom_attributes
1041-
_system_attributes: ClassVar[set[str]] = _fire_system_attributes
1042-
_global_attributes: ClassVar[set[str]] = _fire_global_attributes
1040+
_atom_attributes = md_atom_attributes
1041+
_system_attributes = _fire_system_attributes
1042+
_global_attributes = _fire_global_attributes
10431043

10441044

10451045
def frechet_cell_fire(

torch_sim/runners.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
from collections.abc import Callable
1010
from dataclasses import dataclass
1111
from itertools import chain
12-
from typing import Any, ClassVar
12+
from typing import Any
1313

1414
import torch
1515
from tqdm import tqdm
@@ -541,10 +541,10 @@ class StaticState(type(state)):
541541
forces: torch.Tensor
542542
stress: torch.Tensor
543543

544-
_atom_attributes: ClassVar[set[str]] = (
544+
_atom_attributes = (
545545
state._atom_attributes | {"forces"} # noqa: SLF001
546546
)
547-
_system_attributes: ClassVar[set[str]] = (
547+
_system_attributes = (
548548
state._system_attributes | {"energy", "stress"} # noqa: SLF001
549549
)
550550

torch_sim/state.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -477,13 +477,13 @@ def _assert_all_attributes_have_defined_scope(cls) -> None:
477477
vars(cls).get(attr_name), property
478478
)
479479
is_method = hasattr(cls, attr_name) and callable(getattr(cls, attr_name))
480-
is_scope_list = attr_name in [
481-
"_atom_attributes",
482-
"_system_attributes",
483-
"_global_attributes",
484-
]
480+
is_class_variable = (
481+
# Note: _atom_attributes, _system_attributes, and _global_attributes
482+
# are all class variables
483+
typing.get_origin(all_annotations.get(attr_name)) is typing.ClassVar
484+
)
485485

486-
if is_special_attribute or is_property or is_method or is_scope_list:
486+
if is_special_attribute or is_property or is_method or is_class_variable:
487487
continue
488488

489489
if attr_name not in all_defined_attributes:

0 commit comments

Comments
 (0)