@@ -37,50 +37,7 @@ TVM_REGISTER_NODE_TYPE(RequantizeAttrs);
3737
3838// Lowering of qnn.requantize op
3939
40- /*
41- * \brief Convert FP32 representation into fixed point representation.
42- * \param double_multplier The input FP32 number.
43- * \return The pair of multiplier and shift for fixed point representation.
44- * \note Converts a floating point number so that it can be represented by
45- * integers. The representation is
46- * float_number = (significand) * 2^(exponent)
47- *
48- * The significand is a number between 0.5 and 1. This is represented by
49- * an integer number. For example, if it is int32, then the decimal point
50- * exists between bit 31 and 30 from LSB (or between first and second bit
51- * from the left).
52- *
53- * Some examples are
54- * 0.25 = (0.5) * 2^(-1)
55- * 0.125 = (0.5) * 2^(-2)
56- *
57- * Credit to TFLite reference implementation.
58- */
59- std::pair<int32_t , int32_t > GetFixedPointMultiplierShift (double double_multiplier) {
60- int32_t significand, exponent;
61- if (double_multiplier == 0 .) {
62- significand = 0 ;
63- exponent = 0 ;
64- return std::make_pair (significand, exponent);
65- }
6640
67- // Get the significand and exponent.
68- double significand_d = std::frexp (double_multiplier, &exponent);
69-
70- // Convert the double significand to int significand, i.e., convert into a
71- // integer where the decimal point is between bit 31 and 30. This is done by
72- // multiplying the double value with 2^31 and then casting to int.
73- significand_d = std::round (significand_d * (1ll << 31 ));
74- auto significand_int64 = static_cast <int64_t >(significand_d);
75- CHECK_LE (significand_int64, (1ll << 31 ));
76- if (significand_int64 == (1ll << 31 )) {
77- significand_int64 /= 2 ;
78- ++exponent;
79- }
80- CHECK_LE (significand_int64, std::numeric_limits<int32_t >::max ());
81- significand = static_cast <int32_t >(significand_int64);
82- return std::make_pair (significand, exponent);
83- }
8441
8542/*
8643 * \brief Lower requantize to a sequence of ops.
@@ -93,93 +50,41 @@ std::pair<int32_t, int32_t> GetFixedPointMultiplierShift(double double_multiplie
9350 * and shift. This is useful, if the target device does not support/have
9451 * very expensive floating point computations.
9552 *
96- * Original compuation is scale_fp32 * quantized_tensor. To convert into
97- * integer computation, the multiplication with fp32 scalar can be
98- * replaced by multiplication with an int value and then right shifting
99- * the result. This approximates the floating point computation with a
100- * fixed point computation.
101- *
10253 * The whole computation this can be broken down into following steps
10354 * 1) Calculate the integer multiplier and integer shift.
10455 * 2) Subtract the input integer zero point.
105- * 3) Multiply the fixed point multiplier with quantized tensor.
106- * 4) Round the result.
107- * 5) Right shift the result.
108- * 6) Add the output zero point.
109- * 7) Cast to the out_dtype.
56+ * 3) Perform fixed point multiplication.
57+ * 4) Add the output zero point.
58+ * 5) Cast to the out_dtype.
11059 */
11160Expr RequantizeLower (const Expr& input_tensor, const RequantizeAttrs* param,
11261 const Array<IndexExpr>& input_shape, const DataType& out_dtype) {
11362 double double_multiplier = param->input_scale / param->output_scale ;
11463
115- // Choose high precision datatype to be int64. This is for avoiding overflow
116- // in multiplication of two int32 values.
11764 DataType hp_dtype = Int (64 );
11865
119- // 1) Calculating the integer multiplier and integer shift
120- int32_t fixed_point_multiplier, shift;
121- std::tie (fixed_point_multiplier, shift) = GetFixedPointMultiplierShift (double_multiplier);
122- int left_shift = shift > 0 ? shift : 0 ;
123- int right_shift = shift > 0 ? 0 : -shift;
124-
125- // 2) Subtract the input_zero_point
12666 auto tensor = Cast (input_tensor, hp_dtype);
67+ // 1) Subtract the input_zero_point
12768 if (param->input_zero_point != 0 ) {
12869 auto input_zp = MakeConstantScalar (hp_dtype, param->input_zero_point );
12970 tensor = Subtract (tensor, input_zp);
13071 }
13172
132- // If the input and output scales are same, we can skip the fixed point multiplication.
73+ // 2) If the input and output scales are same, we can skip the fixed point multiplication.
13374 auto scaled_int64_t = tensor;
13475 if (param->input_scale != param->output_scale ) {
135- // 3) Multiply the integer multiplier
136- if (left_shift != 0 ) {
137- tensor = Multiply (tensor, MakeConstantScalar (hp_dtype, 1 << left_shift));
138- }
139- // Perform the multiplication in higher precision.
140- // The scalar is a fixed point value of int32 where the decimal point is
141- // between bits 31 and 30. After multiplying with input_tensor, the result is
142- // in int64 where the decimal point is sitting between bits 31 and 30 (from
143- // the right, rightmost bit is bit 0). The computation is performed in higher
144- // precision to avoid overflow in multiplying two int32 values.
145- Expr scalar = MakeConstantScalar (hp_dtype, fixed_point_multiplier);
146- auto multiplied_t = Multiply (tensor, scalar);
147-
148- // 4) Find the rounding scalar. This depends on where the final decimal point
149- // sits. As we will be right shifting the multiplied_t, we need to first
150- // calculate the total_right_shift.
151- int total_right_shift = right_shift + 31 ;
152- int64_t pos_rounding_value = (1ll << (total_right_shift - 1 ));
153-
154- tensor = multiplied_t ;
155- Expr round_scalar;
156- if (param->rounding == " UPWARD" ) {
157- round_scalar = MakeConstantScalar (hp_dtype, pos_rounding_value);
158- } else if (param->rounding == " TONEAREST" ) {
159- auto pos_rounder = MakeConstantScalar (hp_dtype, pos_rounding_value);
160- auto neg_rounder = MakeConstantScalar (hp_dtype, pos_rounding_value - 1 );
161- auto pos_rounder_t = Full (pos_rounder, input_shape, hp_dtype);
162- auto neg_rounder_t = Full (neg_rounder, input_shape, hp_dtype);
163-
164- auto zero = MakeConstantScalar (hp_dtype, 0 );
165- auto zero_t = Full (zero, input_shape, hp_dtype);
166- round_scalar = Where (GreaterEqual (tensor, zero_t ), pos_rounder_t , neg_rounder_t );
167- }
168- // Add the rounding scalar.
169- tensor = Add (tensor, round_scalar);
170-
171- // 5) Simply right shift the result to get the final output.
172- scaled_int64_t = RightShift (tensor, MakeConstantScalar (hp_dtype, total_right_shift));
76+ scaled_int64_t = FixedPointMuliply (scaled_int64_t , double_multiplier, input_shape,
77+ param->rounding );
17378 }
17479
175- // 6 ) Add the output zero point.
80+ // 3 ) Add the output zero point.
17681 auto shifted_int64_t = scaled_int64_t ;
17782 if (param->output_zero_point != 0 ) {
17883 auto output_zp = MakeConstantScalar (hp_dtype, param->output_zero_point );
17984 shifted_int64_t = Add (output_zp, scaled_int64_t );
18085 }
18186
182- // 7 ) Clip to the out_dtype min/max.
87+ // 4 ) Clip to the out_dtype min/max.
18388 auto q_min = GetQmin (out_dtype);
18489 auto q_max = GetQmax (out_dtype);
18590 auto clipped_t = Clip (shifted_int64_t , q_min, q_max);
0 commit comments