|
20 | 20 | import functools as ft |
21 | 21 |
|
22 | 22 | import tvm |
| 23 | +from tvm import relay |
| 24 | +from tvm.relay.dataflow_pattern import DFPatternCallback, rewrite, wildcard |
| 25 | +from tvm.relay.dataflow_pattern import is_constant, is_op, is_tuple |
23 | 26 | from ..._ffi.registry import register_func |
24 | 27 |
|
25 | 28 | ### VTCM |
@@ -148,3 +151,163 @@ def transform(func, mod, ctx): |
148 | 151 |
|
149 | 152 | def ir_lower_vtcm_pass(): |
150 | 153 | return [(3, ir_lower_vtcm())] |
| 154 | + |
| 155 | + |
| 156 | +class qdistilbert_rewrite(DFPatternCallback): |
| 157 | + """ |
| 158 | + A callback to replace the below pattern: |
| 159 | + Pattern: |
| 160 | + %35 = strided_slice(%34, begin=[0, 0, 0], end=[1, 128, 64], strides=[1, 1, 1], axes=None); |
| 161 | + %44 = reshape(%35, newshape=[-1, 64]); |
| 162 | + <snip> |
| 163 | + %42 = strided_slice(%41, begin=[0, 0, 0], end=[1, 64, 128], strides=[1, 1, 1], axes=None); |
| 164 | + %43 = reshape(%42, newshape=[64, 128]); |
| 165 | + %45 = transpose(%43, axes=[1, 0]); |
| 166 | + <snip> |
| 167 | + %46 = qnn.dense(%44, %45, 13, 1, 0.0541715f, 0.0489368f, units=None, out_dtype="int32"); |
| 168 | + %47 = qnn.requantize(%46, 0.00265098f, 0, 0.728874f, -14, axis=1, out_dtype="int8"); |
| 169 | + <snip> |
| 170 | + %125 = expand_dims(%47, axis=0) /* ty=Tensor[(1, 128, 128), int8] */; |
| 171 | + < The above pattern repeats 12 times, which is the batch size > |
| 172 | +
|
| 173 | + %137 = (%125, %126, %127, %128, %129, %130, %131, %132, %133, %134, %135, %136); |
| 174 | + %138 = concatenate(%137); |
| 175 | +
|
| 176 | + """ |
| 177 | + |
| 178 | + def __init__(self): |
| 179 | + super(qdistilbert_rewrite, self).__init__() |
| 180 | + self.A = wildcard() # Tensor A |
| 181 | + self.B = wildcard() # Tensor B |
| 182 | + self.batch = 12 # Number of time pattern repeats or Batch size |
| 183 | + |
| 184 | + self.d = [] # List of dense quantization parameters |
| 185 | + self.q = [] # List of requantize parameters |
| 186 | + L = [] # List of patterns |
| 187 | + |
| 188 | + z = tvm.tir.IntImm("int64", 0) |
| 189 | + s1 = tvm.tir.IntImm("int64", 1) |
| 190 | + |
| 191 | + for i in range(self.batch): |
| 192 | + x = tvm.tir.IntImm("int64", i) |
| 193 | + |
| 194 | + self.d.append([is_constant(), is_constant(), is_constant(), is_constant()]) |
| 195 | + self.q.append([is_constant(), is_constant(), is_constant(), is_constant()]) |
| 196 | + |
| 197 | + pat_a = is_op("strided_slice")(self.A).has_attr( |
| 198 | + {"begin": [x, z, z], "strides": [s1, s1, s1]} |
| 199 | + ) |
| 200 | + pat_a = is_op("reshape")(pat_a) |
| 201 | + |
| 202 | + pat_b = is_op("strided_slice")(self.B).has_attr( |
| 203 | + {"begin": [x, z, z], "strides": [s1, s1, s1]} |
| 204 | + ) |
| 205 | + pat_b = is_op("reshape")(pat_b) |
| 206 | + pat_b = is_op("transpose")(pat_b) |
| 207 | + |
| 208 | + pat = is_op("qnn.dense")( |
| 209 | + pat_a, pat_b, self.d[i][0], self.d[i][1], self.d[i][2], self.d[i][3] |
| 210 | + ) |
| 211 | + pat = is_op("qnn.requantize")( |
| 212 | + pat, self.q[i][0], self.q[i][1], self.q[i][2], self.q[i][3] |
| 213 | + ) |
| 214 | + pat = is_op("expand_dims")(pat) |
| 215 | + L.append(pat) |
| 216 | + |
| 217 | + T = is_tuple(L) |
| 218 | + self.pattern = is_op("concatenate")(T) |
| 219 | + |
| 220 | + def check_quant_params(self, node_map): |
| 221 | + """checking if dense and requant params are the same across patterns""" |
| 222 | + r = self.batch |
| 223 | + x1 = [node_map[self.d[0][i]][0].data.numpy().item() for i in range(4)] |
| 224 | + x2 = [node_map[self.q[0][i]][0].data.numpy().item() for i in range(4)] |
| 225 | + for i in range(1, r): |
| 226 | + for j in range(4): |
| 227 | + y1 = node_map[self.d[i][j]][0].data.numpy().item() |
| 228 | + y2 = node_map[self.q[i][j]][0].data.numpy().item() |
| 229 | + if x1[j] != y1 or x2[j] != y2: |
| 230 | + return False |
| 231 | + return True |
| 232 | + |
| 233 | + def callback(self, pre, post, node_map): |
| 234 | + A = node_map[self.A][0] |
| 235 | + B = node_map[self.B][0] |
| 236 | + |
| 237 | + if not self.check_quant_params(node_map): |
| 238 | + return post |
| 239 | + |
| 240 | + [a0, a1, a2] = [0, 0, 0] # Tensor A shape |
| 241 | + [b0, b1, b2] = [0, 0, 0] # Tensor B shape |
| 242 | + |
| 243 | + if isinstance(A, relay.expr.Call) and isinstance(B, relay.expr.Call): |
| 244 | + if A.checked_type is None or B.checked_type is None: |
| 245 | + # Need infer pass to be run before this pass |
| 246 | + return post |
| 247 | + if len(A.checked_type.shape) == 3 and len(B.checked_type.shape) == 3: |
| 248 | + [a0, a1, a2] = A.checked_type.shape |
| 249 | + [b0, b1, b2] = B.checked_type.shape |
| 250 | + |
| 251 | + if isinstance(A, relay.Var) and isinstance(B, relay.Var): |
| 252 | + if len(A.type_annotation.shape) == 3 and len(B.type_annotation.shape) == 3: |
| 253 | + [a0, a1, a2] = A.type_annotation.shape |
| 254 | + [b0, b1, b2] = B.type_annotation.shape |
| 255 | + |
| 256 | + # Check if the batch size is same as expected tensor size |
| 257 | + if (a0 != self.batch) or (b0 != self.batch): |
| 258 | + return post |
| 259 | + |
| 260 | + for i in range(self.batch): |
| 261 | + # end=(x, pa1, pa2) attribute of strided_slice for Tensor A |
| 262 | + pa1 = pre.args[0][i].args[0].args[0].args[0].args[0].attrs.end[1].value |
| 263 | + pa2 = pre.args[0][i].args[0].args[0].args[0].args[0].attrs.end[2].value |
| 264 | + |
| 265 | + # end=(x, pb1, pb2) attribute of strided_slice for Tensor B |
| 266 | + pb1 = pre.args[0][i].args[0].args[0].args[1].args[0].args[0].attrs.end[1].value |
| 267 | + pb2 = pre.args[0][i].args[0].args[0].args[1].args[0].args[0].attrs.end[2].value |
| 268 | + |
| 269 | + if a1 != pa1 or a2 != pa2 or b1 != pb1 or b2 != pb2: |
| 270 | + return post |
| 271 | + |
| 272 | + d = [node_map[self.d[0][i]][0] for i in range(4)] |
| 273 | + q = [node_map[self.q[0][i]][0] for i in range(4)] |
| 274 | + |
| 275 | + out = relay.op.transpose(B, axes=[0, 2, 1]) |
| 276 | + out = relay.qnn.op.batch_matmul(A, out, d[0], d[1], d[2], d[3], out_dtype="int32") |
| 277 | + out = relay.qnn.op.requantize(out, q[0], q[1], q[2], q[3], out_dtype="int8") |
| 278 | + return out |
| 279 | + |
| 280 | + |
| 281 | +def rewrite_qdistilbert(mod): |
| 282 | + """Rewrite the Quantized Distilbert to reduce computational complexity.""" |
| 283 | + mod["main"] = rewrite(qdistilbert_rewrite(), mod["main"]) |
| 284 | + return mod |
| 285 | + |
| 286 | + |
| 287 | +class remove_empty_pad_callback(DFPatternCallback): |
| 288 | + """ |
| 289 | + A callback to remove empty pad op from the below pattern: |
| 290 | + Pattern: |
| 291 | + %0 = cast(0f, dtype="float16"); |
| 292 | + %1 = nn.pad(%inp, %0, pad_width=[[0i64, 0i64], [0i64, 0i64]]); |
| 293 | + nn.matmul(%1, %inp2, units=None) |
| 294 | +
|
| 295 | + """ |
| 296 | + |
| 297 | + def __init__(self): |
| 298 | + super(remove_empty_pad_callback, self).__init__() |
| 299 | + self.A = wildcard() |
| 300 | + self.B = wildcard() |
| 301 | + self.a = is_op("nn.pad")(self.A, wildcard()).has_attr({"pad_width": ((0, 0), (0, 0))}) |
| 302 | + self.pattern = is_op("nn.matmul")(self.a, self.B) |
| 303 | + |
| 304 | + def callback(self, pre, post, node_map): |
| 305 | + A = node_map[self.A][0] |
| 306 | + B = node_map[self.B][0] |
| 307 | + return relay.nn.matmul(A, B) |
| 308 | + |
| 309 | + |
| 310 | +def remove_empty_pad(mod): |
| 311 | + """Remove the empty pad operator.""" |
| 312 | + mod["main"] = rewrite(remove_empty_pad_callback(), mod["main"]) |
| 313 | + return mod |
0 commit comments