Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 15 additions & 8 deletions mesa/visualization/mpl_space_drawing.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,8 +66,12 @@ def get_agent_pos(agent, space):
elif isinstance(space, Network):
agent_x, agent_y = agent.cell.coordinate, agent.cell.coordinate
else:
agent_x = agent.pos[0] if agent.pos else agent.cell.coordinate[0]
agent_y = agent.pos[1] if agent.pos else agent.cell.coordinate[1]
agent_x = (
agent.pos[0] if agent.pos is not None else agent.cell.coordinate[0]
)
agent_y = (
agent.pos[1] if agent.pos is not None else agent.cell.coordinate[1]
)
return agent_x, agent_y

arguments = {
Expand Down Expand Up @@ -104,14 +108,12 @@ def get_agent_pos(agent, space):
agent_x, agent_y = get_agent_pos(agent, space)

# Extract values from the dict, using defaults if not provided
size_val = dict_data.pop("s", style_fields.get("size"))
color_val = dict_data.pop("c", style_fields.get("color"))
size_val = dict_data.pop("size", style_fields.get("size"))
color_val = dict_data.pop("color", style_fields.get("color"))
marker_val = dict_data.pop("marker", style_fields.get("marker"))
zorder_val = dict_data.pop("zorder", style_fields.get("zorder"))
alpha_val = dict_data.pop("alpha", style_fields.get("alpha"))
edgecolors_val = dict_data.pop(
"edgecolors", color_val
) # default to agent's color if not provided
edgecolors_val = dict_data.pop("edgecolors", None)
linewidths_val = dict_data.pop("linewidths", style_fields.get("linewidths"))

aps = AgentPortrayalStyle(
Expand Down Expand Up @@ -158,7 +160,8 @@ def get_agent_pos(agent, space):
arguments["marker"].append(aps.marker)
arguments["zorder"].append(aps.zorder)
arguments["alpha"].append(aps.alpha)
arguments["edgecolors"].append(aps.edgecolors)
if aps.edgecolors is not None:
arguments["edgecolors"].append(aps.edgecolors)
arguments["linewidths"].append(aps.linewidths)

data = {
Expand Down Expand Up @@ -756,6 +759,10 @@ def _scatter(ax: Axes, arguments, **kwargs):
zorder_mask = z_order == zorder
logical = mark_mask & zorder_mask

# No agents with this marker and z-order, skip
if not np.any(logical):
continue

ax.scatter(
x[logical],
y[logical],
Expand Down