Skip to content

Commit 8455eba

Browse files
committed
[QNN] Refactor fixed point multiplication in requantize
1 parent 85a1d3f commit 8455eba

File tree

4 files changed

+183
-104
lines changed

4 files changed

+183
-104
lines changed

src/relay/pass/pattern_util.h

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -336,6 +336,14 @@ inline Expr ZerosLike(Expr e) {
336336
return CallNode::make(op, {e});
337337
}
338338

339+
inline Expr Zeros(Array<IndexExpr> shape, DataType dtype) {
340+
auto attrs = make_node<InitOpAttrs>();
341+
attrs->shape = std::move(shape);
342+
attrs->dtype = std::move(dtype);
343+
static const Op& op = Op::Get("zeros");
344+
return CallNode::make(op, {}, Attrs(attrs), {});
345+
}
346+
339347
inline Expr OnesLike(Expr e) {
340348
static const Op& op = Op::Get("ones_like");
341349
return CallNode::make(op, {e});

src/relay/qnn/op/requantize.cc

Lines changed: 9 additions & 104 deletions
Original file line numberDiff line numberDiff line change
@@ -37,50 +37,7 @@ TVM_REGISTER_NODE_TYPE(RequantizeAttrs);
3737

3838
// Lowering of qnn.requantize op
3939

40-
/*
41-
* \brief Convert FP32 representation into fixed point representation.
42-
* \param double_multplier The input FP32 number.
43-
* \return The pair of multiplier and shift for fixed point representation.
44-
* \note Converts a floating point number so that it can be represented by
45-
* integers. The representation is
46-
* float_number = (significand) * 2^(exponent)
47-
*
48-
* The significand is a number between 0.5 and 1. This is represented by
49-
* an integer number. For example, if it is int32, then the decimal point
50-
* exists between bit 31 and 30 from LSB (or between first and second bit
51-
* from the left).
52-
*
53-
* Some examples are
54-
* 0.25 = (0.5) * 2^(-1)
55-
* 0.125 = (0.5) * 2^(-2)
56-
*
57-
* Credit to TFLite reference implementation.
58-
*/
59-
std::pair<int32_t, int32_t> GetFixedPointMultiplierShift(double double_multiplier) {
60-
int32_t significand, exponent;
61-
if (double_multiplier == 0.) {
62-
significand = 0;
63-
exponent = 0;
64-
return std::make_pair(significand, exponent);
65-
}
6640

67-
// Get the significand and exponent.
68-
double significand_d = std::frexp(double_multiplier, &exponent);
69-
70-
// Convert the double significand to int significand, i.e., convert into a
71-
// integer where the decimal point is between bit 31 and 30. This is done by
72-
// multiplying the double value with 2^31 and then casting to int.
73-
significand_d = std::round(significand_d * (1ll << 31));
74-
auto significand_int64 = static_cast<int64_t>(significand_d);
75-
CHECK_LE(significand_int64, (1ll << 31));
76-
if (significand_int64 == (1ll << 31)) {
77-
significand_int64 /= 2;
78-
++exponent;
79-
}
80-
CHECK_LE(significand_int64, std::numeric_limits<int32_t>::max());
81-
significand = static_cast<int32_t>(significand_int64);
82-
return std::make_pair(significand, exponent);
83-
}
8441

8542
/*
8643
* \brief Lower requantize to a sequence of ops.
@@ -93,93 +50,41 @@ std::pair<int32_t, int32_t> GetFixedPointMultiplierShift(double double_multiplie
9350
* and shift. This is useful, if the target device does not support/have
9451
* very expensive floating point computations.
9552
*
96-
* Original compuation is scale_fp32 * quantized_tensor. To convert into
97-
* integer computation, the multiplication with fp32 scalar can be
98-
* replaced by multiplication with an int value and then right shifting
99-
* the result. This approximates the floating point computation with a
100-
* fixed point computation.
101-
*
10253
* The whole computation this can be broken down into following steps
10354
* 1) Calculate the integer multiplier and integer shift.
10455
* 2) Subtract the input integer zero point.
105-
* 3) Multiply the fixed point multiplier with quantized tensor.
106-
* 4) Round the result.
107-
* 5) Right shift the result.
108-
* 6) Add the output zero point.
109-
* 7) Cast to the out_dtype.
56+
* 3) Perform fixed point multiplication.
57+
* 4) Add the output zero point.
58+
* 5) Cast to the out_dtype.
11059
*/
11160
Expr RequantizeLower(const Expr& input_tensor, const RequantizeAttrs* param,
11261
const Array<IndexExpr>& input_shape, const DataType& out_dtype) {
11362
double double_multiplier = param->input_scale / param->output_scale;
11463

115-
// Choose high precision datatype to be int64. This is for avoiding overflow
116-
// in multiplication of two int32 values.
11764
DataType hp_dtype = Int(64);
11865

119-
// 1) Calculating the integer multiplier and integer shift
120-
int32_t fixed_point_multiplier, shift;
121-
std::tie(fixed_point_multiplier, shift) = GetFixedPointMultiplierShift(double_multiplier);
122-
int left_shift = shift > 0 ? shift : 0;
123-
int right_shift = shift > 0 ? 0 : -shift;
124-
125-
// 2) Subtract the input_zero_point
12666
auto tensor = Cast(input_tensor, hp_dtype);
67+
// 1) Subtract the input_zero_point
12768
if (param->input_zero_point != 0) {
12869
auto input_zp = MakeConstantScalar(hp_dtype, param->input_zero_point);
12970
tensor = Subtract(tensor, input_zp);
13071
}
13172

132-
// If the input and output scales are same, we can skip the fixed point multiplication.
73+
// 2) If the input and output scales are same, we can skip the fixed point multiplication.
13374
auto scaled_int64_t = tensor;
13475
if (param->input_scale != param->output_scale) {
135-
// 3) Multiply the integer multiplier
136-
if (left_shift != 0) {
137-
tensor = Multiply(tensor, MakeConstantScalar(hp_dtype, 1 << left_shift));
138-
}
139-
// Perform the multiplication in higher precision.
140-
// The scalar is a fixed point value of int32 where the decimal point is
141-
// between bits 31 and 30. After multiplying with input_tensor, the result is
142-
// in int64 where the decimal point is sitting between bits 31 and 30 (from
143-
// the right, rightmost bit is bit 0). The computation is performed in higher
144-
// precision to avoid overflow in multiplying two int32 values.
145-
Expr scalar = MakeConstantScalar(hp_dtype, fixed_point_multiplier);
146-
auto multiplied_t = Multiply(tensor, scalar);
147-
148-
// 4) Find the rounding scalar. This depends on where the final decimal point
149-
// sits. As we will be right shifting the multiplied_t, we need to first
150-
// calculate the total_right_shift.
151-
int total_right_shift = right_shift + 31;
152-
int64_t pos_rounding_value = (1ll << (total_right_shift - 1));
153-
154-
tensor = multiplied_t;
155-
Expr round_scalar;
156-
if (param->rounding == "UPWARD") {
157-
round_scalar = MakeConstantScalar(hp_dtype, pos_rounding_value);
158-
} else if (param->rounding == "TONEAREST") {
159-
auto pos_rounder = MakeConstantScalar(hp_dtype, pos_rounding_value);
160-
auto neg_rounder = MakeConstantScalar(hp_dtype, pos_rounding_value - 1);
161-
auto pos_rounder_t = Full(pos_rounder, input_shape, hp_dtype);
162-
auto neg_rounder_t = Full(neg_rounder, input_shape, hp_dtype);
163-
164-
auto zero = MakeConstantScalar(hp_dtype, 0);
165-
auto zero_t = Full(zero, input_shape, hp_dtype);
166-
round_scalar = Where(GreaterEqual(tensor, zero_t), pos_rounder_t, neg_rounder_t);
167-
}
168-
// Add the rounding scalar.
169-
tensor = Add(tensor, round_scalar);
170-
171-
// 5) Simply right shift the result to get the final output.
172-
scaled_int64_t = RightShift(tensor, MakeConstantScalar(hp_dtype, total_right_shift));
76+
scaled_int64_t = FixedPointMuliply(scaled_int64_t, double_multiplier, input_shape,
77+
param->rounding);
17378
}
17479

175-
// 6) Add the output zero point.
80+
// 3) Add the output zero point.
17681
auto shifted_int64_t = scaled_int64_t;
17782
if (param->output_zero_point != 0) {
17883
auto output_zp = MakeConstantScalar(hp_dtype, param->output_zero_point);
17984
shifted_int64_t = Add(output_zp, scaled_int64_t);
18085
}
18186

182-
// 7) Clip to the out_dtype min/max.
87+
// 4) Clip to the out_dtype min/max.
18388
auto q_min = GetQmin(out_dtype);
18489
auto q_max = GetQmax(out_dtype);
18590
auto clipped_t = Clip(shifted_int64_t, q_min, q_max);

src/relay/qnn/util.cc

Lines changed: 138 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,138 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one
3+
* or more contributor license agreements. See the NOTICE file
4+
* distributed with this work for additional information
5+
* regarding copyright ownership. The ASF licenses this file
6+
* to you under the Apache License, Version 2.0 (the
7+
* "License"); you may not use this file except in compliance
8+
* with the License. You may obtain a copy of the License at
9+
*
10+
* http://www.apache.org/licenses/LICENSE-2.0
11+
*
12+
* Unless required by applicable law or agreed to in writing,
13+
* software distributed under the License is distributed on an
14+
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15+
* KIND, either express or implied. See the License for the
16+
* specific language governing permissions and limitations
17+
* under the License.
18+
*/
19+
20+
/*!
21+
* Copyright (c) 2019 by Contributors
22+
* \file src/relay/qnn/util.cc
23+
* \brief Utility functions for QNN.
24+
*/
25+
26+
#include "util.h"
27+
#include "../pass/pattern_util.h"
28+
29+
namespace tvm {
30+
namespace relay {
31+
namespace qnn {
32+
33+
/*
34+
* \brief Convert FP32 representation into fixed point representation.
35+
* \param double_multplier The input FP32 number.
36+
* \return The pair of multiplier and shift for fixed point representation.
37+
* \note Converts a floating point number so that it can be represented by
38+
* integers. The representation is
39+
* float_number = (significand) * 2^(exponent)
40+
*
41+
* The significand is a number between 0.5 and 1. This is represented by
42+
* an integer number. For example, if it is int32, then the decimal point
43+
* exists between bit 31 and 30 from LSB (or between first and second bit
44+
* from the left).
45+
*
46+
* Some examples are
47+
* 0.25 = (0.5) * 2^(-1)
48+
* 0.125 = (0.5) * 2^(-2)
49+
*
50+
* Credit to TFLite reference implementation.
51+
*/
52+
std::pair<int32_t, int32_t> GetFixedPointMultiplierShift(
53+
double double_multiplier) {
54+
int32_t significand, exponent;
55+
if (double_multiplier == 0.) {
56+
significand = 0;
57+
exponent = 0;
58+
return std::make_pair(significand, exponent);
59+
}
60+
61+
// Get the significand and exponent.
62+
double significand_d = std::frexp(double_multiplier, &exponent);
63+
64+
// Convert the double significand to int significand, i.e., convert into a
65+
// integer where the decimal point is between bit 31 and 30. This is done by
66+
// multiplying the double value with 2^31 and then casting to int.
67+
significand_d = std::round(significand_d * (1ll << 31));
68+
auto significand_int64 = static_cast<int64_t>(significand_d);
69+
CHECK_LE(significand_int64, (1ll << 31));
70+
if (significand_int64 == (1ll << 31)) {
71+
significand_int64 /= 2;
72+
++exponent;
73+
}
74+
CHECK_LE(significand_int64, std::numeric_limits<int32_t>::max());
75+
significand = static_cast<int32_t>(significand_int64);
76+
return std::make_pair(significand, exponent);
77+
}
78+
79+
Expr FixedPointMuliply(Expr tensor, double multiplier,
80+
const Array<IndexExpr>& input_shape, const std::string& rounding) {
81+
// Choose high precision datatype to be int64. This is for avoiding overflow
82+
// in multiplication of two int32 values.
83+
DataType hp_dtype = Int(64);
84+
85+
// 1) Calculating the integer multiplier and integer shift
86+
int32_t fixed_point_multiplier, shift;
87+
std::tie(fixed_point_multiplier, shift) =
88+
GetFixedPointMultiplierShift(multiplier);
89+
int left_shift = shift > 0 ? shift : 0;
90+
int right_shift = shift > 0 ? 0 : -shift;
91+
92+
// 2) Multiply the integer multiplier
93+
if (left_shift != 0) {
94+
tensor = LeftShift(tensor, MakeConstantScalar(hp_dtype, left_shift));
95+
}
96+
97+
// 3) Perform the multiplication in higher precision.
98+
// The scalar is a fixed point value of int32 where the decimal point is
99+
// between bits 31 and 30. After multiplying with input_tensor, the result
100+
// is in int64 where the decimal point is sitting between bits 31 and 30
101+
// (from the right, rightmost bit is bit 0). The computation is performed in
102+
// higher precision to avoid overflow in multiplying two int32 values.
103+
Expr scalar = MakeConstantScalar(hp_dtype, fixed_point_multiplier);
104+
tensor = Multiply(tensor, scalar);
105+
106+
// 4) Find the rounding scalar. This depends on where the final decimal
107+
// point sits. As we will be right shifting the multiplied_t, we need to
108+
// first calculate the total_right_shift.
109+
int total_right_shift = right_shift + 31;
110+
int64_t pos_rounding_value = (1ll << (total_right_shift - 1));
111+
112+
Expr round_scalar;
113+
if (rounding == "UPWARD") {
114+
round_scalar = MakeConstantScalar(hp_dtype, pos_rounding_value);
115+
} else if (rounding == "TONEAREST") {
116+
auto pos_rounder = MakeConstantScalar(hp_dtype, pos_rounding_value);
117+
auto neg_rounder = MakeConstantScalar(hp_dtype, pos_rounding_value - 1);
118+
auto pos_rounder_t = Full(pos_rounder, input_shape, hp_dtype);
119+
auto neg_rounder_t = Full(neg_rounder, input_shape, hp_dtype);
120+
121+
auto zero_t = Zeros(input_shape, hp_dtype);
122+
round_scalar =
123+
Where(GreaterEqual(tensor, zero_t), pos_rounder_t, neg_rounder_t);
124+
}
125+
// Add the rounding scalar.
126+
tensor = Add(tensor, round_scalar);
127+
128+
// 5) Simply right shift the result to get the final output.
129+
tensor =
130+
RightShift(tensor, MakeConstantScalar(hp_dtype, total_right_shift));
131+
132+
return tensor;
133+
}
134+
135+
} // namespace qnn
136+
} // namespace relay
137+
} // namespace tvm
138+

src/relay/qnn/util.h

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727

2828
#include <tvm/expr.h>
2929
#include <tvm/relay/expr.h>
30+
#include <tvm/relay/qnn/attrs.h>
3031
#include <limits>
3132
#include <string>
3233
#include <utility>
@@ -92,7 +93,34 @@ static inline int64_t get_const_int(const tvm::Expr& x) {
9293
return value_ptr[0];
9394
}
9495

96+
/*
97+
* \brief Fixed point multiplication between integer tensor with floating point
98+
scalar.
99+
* \param tensor The quantized input tensor of dtype int64.
100+
* \param multiplier The scalar multiplier.
101+
* \param input_shape Shape of the input tensor.
102+
* \param rounding "UPWARD" or "TONEAREST". The rounding direction when the value
103+
is midway between" "two representable values.
104+
* \return The sequence of Relay ops for fixed point multiplication.
105+
106+
* \note Original compuation is scale_fp32 * quantized_tensor. To convert into
107+
* integer computation, the multiplication with fp32 scalar can be
108+
* replaced by multiplication with an int value and then right shifting
109+
* the result. This approximates the floating point computation with a
110+
* fixed point computation.
111+
*
112+
* Computation of fixed point multiplication is consist of following
113+
steps:
114+
* 1) Multiply the fixed point multiplier with quantized tensor.
115+
* 2) Round the result.
116+
* 3) Right shift the result
117+
*/
118+
Expr FixedPointMuliply(Expr tensor, double multiplier,
119+
const Array<IndexExpr>& input_shape,
120+
const std::string& rounding);
121+
95122
} // namespace qnn
96123
} // namespace relay
97124
} // namespace tvm
98125
#endif // TVM_RELAY_QNN_UTIL_H_
126+

0 commit comments

Comments
 (0)