Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 28 additions & 4 deletions mesa/batchrunner.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,27 +30,32 @@

import itertools
import multiprocessing
from collections.abc import Iterable, Mapping
import warnings
from collections.abc import Iterable, Mapping, Sequence
from functools import partial
from multiprocessing import Pool
from typing import Any

import numpy as np
from tqdm.auto import tqdm

from mesa.model import Model

multiprocessing.set_start_method("spawn", force=True)

SeedLike = int | np.integer | Sequence[int] | np.random.SeedSequence


def batch_run(
model_cls: type[Model],
parameters: Mapping[str, Any | Iterable[Any]],
# We still retain the Optional[int] because users may set it to None (i.e. use all CPUs)
number_processes: int | None = 1,
iterations: int = 1,
iterations: int | None = None,
data_collection_period: int = -1,
max_steps: int = 1000,
display_progress: bool = True,
rng: SeedLike | Iterable[SeedLike] | None = None,
) -> list[dict[str, Any]]:
"""Batch run a mesa model with a set of parameter values.

Expand All @@ -62,6 +67,7 @@ def batch_run(
data_collection_period (int, optional): Number of steps after which data gets collected, by default -1 (end of episode)
max_steps (int, optional): Maximum number of model steps after which the model halts, by default 1000
display_progress (bool, optional): Display batch run process, by default True
rng : a valid value or iterable of values for seeding the random number generator in the model

Returns:
List[Dict[str, Any]]
Expand All @@ -70,11 +76,28 @@ def batch_run(
batch_run assumes the model has a `datacollector` attribute that has a DataCollector object initialized.

"""
if iterations is not None and rng is not None:
raise ValueError(
"you cannot use both iterations and rng at the same time. Please only use rng."
)
if iterations is not None:
warnings.warn(
"iterations is deprecated, please use rng instead",
DeprecationWarning,
stacklevel=2,
)
rng = [
None,
] * iterations
if not isinstance(rng, Iterable):
rng = [rng]

runs_list = []
run_id = 0
for iteration in range(iterations):
for i, rng_i in enumerate(rng):
for kwargs in _make_model_kwargs(parameters):
runs_list.append((run_id, iteration, kwargs))
kwargs["rng"] = rng_i
runs_list.append((run_id, i, kwargs))
run_id += 1

process_func = partial(
Expand Down Expand Up @@ -170,6 +193,7 @@ def _model_run_func(
Return model_data, agent_data from the reporters
"""
run_id, iteration, kwargs = run

model = model_cls(**kwargs)
while model.running and model.steps <= max_steps:
model.step()
Expand Down
110 changes: 109 additions & 1 deletion tests/test_batch_run.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
"""Test Batchrunner."""

import pytest

import mesa
from mesa.agent import Agent
from mesa.batchrunner import _make_model_kwargs
Expand Down Expand Up @@ -130,7 +132,7 @@ def step(self): # noqa: D102


def test_batch_run(): # noqa: D103
result = mesa.batch_run(MockModel, {}, number_processes=2)
result = mesa.batch_run(MockModel, {}, number_processes=2, rng=42)
assert result == [
{
"RunId": 0,
Expand All @@ -140,6 +142,7 @@ def test_batch_run(): # noqa: D103
"AgentID": 1,
"agent_id": 1,
"agent_local": 250.0,
"rng": 42,
},
{
"RunId": 0,
Expand All @@ -149,6 +152,7 @@ def test_batch_run(): # noqa: D103
"AgentID": 2,
"agent_id": 2,
"agent_local": 250.0,
"rng": 42,
},
{
"RunId": 0,
Expand All @@ -158,9 +162,111 @@ def test_batch_run(): # noqa: D103
"AgentID": 3,
"agent_id": 3,
"agent_local": 250.0,
"rng": 42,
},
]

result = mesa.batch_run(MockModel, {}, number_processes=2, iterations=1)
assert result == [
{
"RunId": 0,
"iteration": 0,
"Step": 1000,
"reported_model_param": 42,
"AgentID": 1,
"agent_id": 1,
"agent_local": 250.0,
"rng": None,
},
{
"RunId": 0,
"iteration": 0,
"Step": 1000,
"reported_model_param": 42,
"AgentID": 2,
"agent_id": 2,
"agent_local": 250.0,
"rng": None,
},
{
"RunId": 0,
"iteration": 0,
"Step": 1000,
"reported_model_param": 42,
"AgentID": 3,
"agent_id": 3,
"agent_local": 250.0,
"rng": None,
},
]

result = mesa.batch_run(MockModel, {}, number_processes=2, rng=[42, 31415])
assert result == [
{
"RunId": 0,
"iteration": 0,
"Step": 1000,
"reported_model_param": 42,
"AgentID": 1,
"agent_id": 1,
"agent_local": 250.0,
"rng": 42,
},
{
"RunId": 0,
"iteration": 0,
"Step": 1000,
"reported_model_param": 42,
"AgentID": 2,
"agent_id": 2,
"agent_local": 250.0,
"rng": 42,
},
{
"RunId": 0,
"iteration": 0,
"Step": 1000,
"reported_model_param": 42,
"AgentID": 3,
"agent_id": 3,
"agent_local": 250.0,
"rng": 42,
},
{
"RunId": 1,
"iteration": 1,
"Step": 1000,
"reported_model_param": 42,
"AgentID": 1,
"agent_id": 1,
"agent_local": 250.0,
"rng": 31415,
},
{
"RunId": 1,
"iteration": 1,
"Step": 1000,
"reported_model_param": 42,
"AgentID": 2,
"agent_id": 2,
"agent_local": 250.0,
"rng": 31415,
},
{
"RunId": 1,
"iteration": 1,
"Step": 1000,
"reported_model_param": 42,
"AgentID": 3,
"agent_id": 3,
"agent_local": 250.0,
"rng": 31415,
},
]

with pytest.raises(ValueError):
mesa.batch_run(MockModel, {}, number_processes=2, rng=42, iterations=1)


def test_batch_run_with_params(): # noqa: D103
mesa.batch_run(
Expand All @@ -185,6 +291,7 @@ def test_batch_run_no_agent_reporters(): # noqa: D103
"Step": 1000,
"enable_agent_reporters": False,
"reported_model_param": 42,
"rng": None,
}
]

Expand All @@ -208,6 +315,7 @@ def test_batch_run_unhashable_param(): # noqa: D103
"agent_local": 250.0,
"n_agents": 2,
"variable_model_params": {"key": "value"},
"rng": None,
}

assert result == [
Expand Down
Loading