@@ -56,13 +56,25 @@ class ConvPackedParam(QNNParam):
5656 """
5757
5858 def __init__ (
59- self , weight_np , bias , scale , zero_point , param_name , stride , padding , dilation , groups
59+ self ,
60+ weight_np ,
61+ bias ,
62+ scale ,
63+ zero_point ,
64+ param_name ,
65+ stride ,
66+ padding ,
67+ dilation ,
68+ groups ,
69+ output_padding ,
6070 ):
6171 super ().__init__ (weight_np , bias , scale , zero_point , param_name )
6272 self .stride = stride
6373 self .padding = padding
6474 self .dilation = dilation
6575 self .groups = groups
76+ # Used only for conv_transpose2d
77+ self .output_padding = output_padding
6678
6779
6880def _get_quant_params (qweight ):
@@ -92,8 +104,18 @@ def make_conv_packed_param(param_name, qweight, bias, packed_params):
92104 padding = packed_params .padding ()
93105 dilation = packed_params .dilation ()
94106 groups = packed_params .groups ()
107+ output_padding = packed_params .output_padding ()
95108 return ConvPackedParam (
96- weight_np , bias , scale , zero_point , param_name , stride , padding , dilation , groups
109+ weight_np ,
110+ bias ,
111+ scale ,
112+ zero_point ,
113+ param_name ,
114+ stride ,
115+ padding ,
116+ dilation ,
117+ groups ,
118+ output_padding ,
97119 )
98120
99121
@@ -154,7 +176,13 @@ def add_quant_params_to_outputs(outputs, packed_param_map, quant_params):
154176 params = [qweight , qparam .scale , qparam .zero_point , qparam .bias_var ]
155177
156178 if isinstance (quant_params [packed_param_name ], ConvPackedParam ):
157- params += [qparam .stride , qparam .padding , qparam .dilation , qparam .groups ]
179+ params += [
180+ qparam .stride ,
181+ qparam .padding ,
182+ qparam .dilation ,
183+ qparam .groups ,
184+ qparam .output_padding ,
185+ ]
158186
159187 outputs [node_name ] = params
160188
@@ -192,6 +220,7 @@ def _get_quant_param_for_input(input_value):
192220 "quantized::mul_scalar" : (2 , 3 ),
193221 "quantized::add_scalar" : (2 , 3 ),
194222 "quantized::hardswish" : (1 , 2 ),
223+ "quantized::conv_transpose2d" : qconv_indices ,
195224 }
196225
197226 def dfs (current_node ):
@@ -362,6 +391,7 @@ def add_input_quant_params_to_op_inputs(graph):
362391 "quantized::relu6" : 1 ,
363392 "quantized::hardswish" : 1 ,
364393 "aten::hardsigmoid" : 1 ,
394+ "quantized::conv_transpose2d" : 1 ,
365395 }
366396
367397 need_input_quant_param = set (num_quantized_inputs .keys ())
@@ -924,6 +954,65 @@ def _impl(inputs, _):
924954 return _impl
925955
926956
957+ def _quantized_conv_transpose2d (with_relu = False ):
958+ def _impl (inputs , _ ):
959+ # Refer to aten/src/ATen/native/quantized/cpu/qconv.cpp
960+ # Supported in Torch 1.7 or newer
961+ conv_params = inputs [1 ]
962+ weight = conv_params [0 ]
963+ weight_scale = conv_params [1 ]
964+ weight_zero_point = conv_params [2 ]
965+ bias = conv_params [3 ]
966+
967+ strides = conv_params [4 ]
968+ padding = conv_params [5 ]
969+ dilation = conv_params [6 ]
970+ groups = conv_params [7 ]
971+ output_padding = conv_params [8 ]
972+
973+ output_scale = _expr .const (inputs [2 ])
974+ output_zero_point = _expr .const (inputs [3 ])
975+
976+ assert len (inputs ) == 6 , "Input quant params not found in op inputs"
977+
978+ # These are manually added by add_input_quant_params_to_op_inputs above
979+ # In torch, they are retrieved from QTensor data structure at runtime
980+ input_scale = _expr .const (inputs [4 ])
981+ input_zero_point = _expr .const (inputs [5 ])
982+
983+ weight_shape = list (infer_shape (weight ))
984+
985+ # Swap I and O dims to match shape relay expects for OIHW
986+ weight_shape [0 ], weight_shape [1 ] = weight_shape [1 ], weight_shape [0 ]
987+
988+ kernel_size = (weight_shape [2 ], weight_shape [3 ])
989+ out_channels = weight_shape [0 ]
990+
991+ conv_out = relay .qnn .op .conv2d_transpose (
992+ inputs [0 ],
993+ weight ,
994+ input_zero_point ,
995+ weight_zero_point ,
996+ input_scale ,
997+ weight_scale ,
998+ kernel_size = kernel_size ,
999+ dilation = dilation ,
1000+ strides = strides ,
1001+ padding = padding ,
1002+ groups = groups ,
1003+ channels = out_channels ,
1004+ output_padding = output_padding ,
1005+ out_dtype = "int32" ,
1006+ kernel_layout = "OIHW" ,
1007+ )
1008+
1009+ return _do_bias_and_requantize (
1010+ conv_out , bias , input_scale , weight_scale , output_scale , output_zero_point , with_relu
1011+ )
1012+
1013+ return _impl
1014+
1015+
9271016convert_map = {
9281017 "aten::quantize_per_tensor" : _quantize_per_tensor (),
9291018 "quantized::conv2d_relu" : _quantized_conv2d (with_relu = True ),
@@ -941,4 +1030,5 @@ def _impl(inputs, _):
9411030 "quantized::relu6" : _relu6 (),
9421031 "quantized::linear_dynamic" : _linear_dynamic (),
9431032 "quantized::hardswish" : _hswish (),
1033+ "quantized::conv_transpose2d" : _quantized_conv_transpose2d (),
9441034}
0 commit comments