Skip to content

Commit e005f85

Browse files
[Frontend][PaddlePaddle] PaddlePaddle model with NCHW data format that supports quantization (#16651)
* support conv2d when data_format is NHWC * modify the annotation * Do not convert input data when processing quantization conv_2d nodes * Fix code formatting issues * fixed error code format * update dequantize and quantize * fixed bug when model is fp32 model * update dequantize and quantize * update for paddle quantize model when format is NCHW
1 parent 6ca2341 commit e005f85

File tree

1 file changed

+74
-9
lines changed

1 file changed

+74
-9
lines changed

python/tvm/relay/frontend/paddlepaddle.py

Lines changed: 74 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@
3131
from .. import function as _function
3232
from .. import ty as _ty
3333
from .. import op as _op
34+
from .. import qnn as _qnn
3435
from .common import (
3536
autopad,
3637
fold_constant,
@@ -314,9 +315,9 @@ def convert_conv2d(g, op, block):
314315
strides = op.attr("strides")
315316

316317
kernel = g.get_node(op.input("Filter")[0])
317-
kernel_layout = "OIHW"
318318
input_x = g.get_node(op.input("Input")[0])
319319
data_layout = op.attr("data_format")
320+
kernel_layout = "OIHW" if data_layout == "NCHW" else "HWIO"
320321
out_channels, _, k_h, k_w = infer_shape(kernel)
321322
if padding_algorithm == "VALID":
322323
paddings = [0, 0]
@@ -336,9 +337,15 @@ def convert_conv2d(g, op, block):
336337
msg = f'Value {padding_algorithm} in attribute "padding" of operator Conv is not "valid."'
337338
raise tvm.error.OpAttributeInvalid(msg)
338339

339-
if data_layout == "NHWC":
340-
kernel_layout = "HWIO"
341-
# PaddlePaddle wieght layout is "OIHW", tvm need "HWIO" when op data_format is "NHWC".
340+
is_quantized = op.has_attr("quantization_type")
341+
# PaddlePaddle wieght layout is "OIHW", tvm need "HWIO" when op data_format is "NHWC".
342+
# There are two situations when converting the data format of weights:
343+
# 1 Conv_2d is not a quantified OP, its weight information is the weights themselves.
344+
# We directly convert the weight information when processing conv_2d.
345+
# 2 Conv_2d is a quantified OP, and its weight information is the output of
346+
# the quantize_linear operator. Therefore, the weight information needs to be
347+
# transformed when processing the quantize_linear operator.
348+
if (not is_quantized) and (data_layout == "NHWC"):
342349
kernel_data = g.get_params(op.input("Filter")[0])
343350
kernel_data = kernel_data.asnumpy()
344351
kernel_data = kernel_data.transpose((2, 3, 1, 0))
@@ -1626,7 +1633,7 @@ def convert_pool3d(g, op, block):
16261633
raise tvm.error.OpAttributeInvalid(msg.format(padding_algorithm))
16271634

16281635
# handle with special case
1629-
# while kernel size less than input size
1636+
# while kernel size more than input size
16301637
# shrink kernel size to input size
16311638
if (
16321639
not isinstance(in_h, _op.Expr)
@@ -1812,6 +1819,59 @@ def convert_roi_align(g, op, block):
18121819
g.add_node(op.output("Out")[0], out)
18131820

18141821

1822+
def convert_dequantize_linear(g, op, block):
1823+
"""Operator converter for dequantize_linear."""
1824+
1825+
data_node_name = op.input("X")[0]
1826+
data_node = g.get_node(data_node_name)
1827+
1828+
# paddle_scale = tvm_scale * 127
1829+
paddle_quantize_scale = g.get_params(op.input("Scale")[0]).asnumpy()
1830+
tvm_quantize_scale = paddle_quantize_scale / 127.0
1831+
1832+
tvm_quantize_zp = g.get_params(op.input("ZeroPoint")[0]).asnumpy()
1833+
1834+
tvm_quantize_axis = op.attr("quant_axis")
1835+
if tvm_quantize_axis == -1:
1836+
tvm_quantize_axis = 0
1837+
1838+
if len(infer_shape(data_node)) < 2:
1839+
tvm_quantize_axis = 0
1840+
1841+
out = _qnn.op.dequantize(
1842+
data=data_node,
1843+
input_scale=_op.const(tvm_quantize_scale, "float32"),
1844+
input_zero_point=_op.const(tvm_quantize_zp, "int32"),
1845+
axis=tvm_quantize_axis,
1846+
)
1847+
g.add_node(op.output("Y")[0], out)
1848+
1849+
1850+
def convert_quantize_linear(g, op, block):
1851+
"""Operator converter for dequantize_linear."""
1852+
1853+
data_node_name = op.input("X")[0]
1854+
data_node = g.get_node(data_node_name)
1855+
1856+
# paddle_scale = tvm_scale * 127
1857+
paddle_quantize_scale = g.get_params(op.input("Scale")[0]).asnumpy()
1858+
tvm_quantize_scale = paddle_quantize_scale / 127.0
1859+
1860+
tvm_quantize_zp = g.get_params(op.input("ZeroPoint")[0]).asnumpy()
1861+
tvm_quantize_axis = op.attr("quant_axis")
1862+
1863+
if tvm_quantize_axis == -1:
1864+
tvm_quantize_axis = 0
1865+
1866+
out = _qnn.op.quantize(
1867+
data=data_node,
1868+
output_scale=_op.const(tvm_quantize_scale, "float32"),
1869+
output_zero_point=_op.const(tvm_quantize_zp, "int32"),
1870+
axis=tvm_quantize_axis,
1871+
)
1872+
g.add_node(op.output("Y")[0], out)
1873+
1874+
18151875
def convert_rnn(g, op, block):
18161876
"""Operator converter for rnn."""
18171877

@@ -2386,11 +2446,11 @@ def convert_slice(g, op, block):
23862446
def convert_softmax(g, op, block):
23872447
"""Operator converter for softmax."""
23882448

2449+
x = g.get_node(op.input("X")[0])
23892450
axis = op.attr("axis")
23902451
input_shape = block.var(op.input("X")[0]).shape
23912452
if axis < 0:
23922453
axis = len(input_shape) + axis
2393-
x = g.get_node(op.input("X")[0])
23942454
m = _op.max(x, axis, keepdims=True)
23952455
e = _op.exp(x - m)
23962456
out = e / _op.sum(e, axis, keepdims=True)
@@ -2905,6 +2965,9 @@ def convert_where_index(g, op, block):
29052965
"unstack": convert_unstack,
29062966
"where": convert_where,
29072967
"where_index": convert_where_index,
2968+
# Quantized
2969+
"dequantize_linear": convert_dequantize_linear,
2970+
"quantize_linear": convert_quantize_linear,
29082971
}
29092972

29102973

@@ -2938,7 +3001,7 @@ def get_params(self, name=None):
29383001

29393002
if name is None:
29403003
return self.params
2941-
assert name in self.params
3004+
assert name in self.params, f"The name({name}) is not in params"
29423005
return self.params[name]
29433006

29443007
def extract_parameters(self, program, scope=None):
@@ -2947,9 +3010,12 @@ def extract_parameters(self, program, scope=None):
29473010
self.params = {}
29483011
variables = program.global_block().vars
29493012
for name in variables:
2950-
var = program.global_block().var(name)
29513013
if name.endswith("feed") or name.endswith("fetch"):
29523014
continue
3015+
# This judgment will cause the PaddleInference model
3016+
# exported by PaddleSlim to skip some operators
3017+
# that need to be read in NHWC format.
3018+
var = program.global_block().var(name)
29533019
if not var.persistable:
29543020
continue
29553021
if isinstance(scope, dict):
@@ -3018,7 +3084,6 @@ def from_program(self, program, shape_dict, scope):
30183084
for op in block.ops:
30193085
if op.type == "fetch":
30203086
output_names.append(op.input("X")[0])
3021-
30223087
outputs = [self.nodes[name] for name in output_names]
30233088
outputs = outputs[0] if len(outputs) == 1 else _expr.Tuple(outputs)
30243089

0 commit comments

Comments
 (0)