|
29 | 29 | #include <vector> |
30 | 30 |
|
31 | 31 | #include "../../pass/alter_op_layout.h" |
| 32 | +#include "convolution.h" |
32 | 33 |
|
33 | 34 | namespace tvm { |
34 | 35 | namespace relay { |
35 | 36 |
|
36 | 37 | // relay.nn.conv2d |
37 | 38 | TVM_REGISTER_NODE_TYPE(Conv2DAttrs); |
38 | 39 |
|
39 | | -bool Conv2DRel(const Array<Type>& types, |
40 | | - int num_inputs, |
41 | | - const Attrs& attrs, |
42 | | - const TypeReporter& reporter) { |
43 | | - CHECK_EQ(types.size(), 3); |
44 | | - const auto* data = types[0].as<TensorTypeNode>(); |
45 | | - const auto* weight = types[1].as<TensorTypeNode>(); |
46 | | - if (data == nullptr) return false; |
47 | | - static const Layout kNCHW("NCHW"); |
48 | | - static const Layout kOIHW("OIHW"); |
49 | | - |
50 | | - const Conv2DAttrs* param = attrs.as<Conv2DAttrs>(); |
51 | | - CHECK(param != nullptr); |
52 | | - const Layout in_layout(param->data_layout); |
53 | | - const Layout kernel_layout(param->kernel_layout); |
54 | | - |
55 | | - const auto trans_in_layout = BijectiveLayoutNode::make(in_layout, kNCHW); |
56 | | - CHECK(trans_in_layout.defined()) |
57 | | - << "Conv only support input layouts that are convertible from NCHW." |
58 | | - << " But got " << in_layout; |
59 | | - |
60 | | - const auto trans_kernel_layout = BijectiveLayoutNode::make(kernel_layout, kOIHW); |
61 | | - CHECK(trans_kernel_layout.defined()) |
62 | | - << "Conv only support kernel layouts that are convertible from OIHW." |
63 | | - << " But got "<< kernel_layout; |
64 | | - |
65 | | - Layout out_layout(param->out_layout == "" ? param->data_layout : param->out_layout); |
66 | | - const auto trans_out_layout = BijectiveLayoutNode::make(out_layout, kNCHW); |
67 | | - CHECK(trans_out_layout.defined()) |
68 | | - << "Conv only support output layouts that are convertible from NCHW." |
69 | | - << " But got " << out_layout; |
70 | | - |
71 | | - Array<IndexExpr> dshape_nchw = trans_in_layout.ForwardShape(data->shape); |
72 | | - |
73 | | - IndexExpr channels, dilated_ksize_y, dilated_ksize_x; |
74 | | - // infer weight if the kernel_size and channels are defined |
75 | | - if (param->kernel_size.defined() && param->channels.defined()) { |
76 | | - CHECK_EQ(param->kernel_size.size(), 2); |
77 | | - CHECK_EQ(param->dilation.size(), 2); |
78 | | - Array<IndexExpr> wshape; |
79 | | - |
80 | | - if (tvm::ir::Equal(param->channels, param->groups)) { |
81 | | - // infer weight's shape for depthwise convolution |
82 | | - wshape = { |
83 | | - {dshape_nchw[1], |
84 | | - param->groups / dshape_nchw[1], |
85 | | - param->kernel_size[0], |
86 | | - param->kernel_size[1]}}; |
87 | | - } else { |
88 | | - wshape = { |
89 | | - {param->channels, |
90 | | - dshape_nchw[1] / param->groups, |
91 | | - param->kernel_size[0], |
92 | | - param->kernel_size[1]}}; |
93 | | - } |
94 | | - |
95 | | - wshape = trans_kernel_layout.BackwardShape(wshape); |
96 | | - channels = param->channels; |
97 | | - dilated_ksize_y = 1 + (param->kernel_size[0] - 1) * param->dilation[0]; |
98 | | - dilated_ksize_x = 1 + (param->kernel_size[1] - 1) * param->dilation[1]; |
99 | | - DataType weight_dtype = data->dtype; |
100 | | - if (weight != nullptr) { |
101 | | - weight_dtype = weight->dtype; |
102 | | - } |
103 | | - // assign result to reporter |
104 | | - reporter->Assign(types[1], TensorTypeNode::make(wshape, weight_dtype)); |
105 | | - } else { |
106 | | - // use weight to infer the conv shape. |
107 | | - if (weight == nullptr) return false; |
108 | | - auto wshape = trans_kernel_layout.ForwardShape(weight->shape); |
109 | | - if (param->kernel_size.defined()) { |
110 | | - CHECK_EQ(param->kernel_size.size(), 2); |
111 | | - // check the size |
112 | | - CHECK(reporter->AssertEQ(param->kernel_size[0], wshape[2]) && |
113 | | - reporter->AssertEQ(param->kernel_size[1], wshape[3])) |
114 | | - << "Conv2D: shape of weight is inconsistent with kernel_size, " |
115 | | - << " kernel_size=" << param->kernel_size |
116 | | - << " wshape=" << wshape; |
117 | | - } |
118 | | - if (param->channels.defined()) { |
119 | | - CHECK(reporter->AssertEQ(param->channels, wshape[0])) |
120 | | - << "Conv2D: shape of weight is inconsistent with channels, " |
121 | | - << " channels=" << param->channels |
122 | | - << " wshape=" << wshape; |
123 | | - } |
124 | | - CHECK(reporter->AssertEQ(dshape_nchw[1] / param->groups, wshape[1])); |
125 | | - channels = wshape[0]; |
126 | | - dilated_ksize_y = 1 + (wshape[2] - 1) * param->dilation[0]; |
127 | | - dilated_ksize_x = 1 + (wshape[3] - 1) * param->dilation[1]; |
128 | | - } |
129 | | - // dilation |
130 | | - Array<IndexExpr> oshape({dshape_nchw[0], channels, 0, 0}); |
131 | | - |
132 | | - oshape.Set(2, (dshape_nchw[2] + param->padding[0] * 2 - dilated_ksize_y) / param->strides[0] + 1); |
133 | | - oshape.Set(3, (dshape_nchw[3] + param->padding[1] * 2 - dilated_ksize_x) / param->strides[1] + 1); |
134 | | - DataType out_dtype = param->out_dtype; |
135 | | - if (out_dtype.bits() == 0) { |
136 | | - out_dtype = data->dtype; |
137 | | - } |
138 | | - oshape = trans_out_layout.BackwardShape(oshape); |
139 | | - // assign output type |
140 | | - reporter->Assign(types[2], TensorTypeNode::make(oshape, out_dtype)); |
141 | | - return true; |
142 | | -} |
143 | | - |
144 | 40 | template<typename T> |
145 | 41 | Array<Array<Layout> > Conv2DInferCorrectLayout( |
146 | 42 | const Attrs& attrs, |
@@ -208,7 +104,7 @@ with the layer input to produce a tensor of outputs. |
208 | 104 | .add_argument("data", "Tensor", "The input tensor.") |
209 | 105 | .add_argument("weight", "Tensor", "The weight tensor.") |
210 | 106 | .set_support_level(2) |
211 | | -.add_type_rel("Conv2D", Conv2DRel) |
| 107 | +.add_type_rel("Conv2D", Conv2DRel<Conv2DAttrs>) |
212 | 108 | .set_attr<FInferCorrectLayout>("FInferCorrectLayout", Conv2DInferCorrectLayout<Conv2DAttrs>); |
213 | 109 |
|
214 | 110 |
|
@@ -770,7 +666,7 @@ RELAY_REGISTER_OP("nn.contrib_depthwise_conv2d_NCHWc") |
770 | 666 | .add_argument("data", "Tensor", "The input tensor.") |
771 | 667 | .add_argument("weight", "Tensor", "The weight tensor.") |
772 | 668 | .set_support_level(10) |
773 | | -.add_type_rel("Conv2D", Conv2DRel) |
| 669 | +.add_type_rel("Conv2D", Conv2DRel<Conv2DAttrs>) |
774 | 670 | .set_attr<FInferCorrectLayout>("FInferCorrectLayout", |
775 | 671 | Conv2DInferCorrectLayout<Conv2DAttrs>); |
776 | 672 |
|
|
0 commit comments