Skip to content

Commit 39c2ad2

Browse files
committed
git commit -m "fix: Remove cell dimension squeezing in NVT Nose-Hoover integrator" -m "- Remove problematic cell.squeeze(0) that breaks batching support" -m "- Fix calculate_momenta function call to use correct signature with batch parameter" -m "- Resolves RuntimeError when using MACE with NVT Nose-Hoover thermostat" -m "" -m "Fixes example script 3.5_MACE_NVT_Nose_Hoover.py which was failing due to" -m "neighbor list function receiving wrong tensor shapes when cell batch" -m "dimension was incorrectly removed."
1 parent 8c2543b commit 39c2ad2

File tree

1 file changed

+1
-6
lines changed

1 file changed

+1
-6
lines changed

torch_sim/integrators/nvt.py

Lines changed: 1 addition & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -324,7 +324,6 @@ def nvt_nose_hoover(
324324
4. Update chain kinetic energy
325325
5. Second half-step of chain evolution
326326
"""
327-
device, dtype = model.device, model.dtype
328327

329328
def nvt_nose_hoover_init(
330329
state: SimState | StateDict,
@@ -358,16 +357,12 @@ def nvt_nose_hoover_init(
358357
if not isinstance(state, SimState):
359358
state = SimState(**state)
360359

361-
# Check if there is an extra batch dimension
362-
if state.cell.dim() == 3:
363-
state.cell = state.cell.squeeze(0)
364-
365360
atomic_numbers = kwargs.get("atomic_numbers", state.atomic_numbers)
366361

367362
model_output = model(state)
368363
momenta = kwargs.get(
369364
"momenta",
370-
calculate_momenta(state.positions, state.masses, kT, device, dtype, seed),
365+
calculate_momenta(state.positions, state.masses, state.batch, kT, seed),
371366
)
372367

373368
# Calculate initial kinetic energy

0 commit comments

Comments
 (0)