Skip to content
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
142 changes: 53 additions & 89 deletions mesa/visualization/mpl_space_drawing.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,7 @@
import contextlib
import itertools
import warnings
from collections.abc import Callable, Iterator
from functools import lru_cache
from collections.abc import Callable
from itertools import pairwise
from typing import Any

Expand All @@ -19,7 +18,7 @@
from matplotlib import pyplot as plt
from matplotlib.axes import Axes
from matplotlib.cm import ScalarMappable
from matplotlib.collections import LineCollection, PatchCollection, PolyCollection
from matplotlib.collections import LineCollection, PatchCollection
from matplotlib.colors import LinearSegmentedColormap, Normalize, to_rgba
from matplotlib.patches import Polygon

Expand Down Expand Up @@ -160,37 +159,6 @@ def draw_space(
return ax


@lru_cache(maxsize=1024, typed=True)
def _get_hexmesh(
width: int, height: int, size: float = 1.0
) -> Iterator[list[tuple[float, float]]]:
"""Generate hexagon vertices for the mesh. Yields list of vertex coordinates for each hexagon."""

# Helper function for getting the vertices of a hexagon given the center and size
def _get_hex_vertices(
center_x: float, center_y: float, size: float = 1.0
) -> list[tuple[float, float]]:
"""Get vertices for a hexagon centered at (center_x, center_y)."""
vertices = [
(center_x, center_y + size), # top
(center_x + size * np.sqrt(3) / 2, center_y + size / 2), # top right
(center_x + size * np.sqrt(3) / 2, center_y - size / 2), # bottom right
(center_x, center_y - size), # bottom
(center_x - size * np.sqrt(3) / 2, center_y - size / 2), # bottom left
(center_x - size * np.sqrt(3) / 2, center_y + size / 2), # top left
]
return vertices

x_spacing = np.sqrt(3) * size
y_spacing = 1.5 * size

for row, col in itertools.product(range(height), range(width)):
# Calculate center position with offset for even rows
x = col * x_spacing + (row % 2 == 0) * (x_spacing / 2)
y = row * y_spacing
yield _get_hex_vertices(x, y, size)


def draw_property_layers(
space, propertylayer_portrayal: dict[str, dict[str, Any]], ax: Axes
):
Expand Down Expand Up @@ -237,74 +205,46 @@ def draw_property_layers(
vmax = portrayal.get("vmax", np.max(data))
colorbar = portrayal.get("colorbar", True)

# Prepare colormap
# Draw the layer
if "color" in portrayal:
data = data.T
rgba_color = to_rgba(portrayal["color"])
normalized_data = (data - vmin) / (vmax - vmin)
rgba_data = np.full((*data.shape, 4), rgba_color)
rgba_data[..., 3] *= normalized_data * alpha
rgba_data = np.clip(rgba_data, 0, 1)
cmap = LinearSegmentedColormap.from_list(
layer_name, [(0, 0, 0, 0), (*rgba_color[:3], alpha)]
)
im = ax.imshow(
rgba_data,
origin="lower",
)
if colorbar:
norm = Normalize(vmin=vmin, vmax=vmax)
sm = ScalarMappable(norm=norm, cmap=cmap)
sm.set_array([])
ax.figure.colorbar(sm, ax=ax, orientation="vertical")

elif "colormap" in portrayal:
cmap = portrayal.get("colormap", "viridis")
if isinstance(cmap, list):
cmap = LinearSegmentedColormap.from_list(layer_name, cmap)
elif isinstance(cmap, str):
cmap = plt.get_cmap(cmap)
im = ax.imshow(
data.T,
cmap=cmap,
alpha=alpha,
vmin=vmin,
vmax=vmax,
origin="lower",
)
if colorbar:
plt.colorbar(im, ax=ax, label=layer_name)
else:
raise ValueError(
f"PropertyLayer {layer_name} portrayal must include 'color' or 'colormap'."
)

if isinstance(space, OrthogonalGrid):
if "color" in portrayal:
data = data.T
normalized_data = (data - vmin) / (vmax - vmin)
rgba_data = np.full((*data.shape, 4), rgba_color)
rgba_data[..., 3] *= normalized_data * alpha
rgba_data = np.clip(rgba_data, 0, 1)
ax.imshow(rgba_data, origin="lower")
else:
ax.imshow(
data.T,
cmap=cmap,
alpha=alpha,
vmin=vmin,
vmax=vmax,
origin="lower",
)

elif isinstance(space, HexGrid):
width, height = data.shape

# Generate hexagon mesh
hexagons = _get_hexmesh(width, height)

# Normalize colors
norm = Normalize(vmin=vmin, vmax=vmax)
colors = data.ravel() # flatten data to 1D array

if "color" in portrayal:
normalized_colors = np.clip(norm(colors), 0, 1)
rgba_colors = np.full((len(colors), 4), rgba_color)
rgba_colors[:, 3] = normalized_colors * alpha
else:
rgba_colors = cmap(norm(colors))

# Draw hexagons
collection = PolyCollection(hexagons, facecolors=rgba_colors, zorder=-1)
ax.add_collection(collection)

else:
raise NotImplementedError(
f"PropertyLayer visualization not implemented for {type(space)}."
)

# Add colorbar if requested
if colorbar:
norm = Normalize(vmin=vmin, vmax=vmax)
sm = ScalarMappable(norm=norm, cmap=cmap)
sm.set_array([])
plt.colorbar(sm, ax=ax, label=layer_name)


def draw_orthogonal_grid(
space: OrthogonalGrid,
Expand Down Expand Up @@ -409,15 +349,39 @@ def draw_hex_grid(
def setup_hexmesh(width, height):
"""Helper function for creating the hexmesh with unique edges."""
edges = set()
size = 1.0
x_spacing = np.sqrt(3) * size
y_spacing = 1.5 * size

def get_hex_vertices(
center_x: float, center_y: float
) -> list[tuple[float, float]]:
"""Get vertices for a hexagon centered at (center_x, center_y)."""
vertices = [
(center_x, center_y + size), # top
(center_x + size * np.sqrt(3) / 2, center_y + size / 2), # top right
(center_x + size * np.sqrt(3) / 2, center_y - size / 2), # bottom right
(center_x, center_y - size), # bottom
(center_x - size * np.sqrt(3) / 2, center_y - size / 2), # bottom left
(center_x - size * np.sqrt(3) / 2, center_y + size / 2), # top left
]
return vertices

# Generate edges for each hexagon
for vertices in _get_hexmesh(width, height):
for row, col in itertools.product(range(height), range(width)):
# Calculate center position for each hexagon with offset for even rows
x = col * x_spacing + (row % 2 == 0) * (x_spacing / 2)
y = row * y_spacing

vertices = get_hex_vertices(x, y)

# Edge logic, connecting each vertex to the next
for v1, v2 in pairwise([*vertices, vertices[0]]):
# Sort vertices to ensure consistent edge representation and avoid duplicates.
edge = tuple(sorted([tuple(np.round(v1, 6)), tuple(np.round(v2, 6))]))
edges.add(edge)

# Return LineCollection for hexmesh
return LineCollection(edges, linestyle=":", color="black", linewidth=1, alpha=1)

if draw_grid:
Expand Down
Loading