Skip to content

Commit 11d62df

Browse files
anijain2305wweic
authored andcommitted
[QNN] Concat - Refactoring to C++ (apache#3819)
1 parent f1639a8 commit 11d62df

File tree

5 files changed

+197
-45
lines changed

5 files changed

+197
-45
lines changed

include/tvm/relay/qnn/attrs.h

Lines changed: 29 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -88,7 +88,7 @@ struct DequantizeAttrs : public tvm::AttrsNode<DequantizeAttrs> {
8888
int32_t input_zero_point;
8989
double input_scale;
9090

91-
TVM_DECLARE_ATTRS(QuantizeAttrs, "relay.attrs.QuantizeAttrs") {
91+
TVM_DECLARE_ATTRS(DequantizeAttrs, "relay.attrs.DequantizeAttrs") {
9292
TVM_ATTR_FIELD(input_zero_point)
9393
.describe("The zero_point for the input tensor of this op.");
9494

@@ -97,6 +97,34 @@ struct DequantizeAttrs : public tvm::AttrsNode<DequantizeAttrs> {
9797
}
9898
};
9999

100+
/*! \brief Attributes used in QNN concatenate operator */
101+
struct QnnConcatenateAttrs : public tvm::AttrsNode<QnnConcatenateAttrs> {
102+
Array<tvm::Expr> input_scales;
103+
Array<tvm::Expr> input_zero_points;
104+
double output_scale;
105+
int32_t output_zero_point;
106+
int axis;
107+
108+
TVM_DECLARE_ATTRS(QnnConcatenateAttrs, "relay.attrs.QnnConcatenateAttrs") {
109+
TVM_ATTR_FIELD(input_scales)
110+
.describe("The list of scales of input quantized tensors.");
111+
112+
TVM_ATTR_FIELD(input_zero_points)
113+
.describe("The list of zero points of input quantized tensors.");
114+
115+
TVM_ATTR_FIELD(output_zero_point)
116+
.describe("The zero_point for the output tensor.");
117+
118+
TVM_ATTR_FIELD(output_scale)
119+
.describe("The scale for the output tensor.");
120+
121+
TVM_ATTR_FIELD(axis)
122+
.describe("The axis at which the input arrays are concatenated."
123+
"Should lie in range `[-ndim, ndim)`.")
124+
.set_default(0);
125+
}
126+
}; // struct QnnConcatenateAttrs
127+
100128
} // namespace qnn
101129
} // namespace relay
102130
} // namespace tvm

python/tvm/relay/qnn/op/qnn.py

Lines changed: 15 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,8 @@
1818
"""QNN dialect operators."""
1919

2020
from __future__ import absolute_import as _abs
21-
from tvm import relay
21+
from tvm.expr import FloatImm, IntImm
22+
from tvm.relay.expr import Tuple
2223
from . import _make
2324

2425
def requantize(data,
@@ -134,6 +135,8 @@ def dequantize(data,
134135
return _make.dequantize(data,
135136
input_scale,
136137
input_zero_point)
138+
139+
137140
def concatenate(data,
138141
input_scales,
139142
input_zero_points,
@@ -169,42 +172,14 @@ def concatenate(data,
169172
"""
170173

171174
data = list(data)
172-
requantized_exprs = list(data)
173-
174-
# Find the dtype of the input expr. This is required for the requantize op. Since, this is
175-
# concatenate op, the dtype of the input is same as dtype of the output.
176-
mod = relay.Module.from_expr(data[0])
177-
mod = relay.transform.InferType()(mod)
178-
entry = mod["main"]
179-
data0 = entry if isinstance(data[0], relay.Function) else entry.body
180-
in_dtype = data0.checked_type.dtype
181-
182-
# First check if all the input qnn params match. If yes, we can call concatenate first, followed
183-
# by a requantize.
184-
if all(scale == input_scales[0] for scale in input_scales)\
185-
and all(zero_point == input_zero_points[0] for zero_point in input_zero_points):
186-
out = relay.concatenate(tuple(data), axis)
187-
input_scale = input_scales[0]
188-
input_zero_point = input_zero_points[0]
189-
if input_scale != output_scale or input_zero_point != output_zero_point:
190-
out = requantize(data=out,
191-
input_scale=input_scales[0],
192-
input_zero_point=input_zero_points[0],
193-
output_scale=output_scale,
194-
output_zero_point=output_zero_point,
195-
out_dtype=in_dtype)
196-
return out
197-
198-
# If the output qnn params do not match the input qnn params, we can call requantize on the
199-
# input expr first, followed by a concatenate on the requantized input exprs.
200-
for idx, quantized_expr in enumerate(data):
201-
input_scale = input_scales[idx]
202-
input_zero_point = input_zero_points[idx]
203-
if input_scale != output_scale or input_zero_point != output_zero_point:
204-
requantized_exprs[idx] = requantize(data=quantized_expr,
205-
input_scale=input_scale,
206-
input_zero_point=input_zero_point,
207-
output_scale=output_scale,
208-
output_zero_point=output_zero_point,
209-
out_dtype=in_dtype)
210-
return relay.concatenate(tuple(requantized_exprs), axis)
175+
if not data:
176+
raise ValueError("relay.concatenate requires data to be non-empty.")
177+
if not isinstance(axis, int):
178+
raise ValueError("For now, we only support integer axis")
179+
180+
return _make.concatenate(Tuple(data),
181+
[FloatImm("float64", x) for x in input_scales],
182+
[IntImm("int32", x) for x in input_zero_points],
183+
output_scale,
184+
output_zero_point,
185+
axis)

src/relay/qnn/op/concatenate.cc

Lines changed: 134 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,134 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one
3+
* or more contributor license agreements. See the NOTICE file
4+
* distributed with this work for additional information
5+
* regarding copyright ownership. The ASF licenses this file
6+
* to you under the Apache License, Version 2.0 (the
7+
* "License"); you may not use this file except in compliance
8+
* with the License. You may obtain a copy of the License at
9+
*
10+
* http://www.apache.org/licenses/LICENSE-2.0
11+
*
12+
* Unless required by applicable law or agreed to in writing,
13+
* software distributed under the License is distributed on an
14+
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15+
* KIND, either express or implied. See the License for the
16+
* specific language governing permissions and limitations
17+
* under the License.
18+
*/
19+
20+
/*!
21+
* Copyright (c) 2019 by Contributors
22+
* \file src/relay/qnn/op/concatenate.cc
23+
* \brief QNN concatenate operator. It concatenates quantized input tensors along a given axis.
24+
*/
25+
26+
#include <tvm/ir.h>
27+
#include <tvm/relay/analysis.h>
28+
#include <tvm/relay/op_attr_types.h>
29+
#include <tvm/relay/qnn/attrs.h>
30+
#include "../../op/tensor/transform.h"
31+
#include "../../pass/pattern_util.h"
32+
#include "../util.h"
33+
34+
namespace tvm {
35+
namespace relay {
36+
namespace qnn {
37+
38+
TVM_REGISTER_NODE_TYPE(QnnConcatenateAttrs);
39+
40+
Expr MakeQnnConcatenate(Expr data, Array<tvm::Expr> input_scales,
41+
Array<tvm::Expr> input_zero_points, double output_scale,
42+
int32_t output_zero_point, int axis) {
43+
auto attrs = make_node<QnnConcatenateAttrs>();
44+
attrs->input_scales = std::move(input_scales);
45+
attrs->input_zero_points = std::move(input_zero_points);
46+
attrs->output_scale = output_scale;
47+
attrs->output_zero_point = output_zero_point;
48+
attrs->axis = axis;
49+
static const Op& op = Op::Get("qnn.concatenate");
50+
return CallNode::make(op, {data}, Attrs(attrs), {});
51+
}
52+
53+
/*
54+
* \brief Canonicalizes the QNN concatenate op.
55+
* \param attrs The QNN concatenate attrs.
56+
* \param new_args The new mutated args to the call node.
57+
* \param arg_types The types of input and output.
58+
* \return The sequence of Relay ops for concatenate op.
59+
*/
60+
Expr ConcatenateQnnCanonicalize(const Attrs& attrs, const Array<Expr>& new_args,
61+
const Array<tvm::relay::Type>& arg_types) {
62+
// Get the attrs.
63+
CHECK_EQ(new_args.size(), 1);
64+
auto& data = new_args[0];
65+
const auto* concatenate_attrs = attrs.as<QnnConcatenateAttrs>();
66+
CHECK(concatenate_attrs != nullptr);
67+
auto input_scales = concatenate_attrs->input_scales;
68+
auto input_zero_points = concatenate_attrs->input_zero_points;
69+
auto output_scale = concatenate_attrs->output_scale;
70+
auto output_zero_point = concatenate_attrs->output_zero_point;
71+
72+
// Get the input dtype and shape.
73+
CHECK_GE(arg_types.size(), 1);
74+
auto tuple_type = arg_types[0].as<TupleTypeNode>();
75+
CHECK(tuple_type != nullptr);
76+
77+
// FIXME (anijain2305) - The lowering can be further optimized. Instead of inserting requantize in
78+
// the start, we can insert requantize at the end if and only if all the input tensors have same
79+
// qnn params. This can be done in future.
80+
81+
// If the output qnn params do not match the input qnn params, we can call requantize on the input
82+
// expr first, followed by a concatenate on the requantized input exprs.
83+
84+
auto tuple_data = data.as<TupleNode>();
85+
CHECK(tuple_data != nullptr);
86+
87+
int idx = 0;
88+
Array<Expr> requantized_exprs;
89+
for (auto quantized_expr : tuple_data->fields) {
90+
// Get the input scale for the idx quantized input tensor.
91+
auto input_scale_expr = input_scales[idx].as<tvm::ir::FloatImm>();
92+
CHECK(input_scale_expr != nullptr);
93+
auto input_scale = input_scale_expr->value;
94+
95+
// Get the zero point for the idx quantized input tensor.
96+
auto input_zero_point_expr = input_zero_points[idx].as<tvm::ir::IntImm>();
97+
CHECK(input_zero_point_expr != nullptr);
98+
auto input_zero_point = input_zero_point_expr->value;
99+
100+
// Check if output and input qnn params are same. If not, requantize.
101+
if (input_scale != output_scale || input_zero_point != output_zero_point) {
102+
// Get the input shape and dtype.
103+
auto tensor_type = tuple_type->fields[idx].as<TensorTypeNode>();
104+
auto input_dtype = tensor_type->dtype;
105+
auto input_shape = tensor_type->shape;
106+
107+
// Requantize the input.
108+
auto requantized_expr = Requantize(quantized_expr, input_shape, input_scale, input_zero_point,
109+
output_scale, output_zero_point, input_dtype);
110+
requantized_exprs.push_back(requantized_expr);
111+
} else {
112+
requantized_exprs.push_back(quantized_expr);
113+
}
114+
idx++;
115+
}
116+
return MakeConcatenate(TupleNode::make(requantized_exprs), concatenate_attrs->axis);
117+
}
118+
119+
RELAY_REGISTER_OP("qnn.concatenate")
120+
.describe(R"code(Concatenate the quantized input tensors along the given axis.
121+
)code" TVM_ADD_FILELINE)
122+
.set_attrs_type_key("relay.attrs.QnnConcatenateAttrs")
123+
.set_num_inputs(1)
124+
.add_argument("data", "Tensor", "The tensor to concatenate.")
125+
.set_support_level(11)
126+
.add_type_rel("QnnConcatenate", ConcatenateRel<QnnConcatenateAttrs>)
127+
.set_attr<FTVMLegalize>("FTVMQnnCanonicalize", ConcatenateQnnCanonicalize);
128+
129+
TVM_REGISTER_API("relay.qnn.op._make.concatenate")
130+
.set_body_typed(MakeQnnConcatenate);
131+
132+
} // namespace qnn
133+
} // namespace relay
134+
} // namespace tvm

src/relay/qnn/util.h

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,8 @@
2828
#include <tvm/expr.h>
2929
#include <tvm/relay/expr.h>
3030
#include <limits>
31+
#include <string>
32+
#include <utility>
3133

3234
namespace tvm {
3335
namespace relay {
@@ -67,6 +69,23 @@ static inline const int32_t GetQmax(const DataType& dtype) {
6769
}
6870
}
6971

72+
Expr RequantizeLower(const Expr& input_tensor, const RequantizeAttrs* param,
73+
const Array<IndexExpr>& input_shape, const DataType& out_dtype);
74+
75+
static inline Expr Requantize(const Expr& data, const Array<IndexExpr>& input_shape,
76+
double input_scale, int32_t input_zero_point, double output_scale,
77+
int32_t output_zero_point, const DataType& out_dtype,
78+
const std::string& rounding = "TONEAREST") {
79+
auto attrs = make_node<RequantizeAttrs>();
80+
attrs->input_scale = std::move(input_scale);
81+
attrs->input_zero_point = std::move(input_zero_point);
82+
attrs->output_scale = std::move(output_scale);
83+
attrs->output_zero_point = std::move(output_zero_point);
84+
attrs->rounding = std::move(rounding);
85+
attrs->out_dtype = std::move(out_dtype);
86+
return RequantizeLower(data, attrs.operator->(), input_shape, out_dtype);
87+
}
88+
7089
} // namespace qnn
7190
} // namespace relay
7291
} // namespace tvm

tests/python/relay/test_qnn_concatenate.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,6 @@ def test_same_io_qnn_params():
3939
axis=axis)
4040

4141
func = relay.Function([x, y], z)
42-
assert func.astext().count('requantize') == 0
4342
mod = relay.Module.from_expr(func)
4443
mod = relay.qnn.transform.CanonicalizeOps()(mod)
4544
func = mod["main"]
@@ -68,7 +67,6 @@ def test_different_io_qnn_params():
6867
axis=axis)
6968

7069
func = relay.Function([x, y], z)
71-
assert func.astext().count('requantize') == 2
7270
mod = relay.Module.from_expr(func)
7371
mod = relay.qnn.transform.CanonicalizeOps()(mod)
7472
func = mod["main"]
@@ -97,7 +95,6 @@ def test_few_same_io_qnn_params():
9795
axis=axis)
9896

9997
func = relay.Function([x, y], z)
100-
assert func.astext().count('requantize') == 1
10198
mod = relay.Module.from_expr(func)
10299
mod = relay.qnn.transform.CanonicalizeOps()(mod)
103100
func = mod["main"]
@@ -126,7 +123,6 @@ def test_same_i_qnn_params():
126123
axis=axis)
127124

128125
func = relay.Function([x, y], z)
129-
assert func.astext().count('requantize') == 1
130126
mod = relay.Module.from_expr(func)
131127
mod = relay.qnn.transform.CanonicalizeOps()(mod)
132128
func = mod["main"]

0 commit comments

Comments
 (0)