Skip to content

Commit 594bc0f

Browse files
authored
Relay transform for rolling a known pattern into batch_matmul (#14210)
Relay transform for rolling strided_slice, dense and other ops into batch_matmul
1 parent f3b64b7 commit 594bc0f

File tree

2 files changed

+283
-0
lines changed

2 files changed

+283
-0
lines changed

python/tvm/contrib/hexagon/transform.py

Lines changed: 163 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,9 @@
2020
import functools as ft
2121

2222
import tvm
23+
from tvm import relay
24+
from tvm.relay.dataflow_pattern import DFPatternCallback, rewrite, wildcard
25+
from tvm.relay.dataflow_pattern import is_constant, is_op, is_tuple
2326
from ..._ffi.registry import register_func
2427

2528
### VTCM
@@ -148,3 +151,163 @@ def transform(func, mod, ctx):
148151

149152
def ir_lower_vtcm_pass():
150153
return [(3, ir_lower_vtcm())]
154+
155+
156+
class qdistilbert_rewrite(DFPatternCallback):
157+
"""
158+
A callback to replace the below pattern:
159+
Pattern:
160+
%35 = strided_slice(%34, begin=[0, 0, 0], end=[1, 128, 64], strides=[1, 1, 1], axes=None);
161+
%44 = reshape(%35, newshape=[-1, 64]);
162+
<snip>
163+
%42 = strided_slice(%41, begin=[0, 0, 0], end=[1, 64, 128], strides=[1, 1, 1], axes=None);
164+
%43 = reshape(%42, newshape=[64, 128]);
165+
%45 = transpose(%43, axes=[1, 0]);
166+
<snip>
167+
%46 = qnn.dense(%44, %45, 13, 1, 0.0541715f, 0.0489368f, units=None, out_dtype="int32");
168+
%47 = qnn.requantize(%46, 0.00265098f, 0, 0.728874f, -14, axis=1, out_dtype="int8");
169+
<snip>
170+
%125 = expand_dims(%47, axis=0) /* ty=Tensor[(1, 128, 128), int8] */;
171+
< The above pattern repeats 12 times, which is the batch size >
172+
173+
%137 = (%125, %126, %127, %128, %129, %130, %131, %132, %133, %134, %135, %136);
174+
%138 = concatenate(%137);
175+
176+
"""
177+
178+
def __init__(self):
179+
super(qdistilbert_rewrite, self).__init__()
180+
self.A = wildcard() # Tensor A
181+
self.B = wildcard() # Tensor B
182+
self.batch = 12 # Number of time pattern repeats or Batch size
183+
184+
self.d = [] # List of dense quantization parameters
185+
self.q = [] # List of requantize parameters
186+
L = [] # List of patterns
187+
188+
z = tvm.tir.IntImm("int64", 0)
189+
s1 = tvm.tir.IntImm("int64", 1)
190+
191+
for i in range(self.batch):
192+
x = tvm.tir.IntImm("int64", i)
193+
194+
self.d.append([is_constant(), is_constant(), is_constant(), is_constant()])
195+
self.q.append([is_constant(), is_constant(), is_constant(), is_constant()])
196+
197+
pat_a = is_op("strided_slice")(self.A).has_attr(
198+
{"begin": [x, z, z], "strides": [s1, s1, s1]}
199+
)
200+
pat_a = is_op("reshape")(pat_a)
201+
202+
pat_b = is_op("strided_slice")(self.B).has_attr(
203+
{"begin": [x, z, z], "strides": [s1, s1, s1]}
204+
)
205+
pat_b = is_op("reshape")(pat_b)
206+
pat_b = is_op("transpose")(pat_b)
207+
208+
pat = is_op("qnn.dense")(
209+
pat_a, pat_b, self.d[i][0], self.d[i][1], self.d[i][2], self.d[i][3]
210+
)
211+
pat = is_op("qnn.requantize")(
212+
pat, self.q[i][0], self.q[i][1], self.q[i][2], self.q[i][3]
213+
)
214+
pat = is_op("expand_dims")(pat)
215+
L.append(pat)
216+
217+
T = is_tuple(L)
218+
self.pattern = is_op("concatenate")(T)
219+
220+
def check_quant_params(self, node_map):
221+
"""checking if dense and requant params are the same across patterns"""
222+
r = self.batch
223+
x1 = [node_map[self.d[0][i]][0].data.numpy().item() for i in range(4)]
224+
x2 = [node_map[self.q[0][i]][0].data.numpy().item() for i in range(4)]
225+
for i in range(1, r):
226+
for j in range(4):
227+
y1 = node_map[self.d[i][j]][0].data.numpy().item()
228+
y2 = node_map[self.q[i][j]][0].data.numpy().item()
229+
if x1[j] != y1 or x2[j] != y2:
230+
return False
231+
return True
232+
233+
def callback(self, pre, post, node_map):
234+
A = node_map[self.A][0]
235+
B = node_map[self.B][0]
236+
237+
if not self.check_quant_params(node_map):
238+
return post
239+
240+
[a0, a1, a2] = [0, 0, 0] # Tensor A shape
241+
[b0, b1, b2] = [0, 0, 0] # Tensor B shape
242+
243+
if isinstance(A, relay.expr.Call) and isinstance(B, relay.expr.Call):
244+
if A.checked_type is None or B.checked_type is None:
245+
# Need infer pass to be run before this pass
246+
return post
247+
if len(A.checked_type.shape) == 3 and len(B.checked_type.shape) == 3:
248+
[a0, a1, a2] = A.checked_type.shape
249+
[b0, b1, b2] = B.checked_type.shape
250+
251+
if isinstance(A, relay.Var) and isinstance(B, relay.Var):
252+
if len(A.type_annotation.shape) == 3 and len(B.type_annotation.shape) == 3:
253+
[a0, a1, a2] = A.type_annotation.shape
254+
[b0, b1, b2] = B.type_annotation.shape
255+
256+
# Check if the batch size is same as expected tensor size
257+
if (a0 != self.batch) or (b0 != self.batch):
258+
return post
259+
260+
for i in range(self.batch):
261+
# end=(x, pa1, pa2) attribute of strided_slice for Tensor A
262+
pa1 = pre.args[0][i].args[0].args[0].args[0].args[0].attrs.end[1].value
263+
pa2 = pre.args[0][i].args[0].args[0].args[0].args[0].attrs.end[2].value
264+
265+
# end=(x, pb1, pb2) attribute of strided_slice for Tensor B
266+
pb1 = pre.args[0][i].args[0].args[0].args[1].args[0].args[0].attrs.end[1].value
267+
pb2 = pre.args[0][i].args[0].args[0].args[1].args[0].args[0].attrs.end[2].value
268+
269+
if a1 != pa1 or a2 != pa2 or b1 != pb1 or b2 != pb2:
270+
return post
271+
272+
d = [node_map[self.d[0][i]][0] for i in range(4)]
273+
q = [node_map[self.q[0][i]][0] for i in range(4)]
274+
275+
out = relay.op.transpose(B, axes=[0, 2, 1])
276+
out = relay.qnn.op.batch_matmul(A, out, d[0], d[1], d[2], d[3], out_dtype="int32")
277+
out = relay.qnn.op.requantize(out, q[0], q[1], q[2], q[3], out_dtype="int8")
278+
return out
279+
280+
281+
def rewrite_qdistilbert(mod):
282+
"""Rewrite the Quantized Distilbert to reduce computational complexity."""
283+
mod["main"] = rewrite(qdistilbert_rewrite(), mod["main"])
284+
return mod
285+
286+
287+
class remove_empty_pad_callback(DFPatternCallback):
288+
"""
289+
A callback to remove empty pad op from the below pattern:
290+
Pattern:
291+
%0 = cast(0f, dtype="float16");
292+
%1 = nn.pad(%inp, %0, pad_width=[[0i64, 0i64], [0i64, 0i64]]);
293+
nn.matmul(%1, %inp2, units=None)
294+
295+
"""
296+
297+
def __init__(self):
298+
super(remove_empty_pad_callback, self).__init__()
299+
self.A = wildcard()
300+
self.B = wildcard()
301+
self.a = is_op("nn.pad")(self.A, wildcard()).has_attr({"pad_width": ((0, 0), (0, 0))})
302+
self.pattern = is_op("nn.matmul")(self.a, self.B)
303+
304+
def callback(self, pre, post, node_map):
305+
A = node_map[self.A][0]
306+
B = node_map[self.B][0]
307+
return relay.nn.matmul(A, B)
308+
309+
310+
def remove_empty_pad(mod):
311+
"""Remove the empty pad operator."""
312+
mod["main"] = rewrite(remove_empty_pad_callback(), mod["main"])
313+
return mod
Lines changed: 120 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,120 @@
1+
# Licensed to the Apache Software Foundation (ASF) under one
2+
# or more contributor license agreements. See the NOTICE file
3+
# distributed with this work for additional information
4+
# regarding copyright ownership. The ASF licenses this file
5+
# to you under the Apache License, Version 2.0 (the
6+
# "License"); you may not use this file except in compliance
7+
# with the License. You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing,
12+
# software distributed under the License is distributed on an
13+
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14+
# KIND, either express or implied. See the License for the
15+
# specific language governing permissions and limitations
16+
# under the License.
17+
# pylint: disable=unused-wildcard-import, invalid-name
18+
19+
"""
20+
Test hexagon relay transforms
21+
"""
22+
import tvm
23+
from tvm import relay
24+
from tvm.contrib.hexagon.transform import rewrite_qdistilbert, remove_empty_pad
25+
from tvm import testing
26+
27+
28+
def test_rewrite_qdistilbert():
29+
"""Test case for rewrite_qdistilbert"""
30+
A = relay.var("A", shape=(12, 128, 64), dtype="int8")
31+
B = relay.var("B", shape=(12, 64, 128), dtype="int8")
32+
33+
z = tvm.tir.IntImm("int64", 0)
34+
s1 = tvm.tir.IntImm("int64", 1)
35+
tx = tvm.tir.IntImm("int64", 128)
36+
ty = tvm.tir.IntImm("int64", 64)
37+
expand_dims = []
38+
for i in range(12):
39+
d1 = relay.const(13, dtype="int32")
40+
d2 = relay.const(1, dtype="int32")
41+
d3 = relay.const(0.0541715, dtype="float32")
42+
d4 = relay.const(0.0489368, dtype="float32")
43+
44+
q1 = relay.const(0.00265098, dtype="float32")
45+
q2 = relay.const(0, dtype="int32")
46+
q3 = relay.const(0.728874, dtype="float32")
47+
q4 = relay.const(-14, dtype="int32")
48+
49+
x = tvm.tir.IntImm("int64", i)
50+
y = tvm.tir.IntImm("int64", i + 1)
51+
52+
SA = relay.op.strided_slice(
53+
A, begin=[x, z, z], end=[y, tx, ty], strides=[s1, s1, s1], axes=None
54+
)
55+
RA = relay.op.reshape(SA, [128, 64])
56+
SB = relay.op.strided_slice(
57+
B, begin=[x, z, z], end=[y, ty, tx], strides=[s1, s1, s1], axes=None
58+
)
59+
RB = relay.op.reshape(SB, [64, 128])
60+
TB = relay.op.transpose(RB, [1, 0])
61+
dense = relay.qnn.op.dense(RA, TB, d1, d2, d3, d4, units=None, out_dtype="int32")
62+
requantize = relay.qnn.op.requantize(dense, q1, q2, q3, q4)
63+
expand_dims.append(relay.op.expand_dims(requantize, axis=0))
64+
65+
t = relay.expr.Tuple(expand_dims)
66+
graph = relay.op.concatenate(t, axis=0)
67+
68+
func = relay.Function(relay.analysis.free_vars(graph), graph)
69+
mod = tvm.IRModule.from_expr(func)
70+
mod = rewrite_qdistilbert(mod)
71+
72+
d1 = relay.const(13, dtype="int32")
73+
d2 = relay.const(1, dtype="int32")
74+
d3 = relay.const(0.0541715, dtype="float32")
75+
d4 = relay.const(0.0489368, dtype="float32")
76+
77+
q1 = relay.const(0.00265098, dtype="float32")
78+
q2 = relay.const(0, dtype="int32")
79+
q3 = relay.const(0.728874, dtype="float32")
80+
q4 = relay.const(-14, dtype="int32")
81+
82+
ref = relay.op.transpose(B, [0, 2, 1])
83+
ref = relay.qnn.op.batch_matmul(A, ref, d1, d2, d3, d4, out_dtype="int32")
84+
ref = relay.qnn.op.requantize(ref, q1, q2, q3, q4, out_dtype="int8")
85+
ref_func = relay.Function(relay.analysis.free_vars(ref), ref)
86+
ref_mod = tvm.IRModule.from_expr(ref_func)
87+
88+
assert tvm.ir.structural_equal(mod["main"], ref_mod["main"])
89+
90+
# If the pattern does not match, should return the original.
91+
func = relay.expr.Tuple(expand_dims) # omitting concatenate
92+
mod = tvm.IRModule.from_expr(func)
93+
out_mod = rewrite_qdistilbert(mod) # out does not return ref_mod but the original mod
94+
95+
assert tvm.ir.structural_equal(mod["main"], out_mod["main"])
96+
97+
98+
def test_remove_empty_pad():
99+
"""Test case for remove_empty_pad"""
100+
A = relay.var("A", shape=(32, 32), dtype="float16")
101+
B = relay.var("B", shape=(32, 32), dtype="float16")
102+
103+
p0 = relay.cast(relay.const(0, dtype="float32"), dtype="float16")
104+
p1 = relay.nn.pad(A, pad_value=p0, pad_width=((0, 0), (0, 0)))
105+
graph = relay.nn.matmul(p1, B)
106+
107+
func = relay.Function(relay.analysis.free_vars(graph), graph)
108+
mod = tvm.IRModule.from_expr(func)
109+
110+
mod = remove_empty_pad(mod)
111+
112+
ref = relay.nn.matmul(A, B)
113+
ref_func = relay.Function(relay.analysis.free_vars(ref), ref)
114+
ref_mod = tvm.IRModule.from_expr(ref_func)
115+
116+
assert tvm.ir.structural_equal(mod["main"], ref_mod["main"])
117+
118+
119+
if __name__ == "__main__":
120+
testing.main()

0 commit comments

Comments
 (0)