From 7243dc9f4a6f38a63d48f8949fb5ad1da79b6cf9 Mon Sep 17 00:00:00 2001 From: Naymul Islam Date: Thu, 28 Mar 2024 15:04:52 +0600 Subject: [PATCH 1/4] add `group_by` method to `AgentSet` for attribute-based grouping Signed-off-by: Naymul Islam --- mesa/agent.py | 16 ++++++++++++++++ 1 file changed, 16 insertions(+) diff --git a/mesa/agent.py b/mesa/agent.py index 7ae76871b95..d7debb1006e 100644 --- a/mesa/agent.py +++ b/mesa/agent.py @@ -345,6 +345,22 @@ def __setstate__(self, state): self.model = state["model"] self._update(state["agents"]) + def group_by(self, attr_name: str) -> dict: + """ + Group agents in the AgentSet based on the specified attribute. + + Args: + attr_name (str): The name of the attribute to group agents by. + + Returns: + dict: A dictionary where keys are attribute values and values are lists of agents with those attribute values. + """ + grouped_agents = defaultdict(list) + for agent in self: + attr_value = getattr(agent, attr_name, None) + grouped_agents[attr_value].append(agent) + return grouped_agents + @property def random(self) -> Random: """ From 7478489d5c8f454cffb9a0284f5b51bd305140e0 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Thu, 28 Mar 2024 09:10:53 +0000 Subject: [PATCH 2/4] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- mesa/agent.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mesa/agent.py b/mesa/agent.py index d7debb1006e..170f8cac68a 100644 --- a/mesa/agent.py +++ b/mesa/agent.py @@ -360,7 +360,7 @@ def group_by(self, attr_name: str) -> dict: attr_value = getattr(agent, attr_name, None) grouped_agents[attr_value].append(agent) return grouped_agents - + @property def random(self) -> Random: """ From 322bcda45919b5dabd787988502d06456dc7b87a Mon Sep 17 00:00:00 2001 From: Naymul Islam Date: Thu, 28 Mar 2024 18:06:43 +0600 Subject: [PATCH 3/4] add test for agentset-group-by Signed-off-by: Naymul Islam --- tests/test_agent.py | 13 +++++++++++++ 1 file changed, 13 insertions(+) diff --git a/tests/test_agent.py b/tests/test_agent.py index 0cd211123e1..4bb58d3e59c 100644 --- a/tests/test_agent.py +++ b/tests/test_agent.py @@ -282,3 +282,16 @@ def test_agentset_shuffle(): agentset = AgentSet(test_agents, model=model) agentset.shuffle(inplace=True) assert not all(a1 == a2 for a1, a2 in zip(test_agents, agentset)) + +def test_agentset_group_by(): + model = Model() + agents = [TestAgent(model.next_id(), model) for _ in range(10)] + for i, agent in enumerate(agents): + agent.category = i % 2 # Assign categories 0 or 1 + agentset = AgentSet(agents, model) + + grouped = agentset.group_by("category") + assert len(grouped[0]) == 5 + assert len(grouped[1]) == 5 + assert all(agent.category == 0 for agent in grouped[0]) + assert all(agent.category == 1 for agent in grouped[1]) \ No newline at end of file From d5fc9b9e8a8a46697123ab6708b75a73d168bd01 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Thu, 28 Mar 2024 12:08:11 +0000 Subject: [PATCH 4/4] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- tests/test_agent.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/tests/test_agent.py b/tests/test_agent.py index 4bb58d3e59c..7b90398e3b0 100644 --- a/tests/test_agent.py +++ b/tests/test_agent.py @@ -283,6 +283,7 @@ def test_agentset_shuffle(): agentset.shuffle(inplace=True) assert not all(a1 == a2 for a1, a2 in zip(test_agents, agentset)) + def test_agentset_group_by(): model = Model() agents = [TestAgent(model.next_id(), model) for _ in range(10)] @@ -294,4 +295,4 @@ def test_agentset_group_by(): assert len(grouped[0]) == 5 assert len(grouped[1]) == 5 assert all(agent.category == 0 for agent in grouped[0]) - assert all(agent.category == 1 for agent in grouped[1]) \ No newline at end of file + assert all(agent.category == 1 for agent in grouped[1])