@@ -135,8 +135,8 @@ def _compute_cell_force(
135
135
3 , device = state .device , dtype = state .dtype
136
136
).unsqueeze (0 )
137
137
138
- # Correct implementation with scaling by n_atoms_per_batch
139
- return virial + e_kin_per_atom * state .n_atoms_per_batch .view (- 1 , 1 , 1 )
138
+ # Correct implementation with scaling by n_atoms_per_graph
139
+ return virial + e_kin_per_atom * state .n_atoms_per_graph .view (- 1 , 1 , 1 )
140
140
141
141
142
142
def npt_langevin ( # noqa: C901, PLR0915
@@ -663,13 +663,13 @@ def npt_init(
663
663
664
664
# Calculate cell masses based on system size and temperature
665
665
# This follows standard NPT barostat mass scaling
666
- n_atoms_per_batch = torch .bincount (state .batch )
666
+ n_atoms_per_graph = torch .bincount (state .batch )
667
667
batch_kT = (
668
668
kT .expand (state .n_batches )
669
669
if isinstance (kT , torch .Tensor ) and kT .ndim == 0
670
670
else kT
671
671
)
672
- cell_masses = (n_atoms_per_batch + 1 ) * batch_kT * b_tau * b_tau
672
+ cell_masses = (n_atoms_per_graph + 1 ) * batch_kT * b_tau * b_tau
673
673
674
674
# Create the initial state
675
675
return NPTLangevinState (
@@ -735,8 +735,8 @@ def npt_update(
735
735
736
736
# Update barostat mass based on current temperature
737
737
# This ensures proper coupling between system and barostat
738
- n_atoms_per_batch = torch .bincount (state .batch )
739
- state .cell_masses = (n_atoms_per_batch + 1 ) * batch_kT * b_tau * b_tau
738
+ n_atoms_per_graph = torch .bincount (state .batch )
739
+ state .cell_masses = (n_atoms_per_graph + 1 ) * batch_kT * b_tau * b_tau
740
740
741
741
# Compute model output for current state
742
742
model_output = model (state )
@@ -1017,8 +1017,8 @@ def update_cell_mass(
1017
1017
kT_batch = kT .expand (state .n_batches ) if kT .ndim == 0 else kT
1018
1018
1019
1019
# Calculate cell masses for each batch
1020
- n_atoms_per_batch = torch .bincount (state .batch , minlength = state .n_batches )
1021
- cell_mass = dim * (n_atoms_per_batch + 1 ) * kT_batch * state .barostat .tau ** 2
1020
+ n_atoms_per_graph = torch .bincount (state .batch , minlength = state .n_batches )
1021
+ cell_mass = dim * (n_atoms_per_graph + 1 ) * kT_batch * state .barostat .tau ** 2
1022
1022
1023
1023
# Update state with new cell masses
1024
1024
state .cell_mass = cell_mass .to (device = device , dtype = dtype )
@@ -1213,15 +1213,15 @@ def compute_cell_force(
1213
1213
1214
1214
# Compute kinetic energy contribution per batch
1215
1215
# Split momenta and masses by batch
1216
- KE_per_batch = torch .zeros (
1216
+ KE_per_graph = torch .zeros (
1217
1217
n_batches , device = positions .device , dtype = positions .dtype
1218
1218
)
1219
1219
for b in range (n_batches ):
1220
1220
batch_mask = batch == b
1221
1221
if batch_mask .any ():
1222
1222
batch_momenta = momenta [batch_mask ]
1223
1223
batch_masses = masses [batch_mask ]
1224
- KE_per_batch [b ] = calc_kinetic_energy (batch_momenta , batch_masses )
1224
+ KE_per_graph [b ] = calc_kinetic_energy (batch_momenta , batch_masses )
1225
1225
1226
1226
# Get stress tensor and compute trace per batch
1227
1227
# Handle stress tensor with batch dimension
@@ -1236,7 +1236,7 @@ def compute_cell_force(
1236
1236
# Compute force on cell coordinate per batch
1237
1237
# F = alpha * KE - dU/dV - P*V*d
1238
1238
return (
1239
- (alpha * KE_per_batch )
1239
+ (alpha * KE_per_graph )
1240
1240
- (internal_pressure * volume )
1241
1241
- (external_pressure * volume * dim )
1242
1242
)
@@ -1285,8 +1285,8 @@ def npt_inner_step(
1285
1285
model_output = model (state )
1286
1286
1287
1287
# First half step: Update momenta
1288
- n_atoms_per_batch = torch .bincount (state .batch , minlength = state .n_batches )
1289
- alpha = 1 + 1 / n_atoms_per_batch # [n_batches]
1288
+ n_atoms_per_graph = torch .bincount (state .batch , minlength = state .n_batches )
1289
+ alpha = 1 + 1 / n_atoms_per_graph # [n_batches]
1290
1290
1291
1291
cell_force_val = compute_cell_force (
1292
1292
alpha = alpha ,
@@ -1425,8 +1425,8 @@ def npt_nose_hoover_init(
1425
1425
kT_batch = kT .expand (n_batches ) if kT .ndim == 0 else kT
1426
1426
1427
1427
# Calculate cell masses for each batch
1428
- n_atoms_per_batch = torch .bincount (state .batch , minlength = n_batches )
1429
- cell_mass = dim * (n_atoms_per_batch + 1 ) * kT_batch * b_tau ** 2
1428
+ n_atoms_per_graph = torch .bincount (state .batch , minlength = n_batches )
1429
+ cell_mass = dim * (n_atoms_per_graph + 1 ) * kT_batch * b_tau ** 2
1430
1430
cell_mass = cell_mass .to (device = device , dtype = dtype )
1431
1431
1432
1432
# Calculate cell kinetic energy (using first batch for initialization)
@@ -1596,19 +1596,19 @@ def npt_nose_hoover_invariant(
1596
1596
e_pot = state .energy # Should be scalar or [n_batches]
1597
1597
1598
1598
# Calculate kinetic energy of particles per batch
1599
- e_kin_per_batch = calc_kinetic_energy (state .momenta , state .masses , batch = state .batch )
1599
+ e_kin_per_graph = calc_kinetic_energy (state .momenta , state .masses , batch = state .batch )
1600
1600
1601
1601
# Calculate degrees of freedom per batch
1602
- n_atoms_per_batch = torch .bincount (state .batch )
1603
- DOF_per_batch = (
1604
- n_atoms_per_batch * state .positions .shape [- 1 ]
1602
+ n_atoms_per_graph = torch .bincount (state .batch )
1603
+ DOF_per_graph = (
1604
+ n_atoms_per_graph * state .positions .shape [- 1 ]
1605
1605
) # n_atoms * n_dimensions
1606
1606
1607
1607
# Initialize total energy with PE + KE
1608
1608
if isinstance (e_pot , torch .Tensor ) and e_pot .ndim > 0 :
1609
- e_tot = e_pot + e_kin_per_batch # [n_batches]
1609
+ e_tot = e_pot + e_kin_per_graph # [n_batches]
1610
1610
else :
1611
- e_tot = e_pot + e_kin_per_batch # [n_batches]
1611
+ e_tot = e_pot + e_kin_per_graph # [n_batches]
1612
1612
1613
1613
# Add thermostat chain contributions
1614
1614
# Note: These are global thermostat variables, so we add them to each batch
@@ -1618,14 +1618,14 @@ def npt_nose_hoover_invariant(
1618
1618
2 * state .thermostat .masses [0 ]
1619
1619
)
1620
1620
1621
- # Ensure kT can broadcast properly with DOF_per_batch
1621
+ # Ensure kT can broadcast properly with DOF_per_graph
1622
1622
if isinstance (kT , torch .Tensor ) and kT .ndim == 0 :
1623
- # Scalar kT - expand to match DOF_per_batch shape
1624
- kT_expanded = kT .expand_as (DOF_per_batch )
1623
+ # Scalar kT - expand to match DOF_per_graph shape
1624
+ kT_expanded = kT .expand_as (DOF_per_graph )
1625
1625
else :
1626
1626
kT_expanded = kT
1627
1627
1628
- thermostat_energy += DOF_per_batch * kT_expanded * state .thermostat .positions [0 ]
1628
+ thermostat_energy += DOF_per_graph * kT_expanded * state .thermostat .positions [0 ]
1629
1629
1630
1630
# Add remaining thermostat terms
1631
1631
for pos , momentum , mass in zip (
0 commit comments