Skip to content

Commit 89835fd

Browse files
niklashoelterorionarcherCompRhys
authored
Fixed max atoms memory estimation (#279)
Co-authored-by: Orion Cohen <[email protected]> Co-authored-by: Rhys Goodall <[email protected]>
1 parent e4c5608 commit 89835fd

File tree

3 files changed

+11
-7
lines changed

3 files changed

+11
-7
lines changed

tests/test_autobatching.py

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -361,12 +361,11 @@ def mock_measure(*_args: Any, **_kwargs: Any) -> float:
361361
)
362362

363363
# Test with a small max_atoms value to limit the sequence
364-
max_size = determine_max_batch_size(si_sim_state, lj_model, max_atoms=10)
365-
364+
max_size = determine_max_batch_size(si_sim_state, lj_model, max_atoms=16)
366365
# The Fibonacci sequence up to 10 is [1, 2, 3, 5, 8, 13]
367-
# Since we're not triggering OOM errors with our mock, it should
368-
# return the largest value < max_atoms
369-
assert max_size == 8
366+
# Since we're not triggering OOM errors with our mock, it should return the
367+
# largest value that fits within max_atoms (simstate has 8 atoms, so 2 batches)
368+
assert max_size == 2
370369

371370

372371
@pytest.mark.parametrize("scale_factor", [1.1, 1.4])
@@ -388,7 +387,9 @@ def test_determine_max_batch_size_small_scale_factor_no_infinite_loop(
388387

389388
# Verify sequence is strictly increasing (prevents infinite loop)
390389
sizes = [1]
391-
while (next_size := max(round(sizes[-1] * scale_factor), sizes[-1] + 1)) < 20:
390+
while (
391+
next_size := max(round(sizes[-1] * scale_factor), sizes[-1] + 1)
392+
) * si_sim_state.n_atoms <= 20:
392393
sizes.append(next_size)
393394

394395
assert all(sizes[idx] > sizes[idx - 1] for idx in range(1, len(sizes)))

torch_sim/autobatching.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -287,7 +287,9 @@ def determine_max_batch_size(
287287
"""
288288
# Create a geometric sequence of batch sizes
289289
sizes = [start_size]
290-
while (next_size := max(round(sizes[-1] * scale_factor), sizes[-1] + 1)) < max_atoms:
290+
while (
291+
next_size := max(round(sizes[-1] * scale_factor), sizes[-1] + 1)
292+
) * state.n_atoms <= max_atoms:
291293
sizes.append(next_size)
292294

293295
for sys_idx in range(len(sizes)):

torch_sim/runners.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -449,6 +449,7 @@ def optimize[T: OptimState]( # noqa: C901, PLR0915
449449
init_kwargs=dict(**init_kwargs or {}),
450450
max_memory_scaler=autobatcher.max_memory_scaler,
451451
memory_scales_with=autobatcher.memory_scales_with,
452+
max_atoms_to_try=autobatcher.max_atoms_to_try,
452453
)
453454
autobatcher.load_states(state)
454455
if trajectory_reporter is not None:

0 commit comments

Comments
 (0)