Skip to content

Commit 2bb3bd5

Browse files
t-reentsjanosh
andauthored
Fix memory scaling in determine_max_batch_size (#212)
* Fix memory scaling in `determine_max_batch_size` The current version results in an infinite loop when `scale_factor < 1.5` due to the rounding. This is fixed by increasing the batch size by at least `+1`. * add `test_autobatching.py` check to ensure `determine_max_batch_size` does regress to infinite loop * remove outdated pymatviz extras 'export-figs' in `6.1_Phonons_MACE.py` and `6.2_QuasiHarmonic_MACE.py` * pin plotly!=6.2.0 --------- Co-authored-by: Janosh Riebesell <[email protected]>
1 parent 3d2bea0 commit 2bb3bd5

File tree

5 files changed

+33
-4
lines changed

5 files changed

+33
-4
lines changed

docs/_static/draw_pkg_treemap.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
# /// script
77
# dependencies = [
88
# "pymatviz @ git+https://github.com/janosh/pymatviz",
9+
# "plotly!=6.2.0", # TODO remove pin pending https://github.com/plotly/plotly.py/issues/5253#issuecomment-3016615635
910
# ]
1011
# ///
1112

examples/scripts/6_Phonons/6.1_Phonons_MACE.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,9 +4,10 @@
44
# dependencies = [
55
# "mace-torch>=0.3.12",
66
# "phonopy>=2.35",
7-
# "pymatviz[export-figs]>=0.15.1",
7+
# "pymatviz>=0.16",
88
# "seekpath",
99
# "ase",
10+
# "plotly!=6.2.0", # TODO remove pin pending https://github.com/plotly/plotly.py/issues/5253#issuecomment-3016615635
1011
# ]
1112
# ///
1213

examples/scripts/6_Phonons/6.2_QuasiHarmonic_MACE.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,8 @@
66
# dependencies = [
77
# "mace-torch>=0.3.12",
88
# "phonopy>=2.35",
9-
# "pymatviz[export-figs]>=0.15.1",
9+
# "pymatviz>=0.16",
10+
# "plotly!=6.2.0", # TODO remove pin pending https://github.com/plotly/plotly.py/issues/5253#issuecomment-3016615635
1011
# ]
1112
# ///
1213

tests/test_autobatching.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -376,6 +376,32 @@ def mock_measure(*_args: Any, **_kwargs: Any) -> float:
376376
assert max_size == 8
377377

378378

379+
@pytest.mark.parametrize("scale_factor", [1.1, 1.4])
380+
def test_determine_max_batch_size_small_scale_factor_no_infinite_loop(
381+
si_sim_state: ts.SimState,
382+
lj_model: LennardJonesModel,
383+
monkeypatch: pytest.MonkeyPatch,
384+
scale_factor: float,
385+
) -> None:
386+
"""Test determine_max_batch_size doesn't infinite loop with small scale factors."""
387+
monkeypatch.setattr(
388+
"torch_sim.autobatching.measure_model_memory_forward", lambda *_: 0.1
389+
)
390+
391+
max_size = determine_max_batch_size(
392+
si_sim_state, lj_model, max_atoms=20, scale_factor=scale_factor
393+
)
394+
assert 0 < max_size <= 20
395+
396+
# Verify sequence is strictly increasing (prevents infinite loop)
397+
sizes = [1]
398+
while (next_size := max(round(sizes[-1] * scale_factor), sizes[-1] + 1)) < 20:
399+
sizes.append(next_size)
400+
401+
assert all(sizes[idx] > sizes[idx - 1] for idx in range(1, len(sizes)))
402+
assert max_size == sizes[-1]
403+
404+
379405
def test_in_flight_auto_batcher_restore_order(
380406
si_sim_state: ts.SimState,
381407
fe_supercell_sim_state: ts.SimState,

torch_sim/autobatching.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -268,7 +268,7 @@ def determine_max_batch_size(
268268
Defaults to 500,000.
269269
start_size (int): Initial batch size to test. Defaults to 1.
270270
scale_factor (float): Factor to multiply batch size by in each iteration.
271-
Defaults to 1.3.
271+
Defaults to 1.6.
272272
273273
Returns:
274274
int: Maximum number of batches that fit in GPU memory.
@@ -289,7 +289,7 @@ def determine_max_batch_size(
289289
"""
290290
# Create a geometric sequence of batch sizes
291291
sizes = [start_size]
292-
while (next_size := round(sizes[-1] * scale_factor)) < max_atoms:
292+
while (next_size := max(round(sizes[-1] * scale_factor), sizes[-1] + 1)) < max_atoms:
293293
sizes.append(next_size)
294294

295295
for i in range(len(sizes)):

0 commit comments

Comments
 (0)