Skip to content

Commit 47eff74

Browse files
Add Altair plotting functionality (#2810)
* Add altair plotting functionality * empty commit * make renderer priority * remove layer name from chart title * update the drawing of property layers in altair * fix tests --------- Co-authored-by: Tom Pike <[email protected]>
1 parent f9e6dbb commit 47eff74

File tree

6 files changed

+138
-200
lines changed

6 files changed

+138
-200
lines changed

mesa/visualization/backends/altair_backend.py

Lines changed: 39 additions & 174 deletions
Original file line numberDiff line numberDiff line change
@@ -299,11 +299,13 @@ def draw_agents(
299299
"x:Q",
300300
title=xlabel,
301301
scale=alt.Scale(type="linear", domain=[xmin, xmax]),
302+
axis=None,
302303
),
303304
y=alt.Y(
304305
"y:Q",
305306
title=ylabel,
306307
scale=alt.Scale(type="linear", domain=[ymin, ymax]),
308+
axis=None,
307309
),
308310
size=alt.Size("size:Q", legend=None, scale=alt.Scale(domain=[0, 50])),
309311
shape=alt.Shape(
@@ -352,8 +354,7 @@ def draw_propertylayer(
352354
Returns:
353355
alt.Chart: A tuple containing the base chart and the color bar chart.
354356
"""
355-
base = None
356-
bar_chart_viz = None
357+
main_charts = []
357358

358359
for layer_name in property_layers:
359360
if layer_name == "empty":
@@ -384,7 +385,6 @@ def draw_propertylayer(
384385
vmin = portrayal.vmin if portrayal.vmin is not None else np.min(data)
385386
vmax = portrayal.vmax if portrayal.vmax is not None else np.max(data)
386387

387-
# Prepare data for Altair
388388
df = pd.DataFrame(
389389
{
390390
"x": np.repeat(np.arange(data.shape[0]), data.shape[1]),
@@ -393,183 +393,48 @@ def draw_propertylayer(
393393
}
394394
)
395395

396-
current_chart = None
397396
if color:
398-
# Create a function to map values to RGBA colors with proper opacity scaling
399-
def apply_rgba(
400-
val, v_min=vmin, v_max=vmax, a=alpha, p_color=portrayal.color
401-
):
402-
# Normalize value to range [0,1] and clamp
403-
normalized = max(
404-
0,
405-
min(
406-
((val - v_min) / (v_max - v_min))
407-
if (v_max - v_min) != 0
408-
else 0.5,
409-
1,
410-
),
411-
)
412-
413-
# Scale opacity by alpha parameter
414-
opacity = normalized * a
415-
416-
# Convert color to RGB components
417-
rgb_color_val = to_rgb(p_color)
418-
r = int(rgb_color_val[0] * 255)
419-
g = int(rgb_color_val[1] * 255)
420-
b = int(rgb_color_val[2] * 255)
421-
return f"rgba({r}, {g}, {b}, {opacity:.2f})"
422-
423-
# Apply color mapping to each value in the dataset
424-
df["color_str"] = df["value"].apply(apply_rgba)
425-
426-
# Create chart for the property layer
427-
current_chart = (
428-
alt.Chart(df)
429-
.mark_rect()
430-
.encode(
431-
x=alt.X("x:O", axis=None),
432-
y=alt.Y("y:O", axis=None),
433-
fill=alt.Fill("color_str:N", scale=None),
434-
)
435-
.properties(
436-
width=chart_width, height=chart_height, title=layer_name
437-
)
397+
# For a single color gradient, we define the range from transparent to solid.
398+
rgb = to_rgb(color)
399+
r, g, b = (int(c * 255) for c in rgb)
400+
401+
min_color = f"rgba({r},{g},{b},0)"
402+
max_color = f"rgba({r},{g},{b},{alpha})"
403+
opacity = 1
404+
color_scale = alt.Scale(
405+
range=[min_color, max_color], domain=[vmin, vmax]
438406
)
439-
base = (
440-
alt.layer(current_chart, base)
441-
if base is not None
442-
else current_chart
443-
)
444-
445-
# Add colorbar if specified in portrayal
446-
if portrayal.colorbar:
447-
# Extract RGB components from base color
448-
rgb_color_val = to_rgb(portrayal.color)
449-
r_int = int(rgb_color_val[0] * 255)
450-
g_int = int(rgb_color_val[1] * 255)
451-
b_int = int(rgb_color_val[2] * 255)
452-
453-
# Define gradient endpoints
454-
min_color_str = f"rgba({r_int},{g_int},{b_int},0)"
455-
max_color_str = f"rgba({r_int},{g_int},{b_int},{alpha:.2f})"
456-
457-
# Define colorbar dimensions
458-
colorbar_height = 20
459-
colorbar_width = chart_width
460-
461-
# Create dataframe for gradient visualization
462-
df_gradient = pd.DataFrame({"x_grad": [0, 1], "y_grad": [0, 1]})
463-
464-
# Create evenly distributed tick values
465-
axis_values = np.linspace(vmin, vmax, 11)
466-
tick_positions = np.linspace(10, colorbar_width - 10, 11)
467-
468-
# Prepare data for axis and labels
469-
axis_data = pd.DataFrame(
470-
{"value_axis": axis_values, "x_axis": tick_positions}
471-
)
472-
473-
# Create colorbar with linear gradient
474-
colorbar_chart_obj = (
475-
alt.Chart(df_gradient)
476-
.mark_rect(
477-
x=20,
478-
y=0,
479-
width=colorbar_width - 20,
480-
height=colorbar_height,
481-
color=alt.Gradient(
482-
gradient="linear",
483-
stops=[
484-
alt.GradientStop(color=min_color_str, offset=0),
485-
alt.GradientStop(color=max_color_str, offset=1),
486-
],
487-
x1=0,
488-
x2=1, # Horizontal gradient
489-
y1=0,
490-
y2=0, # Keep y constant
491-
),
492-
)
493-
.encode(
494-
x=alt.value(chart_width / 2), y=alt.value(8)
495-
) # Center colorbar
496-
.properties(width=colorbar_width, height=colorbar_height)
497-
)
498-
# Add tick marks to colorbar
499-
axis_chart = (
500-
alt.Chart(axis_data)
501-
.mark_tick(thickness=2, size=10)
502-
.encode(
503-
x=alt.X("x_axis:Q", axis=None),
504-
y=alt.value(colorbar_height - 2),
505-
)
506-
)
507-
# Add value labels below tick marks
508-
text_labels = (
509-
alt.Chart(axis_data)
510-
.mark_text(baseline="top", fontSize=10, dy=0)
511-
.encode(
512-
x=alt.X("x_axis:Q"),
513-
text=alt.Text("value_axis:Q", format=".1f"),
514-
y=alt.value(colorbar_height + 10),
515-
)
516-
)
517-
# Add title to colorbar
518-
title_chart = (
519-
alt.Chart(pd.DataFrame([{"text_title": layer_name}]))
520-
.mark_text(
521-
fontSize=12,
522-
fontWeight="bold",
523-
baseline="bottom",
524-
align="center",
525-
)
526-
.encode(
527-
text="text_title:N",
528-
x=alt.value(colorbar_width / 2),
529-
y=alt.value(colorbar_height + 40),
530-
)
531-
)
532-
# Combine all colorbar components
533-
combined_colorbar = alt.layer(
534-
colorbar_chart_obj, axis_chart, text_labels, title_chart
535-
).properties(width=colorbar_width, height=colorbar_height + 50)
536-
537-
bar_chart_viz = (
538-
alt.vconcat(bar_chart_viz, combined_colorbar).resolve_scale(
539-
color="independent"
540-
)
541-
if bar_chart_viz is not None
542-
else combined_colorbar
543-
)
544407

545408
elif colormap:
546409
cmap = colormap
547-
cmap_scale = alt.Scale(scheme=cmap, domain=[vmin, vmax])
548-
549-
current_chart = (
550-
alt.Chart(df)
551-
.mark_rect(opacity=alpha)
552-
.encode(
553-
x=alt.X("x:O", axis=None),
554-
y=alt.Y("y:O", axis=None),
555-
color=alt.Color(
556-
"value:Q",
557-
scale=cmap_scale,
558-
title=layer_name,
559-
legend=alt.Legend(title=layer_name)
560-
if portrayal.colorbar
561-
else None,
562-
),
563-
)
564-
.properties(width=chart_width, height=chart_height)
565-
)
566-
base = (
567-
alt.layer(current_chart, base)
568-
if base is not None
569-
else current_chart
570-
)
410+
color_scale = alt.Scale(scheme=cmap, domain=[vmin, vmax])
411+
opacity = alpha
412+
571413
else:
572414
raise ValueError(
573415
f"PropertyLayer {layer_name} portrayal must include 'color' or 'colormap'."
574416
)
575-
return (base, bar_chart_viz)
417+
418+
current_chart = (
419+
alt.Chart(df)
420+
.mark_rect(opacity=opacity)
421+
.encode(
422+
x=alt.X("x:O", axis=None),
423+
y=alt.Y("y:O", axis=None),
424+
color=alt.Color(
425+
"value:Q",
426+
scale=color_scale,
427+
title=layer_name,
428+
legend=alt.Legend(title=layer_name, orient="bottom")
429+
if portrayal.colorbar
430+
else None,
431+
),
432+
)
433+
.properties(width=chart_width, height=chart_height)
434+
)
435+
436+
if current_chart is not None:
437+
main_charts.append(current_chart)
438+
439+
base = alt.layer(*main_charts).resolve_scale(color="independent")
440+
return base

mesa/visualization/components/__init__.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,11 @@
44

55
from collections.abc import Callable
66

7-
from .altair_components import SpaceAltair, make_altair_space
7+
from .altair_components import (
8+
SpaceAltair,
9+
make_altair_plot_component,
10+
make_altair_space,
11+
)
812
from .matplotlib_components import (
913
SpaceMatplotlib,
1014
make_mpl_plot_component,
@@ -80,16 +84,13 @@ def make_plot_component(
8084
backend: the backend to use {"matplotlib", "altair"}
8185
plot_drawing_kwargs: additional keyword arguments to pass onto the backend specific function for making a plotting component
8286
83-
Notes:
84-
altair plotting backend is not yet implemented and planned for mesa 3.1.
85-
8687
Returns:
8788
function: A function that creates a plot component
8889
"""
8990
if backend == "matplotlib":
9091
return make_mpl_plot_component(measure, post_process, **plot_drawing_kwargs)
9192
elif backend == "altair":
92-
raise NotImplementedError("altair line plots are not yet implemented")
93+
return make_altair_plot_component(measure, post_process, **plot_drawing_kwargs)
9394
else:
9495
raise ValueError(
9596
f"unknown backend {backend}, must be one of matplotlib, altair"

mesa/visualization/components/altair_components.py

Lines changed: 84 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
"""Altair based solara components for visualization mesa spaces."""
22

33
import warnings
4+
from collections.abc import Callable
45

56
import altair as alt
67
import numpy as np
@@ -448,3 +449,86 @@ def apply_rgba(val, vmin=vmin, vmax=vmax, alpha=alpha, portrayal=portrayal):
448449
f"PropertyLayer {layer_name} portrayal must include 'color' or 'colormap'."
449450
)
450451
return base, bar_chart
452+
453+
454+
def make_altair_plot_component(
455+
measure: str | dict[str, str] | list[str] | tuple[str],
456+
post_process: Callable | None = None,
457+
grid=False,
458+
):
459+
"""Create a plotting function for a specified measure.
460+
461+
Args:
462+
measure (str | dict[str, str] | list[str] | tuple[str]): Measure(s) to plot.
463+
post_process: a user-specified callable to do post-processing called with the Axes instance.
464+
grid: Bool to draw grid or not.
465+
466+
Returns:
467+
function: A function that creates a PlotAltair component.
468+
"""
469+
470+
def MakePlotAltair(model):
471+
return PlotAltair(model, measure, post_process=post_process, grid=grid)
472+
473+
return MakePlotAltair
474+
475+
476+
@solara.component
477+
def PlotAltair(model, measure, post_process: Callable | None = None, grid=False):
478+
"""Create an Altair-based plot for a measure or measures.
479+
480+
Args:
481+
model (mesa.Model): The model instance.
482+
measure (str | dict[str, str] | list[str] | tuple[str]): Measure(s) to plot.
483+
If a dict is given, keys are measure names and values are colors.
484+
post_process: A user-specified callable for post-processing, called
485+
with the Altair Chart instance.
486+
grid: Bool to draw grid or not.
487+
488+
Returns:
489+
solara.FigureAltair: A component for rendering the plot.
490+
"""
491+
update_counter.get()
492+
df = model.datacollector.get_model_vars_dataframe().reset_index()
493+
df = df.rename(columns={"index": "Step"})
494+
495+
y_title = "Value"
496+
if isinstance(measure, str):
497+
measures_to_plot = [measure]
498+
y_title = measure
499+
elif isinstance(measure, list | tuple):
500+
measures_to_plot = list(measure)
501+
elif isinstance(measure, dict):
502+
measures_to_plot = list(measure.keys())
503+
504+
df_long = df.melt(
505+
id_vars=["Step"],
506+
value_vars=measures_to_plot,
507+
var_name="Measure",
508+
value_name="Value",
509+
)
510+
511+
chart = (
512+
alt.Chart(df_long)
513+
.mark_line()
514+
.encode(
515+
x=alt.X("Step:Q", axis=alt.Axis(tickMinStep=1, title="Step", grid=grid)),
516+
y=alt.Y("Value:Q", axis=alt.Axis(title=y_title, grid=grid)),
517+
tooltip=["Step", "Measure", "Value"],
518+
)
519+
.properties(width=450, height=350)
520+
.interactive()
521+
)
522+
523+
if len(measures_to_plot) > 0:
524+
color_args = {}
525+
if isinstance(measure, dict):
526+
color_args["scale"] = alt.Scale(
527+
domain=list(measure.keys()), range=list(measure.values())
528+
)
529+
chart = chart.encode(color=alt.Color("Measure:N", **color_args))
530+
531+
if post_process is not None:
532+
chart = post_process(chart)
533+
534+
return solara.FigureAltair(chart)

0 commit comments

Comments
 (0)