20
20
import functools
21
21
from collections .abc import Callable
22
22
from dataclasses import dataclass
23
- from typing import Any , ClassVar , Literal , get_args
23
+ from typing import Any , Literal , get_args
24
24
25
25
import torch
26
26
@@ -221,7 +221,7 @@ class UnitCellGDState(GDState, DeformGradMixin):
221
221
cell_forces : torch .Tensor
222
222
cell_masses : torch .Tensor
223
223
224
- _system_attributes : ClassVar [ set [ str ]] = (
224
+ _system_attributes = (
225
225
GDState ._system_attributes # noqa: SLF001
226
226
| DeformGradMixin ._system_attributes # noqa: SLF001
227
227
| {
@@ -233,7 +233,7 @@ class UnitCellGDState(GDState, DeformGradMixin):
233
233
"cell_masses" ,
234
234
}
235
235
)
236
- _global_attributes : ClassVar [ set [ str ]] = (
236
+ _global_attributes = (
237
237
GDState ._global_attributes | {"hydrostatic_strain" , "constant_volume" } # noqa: SLF001
238
238
)
239
239
@@ -524,8 +524,8 @@ class FireState(SimState):
524
524
alpha : torch .Tensor
525
525
n_pos : torch .Tensor
526
526
527
- _atom_attributes : ClassVar [ set [ str ]] = md_atom_attributes
528
- _system_attributes : ClassVar [ set [ str ]] = (
527
+ _atom_attributes = md_atom_attributes
528
+ _system_attributes = (
529
529
SimState ._system_attributes # noqa: SLF001
530
530
| {
531
531
"energy" ,
@@ -745,9 +745,9 @@ class UnitCellFireState(SimState, DeformGradMixin):
745
745
alpha : torch .Tensor
746
746
n_pos : torch .Tensor
747
747
748
- _atom_attributes : ClassVar [ set [ str ]] = md_atom_attributes
749
- _system_attributes : ClassVar [ set [ str ]] = _fire_system_attributes
750
- _global_attributes : ClassVar [ set [ str ]] = _fire_global_attributes
748
+ _atom_attributes = md_atom_attributes
749
+ _system_attributes = _fire_system_attributes
750
+ _global_attributes = _fire_global_attributes
751
751
752
752
753
753
def unit_cell_fire (
@@ -1037,9 +1037,9 @@ class FrechetCellFIREState(SimState, DeformGradMixin):
1037
1037
alpha : torch .Tensor
1038
1038
n_pos : torch .Tensor
1039
1039
1040
- _atom_attributes : ClassVar [ set [ str ]] = md_atom_attributes
1041
- _system_attributes : ClassVar [ set [ str ]] = _fire_system_attributes
1042
- _global_attributes : ClassVar [ set [ str ]] = _fire_global_attributes
1040
+ _atom_attributes = md_atom_attributes
1041
+ _system_attributes = _fire_system_attributes
1042
+ _global_attributes = _fire_global_attributes
1043
1043
1044
1044
1045
1045
def frechet_cell_fire (
0 commit comments