@@ -109,7 +109,7 @@ def test_concatenate_two_si_states(
109
109
assert concatenated .graph_idx .shape == si_double_sim_state .graph_idx .shape
110
110
111
111
# Check graph indices
112
- expected_graphs = torch .cat (
112
+ expected_graph_indices = torch .cat (
113
113
[
114
114
torch .zeros (
115
115
si_sim_state .n_atoms , dtype = torch .int64 , device = si_sim_state .device
@@ -119,7 +119,7 @@ def test_concatenate_two_si_states(
119
119
),
120
120
]
121
121
)
122
- assert torch .all (concatenated .graph_idx == expected_graphs )
122
+ assert torch .all (concatenated .graph_idx == expected_graph_indices )
123
123
124
124
# Check that positions match (accounting for graph indices)
125
125
for graph_idx in range (2 ):
@@ -153,13 +153,13 @@ def test_concatenate_si_and_fe_states(
153
153
# Check graph indices
154
154
si_atoms = si_sim_state .n_atoms
155
155
fe_atoms = fe_supercell_sim_state .n_atoms
156
- expected_graphs = torch .cat (
156
+ expected_graph_indices = torch .cat (
157
157
[
158
158
torch .zeros (si_atoms , dtype = torch .int64 , device = si_sim_state .device ),
159
159
torch .ones (fe_atoms , dtype = torch .int64 , device = fe_supercell_sim_state .device ),
160
160
]
161
161
)
162
- assert torch .all (concatenated .graph_idx == expected_graphs )
162
+ assert torch .all (concatenated .graph_idx == expected_graph_indices )
163
163
164
164
# check n_atoms_per_graph
165
165
assert torch .all (
@@ -203,15 +203,15 @@ def test_concatenate_double_si_and_fe_states(
203
203
fe_atoms = fe_supercell_sim_state .n_atoms
204
204
205
205
# The double Si state already has graphs 0 and 1, so Ar should be graph 2
206
- expected_graphs = torch .cat (
206
+ expected_graph_indices = torch .cat (
207
207
[
208
208
si_double_sim_state .graph_idx ,
209
209
torch .full (
210
210
(fe_atoms ,), 2 , dtype = torch .int64 , device = fe_supercell_sim_state .device
211
211
),
212
212
]
213
213
)
214
- assert torch .all (concatenated .graph_idx == expected_graphs )
214
+ assert torch .all (concatenated .graph_idx == expected_graph_indices )
215
215
assert torch .unique (concatenated .graph_idx ).shape [0 ] == 3
216
216
217
217
# Check that we can slice back to the original states
@@ -638,6 +638,7 @@ def test_deprecated_batch_properties_equal_to_new_graph_properties(
638
638
assert state .n_batches == state .n_graphs
639
639
assert torch .allclose (state .n_atoms_per_batch , state .n_atoms_per_graph )
640
640
641
+ # now test that assigning the old .batch property behaves the same
641
642
new_graph_idx = torch .arange (4 , device = device )
642
643
state .batch = new_graph_idx
643
644
assert torch .allclose (state .graph_idx , new_graph_idx )
0 commit comments