3131namespace tvm {
3232namespace relay {
3333
34+ Expr ConvolveQuantizedTensors (const Expr& quantized_data,
35+ const Expr& quantized_kernel, const QuantizedConv2DAttrs*& param) {
36+ // TODO (janimesh) - Who should decide the accumulation dtype?
37+ if (param->input_zero_point == 0 && param->kernel_zero_point == 0 ) {
38+ Expr int8_conv = Conv2D (quantized_data,
39+ quantized_kernel,
40+ param->strides ,
41+ param->padding ,
42+ param->dilation ,
43+ param->groups ,
44+ param->channels ,
45+ param->kernel_size ,
46+ param->data_layout ,
47+ param->kernel_layout ,
48+ param->out_layout ,
49+ Int (32 ));
50+ return int8_conv;
51+ }
52+ LOG (FATAL) << " Only symmetric quantization supported" ;
53+ return Expr (); // to hide the warning.
54+ }
55+
56+ Expr ScaleHandling (const Expr& convolved_tensor,
57+ const QuantizedConv2DAttrs*& param) {
58+ // The scale handling can be done in many ways.
59+ // 1) Floating point handling
60+ // Here we can multiply the scale to the convolved_tensor, round to nearest
61+ // integer and then cast back to int32.
62+ // 2) Integer only scale handling
63+ // Here, the computation is converted to a fixed point computation by
64+ // computing output multiplier and shift. This is useful, if the target
65+ // device does not support/have very expensive floating point computations.
66+
67+ if (param->use_integer_computation_for_scale_handling == false ) {
68+ double multiplier = (param->input_scale * param->kernel_scale ) /
69+ param->output_scale ;
70+ auto scalar_multiplier = MakeConstantScalar (Float (32 ), multiplier);
71+ auto casted_convolved_tensor = Cast (convolved_tensor, Float (32 ));
72+ auto scaled_fp32_tensor = Multiply (casted_convolved_tensor, scalar_multiplier);
73+ auto scaled_rounded_fp32_tensor = Round (scaled_fp32_tensor);
74+ auto scaled_tensor = Cast (scaled_rounded_fp32_tensor, Int (32 ));
75+ return scaled_tensor;
76+ }
77+ LOG (FATAL) << " Only floating point scale handling is supported for now." ;
78+ return Expr (); // to hide the warning.
79+ }
80+
81+ Expr ReQuantize (const Expr& scaled_output,
82+ const QuantizedConv2DAttrs*& param) {
83+ Expr requantized_output = Cast (scaled_output, param->out_dtype );
84+ return requantized_output;
85+ }
86+
3487Expr QuantizedConv2DForwardRewrite (const Call& ref_call,
3588 const Array<Expr>& new_args,
3689 const NodeRef& ctx) {
37- // TODO(janimesh) - This is not the right calculation. This only serves as a
38- // prototype to discuss the flow of lowering of quantization ops and
39- // namespaces.
4090 CHECK_EQ (new_args.size (), 2 );
4191 Expr quantized_data = new_args[0 ];
4292 Expr quantized_kernel = new_args[1 ];
@@ -62,6 +112,68 @@ Expr QuantizedConv2DForwardRewrite(const Call& ref_call,
62112 // TODO(janimesh) - Look at the literature and use the right scale
63113 // calculations.
64114 return int8_conv;
115+
116+ // Check for current quantization support.
117+ CHECK_EQ (param->input_zero_point , 0 )
118+ << " Encountered non-zero zero point."
119+ << " Only symmetric quantization supported for now." ;
120+ CHECK_EQ (param->kernel_zero_point , 0 )
121+ << " Encountered non-zero zero point."
122+ << " Only symmetric quantization supported for now." ;
123+ CHECK_EQ (param->output_zero_point , 0 )
124+ << " Encountered non-zero zero point."
125+ << " Only symmetric quantization supported for now." ;
126+ CHECK_EQ (param->use_integer_computation_for_scale_handling , false )
127+ << " Currently floating point computation is used for scale handling. "
128+ << " Please switch to False if HW supports floating point arithmetic" ;
129+
130+ // Lowering of the quantized_convolution.
131+ //
132+ // For FP32, the conv output is
133+ // C = conv(A, W)
134+ // or, C(n, oc, oh, ow) = A(n, ic, oh + r, ow + s) * W(oc, ic, r, s)
135+ // where, ic, r, s are reduce axis.
136+ //
137+ // For quantized convolution, each tensor is represented in quantized format
138+ // A = scale_a x (QA - zp_A)
139+ // where QA is quantized tensor, scale_a and zp_A are quantizations params.
140+ //
141+ // For symmetric quantization, the zp_* for all tensors is 0.
142+ // So, the quantized_convolution becomes
143+ //
144+ // scale_c * QC(n, oc, oh, ow) =
145+ // scale_a * QA(n, ic, oh + r, ow + s) x
146+ // scale_w * QW(oc, ic, r, s)
147+ //
148+ // So, to get the quantized tensor C, the computation is
149+ //
150+ // QC(n, oc, oh, ow) = (scale_a * scale_w)/scale_c x
151+ // QA(n, ic, oh + r, ow + s) x QW(oc, ic, r, s)
152+ //
153+ // or,
154+ // QC = K * conv(QA, QB)
155+ //
156+ // For asymmetric computation, we can perform similar unrolling. We can find
157+ // more details at
158+ // https://discuss.tvm.ai/t/tf-lite-quantized-conv2d-operator-conversion/2651/8?u=janimesh
159+
160+ // The above computation is arranged in following functions
161+ // 1) ConvolveQuantizedTensors
162+ // a) For symmetric, conv(QA, QB).
163+ // b) For asymmetric, it involves 4 terms.
164+ // 2) ScaleHandling
165+ // a) Takes convolved output and scales it.
166+ // b) Can support both float and integer computation.
167+ // 3) Requantize
168+ // a) Converts the intermediate dtype back to int8.
169+ Expr convolved_tensor = ConvolveQuantizedTensors (quantized_data,
170+ quantized_kernel,
171+ param);
172+ Expr scaled_output = ScaleHandling (convolved_tensor, param);
173+ Expr requantized_output = ReQuantize (scaled_output, param);
174+ // TODO(janimesh) - Look at the literature and use the right scale
175+ // calculations.
176+ return requantized_output;
65177}
66178
67179RELAY_REGISTER_OP (" nn_quantized.quantized_conv2d" )
0 commit comments