|
13 | 13 | _pop_states,
|
14 | 14 | _slice_state,
|
15 | 15 | concatenate_states,
|
16 |
| - infer_property_scope, |
| 16 | + get_attrs_for_scope, |
17 | 17 | initialize_state,
|
18 | 18 | )
|
19 | 19 |
|
|
24 | 24 | from pymatgen.core import Structure
|
25 | 25 |
|
26 | 26 |
|
27 |
| -def test_infer_sim_state_property_scope(si_sim_state: ts.SimState) -> None: |
28 |
| - """Test inference of property scope.""" |
29 |
| - scope = infer_property_scope(si_sim_state) |
30 |
| - assert set(scope["global"]) == {"pbc"} |
31 |
| - assert set(scope["per_atom"]) == { |
| 27 | +def test_get_attrs_for_scope(si_sim_state: ts.SimState) -> None: |
| 28 | + """Test getting attributes for a scope.""" |
| 29 | + per_atom_attrs = dict(get_attrs_for_scope(si_sim_state, "per-atom")) |
| 30 | + assert set(per_atom_attrs.keys()) == { |
32 | 31 | "positions",
|
33 | 32 | "masses",
|
34 | 33 | "atomic_numbers",
|
35 | 34 | "system_idx",
|
36 | 35 | }
|
37 |
| - assert set(scope["per_system"]) == {"cell"} |
| 36 | + per_system_attrs = dict(get_attrs_for_scope(si_sim_state, "per-system")) |
| 37 | + assert set(per_system_attrs.keys()) == {"cell"} |
| 38 | + global_attrs = dict(get_attrs_for_scope(si_sim_state, "global")) |
| 39 | + assert set(global_attrs.keys()) == {"pbc"} |
38 | 40 |
|
39 | 41 |
|
40 |
| -def test_infer_md_state_property_scope(si_sim_state: ts.SimState) -> None: |
41 |
| - """Test inference of property scope.""" |
42 |
| - state = MDState( |
43 |
| - **asdict(si_sim_state), |
44 |
| - momenta=torch.randn_like(si_sim_state.positions), |
45 |
| - forces=torch.randn_like(si_sim_state.positions), |
46 |
| - energy=torch.zeros((1,)), |
47 |
| - ) |
48 |
| - scope = infer_property_scope(state) |
49 |
| - assert set(scope["global"]) == {"pbc"} |
50 |
| - assert set(scope["per_atom"]) == { |
51 |
| - "positions", |
52 |
| - "masses", |
53 |
| - "atomic_numbers", |
54 |
| - "system_idx", |
55 |
| - "forces", |
56 |
| - "momenta", |
57 |
| - } |
58 |
| - assert set(scope["per_system"]) == {"cell", "energy"} |
| 42 | +def test_all_attributes_must_be_specified_in_scopes() -> None: |
| 43 | + """Test that an error is raised when we forget to specify the scope |
| 44 | + for an attribute in a child SimState class.""" |
| 45 | + with pytest.raises(TypeError) as excinfo: |
| 46 | + |
| 47 | + class ChildState(SimState): |
| 48 | + attribute_specified_in_scopes: bool |
| 49 | + attribute_not_specified_in_scopes: bool |
| 50 | + |
| 51 | + _atom_attributes = ( |
| 52 | + SimState._atom_attributes | {"attribute_specified_in_scopes"} # noqa: SLF001 |
| 53 | + ) |
| 54 | + |
| 55 | + assert "attribute_not_specified_in_scopes" in str(excinfo.value) |
| 56 | + assert "attribute_specified_in_scopes" not in str(excinfo.value) |
| 57 | + |
| 58 | + |
| 59 | +def test_no_duplicate_attributes_in_scopes() -> None: |
| 60 | + """Test that no attributes are specified in multiple scopes.""" |
| 61 | + |
| 62 | + # Capture the exception information using "as excinfo" |
| 63 | + with pytest.raises(TypeError) as excinfo: |
| 64 | + |
| 65 | + class ChildState(SimState): |
| 66 | + duplicated_attribute: bool |
| 67 | + |
| 68 | + _system_attributes = SimState._system_attributes | {"duplicated_attribute"} # noqa: SLF001 |
| 69 | + _global_attributes = SimState._global_attributes | {"duplicated_attribute"} # noqa: SLF001 |
| 70 | + |
| 71 | + assert "are declared multiple times" in str(excinfo.value) |
| 72 | + assert "duplicated_attribute" in str(excinfo.value) |
59 | 73 |
|
60 | 74 |
|
61 | 75 | def test_slice_substate(
|
@@ -497,6 +511,11 @@ def test_column_vector_cell(si_sim_state: ts.SimState) -> None:
|
497 | 511 | class DeformState(SimState, DeformGradMixin):
|
498 | 512 | """Test class that combines SimState with DeformGradMixin."""
|
499 | 513 |
|
| 514 | + _system_attributes = ( |
| 515 | + SimState._system_attributes # noqa: SLF001 |
| 516 | + | DeformGradMixin._system_attributes # noqa: SLF001 |
| 517 | + ) |
| 518 | + |
500 | 519 | def __init__(
|
501 | 520 | self,
|
502 | 521 | *args,
|
|
0 commit comments