Skip to content

Commit 87f012e

Browse files
committed
[Relay][AlterOp] NHWC to NCHWc support for Pool, pad, concatenate, sum.
1 parent d703fb4 commit 87f012e

File tree

6 files changed

+511
-35
lines changed

6 files changed

+511
-35
lines changed

python/tvm/relay/frontend/tflite.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -748,10 +748,12 @@ def convert_conv(self, op, conv_type):
748748
elif padding == Padding.SAME:
749749
pad_top, pad_bottom = get_pad_value(input_h, dilated_kernel_h, stride_h)
750750
pad_left, pad_right = get_pad_value(input_w, dilated_kernel_w, stride_w)
751-
in_expr = _op.nn.pad(data=in_expr, pad_width=((0, 0),
752-
(pad_top, pad_bottom),
753-
(pad_left, pad_right),
754-
(0, 0)))
751+
do_pad = not (pad_top == 0 and pad_bottom == 0 and pad_left == 0 and pad_right == 0)
752+
if do_pad:
753+
in_expr = _op.nn.pad(data=in_expr, pad_width=((0, 0),
754+
(pad_top, pad_bottom),
755+
(pad_left, pad_right),
756+
(0, 0)))
755757
else:
756758
raise tvm.error.OpAttributeUnImplemented(
757759
'Padding format {} is not supported for operator Conv.'.format(padding))

src/relay/op/nn/pad.cc

Lines changed: 77 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,9 +6,9 @@
66
* to you under the Apache License, Version 2.0 (the
77
* "License"); you may not use this file except in compliance
88
* with the License. You may obtain a copy of the License at
9-
*
9+
*
1010
* http://www.apache.org/licenses/LICENSE-2.0
11-
*
11+
*
1212
* Unless required by applicable law or agreed to in writing,
1313
* software distributed under the License is distributed on an
1414
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
@@ -36,6 +36,80 @@ namespace relay {
3636
// relay.nn.pad
3737
TVM_REGISTER_NODE_TYPE(PadAttrs);
3838

39+
Array<Array<Layout> > PadInferCorrectLayout(
40+
const Attrs& attrs,
41+
const Array<Layout>& new_in_layouts,
42+
const Array<Layout>& old_in_layouts,
43+
const Array<Array<IndexExpr>> &old_in_shapes) {
44+
// NOTE: Discard "const" qualifier here.
45+
PadAttrs *params = const_cast<PadAttrs*>(attrs.as<PadAttrs>());
46+
47+
Layout ret;
48+
// If new_in_layouts are defined, this code tries to modify the layout.
49+
bool is_layout_modified = new_in_layouts.defined();
50+
if (new_in_layouts.defined()) {
51+
// Create a map of axis to param_width. For the new layout, a new param_width is generated using
52+
// the map. The new layout is rejected, if the padding is happening along the axis which was
53+
// split.
54+
55+
// 1) Create a map from axis to param_width using old layout.
56+
std::map<std::string, tvm::Array<tvm::Expr>> axis_pad_width;
57+
int index_counter = 0;
58+
CHECK_EQ(new_in_layouts.size(), 1);
59+
for (auto iter_var : old_in_layouts[0]->axes) {
60+
const auto& old_layout_axis = LayoutAxis::Get(iter_var);
61+
axis_pad_width.emplace(old_layout_axis.name(), params->pad_width[index_counter]);
62+
index_counter++;
63+
}
64+
65+
// 2) Create new pad width by walking over the new layout and using the map.
66+
tvm::Array<tvm::Array<tvm::Expr>> new_pad_width;
67+
for (auto iter_var : new_in_layouts[0]->axes) {
68+
const auto& new_layout_axis = LayoutAxis::Get(iter_var);
69+
auto axis_name = new_layout_axis.name();
70+
if (axis_pad_width.count(axis_name) != 0 && new_layout_axis.IsPrimal()) {
71+
// This is primal axis. So, directly use the original pad_width.
72+
new_pad_width.push_back(axis_pad_width.at(axis_name));
73+
} else {
74+
// This is the axis that got split. So, check that pad_width was [0, 0] originally.
75+
const auto& dual_axis = new_layout_axis.ToPrimal();
76+
auto dual_axis_name = dual_axis.name();
77+
CHECK(axis_pad_width.count(dual_axis_name));
78+
new_pad_width.push_back(axis_pad_width.at(dual_axis_name));
79+
80+
// If all pad_width elements are not zero, do not change the layout.
81+
for (auto width : axis_pad_width.at(dual_axis_name)) {
82+
if (auto* width_imm = width.as<IntImm>()) {
83+
if (width_imm->value != 0) {
84+
is_layout_modified = false;
85+
}
86+
} else {
87+
is_layout_modified = false;
88+
}
89+
}
90+
}
91+
}
92+
93+
// If the above conditions satisfied, we can set the newly created pad_width and use the new
94+
// layout.
95+
if (is_layout_modified) {
96+
ret = new_in_layouts[0];
97+
params->pad_width = new_pad_width;
98+
}
99+
}
100+
101+
if (!is_layout_modified) {
102+
if (old_in_layouts.defined()) {
103+
CHECK_EQ(old_in_layouts.size(), 1);
104+
ret = old_in_layouts[0];
105+
} else {
106+
ret = Layout::Undef();
107+
}
108+
}
109+
110+
return Array<Array<Layout> >{{ret}, {ret}};
111+
}
112+
39113
bool PadRel(const Array<Type>& types,
40114
int num_inputs,
41115
const Attrs& attrs,
@@ -133,6 +207,7 @@ RELAY_REGISTER_OP("nn.pad")
133207
.add_argument("data", "Tensor", "The input tensor.")
134208
.set_support_level(2)
135209
.add_type_rel("Pad", PadRel)
210+
.set_attr<FInferCorrectLayout>("FInferCorrectLayout", PadInferCorrectLayout)
136211
.set_attr<TOpPattern>("TOpPattern", kInjective)
137212
.set_attr<FTVMCompute>("FTVMCompute", PadCompute);
138213

src/relay/op/nn/pooling.cc

Lines changed: 2 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -47,15 +47,9 @@ Array<Array<Layout> > Pool2DInferCorrectLayout(
4747
T *params = const_cast<T*>(attrs.as<T>());
4848

4949
if (new_in_layouts.defined()) {
50+
// Set the pool with the new layout.
5051
CHECK_EQ(new_in_layouts.size(), 1);
51-
52-
Layout raw_layout(params->layout);
53-
Layout input = new_in_layouts[0];
54-
if (input.IndexOf(LayoutAxis::Get('W')) == raw_layout.IndexOf(LayoutAxis::Get('W')) &&
55-
input.IndexOf(LayoutAxis::Get('H')) == raw_layout.IndexOf(LayoutAxis::Get('H')) &&
56-
!input.Contains(LayoutAxis::Get('w')) && !input.Contains(LayoutAxis::Get('h'))) {
57-
params->layout = input.name(); // modify self to follow the input layout
58-
}
52+
params->layout = new_in_layouts[0].name();
5953
}
6054

6155
Layout inferred_layout(params->layout);

src/relay/op/tensor/reduce.cc

Lines changed: 56 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,9 +6,9 @@
66
* to you under the Apache License, Version 2.0 (the
77
* "License"); you may not use this file except in compliance
88
* with the License. You may obtain a copy of the License at
9-
*
9+
*
1010
* http://www.apache.org/licenses/LICENSE-2.0
11-
*
11+
*
1212
* Unless required by applicable law or agreed to in writing,
1313
* software distributed under the License is distributed on an
1414
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
@@ -119,6 +119,59 @@ Array<Integer> GetExcludeAxes(size_t indim,
119119
return r_axes;
120120
}
121121

122+
// Return the modified layout for AlterOpLayout pass.
123+
Array<Array<Layout>> ReduceInferCorrectLayout(const Attrs& attrs,
124+
const Array<Layout>& new_in_layouts,
125+
const Array<Layout>& old_in_layouts,
126+
const Array<Array<IndexExpr>>& old_in_shapes) {
127+
// NOTE: Discard "const" qualifier here.
128+
ReduceAttrs* params = const_cast<ReduceAttrs*>(attrs.as<ReduceAttrs>());
129+
130+
// Get the reduce axes.
131+
uint32_t indim = old_in_shapes[0].size();
132+
auto r_axes = GetReduceAxes(indim, params->axis, params->exclude);
133+
134+
Layout ret = Layout::Undef();
135+
if (new_in_layouts.defined() && r_axes.size()) {
136+
// Adapt to new layout. The axis has to change. Record original reduce axes. Convert to the
137+
// modified layout axes.
138+
CHECK_EQ(new_in_layouts.size(), 1);
139+
CHECK_EQ(old_in_layouts.size(), 1);
140+
141+
// 1) Collect the original axes
142+
std::unordered_set<std::string> old_r_dims;
143+
for (auto r_axis : r_axes) {
144+
old_r_dims.emplace(old_in_layouts[0][r_axis].name());
145+
}
146+
147+
// 2) Collect the new axes by walking new_layout.
148+
tvm::Array<tvm::Integer> new_r_axes;
149+
std::string new_layout_string = "";
150+
int axis_index = 0;
151+
for (auto iter_var : new_in_layouts[0]->axes) {
152+
const auto& layout_axis = LayoutAxis::Get(iter_var);
153+
const std::string& layout_dim = layout_axis.name();
154+
if (old_r_dims.count(layout_dim)) {
155+
new_r_axes.push_back(tvm::Integer(axis_index));
156+
}
157+
// Collect only the primal axis.
158+
if (layout_axis.IsPrimal()) {
159+
new_layout_string += layout_dim;
160+
axis_index++;
161+
}
162+
}
163+
164+
// 3) Set the new axis and layout.
165+
ret = Layout(new_layout_string);
166+
params->axis = new_r_axes;
167+
} else if (old_in_layouts.defined()) {
168+
// If the new layout is undefined, set the old layout as the inferred layout.
169+
CHECK_EQ(old_in_layouts.size(), 1);
170+
ret = old_in_layouts[0];
171+
}
172+
173+
return Array<Array<Layout>>{{ret}, {ret}};
174+
}
122175

123176
template<typename F>
124177
Array<Tensor> ReduceCompute(const Attrs& attrs,
@@ -325,6 +378,7 @@ Example::
325378
.set_attrs_type_key("relay.attrs.ReduceAttrs")
326379
.set_support_level(4)
327380
.add_type_rel("Reduce", ReduceRel)
381+
.set_attr<FInferCorrectLayout>("FInferCorrectLayout", ReduceInferCorrectLayout)
328382
.set_attr<FTVMCompute>("FTVMCompute", SumCompute)
329383
.set_attr<TOpPattern>("TOpPattern", kCommReduce);
330384

src/relay/op/tensor/transform.cc

Lines changed: 19 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -283,22 +283,34 @@ Array<Array<Layout>> ConcatenateLayout(
283283
const Array<Layout>& new_in_layouts,
284284
const Array<Layout>& old_in_layouts,
285285
const Array<Array<IndexExpr>> &old_in_shapes) {
286-
const ConcatenateAttrs* param = attrs.as<ConcatenateAttrs>();
286+
ConcatenateAttrs* param = const_cast<ConcatenateAttrs*>(attrs.as<ConcatenateAttrs>());
287287

288288
size_t axis = param->axis < 0 ? param->axis + old_in_shapes[0].size() :
289289
static_cast<size_t>(param->axis);
290290

291291
Layout ret;
292+
bool is_new_layout_selected = false;
292293
if (new_in_layouts.defined()) { // this function is called after some operators are alternated.
294+
// If all the new input layouts are same, the new in layout gets selected. For axis, the new
295+
// axis in the new layout is identified. The param->axis is then modified on the fly to conform
296+
// to the new input layout.
293297
const auto& concate_dim = old_in_layouts[0][axis];
294-
for (size_t i = 0; i < new_in_layouts.size(); ++i) {
295-
if (new_in_layouts[i].ndim() > axis &&
296-
new_in_layouts[i][axis] == concate_dim) {
297-
ret = new_in_layouts[i];
298-
break;
298+
bool all_input_layouts_same = true;
299+
for (auto new_layout : new_in_layouts) {
300+
if (!new_layout.Equals(new_in_layouts[0])) {
301+
all_input_layouts_same = false;
299302
}
300303
}
301-
} else { // this function is called on the original correct relay ir
304+
if (all_input_layouts_same) {
305+
auto new_index = new_in_layouts[0].IndexOf(concate_dim);
306+
ret = new_in_layouts[0];
307+
param->axis = new_index;
308+
is_new_layout_selected = true;
309+
}
310+
}
311+
312+
if (!is_new_layout_selected) {
313+
// this function is called on the original correct relay ir
302314
for (size_t i = 0; i < old_in_layouts.size(); ++i) {
303315
if (old_in_layouts[i].defined()) {
304316
ret = old_in_layouts[i];

0 commit comments

Comments
 (0)