Skip to content

Commit efa7e00

Browse files
committed
swap from tuple to set
1 parent 3d89fb3 commit efa7e00

File tree

8 files changed

+118
-89
lines changed

8 files changed

+118
-89
lines changed

examples/scripts/3_Dynamics/3.10_Hybrid_swap_mc.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
# ///
99

1010
from dataclasses import dataclass
11+
from typing import ClassVar
1112

1213
import torch
1314
from mace.calculators.foundations_models import mace_mp
@@ -76,7 +77,9 @@ class HybridSwapMCState(MDState):
7677
"""
7778

7879
last_permutation: torch.Tensor
79-
_atom_attributes = (*MDState._atom_attributes, "last_permutation") # noqa: SLF001
80+
_atom_attributes: ClassVar[set[str]] = (
81+
MDState._atom_attributes | {"last_permutation"} # noqa: SLF001
82+
)
8083

8184

8285
nvt_init, nvt_step = nvt_langevin(model=model, dt=0.002, kT=kT)

tests/test_state.py

Lines changed: 9 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -48,9 +48,8 @@ class ChildState(SimState):
4848
attribute_specified_in_scopes: bool
4949
attribute_not_specified_in_scopes: bool
5050

51-
_atom_attributes = (
52-
*SimState._atom_attributes, # noqa: SLF001
53-
"attribute_specified_in_scopes",
51+
_atom_attributes: typing.ClassVar[set[str]] = (
52+
SimState._atom_attributes | {"attribute_specified_in_scopes"} # noqa: SLF001
5453
)
5554

5655
assert "attribute_not_specified_in_scopes" in str(excinfo.value)
@@ -66,13 +65,11 @@ def test_no_duplicate_attributes_in_scopes() -> None:
6665
class ChildState(SimState):
6766
duplicated_attribute: bool
6867

69-
_system_attributes = (
70-
*SimState._atom_attributes, # noqa: SLF001
71-
"duplicated_attribute",
68+
_system_attributes: typing.ClassVar[set[str]] = (
69+
SimState._atom_attributes | {"duplicated_attribute"} # noqa: SLF001
7270
)
73-
_global_attributes = (
74-
*SimState._global_attributes, # noqa: SLF001
75-
"duplicated_attribute",
71+
_global_attributes: typing.ClassVar[set[str]] = (
72+
SimState._global_attributes | {"duplicated_attribute"} # noqa: SLF001
7673
)
7774

7875
assert "are declared multiple times" in str(excinfo.value)
@@ -518,9 +515,9 @@ def test_column_vector_cell(si_sim_state: ts.SimState) -> None:
518515
class DeformState(SimState, DeformGradMixin):
519516
"""Test class that combines SimState with DeformGradMixin."""
520517

521-
_system_attributes = (
522-
*SimState._system_attributes, # noqa: SLF001
523-
*DeformGradMixin._system_attributes, # noqa: SLF001
518+
_system_attributes: typing.ClassVar[set[str]] = (
519+
SimState._system_attributes # noqa: SLF001
520+
| DeformGradMixin._system_attributes # noqa: SLF001
524521
)
525522

526523
def __init__(

torch_sim/integrators/md.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
from collections.abc import Callable
44
from dataclasses import dataclass
5+
from typing import ClassVar
56

67
import torch
78

@@ -41,8 +42,12 @@ class MDState(SimState):
4142
energy: torch.Tensor
4243
forces: torch.Tensor
4344

44-
_atom_attributes = (*SimState._atom_attributes, "momenta", "forces") # noqa: SLF001
45-
_system_attributes = (*SimState._system_attributes, "energy") # noqa: SLF001
45+
_atom_attributes: ClassVar[set[str]] = (
46+
SimState._atom_attributes | {"momenta", "forces"} # noqa: SLF001
47+
)
48+
_system_attributes: ClassVar[set[str]] = (
49+
SimState._system_attributes | {"energy"} # noqa: SLF001
50+
)
4651

4752
@property
4853
def velocities(self) -> torch.Tensor:

torch_sim/integrators/npt.py

Lines changed: 22 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
from collections.abc import Callable
44
from dataclasses import dataclass
5-
from typing import Any
5+
from typing import Any, ClassVar
66

77
import torch
88

@@ -67,16 +67,17 @@ class NPTLangevinState(SimState):
6767
cell_velocities: torch.Tensor
6868
cell_masses: torch.Tensor
6969

70-
_atom_attributes = (*SimState._atom_attributes, "forces", "velocities") # noqa: SLF001
71-
_system_attributes = (
72-
*SimState._system_attributes, # noqa: SLF001
70+
_atom_attributes: ClassVar[set[str]] = (
71+
SimState._atom_attributes | {"forces", "velocities"} # noqa: SLF001
72+
)
73+
_system_attributes: ClassVar[set[str]] = SimState._system_attributes | { # noqa: SLF001
7374
"stress",
7475
"cell_positions",
7576
"cell_velocities",
7677
"cell_masses",
7778
"reference_cell",
7879
"energy",
79-
)
80+
}
8081

8182
@property
8283
def momenta(self) -> torch.Tensor:
@@ -878,19 +879,23 @@ class NPTNoseHooverState(MDState):
878879
barostat: NoseHooverChain
879880
barostat_fns: NoseHooverChainFns
880881

881-
_system_attributes = (
882-
*MDState._system_attributes, # noqa: SLF001
883-
"reference_cell",
884-
"cell_position",
885-
"cell_momentum",
886-
"cell_mass",
882+
_system_attributes: ClassVar[set[str]] = (
883+
MDState._system_attributes # noqa: SLF001
884+
| {
885+
"reference_cell",
886+
"cell_position",
887+
"cell_momentum",
888+
"cell_mass",
889+
}
887890
)
888-
_global_attributes = (
889-
*MDState._global_attributes, # noqa: SLF001
890-
"thermostat",
891-
"barostat",
892-
"thermostat_fns",
893-
"barostat_fns",
891+
_global_attributes: ClassVar[set[str]] = (
892+
MDState._global_attributes # noqa: SLF001
893+
| {
894+
"thermostat",
895+
"barostat",
896+
"thermostat_fns",
897+
"barostat_fns",
898+
}
894899
)
895900

896901
@property

torch_sim/integrators/nvt.py

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
from collections.abc import Callable
44
from dataclasses import dataclass
5-
from typing import Any
5+
from typing import Any, ClassVar
66

77
import torch
88

@@ -266,10 +266,8 @@ class NVTNoseHooverState(MDState):
266266
chain: NoseHooverChain
267267
_chain_fns: NoseHooverChainFns
268268

269-
_global_attributes = (
270-
*MDState._global_attributes, # noqa: SLF001
271-
"chain",
272-
"_chain_fns",
269+
_global_attributes: ClassVar[set[str]] = (
270+
MDState._global_attributes | {"chain", "_chain_fns"} # noqa: SLF001
273271
)
274272

275273
@property

torch_sim/monte_carlo.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212

1313
from collections.abc import Callable
1414
from dataclasses import dataclass
15+
from typing import ClassVar
1516

1617
import torch
1718

@@ -36,8 +37,12 @@ class SwapMCState(SimState):
3637
energy: torch.Tensor
3738
last_permutation: torch.Tensor
3839

39-
_atom_attributes = (*SimState._atom_attributes, "last_permutation") # noqa: SLF001
40-
_system_attributes = (*SimState._system_attributes, "energy") # noqa: SLF001
40+
_atom_attributes: ClassVar[set[str]] = (
41+
SimState._atom_attributes | {"last_permutation"} # noqa: SLF001
42+
)
43+
_system_attributes: ClassVar[set[str]] = (
44+
SimState._system_attributes | {"energy"} # noqa: SLF001
45+
)
4146

4247

4348
def generate_swaps(

torch_sim/optimizers.py

Lines changed: 50 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020
import functools
2121
from collections.abc import Callable
2222
from dataclasses import dataclass
23-
from typing import Any, Literal, get_args
23+
from typing import Any, ClassVar, Literal, get_args
2424

2525
import torch
2626

@@ -33,27 +33,28 @@
3333
MdFlavor = Literal["vv_fire", "ase_fire"]
3434
vv_fire_key, ase_fire_key = get_args(MdFlavor)
3535

36-
md_atom_attributes = (*SimState._atom_attributes, "forces", "velocities") # noqa: SLF001
36+
md_atom_attributes = SimState._atom_attributes | {"forces", "velocities"} # noqa: SLF001
3737
_fire_system_attributes = (
38-
*SimState._system_attributes, # noqa: SLF001
39-
*DeformGradMixin._system_attributes, # noqa: SLF001
40-
"energy",
41-
"stress",
42-
"cell_positions",
43-
"cell_velocities",
44-
"cell_forces",
45-
"cell_masses",
46-
"cell_factor",
47-
"pressure",
48-
"dt",
49-
"alpha",
50-
"n_pos",
38+
SimState._system_attributes # noqa: SLF001
39+
| DeformGradMixin._system_attributes # noqa: SLF001
40+
| {
41+
"energy",
42+
"stress",
43+
"cell_positions",
44+
"cell_velocities",
45+
"cell_forces",
46+
"cell_masses",
47+
"cell_factor",
48+
"pressure",
49+
"dt",
50+
"alpha",
51+
"n_pos",
52+
}
5153
)
52-
_fire_global_attributes = (
53-
*SimState._global_attributes, # noqa: SLF001
54+
_fire_global_attributes = SimState._global_attributes | { # noqa: SLF001
5455
"hydrostatic_strain",
5556
"constant_volume",
56-
)
57+
}
5758

5859

5960
@dataclass
@@ -78,8 +79,8 @@ class GDState(SimState):
7879
forces: torch.Tensor
7980
energy: torch.Tensor
8081

81-
_atom_attributes = (*SimState._atom_attributes, "forces") # noqa: SLF001
82-
_system_attributes = (*SimState._system_attributes, "energy") # noqa: SLF001
82+
_atom_attributes = SimState._atom_attributes | {"forces"} # noqa: SLF001
83+
_system_attributes = SimState._system_attributes | {"energy"} # noqa: SLF001
8384

8485

8586
def gradient_descent(
@@ -220,20 +221,20 @@ class UnitCellGDState(GDState, DeformGradMixin):
220221
cell_forces: torch.Tensor
221222
cell_masses: torch.Tensor
222223

223-
_system_attributes = (
224-
*GDState._system_attributes, # noqa: SLF001
225-
*DeformGradMixin._system_attributes, # noqa: SLF001
226-
"cell_forces",
227-
"pressure",
228-
"stress",
229-
"cell_positions",
230-
"cell_factor",
231-
"cell_masses",
224+
_system_attributes: ClassVar[set[str]] = (
225+
GDState._system_attributes # noqa: SLF001
226+
| DeformGradMixin._system_attributes # noqa: SLF001
227+
| {
228+
"cell_forces",
229+
"pressure",
230+
"stress",
231+
"cell_positions",
232+
"cell_factor",
233+
"cell_masses",
234+
}
232235
)
233-
_global_attributes = (
234-
*GDState._global_attributes, # noqa: SLF001
235-
"hydrostatic_strain",
236-
"constant_volume",
236+
_global_attributes: ClassVar[set[str]] = (
237+
GDState._global_attributes | {"hydrostatic_strain", "constant_volume"} # noqa: SLF001
237238
)
238239

239240

@@ -523,8 +524,16 @@ class FireState(SimState):
523524
alpha: torch.Tensor
524525
n_pos: torch.Tensor
525526

526-
_atom_attributes = md_atom_attributes
527-
_system_attributes = (*SimState._system_attributes, "energy", "dt", "alpha", "n_pos") # noqa: SLF001
527+
_atom_attributes: ClassVar[set[str]] = md_atom_attributes
528+
_system_attributes: ClassVar[set[str]] = (
529+
SimState._system_attributes # noqa: SLF001
530+
| {
531+
"energy",
532+
"dt",
533+
"alpha",
534+
"n_pos",
535+
}
536+
)
528537

529538

530539
def fire(
@@ -736,9 +745,9 @@ class UnitCellFireState(SimState, DeformGradMixin):
736745
alpha: torch.Tensor
737746
n_pos: torch.Tensor
738747

739-
_atom_attributes = md_atom_attributes
740-
_system_attributes = _fire_system_attributes
741-
_global_attributes = _fire_global_attributes
748+
_atom_attributes: ClassVar[set[str]] = md_atom_attributes
749+
_system_attributes: ClassVar[set[str]] = _fire_system_attributes
750+
_global_attributes: ClassVar[set[str]] = _fire_global_attributes
742751

743752

744753
def unit_cell_fire(
@@ -1028,9 +1037,9 @@ class FrechetCellFIREState(SimState, DeformGradMixin):
10281037
alpha: torch.Tensor
10291038
n_pos: torch.Tensor
10301039

1031-
_atom_attributes = md_atom_attributes
1032-
_system_attributes = _fire_system_attributes
1033-
_global_attributes = _fire_global_attributes
1040+
_atom_attributes: ClassVar[set[str]] = md_atom_attributes
1041+
_system_attributes: ClassVar[set[str]] = _fire_system_attributes
1042+
_global_attributes: ClassVar[set[str]] = _fire_global_attributes
10341043

10351044

10361045
def frechet_cell_fire(

torch_sim/state.py

Lines changed: 16 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
from collections import defaultdict
1212
from collections.abc import Generator
1313
from dataclasses import dataclass
14-
from typing import TYPE_CHECKING, Any, Literal, Self, cast
14+
from typing import TYPE_CHECKING, Any, ClassVar, Literal, Self, cast
1515

1616
import torch
1717

@@ -85,9 +85,14 @@ class SimState:
8585
atomic_numbers: torch.Tensor
8686
system_idx: torch.Tensor
8787

88-
_atom_attributes = ("positions", "masses", "atomic_numbers", "system_idx")
89-
_system_attributes = ("cell",)
90-
_global_attributes = ("pbc",)
88+
_atom_attributes: ClassVar[set[str]] = {
89+
"positions",
90+
"masses",
91+
"atomic_numbers",
92+
"system_idx",
93+
}
94+
_system_attributes: ClassVar[set[str]] = {"cell"}
95+
_global_attributes: ClassVar[set[str]] = {"pbc"}
9196

9297
def __init__(
9398
self,
@@ -444,12 +449,14 @@ def _assert_no_tensor_attributes_can_be_none(cls) -> None:
444449
@classmethod
445450
def _assert_all_attributes_have_defined_scope(cls) -> None:
446451
all_defined_attributes = (
447-
cls._atom_attributes + cls._system_attributes + cls._global_attributes
452+
cls._atom_attributes | cls._system_attributes | cls._global_attributes
448453
)
449454
# 1) assert that no attribute is defined twice in all_defined_attributes
450-
duplicates = [
451-
x for x in all_defined_attributes if all_defined_attributes.count(x) > 1
452-
]
455+
duplicates = (
456+
(cls._atom_attributes & cls._system_attributes)
457+
| (cls._atom_attributes & cls._global_attributes)
458+
| (cls._system_attributes & cls._global_attributes)
459+
)
453460
if duplicates:
454461
raise TypeError(
455462
f"Attributes {duplicates} are declared multiple times in {cls.__name__} "
@@ -492,7 +499,7 @@ class DeformGradMixin:
492499

493500
reference_cell: torch.Tensor
494501

495-
_system_attributes = ("reference_cell",)
502+
_system_attributes: ClassVar[set[str]] = {"reference_cell"}
496503

497504
@property
498505
def reference_row_vector_cell(self) -> torch.Tensor:

0 commit comments

Comments
 (0)