@@ -124,7 +124,7 @@ TVM_REGISTER_API("relay._quantize.annotate")
124124 }
125125 return e;
126126 };
127- return ForwardRewrite (expr, " FQAnnotateRewrite" , nullptr , nullptr );
127+ return ForwardRewrite (expr, " FQAnnotateRewrite" , nullptr , fmulti_ref );
128128});
129129
130130
@@ -329,9 +329,11 @@ float ChooseDomScale(const std::vector<const QRealizeIntExprNode*>& nptrs) {
329329
330330
331331/* \brief Unify the dom scale of arguments */
332- Array<Expr> UnifyDTypeScale (const Array<Expr>& args,
332+ Array<Expr> UnifyDTypeScale (const Array<Expr>& ref_args,
333+ const Array<Expr>& args,
333334 DataType* dtype_ptr,
334335 Expr* scale_ptr) {
336+ static const Op& simulated_quantize = Op::Get (" relay.op.annotation.simulated_quantize" );
335337 const QConfig& cfg = QConfig::Current ();
336338
337339 std::vector<const QRealizeIntExprNode*> nptrs;
@@ -344,10 +346,17 @@ Array<Expr> UnifyDTypeScale(const Array<Expr>& args,
344346 }
345347
346348 // unify the data type
349+ CHECK_EQ (ref_args.size (), args.size ());
347350 DataType dtype = cfg->dtype_activation ;
348351 for (size_t i = 0 ; i < ret.size (); ++i) {
352+ auto ref_arg = ref_args[i].as <CallNode>();
349353 if (nptrs[i]->dtype != dtype) {
350354 ret.Set (i, Cast (ret[i], dtype));
355+ } else if (ref_arg && ref_arg->op .same_as (simulated_quantize) &&
356+ ref_arg->attrs .as <SimulatedQuantizeAttrs>()->kind == kQInput ) {
357+ auto new_arg = Cast (ret[i], cfg->dtype_input );
358+ new_arg = StopFusion (new_arg);
359+ ret.Set (i, Cast (new_arg, dtype));
351360 }
352361 }
353362
@@ -371,7 +380,7 @@ Expr AddRealize(const Call& ref_call,
371380 if (new_args[0 ].as <QRealizeIntExprNode>() && new_args[1 ].as <QRealizeIntExprNode>()) {
372381 DataType dtype;
373382 Expr dom_scale;
374- Array<Expr> ret_args = UnifyDTypeScale (new_args, &dtype, &dom_scale);
383+ Array<Expr> ret_args = UnifyDTypeScale (ref_call-> args , new_args, &dtype, &dom_scale);
375384 Expr ret = ForwardOp (ref_call, ret_args);
376385 return QRealizeIntExprNode::make (ret, dom_scale, dtype);
377386 }
@@ -387,15 +396,19 @@ Expr ConcatenateRealize(const Call& ref_call,
387396 const Array<Expr>& new_args,
388397 const NodeRef& ctx) {
389398 CHECK_EQ (new_args.size (), 1 );
399+ CHECK_EQ (ref_call->args .size (), 1 );
390400
391401 const auto * tuple = new_args[0 ].as <TupleNode>();
402+ const auto * ref_tuple = ref_call->args [0 ].as <TupleNode>();
392403 CHECK (tuple);
404+ CHECK (ref_tuple);
393405 const Array<Expr>& arr = tuple->fields ;
406+ const Array<Expr>& ref_arr = ref_tuple->fields ;
394407
395408 if (arr[0 ].as <QRealizeIntExprNode>()) {
396409 DataType dtype;
397410 Expr dom_scale;
398- Array<Expr> ret_args = UnifyDTypeScale (arr, &dtype, &dom_scale);
411+ Array<Expr> ret_args = UnifyDTypeScale (ref_arr, arr, &dtype, &dom_scale);
399412 Expr ret = ForwardOp (ref_call, {TupleNode::make (ret_args)});
400413 return QRealizeIntExprNode::make (ret, dom_scale, dtype);
401414 } else {
0 commit comments