|
6 | 6 | * to you under the Apache License, Version 2.0 (the |
7 | 7 | * "License"); you may not use this file except in compliance |
8 | 8 | * with the License. You may obtain a copy of the License at |
9 | | - * |
| 9 | + * |
10 | 10 | * http://www.apache.org/licenses/LICENSE-2.0 |
11 | | - * |
| 11 | + * |
12 | 12 | * Unless required by applicable law or agreed to in writing, |
13 | 13 | * software distributed under the License is distributed on an |
14 | 14 | * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
@@ -36,6 +36,80 @@ namespace relay { |
36 | 36 | // relay.nn.pad |
37 | 37 | TVM_REGISTER_NODE_TYPE(PadAttrs); |
38 | 38 |
|
| 39 | +Array<Array<Layout> > PadInferCorrectLayout( |
| 40 | + const Attrs& attrs, |
| 41 | + const Array<Layout>& new_in_layouts, |
| 42 | + const Array<Layout>& old_in_layouts, |
| 43 | + const Array<Array<IndexExpr>> &old_in_shapes) { |
| 44 | + // NOTE: Discard "const" qualifier here. |
| 45 | + PadAttrs *params = const_cast<PadAttrs*>(attrs.as<PadAttrs>()); |
| 46 | + |
| 47 | + Layout ret; |
| 48 | + // If new_in_layouts are defined, this code tries to modify the layout. |
| 49 | + bool is_layout_modified = new_in_layouts.defined(); |
| 50 | + if (new_in_layouts.defined()) { |
| 51 | + // Create a map of axis to param_width. For the new layout, a new param_width is generated using |
| 52 | + // the map. The new layout is rejected, if the padding is happening along the axis which was |
| 53 | + // split. |
| 54 | + |
| 55 | + // 1) Create a map from axis to param_width using old layout. |
| 56 | + std::map<std::string, tvm::Array<tvm::Expr>> axis_pad_width; |
| 57 | + int index_counter = 0; |
| 58 | + CHECK_EQ(new_in_layouts.size(), 1); |
| 59 | + for (auto iter_var : old_in_layouts[0]->axes) { |
| 60 | + const auto& old_layout_axis = LayoutAxis::Get(iter_var); |
| 61 | + axis_pad_width.emplace(old_layout_axis.name(), params->pad_width[index_counter]); |
| 62 | + index_counter++; |
| 63 | + } |
| 64 | + |
| 65 | + // 2) Create new pad width by walking over the new layout and using the map. |
| 66 | + tvm::Array<tvm::Array<tvm::Expr>> new_pad_width; |
| 67 | + for (auto iter_var : new_in_layouts[0]->axes) { |
| 68 | + const auto& new_layout_axis = LayoutAxis::Get(iter_var); |
| 69 | + auto axis_name = new_layout_axis.name(); |
| 70 | + if (axis_pad_width.count(axis_name) != 0 && new_layout_axis.IsPrimal()) { |
| 71 | + // This is primal axis. So, directly use the original pad_width. |
| 72 | + new_pad_width.push_back(axis_pad_width.at(axis_name)); |
| 73 | + } else { |
| 74 | + // This is the axis that got split. So, check that pad_width was [0, 0] originally. |
| 75 | + const auto& dual_axis = new_layout_axis.ToPrimal(); |
| 76 | + auto dual_axis_name = dual_axis.name(); |
| 77 | + CHECK(axis_pad_width.count(dual_axis_name)); |
| 78 | + new_pad_width.push_back(axis_pad_width.at(dual_axis_name)); |
| 79 | + |
| 80 | + // If all pad_width elements are not zero, do not change the layout. |
| 81 | + for (auto width : axis_pad_width.at(dual_axis_name)) { |
| 82 | + if (auto* width_imm = width.as<IntImm>()) { |
| 83 | + if (width_imm->value != 0) { |
| 84 | + is_layout_modified = false; |
| 85 | + } |
| 86 | + } else { |
| 87 | + is_layout_modified = false; |
| 88 | + } |
| 89 | + } |
| 90 | + } |
| 91 | + } |
| 92 | + |
| 93 | + // If the above conditions satisfied, we can set the newly created pad_width and use the new |
| 94 | + // layout. |
| 95 | + if (is_layout_modified) { |
| 96 | + ret = new_in_layouts[0]; |
| 97 | + params->pad_width = new_pad_width; |
| 98 | + } |
| 99 | + } |
| 100 | + |
| 101 | + if (!is_layout_modified) { |
| 102 | + if (old_in_layouts.defined()) { |
| 103 | + CHECK_EQ(old_in_layouts.size(), 1); |
| 104 | + ret = old_in_layouts[0]; |
| 105 | + } else { |
| 106 | + ret = Layout::Undef(); |
| 107 | + } |
| 108 | + } |
| 109 | + |
| 110 | + return Array<Array<Layout> >{{ret}, {ret}}; |
| 111 | +} |
| 112 | + |
39 | 113 | bool PadRel(const Array<Type>& types, |
40 | 114 | int num_inputs, |
41 | 115 | const Attrs& attrs, |
@@ -133,6 +207,7 @@ RELAY_REGISTER_OP("nn.pad") |
133 | 207 | .add_argument("data", "Tensor", "The input tensor.") |
134 | 208 | .set_support_level(2) |
135 | 209 | .add_type_rel("Pad", PadRel) |
| 210 | +.set_attr<FInferCorrectLayout>("FInferCorrectLayout", PadInferCorrectLayout) |
136 | 211 | .set_attr<TOpPattern>("TOpPattern", kInjective) |
137 | 212 | .set_attr<FTVMCompute>("FTVMCompute", PadCompute); |
138 | 213 |
|
|
0 commit comments