Skip to content

Commit 9403ba6

Browse files
committed
Adding the fixed point compute handling for requantiazation.
1 parent f365ea7 commit 9403ba6

File tree

7 files changed

+858
-135
lines changed

7 files changed

+858
-135
lines changed

include/tvm/relay/attrs/nn_quantize.h

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,8 @@ struct QuantizedConv2DAttrs : public tvm::AttrsNode<QuantizedConv2DAttrs> {
5252
double input_scale;
5353
double kernel_scale;
5454
double output_scale;
55-
bool use_integer_computation_for_scale_handling;
55+
bool use_int_compute_for_requantize;
56+
std::string rounding;
5657

5758
TVM_DECLARE_ATTRS(QuantizedConv2DAttrs, "relay.attrs.QuantizedConv2DAttrs") {
5859
TVM_ATTR_FIELD(strides).set_default(Array<IndexExpr>({1, 1}))
@@ -107,8 +108,10 @@ struct QuantizedConv2DAttrs : public tvm::AttrsNode<QuantizedConv2DAttrs> {
107108
.describe("The scale of the kernel tensor.");
108109
TVM_ATTR_FIELD(output_scale)
109110
.describe("The scale of the output tensor.");
110-
TVM_ATTR_FIELD(use_integer_computation_for_scale_handling).set_default(false)
111+
TVM_ATTR_FIELD(use_int_compute_for_requantize).set_default(false)
111112
.describe("When true, the integer computation is used to handle output scale");
113+
TVM_ATTR_FIELD(rounding).set_default("ceil")
114+
.describe("The rounding that has to be used for handling scales.");
112115

113116

114117
}

include/tvm/relay/quantize_util.h

Lines changed: 136 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,136 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one
3+
* or more contributor license agreements. See the NOTICE file
4+
* distributed with this work for additional information
5+
* regarding copyright ownership. The ASF licenses this file
6+
* to you under the Apache License, Version 2.0 (the
7+
* "License"); you may not use this file except in compliance
8+
* with the License. You may obtain a copy of the License at
9+
*
10+
* http://www.apache.org/licenses/LICENSE-2.0
11+
*
12+
* Unless required by applicable law or agreed to in writing,
13+
* software distributed under the License is distributed on an
14+
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15+
* KIND, either express or implied. See the License for the
16+
* specific language governing permissions and limitations
17+
* under the License.
18+
*/
19+
20+
/*!
21+
* \file nnvm/compiler/quantize_util.h
22+
* \brief Utility methods needs for quantized ops that can be shared
23+
*/
24+
25+
#ifndef TVM_QUANTIZE_UTIL_H
26+
#define TVM_QUANTIZE_UTIL_H
27+
28+
#include <tvm/expr.h>
29+
#include "./base.h"
30+
31+
namespace tvm {
32+
namespace relay {
33+
34+
inline bool is_Int8(const DataType& dtype) {
35+
return dtype == Int(8);
36+
}
37+
38+
inline bool is_UInt8(const DataType& dtype) {
39+
return dtype == UInt(8);
40+
}
41+
42+
43+
inline bool is_Int16(const DataType& dtype) {
44+
return dtype == Int(16);
45+
}
46+
47+
inline bool is_UInt16(const DataType& dtype) {
48+
return dtype == UInt(16);
49+
}
50+
51+
inline bool is_Int32(const DataType& dtype) {
52+
return dtype == Int(32);
53+
}
54+
55+
inline bool is_UInt32(const DataType& dtype) {
56+
return dtype == UInt(32);
57+
}
58+
59+
60+
61+
inline bool is_Float32(const DataType& dtype) {
62+
return dtype == Float(32);
63+
}
64+
65+
inline bool is_quantized_type(const DataType& dtype) {
66+
return is_Int8(dtype) || is_UInt8(dtype)
67+
|| is_Int16(dtype) || is_UInt16(dtype);
68+
}
69+
70+
enum class QuantizeOpType : uint8_t {
71+
Quantize_Requantize,
72+
Dequantize
73+
};
74+
75+
inline bool is_valid_quantized_op_input_type(const QuantizeOpType &op_type, const DataType &in_dtype) {
76+
switch(op_type) {
77+
case QuantizeOpType::Quantize_Requantize:
78+
return is_Float32(in_dtype) || is_quantized_type(in_dtype);
79+
case QuantizeOpType ::Dequantize:
80+
return is_quantized_type(in_dtype);
81+
default:
82+
return false;
83+
}
84+
}
85+
86+
inline bool is_valid_quantized_op_output_type(const QuantizeOpType &op_type, const DataType &in_dtype) {
87+
switch(op_type) {
88+
case QuantizeOpType::Quantize_Requantize:
89+
return is_quantized_type(in_dtype);
90+
case QuantizeOpType::Dequantize:
91+
return is_Float32(in_dtype);
92+
default:
93+
return false;
94+
}
95+
}
96+
97+
inline const int32_t get_qmin(const DataType& dtype) {
98+
if (is_Int8(dtype)) {
99+
return std::numeric_limits<int8_t>::min();
100+
} else if (is_UInt8(dtype)) {
101+
return std::numeric_limits<uint8_t>::min();
102+
} else if (is_Int16(dtype)) {
103+
return std::numeric_limits<int16_t>::min();
104+
} else if (is_UInt16(dtype)) {
105+
return std::numeric_limits<uint16_t>::min();
106+
} else if (is_Int32(dtype)) {
107+
return std::numeric_limits<int32_t>::min();
108+
} else if (is_UInt32(dtype)) {
109+
return std::numeric_limits<uint32_t>::min();
110+
}
111+
LOG(FATAL) << "Type not supported\n";
112+
return -1;
113+
}
114+
115+
116+
inline const int32_t get_qmax(const DataType& dtype) {
117+
if (is_Int8(dtype)) {
118+
return std::numeric_limits<int8_t>::max();
119+
} else if (is_UInt8(dtype)) {
120+
return std::numeric_limits<uint8_t>::max();
121+
} else if (is_Int16(dtype)) {
122+
return std::numeric_limits<int16_t>::max();
123+
} else if (is_UInt16(dtype)) {
124+
return std::numeric_limits<uint16_t>::max();
125+
} else if (is_Int32(dtype)) {
126+
return std::numeric_limits<int32_t>::max();
127+
} else if (is_UInt32(dtype)) {
128+
return std::numeric_limits<uint32_t>::max();
129+
}
130+
LOG(FATAL) << "Type not supported\n";
131+
return -1;
132+
}
133+
134+
} // namespace relay
135+
} // namespace tvm
136+
#endif //TVM_QUANTIZE_UTIL_H

python/tvm/relay/op/nn/_quantize.py

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,9 @@ def quantized_conv2d(quantized_data,
3737
data_layout="NCHW",
3838
kernel_layout="OIHW",
3939
out_layout="",
40-
out_dtype=""):
40+
out_dtype="",
41+
rounding="ceil",
42+
use_int_compute_for_requantize=False):
4143
r"""Quantized 2D convolution.
4244
4345
This operator takes the quantized_weight as the convolution kernel
@@ -119,6 +121,12 @@ def quantized_conv2d(quantized_data,
119121
out_dtype : str, optional
120122
Specifies the output quantized_data type for mixed precision conv2d.
121123
124+
rounding : str, optional
125+
Specificies which rounding to use - floor, ceil, round, trunc.
126+
127+
use_int_compute_for_requantize : bool, optional
128+
Use fully integer computation for requantizing.
129+
122130
Returns
123131
-------
124132
result : tvm.relay.Expr
@@ -130,4 +138,5 @@ def quantized_conv2d(quantized_data,
130138
strides, padding, dilation,
131139
groups, channels, kernel_size,
132140
data_layout, kernel_layout, out_layout,
133-
out_dtype)
141+
out_dtype, rounding,
142+
use_int_compute_for_requantize)

src/relay/op/nn/quantized_convolution.cc

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -146,7 +146,9 @@ Expr MakeQuantizeConv2D(Expr quantized_data,
146146
std::string data_layout,
147147
std::string kernel_layout,
148148
std::string out_layout,
149-
DataType out_dtype) {
149+
DataType out_dtype,
150+
std::string rounding,
151+
bool use_int_compute_for_requantize) {
150152
auto attrs = make_node<QuantizedConv2DAttrs>();
151153
attrs->strides = std::move(strides);
152154
attrs->padding = std::move(padding);
@@ -164,6 +166,8 @@ Expr MakeQuantizeConv2D(Expr quantized_data,
164166
attrs->input_scale = std::move(input_scale);
165167
attrs->kernel_scale = std::move(kernel_scale);
166168
attrs->output_scale = std::move(output_scale);
169+
attrs->rounding = std::move(rounding);
170+
attrs->use_int_compute_for_requantize = std::move(use_int_compute_for_requantize);
167171
static const Op& op = Op::Get("nn_quantized.quantized_conv2d");
168172
return CallNode::make(op, {quantized_data, quantized_weight}, Attrs(attrs), {});
169173
}

src/relay/pass/pattern_util.h

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -399,6 +399,25 @@ inline Expr Conv2D(Expr data,
399399
return CallNode::make(op, {data, weight}, Attrs(attrs), {});
400400
}
401401

402+
inline Expr Where(const Expr& condition, const Expr& x, const Expr& y) {
403+
static const Op& op = Op::Get("where");
404+
return CallNode::make(op, {condition, x, y});
405+
}
406+
407+
inline Expr GreaterEqual(const Expr& lhs, const Expr& rhs) {
408+
static const Op& op = Op::Get("greater_equal");
409+
return CallNode::make(op, {lhs, rhs}, Attrs(), {});
410+
}
411+
412+
inline Expr Full(Expr fill_value,
413+
Array<IndexExpr> shape,
414+
DataType dtype) {
415+
auto attrs = make_node<InitOpAttrs>();
416+
attrs->shape = std::move(shape);
417+
attrs->dtype = std::move(dtype);
418+
static const Op& op = Op::Get("full");
419+
return CallNode::make(op, {fill_value}, Attrs(attrs), {});
420+
}
402421
Expr MakeConcatenate(Expr data, int axis);
403422

404423
Expr MakeStridedSlice(Expr data, Array<Integer> begin, Array<Integer> end, Array<Integer> strides);

0 commit comments

Comments
 (0)