Skip to content

Commit cd27778

Browse files
committed
[Relay][Quantize] Use fixed point mulplications
1 parent 3a32729 commit cd27778

File tree

7 files changed

+29
-19
lines changed

7 files changed

+29
-19
lines changed

python/tvm/relay/quantize/quantize.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -83,6 +83,7 @@ class QConfig(NodeBase):
8383
"do_simulation": False,
8484
"round_for_shift": True,
8585
"debug_enabled_ops": None,
86+
"rounding": "UPWARD"
8687
}
8788

8889
# pylint: disable=no-member
@@ -160,6 +161,9 @@ def qconfig(**kwargs):
160161
is None, which means will try to call all operartors' annotate rewrite
161162
function.
162163
164+
rounding: "UPWARD" or "TONEAREST"
165+
Rounding direction for fixed point multiplications.
166+
163167
Returns
164168
-------
165169
config: QConfig

src/relay/pass/quantize/quantize.cc

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -128,7 +128,8 @@ TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
128128
p->stream << "skip_conv_layers==" << op->skip_conv_layers << ", ";
129129
p->stream << "do_simulation==" << op->do_simulation << ", ";
130130
p->stream << "round_for_shift==" << op->round_for_shift << ", ";
131-
p->stream << "debug_enabled_ops==" << op->debug_enabled_ops;
131+
p->stream << "debug_enabled_ops==" << op->debug_enabled_ops <<", ";
132+
p->stream << "rounding==" << op->rounding;
132133
p->stream << ")";
133134
});
134135

src/relay/pass/quantize/quantize.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -77,6 +77,7 @@ class QConfigNode : public Node {
7777
bool do_simulation = false;
7878
bool round_for_shift = true;
7979
Array<Expr> debug_enabled_ops = Array<Expr>(NodePtr<Node>(nullptr));
80+
std::string rounding = "UPWARD";
8081

8182
void VisitAttrs(AttrVisitor* v) final {
8283
v->Visit("nbit_input", &nbit_input);
@@ -90,6 +91,7 @@ class QConfigNode : public Node {
9091
v->Visit("do_simulation", &do_simulation);
9192
v->Visit("round_for_shift", &round_for_shift);
9293
v->Visit("debug_enabled_ops", &debug_enabled_ops);
94+
v->Visit("rounding", &rounding);
9395
}
9496

9597
static constexpr const char* _type_key = "relay.quantize.QConfig";

src/relay/pass/quantize/realize.cc

Lines changed: 13 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@
3131
#include <tvm/relay/attrs/annotation.h>
3232
#include "./quantize.h"
3333
#include "../pattern_util.h"
34+
#include "../../qnn/util.h"
3435

3536
namespace tvm {
3637
namespace relay {
@@ -97,7 +98,9 @@ inline Expr ForwardOp(const Call& ref_call, const Array<Expr>& args) {
9798

9899

99100
/* calculate `data * s1 / s2`, use shift if possible */
100-
inline Expr MulAndDiv(Expr data, float s1, float s2, DataType dtype) {
101+
inline Expr MulAndDiv(Expr data, float s1, float s2, DataType dtype,
102+
const Array<IndexExpr> &data_shape) {
103+
const QConfig& cfg = QConfig::Current();
101104
// here we assume the dtype of data is dtype activation
102105
if (s1 == s2) return data;
103106

@@ -110,9 +113,8 @@ inline Expr MulAndDiv(Expr data, float s1, float s2, DataType dtype) {
110113
} else if (static_cast<int>(factor) == factor) {
111114
return Multiply(data, MakeConstantScalar(dtype, factor));
112115
} else {
113-
data = Cast(data, Float(32));
114-
data = Multiply(data, MakeConstantScalar(Float(32), factor));
115-
return Cast(Round(data), dtype);
116+
data = qnn::FixedPointMultiply(Cast(data, Int(64)), factor, data_shape, cfg->rounding);
117+
return Cast(data, dtype);
116118
}
117119
}
118120

@@ -164,11 +166,12 @@ Expr QuantizeRealize(const Call& ref_call,
164166
data = Clip(data, clip_min_imm, clip_max_imm);
165167
return QRealizeIntExprNode::make(data, dom_scale, n->dtype);
166168
} else {
167-
// float computation
168-
data = Cast(data, Float(32));
169-
Expr scaled_data = Multiply(data, Divide(n->dom_scale, dom_scale));
170-
Expr round_data = Clip(Round(scaled_data), clip_min_imm, clip_max_imm);
171-
return QRealizeIntExprNode::make(round_data, dom_scale, Float(32));
169+
data = Cast(data, Int(64));
170+
data = qnn::FixedPointMultiply(data, idom_scale_imm / odom_scale_imm,
171+
ref_call->type_as<TensorTypeNode>()->shape,
172+
cfg->rounding);
173+
data = Cast(Clip(data, clip_min_imm, clip_max_imm), n->dtype);
174+
return QRealizeIntExprNode::make(data, dom_scale, n->dtype);
172175
}
173176
}
174177

@@ -355,7 +358,7 @@ Array<Expr> UnifyDTypeScale(const Array<Expr>& ref_args, const Array<Expr>& args
355358
Expr dom_scale = MakeConstantScalar(Float(32), s);
356359
for (size_t i = 0; i < ret.size(); ++i) {
357360
float cur_s = GetScalarFromConstant<float>(nptrs[i]->dom_scale);
358-
ret.Set(i, MulAndDiv(ret[i], cur_s, s, dtype));
361+
ret.Set(i, MulAndDiv(ret[i], cur_s, s, dtype, ref_args[i]->type_as<TensorTypeNode>()->shape));
359362
}
360363

361364
*dtype_ptr = dtype;

src/relay/qnn/op/requantize.cc

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -37,8 +37,6 @@ TVM_REGISTER_NODE_TYPE(RequantizeAttrs);
3737

3838
// Lowering of qnn.requantize op
3939

40-
41-
4240
/*
4341
* \brief Lower requantize to a sequence of ops.
4442
* \param input_tensor The input tensor to requantize op.
@@ -73,8 +71,8 @@ Expr RequantizeLower(const Expr& input_tensor, const RequantizeAttrs* param,
7371
// 2) If the input and output scales are same, we can skip the fixed point multiplication.
7472
auto scaled_int64_t = tensor;
7573
if (param->input_scale != param->output_scale) {
76-
scaled_int64_t = FixedPointMuliply(scaled_int64_t, double_multiplier, input_shape,
77-
param->rounding);
74+
scaled_int64_t =
75+
FixedPointMultiply(scaled_int64_t, double_multiplier, input_shape, param->rounding);
7876
}
7977

8078
// 3) Add the output zero point.

src/relay/qnn/util.cc

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -76,7 +76,7 @@ std::pair<int32_t, int32_t> GetFixedPointMultiplierShift(
7676
return std::make_pair(significand, exponent);
7777
}
7878

79-
Expr FixedPointMuliply(Expr tensor, double multiplier,
79+
Expr FixedPointMultiply(Expr tensor, double multiplier,
8080
const Array<IndexExpr>& input_shape, const std::string& rounding) {
8181
// Choose high precision datatype to be int64. This is for avoiding overflow
8282
// in multiplication of two int32 values.
@@ -121,6 +121,8 @@ Expr FixedPointMuliply(Expr tensor, double multiplier,
121121
auto zero_t = Zeros(input_shape, hp_dtype);
122122
round_scalar =
123123
Where(GreaterEqual(tensor, zero_t), pos_rounder_t, neg_rounder_t);
124+
} else {
125+
LOG(FATAL) << "Rounding mode " << rounding << " not supported.";
124126
}
125127
// Add the rounding scalar.
126128
tensor = Add(tensor, round_scalar);

src/relay/qnn/util.h

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -115,9 +115,9 @@ static inline int64_t get_const_int(const tvm::Expr& x) {
115115
* 2) Round the result.
116116
* 3) Right shift the result
117117
*/
118-
Expr FixedPointMuliply(Expr tensor, double multiplier,
119-
const Array<IndexExpr>& input_shape,
120-
const std::string& rounding);
118+
Expr FixedPointMultiply(Expr tensor, double multiplier,
119+
const Array<IndexExpr>& input_shape,
120+
const std::string& rounding);
121121

122122
} // namespace qnn
123123
} // namespace relay

0 commit comments

Comments
 (0)