Skip to content

Commit 1f5d289

Browse files
shoubhikwweic
authored andcommitted
QNN quantize and dequantize operators. (apache#3745)
* QNN quantize and dequantize operators. * addressing review comments. * addressing review comments. * Adding new line at the end of the file. * Adhering to styling guidelines. * Adding name to contributors. * Fixing lint issue. * Fixing file name. * Removing unnecessary code.
1 parent f354f07 commit 1f5d289

File tree

9 files changed

+464
-11
lines changed

9 files changed

+464
-11
lines changed

CONTRIBUTORS.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -111,3 +111,4 @@ We do encourage everyone to work anything they are interested in.
111111
- [Haolong Zhang](https://github.com/haolongzhangm)
112112
- [Cody Hao Yu](https://github.com/comaniac)
113113
- [Chris Nuernberger](https://github.com/cnuernber)
114+
- [Shoubhik Bhattacharya](https://github.com/shoubhik)

include/tvm/relay/qnn/attrs.h

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,38 @@ struct RequantizeAttrs : public tvm::AttrsNode<RequantizeAttrs> {
6565
}
6666
};
6767

68+
/*! \brief Attribute for quantize operator */
69+
struct QuantizeAttrs : public tvm::AttrsNode<QuantizeAttrs> {
70+
int32_t output_zero_point;
71+
double output_scale;
72+
DataType out_dtype;
73+
74+
TVM_DECLARE_ATTRS(QuantizeAttrs, "relay.attrs.QuantizeAttrs") {
75+
TVM_ATTR_FIELD(out_dtype)
76+
.describe("Output data type, can be one of [int8 or uint8].");
77+
78+
TVM_ATTR_FIELD(output_zero_point)
79+
.describe("The zero_point for the activation of this op.");
80+
81+
TVM_ATTR_FIELD(output_scale)
82+
.describe("The scale for the activation of this op.");
83+
}
84+
};
85+
86+
/*! \brief Attribute for dequantize operator */
87+
struct DequantizeAttrs : public tvm::AttrsNode<DequantizeAttrs> {
88+
int32_t input_zero_point;
89+
double input_scale;
90+
91+
TVM_DECLARE_ATTRS(QuantizeAttrs, "relay.attrs.QuantizeAttrs") {
92+
TVM_ATTR_FIELD(input_zero_point)
93+
.describe("The zero_point for the input tensor of this op.");
94+
95+
TVM_ATTR_FIELD(input_scale)
96+
.describe("The scale for the input tensor of this op.");
97+
}
98+
};
99+
68100
} // namespace qnn
69101
} // namespace relay
70102
} // namespace tvm

python/tvm/relay/qnn/op/qnn.py

Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -74,6 +74,66 @@ def requantize(data,
7474
rounding,
7575
out_dtype)
7676

77+
78+
def quantize(data,
79+
output_scale,
80+
output_zero_point,
81+
out_dtype='int8'):
82+
r""" Quantize op
83+
This operator takes float32 as input and produces quantized int8 or unit8 as output.
84+
The input tensor can be of any shape. The output shape is the same as input shape.
85+
86+
Q_output = clamp((round(input_tensor/output_scale) + output_zero_point),
87+
out_dtype::min,
88+
out_dtype::max)
89+
90+
Parameters
91+
----------
92+
data : tvm.relay.Expr
93+
The input tensor to be quantized. Can be of type float32.
94+
output_zero_point : int
95+
The output zero_point.
96+
output_scale : float
97+
The output scale.
98+
input_dtype : str, optional
99+
The data type of the input tensor. Can be [int8, uint8]
100+
Returns
101+
-------
102+
result : tvm.relay.Expr
103+
The computed result.
104+
"""
105+
106+
return _make.quantize(data,
107+
output_scale,
108+
output_zero_point,
109+
out_dtype)
110+
111+
112+
def dequantize(data,
113+
input_scale,
114+
input_zero_point):
115+
r""" Dequantize op
116+
This operator takes quantized int8 and unit8 as input and produces
117+
dequantized float32 as output. The output shape is the same as input shape. The input
118+
tensor can be of any shape.
119+
120+
Parameters
121+
----------
122+
data : tvm.relay.Expr
123+
The input tensor to be dequantized. Can be of type [int8, uint8].
124+
input_zero_point : int
125+
The output zero_point.
126+
input_scale : float
127+
The output scale.
128+
Returns
129+
-------
130+
result : tvm.relay.Expr
131+
The computed result.
132+
"""
133+
134+
return _make.dequantize(data,
135+
input_scale,
136+
input_zero_point)
77137
def concatenate(data,
78138
input_scales,
79139
input_zero_points,

src/relay/qnn/op/dequantize.cc

Lines changed: 103 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,103 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one
3+
* or more contributor license agreements. See the NOTICE file
4+
* distributed with this work for additional information
5+
* regarding copyright ownership. The ASF licenses this file
6+
* to you under the Apache License, Version 2.0 (the
7+
* "License"); you may not use this file except in compliance
8+
* with the License. You may obtain a copy of the License at
9+
*
10+
* http://www.apache.org/licenses/LICENSE-2.0
11+
*
12+
* Unless required by applicable law or agreed to in writing,
13+
* software distributed under the License is distributed on an
14+
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15+
* KIND, either express or implied. See the License for the
16+
* specific language governing permissions and limitations
17+
* under the License.
18+
*/
19+
20+
/*!
21+
* Copyright (c) 2019 by Contributors
22+
* \file src/relay/qnn/op/dequantize.cc
23+
* \brief QNN dequantize operator. Dequantize operator converts from quantized
24+
* domain to unquantized domain.
25+
*/
26+
27+
#include <tvm/relay/analysis.h>
28+
#include <tvm/relay/op_attr_types.h>
29+
#include <tvm/relay/qnn/attrs.h>
30+
#include "../../pass/pattern_util.h"
31+
#include "../util.h"
32+
33+
namespace tvm {
34+
namespace relay {
35+
namespace qnn {
36+
37+
TVM_REGISTER_NODE_TYPE(DequantizeAttrs);
38+
39+
bool DequantizeRel(const Array<Type>& types,
40+
int num_inputs,
41+
const Attrs& attrs,
42+
const TypeReporter& reporter) {
43+
CHECK_EQ(types.size(), 2);
44+
const auto* data = types[0].as<TensorTypeNode>();
45+
const auto input_dtype = data->dtype;
46+
CHECK(input_dtype == Int(8) || input_dtype == UInt(8))
47+
<< "Input type should be one of the quantized types [unit8, int8] but was " << input_dtype;
48+
const Array<tvm::Expr> oshape = data->shape;
49+
// assign output type, output will always be float 32.
50+
reporter->Assign(types[1], TensorTypeNode::make(oshape, Float(32)));
51+
return true;
52+
}
53+
54+
Expr MakeDequantize(Expr data,
55+
double input_scale,
56+
int32_t input_zero_point) {
57+
auto attrs = make_node<DequantizeAttrs>();
58+
attrs->input_scale = input_scale;
59+
attrs->input_zero_point = input_zero_point;
60+
// real_value = scale * (quantized_value - zero_point)
61+
// A more detailed explanation can be found here - https://github.com/google/gemmlowp/blob/master/doc/quantization.md
62+
static const Op& op = Op::Get("qnn.dequantize");
63+
return CallNode::make(op, {data}, Attrs(attrs), {});
64+
}
65+
66+
Expr DequantizeLower(const Expr& input_tensor,
67+
const DequantizeAttrs* attrs) {
68+
const auto input_zero_point = MakeConstantScalar(Int(32), attrs->input_zero_point);
69+
const auto input_scale = MakeConstantScalar(Float(32), attrs->input_scale);
70+
auto shift = Subtract(Cast(input_tensor, Int(32)), input_zero_point);
71+
auto scaled_output = Multiply(Cast(shift, Float(32)), input_scale);
72+
return scaled_output;
73+
}
74+
75+
Expr DequantizeLegalize(const Attrs& attrs,
76+
const Array<Expr>& new_args,
77+
const Array<tvm::relay::Type>& arg_types) {
78+
CHECK_EQ(new_args.size(), 1);
79+
auto& data = new_args[0];
80+
const auto* dequantize_attrs = attrs.as<DequantizeAttrs>();
81+
CHECK(dequantize_attrs != nullptr);
82+
CHECK_EQ(arg_types.size(), 1);
83+
return DequantizeLower(data, dequantize_attrs);
84+
}
85+
86+
RELAY_REGISTER_OP("qnn.dequantize")
87+
.describe(R"code(Dequantizes the input and produces float32 output.
88+
The input is always quantized (int8, uint8) and will be converted to float32 given input scale and zero_point.
89+
- **data**: Quantized tensor of any shape to dequantize. The input data can be of floating point
90+
)code" TVM_ADD_FILELINE)
91+
.set_attrs_type_key("relay.attrs.DequantizeAttrs")
92+
.set_num_inputs(1)
93+
.add_argument("data", "Tensor", "The tensor to dequantize.")
94+
.set_support_level(11)
95+
.add_type_rel("Dequantize", DequantizeRel)
96+
.set_attr<FTVMLegalize>("FTVMLegalize", DequantizeLegalize);
97+
98+
TVM_REGISTER_API("relay.qnn.op._make.dequantize")
99+
.set_body_typed(MakeDequantize);
100+
101+
} // namespace qnn
102+
} // namespace relay
103+
} // namespace tvm

src/relay/qnn/op/quantize.cc

Lines changed: 121 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,121 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one
3+
* or more contributor license agreements. See the NOTICE file
4+
* distributed with this work for additional information
5+
* regarding copyright ownership. The ASF licenses this file
6+
* to you under the Apache License, Version 2.0 (the
7+
* "License"); you may not use this file except in compliance
8+
* with the License. You may obtain a copy of the License at
9+
*
10+
* http://www.apache.org/licenses/LICENSE-2.0
11+
*
12+
* Unless required by applicable law or agreed to in writing,
13+
* software distributed under the License is distributed on an
14+
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15+
* KIND, either express or implied. See the License for the
16+
* specific language governing permissions and limitations
17+
* under the License.
18+
*/
19+
20+
/*!
21+
* Copyright (c) 2019 by Contributors
22+
* \file src/relay/qnn/op/quantize.cc
23+
* \brief QNN dequantize operator. Dequantize operator converts from quantized
24+
* domain to unquantized domain.
25+
*/
26+
27+
#include <tvm/relay/analysis.h>
28+
#include <tvm/relay/op_attr_types.h>
29+
#include <tvm/relay/qnn/attrs.h>
30+
#include "../../pass/pattern_util.h"
31+
#include "../util.h"
32+
33+
namespace tvm {
34+
namespace relay {
35+
namespace qnn {
36+
37+
TVM_REGISTER_NODE_TYPE(QuantizeAttrs);
38+
39+
bool QuantizeRel(const Array<Type>& types,
40+
int num_inputs,
41+
const Attrs& attrs,
42+
const TypeReporter& reporter) {
43+
CHECK_EQ(types.size(), 2);
44+
const auto* data = types[0].as<TensorTypeNode>();
45+
const auto input_dtype = data->dtype;
46+
CHECK(input_dtype == Float(32))
47+
<< "Input type should be one of float32 but was " << input_dtype;
48+
const auto* quantize_attrs = attrs.as<QuantizeAttrs>();
49+
const Array<tvm::Expr> oshape = data->shape;
50+
const DataType out_dtype = quantize_attrs->out_dtype;
51+
CHECK(out_dtype == Int(8) || out_dtype == UInt(8))
52+
<< "Output type should be one of [int8, unit8 ] but was " << out_dtype;
53+
// assign output type
54+
reporter->Assign(types[1], TensorTypeNode::make(oshape, out_dtype));
55+
return true;
56+
}
57+
58+
Expr MakeQuantize(Expr data,
59+
double output_scale,
60+
int32_t output_zero_point,
61+
DataType out_dtype) {
62+
auto attrs = make_node<QuantizeAttrs>();
63+
attrs->output_scale = output_scale;
64+
attrs->output_zero_point = output_zero_point;
65+
attrs->out_dtype = std::move(out_dtype);
66+
// result_quantized_value = result_zero_point + result_real_value / result_scale.
67+
// A more detailed explanation can be found here - https://github.com/google/gemmlowp/blob/master/doc/quantization.md
68+
static const Op& op = Op::Get("qnn.quantize");
69+
return CallNode::make(op, {data}, Attrs(attrs), {});
70+
}
71+
72+
Expr QuantizeLower(const Expr& input_tensor,
73+
const QuantizeAttrs* attrs) {
74+
const auto out_dtype = attrs->out_dtype;
75+
const auto output_zero_point = MakeConstantScalar(Int(32), attrs->output_zero_point);
76+
const auto scale = MakeConstantScalar(Float(32), attrs->output_scale);
77+
const int32_t min_val = GetQmin(out_dtype);
78+
const int32_t max_val = GetQmax(out_dtype);
79+
auto scale_data = Cast(Round(Divide(input_tensor, scale)), Int(32));
80+
auto add_zero_point = Add(scale_data, output_zero_point);
81+
auto clamped_output = Clip(add_zero_point, min_val, max_val);
82+
auto clamp_out_dtype = Cast(clamped_output, out_dtype);
83+
return clamp_out_dtype;
84+
}
85+
86+
Expr QuantizeLegalize(const Attrs& attrs,
87+
const Array<Expr>& new_args,
88+
const Array<tvm::relay::Type>& arg_types) {
89+
CHECK_EQ(new_args.size(), 1);
90+
auto& data = new_args[0];
91+
const auto* quantize_attrs = attrs.as<QuantizeAttrs>();
92+
CHECK(quantize_attrs != nullptr);
93+
94+
CHECK_EQ(arg_types.size(), 1);
95+
return QuantizeLower(data, quantize_attrs);
96+
}
97+
98+
RELAY_REGISTER_OP("qnn.quantize")
99+
.describe(R"code(Quantizes the input and produces quantized output.
100+
The input can be either float or quantized(int8, unit8). If the input is float,
101+
this op takes scale and zero point and quantize the float value to
102+
quantized output, in int8 or uint8 format. If the input is quantized value,
103+
the op requantize the input (of a certain type, with a given scale and zero
104+
point) to the output of the same or different type with a same or different
105+
scale and zero point.
106+
- **data**: Tensor of any shape to quantize. The input data can be of floating point
107+
or quantized.
108+
)code" TVM_ADD_FILELINE)
109+
.set_attrs_type_key("relay.attrs.QuantizeAttrs")
110+
.set_num_inputs(1)
111+
.add_argument("data", "Tensor", "The tensor to quantize.")
112+
.set_support_level(11)
113+
.add_type_rel("Quantize", QuantizeRel)
114+
.set_attr<FTVMLegalize>("FTVMLegalize", QuantizeLegalize);
115+
116+
TVM_REGISTER_API("relay.qnn.op._make.quantize")
117+
.set_body_typed(MakeQuantize);
118+
119+
} // namespace qnn
120+
} // namespace relay
121+
} // namespace tvm

src/relay/qnn/op/requantize.cc

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919

2020
/*!
2121
* Copyright (c) 2019 by Contributors
22-
* \file requantize.cc
22+
* \file src/relay/qnn/op/requantize.cc
2323
* \brief QNN requantize operator.
2424
*/
2525

@@ -228,14 +228,14 @@ bool RequantizeRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,
228228
const auto* data = types[0].as<TensorTypeNode>();
229229
const auto in_dtype = data->dtype;
230230
CHECK(in_dtype == Int(8) || in_dtype == UInt(8) || in_dtype == Int(32))
231-
<< "Input type should be an integer but was " << in_dtype;
231+
<< "Input type should be one of [int8, uint8, int32] but was " << in_dtype;
232232

233233
const Array<tvm::Expr> oshape = data->shape;
234234
// assign output type
235235
const RequantizeAttrs* param = attrs.as<RequantizeAttrs>();
236236
auto out_dtype = param->out_dtype;
237237
CHECK(out_dtype == Int(8) || out_dtype == UInt(8) || out_dtype == Int(32))
238-
<< "Output type should be an integer but was " << out_dtype;
238+
<< "Output type should be one of [int8, uint8, int32] but was " << out_dtype;
239239
reporter->Assign(types[1], TensorTypeNode::make(oshape, out_dtype));
240240
return true;
241241
}

0 commit comments

Comments
 (0)