Skip to content

Commit af7361b

Browse files
kueitanggracetang
authored andcommitted
[Visualization Relay IR on terminal]
-Add a AST dump pass -It provides a snap shot to the relay IR graph
1 parent 26c2a9a commit af7361b

File tree

2 files changed

+530
-0
lines changed

2 files changed

+530
-0
lines changed

python/tvm/contrib/retv.py

Lines changed: 307 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,307 @@
1+
# Licensed to the Apache Software Foundation (ASF) under one
2+
# or more contributor license agreements. See the NOTICE file
3+
# distributed with this work for additional information
4+
# regarding copyright ownership. The ASF licenses this file
5+
# to you under the Apache License, Version 2.0 (the
6+
# "License"); you may not use this file except in compliance
7+
# with the License. You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing,
12+
# software distributed under the License is distributed on an
13+
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14+
# KIND, either express or implied. See the License for the
15+
# specific language governing permissions and limitations
16+
# under the License.
17+
"""Relay Expression Terminal Visualization (RETV), visualizing Relay Expression on Terminal"""
18+
from tvm import relay
19+
from tvm import ir
20+
from tvm.relay import Tuple
21+
22+
23+
class Node:
24+
"""Base unit of a relay IR visualization node.
25+
26+
Parameters
27+
----------
28+
expr : expr
29+
Relay IR expression.
30+
31+
name : str
32+
The name of the relay IR node.
33+
34+
parent : Node
35+
The parent node of the relay IR node.
36+
37+
is_last : bool
38+
Whether the node is the last within same level nodes.
39+
"""
40+
41+
def __init__(self, expr, name, parent, is_last):
42+
self.expr = expr
43+
self.name = name
44+
self.parent = parent
45+
self.is_last = is_last
46+
self.children = []
47+
self.prefix = ""
48+
49+
50+
@ir.transform.module_pass(opt_level=1)
51+
class ASTVisualization:
52+
"""To visualize the relay IR graph on terminal."""
53+
54+
def __init__(self):
55+
self.output = []
56+
57+
def get_output(self):
58+
"""
59+
Returns
60+
-------
61+
output: str
62+
The graph.
63+
"""
64+
output = "== The AST view of the IRModule is ==\n"
65+
for subout in self.output[1:]:
66+
output += subout + "\n"
67+
output += self.output[0] + "\n" # "main" function
68+
return output
69+
70+
def transform_module(self, mod, ctx):
71+
"""A module pass"""
72+
73+
class ASTVisitor(relay.ExprVisitor):
74+
"""
75+
A visitor over Expr.
76+
77+
It traverses the AST recursively, and each node information into a sequence.
78+
"""
79+
80+
def __init__(self):
81+
super(ASTVisitor, self).__init__()
82+
self.sequence = []
83+
self.parent_stack = []
84+
self.last_stack = []
85+
self.current_subgraph = ""
86+
87+
def seen_node(self, new_node, expr):
88+
"""Record those seen expression"""
89+
self.sequence.append(new_node)
90+
self.parent_stack.append(new_node)
91+
for expr_child in self.memo_map[expr].children:
92+
new_node = Node(
93+
expr=expr_child,
94+
name=self.memo_map[expr_child].name,
95+
parent=self.parent_stack[-1],
96+
is_last=self.memo_map[expr_child].is_last,
97+
)
98+
self.seen_node(new_node, expr_child)
99+
self.parent_stack.pop()
100+
101+
def visit(self, expr):
102+
if expr in self.memo_map:
103+
new_node = Node(
104+
expr=expr,
105+
name=self.memo_map[expr].name,
106+
parent=self.parent_stack[-1],
107+
is_last=self.last_stack[-1],
108+
)
109+
self.seen_node(new_node, expr)
110+
else:
111+
super(ASTVisitor, self).visit(expr)
112+
113+
def visit_tuple(self, tup):
114+
node = Node(
115+
expr=tup,
116+
name="(tuple)",
117+
parent=self.parent_stack[-1],
118+
is_last=self.last_stack[-1],
119+
)
120+
self.sequence.append(node)
121+
node.parent.children.append(tup)
122+
self.parent_stack.append(node)
123+
for i, x in enumerate(tup.fields):
124+
if i == len(tup.fields) - 1:
125+
self.last_stack.append(True)
126+
else:
127+
self.last_stack.append(False)
128+
self.visit(x)
129+
self.last_stack.pop()
130+
self.parent_stack.pop()
131+
return node
132+
133+
def visit_var(self, var):
134+
node = Node(
135+
expr=var,
136+
name=var.name_hint,
137+
parent=self.parent_stack[-1],
138+
is_last=self.last_stack[-1],
139+
)
140+
self.sequence.append(node)
141+
node.parent.children.append(var)
142+
return node
143+
144+
def visit_function(self, fn):
145+
if len(self.sequence) == 0: # entry function call
146+
layer_name = "@" + self.current_subgraph + "(" + str(fn.params) + ")"
147+
self.parent_stack = [None]
148+
self.last_stack = [True]
149+
else:
150+
layer_name = "Function_" + str(fn.__hash__()) + "(" + str(fn.params) + ")"
151+
152+
node = Node(
153+
expr=fn,
154+
name=layer_name,
155+
parent=self.parent_stack[-1],
156+
is_last=self.last_stack[-1],
157+
)
158+
self.sequence.append(node)
159+
if node.parent is not None:
160+
node.parent.children.append(fn)
161+
162+
is_last = True
163+
self.last_stack.append(is_last)
164+
self.parent_stack.append(node)
165+
self.visit(fn.body)
166+
self.parent_stack.pop()
167+
self.last_stack.pop()
168+
return node
169+
170+
def visit_call(self, call):
171+
layer_name = "(call)"
172+
node = Node(
173+
expr=call,
174+
name=layer_name,
175+
parent=self.parent_stack[-1],
176+
is_last=self.last_stack[-1],
177+
)
178+
self.sequence.append(node)
179+
node.parent.children.append(call)
180+
self.parent_stack.append(node)
181+
self.last_stack.append(len(call.args) == 0)
182+
self.visit(call.op)
183+
self.last_stack.pop()
184+
185+
for i, arg in enumerate(call.args):
186+
is_last = i == len(call.args) - 1
187+
self.last_stack.append(is_last)
188+
self.visit(arg)
189+
self.last_stack.pop()
190+
self.parent_stack.pop()
191+
return node
192+
193+
def visit_constant(self, const):
194+
node = Node(
195+
expr=const,
196+
name=const,
197+
parent=self.parent_stack[-1],
198+
is_last=self.last_stack[-1],
199+
)
200+
self.sequence.append(node)
201+
node.parent.children.append(const)
202+
return node
203+
204+
def visit_if(self, i):
205+
layer_name = "if(cond, true, false)"
206+
node = Node(
207+
expr=i,
208+
name=layer_name,
209+
parent=self.parent_stack[-1],
210+
is_last=self.last_stack[-1],
211+
)
212+
node.parent.children.append(node)
213+
self.sequence.append(node)
214+
self.parent_stack.append(node)
215+
self.last_stack.append(False)
216+
self.visit(i.cond)
217+
self.last_stack[-1] = False
218+
self.visit(i.true_branch)
219+
self.last_stack[-1] = True
220+
self.visit(i.false_branch)
221+
self.last_stack.pop()
222+
self.parent_stack.pop()
223+
return node
224+
225+
def visit_let(self, let):
226+
layer_name = "let(var, val, body)"
227+
node = Node(
228+
expr=let,
229+
name=layer_name,
230+
parent=self.parent_stack[-1],
231+
is_last=self.last_stack[-1],
232+
)
233+
self.sequence.append(node)
234+
node.parent.children.append(let)
235+
self.parent_stack.append(node)
236+
self.last_stack.append(False)
237+
self.visit(let.var)
238+
self.last_stack[-1] = False
239+
self.visit(let.value)
240+
self.last_stack[-1] = True
241+
self.visit(let.body)
242+
self.last_stack.pop()
243+
self.parent_stack.pop()
244+
return node
245+
246+
def visit_global_var(self, gv):
247+
layer_name = "@" + str(gv.name_hint)
248+
node = Node(
249+
expr=gv,
250+
name=layer_name,
251+
parent=self.parent_stack[-1],
252+
is_last=self.last_stack[-1],
253+
)
254+
self.sequence.append(node)
255+
node.parent.children.append(gv)
256+
return node
257+
258+
def visit_op(self, op):
259+
node = Node(
260+
expr=op,
261+
name=str(op.name),
262+
parent=self.parent_stack[-1],
263+
is_last=self.last_stack[-1],
264+
)
265+
self.sequence.append(node)
266+
node.parent.children.append(op)
267+
return node
268+
269+
def prettyprint(self):
270+
"""Prettyprint the result"""
271+
272+
if len(self.sequence) <= 1:
273+
raise RuntimeError("It is an empty IRmodule")
274+
res = ""
275+
res += self.sequence[0].name + "\n"
276+
for node in self.sequence[1:]:
277+
if node.parent is None:
278+
part_a = ""
279+
part_b = ""
280+
else:
281+
part_a = node.parent.prefix[:-3]
282+
part_b = " " * 3 if node.parent.is_last else "| "
283+
part_c = "`--" if node.is_last else "|--"
284+
if isinstance(node.expr, Tuple):
285+
name = ""
286+
for child in node.children:
287+
name += str(self.memo_map[child].name) + ", "
288+
name = "(" + name[:-2] + ")"
289+
node.name = name
290+
node.prefix = part_a + part_b + part_c
291+
res += node.prefix + str(node.name) + "\n"
292+
return res
293+
294+
printer = ASTVisitor()
295+
printer.current_subgraph = "main"
296+
printer.visit(mod["main"])
297+
self.output.append(printer.prettyprint())
298+
for subgraph in mod.get_global_vars():
299+
name = subgraph.name_hint
300+
if name != "main":
301+
printer.sequence = []
302+
printer.parent_stack = []
303+
printer.last_stack = []
304+
printer.current_subgraph = name
305+
printer.visit(mod[name])
306+
self.output.append(printer.prettyprint())
307+
return mod

0 commit comments

Comments
 (0)