3131from .. import function as _function
3232from .. import ty as _ty
3333from .. import op as _op
34+ from .. import qnn as _qnn
3435from .common import (
3536 autopad ,
3637 fold_constant ,
@@ -314,9 +315,9 @@ def convert_conv2d(g, op, block):
314315 strides = op .attr ("strides" )
315316
316317 kernel = g .get_node (op .input ("Filter" )[0 ])
317- kernel_layout = "OIHW"
318318 input_x = g .get_node (op .input ("Input" )[0 ])
319319 data_layout = op .attr ("data_format" )
320+ kernel_layout = "OIHW" if data_layout == "NCHW" else "HWIO"
320321 out_channels , _ , k_h , k_w = infer_shape (kernel )
321322 if padding_algorithm == "VALID" :
322323 paddings = [0 , 0 ]
@@ -336,9 +337,15 @@ def convert_conv2d(g, op, block):
336337 msg = f'Value { padding_algorithm } in attribute "padding" of operator Conv is not "valid."'
337338 raise tvm .error .OpAttributeInvalid (msg )
338339
339- if data_layout == "NHWC" :
340- kernel_layout = "HWIO"
341- # PaddlePaddle wieght layout is "OIHW", tvm need "HWIO" when op data_format is "NHWC".
340+ is_quantized = op .has_attr ("quantization_type" )
341+ # PaddlePaddle wieght layout is "OIHW", tvm need "HWIO" when op data_format is "NHWC".
342+ # There are two situations when converting the data format of weights:
343+ # 1 Conv_2d is not a quantified OP, its weight information is the weights themselves.
344+ # We directly convert the weight information when processing conv_2d.
345+ # 2 Conv_2d is a quantified OP, and its weight information is the output of
346+ # the quantize_linear operator. Therefore, the weight information needs to be
347+ # transformed when processing the quantize_linear operator.
348+ if (not is_quantized ) and (data_layout == "NHWC" ):
342349 kernel_data = g .get_params (op .input ("Filter" )[0 ])
343350 kernel_data = kernel_data .asnumpy ()
344351 kernel_data = kernel_data .transpose ((2 , 3 , 1 , 0 ))
@@ -1626,7 +1633,7 @@ def convert_pool3d(g, op, block):
16261633 raise tvm .error .OpAttributeInvalid (msg .format (padding_algorithm ))
16271634
16281635 # handle with special case
1629- # while kernel size less than input size
1636+ # while kernel size more than input size
16301637 # shrink kernel size to input size
16311638 if (
16321639 not isinstance (in_h , _op .Expr )
@@ -1812,6 +1819,59 @@ def convert_roi_align(g, op, block):
18121819 g .add_node (op .output ("Out" )[0 ], out )
18131820
18141821
1822+ def convert_dequantize_linear (g , op , block ):
1823+ """Operator converter for dequantize_linear."""
1824+
1825+ data_node_name = op .input ("X" )[0 ]
1826+ data_node = g .get_node (data_node_name )
1827+
1828+ # paddle_scale = tvm_scale * 127
1829+ paddle_quantize_scale = g .get_params (op .input ("Scale" )[0 ]).asnumpy ()
1830+ tvm_quantize_scale = paddle_quantize_scale / 127.0
1831+
1832+ tvm_quantize_zp = g .get_params (op .input ("ZeroPoint" )[0 ]).asnumpy ()
1833+
1834+ tvm_quantize_axis = op .attr ("quant_axis" )
1835+ if tvm_quantize_axis == - 1 :
1836+ tvm_quantize_axis = 0
1837+
1838+ if len (infer_shape (data_node )) < 2 :
1839+ tvm_quantize_axis = 0
1840+
1841+ out = _qnn .op .dequantize (
1842+ data = data_node ,
1843+ input_scale = _op .const (tvm_quantize_scale , "float32" ),
1844+ input_zero_point = _op .const (tvm_quantize_zp , "int32" ),
1845+ axis = tvm_quantize_axis ,
1846+ )
1847+ g .add_node (op .output ("Y" )[0 ], out )
1848+
1849+
1850+ def convert_quantize_linear (g , op , block ):
1851+ """Operator converter for dequantize_linear."""
1852+
1853+ data_node_name = op .input ("X" )[0 ]
1854+ data_node = g .get_node (data_node_name )
1855+
1856+ # paddle_scale = tvm_scale * 127
1857+ paddle_quantize_scale = g .get_params (op .input ("Scale" )[0 ]).asnumpy ()
1858+ tvm_quantize_scale = paddle_quantize_scale / 127.0
1859+
1860+ tvm_quantize_zp = g .get_params (op .input ("ZeroPoint" )[0 ]).asnumpy ()
1861+ tvm_quantize_axis = op .attr ("quant_axis" )
1862+
1863+ if tvm_quantize_axis == - 1 :
1864+ tvm_quantize_axis = 0
1865+
1866+ out = _qnn .op .quantize (
1867+ data = data_node ,
1868+ output_scale = _op .const (tvm_quantize_scale , "float32" ),
1869+ output_zero_point = _op .const (tvm_quantize_zp , "int32" ),
1870+ axis = tvm_quantize_axis ,
1871+ )
1872+ g .add_node (op .output ("Y" )[0 ], out )
1873+
1874+
18151875def convert_rnn (g , op , block ):
18161876 """Operator converter for rnn."""
18171877
@@ -2386,11 +2446,11 @@ def convert_slice(g, op, block):
23862446def convert_softmax (g , op , block ):
23872447 """Operator converter for softmax."""
23882448
2449+ x = g .get_node (op .input ("X" )[0 ])
23892450 axis = op .attr ("axis" )
23902451 input_shape = block .var (op .input ("X" )[0 ]).shape
23912452 if axis < 0 :
23922453 axis = len (input_shape ) + axis
2393- x = g .get_node (op .input ("X" )[0 ])
23942454 m = _op .max (x , axis , keepdims = True )
23952455 e = _op .exp (x - m )
23962456 out = e / _op .sum (e , axis , keepdims = True )
@@ -2905,6 +2965,9 @@ def convert_where_index(g, op, block):
29052965 "unstack" : convert_unstack ,
29062966 "where" : convert_where ,
29072967 "where_index" : convert_where_index ,
2968+ # Quantized
2969+ "dequantize_linear" : convert_dequantize_linear ,
2970+ "quantize_linear" : convert_quantize_linear ,
29082971}
29092972
29102973
@@ -2938,7 +3001,7 @@ def get_params(self, name=None):
29383001
29393002 if name is None :
29403003 return self .params
2941- assert name in self .params
3004+ assert name in self .params , f"The name( { name } ) is not in params"
29423005 return self .params [name ]
29433006
29443007 def extract_parameters (self , program , scope = None ):
@@ -2947,9 +3010,12 @@ def extract_parameters(self, program, scope=None):
29473010 self .params = {}
29483011 variables = program .global_block ().vars
29493012 for name in variables :
2950- var = program .global_block ().var (name )
29513013 if name .endswith ("feed" ) or name .endswith ("fetch" ):
29523014 continue
3015+ # This judgment will cause the PaddleInference model
3016+ # exported by PaddleSlim to skip some operators
3017+ # that need to be read in NHWC format.
3018+ var = program .global_block ().var (name )
29533019 if not var .persistable :
29543020 continue
29553021 if isinstance (scope , dict ):
@@ -3018,7 +3084,6 @@ def from_program(self, program, shape_dict, scope):
30183084 for op in block .ops :
30193085 if op .type == "fetch" :
30203086 output_names .append (op .input ("X" )[0 ])
3021-
30223087 outputs = [self .nodes [name ] for name in output_names ]
30233088 outputs = outputs [0 ] if len (outputs ) == 1 else _expr .Tuple (outputs )
30243089
0 commit comments