diff --git a/mesa/examples/advanced/sugarscape_g1mt/app.py b/mesa/examples/advanced/sugarscape_g1mt/app.py index 537336ef5ee..dfa00327bb4 100644 --- a/mesa/examples/advanced/sugarscape_g1mt/app.py +++ b/mesa/examples/advanced/sugarscape_g1mt/app.py @@ -1,16 +1,26 @@ from mesa.examples.advanced.sugarscape_g1mt.model import SugarscapeG1mt from mesa.visualization import Slider, SolaraViz, make_plot_component +from mesa.visualization.components import AgentPortrayalStyle, PropertyLayerStyle from mesa.visualization.components.matplotlib_components import make_mpl_space_component def agent_portrayal(agent): - return {"marker": "o", "color": "red", "size": 10} + return AgentPortrayalStyle( + x=agent.cell.coordinate[0], + y=agent.cell.coordinate[1], + color="red", + marker="o", + size=10, + zorder=1, + ) -propertylayer_portrayal = { - "sugar": {"color": "blue", "alpha": 0.8, "colorbar": True, "vmin": 0, "vmax": 10}, - "spice": {"color": "red", "alpha": 0.8, "colorbar": True, "vmin": 0, "vmax": 10}, -} +def propertylayer_portrayal(layer): + if layer.name == "sugar": + return PropertyLayerStyle( + color="blue", alpha=0.8, colorbar=True, vmin=0, vmax=10 + ) + return PropertyLayerStyle(color="red", alpha=0.8, colorbar=True, vmin=0, vmax=10) sugarscape_space = make_mpl_space_component( diff --git a/mesa/examples/advanced/sugarscape_g1mt/tests.py b/mesa/examples/advanced/sugarscape_g1mt/tests.py deleted file mode 100644 index a25715ea0d5..00000000000 --- a/mesa/examples/advanced/sugarscape_g1mt/tests.py +++ /dev/null @@ -1,69 +0,0 @@ -import numpy as np -from scipy import stats - -from .agents import Trader -from .model import SugarscapeG1mt, flatten - - -def check_slope(y, increasing): - x = range(len(y)) - slope, intercept, _, p_value, _ = stats.linregress(x, y) - result = (slope > 0) if increasing else (slope < 0) - # p_value for significance. - assert result and p_value < 0.05, (slope, p_value) - - -def test_decreasing_price_variance(): - # The variance of the average trade price should decrease over time (figure IV-3) - # See Growing Artificial Societies p. 109. - model = SugarscapeG1mt(42) - model.datacollector._new_model_reporter( - "price_variance", - lambda m: np.var( - flatten([a.prices for a in m.agents_by_type[Trader].values()]) - ), - ) - model.run_model(step_count=50) - - df_model = model.datacollector.get_model_vars_dataframe() - - check_slope(df_model.price_variance, increasing=False) - - -def test_carrying_capacity(): - def calculate_carrying_capacities(enable_trade): - carrying_capacities = [] - visions = range(1, 10) - for vision_max in visions: - model = SugarscapeG1mt(vision_max=vision_max, enable_trade=enable_trade) - model.run_model(step_count=50) - carrying_capacities.append(len(model.agents_by_type[Trader])) - return carrying_capacities - - # Carrying capacity should increase over mean vision (figure IV-6). - # See Growing Artificial Societies p. 112. - carrying_capacities_with_trade = calculate_carrying_capacities(True) - check_slope( - carrying_capacities_with_trade, - increasing=True, - ) - # Carrying capacity should be higher when trade is enabled (figure IV-6). - carrying_capacities_no_trade = calculate_carrying_capacities(False) - check_slope( - carrying_capacities_no_trade, - increasing=True, - ) - - t_statistic, p_value = stats.ttest_rel( - carrying_capacities_with_trade, carrying_capacities_no_trade - ) - # t_statistic > 0 means carrying_capacities_with_trade has larger values - # than carrying_capacities_no_trade. - # p_value for significance. - assert t_statistic > 0 and p_value < 0.05 - - -# TODO: -# 1. Reproduce figure IV-12 that the log of average price should decrease over average agent age -# 2. Reproduce figure IV-13 that the gini coefficient on trade should decrease over mean vision, and should be higher with trade -# 3. a stricter test would be to ensure the amount of variance of the trade price matches figure IV-3 diff --git a/mesa/visualization/components/__init__.py b/mesa/visualization/components/__init__.py index 4b70fc2b97c..db21723c404 100644 --- a/mesa/visualization/components/__init__.py +++ b/mesa/visualization/components/__init__.py @@ -1,4 +1,4 @@ -"""custom solara components.""" +"""Custom visualization components.""" from __future__ import annotations @@ -10,6 +10,19 @@ make_mpl_plot_component, make_mpl_space_component, ) +from .portrayal_components import AgentPortrayalStyle, PropertyLayerStyle + +__all__ = [ + "AgentPortrayalStyle", + "PropertyLayerStyle", + "SpaceAltair", + "SpaceMatplotlib", + "make_altair_space", + "make_mpl_plot_component", + "make_mpl_space_component", + "make_plot_component", + "make_space_component", +] def make_space_component( diff --git a/mesa/visualization/components/portrayal_components.py b/mesa/visualization/components/portrayal_components.py new file mode 100644 index 00000000000..f72e1cd2638 --- /dev/null +++ b/mesa/visualization/components/portrayal_components.py @@ -0,0 +1,79 @@ +"""Portrayal Components Module. + +This module defines data structures for styling visual elements in Mesa agent-based model visualizations. +It provides user-facing classes to specify how agents and property layers should appear in the rendered space. + +Classes: +- AgentPortrayalStyle: Controls the appearance of individual agents (e.g., color, shape, size, etc.). +- PropertyLayerStyle: Controls the appearance of background property layers (e.g., color gradients or uniform fills). + +These components are designed to be passed into Mesa visualizations to customize and standardize how data is presented. +""" + +from dataclasses import dataclass +from typing import Any + + +@dataclass +class AgentPortrayalStyle: + """Represents the visual styling options for an agent in a visualization. + + User facing component to control how agents are drawn. + Allows specifying properties like color, size, + marker shape, position, and other plot attributes. + """ + + x: float | None = None + y: float | None = None + color: str | tuple | None = "tab:blue" + marker: str | None = "o" + size: int | float | None = 50 + zorder: int | None = 1 + alpha: float | None = 1.0 + edgecolors: str | tuple | None = None + linewidths: float | int | None = 1.0 + + def update(self, *updates_fields: tuple[str, Any]): + """Updates attributes from variable (field_name, new_value) tuple arguments. + + Example: + >>> def agent_portrayal(agent): + >>> primary_style = AgentPortrayalStyle(color="blue", marker="^", size=10, x=agent.pos[0], y=agent.pos[1]) + >>> if agent.type == 1: + >>> primary_style.update(("color", "red"), ("size", 30)) + >>> return primary_style + """ + for field_to_change, field_to_change_to in updates_fields: + if hasattr(self, field_to_change): + setattr(self, field_to_change, field_to_change_to) + else: + raise AttributeError( + f"'{type(self).__name__}' object has no attribute '{field_to_change}'" + ) + + +@dataclass +class PropertyLayerStyle: + """Represents the visual styling options for a property layer in a visualization. + + User facing component to control how property layers are drawn. + Allows specifying properties like colormap, single color, value limits, + and colorbar visibility. + + Note: You can specify either a 'colormap' (for varying data) or a single + 'color' (for a uniform layer appearance), but not both simultaneously. + """ + + colormap: str | None = None + color: str | None = None + alpha: float = 0.8 + colorbar: bool = True + vmin: float | None = None + vmax: float | None = None + + def __post_init__(self): + """Validate that color and colormap are not simultaneously specified.""" + if self.color is not None and self.colormap is not None: + raise ValueError("Specify either 'color' or 'colormap', not both.") + if self.color is None and self.colormap is None: + raise ValueError("Specify one of 'color' or 'colormap'") diff --git a/mesa/visualization/mpl_space_drawing.py b/mesa/visualization/mpl_space_drawing.py index 93eadd80079..ffd2144dfa5 100644 --- a/mesa/visualization/mpl_space_drawing.py +++ b/mesa/visualization/mpl_space_drawing.py @@ -6,10 +6,10 @@ """ -import contextlib import itertools import warnings from collections.abc import Callable +from dataclasses import fields from functools import lru_cache from itertools import pairwise from typing import Any @@ -35,7 +35,6 @@ HexSingleGrid, MultiGrid, NetworkGrid, - PropertyLayer, SingleGrid, ) @@ -47,59 +46,120 @@ def collect_agent_data( space: OrthogonalGrid | HexGrid | Network | ContinuousSpace | VoronoiGrid, agent_portrayal: Callable, - color="tab:blue", - size=25, - marker="o", - zorder: int = 1, -): + default_size: float | None = None, +) -> dict: """Collect the plotting data for all agents in the space. Args: space: The space containing the Agents. - agent_portrayal: A callable that is called with the agent and returns a dict - color: default color - size: default size - marker: default marker - zorder: default zorder - - agent_portrayal should return a dict, limited to size (size of marker), color (color of marker), zorder (z-order), - marker (marker style), alpha, linewidths, and edgecolors + agent_portrayal: A callable that is called with the agent and returns a AgentPortrayalStyle + default_size: default size + agent_portrayal should return a AgentPortrayalStyle, limited to size (size of marker), color (color of marker), zorder (z-order), + marker (marker style), alpha, linewidths, and edgecolors. """ + + def get_agent_pos(agent, space): + """Helper function to get the agent position depending on the grid type.""" + if isinstance(space, NetworkGrid): + agent_x, agent_y = agent.pos, agent.pos + elif isinstance(space, Network): + agent_x, agent_y = agent.cell.coordinate, agent.cell.coordinate + else: + agent_x = agent.pos[0] if agent.pos else agent.cell.coordinate[0] + agent_y = agent.pos[1] if agent.pos else agent.cell.coordinate[1] + return agent_x, agent_y + arguments = { + "loc": [], "s": [], "c": [], "marker": [], "zorder": [], - "loc": [], "alpha": [], "edgecolors": [], "linewidths": [], } + # Importing AgentPortrayalStyle inside the function to prevent circular imports + from mesa.visualization.components import AgentPortrayalStyle + + # Get AgentPortrayalStyle defaults + style_fields = {f.name: f.default for f in fields(AgentPortrayalStyle)} + class_default_size = style_fields.get("size") + for agent in space.agents: - portray = agent_portrayal(agent) - loc = agent.pos - if loc is None: - loc = agent.cell.coordinate - - arguments["loc"].append(loc) - arguments["s"].append(portray.pop("size", size)) - arguments["c"].append(portray.pop("color", color)) - arguments["marker"].append(portray.pop("marker", marker)) - arguments["zorder"].append(portray.pop("zorder", zorder)) - - for entry in ["alpha", "edgecolors", "linewidths"]: - with contextlib.suppress(KeyError): - arguments[entry].append(portray.pop(entry)) - - if len(portray) > 0: - ignored_fields = list(portray.keys()) - msg = ", ".join(ignored_fields) + portray_input = agent_portrayal(agent) + aps: AgentPortrayalStyle + + if isinstance(portray_input, dict): warnings.warn( - f"the following fields are not used in agent portrayal and thus ignored: {msg}.", + "Returning a dict from agent_portrayal is deprecated and will be removed " + "in a future version. Please return an AgentPortrayalStyle instance instead.", + DeprecationWarning, stacklevel=2, ) + dict_data = portray_input.copy() + + agent_x, agent_y = get_agent_pos(agent, space) + + # Extract values from the dict, using defaults if not provided + size_val = dict_data.pop("s", style_fields.get("size")) + color_val = dict_data.pop("c", style_fields.get("color")) + marker_val = dict_data.pop("marker", style_fields.get("marker")) + zorder_val = dict_data.pop("zorder", style_fields.get("zorder")) + alpha_val = dict_data.pop("alpha", style_fields.get("alpha")) + edgecolors_val = dict_data.pop( + "edgecolors", color_val + ) # default to agent's color if not provided + linewidths_val = dict_data.pop("linewidths", style_fields.get("linewidths")) + + aps = AgentPortrayalStyle( + x=agent_x, + y=agent_y, + size=size_val, + color=color_val, + marker=marker_val, + zorder=zorder_val, + alpha=alpha_val, + edgecolors=edgecolors_val, + linewidths=linewidths_val, + ) + + # Report list of unused data + if dict_data: + ignored_keys = list(dict_data.keys()) + warnings.warn( + f"The following keys from the returned dict were ignored: {', '.join(ignored_keys)}", + UserWarning, + stacklevel=2, + ) + else: + aps = portray_input + # default to agent's color if not provided + if aps.edgecolors is None: + aps.edgecolors = aps.color + # get position if not specified + if aps.x is None and aps.y is None: + aps.x, aps.y = get_agent_pos(agent, space) + + # Collect common data from the AgentPortrayalStyle instance + arguments["loc"].append((aps.x, aps.y)) + + # Determine final size for collection + size_to_collect = aps.size + if size_to_collect is None: + size_to_collect = default_size + if size_to_collect is None: + size_to_collect = class_default_size + + arguments["s"].append(size_to_collect) + arguments["c"].append(aps.color) + arguments["marker"].append(aps.marker) + arguments["zorder"].append(aps.zorder) + arguments["alpha"].append(aps.alpha) + arguments["edgecolors"].append(aps.edgecolors) + arguments["linewidths"].append(aps.linewidths) data = { k: (np.asarray(v, dtype=object) if k == "marker" else np.asarray(v)) @@ -115,7 +175,7 @@ def collect_agent_data( def draw_space( space, agent_portrayal: Callable, - propertylayer_portrayal: dict | None = None, + propertylayer_portrayal: Callable | None = None, ax: Axes | None = None, **space_drawing_kwargs, ): @@ -123,15 +183,15 @@ def draw_space( Args: space: the space of the mesa model - agent_portrayal: A callable that returns a dict specifying how to show the agent - propertylayer_portrayal: a dict specifying how to show propertylayer(s) + agent_portrayal: A callable that returns a AgnetPortrayalStyle specifying how to show the agent + propertylayer_portrayal: A callable that returns a PropertyLayerStyle specifying how to show the property layer ax: the axes upon which to draw the plot space_drawing_kwargs: any additional keyword arguments to be passed on to the underlying function for drawing the space. Returns: Returns the Axes object with the plot drawn onto it. - ``agent_portrayal`` is called with an agent and should return a dict. Valid fields in this dict are "color", + ``agent_portrayal`` is called with an agent and should return a AgentPortrayalStyle. Valid fields in this object are "color", "size", "marker", "zorder", alpha, linewidths, and edgecolors. Other field are ignored and will result in a user warning. """ @@ -203,21 +263,52 @@ def _get_hex_vertices( def draw_property_layers( - space, propertylayer_portrayal: dict[str, dict[str, Any]], ax: Axes + space, propertylayer_portrayal: dict[str, dict[str, Any]] | Callable, ax: Axes ): """Draw PropertyLayers on the given axes. Args: space (mesa.space._Grid): The space containing the PropertyLayers. - propertylayer_portrayal (dict): the key is the name of the layer, the value is a dict with - fields specifying how the layer is to be portrayed + propertylayer_portrayal (Callable): A function that accepts a property layer object + and returns either a `PropertyLayerStyle` object defining its visualization, + or `None` to skip drawing this particular layer. ax (matplotlib.axes.Axes): The axes to draw on. - Notes: - valid fields in in the inner dict of propertylayer_portrayal are "alpha", "vmin", "vmax", "color" or "colormap", and "colorbar" - so you can do `{"some_layer":{"colormap":'viridis', 'alpha':.25, "colorbar":False}}` - """ + # Importing here to avoid circular import issues + from mesa.visualization.components import PropertyLayerStyle + + def _propertylayer_portryal_dict_to_callable( + propertylayer_portrayal: dict[str, dict[str, Any]], + ): + """Helper function to convert a propertylayer_portrayal dict to a callable that return a PropertyLayerStyle.""" + + def style_callable(layer_object: Any): + layer_name = layer_object.name + params = propertylayer_portrayal.get(layer_name) + + warnings.warn( + "The propertylayer_portrayal dict is deprecated. Use a callable that returns PropertyLayerStyle instead.", + DeprecationWarning, + stacklevel=2, + ) + + if params is None: + return None # Layer not specified in the dict, so skip. + + return PropertyLayerStyle( + color=params.get("color"), + colormap=params.get("colormap"), + alpha=params.get( + "alpha", PropertyLayerStyle.alpha + ), # Use defaults defined in the dataclass itself + vmin=params.get("vmin"), + vmax=params.get("vmax"), + colorbar=params.get("colorbar", PropertyLayerStyle.colorbar), + ) + + return style_callable + try: # old style spaces property_layers = space.properties @@ -225,12 +316,24 @@ def draw_property_layers( # new style spaces property_layers = space._mesa_property_layers - for layer_name, portrayal in propertylayer_portrayal.items(): + callable_portrayal: Callable[[Any], PropertyLayerStyle | None] + if isinstance(propertylayer_portrayal, dict): + callable_portrayal = _propertylayer_portryal_dict_to_callable( + propertylayer_portrayal + ) + else: + callable_portrayal = propertylayer_portrayal + + for layer_name in property_layers: + if layer_name == "empty": + # Skipping empty layer, automatically generated + continue + layer = property_layers.get(layer_name, None) - if not isinstance( - layer, - PropertyLayer | mesa.discrete_space.property_layer.PropertyLayer, - ): + portrayal = callable_portrayal(layer) + + if portrayal is None: + # Not visualizing layers that do not have a defined visual encoding. continue data = layer.data.astype(float) if layer.data.dtype == bool else layer.data @@ -242,20 +345,19 @@ def draw_property_layers( stacklevel=2, ) - # Get portrayal properties, or use defaults - alpha = portrayal.get("alpha", 1) - vmin = portrayal.get("vmin", np.min(data)) - vmax = portrayal.get("vmax", np.max(data)) - colorbar = portrayal.get("colorbar", True) + color = portrayal.color + colormap = portrayal.colormap + alpha = portrayal.alpha + vmin = portrayal.vmin if portrayal.vmin else np.min(data) + vmax = portrayal.vmax if portrayal.vmax else np.max(data) - # Prepare colormap - if "color" in portrayal: - rgba_color = to_rgba(portrayal["color"]) + if color: + rgba_color = to_rgba(color) cmap = LinearSegmentedColormap.from_list( layer_name, [(0, 0, 0, 0), (*rgba_color[:3], alpha)] ) - elif "colormap" in portrayal: - cmap = portrayal.get("colormap", "viridis") + elif colormap: + cmap = colormap if isinstance(cmap, list): cmap = LinearSegmentedColormap.from_list(layer_name, cmap) elif isinstance(cmap, str): @@ -266,7 +368,7 @@ def draw_property_layers( ) if isinstance(space, OrthogonalGrid): - if "color" in portrayal: + if color: data = data.T normalized_data = (data - vmin) / (vmax - vmin) rgba_data = np.full((*data.shape, 4), rgba_color) @@ -282,36 +384,26 @@ def draw_property_layers( vmax=vmax, origin="lower", ) - elif isinstance(space, HexGrid): width, height = data.shape - - # Generate hexagon mesh hexagons = _get_hexmesh(width, height) - - # Normalize colors norm = Normalize(vmin=vmin, vmax=vmax) - colors = data.ravel() # flatten data to 1D array + colors = data.ravel() - if "color" in portrayal: + if color: normalized_colors = np.clip(norm(colors), 0, 1) rgba_colors = np.full((len(colors), 4), rgba_color) rgba_colors[:, 3] = normalized_colors * alpha else: rgba_colors = cmap(norm(colors)) rgba_colors[..., 3] *= alpha - - # Draw hexagons collection = PolyCollection(hexagons, facecolors=rgba_colors, zorder=-1) ax.add_collection(collection) - else: raise NotImplementedError( f"PropertyLayer visualization not implemented for {type(space)}." ) - - # Add colorbar if requested - if colorbar: + if portrayal.colorbar: norm = Normalize(vmin=vmin, vmax=vmax) sm = ScalarMappable(norm=norm, cmap=cmap) sm.set_array([]) @@ -329,7 +421,7 @@ def draw_orthogonal_grid( Args: space: the space to visualize - agent_portrayal: a callable that is called with the agent and returns a dict + agent_portrayal: a callable that is called with the agent and returns a AgentPortrayalStyle ax: a Matplotlib Axes instance. If none is provided a new figure and ax will be created using plt.subplots draw_grid: whether to draw the grid kwargs: additional keyword arguments passed to ax.scatter @@ -337,8 +429,8 @@ def draw_orthogonal_grid( Returns: Returns the Axes object with the plot drawn onto it. - ``agent_portrayal`` is called with an agent and should return a dict. Valid fields in this dict are "color", - "size", "marker", and "zorder". Other field are ignored and will result in a user warning. + ``agent_portrayal`` is called with an agent and should return a AgentPortrayalStyle. Valid fields in this object are "color", + "size", "marker", "zorder", alpha, linewidths, and edgecolors. Other field are ignored and will result in a user warning. """ if ax is None: @@ -346,7 +438,7 @@ def draw_orthogonal_grid( # gather agent data s_default = (180 / max(space.width, space.height)) ** 2 - arguments = collect_agent_data(space, agent_portrayal, size=s_default) + arguments = collect_agent_data(space, agent_portrayal, default_size=s_default) # plot the agents _scatter(ax, arguments, **kwargs) @@ -376,17 +468,22 @@ def draw_hex_grid( Args: space: the space to visualize - agent_portrayal: a callable that is called with the agent and returns a dict + agent_portrayal: a callable that is called with the agent and returns a AgentPortrayalStyle ax: a Matplotlib Axes instance. If none is provided a new figure and ax will be created using plt.subplots draw_grid: whether to draw the grid kwargs: additional keyword arguments passed to ax.scatter + Returns: + Returns the Axes object with the plot drawn onto it. + + ``agent_portrayal`` is called with an agent and should return a AgentPortrayalStyle. Valid fields in this object are "color", + "size", "marker", "zorder", alpha, linewidths, and edgecolors. Other field are ignored and will result in a user warning. """ if ax is None: fig, ax = plt.subplots() # gather data s_default = (180 / max(space.width, space.height)) ** 2 - arguments = collect_agent_data(space, agent_portrayal, size=s_default) + arguments = collect_agent_data(space, agent_portrayal, default_size=s_default) # Parameters for hexagon grid size = 1.0 @@ -452,7 +549,7 @@ def draw_network( Args: space: the space to visualize - agent_portrayal: a callable that is called with the agent and returns a dict + agent_portrayal: a callable that is called with the agent and returns a AgentPortrayalStyle ax: a Matplotlib Axes instance. If none is provided a new figure and ax will be created using plt.subplots draw_grid: whether to draw the grid layout_alg: a networkx layout algorithm or other callable with the same behavior @@ -462,8 +559,8 @@ def draw_network( Returns: Returns the Axes object with the plot drawn onto it. - ``agent_portrayal`` is called with an agent and should return a dict. Valid fields in this dict are "color", - "size", "marker", and "zorder". Other field are ignored and will result in a user warning. + ``agent_portrayal`` is called with an agent and should return a AgentPortrayalStyle. Valid fields in this object are "color", + "size", "marker", "zorder", alpha, linewidths, and edgecolors. Other field are ignored and will result in a user warning. """ if ax is None: @@ -485,12 +582,19 @@ def draw_network( # gather agent data s_default = (180 / max(width, height)) ** 2 - arguments = collect_agent_data(space, agent_portrayal, size=s_default) + arguments = collect_agent_data(space, agent_portrayal, default_size=s_default) # this assumes that nodes are identified by an integer # which is true for default nx graphs but might user changeable pos = np.asarray(list(pos.values())) - arguments["loc"] = pos[arguments["loc"]] + loc = arguments["loc"] + + # For network only one of x and y contains the correct coordinates + x = loc[:, 0] + if x is None: + x = loc[:, 1] + + arguments["loc"] = pos[x] # plot the agents _scatter(ax, arguments, **kwargs) @@ -517,15 +621,15 @@ def draw_continuous_space( Args: space: the space to visualize - agent_portrayal: a callable that is called with the agent and returns a dict + agent_portrayal: a callable that is called with the agent and returns a AgentPortrayalStyle ax: a Matplotlib Axes instance. If none is provided a new figure and ax will be created using plt.subplots kwargs: additional keyword arguments passed to ax.scatter Returns: Returns the Axes object with the plot drawn onto it. - ``agent_portrayal`` is called with an agent and should return a dict. Valid fields in this dict are "color", - "size", "marker", and "zorder". Other field are ignored and will result in a user warning. + ``agent_portrayal`` is called with an agent and should return a AgentPortrayalStyle. Valid fields in this object are "color", + "size", "marker", "zorder", alpha, linewidths, and edgecolors. Other field are ignored and will result in a user warning. """ if ax is None: @@ -539,7 +643,7 @@ def draw_continuous_space( # gather agent data s_default = (180 / max(width, height)) ** 2 - arguments = collect_agent_data(space, agent_portrayal, size=s_default) + arguments = collect_agent_data(space, agent_portrayal, default_size=s_default) # plot the agents _scatter(ax, arguments, **kwargs) @@ -568,7 +672,7 @@ def draw_voronoi_grid( Args: space: the space to visualize - agent_portrayal: a callable that is called with the agent and returns a dict + agent_portrayal: a callable that is called with the agent and returns a AgentPortrayalStyle ax: a Matplotlib Axes instance. If none is provided a new figure and ax will be created using plt.subplots draw_grid: whether to draw the grid or not kwargs: additional keyword arguments passed to ax.scatter @@ -576,8 +680,8 @@ def draw_voronoi_grid( Returns: Returns the Axes object with the plot drawn onto it. - ``agent_portrayal`` is called with an agent and should return a dict. Valid fields in this dict are "color", - "size", "marker", and "zorder". Other field are ignored and will result in a user warning. + ``agent_portrayal`` is called with an agent and should return a AgentPortrayalStyle. Valid fields in this object are "color", + "size", "marker", "zorder", alpha, linewidths, and edgecolors. Other field are ignored and will result in a user warning. """ if ax is None: @@ -596,7 +700,7 @@ def draw_voronoi_grid( y_padding = height / 20 s_default = (180 / max(width, height)) ** 2 - arguments = collect_agent_data(space, agent_portrayal, size=s_default) + arguments = collect_agent_data(space, agent_portrayal, default_size=s_default) ax.set_xlim(x_min - x_padding, x_max + x_padding) ax.set_ylim(y_min - y_padding, y_max + y_padding) diff --git a/tests/test_portrayal_components.py b/tests/test_portrayal_components.py new file mode 100644 index 00000000000..5c8faabb57a --- /dev/null +++ b/tests/test_portrayal_components.py @@ -0,0 +1,139 @@ +"""Tests for the portrayal components in Mesa visualization.""" + +from dataclasses import is_dataclass + +import pytest + +from mesa.visualization.components import AgentPortrayalStyle, PropertyLayerStyle + + +def test_agent_portrayal_style_is_dataclass(): + """Test if AgentPortrayalStyle is a dataclass.""" + assert is_dataclass(AgentPortrayalStyle) + + +def test_agent_portrayal_style_defaults(): + """Test default values of AgentPortrayalStyle.""" + style = AgentPortrayalStyle() + assert style.x is None + assert style.y is None + assert style.color == "tab:blue" + assert style.marker == "o" + assert style.size == 50 + assert style.zorder == 1 + assert style.alpha == 1.0 + assert style.edgecolors is None + assert style.linewidths == 1.0 + + +def test_agent_portrayal_style_custom_initialization(): + """Test custom initialization of AgentPortrayalStyle.""" + style = AgentPortrayalStyle( + x=10, + y=20, + color="red", + marker="^", + size=100, + zorder=2, + alpha=0.5, + edgecolors="black", + linewidths=2.0, + ) + assert style.x == 10 + assert style.y == 20 + assert style.color == "red" + assert style.marker == "^" + assert style.size == 100 + assert style.zorder == 2 + assert style.alpha == 0.5 + assert style.edgecolors == "black" + assert style.linewidths == 2.0 + + +def test_agent_portrayal_style_update_attributes(): + """Test updating attributes of AgentPortrayalStyle.""" + style = AgentPortrayalStyle() + style.update(("marker", "s"), ("size", 75), ("alpha", 0.7)) + assert style.marker == "s" + assert style.size == 75 + assert style.alpha == 0.7 + + +def test_agent_portrayal_style_update_non_existent_attribute(): + """Test updating a non-existent attribute raises AttributeError.""" + style = AgentPortrayalStyle() + with pytest.raises( + AttributeError, match="'AgentPortrayalStyle' object has no attribute 'shape'" + ): + style.update(("shape", "triangle")) + + +def test_agent_portrayal_style_update_with_no_arguments(): + """Test updating AgentPortrayalStyle with no arguments does not change the style.""" + original_style = AgentPortrayalStyle(color="blue") + updated_style = AgentPortrayalStyle(color="blue") + updated_style.update() + assert updated_style.color == original_style.color # Ensure no change + + +def test_property_layer_style_is_dataclass(): + """Test if PropertyLayerStyle is a dataclass.""" + assert is_dataclass(PropertyLayerStyle) + + +def test_property_layer_style_default_values_with_colormap(): + """Test default values of PropertyLayerStyle with colormap.""" + style = PropertyLayerStyle(colormap="viridis") + assert style.colormap == "viridis" + assert style.color is None + assert style.alpha == 0.8 + assert style.colorbar is True + assert style.vmin is None + assert style.vmax is None + + +def test_property_layer_style_default_values_with_color(): + """Test default values of PropertyLayerStyle with color.""" + style = PropertyLayerStyle(color="red") + assert style.colormap is None + assert style.color == "red" + assert style.alpha == 0.8 + assert style.colorbar is True + assert style.vmin is None + assert style.vmax is None + + +def test_property_layer_style_custom_initialization_with_colormap(): + """Test custom initialization of PropertyLayerStyle with colormap.""" + style = PropertyLayerStyle( + colormap="plasma", alpha=0.5, colorbar=False, vmin=0, vmax=1 + ) + assert style.colormap == "plasma" + assert style.color is None + assert style.alpha == 0.5 + assert style.colorbar is False + assert style.vmin == 0 + assert style.vmax == 1 + + +def test_property_layer_style_custom_initialization_with_color(): + """Test custom initialization of PropertyLayerStyle with color.""" + style = PropertyLayerStyle(color="blue", alpha=0.9, colorbar=False) + assert style.colormap is None + assert style.color == "blue" + assert style.alpha == 0.9 + assert style.colorbar is False + + +def test_property_layer_style_post_init_both_color_and_colormap_error(): + """Test error when both color and colormap are specified.""" + with pytest.raises( + ValueError, match="Specify either 'color' or 'colormap', not both." + ): + PropertyLayerStyle(colormap="viridis", color="red") + + +def test_property_layer_style_post_init_neither_color_nor_colormap_error(): + """Test error when neither color nor colormap is specified.""" + with pytest.raises(ValueError, match="Specify one of 'color' or 'colormap'"): + PropertyLayerStyle()