From a12e74f95a34a37254caaba3d6a764d16b8994f4 Mon Sep 17 00:00:00 2001 From: Will Dean Date: Tue, 7 May 2024 15:29:33 +0200 Subject: [PATCH 1/8] pull out node kwargs --- pymc/model_graph.py | 205 ++++++++++++++++++++++++++++++++++++-------- 1 file changed, 167 insertions(+), 38 deletions(-) diff --git a/pymc/model_graph.py b/pymc/model_graph.py index 65d518c8a..139d65e55 100644 --- a/pymc/model_graph.py +++ b/pymc/model_graph.py @@ -14,8 +14,10 @@ import warnings from collections import defaultdict -from collections.abc import Iterable, Sequence +from collections.abc import Callable, Iterable, Sequence +from enum import Enum from os import path +from typing import Any from pytensor import function from pytensor.graph import Apply @@ -41,6 +43,118 @@ def fast_eval(var): return function([], var, mode="FAST_COMPILE")() +class NodeType(str, Enum): + """Enum for the types of nodes in the graph.""" + + POTENTIAL = "Potential" + BASIC_RV = "Basic Random Variable" + OBSERVED_RV = "Observed Random Variable" + DETERMINISTIC = "Deterministic" + DATA = "Data" + + +GraphvizNodeKwargs = dict[str, Any] +NodeFormatter = Callable[[TensorVariable], GraphvizNodeKwargs] + + +def default_potential(var: TensorVariable) -> GraphvizNodeKwargs: + """Default data for the node in the graph.""" + return { + "shape": "octagon", + "style": "filled", + "label": f"{var.name}\n~\nPotential", + } + + +def random_variable_symbol(var: TensorVariable) -> str: + """Get the name of the random variable.""" + symbol = var.owner.op.__class__.__name__ + + if symbol.endswith("RV"): + symbol = symbol[:-2] + + return symbol + + +def default_basic_rv(var: TensorVariable) -> GraphvizNodeKwargs: + """Default data for the node in the graph.""" + symbol = random_variable_symbol(var) + + return { + "shape": "ellipse", + "style": None, + "label": f"{var.name}\n~\n{symbol}", + } + + +def default_observed_rv(var: TensorVariable) -> GraphvizNodeKwargs: + """Default data for the node in the graph.""" + symbol = random_variable_symbol(var) + + return { + "shape": "ellipse", + "style": "filled", + "label": f"{var.name}\n~\n{symbol}", + } + + +def default_deterministic(var: TensorVariable) -> GraphvizNodeKwargs: + """Default data for the node in the graph.""" + return { + "shape": "box", + "style": None, + "label": f"{var.name}\n~\nDeterministic", + } + + +def default_data(var: TensorVariable) -> GraphvizNodeKwargs: + """Default data for the node in the graph.""" + return { + "shape": "box", + "style": "rounded, filled", + "label": f"{var.name}\n~\nData", + } + + +def get_node_type(var_name: VarName, model) -> NodeType: + v = model[var_name] + + if v in model.potentials: + return NodeType.POTENTIAL + elif v in model.basic_RVs and v in model.observed_RVs: + return NodeType.OBSERVED_RV + elif v in model.basic_RVs: + return NodeType.BASIC_RV + elif v in model.deterministics: + return NodeType.DETERMINISTIC + else: + return NodeType.DATA + + +NodeTypeFormatterMapping = dict[NodeType, NodeFormatter] + +DEFAULT_NODE_FORMATTERS: NodeTypeFormatterMapping = { + NodeType.POTENTIAL: default_potential, + NodeType.BASIC_RV: default_basic_rv, + NodeType.OBSERVED_RV: default_observed_rv, + NodeType.DETERMINISTIC: default_deterministic, + NodeType.DATA: default_data, +} + + +def update_node_formatters(node_formatters: NodeTypeFormatterMapping): + node_formatters = {**DEFAULT_NODE_FORMATTERS, **node_formatters} + + unknown_keys = set(node_formatters.keys()) - set(NodeType) + if unknown_keys: + raise ValueError( + f"Node formatters must be of type NodeType. Found: {list(unknown_keys)}." + f" Please use one of {[node_type.value for node_type in NodeType]}." + ) + + return node_formatters + + class ModelGraph: def __init__(self, model): self.model = model @@ -148,42 +262,23 @@ def make_compute_graph( return input_map - def _make_node(self, var_name, graph, *, nx=False, cluster=False, formatting: str = "plain"): + def _make_node( + self, + var_name, + graph, + *, + node_formatters: NodeTypeFormatterMapping, + nx=False, + cluster=False, + formatting: str = "plain", + ): """Attaches the given variable to a graphviz or networkx Digraph""" v = self.model[var_name] - shape = None - style = None - label = str(v) - - if v in self.model.potentials: - shape = "octagon" - style = "filled" - label = f"{var_name}\n~\nPotential" - elif v in self.model.basic_RVs: - shape = "ellipse" - if v in self.model.observed_RVs: - style = "filled" - else: - style = None - symbol = v.owner.op.__class__.__name__ - if symbol.endswith("RV"): - symbol = symbol[:-2] - label = f"{var_name}\n~\n{symbol}" - elif v in self.model.deterministics: - shape = "box" - style = None - label = f"{var_name}\n~\nDeterministic" - else: - shape = "box" - style = "rounded, filled" - label = f"{var_name}\n~\nData" + node_type = get_node_type(var_name, self.model) + node_formatter = node_formatters[node_type] - kwargs = { - "shape": shape, - "style": style, - "label": label, - } + kwargs = node_formatter(v) if cluster: kwargs["cluster"] = cluster @@ -240,6 +335,7 @@ def make_graph( save=None, figsize=None, dpi=300, + node_formatters: NodeTypeFormatterMapping | None = None, ): """Make graphviz Digraph of PyMC model @@ -255,18 +351,26 @@ def make_graph( "The easiest way to install all of this is by running\n\n" "\tconda install -c conda-forge python-graphviz" ) + + node_formatters = node_formatters or {} + node_formatters = update_node_formatters(node_formatters) + graph = graphviz.Digraph(self.model.name) for plate_label, all_var_names in self.get_plates(var_names).items(): if plate_label: # must be preceded by 'cluster' to get a box around it with graph.subgraph(name="cluster" + plate_label) as sub: for var_name in all_var_names: - self._make_node(var_name, sub, formatting=formatting) + self._make_node( + var_name, sub, formatting=formatting, node_formatters=node_formatters + ) # plate label goes bottom right sub.attr(label=plate_label, labeljust="r", labelloc="b", style="rounded") else: for var_name in all_var_names: - self._make_node(var_name, graph, formatting=formatting) + self._make_node( + var_name, graph, formatting=formatting, node_formatters=node_formatters + ) for child, parents in self.make_compute_graph(var_names=var_names).items(): # parents is a set of rv names that precede child rv nodes @@ -287,7 +391,12 @@ def make_graph( return graph - def make_networkx(self, var_names: Iterable[VarName] | None = None, formatting: str = "plain"): + def make_networkx( + self, + var_names: Iterable[VarName] | None = None, + formatting: str = "plain", + node_formatters: NodeTypeFormatterMapping | None = None, + ): """Make networkx Digraph of PyMC model Returns @@ -302,6 +411,10 @@ def make_networkx(self, var_names: Iterable[VarName] | None = None, formatting: "The easiest way to install all of this is by running\n\n" "\tconda install networkx" ) + + node_formatters = node_formatters or {} + node_formatters = update_node_formatters(node_formatters) + graphnetwork = networkx.DiGraph(name=self.model.name) for plate_label, all_var_names in self.get_plates(var_names).items(): if plate_label: @@ -314,6 +427,7 @@ def make_networkx(self, var_names: Iterable[VarName] | None = None, formatting: var_name, subgraphnetwork, nx=True, + node_formatters=node_formatters, cluster="cluster" + plate_label, formatting=formatting, ) @@ -332,7 +446,13 @@ def make_networkx(self, var_names: Iterable[VarName] | None = None, formatting: graphnetwork.graph["name"] = self.model.name else: for var_name in all_var_names: - self._make_node(var_name, graphnetwork, nx=True, formatting=formatting) + self._make_node( + var_name, + graphnetwork, + nx=True, + formatting=formatting, + node_formatters=node_formatters, + ) for child, parents in self.make_compute_graph(var_names=var_names).items(): # parents is a set of rv names that precede child rv nodes @@ -346,6 +466,7 @@ def model_to_networkx( *, var_names: Iterable[VarName] | None = None, formatting: str = "plain", + node_formatters: NodeTypeFormatterMapping | None = None, ): """Produce a networkx Digraph from a PyMC model. @@ -367,6 +488,8 @@ def model_to_networkx( Subset of variables to be plotted that identify a subgraph with respect to the entire model graph formatting : str, optional one of { "plain" } + node_formatters : dict, optional + A dictionary mapping node types to functions that return a dictionary of node attributes. Examples -------- @@ -403,7 +526,9 @@ def model_to_networkx( stacklevel=2, ) model = pm.modelcontext(model) - return ModelGraph(model).make_networkx(var_names=var_names, formatting=formatting) + return ModelGraph(model).make_networkx( + var_names=var_names, formatting=formatting, node_formatters=node_formatters + ) def model_to_graphviz( @@ -414,6 +539,7 @@ def model_to_graphviz( save: str | None = None, figsize: tuple[int, int] | None = None, dpi: int = 300, + node_formatters: NodeTypeFormatterMapping | None = None, ): """Produce a graphviz Digraph from a PyMC model. @@ -441,6 +567,8 @@ def model_to_graphviz( the size of the saved figure. dpi : int, optional Dots per inch. It only affects the resolution of the saved figure. The default is 300. + node_formatters : dict, optional + A dictionary mapping node types to functions that return a dictionary of node attributes. Examples -------- @@ -491,4 +619,5 @@ def model_to_graphviz( save=save, figsize=figsize, dpi=dpi, + node_formatters=node_formatters, ) From f0629c8732dfc5b5a96f532934c89d6ac4a34f61 Mon Sep 17 00:00:00 2001 From: Will Dean Date: Tue, 7 May 2024 16:26:52 +0200 Subject: [PATCH 2/8] add type hint --- pymc/model_graph.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pymc/model_graph.py b/pymc/model_graph.py index 139d65e55..5d116212f 100644 --- a/pymc/model_graph.py +++ b/pymc/model_graph.py @@ -142,7 +142,7 @@ def get_node_type(var_name: VarName, model) -> NodeType: } -def update_node_formatters(node_formatters: NodeTypeFormatterMapping): +def update_node_formatters(node_formatters: NodeTypeFormatterMapping) -> NodeTypeFormatterMapping: node_formatters = {**DEFAULT_NODE_FORMATTERS, **node_formatters} unknown_keys = set(node_formatters.keys()) - set(NodeType) From e4faf6090e2aea9cddb817194d01018bb05e834e Mon Sep 17 00:00:00 2001 From: Will Dean <57733339+wd60622@users.noreply.github.com> Date: Tue, 7 May 2024 22:12:25 +0200 Subject: [PATCH 3/8] switch basic to free random variable Co-authored-by: Ricardo Vieira <28983449+ricardoV94@users.noreply.github.com> --- pymc/model_graph.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/pymc/model_graph.py b/pymc/model_graph.py index 5d116212f..8e5212f87 100644 --- a/pymc/model_graph.py +++ b/pymc/model_graph.py @@ -47,7 +47,7 @@ class NodeType(str, Enum): """Enum for the types of nodes in the graph.""" POTENTIAL = "Potential" - BASIC_RV = "Basic Random Variable" + FREE_RV = "Basic Random Variable" OBSERVED_RV = "Observed Random Variable" DETERMINISTIC = "Deterministic" DATA = "Data" @@ -121,10 +121,10 @@ def get_node_type(var_name: VarName, model) -> NodeType: if v in model.potentials: return NodeType.POTENTIAL - elif v in model.basic_RVs and v in model.observed_RVs: + elif v in model.observed_RVs: return NodeType.OBSERVED_RV - elif v in model.basic_RVs: - return NodeType.BASIC_RV + elif v in model.Free_RVs: + return NodeType.Free_RV elif v in model.deterministics: return NodeType.DETERMINISTIC else: From 279316c5c342a24a4733f05f7f888f4aa3395321 Mon Sep 17 00:00:00 2001 From: Will Dean Date: Tue, 7 May 2024 23:01:15 +0200 Subject: [PATCH 4/8] add to docstrings --- pymc/model_graph.py | 56 ++++++++++++++++++++++++++++++++------------- 1 file changed, 40 insertions(+), 16 deletions(-) diff --git a/pymc/model_graph.py b/pymc/model_graph.py index 8e5212f87..e7bf16f89 100644 --- a/pymc/model_graph.py +++ b/pymc/model_graph.py @@ -47,7 +47,7 @@ class NodeType(str, Enum): """Enum for the types of nodes in the graph.""" POTENTIAL = "Potential" - FREE_RV = "Basic Random Variable" + FREE_RV = "Free Random Variable" OBSERVED_RV = "Observed Random Variable" DETERMINISTIC = "Deterministic" DATA = "Data" @@ -58,7 +58,7 @@ class NodeType(str, Enum): def default_potential(var: TensorVariable) -> GraphvizNodeKwargs: - """Default data for the node in the graph.""" + """Default data for potential in the graph.""" return { "shape": "octagon", "style": "filled", @@ -67,7 +67,7 @@ def default_potential(var: TensorVariable) -> GraphvizNodeKwargs: def random_variable_symbol(var: TensorVariable) -> str: - """Get the name of the random variable.""" + """Get the symbol of the random variable.""" symbol = var.owner.op.__class__.__name__ if symbol.endswith("RV"): @@ -76,8 +76,8 @@ def random_variable_symbol(var: TensorVariable) -> str: return symbol -def default_basic_rv(var: TensorVariable) -> GraphvizNodeKwargs: - """Default data for the node in the graph.""" +def default_free_rv(var: TensorVariable) -> GraphvizNodeKwargs: + """Default data for free RV in the graph.""" symbol = random_variable_symbol(var) return { @@ -88,7 +88,7 @@ def default_basic_rv(var: TensorVariable) -> GraphvizNodeKwargs: def default_observed_rv(var: TensorVariable) -> GraphvizNodeKwargs: - """Default data for the node in the graph.""" + """Default data for observed RV in the graph.""" symbol = random_variable_symbol(var) return { @@ -99,7 +99,7 @@ def default_observed_rv(var: TensorVariable) -> GraphvizNodeKwargs: def default_deterministic(var: TensorVariable) -> GraphvizNodeKwargs: - """Default data for the node in the graph.""" + """Default data for the deterministic in the graph.""" return { "shape": "box", "style": None, @@ -108,7 +108,7 @@ def default_deterministic(var: TensorVariable) -> GraphvizNodeKwargs: def default_data(var: TensorVariable) -> GraphvizNodeKwargs: - """Default data for the node in the graph.""" + """Default data for the data in the graph.""" return { "shape": "box", "style": "rounded, filled", @@ -117,25 +117,26 @@ def default_data(var: TensorVariable) -> GraphvizNodeKwargs: def get_node_type(var_name: VarName, model) -> NodeType: + """Return the node type of the variable in the model.""" v = model[var_name] - if v in model.potentials: - return NodeType.POTENTIAL + if v in model.deterministics: + return NodeType.DETERMINISTIC + elif v in model.free_RVs: + return NodeType.FREE_RV elif v in model.observed_RVs: return NodeType.OBSERVED_RV - elif v in model.Free_RVs: - return NodeType.Free_RV - elif v in model.deterministics: - return NodeType.DETERMINISTIC - else: + elif v in model.data_vars: return NodeType.DATA + else: + return NodeType.POTENTIAL NodeTypeFormatterMapping = dict[NodeType, NodeFormatter] DEFAULT_NODE_FORMATTERS: NodeTypeFormatterMapping = { NodeType.POTENTIAL: default_potential, - NodeType.BASIC_RV: default_basic_rv, + NodeType.FREE_RV: default_free_rv, NodeType.OBSERVED_RV: default_observed_rv, NodeType.DETERMINISTIC: default_deterministic, NodeType.DATA: default_data, @@ -515,6 +516,17 @@ def model_to_networkx( obs = Normal("obs", theta, sigma=sigma, observed=y) model_to_networkx(schools) + + Add custom attributes to Free Random Variables and Observed Random Variables nodes. + + .. code-block:: python + + node_formatters = { + "Free Random Variable": lambda var: {"shape": "circle", "label": var.name}, + "Observed Random Variable": lambda var: {"shape": "square", "label": var.name}, + } + model_to_networkx(schools, node_formatters=node_formatters) + """ if "plain" not in formatting: raise ValueError(f"Unsupported formatting for graph nodes: '{formatting}'. See docstring.") @@ -569,6 +581,8 @@ def model_to_graphviz( Dots per inch. It only affects the resolution of the saved figure. The default is 300. node_formatters : dict, optional A dictionary mapping node types to functions that return a dictionary of node attributes. + Check out graphviz documentation for more information on available + attributes. https://graphviz.org/docs/nodes/ Examples -------- @@ -603,6 +617,16 @@ def model_to_graphviz( # creates the file `schools.pdf` model_to_graphviz(schools).render("schools") + + Display Free Random Variables and Observed Random Variables nodes with custom formatting. + + .. code-block:: python + + node_formatters = { + "Free Random Variable": lambda var: {"shape": "circle", "label": var.name}, + "Observed Random Variable": lambda var: {"shape": "square", "label": var.name}, + } + model_to_graphviz(schools, node_formatters=node_formatters) """ if "plain" not in formatting: raise ValueError(f"Unsupported formatting for graph nodes: '{formatting}'. See docstring.") From 052095b39c248d090a9d5e9d7112a0de57586ad0 Mon Sep 17 00:00:00 2001 From: Will Dean Date: Tue, 7 May 2024 23:01:30 +0200 Subject: [PATCH 5/8] add some tests --- tests/test_model_graph.py | 44 +++++++++++++++++++++++++++++++++++++++ 1 file changed, 44 insertions(+) diff --git a/tests/test_model_graph.py b/tests/test_model_graph.py index 05b8b99e4..fd7f693c6 100644 --- a/tests/test_model_graph.py +++ b/tests/test_model_graph.py @@ -421,3 +421,47 @@ def test_model_graph_with_intermediate_named_variables(): b.name = "b" pm.Normal("c", b, 1) assert dict(ModelGraph(m2).make_compute_graph()) == {"a": set(), "c": {"a"}} + + +@pytest.fixture +def simple_model() -> pm.Model: + with pm.Model() as model: + a = pm.Normal("a") + b = pm.Normal("b", mu=a) + c = pm.Normal("c", mu=b) + + return model + + +def test_unknown_node_type(simple_model): + with pytest.raises(ValueError, match="Node formatters must be of type NodeType."): + model_to_graphviz(simple_model, node_formatters={"Unknown Node Type": "dummy"}) + + +def test_custom_node_formatting_networkx(simple_model): + node_formatters = { + "Free Random Variable": lambda var: { + "label": var.name, + }, + } + + G = model_to_networkx(simple_model, node_formatters=node_formatters) + assert G.__dict__["_node"] == { + "a": {"label": "a"}, + "b": {"label": "b"}, + "c": {"label": "c"}, + } + + +@pytest.mark.xfail(reason="Graphviz is not deterministic") +def test_custom_node_formatting_graphviz(simple_model): + node_formatters = { + "Free Random Variable": lambda var: { + "label": var.name, + }, + } + + G = model_to_graphviz(simple_model, node_formatters=node_formatters) + assert G.source == ( + "digraph {\n\ta [label=a]\n\tb [label=b]" "\n\tc [label=c]\n\ta -> b\n\tb -> c\n}\n" + ) From fbd904a6491329bb82219a875d96a3cea07e2408 Mon Sep 17 00:00:00 2001 From: Will Dean Date: Wed, 8 May 2024 08:05:41 +0200 Subject: [PATCH 6/8] link to networkx add_node --- pymc/model_graph.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/pymc/model_graph.py b/pymc/model_graph.py index e7bf16f89..2910e49d4 100644 --- a/pymc/model_graph.py +++ b/pymc/model_graph.py @@ -491,6 +491,8 @@ def model_to_networkx( one of { "plain" } node_formatters : dict, optional A dictionary mapping node types to functions that return a dictionary of node attributes. + Check out the networkx documentation for more information + how attributes are added to nodes: https://networkx.org/documentation/stable/reference/classes/generated/networkx.Graph.add_node.html Examples -------- From 403a81de4b58698b362c8f40dbe32233695cdb16 Mon Sep 17 00:00:00 2001 From: Will Dean Date: Wed, 8 May 2024 08:34:42 +0200 Subject: [PATCH 7/8] extract from body --- tests/test_model_graph.py | 15 +++++++++++---- 1 file changed, 11 insertions(+), 4 deletions(-) diff --git a/tests/test_model_graph.py b/tests/test_model_graph.py index fd7f693c6..b0ba1dba8 100644 --- a/tests/test_model_graph.py +++ b/tests/test_model_graph.py @@ -453,7 +453,6 @@ def test_custom_node_formatting_networkx(simple_model): } -@pytest.mark.xfail(reason="Graphviz is not deterministic") def test_custom_node_formatting_graphviz(simple_model): node_formatters = { "Free Random Variable": lambda var: { @@ -462,6 +461,14 @@ def test_custom_node_formatting_graphviz(simple_model): } G = model_to_graphviz(simple_model, node_formatters=node_formatters) - assert G.source == ( - "digraph {\n\ta [label=a]\n\tb [label=b]" "\n\tc [label=c]\n\ta -> b\n\tb -> c\n}\n" - ) + body = [item.strip() for item in G.body] + + items = [ + "a [label=a]", + "b [label=b]", + "c [label=c]", + "a -> b", + "b -> c", + ] + for item in items: + assert item in body From d1af9da44cf80bbcd957ad273e643a403b5aa3ad Mon Sep 17 00:00:00 2001 From: Will Dean Date: Wed, 8 May 2024 08:37:35 +0200 Subject: [PATCH 8/8] use set operation --- tests/test_model_graph.py | 23 ++++++++++++----------- 1 file changed, 12 insertions(+), 11 deletions(-) diff --git a/tests/test_model_graph.py b/tests/test_model_graph.py index b0ba1dba8..718f38e53 100644 --- a/tests/test_model_graph.py +++ b/tests/test_model_graph.py @@ -461,14 +461,15 @@ def test_custom_node_formatting_graphviz(simple_model): } G = model_to_graphviz(simple_model, node_formatters=node_formatters) - body = [item.strip() for item in G.body] - - items = [ - "a [label=a]", - "b [label=b]", - "c [label=c]", - "a -> b", - "b -> c", - ] - for item in items: - assert item in body + body = set(item.strip() for item in G.body) + + items = set( + [ + "a [label=a]", + "b [label=b]", + "c [label=c]", + "a -> b", + "b -> c", + ] + ) + assert body == items