Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 13 additions & 11 deletions python/tvm/relay/frontend/paddlepaddle.py
Original file line number Diff line number Diff line change
Expand Up @@ -1826,10 +1826,12 @@ def convert_dequantize_linear(g, op, block):
data_node = g.get_node(data_node_name)

# paddle_scale = tvm_scale * 127
paddle_quantize_scale = g.get_params(op.input("Scale")[0]).asnumpy()
tvm_quantize_scale = paddle_quantize_scale / 127.0
paddle_dequantize_scale = g.get_params(op.input("Scale")[0]).asnumpy()
tvm_dequantize_scale = paddle_dequantize_scale / 127.0
tvm_dequantize_scale = tvm_dequantize_scale.squeeze()

tvm_quantize_zp = g.get_params(op.input("ZeroPoint")[0]).asnumpy()
tvm_dequantize_zp = g.get_params(op.input("ZeroPoint")[0]).asnumpy()
tvm_dequantize_zp = tvm_dequantize_zp.squeeze()

tvm_quantize_axis = op.attr("quant_axis")
if tvm_quantize_axis == -1:
Expand All @@ -1840,8 +1842,8 @@ def convert_dequantize_linear(g, op, block):

out = _qnn.op.dequantize(
data=data_node,
input_scale=_op.const(tvm_quantize_scale, "float32"),
input_zero_point=_op.const(tvm_quantize_zp, "int32"),
input_scale=_expr.const(tvm_dequantize_scale, "float32"),
input_zero_point=_expr.const(tvm_dequantize_zp, "int32"),
axis=tvm_quantize_axis,
)
g.add_node(op.output("Y")[0], out)
Expand All @@ -1856,17 +1858,19 @@ def convert_quantize_linear(g, op, block):
# paddle_scale = tvm_scale * 127
paddle_quantize_scale = g.get_params(op.input("Scale")[0]).asnumpy()
tvm_quantize_scale = paddle_quantize_scale / 127.0
tvm_quantize_scale = tvm_quantize_scale.squeeze()

tvm_quantize_zp = g.get_params(op.input("ZeroPoint")[0]).asnumpy()
tvm_quantize_zp = tvm_quantize_zp.squeeze()
tvm_quantize_axis = op.attr("quant_axis")

if tvm_quantize_axis == -1:
tvm_quantize_axis = 0

out = _qnn.op.quantize(
data=data_node,
output_scale=_op.const(tvm_quantize_scale, "float32"),
output_zero_point=_op.const(tvm_quantize_zp, "int32"),
output_scale=_expr.const(tvm_quantize_scale, "float32"),
output_zero_point=_expr.const(tvm_quantize_zp, "int32"),
axis=tvm_quantize_axis,
)
g.add_node(op.output("Y")[0], out)
Expand Down Expand Up @@ -2446,14 +2450,12 @@ def convert_slice(g, op, block):
def convert_softmax(g, op, block):
"""Operator converter for softmax."""

x = g.get_node(op.input("X")[0])
data = g.get_node(op.input("X")[0])
axis = op.attr("axis")
input_shape = block.var(op.input("X")[0]).shape
if axis < 0:
axis = len(input_shape) + axis
m = _op.max(x, axis, keepdims=True)
e = _op.exp(x - m)
out = e / _op.sum(e, axis, keepdims=True)
out = _op.nn.softmax(data, axis)
g.add_node(op.output("Out")[0], out)


Expand Down
10 changes: 7 additions & 3 deletions python/tvm/relay/op/contrib/cmsisnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,9 +86,14 @@ def check_qnn_softmax(pattern):
zero_point = pattern.args[2].data.numpy().item(0)

# check for dtypes of quantize and dequantize
# if (
# (scale == 1.0 / 256 and zero_point == -128)
# and pattern.attrs.out_dtype == "int8"
# and dequantize_call.args[0].checked_type.dtype == "int8"
# ):
# return True
if (
(scale == 1.0 / 256 and zero_point == -128)
and pattern.attrs.out_dtype == "int8"
pattern.attrs.out_dtype == "int8"
and dequantize_call.args[0].checked_type.dtype == "int8"
):
return True
Expand All @@ -99,7 +104,6 @@ def check_qnn_softmax(pattern):
and dequantize_call.args[0].checked_type.dtype == "int16"
):
return True

return False

def qnn_conv2d_pattern(with_pad):
Expand Down