Skip to content

Commit 8773f45

Browse files
committed
rename per_batch to per_graph
1 parent 7d27048 commit 8773f45

File tree

11 files changed

+72
-72
lines changed

11 files changed

+72
-72
lines changed

examples/scripts/1_Introduction/1.2_MACE.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -64,11 +64,11 @@
6464
atomic_numbers = torch.tensor(atomic_numbers_numpy, device=device, dtype=torch.int)
6565

6666
# create batch index array of shape (16,) which is 0 for first 8 atoms, 1 for last 8 atoms
67-
atoms_per_batch = torch.tensor(
67+
atoms_per_graph = torch.tensor(
6868
[len(atoms) for atoms in atoms_list], device=device, dtype=torch.int
6969
)
7070
batch = torch.repeat_interleave(
71-
torch.arange(len(atoms_per_batch), device=device), atoms_per_batch
71+
torch.arange(len(atoms_per_graph), device=device), atoms_per_graph
7272
)
7373

7474
# You can see their shapes are as expected

examples/scripts/2_Structural_optimization/2.3_MACE_Gradient_Descent.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -94,11 +94,11 @@
9494
masses = torch.tensor(masses_numpy, device=device, dtype=dtype)
9595
9696
# Create batch indices tensor for scatter operations
97-
atoms_per_batch = torch.tensor(
97+
atoms_per_graph = torch.tensor(
9898
[len(atoms) for atoms in atoms_list], device=device, dtype=torch.int
9999
)
100100
batch_indices = torch.repeat_interleave(
101-
torch.arange(len(atoms_per_batch), device=device), atoms_per_batch
101+
torch.arange(len(atoms_per_graph), device=device), atoms_per_graph
102102
)
103103
"""
104104

examples/tutorials/state_tutorial.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -260,7 +260,7 @@
260260
scope = infer_property_scope(md_state)
261261
print("Global properties:", scope["global"])
262262
print("Per-atom properties:", scope["per_atom"])
263-
print("Per-batch properties:", scope["per_batch"])
263+
print("Per-batch properties:", scope["per_graph"])
264264

265265

266266
# %% [markdown]

tests/test_integrators.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -372,7 +372,7 @@ def test_compare_single_vs_batched_integrators(
372372
torch.testing.assert_close(single_state.energy, final_state.energy)
373373

374374

375-
def test_compute_cell_force_atoms_per_batch():
375+
def test_compute_cell_force_atoms_per_graph():
376376
"""Test that compute_cell_force correctly scales by number of atoms per batch.
377377
378378
Covers fix in https://github.com/Radical-AI/torch-sim/pull/153."""

tests/test_state.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@ def test_infer_sim_state_property_scope(si_sim_state: ts.SimState) -> None:
2929
scope = infer_property_scope(si_sim_state)
3030
assert set(scope["global"]) == {"pbc"}
3131
assert set(scope["per_atom"]) == {"positions", "masses", "atomic_numbers", "batch"}
32-
assert set(scope["per_batch"]) == {"cell"}
32+
assert set(scope["per_graph"]) == {"cell"}
3333

3434

3535
def test_infer_md_state_property_scope(si_sim_state: ts.SimState) -> None:
@@ -50,7 +50,7 @@ def test_infer_md_state_property_scope(si_sim_state: ts.SimState) -> None:
5050
"forces",
5151
"momenta",
5252
}
53-
assert set(scope["per_batch"]) == {"cell", "energy"}
53+
assert set(scope["per_graph"]) == {"cell", "energy"}
5454

5555

5656
def test_slice_substate(
@@ -156,9 +156,9 @@ def test_concatenate_si_and_fe_states(
156156
)
157157
assert torch.all(concatenated.batch == expected_batch)
158158

159-
# check n_atoms_per_batch
159+
# check n_atoms_per_graph
160160
assert torch.all(
161-
concatenated.n_atoms_per_batch
161+
concatenated.n_atoms_per_graph
162162
== torch.tensor(
163163
[si_sim_state.n_atoms, fe_supercell_sim_state.n_atoms],
164164
device=concatenated.device,

tests/test_trajectory.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -678,9 +678,9 @@ def test_multi_batch_reporter(
678678

679679
# Check that each trajectory has the correct number of atoms
680680
# (should be half of the total in the double state)
681-
atoms_per_batch = si_double_sim_state.positions.shape[0] // 2
682-
assert traj0.get_array("positions").shape[1] == atoms_per_batch
683-
assert traj1.get_array("positions").shape[1] == atoms_per_batch
681+
atoms_per_graph = si_double_sim_state.positions.shape[0] // 2
682+
assert traj0.get_array("positions").shape[1] == atoms_per_graph
683+
assert traj1.get_array("positions").shape[1] == atoms_per_graph
684684

685685
# Check property data
686686
assert "ones" in traj0.array_registry

torch_sim/integrators/npt.py

Lines changed: 25 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -135,8 +135,8 @@ def _compute_cell_force(
135135
3, device=state.device, dtype=state.dtype
136136
).unsqueeze(0)
137137

138-
# Correct implementation with scaling by n_atoms_per_batch
139-
return virial + e_kin_per_atom * state.n_atoms_per_batch.view(-1, 1, 1)
138+
# Correct implementation with scaling by n_atoms_per_graph
139+
return virial + e_kin_per_atom * state.n_atoms_per_graph.view(-1, 1, 1)
140140

141141

142142
def npt_langevin( # noqa: C901, PLR0915
@@ -663,13 +663,13 @@ def npt_init(
663663

664664
# Calculate cell masses based on system size and temperature
665665
# This follows standard NPT barostat mass scaling
666-
n_atoms_per_batch = torch.bincount(state.batch)
666+
n_atoms_per_graph = torch.bincount(state.batch)
667667
batch_kT = (
668668
kT.expand(state.n_batches)
669669
if isinstance(kT, torch.Tensor) and kT.ndim == 0
670670
else kT
671671
)
672-
cell_masses = (n_atoms_per_batch + 1) * batch_kT * b_tau * b_tau
672+
cell_masses = (n_atoms_per_graph + 1) * batch_kT * b_tau * b_tau
673673

674674
# Create the initial state
675675
return NPTLangevinState(
@@ -735,8 +735,8 @@ def npt_update(
735735

736736
# Update barostat mass based on current temperature
737737
# This ensures proper coupling between system and barostat
738-
n_atoms_per_batch = torch.bincount(state.batch)
739-
state.cell_masses = (n_atoms_per_batch + 1) * batch_kT * b_tau * b_tau
738+
n_atoms_per_graph = torch.bincount(state.batch)
739+
state.cell_masses = (n_atoms_per_graph + 1) * batch_kT * b_tau * b_tau
740740

741741
# Compute model output for current state
742742
model_output = model(state)
@@ -1017,8 +1017,8 @@ def update_cell_mass(
10171017
kT_batch = kT.expand(state.n_batches) if kT.ndim == 0 else kT
10181018

10191019
# Calculate cell masses for each batch
1020-
n_atoms_per_batch = torch.bincount(state.batch, minlength=state.n_batches)
1021-
cell_mass = dim * (n_atoms_per_batch + 1) * kT_batch * state.barostat.tau**2
1020+
n_atoms_per_graph = torch.bincount(state.batch, minlength=state.n_batches)
1021+
cell_mass = dim * (n_atoms_per_graph + 1) * kT_batch * state.barostat.tau**2
10221022

10231023
# Update state with new cell masses
10241024
state.cell_mass = cell_mass.to(device=device, dtype=dtype)
@@ -1213,15 +1213,15 @@ def compute_cell_force(
12131213

12141214
# Compute kinetic energy contribution per batch
12151215
# Split momenta and masses by batch
1216-
KE_per_batch = torch.zeros(
1216+
KE_per_graph = torch.zeros(
12171217
n_batches, device=positions.device, dtype=positions.dtype
12181218
)
12191219
for b in range(n_batches):
12201220
batch_mask = batch == b
12211221
if batch_mask.any():
12221222
batch_momenta = momenta[batch_mask]
12231223
batch_masses = masses[batch_mask]
1224-
KE_per_batch[b] = calc_kinetic_energy(batch_momenta, batch_masses)
1224+
KE_per_graph[b] = calc_kinetic_energy(batch_momenta, batch_masses)
12251225

12261226
# Get stress tensor and compute trace per batch
12271227
# Handle stress tensor with batch dimension
@@ -1236,7 +1236,7 @@ def compute_cell_force(
12361236
# Compute force on cell coordinate per batch
12371237
# F = alpha * KE - dU/dV - P*V*d
12381238
return (
1239-
(alpha * KE_per_batch)
1239+
(alpha * KE_per_graph)
12401240
- (internal_pressure * volume)
12411241
- (external_pressure * volume * dim)
12421242
)
@@ -1285,8 +1285,8 @@ def npt_inner_step(
12851285
model_output = model(state)
12861286

12871287
# First half step: Update momenta
1288-
n_atoms_per_batch = torch.bincount(state.batch, minlength=state.n_batches)
1289-
alpha = 1 + 1 / n_atoms_per_batch # [n_batches]
1288+
n_atoms_per_graph = torch.bincount(state.batch, minlength=state.n_batches)
1289+
alpha = 1 + 1 / n_atoms_per_graph # [n_batches]
12901290

12911291
cell_force_val = compute_cell_force(
12921292
alpha=alpha,
@@ -1425,8 +1425,8 @@ def npt_nose_hoover_init(
14251425
kT_batch = kT.expand(n_batches) if kT.ndim == 0 else kT
14261426

14271427
# Calculate cell masses for each batch
1428-
n_atoms_per_batch = torch.bincount(state.batch, minlength=n_batches)
1429-
cell_mass = dim * (n_atoms_per_batch + 1) * kT_batch * b_tau**2
1428+
n_atoms_per_graph = torch.bincount(state.batch, minlength=n_batches)
1429+
cell_mass = dim * (n_atoms_per_graph + 1) * kT_batch * b_tau**2
14301430
cell_mass = cell_mass.to(device=device, dtype=dtype)
14311431

14321432
# Calculate cell kinetic energy (using first batch for initialization)
@@ -1596,19 +1596,19 @@ def npt_nose_hoover_invariant(
15961596
e_pot = state.energy # Should be scalar or [n_batches]
15971597

15981598
# Calculate kinetic energy of particles per batch
1599-
e_kin_per_batch = calc_kinetic_energy(state.momenta, state.masses, batch=state.batch)
1599+
e_kin_per_graph = calc_kinetic_energy(state.momenta, state.masses, batch=state.batch)
16001600

16011601
# Calculate degrees of freedom per batch
1602-
n_atoms_per_batch = torch.bincount(state.batch)
1603-
DOF_per_batch = (
1604-
n_atoms_per_batch * state.positions.shape[-1]
1602+
n_atoms_per_graph = torch.bincount(state.batch)
1603+
DOF_per_graph = (
1604+
n_atoms_per_graph * state.positions.shape[-1]
16051605
) # n_atoms * n_dimensions
16061606

16071607
# Initialize total energy with PE + KE
16081608
if isinstance(e_pot, torch.Tensor) and e_pot.ndim > 0:
1609-
e_tot = e_pot + e_kin_per_batch # [n_batches]
1609+
e_tot = e_pot + e_kin_per_graph # [n_batches]
16101610
else:
1611-
e_tot = e_pot + e_kin_per_batch # [n_batches]
1611+
e_tot = e_pot + e_kin_per_graph # [n_batches]
16121612

16131613
# Add thermostat chain contributions
16141614
# Note: These are global thermostat variables, so we add them to each batch
@@ -1618,14 +1618,14 @@ def npt_nose_hoover_invariant(
16181618
2 * state.thermostat.masses[0]
16191619
)
16201620

1621-
# Ensure kT can broadcast properly with DOF_per_batch
1621+
# Ensure kT can broadcast properly with DOF_per_graph
16221622
if isinstance(kT, torch.Tensor) and kT.ndim == 0:
1623-
# Scalar kT - expand to match DOF_per_batch shape
1624-
kT_expanded = kT.expand_as(DOF_per_batch)
1623+
# Scalar kT - expand to match DOF_per_graph shape
1624+
kT_expanded = kT.expand_as(DOF_per_graph)
16251625
else:
16261626
kT_expanded = kT
16271627

1628-
thermostat_energy += DOF_per_batch * kT_expanded * state.thermostat.positions[0]
1628+
thermostat_energy += DOF_per_graph * kT_expanded * state.thermostat.positions[0]
16291629

16301630
# Add remaining thermostat terms
16311631
for pos, momentum, mass in zip(

torch_sim/integrators/nvt.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -370,14 +370,14 @@ def nvt_nose_hoover_init(
370370
KE = calc_kinetic_energy(momenta, state.masses, batch=state.batch)
371371

372372
# Calculate degrees of freedom per batch
373-
n_atoms_per_batch = torch.bincount(state.batch)
374-
dof_per_batch = (
375-
n_atoms_per_batch * state.positions.shape[-1]
373+
n_atoms_per_graph = torch.bincount(state.batch)
374+
dof_per_graph = (
375+
n_atoms_per_graph * state.positions.shape[-1]
376376
) # n_atoms * n_dimensions
377377

378378
# For now, sum the per-batch DOF as chain expects a single int
379379
# This is a limitation that should be addressed in the chain implementation
380-
total_dof = int(dof_per_batch.sum().item())
380+
total_dof = int(dof_per_graph.sum().item())
381381

382382
# Initialize state
383383
state = NVTNoseHooverState(
@@ -479,8 +479,8 @@ def nvt_nose_hoover_invariant(
479479
e_kin = calc_kinetic_energy(state.momenta, state.masses, batch=state.batch)
480480

481481
# Get system degrees of freedom per batch
482-
n_atoms_per_batch = torch.bincount(state.batch)
483-
dof = n_atoms_per_batch * state.positions.shape[-1] # n_atoms * n_dimensions
482+
n_atoms_per_graph = torch.bincount(state.batch)
483+
dof = n_atoms_per_graph * state.positions.shape[-1] # n_atoms * n_dimensions
484484

485485
# Start with system energy
486486
e_tot = e_pot + e_kin

torch_sim/io.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -226,9 +226,9 @@ def atoms_to_state(
226226
)
227227

228228
# Create batch indices using repeat_interleave
229-
atoms_per_batch = torch.tensor([len(a) for a in atoms_list], device=device)
229+
atoms_per_graph = torch.tensor([len(a) for a in atoms_list], device=device)
230230
batch = torch.repeat_interleave(
231-
torch.arange(len(atoms_list), device=device), atoms_per_batch
231+
torch.arange(len(atoms_list), device=device), atoms_per_graph
232232
)
233233

234234
# Verify consistent pbc
@@ -298,9 +298,9 @@ def structures_to_state(
298298
)
299299

300300
# Create batch indices
301-
atoms_per_batch = torch.tensor([len(s) for s in struct_list], device=device)
301+
atoms_per_graph = torch.tensor([len(s) for s in struct_list], device=device)
302302
batch = torch.repeat_interleave(
303-
torch.arange(len(struct_list), device=device), atoms_per_batch
303+
torch.arange(len(struct_list), device=device), atoms_per_graph
304304
)
305305

306306
return ts.SimState(
@@ -369,9 +369,9 @@ def phonopy_to_state(
369369
)
370370

371371
# Create batch indices using repeat_interleave
372-
atoms_per_batch = torch.tensor([len(a) for a in phonopy_atoms_list], device=device)
372+
atoms_per_graph = torch.tensor([len(a) for a in phonopy_atoms_list], device=device)
373373
batch = torch.repeat_interleave(
374-
torch.arange(len(phonopy_atoms_list), device=device), atoms_per_batch
374+
torch.arange(len(phonopy_atoms_list), device=device), atoms_per_graph
375375
)
376376

377377
"""

torch_sim/quantities.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -60,13 +60,13 @@ def calc_kT( # noqa: N802
6060

6161
# Count degrees of freedom per batch
6262
batch_sizes = torch.bincount(batch)
63-
dof_per_batch = batch_sizes * squared_term.shape[-1] # multiply by n_dimensions
63+
dof_per_graph = batch_sizes * squared_term.shape[-1] # multiply by n_dimensions
6464

6565
# Calculate temperature per batch
6666
batch_sums = torch.segment_reduce(
6767
flattened_squared, reduce="sum", lengths=batch_sizes
6868
)
69-
return batch_sums / dof_per_batch
69+
return batch_sums / dof_per_graph
7070

7171

7272
def calc_temperature(

0 commit comments

Comments
 (0)