3131#include < tvm/relay/attrs/annotation.h>
3232#include " ./quantize.h"
3333#include " ../pattern_util.h"
34+ #include " ../../qnn/util.h"
3435
3536namespace tvm {
3637namespace relay {
@@ -97,7 +98,8 @@ inline Expr ForwardOp(const Call& ref_call, const Array<Expr>& args) {
9798
9899
99100/* calculate `data * s1 / s2`, use shift if possible */
100- inline Expr MulAndDiv (Expr data, float s1, float s2, DataType dtype) {
101+ inline Expr MulAndDiv (Expr data, float s1, float s2, DataType dtype, const Array<IndexExpr> &data_shape) {
102+ const QConfig& cfg = QConfig::Current ();
101103 // here we assume the dtype of data is dtype activation
102104 if (s1 == s2) return data;
103105
@@ -110,9 +112,8 @@ inline Expr MulAndDiv(Expr data, float s1, float s2, DataType dtype) {
110112 } else if (static_cast <int >(factor) == factor) {
111113 return Multiply (data, MakeConstantScalar (dtype, factor));
112114 } else {
113- data = Cast (data, Float (32 ));
114- data = Multiply (data, MakeConstantScalar (Float (32 ), factor));
115- return Cast (Round (data), dtype);
115+ data = qnn::FixedPointMultiply (Cast (data, Int (64 )), factor, data_shape, cfg->rounding );
116+ return Cast (data, dtype);
116117 }
117118}
118119
@@ -164,11 +165,12 @@ Expr QuantizeRealize(const Call& ref_call,
164165 data = Clip (data, clip_min_imm, clip_max_imm);
165166 return QRealizeIntExprNode::make (data, dom_scale, n->dtype );
166167 } else {
167- // float computation
168- data = Cast (data, Float (32 ));
169- Expr scaled_data = Multiply (data, Divide (n->dom_scale , dom_scale));
170- Expr round_data = Clip (Round (scaled_data), clip_min_imm, clip_max_imm);
171- return QRealizeIntExprNode::make (round_data, dom_scale, Float (32 ));
168+ data = Cast (data, Int (64 ));
169+ data = qnn::FixedPointMultiply (data, idom_scale_imm / odom_scale_imm,
170+ ref_call->type_as <TensorTypeNode>()->shape ,
171+ cfg->rounding );
172+ data = Cast (Clip (data, clip_min_imm, clip_max_imm), n->dtype );
173+ return QRealizeIntExprNode::make (data, dom_scale, n->dtype );
172174 }
173175 }
174176
@@ -355,7 +357,7 @@ Array<Expr> UnifyDTypeScale(const Array<Expr>& ref_args, const Array<Expr>& args
355357 Expr dom_scale = MakeConstantScalar (Float (32 ), s);
356358 for (size_t i = 0 ; i < ret.size (); ++i) {
357359 float cur_s = GetScalarFromConstant<float >(nptrs[i]->dom_scale );
358- ret.Set (i, MulAndDiv (ret[i], cur_s, s, dtype));
360+ ret.Set (i, MulAndDiv (ret[i], cur_s, s, dtype, ref_args[i]-> type_as <TensorTypeNode>()-> shape ));
359361 }
360362
361363 *dtype_ptr = dtype;
0 commit comments