diff --git a/tests/test_autobatching.py b/tests/test_autobatching.py index 81d90472..d2b6fcc9 100644 --- a/tests/test_autobatching.py +++ b/tests/test_autobatching.py @@ -368,12 +368,11 @@ def mock_measure(*_args: Any, **_kwargs: Any) -> float: ) # Test with a small max_atoms value to limit the sequence - max_size = determine_max_batch_size(si_sim_state, lj_model, max_atoms=10) - + max_size = determine_max_batch_size(si_sim_state, lj_model, max_atoms=16) # The Fibonacci sequence up to 10 is [1, 2, 3, 5, 8, 13] # Since we're not triggering OOM errors with our mock, it should - # return the largest value < max_atoms - assert max_size == 8 + # return the largest value that fits within max_atoms (simstate has 8 atoms, so 2 batches) + assert max_size == 2 @pytest.mark.parametrize("scale_factor", [1.1, 1.4]) @@ -395,7 +394,7 @@ def test_determine_max_batch_size_small_scale_factor_no_infinite_loop( # Verify sequence is strictly increasing (prevents infinite loop) sizes = [1] - while (next_size := max(round(sizes[-1] * scale_factor), sizes[-1] + 1)) < 20: + while (next_size := max(round(sizes[-1] * scale_factor), sizes[-1] + 1))*si_sim_state.n_atoms <= 20: sizes.append(next_size) assert all(sizes[idx] > sizes[idx - 1] for idx in range(1, len(sizes))) diff --git a/torch_sim/autobatching.py b/torch_sim/autobatching.py index 4f556ce0..e0410afe 100644 --- a/torch_sim/autobatching.py +++ b/torch_sim/autobatching.py @@ -289,7 +289,9 @@ def determine_max_batch_size( """ # Create a geometric sequence of batch sizes sizes = [start_size] - while (next_size := max(round(sizes[-1] * scale_factor), sizes[-1] + 1)) < max_atoms: + while ( + next_size := max(round(sizes[-1] * scale_factor), sizes[-1] + 1) + ) * state.n_atoms <= max_atoms: sizes.append(next_size) for i in range(len(sizes)): diff --git a/torch_sim/runners.py b/torch_sim/runners.py index 5a11f86c..423f604a 100644 --- a/torch_sim/runners.py +++ b/torch_sim/runners.py @@ -426,6 +426,7 @@ def optimize( # noqa: C901 model, max_memory_scaler=autobatcher.max_memory_scaler, memory_scales_with=autobatcher.memory_scales_with, + max_atoms_to_try=autobatcher.max_atoms_to_try, ) autobatcher.load_states(state) trajectory_reporter = _configure_reporter(