@@ -56,13 +56,24 @@ def _pack_batch_channel(data, dshape, bfactor, cfactor):
5656 return data
5757
5858
59- def _unpack_batch_channel (data , old_shape ):
59+ def _unpack_batch_channel (data , old_shape , unpack_transpose = False ):
6060 """Unpack the data channel dimension."""
61- data = op .transpose (data , axes = (0 , 4 , 1 , 5 , 2 , 3 ))
61+ if unpack_transpose :
62+ data = op .transpose (data , axes = (0 , 4 , 1 , 5 , 2 , 3 ))
6263 data = op .reshape (data , newshape = old_shape )
6364 return data
6465
6566
67+ def _channel_const_match (channel_length , cfactor_out ):
68+ """Round the chanel const variant if the value not divisible by cfactor_out"""
69+ diff = int (channel_length ) % cfactor_out
70+ if diff != 0 :
71+ diff = cfactor_out - diff
72+ channel_length = channel_length + diff
73+
74+ return diff , channel_length
75+
76+
6677def _const_shape_match (data , dshape , cfactor_out ):
6778 """Pad the constant if the shape[0] not divisible by cfactor_out."""
6879 assert len (dshape ) == 3
@@ -299,6 +310,7 @@ def __init__(self, bfactor, cfactor, weight_bits):
299310 self .upsampling = op .op .get ("nn.upsampling" )
300311 self .reshape = op .op .get ("reshape" )
301312 self .number_of_conv2d = 0
313+ self .unpack_transpose = True
302314 super ().__init__ ()
303315
304316 def visit_call (self , call ):
@@ -319,7 +331,7 @@ def visit_call(self, call):
319331 self .start_pack = False
320332 data = args [0 ]
321333 data_shape = _get_tensor_shape (call .args [0 ])
322- return _unpack_batch_channel (data , data_shape )
334+ return _unpack_batch_channel (data , data_shape , self . unpack_transpose )
323335 if self .start_pack :
324336 # Operator cases
325337 if call .op == self .conv2d and odtype == "int32" :
@@ -429,12 +441,12 @@ def visit_call(self, call):
429441 if len (pad_width ) == 6 :
430442 pass
431443 elif len (pad_width ) == 4 :
432- (data ,) = args
444+ (data , pad_value ) = args
433445 new_pad_width = []
434446 new_pad_width .extend (pad_width )
435447 for _ in range (2 ):
436448 new_pad_width .append ([0 , 0 ])
437- return op .nn .pad (data , pad_value = call . attrs . pad_value , pad_width = new_pad_width )
449+ return op .nn .pad (data , pad_value = pad_value , pad_width = new_pad_width )
438450 elif call .op == self .upsampling :
439451 (data ,) = args
440452 scale_h = call .attrs .scale_h
@@ -445,8 +457,17 @@ def visit_call(self, call):
445457 return op .nn .upsampling (data , scale_h , scale_w , data_layout , method , align_corners )
446458 elif call .op == self .reshape and len (input_types [0 ].shape ) == 4 :
447459 (data ,) = args
460+ self .unpack_transpose = False
448461 data = op .transpose (data , axes = (0 , 4 , 1 , 5 , 2 , 3 ))
449- return op .reshape (data , [int (x ) for x in input_types [0 ].shape ])
462+ new_shape = [int (x ) for x in input_types [0 ].shape ]
463+ # Check if the reshape match with such shape after pad
464+ pad , new_shape [1 ] = _channel_const_match (new_shape [1 ], self .cfactor )
465+ data = op .reshape (data , new_shape )
466+ # remove pad data
467+ if pad != 0 :
468+ new_pad_width = [[0 , 0 ], [0 , - pad ], [0 , 0 ], [0 , 0 ]]
469+ data = op .nn .pad (data , pad_width = new_pad_width )
470+ return data
450471
451472 return relay .Call (self .visit (call .op ), args , call .attrs )
452473
0 commit comments