Skip to content

Commit 4b29df7

Browse files
vinx13wweic
authored andcommitted
[Relay][Quantization] Speed-aware quantization scheme improvement (apache#2723)
* [Relay][Quantization] Speed-aware quantization scheme improvement * Add comment * Add use_stop_fusion to qconfig * Update comment
1 parent 2470364 commit 4b29df7

File tree

5 files changed

+45
-8
lines changed

5 files changed

+45
-8
lines changed

python/tvm/relay/build_module.py

Lines changed: 14 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
from .. import nd as _nd, target as _target, autotvm
1010
from ..contrib import graph_runtime as _graph_rt
1111
from . import ir_pass
12-
from . import expr
12+
from . import expr as _expr
1313
from .backend import interpreter as _interpreter
1414
from .backend import graph_runtime_codegen as _graph_gen
1515

@@ -22,6 +22,7 @@
2222
"FoldScaleAxis": 3,
2323
"AlterOpLayout": 3,
2424
"CanonicalizeOps": 3,
25+
"EliminateCommonSubexpr": 3,
2526
}
2627

2728

@@ -126,8 +127,8 @@ def _bind_params_by_name(func, params):
126127
arg = name_dict[k]
127128
if arg is None:
128129
raise ValueError("Multiple args in the function have name %s" % k)
129-
bind_dict[arg] = expr.const(v)
130-
return expr.bind(func, bind_dict)
130+
bind_dict[arg] = _expr.const(v)
131+
return _expr.bind(func, bind_dict)
131132

132133

133134
def optimize(func, target=None, params=None):
@@ -162,6 +163,16 @@ def optimize(func, target=None, params=None):
162163
func = ir_pass.infer_type(func)
163164
func = ir_pass.simplify_inference(func)
164165

166+
if cfg.pass_enabled("EliminateCommonSubexpr"):
167+
def fskip(expr):
168+
if isinstance(expr, _expr.Call) and expr.op.name == 'cast' and \
169+
expr.attrs.dtype == 'int32':
170+
return True
171+
return False
172+
173+
func = ir_pass.infer_type(func)
174+
func = ir_pass.eliminate_common_subexpr(func, fskip)
175+
165176
if cfg.pass_enabled("CombineParallelConv2D"):
166177
func = ir_pass.infer_type(func)
167178
func = ir_pass.combine_parallel_conv2d(func)

python/tvm/relay/quantize/_annotate.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -192,6 +192,9 @@ def add_rewrite(ref_call, new_args, ctx):
192192
else:
193193
# quantize rhs to INPUT field if it is not Constant
194194
rhs_expr = attach_simulated_quantize(rhs_expr, QAnnotateKind.INPUT)
195+
if lhs_kind == QAnnotateKind.ACTIVATION and rhs_kind == QAnnotateKind.ACTIVATION:
196+
# quantize rhs to INPUT field if both lhs and rhs are ACTIVATION
197+
rhs_expr = attach_simulated_quantize(rhs_expr, QAnnotateKind.INPUT)
195198

196199
expr = _forward_op(ref_call, [lhs_expr, rhs_expr])
197200
return QAnnotateExpr(expr, QAnnotateKind.ACTIVATION)

python/tvm/relay/quantize/quantize.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,7 @@ class QConfig(NodeBase):
5858
"round_for_shift": True,
5959
"store_lowbit_output": True,
6060
"debug_enabled_ops": None,
61+
"use_stop_fusion": True
6162
}
6263

6364
# pylint: disable=no-member
@@ -129,6 +130,10 @@ def qconfig(**kwargs):
129130
Whether to store low-bit integer back as output before dequantizing.
130131
Some accelerators need this, e.g. VTA.
131132
133+
use_stop_fusion: boolean
134+
Whether add stop_fusion when casting to dtype_activation. stop_fusion forces lowbit
135+
results to be stored in memory.
136+
132137
Returns
133138
-------
134139
config: QConfig

src/relay/pass/quantize.cc

Lines changed: 21 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -124,7 +124,7 @@ TVM_REGISTER_API("relay._quantize.annotate")
124124
}
125125
return e;
126126
};
127-
return ForwardRewrite(expr, "FQAnnotateRewrite", nullptr, nullptr);
127+
return ForwardRewrite(expr, "FQAnnotateRewrite", nullptr, fmulti_ref);
128128
});
129129

130130

@@ -329,9 +329,11 @@ float ChooseDomScale(const std::vector<const QRealizeIntExprNode*>& nptrs) {
329329

330330

331331
/* \brief Unify the dom scale of arguments */
332-
Array<Expr> UnifyDTypeScale(const Array<Expr>& args,
332+
Array<Expr> UnifyDTypeScale(const Array<Expr>& ref_args,
333+
const Array<Expr>& args,
333334
DataType* dtype_ptr,
334335
Expr* scale_ptr) {
336+
static const Op& simulated_quantize = Op::Get("relay.op.annotation.simulated_quantize");
335337
const QConfig& cfg = QConfig::Current();
336338

337339
std::vector<const QRealizeIntExprNode*> nptrs;
@@ -344,10 +346,19 @@ Array<Expr> UnifyDTypeScale(const Array<Expr>& args,
344346
}
345347

346348
// unify the data type
349+
CHECK_EQ(ref_args.size(), args.size());
347350
DataType dtype = cfg->dtype_activation;
348351
for (size_t i = 0; i < ret.size(); ++i) {
352+
auto ref_arg = ref_args[i].as<CallNode>();
349353
if (nptrs[i]->dtype != dtype) {
350354
ret.Set(i, Cast(ret[i], dtype));
355+
} else if (ref_arg && ref_arg->op.same_as(simulated_quantize) &&
356+
ref_arg->attrs.as<SimulatedQuantizeAttrs>()->kind == kQInput) {
357+
auto new_arg = Cast(ret[i], cfg->dtype_input);
358+
if (cfg->use_stop_fusion) {
359+
new_arg = StopFusion(new_arg);
360+
}
361+
ret.Set(i, Cast(new_arg, dtype));
351362
}
352363
}
353364

@@ -371,7 +382,7 @@ Expr AddRealize(const Call& ref_call,
371382
if (new_args[0].as<QRealizeIntExprNode>() && new_args[1].as<QRealizeIntExprNode>()) {
372383
DataType dtype;
373384
Expr dom_scale;
374-
Array<Expr> ret_args = UnifyDTypeScale(new_args, &dtype, &dom_scale);
385+
Array<Expr> ret_args = UnifyDTypeScale(ref_call->args, new_args, &dtype, &dom_scale);
375386
Expr ret = ForwardOp(ref_call, ret_args);
376387
return QRealizeIntExprNode::make(ret, dom_scale, dtype);
377388
}
@@ -387,15 +398,19 @@ Expr ConcatenateRealize(const Call& ref_call,
387398
const Array<Expr>& new_args,
388399
const NodeRef& ctx) {
389400
CHECK_EQ(new_args.size(), 1);
401+
CHECK_EQ(ref_call->args.size(), 1);
390402

391403
const auto* tuple = new_args[0].as<TupleNode>();
404+
const auto* ref_tuple = ref_call->args[0].as<TupleNode>();
392405
CHECK(tuple);
406+
CHECK(ref_tuple);
393407
const Array<Expr>& arr = tuple->fields;
408+
const Array<Expr>& ref_arr = ref_tuple->fields;
394409

395410
if (arr[0].as<QRealizeIntExprNode>()) {
396411
DataType dtype;
397412
Expr dom_scale;
398-
Array<Expr> ret_args = UnifyDTypeScale(arr, &dtype, &dom_scale);
413+
Array<Expr> ret_args = UnifyDTypeScale(ref_arr, arr, &dtype, &dom_scale);
399414
Expr ret = ForwardOp(ref_call, {TupleNode::make(ret_args)});
400415
return QRealizeIntExprNode::make(ret, dom_scale, dtype);
401416
} else {
@@ -530,7 +545,8 @@ TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
530545
p->stream << "skip_k_conv==" << op->skip_k_conv << ", ";
531546
p->stream << "round_for_shift==" << op->round_for_shift << ", ";
532547
p->stream << "store_lowbit_output==" << op->store_lowbit_output << ", ";
533-
p->stream << "debug_enabled_ops==" << op->debug_enabled_ops;
548+
p->stream << "debug_enabled_ops==" << op->debug_enabled_ops << ", ";
549+
p->stream << "use_stop_fusion==" << op->use_stop_fusion;
534550
p->stream << ")";
535551
});
536552

src/relay/pass/quantize.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -110,6 +110,7 @@ class QConfigNode : public Node {
110110
bool round_for_shift = true;
111111
bool store_lowbit_output = true;
112112
Array<Expr> debug_enabled_ops = Array<Expr>(NodePtr<Node>(nullptr));
113+
bool use_stop_fusion = true;
113114

114115
void VisitAttrs(AttrVisitor* v) final {
115116
v->Visit("nbit_input", &nbit_input);
@@ -123,6 +124,7 @@ class QConfigNode : public Node {
123124
v->Visit("round_for_shift", &round_for_shift);
124125
v->Visit("store_lowbit_output", &store_lowbit_output);
125126
v->Visit("debug_enabled_ops", &debug_enabled_ops);
127+
v->Visit("use_stop_fusion", &use_stop_fusion);
126128
}
127129

128130
static constexpr const char* _type_key = "relay.quantize.QConfig";

0 commit comments

Comments
 (0)