Skip to content

Commit fb24fd3

Browse files
author
kueitang
committed
[Visualization Relay IR on terminal]
-Add a AST dump pass -It provides a snap shot to the relay IR graph
1 parent 26c2a9a commit fb24fd3

File tree

2 files changed

+529
-0
lines changed

2 files changed

+529
-0
lines changed

python/tvm/contrib/retv.py

Lines changed: 306 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,306 @@
1+
# Licensed to the Apache Software Foundation (ASF) under one
2+
# or more contributor license agreements. See the NOTICE file
3+
# distributed with this work for additional information
4+
# regarding copyright ownership. The ASF licenses this file
5+
# to you under the Apache License, Version 2.0 (the
6+
# "License"); you may not use this file except in compliance
7+
# with the License. You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing,
12+
# software distributed under the License is distributed on an
13+
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14+
# KIND, either express or implied. See the License for the
15+
# specific language governing permissions and limitations
16+
# under the License.
17+
"""Relay Expression Terminal Visualization (RETV), visualizing Relay Expression on Terminal"""
18+
from tvm import relay
19+
from tvm import ir
20+
21+
22+
class Node:
23+
"""Base unit of a relay IR visualization node.
24+
25+
Parameters
26+
----------
27+
expr : expr
28+
Relay IR expression.
29+
30+
name : str
31+
The name of the relay IR node.
32+
33+
parent : Node
34+
The parent node of the relay IR node.
35+
36+
is_last : bool
37+
Whether the node is the last within same level nodes.
38+
"""
39+
40+
def __init__(self, expr, name, parent, is_last):
41+
self.expr = expr
42+
self.name = name
43+
self.parent = parent
44+
self.is_last = is_last
45+
self.children = []
46+
self.prefix = ""
47+
48+
49+
@ir.transform.module_pass(opt_level=1)
50+
class ASTVisualization:
51+
"""To visualize the relay IR graph on terminal."""
52+
53+
def __init__(self):
54+
self.output = []
55+
56+
def get_output(self):
57+
"""
58+
Returns
59+
-------
60+
output: str
61+
The graph.
62+
"""
63+
output = "== The AST view of the IRModule is ==\n"
64+
for subout in self.output[1:]:
65+
output += subout + "\n"
66+
output += self.output[0] + "\n" # "main" function
67+
return output
68+
69+
def transform_module(self, mod):
70+
"""A module pass"""
71+
72+
class ASTVisitor(relay.ExprVisitor):
73+
"""
74+
A visitor over Expr.
75+
76+
It traverses the AST recursively, and each node information into a sequence.
77+
"""
78+
79+
def __init__(self):
80+
super(ASTVisitor, self).__init__()
81+
self.sequence = []
82+
self.parent_stack = []
83+
self.last_stack = []
84+
self.current_subgraph = ""
85+
86+
def seen_node(self, new_node, expr):
87+
"""Record those seen expression"""
88+
self.sequence.append(new_node)
89+
self.parent_stack.append(new_node)
90+
for expr_child in self.memo_map[expr].children:
91+
new_node = Node(
92+
expr=expr_child,
93+
name=self.memo_map[expr_child].name,
94+
parent=self.parent_stack[-1],
95+
is_last=self.memo_map[expr_child].is_last,
96+
)
97+
self.seen_node(new_node, expr_child)
98+
self.parent_stack.pop()
99+
100+
def visit(self, expr):
101+
if expr in self.memo_map:
102+
new_node = Node(
103+
expr=expr,
104+
name=self.memo_map[expr].name,
105+
parent=self.parent_stack[-1],
106+
is_last=self.last_stack[-1],
107+
)
108+
self.seen_node(new_node, expr)
109+
else:
110+
super(ASTVisitor, self).visit(expr)
111+
112+
def visit_tuple(self, tup):
113+
node = Node(
114+
expr=tup,
115+
name="(tuple)",
116+
parent=self.parent_stack[-1],
117+
is_last=self.last_stack[-1],
118+
)
119+
self.sequence.append(node)
120+
node.parent.children.append(tup)
121+
self.parent_stack.append(node)
122+
for i, x in enumerate(tup.fields):
123+
if i == len(tup.fields) - 1:
124+
self.last_stack.append(True)
125+
else:
126+
self.last_stack.append(False)
127+
self.visit(x)
128+
self.last_stack.pop()
129+
self.parent_stack.pop()
130+
return node
131+
132+
def visit_var(self, var):
133+
node = Node(
134+
expr=var,
135+
name=var.name_hint,
136+
parent=self.parent_stack[-1],
137+
is_last=self.last_stack[-1],
138+
)
139+
self.sequence.append(node)
140+
node.parent.children.append(var)
141+
return node
142+
143+
def visit_function(self, fn):
144+
if len(self.sequence) == 0: # entry function call
145+
layer_name = "@" + self.current_subgraph + "(" + str(f.params) + ")"
146+
self.parent_stack = [None]
147+
self.last_stack = [True]
148+
else:
149+
layer_name = "Function_" + str(fn.__hash__()) + "(" + str(fn.params) + ")"
150+
151+
node = Node(
152+
expr=fn,
153+
name=layer_name,
154+
parent=self.parent_stack[-1],
155+
is_last=self.last_stack[-1],
156+
)
157+
self.sequence.append(node)
158+
if node.parent is not None:
159+
node.parent.children.append(fn)
160+
161+
is_last = True
162+
self.last_stack.append(is_last)
163+
self.parent_stack.append(node)
164+
self.visit(fn.body)
165+
self.parent_stack.pop()
166+
self.last_stack.pop()
167+
return node
168+
169+
def visit_call(self, call):
170+
layer_name = "(call)"
171+
node = Node(
172+
expr=call,
173+
name=layer_name,
174+
parent=self.parent_stack[-1],
175+
is_last=self.last_stack[-1],
176+
)
177+
self.sequence.append(node)
178+
node.parent.children.append(call)
179+
self.parent_stack.append(node)
180+
self.last_stack.append(len(call.args) == 0)
181+
self.visit(call.op)
182+
self.last_stack.pop()
183+
184+
for i, arg in enumerate(call.args):
185+
is_last = i == len(call.args) - 1
186+
self.last_stack.append(is_last)
187+
self.visit(arg)
188+
self.last_stack.pop()
189+
self.parent_stack.pop()
190+
return node
191+
192+
def visit_constant(self, const):
193+
node = Node(
194+
expr=const,
195+
name=const,
196+
parent=self.parent_stack[-1],
197+
is_last=self.last_stack[-1],
198+
)
199+
self.sequence.append(node)
200+
node.parent.children.append(const)
201+
return node
202+
203+
def visit_if(self, i):
204+
layer_name = "if(cond, true, false)"
205+
node = Node(
206+
expr=i,
207+
name=layer_name,
208+
parent=self.parent_stack[-1],
209+
is_last=self.last_stack[-1],
210+
)
211+
node.parent.children.append(node)
212+
self.sequence.append(node)
213+
self.parent_stack.append(node)
214+
self.last_stack.append(False)
215+
self.visit(i.cond)
216+
self.last_stack[-1] = False
217+
self.visit(i.true_branch)
218+
self.last_stack[-1] = True
219+
self.visit(i.false_branch)
220+
self.last_stack.pop()
221+
self.parent_stack.pop()
222+
return node
223+
224+
def visit_let(self, let):
225+
layer_name = "let(var, val, body)"
226+
node = Node(
227+
expr=let,
228+
name=layer_name,
229+
parent=self.parent_stack[-1],
230+
is_last=self.last_stack[-1],
231+
)
232+
self.sequence.append(node)
233+
node.parent.children.append(let)
234+
self.parent_stack.append(node)
235+
self.last_stack.append(False)
236+
self.visit(let.var)
237+
self.last_stack[-1] = False
238+
self.visit(let.value)
239+
self.last_stack[-1] = True
240+
self.visit(let.body)
241+
self.last_stack.pop()
242+
self.parent_stack.pop()
243+
return node
244+
245+
def visit_global_var(self, gv):
246+
layer_name = "@" + str(gv.name_hint)
247+
node = Node(
248+
expr=gv,
249+
name=layer_name,
250+
parent=self.parent_stack[-1],
251+
is_last=self.last_stack[-1],
252+
)
253+
self.sequence.append(node)
254+
node.parent.children.append(gv)
255+
return node
256+
257+
def visit_op(self, op):
258+
node = Node(
259+
expr=op,
260+
name=str(op.name),
261+
parent=self.parent_stack[-1],
262+
is_last=self.last_stack[-1],
263+
)
264+
self.sequence.append(node)
265+
node.parent.children.append(op)
266+
return node
267+
268+
def prettyprint(self):
269+
"""Prettyprint the result"""
270+
271+
if len(self.sequence) <= 1:
272+
raise RuntimeError("It is an empty IRmodule")
273+
res = ""
274+
res += self.sequence[0].name + "\n"
275+
for node in self.sequence[1:]:
276+
if node.parent is None:
277+
part_a = ""
278+
part_b = ""
279+
else:
280+
part_a = node.parent.prefix[:-3]
281+
part_b = " " * 3 if node.parent.is_last else "| "
282+
part_c = "`--" if node.is_last else "|--"
283+
if isinstance(node.expr, Tuple):
284+
name = ""
285+
for child in node.children:
286+
name += str(self.memo_map[child].name) + ", "
287+
name = "(" + name[:-2] + ")"
288+
node.name = name
289+
node.prefix = part_a + part_b + part_c
290+
res += node.prefix + str(node.name) + "\n"
291+
return res
292+
293+
printer = ASTVisitor()
294+
printer.current_subgraph = "main"
295+
printer.visit(mod["main"])
296+
self.output.append(printer.prettyprint())
297+
for subgraph in mod.get_global_vars():
298+
name = subgraph.name_hint
299+
if name != "main":
300+
printer.sequence = []
301+
printer.parent_stack = []
302+
printer.last_stack = []
303+
printer.current_subgraph = name
304+
printer.visit(mod[name])
305+
self.output.append(printer.prettyprint())
306+
return mod

0 commit comments

Comments
 (0)