Skip to content

Commit cff0bdd

Browse files
committed
[Relay] [Quantization] WIP - Protoyping the quantized convolution op
Goal - Act as medium of discussion for pull request #2351 Features - New quantized conv2D op in Relay - Python API interface to instantiate the Relay op - Infer Type implemented - Lowering of quantized_conv op to low-level Relay ops Discussion points - Does the namespace look correct? - Relay op is called 'relay.op.nn._quantize.quantized_conv2d' - Idea is that any op under '_quantize' namespace will go through rewrite. - Should we reuse Conv2DRel and Conv2DAttrs - Tried protoyping. Found it hard to derive from Conv2DAttr struct - Infer Type has a param field. This need to come from the right datatype. Missing implememtation - Lowering of quantized conv into conv+cast is incomplete. - Will work on it async. This is orthogonal to the discussion. Adding the fixed point compute handling for requantiazation.
1 parent 2a7aebe commit cff0bdd

File tree

13 files changed

+1309
-4
lines changed

13 files changed

+1309
-4
lines changed

include/tvm/relay/attrs/qnn.h

Lines changed: 125 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,125 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one
3+
* or more contributor license agreements. See the NOTICE file
4+
* distributed with this work for additional information
5+
* regarding copyright ownership. The ASF licenses this file
6+
* to you under the Apache License, Version 2.0 (the
7+
* "License"); you may not use this file except in compliance
8+
* with the License. You may obtain a copy of the License at
9+
*
10+
* http://www.apache.org/licenses/LICENSE-2.0
11+
*
12+
* Unless required by applicable law or agreed to in writing,
13+
* software distributed under the License is distributed on an
14+
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15+
* KIND, either express or implied. See the License for the
16+
* specific language governing permissions and limitations
17+
* under the License.
18+
*/
19+
20+
/*!
21+
* \file tvm/relay/attrs/nn.h
22+
* \brief Auxiliary attributes for nn operators.
23+
*/
24+
#ifndef TVM_RELAY_ATTRS_NN_QUANTIZE_H_
25+
#define TVM_RELAY_ATTRS_NN_QUANTIZE_H_
26+
27+
#include <tvm/attrs.h>
28+
#include <string>
29+
30+
namespace tvm {
31+
namespace relay {
32+
33+
/*! \brief Attribute for quantized conv2d operator */
34+
struct QConv2DAttrs : public tvm::AttrsNode<QConv2DAttrs> {
35+
// Traditional conv2d attributes.
36+
Array<IndexExpr> strides;
37+
Array<IndexExpr> padding;
38+
Array<IndexExpr> dilation;
39+
int groups;
40+
IndexExpr channels;
41+
Array<IndexExpr> kernel_size;
42+
std::string data_layout;
43+
std::string kernel_layout;
44+
std::string out_layout;
45+
DataType out_dtype;
46+
47+
// Quantization related attributes.
48+
int32_t input_zero_point;
49+
int32_t kernel_zero_point;
50+
51+
TVM_DECLARE_ATTRS(QConv2DAttrs, "relay.attrs.QConv2DAttrs") {
52+
TVM_ATTR_FIELD(strides).set_default(Array<IndexExpr>({1, 1}))
53+
.describe("Specifies the strides of the convolution.");
54+
TVM_ATTR_FIELD(padding).set_default(Array<IndexExpr>({0, 0}))
55+
.describe("If padding is non-zero, then the input is implicitly zero-padded"
56+
"on both sides for padding number of points");
57+
TVM_ATTR_FIELD(dilation).set_default(Array<IndexExpr>({1, 1}))
58+
.describe("Specifies the dilation rate to use for dilated convolution.");
59+
TVM_ATTR_FIELD(groups).set_default(1)
60+
.describe("Controls the connections between inputs and outputs."
61+
"At groups=1, all inputs are convolved to all outputs."
62+
"At groups=2, the operation becomes equivalent to having two convolution"
63+
"layers side by side, each seeing half the input channels, and producing"
64+
"half the output channels, and both subsequently concatenated.");
65+
TVM_ATTR_FIELD(channels)
66+
.describe("The number of output channels in the convolution."
67+
" If it is not set, inferred by shape of the weight.")
68+
.set_default(NullValue<IndexExpr>());
69+
TVM_ATTR_FIELD(kernel_size)
70+
.describe("Specifies the dimensions of the convolution window.")
71+
.set_default(NullValue<Array<IndexExpr> >());
72+
TVM_ATTR_FIELD(data_layout).set_default("NCHW")
73+
.describe("Dimension ordering of input data. Can be 'NCHW', 'NHWC', etc."
74+
"'N', 'C', 'H', 'W' stands for batch, channel, height, and width"
75+
"dimensions respectively. Convolution is applied on the 'H' and"
76+
"'W' dimensions.");
77+
TVM_ATTR_FIELD(kernel_layout).set_default("OIHW")
78+
.describe("Dimension ordering of weight. Can be 'OIHW', 'OIHW16o16i', etc."
79+
"'O', 'I', 'H', 'W' stands for num_filter, input_channel, height, and width"
80+
"dimensions respectively.");
81+
TVM_ATTR_FIELD(out_layout).set_default("")
82+
.describe("Dimension ordering of output. Can be 'NCHW', 'NHWC', etc."
83+
"'N', 'C', 'H', 'W' stands for batch, channel, height, and width"
84+
"dimensions respectively. Default to be same as input layout.");
85+
TVM_ATTR_FIELD(out_dtype)
86+
.set_default(NullValue<DataType>())
87+
.describe("Output data type, set to explicit type under mixed precision setting");
88+
TVM_ATTR_FIELD(input_zero_point)
89+
.describe("The zero point of the input tensor.");
90+
TVM_ATTR_FIELD(kernel_zero_point)
91+
.describe("The zero point of the kernel tensor.");
92+
}
93+
};
94+
95+
96+
/*! \brief Attribute for requantize operator */
97+
struct RequantizeAttrs : public tvm::AttrsNode<RequantizeAttrs> {
98+
double input_scale;
99+
int32_t input_zero_point;
100+
double output_scale;
101+
int32_t output_zero_point;
102+
bool use_int_compute;
103+
DataType out_dtype;
104+
105+
TVM_DECLARE_ATTRS(RequantizeAttrs, "relay.attrs.RequantizeAttrs") {
106+
TVM_ATTR_FIELD(input_zero_point)
107+
.describe("The zero point of the input tensor.");
108+
TVM_ATTR_FIELD(output_zero_point)
109+
.describe("The zero point of the output tensor.");
110+
TVM_ATTR_FIELD(input_scale)
111+
.describe("The scale of the input tensor.");
112+
TVM_ATTR_FIELD(output_scale)
113+
.describe("The scale of the output tensor.");
114+
TVM_ATTR_FIELD(use_int_compute).set_default(false)
115+
.describe("When true, the integer computation is used to handle output scale");
116+
TVM_ATTR_FIELD(out_dtype)
117+
.set_default(NullValue<DataType>())
118+
.describe("Output data type, set to explicit type under mixed precision setting");
119+
}
120+
};
121+
122+
123+
} // namespace relay
124+
} // namespace tvm
125+
#endif // TVM_RELAY_ATTRS_NN_QUANTIZE_H_

include/tvm/relay/quantize_util.h

Lines changed: 139 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,139 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one
3+
* or more contributor license agreements. See the NOTICE file
4+
* distributed with this work for additional information
5+
* regarding copyright ownership. The ASF licenses this file
6+
* to you under the Apache License, Version 2.0 (the
7+
* "License"); you may not use this file except in compliance
8+
* with the License. You may obtain a copy of the License at
9+
*
10+
* http://www.apache.org/licenses/LICENSE-2.0
11+
*
12+
* Unless required by applicable law or agreed to in writing,
13+
* software distributed under the License is distributed on an
14+
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15+
* KIND, either express or implied. See the License for the
16+
* specific language governing permissions and limitations
17+
* under the License.
18+
*/
19+
20+
/*!
21+
* \file nnvm/compiler/quantize_util.h
22+
* \brief Utility methods needs for quantized ops that can be shared
23+
*/
24+
25+
#ifndef TVM_QUANTIZE_UTIL_H
26+
#define TVM_QUANTIZE_UTIL_H
27+
28+
#include <tvm/expr.h>
29+
#include "./base.h"
30+
31+
namespace tvm {
32+
namespace relay {
33+
34+
inline bool is_Int8(const DataType& dtype) {
35+
return dtype == Int(8);
36+
}
37+
38+
inline bool is_UInt8(const DataType& dtype) {
39+
return dtype == UInt(8);
40+
}
41+
42+
43+
inline bool is_Int16(const DataType& dtype) {
44+
return dtype == Int(16);
45+
}
46+
47+
inline bool is_UInt16(const DataType& dtype) {
48+
return dtype == UInt(16);
49+
}
50+
51+
inline bool is_Int32(const DataType& dtype) {
52+
return dtype == Int(32);
53+
}
54+
55+
inline bool is_UInt32(const DataType& dtype) {
56+
return dtype == UInt(32);
57+
}
58+
59+
60+
61+
inline bool is_Float32(const DataType& dtype) {
62+
return dtype == Float(32);
63+
}
64+
65+
inline bool is_quantized_type(const DataType& dtype) {
66+
return is_Int8(dtype) || is_UInt8(dtype)
67+
|| is_Int16(dtype) || is_UInt16(dtype);
68+
}
69+
70+
enum class QuantizeOpType : uint8_t {
71+
Quantize_Requantize,
72+
Dequantize,
73+
Requantize
74+
};
75+
76+
inline bool is_valid_quantized_op_input_type(const QuantizeOpType &op_type, const DataType &in_dtype) {
77+
switch(op_type) {
78+
case QuantizeOpType::Quantize_Requantize:
79+
return is_Float32(in_dtype) || is_quantized_type(in_dtype);
80+
case QuantizeOpType ::Dequantize:
81+
return is_quantized_type(in_dtype);
82+
case QuantizeOpType ::Requantize:
83+
return is_Int16(in_dtype) || is_Int32(in_dtype);
84+
default:
85+
return false;
86+
}
87+
}
88+
89+
inline bool is_valid_quantized_op_output_type(const QuantizeOpType &op_type, const DataType &in_dtype) {
90+
switch(op_type) {
91+
case QuantizeOpType::Quantize_Requantize:
92+
return is_quantized_type(in_dtype);
93+
case QuantizeOpType::Dequantize:
94+
return is_Float32(in_dtype);
95+
default:
96+
return false;
97+
}
98+
}
99+
100+
inline const int32_t get_qmin(const DataType& dtype) {
101+
if (is_Int8(dtype)) {
102+
return std::numeric_limits<int8_t>::min();
103+
} else if (is_UInt8(dtype)) {
104+
return std::numeric_limits<uint8_t>::min();
105+
} else if (is_Int16(dtype)) {
106+
return std::numeric_limits<int16_t>::min();
107+
} else if (is_UInt16(dtype)) {
108+
return std::numeric_limits<uint16_t>::min();
109+
} else if (is_Int32(dtype)) {
110+
return std::numeric_limits<int32_t>::min();
111+
} else if (is_UInt32(dtype)) {
112+
return std::numeric_limits<uint32_t>::min();
113+
}
114+
LOG(FATAL) << "Type not supported\n";
115+
return -1;
116+
}
117+
118+
119+
inline const int32_t get_qmax(const DataType& dtype) {
120+
if (is_Int8(dtype)) {
121+
return std::numeric_limits<int8_t>::max();
122+
} else if (is_UInt8(dtype)) {
123+
return std::numeric_limits<uint8_t>::max();
124+
} else if (is_Int16(dtype)) {
125+
return std::numeric_limits<int16_t>::max();
126+
} else if (is_UInt16(dtype)) {
127+
return std::numeric_limits<uint16_t>::max();
128+
} else if (is_Int32(dtype)) {
129+
return std::numeric_limits<int32_t>::max();
130+
} else if (is_UInt32(dtype)) {
131+
return std::numeric_limits<uint32_t>::max();
132+
}
133+
LOG(FATAL) << "Type not supported\n";
134+
return -1;
135+
}
136+
137+
} // namespace relay
138+
} // namespace tvm
139+
#endif //TVM_QUANTIZE_UTIL_H

python/tvm/relay/op/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
from .transform import *
2727
from .algorithm import *
2828
from . import nn
29+
from . import qnn
2930
from . import annotation
3031
from . import image
3132
from . import vision
Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
# Licensed to the Apache Software Foundation (ASF) under one
2+
# or more contributor license agreements. See the NOTICE file
3+
# distributed with this work for additional information
4+
# regarding copyright ownership. The ASF licenses this file
5+
# to you under the Apache License, Version 2.0 (the
6+
# "License"); you may not use this file except in compliance
7+
# with the License. You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing,
12+
# software distributed under the License is distributed on an
13+
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14+
# KIND, either express or implied. See the License for the
15+
# specific language governing permissions and limitations
16+
# under the License.
17+
# pylint: disable=wildcard-import
18+
"""Neural network related operators."""
19+
from __future__ import absolute_import as _abs
20+
from .qnn import *
21+
# from . import _nn

python/tvm/relay/op/qnn/_make.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
# Licensed to the Apache Software Foundation (ASF) under one
2+
# or more contributor license agreements. See the NOTICE file
3+
# distributed with this work for additional information
4+
# regarding copyright ownership. The ASF licenses this file
5+
# to you under the Apache License, Version 2.0 (the
6+
# "License"); you may not use this file except in compliance
7+
# with the License. You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing,
12+
# software distributed under the License is distributed on an
13+
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14+
# KIND, either express or implied. See the License for the
15+
# specific language governing permissions and limitations
16+
# under the License.
17+
"""Constructor APIs"""
18+
from ...._ffi.function import _init_api
19+
20+
_init_api("relay.op.qnn._make", __name__)

0 commit comments

Comments
 (0)