Skip to content

Commit 4996286

Browse files
committed
[Relay][Quantization] Speed-aware quantization scheme improvement
1 parent c8373ec commit 4996286

File tree

3 files changed

+33
-7
lines changed

3 files changed

+33
-7
lines changed

python/tvm/relay/build_module.py

Lines changed: 14 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
from .. import nd as _nd, target as _target, autotvm
1010
from ..contrib import graph_runtime as _graph_rt
1111
from . import ir_pass
12-
from . import expr
12+
from . import expr as _expr
1313
from .backend import interpreter as _interpreter
1414
from .backend import graph_runtime_codegen as _graph_gen
1515

@@ -22,6 +22,7 @@
2222
"FoldScaleAxis": 3,
2323
"AlterOpLayout": 3,
2424
"CanonicalizeOps": 3,
25+
"EliminateCommonSubexpr": 3,
2526
}
2627

2728

@@ -126,8 +127,8 @@ def _bind_params_by_name(func, params):
126127
arg = name_dict[k]
127128
if arg is None:
128129
raise ValueError("Multiple args in the function have name %s" % k)
129-
bind_dict[arg] = expr.const(v)
130-
return expr.bind(func, bind_dict)
130+
bind_dict[arg] = _expr.const(v)
131+
return _expr.bind(func, bind_dict)
131132

132133

133134
def optimize(func, target=None, params=None):
@@ -162,6 +163,16 @@ def optimize(func, target=None, params=None):
162163
func = ir_pass.infer_type(func)
163164
func = ir_pass.simplify_inference(func)
164165

166+
if cfg.pass_enabled("EliminateCommonSubexpr"):
167+
def fskip(expr):
168+
if isinstance(expr, _expr.Call) and expr.op.name == 'cast' and \
169+
expr.attrs.dtype == 'int32':
170+
return True
171+
return False
172+
173+
func = ir_pass.infer_type(func)
174+
func = ir_pass.eliminate_common_subexpr(func, fskip)
175+
165176
if cfg.pass_enabled("CombineParallelConv2D"):
166177
func = ir_pass.infer_type(func)
167178
func = ir_pass.combine_parallel_conv2d(func)

python/tvm/relay/quantize/_annotate.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -192,6 +192,8 @@ def add_rewrite(ref_call, new_args, ctx):
192192
else:
193193
# quantize rhs to INPUT field if it is not Constant
194194
rhs_expr = attach_simulated_quantize(rhs_expr, QAnnotateKind.INPUT)
195+
if lhs_kind == QAnnotateKind.ACTIVATION and rhs_kind == QAnnotateKind.ACTIVATION:
196+
rhs_expr = attach_simulated_quantize(rhs_expr, QAnnotateKind.INPUT)
195197

196198
expr = _forward_op(ref_call, [lhs_expr, rhs_expr])
197199
return QAnnotateExpr(expr, QAnnotateKind.ACTIVATION)

src/relay/pass/quantize.cc

Lines changed: 17 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -124,7 +124,7 @@ TVM_REGISTER_API("relay._quantize.annotate")
124124
}
125125
return e;
126126
};
127-
return ForwardRewrite(expr, "FQAnnotateRewrite", nullptr, nullptr);
127+
return ForwardRewrite(expr, "FQAnnotateRewrite", nullptr, fmulti_ref);
128128
});
129129

130130

@@ -329,9 +329,11 @@ float ChooseDomScale(const std::vector<const QRealizeIntExprNode*>& nptrs) {
329329

330330

331331
/* \brief Unify the dom scale of arguments */
332-
Array<Expr> UnifyDTypeScale(const Array<Expr>& args,
332+
Array<Expr> UnifyDTypeScale(const Array<Expr>& ref_args,
333+
const Array<Expr>& args,
333334
DataType* dtype_ptr,
334335
Expr* scale_ptr) {
336+
static const Op& simulated_quantize = Op::Get("relay.op.annotation.simulated_quantize");
335337
const QConfig& cfg = QConfig::Current();
336338

337339
std::vector<const QRealizeIntExprNode*> nptrs;
@@ -344,10 +346,17 @@ Array<Expr> UnifyDTypeScale(const Array<Expr>& args,
344346
}
345347

346348
// unify the data type
349+
CHECK_EQ(ref_args.size(), args.size());
347350
DataType dtype = cfg->dtype_activation;
348351
for (size_t i = 0; i < ret.size(); ++i) {
352+
auto ref_arg = ref_args[i].as<CallNode>();
349353
if (nptrs[i]->dtype != dtype) {
350354
ret.Set(i, Cast(ret[i], dtype));
355+
} else if (ref_arg && ref_arg->op.same_as(simulated_quantize) &&
356+
ref_arg->attrs.as<SimulatedQuantizeAttrs>()->kind == kQInput) {
357+
auto new_arg = Cast(ret[i], cfg->dtype_input);
358+
new_arg = StopFusion(new_arg);
359+
ret.Set(i, Cast(new_arg, dtype));
351360
}
352361
}
353362

@@ -371,7 +380,7 @@ Expr AddRealize(const Call& ref_call,
371380
if (new_args[0].as<QRealizeIntExprNode>() && new_args[1].as<QRealizeIntExprNode>()) {
372381
DataType dtype;
373382
Expr dom_scale;
374-
Array<Expr> ret_args = UnifyDTypeScale(new_args, &dtype, &dom_scale);
383+
Array<Expr> ret_args = UnifyDTypeScale(ref_call->args, new_args, &dtype, &dom_scale);
375384
Expr ret = ForwardOp(ref_call, ret_args);
376385
return QRealizeIntExprNode::make(ret, dom_scale, dtype);
377386
}
@@ -387,15 +396,19 @@ Expr ConcatenateRealize(const Call& ref_call,
387396
const Array<Expr>& new_args,
388397
const NodeRef& ctx) {
389398
CHECK_EQ(new_args.size(), 1);
399+
CHECK_EQ(ref_call->args.size(), 1);
390400

391401
const auto* tuple = new_args[0].as<TupleNode>();
402+
const auto* ref_tuple = ref_call->args[0].as<TupleNode>();
392403
CHECK(tuple);
404+
CHECK(ref_tuple);
393405
const Array<Expr>& arr = tuple->fields;
406+
const Array<Expr>& ref_arr = ref_tuple->fields;
394407

395408
if (arr[0].as<QRealizeIntExprNode>()) {
396409
DataType dtype;
397410
Expr dom_scale;
398-
Array<Expr> ret_args = UnifyDTypeScale(arr, &dtype, &dom_scale);
411+
Array<Expr> ret_args = UnifyDTypeScale(ref_arr, arr, &dtype, &dom_scale);
399412
Expr ret = ForwardOp(ref_call, {TupleNode::make(ret_args)});
400413
return QRealizeIntExprNode::make(ret, dom_scale, dtype);
401414
} else {

0 commit comments

Comments
 (0)