3131#include  < tvm/relay/attrs/annotation.h> 
3232#include  " ./quantize.h" 
3333#include  " ../pattern_util.h" 
34+ #include  " ../../qnn/util.h" 
3435
3536namespace  tvm  {
3637namespace  relay  {
@@ -97,7 +98,9 @@ inline Expr ForwardOp(const Call& ref_call, const Array<Expr>& args) {
9798
9899
99100/*  calculate `data * s1 / s2`, use shift if possible */ 
100- inline  Expr MulAndDiv (Expr data, float  s1, float  s2, DataType dtype) {
101+ inline  Expr MulAndDiv (Expr data, float  s1, float  s2, DataType dtype,
102+                       const  Array<IndexExpr> &data_shape) {
103+   const  QConfig& cfg = QConfig::Current ();
101104  //  here we assume the dtype of data is dtype activation
102105  if  (s1 == s2) return  data;
103106
@@ -110,9 +113,8 @@ inline Expr MulAndDiv(Expr data, float s1, float s2, DataType dtype) {
110113  } else  if  (static_cast <int >(factor) == factor) {
111114    return  Multiply (data, MakeConstantScalar (dtype, factor));
112115  } else  {
113-     data = Cast (data, Float (32 ));
114-     data = Multiply (data, MakeConstantScalar (Float (32 ), factor));
115-     return  Cast (Round (data), dtype);
116+     data = qnn::FixedPointMultiply (Cast (data, Int (64 )), factor, data_shape, cfg->rounding );
117+     return  Cast (data, dtype);
116118  }
117119}
118120
@@ -164,11 +166,12 @@ Expr QuantizeRealize(const Call& ref_call,
164166      data = Clip (data, clip_min_imm, clip_max_imm);
165167      return  QRealizeIntExprNode::make (data, dom_scale, n->dtype );
166168    } else  {
167-       //  float computation
168-       data = Cast (data, Float (32 ));
169-       Expr scaled_data = Multiply (data, Divide (n->dom_scale , dom_scale));
170-       Expr round_data = Clip (Round (scaled_data), clip_min_imm, clip_max_imm);
171-       return  QRealizeIntExprNode::make (round_data, dom_scale, Float (32 ));
169+       data = Cast (data, Int (64 ));
170+       data = qnn::FixedPointMultiply (data, idom_scale_imm / odom_scale_imm,
171+                                      ref_call->type_as <TensorTypeNode>()->shape ,
172+                                      cfg->rounding );
173+       data = Cast (Clip (data, clip_min_imm, clip_max_imm), n->dtype );
174+       return  QRealizeIntExprNode::make (data, dom_scale, n->dtype );
172175    }
173176  }
174177
@@ -355,7 +358,7 @@ Array<Expr> UnifyDTypeScale(const Array<Expr>& ref_args, const Array<Expr>& args
355358  Expr dom_scale = MakeConstantScalar (Float (32 ), s);
356359  for  (size_t  i = 0 ; i < ret.size (); ++i) {
357360    float  cur_s = GetScalarFromConstant<float >(nptrs[i]->dom_scale );
358-     ret.Set (i, MulAndDiv (ret[i], cur_s, s, dtype));
361+     ret.Set (i, MulAndDiv (ret[i], cur_s, s, dtype, ref_args[i]-> type_as <TensorTypeNode>()-> shape ));
359362  }
360363
361364  *dtype_ptr = dtype;
0 commit comments