@@ -124,7 +124,7 @@ TVM_REGISTER_API("relay._quantize.annotate")
124124 }
125125 return e;
126126 };
127- return ForwardRewrite (expr, " FQAnnotateRewrite" , nullptr , nullptr );
127+ return ForwardRewrite (expr, " FQAnnotateRewrite" , nullptr , fmulti_ref );
128128});
129129
130130
@@ -329,9 +329,11 @@ float ChooseDomScale(const std::vector<const QRealizeIntExprNode*>& nptrs) {
329329
330330
331331/* \brief Unify the dom scale of arguments */
332- Array<Expr> UnifyDTypeScale (const Array<Expr>& args,
332+ Array<Expr> UnifyDTypeScale (const Array<Expr>& ref_args,
333+ const Array<Expr>& args,
333334 DataType* dtype_ptr,
334335 Expr* scale_ptr) {
336+ static const Op& simulated_quantize = Op::Get (" relay.op.annotation.simulated_quantize" );
335337 const QConfig& cfg = QConfig::Current ();
336338
337339 std::vector<const QRealizeIntExprNode*> nptrs;
@@ -344,10 +346,19 @@ Array<Expr> UnifyDTypeScale(const Array<Expr>& args,
344346 }
345347
346348 // unify the data type
349+ CHECK_EQ (ref_args.size (), args.size ());
347350 DataType dtype = cfg->dtype_activation ;
348351 for (size_t i = 0 ; i < ret.size (); ++i) {
352+ auto ref_arg = ref_args[i].as <CallNode>();
349353 if (nptrs[i]->dtype != dtype) {
350354 ret.Set (i, Cast (ret[i], dtype));
355+ } else if (ref_arg && ref_arg->op .same_as (simulated_quantize) &&
356+ ref_arg->attrs .as <SimulatedQuantizeAttrs>()->kind == kQInput ) {
357+ auto new_arg = Cast (ret[i], cfg->dtype_input );
358+ if (cfg->use_stop_fusion ) {
359+ new_arg = StopFusion (new_arg);
360+ }
361+ ret.Set (i, Cast (new_arg, dtype));
351362 }
352363 }
353364
@@ -371,7 +382,7 @@ Expr AddRealize(const Call& ref_call,
371382 if (new_args[0 ].as <QRealizeIntExprNode>() && new_args[1 ].as <QRealizeIntExprNode>()) {
372383 DataType dtype;
373384 Expr dom_scale;
374- Array<Expr> ret_args = UnifyDTypeScale (new_args, &dtype, &dom_scale);
385+ Array<Expr> ret_args = UnifyDTypeScale (ref_call-> args , new_args, &dtype, &dom_scale);
375386 Expr ret = ForwardOp (ref_call, ret_args);
376387 return QRealizeIntExprNode::make (ret, dom_scale, dtype);
377388 }
@@ -387,15 +398,19 @@ Expr ConcatenateRealize(const Call& ref_call,
387398 const Array<Expr>& new_args,
388399 const NodeRef& ctx) {
389400 CHECK_EQ (new_args.size (), 1 );
401+ CHECK_EQ (ref_call->args .size (), 1 );
390402
391403 const auto * tuple = new_args[0 ].as <TupleNode>();
404+ const auto * ref_tuple = ref_call->args [0 ].as <TupleNode>();
392405 CHECK (tuple);
406+ CHECK (ref_tuple);
393407 const Array<Expr>& arr = tuple->fields ;
408+ const Array<Expr>& ref_arr = ref_tuple->fields ;
394409
395410 if (arr[0 ].as <QRealizeIntExprNode>()) {
396411 DataType dtype;
397412 Expr dom_scale;
398- Array<Expr> ret_args = UnifyDTypeScale (arr, &dtype, &dom_scale);
413+ Array<Expr> ret_args = UnifyDTypeScale (ref_arr, arr, &dtype, &dom_scale);
399414 Expr ret = ForwardOp (ref_call, {TupleNode::make (ret_args)});
400415 return QRealizeIntExprNode::make (ret, dom_scale, dtype);
401416 } else {
@@ -530,7 +545,8 @@ TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
530545 p->stream << " skip_k_conv==" << op->skip_k_conv << " , " ;
531546 p->stream << " round_for_shift==" << op->round_for_shift << " , " ;
532547 p->stream << " store_lowbit_output==" << op->store_lowbit_output << " , " ;
533- p->stream << " debug_enabled_ops==" << op->debug_enabled_ops ;
548+ p->stream << " debug_enabled_ops==" << op->debug_enabled_ops << " , " ;
549+ p->stream << " use_stop_fusion==" << op->use_stop_fusion ;
534550 p->stream << " )" ;
535551});
536552
0 commit comments