Skip to content

Commit 1710363

Browse files
committed
cleanup
1 parent 79c6c75 commit 1710363

File tree

1 file changed

+7
-6
lines changed

1 file changed

+7
-6
lines changed

tests/test_state.py

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -109,7 +109,7 @@ def test_concatenate_two_si_states(
109109
assert concatenated.graph_idx.shape == si_double_sim_state.graph_idx.shape
110110

111111
# Check graph indices
112-
expected_graphs = torch.cat(
112+
expected_graph_indices = torch.cat(
113113
[
114114
torch.zeros(
115115
si_sim_state.n_atoms, dtype=torch.int64, device=si_sim_state.device
@@ -119,7 +119,7 @@ def test_concatenate_two_si_states(
119119
),
120120
]
121121
)
122-
assert torch.all(concatenated.graph_idx == expected_graphs)
122+
assert torch.all(concatenated.graph_idx == expected_graph_indices)
123123

124124
# Check that positions match (accounting for graph indices)
125125
for graph_idx in range(2):
@@ -153,13 +153,13 @@ def test_concatenate_si_and_fe_states(
153153
# Check graph indices
154154
si_atoms = si_sim_state.n_atoms
155155
fe_atoms = fe_supercell_sim_state.n_atoms
156-
expected_graphs = torch.cat(
156+
expected_graph_indices = torch.cat(
157157
[
158158
torch.zeros(si_atoms, dtype=torch.int64, device=si_sim_state.device),
159159
torch.ones(fe_atoms, dtype=torch.int64, device=fe_supercell_sim_state.device),
160160
]
161161
)
162-
assert torch.all(concatenated.graph_idx == expected_graphs)
162+
assert torch.all(concatenated.graph_idx == expected_graph_indices)
163163

164164
# check n_atoms_per_graph
165165
assert torch.all(
@@ -203,15 +203,15 @@ def test_concatenate_double_si_and_fe_states(
203203
fe_atoms = fe_supercell_sim_state.n_atoms
204204

205205
# The double Si state already has graphs 0 and 1, so Ar should be graph 2
206-
expected_graphs = torch.cat(
206+
expected_graph_indices = torch.cat(
207207
[
208208
si_double_sim_state.graph_idx,
209209
torch.full(
210210
(fe_atoms,), 2, dtype=torch.int64, device=fe_supercell_sim_state.device
211211
),
212212
]
213213
)
214-
assert torch.all(concatenated.graph_idx == expected_graphs)
214+
assert torch.all(concatenated.graph_idx == expected_graph_indices)
215215
assert torch.unique(concatenated.graph_idx).shape[0] == 3
216216

217217
# Check that we can slice back to the original states
@@ -638,6 +638,7 @@ def test_deprecated_batch_properties_equal_to_new_graph_properties(
638638
assert state.n_batches == state.n_graphs
639639
assert torch.allclose(state.n_atoms_per_batch, state.n_atoms_per_graph)
640640

641+
# now test that assigning the old .batch property behaves the same
641642
new_graph_idx = torch.arange(4, device=device)
642643
state.batch = new_graph_idx
643644
assert torch.allclose(state.graph_idx, new_graph_idx)

0 commit comments

Comments
 (0)