Skip to content

Commit 9955d7c

Browse files
author
chiwwang
committed
tutorial and doc
1 parent fb2cf61 commit 9955d7c

File tree

8 files changed

+326
-128
lines changed

8 files changed

+326
-128
lines changed

docs/reference/api/python/contrib.rst

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -92,6 +92,16 @@ tvm.contrib.random
9292
.. automodule:: tvm.contrib.random
9393
:members:
9494

95+
tvm.contrib.relay_viz
96+
~~~~~~~~~~~~~~~~~~~~~
97+
.. automodule:: tvm.contrib.relay_viz
98+
:members: RelayVisualizer
99+
.. autoattribute:: tvm.contrib.relay_viz.PlotterBackend.BOKEH
100+
.. autoattribute:: tvm.contrib.relay_viz.PlotterBackend.TERMINAL
101+
.. automodule:: tvm.contrib.relay_viz.plotter
102+
:members:
103+
.. automodule:: tvm.contrib.relay_viz.node_edge_gen
104+
:members:
95105

96106
tvm.contrib.rocblas
97107
~~~~~~~~~~~~~~~~~~~
Lines changed: 162 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,162 @@
1+
# Licensed to the Apache Software Foundation (ASF) under one
2+
# or more contributor license agreements. See the NOTICE file
3+
# distributed with this work for additional information
4+
# regarding copyright ownership. The ASF licenses this file
5+
# to you under the Apache License, Version 2.0 (the
6+
# "License"); you may not use this file except in compliance
7+
# with the License. You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing,
12+
# software distributed under the License is distributed on an
13+
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14+
# KIND, either express or implied. See the License for the
15+
# specific language governing permissions and limitations
16+
# under the License.
17+
# pylint: disable=line-too-long
18+
"""
19+
Use Relay Visualizer to Visualize Relay
20+
============================================================
21+
**Author**: `Chi-Wei Wang <https://github.com/chiwwang>`_
22+
23+
This is an introduction about using Relay Visualizer to visualize a Relay IR module.
24+
25+
Relay IR module can contain lots of operations. Although individual
26+
operations are usually easy to understand, they become complicated quickly
27+
when you put them together. It could get even worse while optimiztion passes
28+
come into play.
29+
30+
This utility abstracts an IR module as graphs containing nodes and edges.
31+
It provides a default parser to interpret an IR modules with nodes and edges.
32+
Two renderer backends are also implemented to visualize them.
33+
34+
Here we use a backend showing Relay IR module in the terminal for illustation.
35+
It is a much more lightweight compared to another backend using `Bokeh <https://docs.bokeh.org/en/latest/>`_.
36+
See ``<TVM_HOME>/python/tvm/contrib/relay_viz/README.md``.
37+
Also we will introduce how to implement customized parsers and renderers through
38+
some interfaces classes.
39+
"""
40+
from typing import (
41+
Dict,
42+
Union,
43+
Tuple,
44+
List,
45+
)
46+
import tvm
47+
from tvm import relay
48+
from tvm.contrib import relay_viz
49+
from tvm.contrib.relay_viz.node_edge_gen import (
50+
VizNode,
51+
VizEdge,
52+
NodeEdgeGenerator,
53+
)
54+
from tvm.contrib.relay_viz.terminal import (
55+
TermNodeEdgeGenerator,
56+
TermGraph,
57+
TermPlotter,
58+
)
59+
60+
######################################################################
61+
# Define a Relay IR Module with multiple GlobalVar
62+
# ------------------------------------------------
63+
# Let's build an example Relay IR Module containing multiple ``GlobalVar``.
64+
# We define an ``add`` function and call it in the main function.
65+
data = relay.var("data")
66+
bias = relay.var("bias")
67+
add_op = relay.add(data, bias)
68+
add_func = relay.Function([data, bias], add_op)
69+
add_gvar = relay.GlobalVar("AddFunc")
70+
71+
input0 = relay.var("input0")
72+
input1 = relay.var("input1")
73+
input2 = relay.var("input2")
74+
add_01 = relay.Call(add_gvar, [input0, input1])
75+
add_012 = relay.Call(add_gvar, [input2, add_01])
76+
main_func = relay.Function([input0, input1, input2], add_012)
77+
main_gvar = relay.GlobalVar("main")
78+
79+
mod = tvm.IRModule({main_gvar: main_func, add_gvar: add_func})
80+
81+
######################################################################
82+
# Render the graph with Relay Visualizer on the terminal
83+
# ------------------------------------------------------
84+
# The terminal backend can show a Relay IR module as in a text-form
85+
# similar to `clang ast-dump <https://clang.llvm.org/docs/IntroductionToTheClangAST.html#examining-the-ast>`_.
86+
# We should see ``main`` and ``AddFunc`` function. ``AddFunc`` is called twice in the ``main`` function.
87+
viz = relay_viz.RelayVisualizer(mod, {}, relay_viz.PlotterBackend.TERMINAL)
88+
viz.render()
89+
90+
######################################################################
91+
# Customize Parser for Interested Relay Types
92+
# -------------------------------------------
93+
# Sometimes the information shown by the default implementation is not suitable
94+
# for a specific usage. It is possible to provide your own parser and renderer.
95+
# Here demostrate how to customize parsers for ``relay.var``.
96+
# We need to implement :py:class:`tvm.contrib.relay_viz.node_edge_gen.NodeEdgeGenerator` interface.
97+
class YourAwesomeParser(NodeEdgeGenerator):
98+
def __init__(self):
99+
self._org_parser = TermNodeEdgeGenerator()
100+
101+
def get_node_edges(
102+
self,
103+
node: relay.Expr,
104+
relay_param: Dict[str, tvm.runtime.NDArray],
105+
node_to_id: Dict[relay.Expr, Union[int, str]],
106+
) -> Tuple[Union[VizNode, None], List[VizEdge]]:
107+
108+
if isinstance(node, relay.Var):
109+
node = VizNode(node_to_id[node], "AwesomeVar", f"name_hint {node.name_hint}")
110+
# no edge is introduced. So return an empty list.
111+
ret = (node, [])
112+
return ret
113+
114+
# delegate other types to the original parser.
115+
return self._org_parser.get_node_edges(node, relay_param, node_to_id)
116+
117+
118+
######################################################################
119+
# Pass a tuple of :py:class:`tvm.contrib.relay_viz.plotter.Plotter` and
120+
# :py:class:`tvm.contrib.relay_viz.node_edge_gen.NodeEdgeGenerator` instances
121+
# to ``RelayVisualizer``. Here we re-use the Plotter interface implemented inside
122+
# ``relay_viz.terminal`` module.
123+
viz = relay_viz.RelayVisualizer(mod, {}, (TermPlotter(), YourAwesomeParser()))
124+
viz.render()
125+
126+
######################################################################
127+
# More Customization around Graph and Plotter
128+
# -------------------------------------------
129+
# All ``RelayVisualizer`` care about are interfaces defined in ``plotter.py`` and
130+
# ``node_edge_generator.py``. We can override them to introduce custimized logics.
131+
# For example, if we want the Graph to duplicate above ``AwesomeVar`` while it is added,
132+
# we can override ``relay_viz.terminal.TermGraph.node``.
133+
class AwesomeGraph(TermGraph):
134+
def node(self, node_id, node_type, node_detail):
135+
# add original node first
136+
super().node(node_id, node_type, node_detail)
137+
if node_type == "AwesomeVar":
138+
duplicated_id = f"duplciated_{node_id}"
139+
duplicated_type = "double AwesomeVar"
140+
super().node(duplicated_id, duplicated_type, "")
141+
# connect the duplicated var to the original one
142+
super().edge(duplicated_id, node_id)
143+
144+
145+
# override TermPlotter to return `AwesomeGraph` instead
146+
class AwesomePlotter(TermPlotter):
147+
def create_graph(self, name):
148+
self._name_to_graph[name] = AwesomeGraph(name)
149+
return self._name_to_graph[name]
150+
151+
152+
viz = relay_viz.RelayVisualizer(mod, {}, (AwesomePlotter(), YourAwesomeParser()))
153+
viz.render()
154+
155+
######################################################################
156+
# Summary
157+
# -------
158+
# This tutorial demonstrates the usage of Relay Visualizer.
159+
# The class :py:class:`tvm.contrib.relay_viz.RelayVisualizer` is composed of interfaces
160+
# defined in ``plotter.py`` and ``node_edge_generator.py``. It provides a single entry point
161+
# while keeping the possibility of implementing customized visualizer in various cases.
162+
#

python/tvm/contrib/relay_viz/README.md

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,10 @@ This tool target to visualize Relay IR.
2828

2929
## Requirement
3030

31+
### Terminal Backend
32+
1. TVM
33+
34+
### Bokeh Backend
3135
1. TVM
3236
2. graphviz
3337
2. pydot
@@ -66,9 +70,9 @@ This utility is composed of two parts: `node_edge_gen.py` and `plotter.py`.
6670

6771
`plotter.py` define interfaces of `Graph` and `Plotter`. `Plotter` is responsible to render a collection of `Graph`.
6872

69-
`node_edge_gen.py` define interfaces of converting Relay IR modules to nodes/edges consumed by `Graph`. Further, this python module also provide a default implementation for common relay types.
73+
`node_edge_gen.py` define interfaces of converting Relay IR modules to nodes and edges. Further, this python module provide a default implementation for common relay types.
7074

7175
If customization is wanted for a certain relay type, we can implement the `NodeEdgeGenerator` interface, handling that relay type accordingly, and delegate other types to the default implementation. See `_terminal.py` for an example usage.
7276

73-
These two interfaces are glued by the top level class `RelayVisualizer`, which passes a relay module to `NodeEdgeGenerator` and add nodes/edges to `Graph`.
74-
Then, it render the plot by `Plotter.render()`.
77+
These two interfaces are glued by the top level class `RelayVisualizer`, which passes a relay module to `NodeEdgeGenerator` and add nodes and edges to `Graph`.
78+
Then, it render the plot by calling `Plotter.render()`.

python/tvm/contrib/relay_viz/__init__.py

Lines changed: 21 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -28,29 +28,33 @@
2828

2929

3030
class PlotterBackend(Enum):
31-
"""Enumeration for available plotters."""
31+
"""Enumeration for available plotter backends."""
3232

3333
BOKEH = "bokeh"
3434
TERMINAL = "terminal"
3535

3636

3737
class RelayVisualizer:
38-
"""Relay IR Visualizer"""
38+
"""Relay IR Visualizer
39+
40+
Parameters
41+
----------
42+
relay_mod : tvm.IRModule
43+
Relay IR module.
44+
relay_param: None | Dict[str, tvm.runtime.NDArray]
45+
Relay parameter dictionary. Default `None`.
46+
backend: PlotterBackend | Tuple[Plotter, NodeEdgeGenerator]
47+
The backend used to render graphs. It can be a tuple of an implemented Plotter instance and
48+
NodeEdgeGenerator instance to introduce customized parsing and visualization logics.
49+
Default ``PlotterBackend.TERMINAL``.
50+
"""
3951

4052
def __init__(
4153
self,
4254
relay_mod: tvm.IRModule,
4355
relay_param: Union[None, Dict[str, tvm.runtime.NDArray]] = None,
4456
backend: Union[PlotterBackend, Tuple[Plotter, NodeEdgeGenerator]] = PlotterBackend.TERMINAL,
4557
):
46-
"""Visualize Relay IR.
47-
48-
Parameters
49-
----------
50-
relay_mod : tvm.IRModule, Relay IR module
51-
relay_param: None | Dict[str, tvm.runtime.NDArray], Relay parameter dictionary. Default `None`.
52-
backend: PlotterBackend | Tuple[Plotter, NodeEdgeGenerator], Default `PlotterBackend.TERMINAL`.
53-
"""
5458

5559
self._plotter, self._ne_generator = get_plotter_and_generator(backend)
5660
self._relay_param = relay_param if relay_param is not None else {}
@@ -83,8 +87,10 @@ def _add_nodes(self, graph, node_to_id, relay_param):
8387
8488
Parameters
8589
----------
86-
graph : `plotter.Graph`
90+
graph : plotter.Graph
91+
8792
node_to_id : Dict[relay.expr, str | int]
93+
8894
relay_param : Dict[str, tvm.runtime.NDarray]
8995
"""
9096
for node in node_to_id:
@@ -102,11 +108,11 @@ def get_plotter_and_generator(backend):
102108
"""Specify the Plottor and its NodeEdgeGenerator"""
103109
if isinstance(backend, (tuple, list)) and len(backend) == 2:
104110
if not isinstance(backend[0], Plotter):
105-
raise ValueError(f"First element of backend should be derived from {type(Plotter)}")
111+
raise ValueError(f"First element should be an instance derived from {type(Plotter)}")
106112

107113
if not isinstance(backend[1], NodeEdgeGenerator):
108114
raise ValueError(
109-
f"Second element of backend should be derived from {type(NodeEdgeGenerator)}"
115+
f"Second element should be an instance derived from {type(NodeEdgeGenerator)}"
110116
)
111117

112118
return backend
@@ -118,7 +124,7 @@ def get_plotter_and_generator(backend):
118124
# Basically we want to keep them optional. Users can choose plotters they want to install.
119125
if backend == PlotterBackend.BOKEH:
120126
# pylint: disable=import-outside-toplevel
121-
from ._bokeh import (
127+
from .bokeh import (
122128
BokehPlotter,
123129
BokehNodeEdgeGenerator,
124130
)
@@ -127,7 +133,7 @@ def get_plotter_and_generator(backend):
127133
ne_generator = BokehNodeEdgeGenerator()
128134
elif backend == PlotterBackend.TERMINAL:
129135
# pylint: disable=import-outside-toplevel
130-
from ._terminal import (
136+
from .terminal import (
131137
TermPlotter,
132138
TermNodeEdgeGenerator,
133139
)

python/tvm/contrib/relay_viz/_bokeh.py renamed to python/tvm/contrib/relay_viz/bokeh.py

Lines changed: 4 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -19,22 +19,9 @@
1919
import functools
2020
import logging
2121

22-
_LOGGER = logging.getLogger(__name__)
23-
2422
import numpy as np
25-
26-
try:
27-
import pydot
28-
except ImportError:
29-
_LOGGER.critical("pydot library is required. You might want to run pip install pydot.")
30-
raise
31-
32-
try:
33-
from bokeh.io import output_file, save
34-
except ImportError:
35-
_LOGGER.critical("bokeh library is required. You might want to run pip install bokeh.")
36-
raise
37-
23+
import pydot
24+
from bokeh.io import output_file, save
3825
from bokeh.models import (
3926
ColumnDataSource,
4027
CustomJS,
@@ -63,6 +50,8 @@
6350

6451
from .node_edge_gen import DefaultNodeEdgeGenerator
6552

53+
_LOGGER = logging.getLogger(__name__)
54+
6655
# Use default node/edge generator
6756
BokehNodeEdgeGenerator = DefaultNodeEdgeGenerator
6857

0 commit comments

Comments
 (0)