Skip to content

Commit fe60bf8

Browse files
committed
[VTA] Make vta graph_pack compatible with latest TVM, and bring back
object detection tutorials.
1 parent 1abd248 commit fe60bf8

File tree

2 files changed

+351
-6
lines changed

2 files changed

+351
-6
lines changed

vta/python/vta/top/graphpack.py

Lines changed: 27 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -56,13 +56,24 @@ def _pack_batch_channel(data, dshape, bfactor, cfactor):
5656
return data
5757

5858

59-
def _unpack_batch_channel(data, old_shape):
59+
def _unpack_batch_channel(data, old_shape, unpack_transpose=False):
6060
"""Unpack the data channel dimension."""
61-
data = op.transpose(data, axes=(0, 4, 1, 5, 2, 3))
61+
if unpack_transpose:
62+
data = op.transpose(data, axes=(0, 4, 1, 5, 2, 3))
6263
data = op.reshape(data, newshape=old_shape)
6364
return data
6465

6566

67+
def _channel_const_match(channel_length, cfactor_out):
68+
"""Round the chanel const variant if the value not divisible by cfactor_out"""
69+
diff = int(channel_length) % cfactor_out
70+
if diff != 0:
71+
diff = cfactor_out - diff
72+
channel_length = channel_length + diff
73+
74+
return diff, channel_length
75+
76+
6677
def _const_shape_match(data, dshape, cfactor_out):
6778
"""Pad the constant if the shape[0] not divisible by cfactor_out."""
6879
assert len(dshape) == 3
@@ -299,6 +310,7 @@ def __init__(self, bfactor, cfactor, weight_bits):
299310
self.upsampling = op.op.get("nn.upsampling")
300311
self.reshape = op.op.get("reshape")
301312
self.number_of_conv2d = 0
313+
self.unpack_transpose = True
302314
super().__init__()
303315

304316
def visit_call(self, call):
@@ -319,7 +331,7 @@ def visit_call(self, call):
319331
self.start_pack = False
320332
data = args[0]
321333
data_shape = _get_tensor_shape(call.args[0])
322-
return _unpack_batch_channel(data, data_shape)
334+
return _unpack_batch_channel(data, data_shape, self.unpack_transpose)
323335
if self.start_pack:
324336
# Operator cases
325337
if call.op == self.conv2d and odtype == "int32":
@@ -429,12 +441,12 @@ def visit_call(self, call):
429441
if len(pad_width) == 6:
430442
pass
431443
elif len(pad_width) == 4:
432-
(data,) = args
444+
(data, pad_value) = args
433445
new_pad_width = []
434446
new_pad_width.extend(pad_width)
435447
for _ in range(2):
436448
new_pad_width.append([0, 0])
437-
return op.nn.pad(data, pad_value=call.attrs.pad_value, pad_width=new_pad_width)
449+
return op.nn.pad(data, pad_value=pad_value, pad_width=new_pad_width)
438450
elif call.op == self.upsampling:
439451
(data,) = args
440452
scale_h = call.attrs.scale_h
@@ -445,8 +457,17 @@ def visit_call(self, call):
445457
return op.nn.upsampling(data, scale_h, scale_w, data_layout, method, align_corners)
446458
elif call.op == self.reshape and len(input_types[0].shape) == 4:
447459
(data,) = args
460+
self.unpack_transpose = False
448461
data = op.transpose(data, axes=(0, 4, 1, 5, 2, 3))
449-
return op.reshape(data, [int(x) for x in input_types[0].shape])
462+
new_shape = [int(x) for x in input_types[0].shape]
463+
# Check if the reshape match with such shape after pad
464+
pad, new_shape[1] = _channel_const_match(new_shape[1], self.cfactor)
465+
data = op.reshape(data, new_shape)
466+
# remove pad data
467+
if pad != 0:
468+
new_pad_width = [[0, 0], [0, -pad], [0, 0], [0, 0]]
469+
data = op.nn.pad(data, pad_width=new_pad_width)
470+
return data
450471

451472
return relay.Call(self.visit(call.op), args, call.attrs)
452473

0 commit comments

Comments
 (0)