|
| 1 | +# Licensed to the Apache Software Foundation (ASF) under one |
| 2 | +# or more contributor license agreements. See the NOTICE file |
| 3 | +# distributed with this work for additional information |
| 4 | +# regarding copyright ownership. The ASF licenses this file |
| 5 | +# to you under the Apache License, Version 2.0 (the |
| 6 | +# "License"); you may not use this file except in compliance |
| 7 | +# with the License. You may obtain a copy of the License at |
| 8 | +# |
| 9 | +# http://www.apache.org/licenses/LICENSE-2.0 |
| 10 | +# |
| 11 | +# Unless required by applicable law or agreed to in writing, |
| 12 | +# software distributed under the License is distributed on an |
| 13 | +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
| 14 | +# KIND, either express or implied. See the License for the |
| 15 | +# specific language governing permissions and limitations |
| 16 | +# under the License. |
| 17 | +# pylint: disable=line-too-long |
| 18 | +""" |
| 19 | +Use Relay Visualizer to Visualize Relay |
| 20 | +============================================================ |
| 21 | +**Author**: `Chi-Wei Wang <https://github.com/chiwwang>`_ |
| 22 | +
|
| 23 | +This is an introduction about using Relay Visualizer to visualize a Relay IR module. |
| 24 | +
|
| 25 | +Relay IR module can contain lots of operations. Although individual |
| 26 | +operations are usually easy to understand, they become complicated quickly |
| 27 | +when you put them together. It could get even worse while optimiztion passes |
| 28 | +come into play. |
| 29 | +
|
| 30 | +This utility abstracts an IR module as graphs containing nodes and edges. |
| 31 | +It provides a default parser to interpret an IR modules with nodes and edges. |
| 32 | +Two renderer backends are also implemented to visualize them. |
| 33 | +
|
| 34 | +Here we use a backend showing Relay IR module in the terminal for illustation. |
| 35 | +It is a much more lightweight compared to another backend using `Bokeh <https://docs.bokeh.org/en/latest/>`_. |
| 36 | +See ``<TVM_HOME>/python/tvm/contrib/relay_viz/README.md``. |
| 37 | +Also we will introduce how to implement customized parsers and renderers through |
| 38 | +some interfaces classes. |
| 39 | +""" |
| 40 | +from typing import ( |
| 41 | + Dict, |
| 42 | + Union, |
| 43 | + Tuple, |
| 44 | + List, |
| 45 | +) |
| 46 | +import tvm |
| 47 | +from tvm import relay |
| 48 | +from tvm.contrib import relay_viz |
| 49 | +from tvm.contrib.relay_viz.node_edge_gen import ( |
| 50 | + VizNode, |
| 51 | + VizEdge, |
| 52 | + NodeEdgeGenerator, |
| 53 | +) |
| 54 | +from tvm.contrib.relay_viz.terminal import ( |
| 55 | + TermNodeEdgeGenerator, |
| 56 | + TermGraph, |
| 57 | + TermPlotter, |
| 58 | +) |
| 59 | + |
| 60 | +###################################################################### |
| 61 | +# Define a Relay IR Module with multiple GlobalVar |
| 62 | +# ------------------------------------------------ |
| 63 | +# Let's build an example Relay IR Module containing multiple ``GlobalVar``. |
| 64 | +# We define an ``add`` function and call it in the main function. |
| 65 | +data = relay.var("data") |
| 66 | +bias = relay.var("bias") |
| 67 | +add_op = relay.add(data, bias) |
| 68 | +add_func = relay.Function([data, bias], add_op) |
| 69 | +add_gvar = relay.GlobalVar("AddFunc") |
| 70 | + |
| 71 | +input0 = relay.var("input0") |
| 72 | +input1 = relay.var("input1") |
| 73 | +input2 = relay.var("input2") |
| 74 | +add_01 = relay.Call(add_gvar, [input0, input1]) |
| 75 | +add_012 = relay.Call(add_gvar, [input2, add_01]) |
| 76 | +main_func = relay.Function([input0, input1, input2], add_012) |
| 77 | +main_gvar = relay.GlobalVar("main") |
| 78 | + |
| 79 | +mod = tvm.IRModule({main_gvar: main_func, add_gvar: add_func}) |
| 80 | + |
| 81 | +###################################################################### |
| 82 | +# Render the graph with Relay Visualizer on the terminal |
| 83 | +# ------------------------------------------------------ |
| 84 | +# The terminal backend can show a Relay IR module as in a text-form |
| 85 | +# similar to `clang ast-dump <https://clang.llvm.org/docs/IntroductionToTheClangAST.html#examining-the-ast>`_. |
| 86 | +# We should see ``main`` and ``AddFunc`` function. ``AddFunc`` is called twice in the ``main`` function. |
| 87 | +viz = relay_viz.RelayVisualizer(mod, {}, relay_viz.PlotterBackend.TERMINAL) |
| 88 | +viz.render() |
| 89 | + |
| 90 | +###################################################################### |
| 91 | +# Customize Parser for Interested Relay Types |
| 92 | +# ------------------------------------------- |
| 93 | +# Sometimes the information shown by the default implementation is not suitable |
| 94 | +# for a specific usage. It is possible to provide your own parser and renderer. |
| 95 | +# Here demostrate how to customize parsers for ``relay.var``. |
| 96 | +# We need to implement :py:class:`tvm.contrib.relay_viz.node_edge_gen.NodeEdgeGenerator` interface. |
| 97 | +class YourAwesomeParser(NodeEdgeGenerator): |
| 98 | + def __init__(self): |
| 99 | + self._org_parser = TermNodeEdgeGenerator() |
| 100 | + |
| 101 | + def get_node_edges( |
| 102 | + self, |
| 103 | + node: relay.Expr, |
| 104 | + relay_param: Dict[str, tvm.runtime.NDArray], |
| 105 | + node_to_id: Dict[relay.Expr, Union[int, str]], |
| 106 | + ) -> Tuple[Union[VizNode, None], List[VizEdge]]: |
| 107 | + |
| 108 | + if isinstance(node, relay.Var): |
| 109 | + node = VizNode(node_to_id[node], "AwesomeVar", f"name_hint {node.name_hint}") |
| 110 | + # no edge is introduced. So return an empty list. |
| 111 | + ret = (node, []) |
| 112 | + return ret |
| 113 | + |
| 114 | + # delegate other types to the original parser. |
| 115 | + return self._org_parser.get_node_edges(node, relay_param, node_to_id) |
| 116 | + |
| 117 | + |
| 118 | +###################################################################### |
| 119 | +# Pass a tuple of :py:class:`tvm.contrib.relay_viz.plotter.Plotter` and |
| 120 | +# :py:class:`tvm.contrib.relay_viz.node_edge_gen.NodeEdgeGenerator` instances |
| 121 | +# to ``RelayVisualizer``. Here we re-use the Plotter interface implemented inside |
| 122 | +# ``relay_viz.terminal`` module. |
| 123 | +viz = relay_viz.RelayVisualizer(mod, {}, (TermPlotter(), YourAwesomeParser())) |
| 124 | +viz.render() |
| 125 | + |
| 126 | +###################################################################### |
| 127 | +# More Customization around Graph and Plotter |
| 128 | +# ------------------------------------------- |
| 129 | +# All ``RelayVisualizer`` care about are interfaces defined in ``plotter.py`` and |
| 130 | +# ``node_edge_generator.py``. We can override them to introduce custimized logics. |
| 131 | +# For example, if we want the Graph to duplicate above ``AwesomeVar`` while it is added, |
| 132 | +# we can override ``relay_viz.terminal.TermGraph.node``. |
| 133 | +class AwesomeGraph(TermGraph): |
| 134 | + def node(self, node_id, node_type, node_detail): |
| 135 | + # add original node first |
| 136 | + super().node(node_id, node_type, node_detail) |
| 137 | + if node_type == "AwesomeVar": |
| 138 | + duplicated_id = f"duplciated_{node_id}" |
| 139 | + duplicated_type = "double AwesomeVar" |
| 140 | + super().node(duplicated_id, duplicated_type, "") |
| 141 | + # connect the duplicated var to the original one |
| 142 | + super().edge(duplicated_id, node_id) |
| 143 | + |
| 144 | + |
| 145 | +# override TermPlotter to return `AwesomeGraph` instead |
| 146 | +class AwesomePlotter(TermPlotter): |
| 147 | + def create_graph(self, name): |
| 148 | + self._name_to_graph[name] = AwesomeGraph(name) |
| 149 | + return self._name_to_graph[name] |
| 150 | + |
| 151 | + |
| 152 | +viz = relay_viz.RelayVisualizer(mod, {}, (AwesomePlotter(), YourAwesomeParser())) |
| 153 | +viz.render() |
| 154 | + |
| 155 | +###################################################################### |
| 156 | +# Summary |
| 157 | +# ------- |
| 158 | +# This tutorial demonstrates the usage of Relay Visualizer. |
| 159 | +# The class :py:class:`tvm.contrib.relay_viz.RelayVisualizer` is composed of interfaces |
| 160 | +# defined in ``plotter.py`` and ``node_edge_generator.py``. It provides a single entry point |
| 161 | +# while keeping the possibility of implementing customized visualizer in various cases. |
| 162 | +# |
0 commit comments