Skip to content

Commit a91949d

Browse files
zhaoyang-starMikael Sevenier
authored andcommitted
[QNN] Support different qnn params between in/out tensor in leaky_relu (apache#12116)
* [QNN] Support different qnn params between in/out tensor in leaky_relu * format code * format code * fix bug * fix format * fix format * fix
1 parent fe043bf commit a91949d

File tree

5 files changed

+104
-47
lines changed

5 files changed

+104
-47
lines changed

python/tvm/relay/frontend/qnn_torch.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -963,7 +963,11 @@ def _impl_int8(inputs, _):
963963
alpha = inputs[1]
964964
output_scale = _expr.const(inputs[3])
965965
output_zero_point = _expr.const(inputs[4])
966-
return relay.qnn.op.leaky_relu(inputs[0], alpha, output_scale, output_zero_point)
966+
input_scale = _expr.const(inputs[5])
967+
input_zero_point = _expr.const(inputs[6])
968+
return relay.qnn.op.leaky_relu(
969+
inputs[0], alpha, input_scale, input_zero_point, output_scale, output_zero_point
970+
)
967971

968972
def _impl(inputs, _):
969973
assert len(inputs) == 7, "Input quant params not found in op inputs"

python/tvm/relay/qnn/op/qnn.py

Lines changed: 13 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1179,7 +1179,7 @@ def batch_matmul(x, y, x_zero_point, y_zero_point, x_scale, y_scale, out_dtype="
11791179
reg.register_pattern("qnn.dequantize", OpPattern.OPAQUE)
11801180

11811181

1182-
def leaky_relu(x, alpha, scale, zero_point):
1182+
def leaky_relu(x, alpha, input_scale, input_zero_point, output_scale, output_zero_point):
11831183
"""Quantized leaky relu.
11841184
11851185
Parameters
@@ -1188,11 +1188,14 @@ def leaky_relu(x, alpha, scale, zero_point):
11881188
The quantized input tensor.
11891189
alpha: double
11901190
The alpha value.
1191-
scale: relay.Expr
1192-
The scale of the quantized expr.
1193-
zero_point: relay.Expr
1194-
The zero point of quantized expr.
1195-
1191+
input_scale: relay.Expr
1192+
The scale of the input quantized expr.
1193+
input_zero_point: relay.Expr
1194+
The zero point of input quantized expr.
1195+
output_scale: relay.Expr
1196+
The scale of the output quantized expr.
1197+
output_zero_point: relay.Expr
1198+
The zero point of output quantized expr.
11961199
Returns
11971200
-------
11981201
result : relay.Expr
@@ -1201,6 +1204,8 @@ def leaky_relu(x, alpha, scale, zero_point):
12011204
return _make.leaky_relu(
12021205
x,
12031206
alpha,
1204-
scale,
1205-
zero_point,
1207+
input_scale,
1208+
input_zero_point,
1209+
output_scale,
1210+
output_zero_point,
12061211
)

python/tvm/relay/transform/fake_quantization_to_integer.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -364,10 +364,13 @@ def relu(expr, type_map):
364364
def leaky_relu(expr, type_map):
365365
"""Rewrite a leaky relu op"""
366366
arg = expr.args[0]
367-
t = type_map[arg]
367+
x_t = type_map[arg]
368+
out_t = type_map[expr]
368369
alpha = expr.attrs.alpha
369-
output = relay.qnn.op.leaky_relu(expr, alpha, t.scale, t.zero_point)
370-
return [output, t]
370+
output = relay.qnn.op.leaky_relu(
371+
expr, alpha, x_t.scale, x_t.zero_point, out_t.scale, out_t.zero_point
372+
)
373+
return [output, x_t]
371374

372375

373376
@register_fake_quantization_to_integer("nn.pad")

src/relay/qnn/op/leaky_relu.cc

Lines changed: 59 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -32,8 +32,8 @@ namespace qnn {
3232

3333
bool QnnLeakyReluRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,
3434
const TypeReporter& reporter) {
35-
// Expected Types: data, scale, zero_point
36-
ICHECK_EQ(types.size(), 4);
35+
// Expected Types: data, input_scale, input_zero_point, output_scale, output_zero_point, out_type
36+
ICHECK_EQ(types.size(), 6);
3737
const auto* x = types[0].as<TensorTypeNode>();
3838
if (x == nullptr) return false;
3939
ICHECK(x->dtype == DataType::Int(8) || x->dtype == DataType::UInt(8))
@@ -42,31 +42,37 @@ bool QnnLeakyReluRel(const Array<Type>& types, int num_inputs, const Attrs& attr
4242
ICHECK(param != nullptr) << "LeakyReluAttrs cannot be nullptr.";
4343

4444
// Check the types of scale and zero points.
45-
for (size_t i = 1; i < 3; ++i) {
45+
for (size_t i = 1; i < 5; ++i) {
4646
if (types[i].as<IncompleteTypeNode>()) {
4747
return false;
4848
}
4949
}
5050

51-
ICHECK(IsScalarType(types[1], DataType::Float(32))); // scale
52-
ICHECK(IsScalarType(types[2], DataType::Int(32))); // zero_point
51+
ICHECK(IsScalarType(types[1], DataType::Float(32))); // input_scale
52+
ICHECK(IsScalarType(types[2], DataType::Int(32))); // input_zero_point
53+
ICHECK(IsScalarType(types[3], DataType::Float(32))); // output_scale
54+
ICHECK(IsScalarType(types[4], DataType::Int(32))); // output_zero_point
5355

5456
// Assign types for scale and zero points.
55-
reporter->Assign(types[1], TensorType({}, DataType::Float(32))); // scale
56-
reporter->Assign(types[2], TensorType({}, DataType::Int(32))); // zero_point
57+
reporter->Assign(types[1], TensorType({}, DataType::Float(32))); // input_scale
58+
reporter->Assign(types[2], TensorType({}, DataType::Int(32))); // input_zero_point
59+
reporter->Assign(types[3], TensorType({}, DataType::Float(32))); // output_scale
60+
reporter->Assign(types[4], TensorType({}, DataType::Int(32))); // output_zero_point
5761

5862
// Collect the input tensor and output tensor devoid of scale and zero points to reuse Relay
5963
// IdentityRel infer type function.
60-
Array<Type> tensor_types = {types[0], types[3]};
64+
Array<Type> tensor_types = {types[0], types[5]};
6165
return IdentityRel(tensor_types, 2, attrs, reporter);
6266
}
6367

6468
// Positional relay function to create quantized leaky relu operator used by frontend FFI.
65-
Expr MakeQuantizedLeakyRelu(Expr x, double alpha, Expr scale, Expr zero_point) {
69+
Expr MakeQuantizedLeakyRelu(Expr x, double alpha, Expr input_scale, Expr input_zero_point,
70+
Expr output_scale, Expr output_zero_point) {
6671
auto attrs = make_object<LeakyReluAttrs>();
6772
attrs->alpha = alpha;
6873
static const Op& op = Op::Get("qnn.leaky_relu");
69-
return Call(op, {x, scale, zero_point}, Attrs(attrs), {});
74+
return Call(op, {x, input_scale, input_zero_point, output_scale, output_zero_point}, Attrs(attrs),
75+
{});
7076
}
7177

7278
/*
@@ -82,42 +88,69 @@ Expr QnnLeakyReluCanonicalize(const Attrs& attrs, const Array<Expr>& new_args,
8288
// by a small alpha value < 1.
8389
//
8490
// We assume the same scale and zero point for alpha and the input tensor.
85-
// Let T = s(q_t - z) where q_t is the input arg[0]
86-
// Then, the quantized value of alpha * T is:
87-
// q(a * T, s, z) = [(a * T) / s] + z = a * s(q_t - z) / s + z = a * (q_t - z) + z
88-
// = a * q_t + (1 - a) * z
91+
// LeakyReLU can be written in terms of respective quantized tensors, scales and
92+
// zero points as
8993
//
90-
// We return the quantized value of alpha * T for all values q_t < input_zero_point.
91-
92-
ICHECK_EQ(new_args.size(), 3);
93-
Expr quantized_data = Cast(new_args[0], DataType::Int(32));
94+
// scale_o * (Q_o - zp_o) = alpha * scale_i * (Q_i - zp_i) when Q_i < zp_i (1)
95+
// scale_o * (Q_o - zp_o) = scale_i * (Q_i - zp_i) when Q_i >= zp_i (2)
96+
//
97+
// Since the input qnn params can be different than output qnn params, we first requantize the
98+
// input tensor to the output qnn params. After requantizing Q_i, equation (1) becames equation
99+
// (3) where Q_i' is the requantized data from Q_i.
100+
//
101+
// scale_o * (Q_o - zp_o) = alpha * scale_o * (Q_i' - zp_o) when Q_i < zp_i (3)
102+
// Q_o = alpha * Q_i' + (1 - alpha) * zp_o when Q_i < zp_i (4)
103+
//
104+
// It is equal to requantize Q_i to Q_o using scale_o and zp_o in equation (2).
105+
// So equation (2) becomes
106+
//
107+
// Q_o = requantize(Q_i) when Q_i >= zp_i (5)
108+
//
109+
// Finnally, Q_o could be calculated by equation (4) and equation (5).
110+
ICHECK_EQ(new_args.size(), 5);
111+
Expr data = Cast(new_args[0], DataType::Int(32));
112+
Expr input_scale = new_args[1];
94113
Expr input_zero_point = Cast(new_args[2], DataType::Int(32));
114+
Expr output_scale = new_args[3];
115+
Expr output_zero_point = Cast(new_args[4], DataType::Int(32));
95116

96117
const auto* q_attrs = attrs.as<LeakyReluAttrs>();
97118
auto alpha = q_attrs->alpha;
98119

120+
const auto input_shape = get_shape(arg_types[0]);
121+
const auto input_dtype = arg_types[0].as<TensorTypeNode>()->dtype;
122+
123+
// requantize the input to Q_i'
124+
auto requantized_expr = RequantizeOrUpcast(data, input_scale, input_zero_point, output_scale,
125+
output_zero_point, input_shape);
126+
127+
// alpha * Q_i'
99128
int32_t fixed_point_multiplier, shift;
100129
std::tie(fixed_point_multiplier, shift) = GetFixedPointMultiplierShift(alpha);
101-
auto prod = FixedPointMultiply(quantized_data, fixed_point_multiplier, shift);
130+
auto prod = FixedPointMultiply(requantized_expr, fixed_point_multiplier, shift);
102131

132+
// (1 - alpha) * zp_o
103133
int32_t fixed_point_multiplier_z, shift_z;
104134
std::tie(fixed_point_multiplier_z, shift_z) = GetFixedPointMultiplierShift(1 - alpha);
105-
auto scaled_z = FixedPointMultiply(input_zero_point, fixed_point_multiplier_z, shift_z);
135+
auto scaled_z = FixedPointMultiply(output_zero_point, fixed_point_multiplier_z, shift_z);
106136

137+
// alpha * Q_i' + (1 - alpha) * zp_o
107138
auto add = Add(prod, scaled_z);
108-
auto output = Where(Less(quantized_data, input_zero_point), add, quantized_data);
139+
auto output = Where(Less(data, input_zero_point), add, requantized_expr);
109140

110-
const auto* input_type = arg_types[0].as<TensorTypeNode>();
111-
return ConvertDtype(output, input_type->dtype);
141+
return ConvertDtype(output, input_dtype);
112142
}
113143

114144
RELAY_REGISTER_OP("qnn.leaky_relu")
115145
.describe("Leaky relu for quantized tensors.")
116146
.set_attrs_type<LeakyReluAttrs>()
117-
.set_num_inputs(3)
147+
.set_num_inputs(5)
118148
.add_argument("data", "Quantized Tensor", "The input data.")
119-
.add_argument("scale", "Tensor", "The quantization scale of the input tensor.")
120-
.add_argument("zero_point", "Tensor", "The quantization zero_point of the input tensor.")
149+
.add_argument("input_scale", "Tensor", "The quantization scale of the input tensor.")
150+
.add_argument("input_zero_point", "Tensor", "The quantization zero_point of the input tensor.")
151+
.add_argument("output_scale", "Tensor", "The quantization scale of the output tensor.")
152+
.add_argument("output_zero_point", "Tensor",
153+
"The quantization zero_point of the output tensor.")
121154
.set_support_level(11)
122155
.add_type_rel("QLeakyRelu", QnnLeakyReluRel)
123156
.set_attr<TNonComputational>("TNonComputational", true)

tests/python/relay/test_op_qnn_leaky_relu.py

Lines changed: 21 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -24,26 +24,36 @@ def dequantize(data, scale, zp):
2424
return scale * (np.asarray(data) - zp)
2525

2626

27-
def generate_golden_output(x_data, dequantized_x, alpha, scale, zero_point):
27+
def generate_golden_output(x_data, dequantized_x, alpha, o_scale, o_zero_point, i_zero_point):
2828
prod = np.multiply(dequantized_x, alpha)
29-
prod = np.around(prod / scale + zero_point)
29+
prod = np.around(prod / o_scale + o_zero_point)
3030

31-
output = np.where(x_data < zero_point, prod, x_data)
31+
q_min = np.iinfo(np.uint8).min
32+
q_max = np.iinfo(np.uint8).max
33+
prod = np.clip(prod, q_min, q_max)
34+
35+
requantized = np.clip(np.round(dequantized_x / o_scale + o_zero_point), q_min, q_max)
36+
37+
output = np.where(x_data < i_zero_point, prod, requantized)
3238
return output
3339

3440

3541
def test_qnn_leaky_relu():
3642
data_dtype = "uint8"
37-
scale = 0.125
38-
zero_point = 60
43+
input_scale = 0.125
44+
input_zero_point = 60
45+
output_scale = 0.6
46+
output_zero_point = 17
3947
alpha = 0.9
4048

4149
x = relay.var("x", shape=(1, 4), dtype=data_dtype)
4250
y = relay.qnn.op.leaky_relu(
4351
x=x,
4452
alpha=alpha,
45-
scale=relay.const(scale, "float32"),
46-
zero_point=relay.const(zero_point, "int32"),
53+
input_scale=relay.const(input_scale, "float32"),
54+
input_zero_point=relay.const(input_zero_point, "int32"),
55+
output_scale=relay.const(output_scale, "float32"),
56+
output_zero_point=relay.const(output_zero_point, "int32"),
4757
)
4858

4959
func = relay.Function([x], y)
@@ -53,8 +63,10 @@ def test_qnn_leaky_relu():
5363
func = mod["main"]
5464

5565
x_data = np.array((255, 133, 0, 9)).reshape((1, 4))
56-
x_dequantized = dequantize(x_data, scale, zero_point)
57-
golden_output = generate_golden_output(x_data, x_dequantized, alpha, scale, zero_point)
66+
x_dequantized = dequantize(x_data, input_scale, input_zero_point)
67+
golden_output = generate_golden_output(
68+
x_data, x_dequantized, alpha, output_scale, output_zero_point, input_zero_point
69+
)
5870

5971
op_res = relay.create_executor("graph", device=tvm.cpu(0), target="llvm").evaluate(func)(x_data)
6072

0 commit comments

Comments
 (0)