20
20
import functools
21
21
from collections .abc import Callable
22
22
from dataclasses import dataclass
23
- from typing import Any , Literal , get_args
23
+ from typing import Any , ClassVar , Literal , get_args
24
24
25
25
import torch
26
26
33
33
MdFlavor = Literal ["vv_fire" , "ase_fire" ]
34
34
vv_fire_key , ase_fire_key = get_args (MdFlavor )
35
35
36
- md_atom_attributes = ( * SimState ._atom_attributes , "forces" , "velocities" ) # noqa: SLF001
36
+ md_atom_attributes = SimState ._atom_attributes | { "forces" , "velocities" } # noqa: SLF001
37
37
_fire_system_attributes = (
38
- * SimState ._system_attributes , # noqa: SLF001
39
- * DeformGradMixin ._system_attributes , # noqa: SLF001
40
- "energy" ,
41
- "stress" ,
42
- "cell_positions" ,
43
- "cell_velocities" ,
44
- "cell_forces" ,
45
- "cell_masses" ,
46
- "cell_factor" ,
47
- "pressure" ,
48
- "dt" ,
49
- "alpha" ,
50
- "n_pos" ,
38
+ SimState ._system_attributes # noqa: SLF001
39
+ | DeformGradMixin ._system_attributes # noqa: SLF001
40
+ | {
41
+ "energy" ,
42
+ "stress" ,
43
+ "cell_positions" ,
44
+ "cell_velocities" ,
45
+ "cell_forces" ,
46
+ "cell_masses" ,
47
+ "cell_factor" ,
48
+ "pressure" ,
49
+ "dt" ,
50
+ "alpha" ,
51
+ "n_pos" ,
52
+ }
51
53
)
52
- _fire_global_attributes = (
53
- * SimState ._global_attributes , # noqa: SLF001
54
+ _fire_global_attributes = SimState ._global_attributes | { # noqa: SLF001
54
55
"hydrostatic_strain" ,
55
56
"constant_volume" ,
56
- )
57
+ }
57
58
58
59
59
60
@dataclass
@@ -78,8 +79,8 @@ class GDState(SimState):
78
79
forces : torch .Tensor
79
80
energy : torch .Tensor
80
81
81
- _atom_attributes = ( * SimState ._atom_attributes , "forces" ) # noqa: SLF001
82
- _system_attributes = ( * SimState ._system_attributes , "energy" ) # noqa: SLF001
82
+ _atom_attributes = SimState ._atom_attributes | { "forces" } # noqa: SLF001
83
+ _system_attributes = SimState ._system_attributes | { "energy" } # noqa: SLF001
83
84
84
85
85
86
def gradient_descent (
@@ -220,20 +221,20 @@ class UnitCellGDState(GDState, DeformGradMixin):
220
221
cell_forces : torch .Tensor
221
222
cell_masses : torch .Tensor
222
223
223
- _system_attributes = (
224
- * GDState ._system_attributes , # noqa: SLF001
225
- * DeformGradMixin ._system_attributes , # noqa: SLF001
226
- "cell_forces" ,
227
- "pressure" ,
228
- "stress" ,
229
- "cell_positions" ,
230
- "cell_factor" ,
231
- "cell_masses" ,
224
+ _system_attributes : ClassVar [set [str ]] = (
225
+ GDState ._system_attributes # noqa: SLF001
226
+ | DeformGradMixin ._system_attributes # noqa: SLF001
227
+ | {
228
+ "cell_forces" ,
229
+ "pressure" ,
230
+ "stress" ,
231
+ "cell_positions" ,
232
+ "cell_factor" ,
233
+ "cell_masses" ,
234
+ }
232
235
)
233
- _global_attributes = (
234
- * GDState ._global_attributes , # noqa: SLF001
235
- "hydrostatic_strain" ,
236
- "constant_volume" ,
236
+ _global_attributes : ClassVar [set [str ]] = (
237
+ GDState ._global_attributes | {"hydrostatic_strain" , "constant_volume" } # noqa: SLF001
237
238
)
238
239
239
240
@@ -523,8 +524,16 @@ class FireState(SimState):
523
524
alpha : torch .Tensor
524
525
n_pos : torch .Tensor
525
526
526
- _atom_attributes = md_atom_attributes
527
- _system_attributes = (* SimState ._system_attributes , "energy" , "dt" , "alpha" , "n_pos" ) # noqa: SLF001
527
+ _atom_attributes : ClassVar [set [str ]] = md_atom_attributes
528
+ _system_attributes : ClassVar [set [str ]] = (
529
+ SimState ._system_attributes # noqa: SLF001
530
+ | {
531
+ "energy" ,
532
+ "dt" ,
533
+ "alpha" ,
534
+ "n_pos" ,
535
+ }
536
+ )
528
537
529
538
530
539
def fire (
@@ -736,9 +745,9 @@ class UnitCellFireState(SimState, DeformGradMixin):
736
745
alpha : torch .Tensor
737
746
n_pos : torch .Tensor
738
747
739
- _atom_attributes = md_atom_attributes
740
- _system_attributes = _fire_system_attributes
741
- _global_attributes = _fire_global_attributes
748
+ _atom_attributes : ClassVar [ set [ str ]] = md_atom_attributes
749
+ _system_attributes : ClassVar [ set [ str ]] = _fire_system_attributes
750
+ _global_attributes : ClassVar [ set [ str ]] = _fire_global_attributes
742
751
743
752
744
753
def unit_cell_fire (
@@ -1028,9 +1037,9 @@ class FrechetCellFIREState(SimState, DeformGradMixin):
1028
1037
alpha : torch .Tensor
1029
1038
n_pos : torch .Tensor
1030
1039
1031
- _atom_attributes = md_atom_attributes
1032
- _system_attributes = _fire_system_attributes
1033
- _global_attributes = _fire_global_attributes
1040
+ _atom_attributes : ClassVar [ set [ str ]] = md_atom_attributes
1041
+ _system_attributes : ClassVar [ set [ str ]] = _fire_system_attributes
1042
+ _global_attributes : ClassVar [ set [ str ]] = _fire_global_attributes
1034
1043
1035
1044
1036
1045
def frechet_cell_fire (
0 commit comments