Skip to content

Commit 18824b4

Browse files
committed
Add FP requantize flow for llvm target
1 parent 5557b8c commit 18824b4

File tree

5 files changed

+358
-12
lines changed

5 files changed

+358
-12
lines changed

python/tvm/relay/qnn/op/qnn.py

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -92,6 +92,44 @@ def requantize(
9292
)
9393

9494

95+
def upward(data):
96+
r"""Upward operator.
97+
98+
UPWARD is the standard rounding except at midpoints where the value
99+
is rounded to positive infinity (for example, -1.5 rounds to -1).
100+
Parameters
101+
----------
102+
data : tvm.relay.Expr
103+
The input data to the operator.
104+
105+
Returns
106+
-------
107+
result : tvm.relay.Expr
108+
The computed result.
109+
"""
110+
111+
return _make.upward(data)
112+
113+
114+
def tonearest(data):
115+
r"""Tonearest operator.
116+
117+
TONEAREST is the standard rounding where the value is rounded away
118+
from zero at midpoints (for example, -1.5 rounds to -2).
119+
Parameters
120+
----------
121+
data : tvm.relay.Expr
122+
The input data to the operator.
123+
124+
Returns
125+
-------
126+
result : tvm.relay.Expr
127+
The computed result.
128+
"""
129+
130+
return _make.tonearest(data)
131+
132+
95133
def quantize(data, output_scale, output_zero_point, axis=-1, out_dtype="int8"):
96134
r"""Quantize op
97135
This operator takes float32 as input and produces quantized int8 or unit8 as output.

python/tvm/topi/x86/utils.py

Lines changed: 25 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,8 +16,26 @@
1616
# under the License.
1717
"""Common x86 related utilities"""
1818
import tvm
19+
import tvm._ffi
1920

2021

22+
@tvm._ffi.register_func("tvm.topi.x86.utils.target_has_sse41")
23+
def target_has_sse41(target):
24+
return (
25+
target_has_sse42(target)
26+
or target_has_avx(target)
27+
or target_has_avx2(target)
28+
or target_has_avx512(target)
29+
or target_has_vnni(target)
30+
or target
31+
in {
32+
"btver2",
33+
"penryn",
34+
}
35+
)
36+
37+
38+
@tvm._ffi.register_func("tvm.topi.x86.utils.target_has_sse42")
2139
def target_has_sse42(target):
2240
return (
2341
target_has_avx(target)
@@ -42,6 +60,7 @@ def target_has_sse42(target):
4260
)
4361

4462

63+
@tvm._ffi.register_func("tvm.topi.x86.utils.target_has_avx")
4564
def target_has_avx(target):
4665
return (
4766
target_has_avx2(target)
@@ -51,6 +70,7 @@ def target_has_avx(target):
5170
)
5271

5372

73+
@tvm._ffi.register_func("tvm.topi.x86.utils.target_has_avx2")
5474
def target_has_avx2(target):
5575
return (
5676
target_has_avx512(target)
@@ -70,6 +90,7 @@ def target_has_avx2(target):
7090
)
7191

7292

93+
@tvm._ffi.register_func("tvm.topi.x86.utils.target_has_avx512")
7394
def target_has_avx512(target):
7495
return target in {
7596
"skylake-avx512",
@@ -82,26 +103,28 @@ def target_has_avx512(target):
82103
"cascadelake",
83104
"icelake-client",
84105
"rocketlake",
85-
"icelake",
106+
"icelake-server",
86107
"tigerlake",
87108
"cooperlake",
88109
"sapphirerapids",
89110
}
90111

91112

113+
@tvm._ffi.register_func("tvm.topi.x86.utils.target_has_vnni")
92114
def target_has_vnni(target):
93115
return target in {
94116
"cascadelake",
95117
"icelake-client",
96118
"rocketlake",
97-
"icelake",
119+
"icelake-server",
98120
"tigerlake",
99121
"cooperlake",
100122
"sapphirerapids",
101123
"alderlake",
102124
}
103125

104126

127+
@tvm._ffi.register_func("tvm.topi.x86.utils.get_simd_32bit_lanes")
105128
def get_simd_32bit_lanes():
106129
mcpu = tvm.target.Target.current().mcpu
107130
fp32_vec_len = 4

src/relay/qnn/op/requantize.cc

Lines changed: 228 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
#include <tvm/relay/op_attr_types.h>
2727
#include <tvm/relay/qnn/attrs.h>
2828

29+
#include "../../op/op_common.h"
2930
#include "../../transforms/infer_layout_utils.h"
3031
#include "../../transforms/pattern_utils.h"
3132
#include "../utils.h"
@@ -111,6 +112,107 @@ InferCorrectLayoutOutput RequantizeInferCorrectLayout(const Attrs& attrs,
111112
return InferCorrectLayoutOutput(input_layouts, output_layouts, Attrs(param));
112113
}
113114

115+
bool has_current_target_sse41_support() {
116+
auto target = Target::Current(true);
117+
Optional<String> mcpu =
118+
target.defined() ? target->GetAttr<String>("mcpu") : Optional<String>(nullptr);
119+
auto target_has_sse41_fn_ptr = tvm::runtime::Registry::Get("tvm.topi.x86.utils.target_has_sse41");
120+
ICHECK(target_has_sse41_fn_ptr) << "Function tvm.topi.x86.utils.target_has_sse41 not found";
121+
return mcpu && (*target_has_sse41_fn_ptr)(mcpu.value());
122+
}
123+
124+
/*
125+
* \brief TONEAREST is the standard rounding where the value is rounded away
126+
* from zero at midpoints (for example, -1.5 rounds to -2).
127+
* \param input_tensor The input tensor to rounding op.
128+
* \return The sequence of existing Relay ops.
129+
*/
130+
Expr Tonearest(const Expr& input_tensor) {
131+
if (has_current_target_sse41_support()) return Round(input_tensor);
132+
133+
auto half = MakeConstantScalar(DataType::Float(64), 0.5f);
134+
auto zero = MakeConstantScalar(DataType::Float(64), 0.f);
135+
auto pos_one = MakeConstantScalar(DataType::Float(64), +1.f);
136+
auto neg_one = MakeConstantScalar(DataType::Float(64), -1.f);
137+
auto multiplier = Where(Less(input_tensor, zero), neg_one, pos_one);
138+
auto half_multiplied = Multiply(half, multiplier);
139+
auto input_tensor_biased = Add(input_tensor, half_multiplied);
140+
auto input_tensor_biased_multiplied = Multiply(input_tensor_biased, multiplier);
141+
auto input_tensor_biased_multiplied_int64 =
142+
Cast(input_tensor_biased_multiplied, DataType::Int(64));
143+
auto input_tensor_biased_multiplied_float64 =
144+
Cast(input_tensor_biased_multiplied_int64, DataType::Float(64));
145+
auto input_tensor_rounded = Multiply(input_tensor_biased_multiplied_float64, multiplier);
146+
return Where(IsFinite(input_tensor), input_tensor_rounded, input_tensor);
147+
}
148+
149+
/*
150+
* \brief UPWARD is the standard rounding except at midpoints where the value
151+
* is rounded to positive infinity (for example, -1.5 rounds to -1).
152+
* \param input_tensor The input tensor to rounding op.
153+
* \return The sequence of existing Relay ops.
154+
*/
155+
Expr Upward(const Expr& input_tensor) {
156+
auto half = MakeConstantScalar(DataType::Float(64), 0.5f);
157+
auto input_tensor_biased = Add(input_tensor, half);
158+
if (has_current_target_sse41_support()) return Floor(input_tensor_biased);
159+
160+
auto zero = MakeConstantScalar(DataType::Float(64), 0.f);
161+
auto one = MakeConstantScalar(DataType::Float(64), +1.f);
162+
auto input_tensor_biased_int64 = Cast(input_tensor_biased, DataType::Int(64));
163+
auto input_tensor_biased_float64 = Cast(input_tensor_biased_int64, DataType::Float(64));
164+
auto is_subtraction_not_necessary =
165+
LogicalOr(Equal(input_tensor_biased, input_tensor_biased_float64),
166+
GreaterEqual(input_tensor_biased, zero));
167+
auto input_tensor_rounded = Where(is_subtraction_not_necessary, input_tensor_biased_float64,
168+
Subtract(input_tensor_biased_float64, one));
169+
return Where(IsFinite(input_tensor), input_tensor_rounded, input_tensor);
170+
}
171+
172+
// Positional relay function to create tonearest operator
173+
// used by frontend FFI.
174+
Expr MakeTonearest(const Attrs& attrs, const Array<Expr>& new_args,
175+
const Array<tvm::relay::Type>& types) {
176+
ICHECK_EQ(new_args.size(), 1);
177+
auto& data = new_args[0];
178+
return Tonearest(data);
179+
}
180+
181+
RELAY_REGISTER_OP("tonearest")
182+
.set_num_inputs(1)
183+
.add_argument("data", "Tensor", "The input tensor.")
184+
.add_type_rel("Identity", IdentityRel)
185+
.set_attr<TOpPattern>("TOpPattern", kElemWise)
186+
.set_attr<TOpIsStateful>("TOpIsStateful", false)
187+
.set_attr<FTVMLegalize>("FTVMQnnCanonicalize", MakeTonearest);
188+
189+
TVM_REGISTER_GLOBAL("relay.qnn.op._make.tonearest").set_body_typed([](Expr data) {
190+
static const Op& op = Op::Get("tonearest");
191+
return Call(op, {data}, Attrs(), {});
192+
});
193+
194+
// Positional relay function to create upward operator
195+
// used by frontend FFI.
196+
Expr MakeUpward(const Attrs& attrs, const Array<Expr>& new_args,
197+
const Array<tvm::relay::Type>& types) {
198+
ICHECK_EQ(new_args.size(), 1);
199+
auto& data = new_args[0];
200+
return Upward(data);
201+
}
202+
203+
RELAY_REGISTER_OP("upward")
204+
.set_num_inputs(1)
205+
.add_argument("data", "Tensor", "The input tensor.")
206+
.add_type_rel("Identity", IdentityRel)
207+
.set_attr<TOpPattern>("TOpPattern", kElemWise)
208+
.set_attr<TOpIsStateful>("TOpIsStateful", false)
209+
.set_attr<FTVMLegalize>("FTVMQnnCanonicalize", MakeUpward);
210+
211+
TVM_REGISTER_GLOBAL("relay.qnn.op._make.upward").set_body_typed([](Expr data) {
212+
static const Op& op = Op::Get("upward");
213+
return Call(op, {data}, Attrs(), {});
214+
});
215+
114216
// Lowering of qnn.requantize op
115217

116218
/*
@@ -119,7 +221,7 @@ InferCorrectLayoutOutput RequantizeInferCorrectLayout(const Attrs& attrs,
119221
* \param param The requantize op attrs.
120222
* \param input_shape The input tensor shape of the requantize op.
121223
* \return The sequence of existing Relay ops.
122-
* \note Requantization using only integer computation. Here, the computation is
224+
* \note RequantizationInt using only integer computation. Here, the computation is
123225
* converted to a fixed point computation by computing output multiplier
124226
* and shift. This is useful, if the target device does not support/have
125227
* very expensive floating point computations.
@@ -131,10 +233,10 @@ InferCorrectLayoutOutput RequantizeInferCorrectLayout(const Attrs& attrs,
131233
* 4) Add the output zero point.
132234
* 5) Cast to the out_dtype.
133235
*/
134-
Expr RequantizeLower(const Expr& input_tensor, const Expr& input_scale,
135-
const Expr& input_zero_point, const Expr& output_scale,
136-
const Expr& output_zero_point, const RequantizeAttrs* param,
137-
const Array<IndexExpr>& input_shape, const DataType& out_dtype) {
236+
Expr RequantizeLowerInt(const Expr& input_tensor, const Expr& input_scale,
237+
const Expr& input_zero_point, const Expr& output_scale,
238+
const Expr& output_zero_point, const RequantizeAttrs* param,
239+
const Array<IndexExpr>& input_shape, const DataType& out_dtype) {
138240
auto tensor = Cast(input_tensor, DataType::Int(32));
139241
auto zero_scalar = MakeConstantScalar(DataType::Int(32), 0);
140242
if (!IsEqualScalar(input_zero_point, zero_scalar)) {
@@ -208,6 +310,127 @@ Expr RequantizeLower(const Expr& input_tensor, const Expr& input_scale,
208310
return Cast(clipped_t, out_dtype);
209311
}
210312

313+
// Lowering of qnn.requantize op
314+
315+
/*
316+
* \brief Lower requantize to a sequence of ops.
317+
* \param input_tensor The input tensor to requantize op.
318+
* \param param The requantize op attrs.
319+
* \param input_shape The input tensor shape of the requantize op.
320+
* \return The sequence of existing Relay ops.
321+
* \note RequantizationFP using floating computation. All multiplication/sub/sum
322+
* occurs in floating point data type and only at the end is converted to
323+
* int32 data type and clamped for output data type.
324+
*
325+
* The whole computation this can be broken down into following steps
326+
* 1) Subtract the input zero point.
327+
* 2) Perform multiplication.
328+
* 3) Add the output zero point.
329+
* 4) Cast to the out_dtype.
330+
*/
331+
Expr RequantizeLowerFP(const Expr& input_tensor, const Expr& input_scale,
332+
const Expr& input_zero_point, const Expr& output_scale,
333+
const Expr& output_zero_point, const RequantizeAttrs* param,
334+
const Array<IndexExpr>& input_shape, const DataType& out_dtype) {
335+
auto tensor = Cast(input_tensor, DataType::Float(64));
336+
auto zero_scalar = MakeConstantScalar(DataType::Int(32), 0);
337+
if (!IsEqualScalar(input_zero_point, zero_scalar)) {
338+
// Broadcast input zero point if needed.
339+
int rank = static_cast<int>(input_shape.size());
340+
int axis = (param->axis < 0) ? ((rank > 0) ? rank + param->axis : 0) : param->axis;
341+
Expr input_zero_broadcast = ExpandBiasToMatchAxis(Reshape(input_zero_point,
342+
{
343+
-1,
344+
}),
345+
rank, {axis});
346+
tensor = Subtract(tensor, Cast(input_zero_broadcast, DataType::Float(64)));
347+
}
348+
349+
// 2) If the input and output scales are same, we can skip the multiplication. Check
350+
// if the input scale is per-tensor or per-channel. If it is per-tensor, there is single scale for
351+
// the whole tensor. For per-channel (aka per-axis), there is a vector of scales for the input
352+
// tensor. Depending on the quantization type, the fixed point multiplication routing is called.
353+
auto scaled_fp64_t = tensor;
354+
double output_scale_float = GetScalarFromConstant<float>(output_scale);
355+
if (IsConstScalar(input_scale)) {
356+
// This is per-tensor quantization. Single scale.
357+
double input_scale_float = GetScalarFromConstant<float>(input_scale);
358+
double double_multiplier = input_scale_float / output_scale_float;
359+
// Skip if input and output scales are same.
360+
if (!IsEqualScalar(input_scale, output_scale)) {
361+
double multiplier = double_multiplier;
362+
auto m_scalar = MakeConstantScalar(DataType::Float(64), multiplier);
363+
scaled_fp64_t = Multiply(m_scalar, scaled_fp64_t);
364+
}
365+
366+
} else {
367+
// This is per-channel (per=axis) quantization.
368+
std::vector<double> double_multipliers;
369+
auto input_axis_scales = GetFloatVectorFromConstant(input_scale);
370+
double output_scale_float = GetScalarFromConstant<float>(output_scale);
371+
for (auto input_axis_scale : input_axis_scales) {
372+
double multiplier = static_cast<double>(input_axis_scale) / output_scale_float;
373+
double_multipliers.push_back(multiplier);
374+
}
375+
int axis = param->axis;
376+
axis = (axis == -1) ? input_shape.size() - 1 : axis;
377+
378+
auto fixed_pt_multiplier_expr = MakeConstantTensor(
379+
DataType::Float(64), {(int64_t)double_multipliers.size()}, double_multipliers);
380+
size_t n_dim = input_shape.size();
381+
auto exp_fixed_pt_multiplier_expr =
382+
ExpandBiasToMatchAxis(fixed_pt_multiplier_expr, n_dim, {axis});
383+
384+
scaled_fp64_t = Multiply(scaled_fp64_t, exp_fixed_pt_multiplier_expr);
385+
}
386+
387+
// 3) Add the output zero point.
388+
auto shifted_fp64_t = scaled_fp64_t;
389+
if (!IsEqualScalar(output_zero_point, zero_scalar)) {
390+
shifted_fp64_t = Add(shifted_fp64_t, Cast(output_zero_point, DataType::Float(64)));
391+
}
392+
393+
if (param->rounding == "UPWARD") {
394+
shifted_fp64_t = Upward(shifted_fp64_t);
395+
} else /*if (param->rounding == "TONEAREST")*/ {
396+
shifted_fp64_t = Tonearest(shifted_fp64_t);
397+
}
398+
399+
shifted_fp64_t = Cast(shifted_fp64_t, DataType::Int(32));
400+
// 4) Clip to the out_dtype min/max. Skip clipping if out_dtype is Int32. The fixed point
401+
// multiplication keeps the value in int32 range.
402+
if (out_dtype == DataType::Int(32)) {
403+
return shifted_fp64_t;
404+
}
405+
406+
auto q_min = GetQmin(out_dtype);
407+
auto q_max = GetQmax(out_dtype);
408+
auto clipped_t = Clip(shifted_fp64_t, q_min, q_max);
409+
return Cast(clipped_t, out_dtype);
410+
}
411+
412+
// Lowering of qnn.requantize op
413+
/*
414+
* \brief Lower requantize to a sequence of ops.
415+
* \param input_tensor The input tensor to requantize op.
416+
* \param param The requantize op attrs.
417+
* \param input_shape The input tensor shape of the requantize op.
418+
* \return The sequence of existing Relay ops.
419+
*/
420+
Expr RequantizeLower(const Expr& input_tensor, const Expr& input_scale,
421+
const Expr& input_zero_point, const Expr& output_scale,
422+
const Expr& output_zero_point, const RequantizeAttrs* param,
423+
const Array<IndexExpr>& input_shape, const DataType& out_dtype) {
424+
auto target = Target::Current(true);
425+
if (target.defined() && target->kind->name == "llvm") {
426+
return RequantizeLowerFP(input_tensor, input_scale, input_zero_point, output_scale,
427+
output_zero_point, param, input_shape, out_dtype);
428+
} else {
429+
return RequantizeLowerInt(input_tensor, input_scale, input_zero_point, output_scale,
430+
output_zero_point, param, input_shape, out_dtype);
431+
}
432+
}
433+
211434
/*
212435
* \brief Forward rewrite the requantize op.
213436
* \param ref_call The original call that will be lowered.

0 commit comments

Comments
 (0)