Skip to content

Commit f1639a8

Browse files
anijain2305wweic
authored andcommitted
[Relay][QNN] Moving Conv, Dense, Concatenate InferTypes to header for sharing. (apache#3783)
1 parent acaed8f commit f1639a8

File tree

6 files changed

+331
-228
lines changed

6 files changed

+331
-228
lines changed

src/relay/op/nn/convolution.cc

Lines changed: 3 additions & 107 deletions
Original file line numberDiff line numberDiff line change
@@ -29,118 +29,14 @@
2929
#include <vector>
3030

3131
#include "../../pass/alter_op_layout.h"
32+
#include "convolution.h"
3233

3334
namespace tvm {
3435
namespace relay {
3536

3637
// relay.nn.conv2d
3738
TVM_REGISTER_NODE_TYPE(Conv2DAttrs);
3839

39-
bool Conv2DRel(const Array<Type>& types,
40-
int num_inputs,
41-
const Attrs& attrs,
42-
const TypeReporter& reporter) {
43-
CHECK_EQ(types.size(), 3);
44-
const auto* data = types[0].as<TensorTypeNode>();
45-
const auto* weight = types[1].as<TensorTypeNode>();
46-
if (data == nullptr) return false;
47-
static const Layout kNCHW("NCHW");
48-
static const Layout kOIHW("OIHW");
49-
50-
const Conv2DAttrs* param = attrs.as<Conv2DAttrs>();
51-
CHECK(param != nullptr);
52-
const Layout in_layout(param->data_layout);
53-
const Layout kernel_layout(param->kernel_layout);
54-
55-
const auto trans_in_layout = BijectiveLayoutNode::make(in_layout, kNCHW);
56-
CHECK(trans_in_layout.defined())
57-
<< "Conv only support input layouts that are convertible from NCHW."
58-
<< " But got " << in_layout;
59-
60-
const auto trans_kernel_layout = BijectiveLayoutNode::make(kernel_layout, kOIHW);
61-
CHECK(trans_kernel_layout.defined())
62-
<< "Conv only support kernel layouts that are convertible from OIHW."
63-
<< " But got "<< kernel_layout;
64-
65-
Layout out_layout(param->out_layout == "" ? param->data_layout : param->out_layout);
66-
const auto trans_out_layout = BijectiveLayoutNode::make(out_layout, kNCHW);
67-
CHECK(trans_out_layout.defined())
68-
<< "Conv only support output layouts that are convertible from NCHW."
69-
<< " But got " << out_layout;
70-
71-
Array<IndexExpr> dshape_nchw = trans_in_layout.ForwardShape(data->shape);
72-
73-
IndexExpr channels, dilated_ksize_y, dilated_ksize_x;
74-
// infer weight if the kernel_size and channels are defined
75-
if (param->kernel_size.defined() && param->channels.defined()) {
76-
CHECK_EQ(param->kernel_size.size(), 2);
77-
CHECK_EQ(param->dilation.size(), 2);
78-
Array<IndexExpr> wshape;
79-
80-
if (tvm::ir::Equal(param->channels, param->groups)) {
81-
// infer weight's shape for depthwise convolution
82-
wshape = {
83-
{dshape_nchw[1],
84-
param->groups / dshape_nchw[1],
85-
param->kernel_size[0],
86-
param->kernel_size[1]}};
87-
} else {
88-
wshape = {
89-
{param->channels,
90-
dshape_nchw[1] / param->groups,
91-
param->kernel_size[0],
92-
param->kernel_size[1]}};
93-
}
94-
95-
wshape = trans_kernel_layout.BackwardShape(wshape);
96-
channels = param->channels;
97-
dilated_ksize_y = 1 + (param->kernel_size[0] - 1) * param->dilation[0];
98-
dilated_ksize_x = 1 + (param->kernel_size[1] - 1) * param->dilation[1];
99-
DataType weight_dtype = data->dtype;
100-
if (weight != nullptr) {
101-
weight_dtype = weight->dtype;
102-
}
103-
// assign result to reporter
104-
reporter->Assign(types[1], TensorTypeNode::make(wshape, weight_dtype));
105-
} else {
106-
// use weight to infer the conv shape.
107-
if (weight == nullptr) return false;
108-
auto wshape = trans_kernel_layout.ForwardShape(weight->shape);
109-
if (param->kernel_size.defined()) {
110-
CHECK_EQ(param->kernel_size.size(), 2);
111-
// check the size
112-
CHECK(reporter->AssertEQ(param->kernel_size[0], wshape[2]) &&
113-
reporter->AssertEQ(param->kernel_size[1], wshape[3]))
114-
<< "Conv2D: shape of weight is inconsistent with kernel_size, "
115-
<< " kernel_size=" << param->kernel_size
116-
<< " wshape=" << wshape;
117-
}
118-
if (param->channels.defined()) {
119-
CHECK(reporter->AssertEQ(param->channels, wshape[0]))
120-
<< "Conv2D: shape of weight is inconsistent with channels, "
121-
<< " channels=" << param->channels
122-
<< " wshape=" << wshape;
123-
}
124-
CHECK(reporter->AssertEQ(dshape_nchw[1] / param->groups, wshape[1]));
125-
channels = wshape[0];
126-
dilated_ksize_y = 1 + (wshape[2] - 1) * param->dilation[0];
127-
dilated_ksize_x = 1 + (wshape[3] - 1) * param->dilation[1];
128-
}
129-
// dilation
130-
Array<IndexExpr> oshape({dshape_nchw[0], channels, 0, 0});
131-
132-
oshape.Set(2, (dshape_nchw[2] + param->padding[0] * 2 - dilated_ksize_y) / param->strides[0] + 1);
133-
oshape.Set(3, (dshape_nchw[3] + param->padding[1] * 2 - dilated_ksize_x) / param->strides[1] + 1);
134-
DataType out_dtype = param->out_dtype;
135-
if (out_dtype.bits() == 0) {
136-
out_dtype = data->dtype;
137-
}
138-
oshape = trans_out_layout.BackwardShape(oshape);
139-
// assign output type
140-
reporter->Assign(types[2], TensorTypeNode::make(oshape, out_dtype));
141-
return true;
142-
}
143-
14440
template<typename T>
14541
Array<Array<Layout> > Conv2DInferCorrectLayout(
14642
const Attrs& attrs,
@@ -208,7 +104,7 @@ with the layer input to produce a tensor of outputs.
208104
.add_argument("data", "Tensor", "The input tensor.")
209105
.add_argument("weight", "Tensor", "The weight tensor.")
210106
.set_support_level(2)
211-
.add_type_rel("Conv2D", Conv2DRel)
107+
.add_type_rel("Conv2D", Conv2DRel<Conv2DAttrs>)
212108
.set_attr<FInferCorrectLayout>("FInferCorrectLayout", Conv2DInferCorrectLayout<Conv2DAttrs>);
213109

214110

@@ -770,7 +666,7 @@ RELAY_REGISTER_OP("nn.contrib_depthwise_conv2d_NCHWc")
770666
.add_argument("data", "Tensor", "The input tensor.")
771667
.add_argument("weight", "Tensor", "The weight tensor.")
772668
.set_support_level(10)
773-
.add_type_rel("Conv2D", Conv2DRel)
669+
.add_type_rel("Conv2D", Conv2DRel<Conv2DAttrs>)
774670
.set_attr<FInferCorrectLayout>("FInferCorrectLayout",
775671
Conv2DInferCorrectLayout<Conv2DAttrs>);
776672

src/relay/op/nn/convolution.h

Lines changed: 132 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,132 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one
3+
* or more contributor license agreements. See the NOTICE file
4+
* distributed with this work for additional information
5+
* regarding copyright ownership. The ASF licenses this file
6+
* to you under the Apache License, Version 2.0 (the
7+
* "License"); you may not use this file except in compliance
8+
* with the License. You may obtain a copy of the License at
9+
*
10+
* http://www.apache.org/licenses/LICENSE-2.0
11+
*
12+
* Unless required by applicable law or agreed to in writing,
13+
* software distributed under the License is distributed on an
14+
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15+
* KIND, either express or implied. See the License for the
16+
* specific language governing permissions and limitations
17+
* under the License.
18+
*/
19+
20+
/*!
21+
* Copyright (c) 2019 by Contributors
22+
* \file src/relay/op/nn/convolution.h
23+
* \brief Properties def of convlution operator for sharing.
24+
*/
25+
#ifndef TVM_RELAY_OP_NN_CONVOLUTION_H_
26+
#define TVM_RELAY_OP_NN_CONVOLUTION_H_
27+
28+
#include <string>
29+
#include <utility>
30+
31+
namespace tvm {
32+
namespace relay {
33+
34+
template <typename AttrType>
35+
bool Conv2DRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,
36+
const TypeReporter& reporter) {
37+
CHECK_EQ(types.size(), 3);
38+
const auto* data = types[0].as<TensorTypeNode>();
39+
const auto* weight = types[1].as<TensorTypeNode>();
40+
if (data == nullptr) return false;
41+
static const Layout kNCHW("NCHW");
42+
static const Layout kOIHW("OIHW");
43+
44+
const AttrType* param = attrs.as<AttrType>();
45+
CHECK(param != nullptr);
46+
const Layout in_layout(param->data_layout);
47+
const Layout kernel_layout(param->kernel_layout);
48+
49+
const auto trans_in_layout = BijectiveLayoutNode::make(in_layout, kNCHW);
50+
CHECK(trans_in_layout.defined())
51+
<< "Conv only support input layouts that are convertible from NCHW."
52+
<< " But got " << in_layout;
53+
54+
const auto trans_kernel_layout = BijectiveLayoutNode::make(kernel_layout, kOIHW);
55+
CHECK(trans_kernel_layout.defined())
56+
<< "Conv only support kernel layouts that are convertible from OIHW."
57+
<< " But got " << kernel_layout;
58+
59+
Layout out_layout(param->out_layout == "" ? param->data_layout : param->out_layout);
60+
const auto trans_out_layout = BijectiveLayoutNode::make(out_layout, kNCHW);
61+
CHECK(trans_out_layout.defined())
62+
<< "Conv only support output layouts that are convertible from NCHW."
63+
<< " But got " << out_layout;
64+
65+
Array<IndexExpr> dshape_nchw = trans_in_layout.ForwardShape(data->shape);
66+
67+
IndexExpr channels, dilated_ksize_y, dilated_ksize_x;
68+
// infer weight if the kernel_size and channels are defined
69+
if (param->kernel_size.defined() && param->channels.defined()) {
70+
CHECK_EQ(param->kernel_size.size(), 2);
71+
CHECK_EQ(param->dilation.size(), 2);
72+
Array<IndexExpr> wshape;
73+
74+
if (tvm::ir::Equal(param->channels, param->groups)) {
75+
// infer weight's shape for depthwise convolution
76+
wshape = {{dshape_nchw[1], param->groups / dshape_nchw[1], param->kernel_size[0],
77+
param->kernel_size[1]}};
78+
} else {
79+
wshape = {{param->channels, dshape_nchw[1] / param->groups, param->kernel_size[0],
80+
param->kernel_size[1]}};
81+
}
82+
83+
wshape = trans_kernel_layout.BackwardShape(wshape);
84+
channels = param->channels;
85+
dilated_ksize_y = 1 + (param->kernel_size[0] - 1) * param->dilation[0];
86+
dilated_ksize_x = 1 + (param->kernel_size[1] - 1) * param->dilation[1];
87+
DataType weight_dtype = data->dtype;
88+
if (weight != nullptr) {
89+
weight_dtype = weight->dtype;
90+
}
91+
// assign result to reporter
92+
reporter->Assign(types[1], TensorTypeNode::make(wshape, weight_dtype));
93+
} else {
94+
// use weight to infer the conv shape.
95+
if (weight == nullptr) return false;
96+
auto wshape = trans_kernel_layout.ForwardShape(weight->shape);
97+
if (param->kernel_size.defined()) {
98+
CHECK_EQ(param->kernel_size.size(), 2);
99+
// check the size
100+
CHECK(reporter->AssertEQ(param->kernel_size[0], wshape[2]) &&
101+
reporter->AssertEQ(param->kernel_size[1], wshape[3]))
102+
<< "Conv2D: shape of weight is inconsistent with kernel_size, "
103+
<< " kernel_size=" << param->kernel_size << " wshape=" << wshape;
104+
}
105+
if (param->channels.defined()) {
106+
CHECK(reporter->AssertEQ(param->channels, wshape[0]))
107+
<< "Conv2D: shape of weight is inconsistent with channels, "
108+
<< " channels=" << param->channels << " wshape=" << wshape;
109+
}
110+
CHECK(reporter->AssertEQ(dshape_nchw[1] / param->groups, wshape[1]));
111+
channels = wshape[0];
112+
dilated_ksize_y = 1 + (wshape[2] - 1) * param->dilation[0];
113+
dilated_ksize_x = 1 + (wshape[3] - 1) * param->dilation[1];
114+
}
115+
// dilation
116+
Array<IndexExpr> oshape({dshape_nchw[0], channels, 0, 0});
117+
118+
oshape.Set(2, (dshape_nchw[2] + param->padding[0] * 2 - dilated_ksize_y) / param->strides[0] + 1);
119+
oshape.Set(3, (dshape_nchw[3] + param->padding[1] * 2 - dilated_ksize_x) / param->strides[1] + 1);
120+
DataType out_dtype = param->out_dtype;
121+
if (out_dtype.bits() == 0) {
122+
out_dtype = data->dtype;
123+
}
124+
oshape = trans_out_layout.BackwardShape(oshape);
125+
// assign output type
126+
reporter->Assign(types[2], TensorTypeNode::make(oshape, out_dtype));
127+
return true;
128+
}
129+
130+
} // namespace relay
131+
} // namespace tvm
132+
#endif // TVM_RELAY_OP_NN_CONVOLUTION_H_

src/relay/op/nn/nn.cc

Lines changed: 2 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@
3535
#include "../type_relations.h"
3636
#include "../../pass/alter_op_layout.h"
3737
#include "../op_common.h"
38+
#include "nn.h"
3839

3940
namespace tvm {
4041
namespace relay {
@@ -102,45 +103,6 @@ RELAY_REGISTER_OP("nn.bias_add")
102103
// relay.nn.dense
103104
TVM_REGISTER_NODE_TYPE(DenseAttrs);
104105

105-
106-
bool DenseRel(const Array<Type>& types,
107-
int num_inputs,
108-
const Attrs& attrs,
109-
const TypeReporter& reporter) {
110-
CHECK_EQ(types.size(), 3);
111-
const auto* data = types[0].as<TensorTypeNode>();
112-
const auto* weight = types[1].as<TensorTypeNode>();
113-
if (data == nullptr) return false;
114-
115-
const DenseAttrs* param = attrs.as<DenseAttrs>();
116-
CHECK(param != nullptr);
117-
118-
CHECK(static_cast<int>(data->shape.size()) != 0);
119-
120-
Array<tvm::Expr> oshape = data->shape;
121-
if (param->units.defined()) {
122-
Array<tvm::Expr> dshape = data->shape;
123-
// validate the weight shape is proper if defined
124-
// Assign weight type
125-
Array<IndexExpr> wshape({param->units, dshape[dshape.size() - 1]});
126-
reporter->Assign(types[1], TensorTypeNode::make(wshape, data->dtype));
127-
oshape.Set((oshape.size() - 1), param->units);
128-
} else {
129-
if (weight == nullptr) return false;
130-
Array<tvm::Expr> wshape = weight->shape;
131-
oshape.Set((oshape.size() - 1), wshape[0]);
132-
}
133-
134-
DataType out_dtype = param->out_dtype;
135-
if (out_dtype.bits() == 0) {
136-
out_dtype = data->dtype;
137-
}
138-
// assign output type
139-
reporter->Assign(types[2], TensorTypeNode::make(oshape, out_dtype));
140-
return true;
141-
}
142-
143-
144106
// Positional relay function to create dense operator used by frontend FFI.
145107
Expr MakeDense(Expr data,
146108
Expr weight,
@@ -171,7 +133,7 @@ RELAY_REGISTER_OP("nn.dense")
171133
.add_argument("data", "nD Tensor", "Input data.")
172134
.add_argument("weight", "2D Tensor", "Weight matrix.")
173135
.set_support_level(1)
174-
.add_type_rel("Dense", DenseRel);
136+
.add_type_rel("Dense", DenseRel<DenseAttrs>);
175137

176138
// relay.leaky_relu
177139
TVM_REGISTER_NODE_TYPE(LeakyReluAttrs);

0 commit comments

Comments
 (0)