From 2980cbad5e4b78f4ded17f91533a36e8d8451bc5 Mon Sep 17 00:00:00 2001 From: niklashoelter Date: Wed, 8 Oct 2025 10:15:54 +0200 Subject: [PATCH 1/4] Fixed mem estimation loop to compare number of atoms --- torch_sim/autobatching.py | 2 +- torch_sim/runners.py | 1 + 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/torch_sim/autobatching.py b/torch_sim/autobatching.py index 4f556ce0..3d19a30e 100644 --- a/torch_sim/autobatching.py +++ b/torch_sim/autobatching.py @@ -289,7 +289,7 @@ def determine_max_batch_size( """ # Create a geometric sequence of batch sizes sizes = [start_size] - while (next_size := max(round(sizes[-1] * scale_factor), sizes[-1] + 1)) < max_atoms: + while (next_size := max(round(sizes[-1] * scale_factor), sizes[-1] + 1))*state.n_atoms <= max_atoms: sizes.append(next_size) for i in range(len(sizes)): diff --git a/torch_sim/runners.py b/torch_sim/runners.py index 5a11f86c..423f604a 100644 --- a/torch_sim/runners.py +++ b/torch_sim/runners.py @@ -426,6 +426,7 @@ def optimize( # noqa: C901 model, max_memory_scaler=autobatcher.max_memory_scaler, memory_scales_with=autobatcher.memory_scales_with, + max_atoms_to_try=autobatcher.max_atoms_to_try, ) autobatcher.load_states(state) trajectory_reporter = _configure_reporter( From 0175ff8975ff98a77b15e77598e70f6ee4ce7701 Mon Sep 17 00:00:00 2001 From: niklashoelter Date: Wed, 8 Oct 2025 10:36:00 +0200 Subject: [PATCH 2/4] adapted tests for fix --- tests/test_autobatching.py | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/tests/test_autobatching.py b/tests/test_autobatching.py index 81d90472..a66cc8c0 100644 --- a/tests/test_autobatching.py +++ b/tests/test_autobatching.py @@ -368,12 +368,13 @@ def mock_measure(*_args: Any, **_kwargs: Any) -> float: ) # Test with a small max_atoms value to limit the sequence - max_size = determine_max_batch_size(si_sim_state, lj_model, max_atoms=10) - + max_size = determine_max_batch_size(si_sim_state, lj_model, max_atoms=16) + print(si_sim_state.n_atoms) + print(max_size) # The Fibonacci sequence up to 10 is [1, 2, 3, 5, 8, 13] # Since we're not triggering OOM errors with our mock, it should - # return the largest value < max_atoms - assert max_size == 8 + # return the largest value that fits within max_atoms (simstate has 8 atoms, so 2 batches) + assert max_size == 2 @pytest.mark.parametrize("scale_factor", [1.1, 1.4]) @@ -395,7 +396,7 @@ def test_determine_max_batch_size_small_scale_factor_no_infinite_loop( # Verify sequence is strictly increasing (prevents infinite loop) sizes = [1] - while (next_size := max(round(sizes[-1] * scale_factor), sizes[-1] + 1)) < 20: + while (next_size := max(round(sizes[-1] * scale_factor), sizes[-1] + 1))*si_sim_state.n_atoms <= 20: sizes.append(next_size) assert all(sizes[idx] > sizes[idx - 1] for idx in range(1, len(sizes))) From 2f30358e9e8ba49f8d01176a357a92dd0a2e9c3a Mon Sep 17 00:00:00 2001 From: niklashoelter Date: Wed, 8 Oct 2025 10:45:49 +0200 Subject: [PATCH 3/4] adapted tests --- tests/test_autobatching.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/tests/test_autobatching.py b/tests/test_autobatching.py index a66cc8c0..d2b6fcc9 100644 --- a/tests/test_autobatching.py +++ b/tests/test_autobatching.py @@ -369,8 +369,6 @@ def mock_measure(*_args: Any, **_kwargs: Any) -> float: # Test with a small max_atoms value to limit the sequence max_size = determine_max_batch_size(si_sim_state, lj_model, max_atoms=16) - print(si_sim_state.n_atoms) - print(max_size) # The Fibonacci sequence up to 10 is [1, 2, 3, 5, 8, 13] # Since we're not triggering OOM errors with our mock, it should # return the largest value that fits within max_atoms (simstate has 8 atoms, so 2 batches) From b3d25661bf1ec4dab62ba99b91d198fd50db4759 Mon Sep 17 00:00:00 2001 From: niklashoelter Date: Thu, 9 Oct 2025 09:28:57 +0200 Subject: [PATCH 4/4] run pre-commit --- torch_sim/autobatching.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/torch_sim/autobatching.py b/torch_sim/autobatching.py index 3d19a30e..e0410afe 100644 --- a/torch_sim/autobatching.py +++ b/torch_sim/autobatching.py @@ -289,7 +289,9 @@ def determine_max_batch_size( """ # Create a geometric sequence of batch sizes sizes = [start_size] - while (next_size := max(round(sizes[-1] * scale_factor), sizes[-1] + 1))*state.n_atoms <= max_atoms: + while ( + next_size := max(round(sizes[-1] * scale_factor), sizes[-1] + 1) + ) * state.n_atoms <= max_atoms: sizes.append(next_size) for i in range(len(sizes)):