2727#include < tvm/relay/attrs/nn.h>
2828#include < vector>
2929
30+ #include " nn.h"
3031#include " ../../pass/alter_op_layout.h"
3132
3233namespace tvm {
@@ -35,99 +36,6 @@ namespace relay {
3536// relay.nn.conv2d
3637TVM_REGISTER_NODE_TYPE (Conv2DAttrs);
3738
38- bool Conv2DRel (const Array<Type>& types,
39- int num_inputs,
40- const Attrs& attrs,
41- const TypeReporter& reporter) {
42- CHECK_EQ (types.size (), 3 );
43- const auto * data = types[0 ].as <TensorTypeNode>();
44- const auto * weight = types[1 ].as <TensorTypeNode>();
45- if (data == nullptr ) return false ;
46- static const Layout kNCHW (" NCHW" );
47- static const Layout kOIHW (" OIHW" );
48-
49- const Conv2DAttrs* param = attrs.as <Conv2DAttrs>();
50- CHECK (param != nullptr );
51- const Layout in_layout (param->data_layout );
52- const Layout kernel_layout (param->kernel_layout );
53-
54- const auto trans_in_layout = BijectiveLayoutNode::make (in_layout, kNCHW );
55- CHECK (trans_in_layout.defined ())
56- << " Conv only support input layouts that are convertible from NCHW."
57- << " But got " << in_layout;
58-
59- const auto trans_kernel_layout = BijectiveLayoutNode::make (kernel_layout, kOIHW );
60- CHECK (trans_kernel_layout.defined ())
61- << " Conv only support kernel layouts that are convertible from OIHW."
62- << " But got " << kernel_layout;
63-
64- Layout out_layout (param->out_layout == " " ? param->data_layout : param->out_layout );
65- const auto trans_out_layout = BijectiveLayoutNode::make (out_layout, kNCHW );
66- CHECK (trans_out_layout.defined ())
67- << " Conv only support output layouts that are convertible from NCHW."
68- << " But got " << out_layout;
69-
70- Array<IndexExpr> dshape_nchw = trans_in_layout.ForwardShape (data->shape );
71-
72- IndexExpr channels, dilated_ksize_y, dilated_ksize_x;
73- // infer weight if the kernel_size and channels are defined
74- if (param->kernel_size .defined () && param->channels .defined ()) {
75- CHECK_EQ (param->kernel_size .size (), 2 );
76- CHECK_EQ (param->dilation .size (), 2 );
77- Array<IndexExpr> wshape (
78- {param->channels ,
79- dshape_nchw[1 ] / param->groups ,
80- param->kernel_size [0 ],
81- param->kernel_size [1 ]});
82- wshape = trans_kernel_layout.BackwardShape (wshape);
83- channels = param->channels ;
84- dilated_ksize_y = 1 + (param->kernel_size [0 ] - 1 ) * param->dilation [0 ];
85- dilated_ksize_x = 1 + (param->kernel_size [1 ] - 1 ) * param->dilation [1 ];
86- DataType weight_dtype = data->dtype ;
87- if (weight != nullptr ) {
88- weight_dtype = weight->dtype ;
89- }
90- // assign result to reporter
91- reporter->Assign (types[1 ], TensorTypeNode::make (wshape, weight_dtype));
92- } else {
93- // use weight to infer the conv shape.
94- if (weight == nullptr ) return false ;
95- auto wshape = trans_kernel_layout.ForwardShape (weight->shape );
96- if (param->kernel_size .defined ()) {
97- CHECK_EQ (param->kernel_size .size (), 2 );
98- // check the size
99- CHECK (reporter->AssertEQ (param->kernel_size [0 ], wshape[2 ]) &&
100- reporter->AssertEQ (param->kernel_size [1 ], wshape[3 ]))
101- << " Conv2D: shape of weight is inconsistent with kernel_size, "
102- << " kernel_size=" << param->kernel_size
103- << " wshape=" << wshape;
104- }
105- if (param->channels .defined ()) {
106- CHECK (reporter->AssertEQ (param->channels , wshape[0 ]))
107- << " Conv2D: shape of weight is inconsistent with channels, "
108- << " channels=" << param->channels
109- << " wshape=" << wshape;
110- }
111- CHECK (reporter->AssertEQ (dshape_nchw[1 ] / param->groups , wshape[1 ]));
112- channels = wshape[0 ];
113- dilated_ksize_y = 1 + (wshape[2 ] - 1 ) * param->dilation [0 ];
114- dilated_ksize_x = 1 + (wshape[3 ] - 1 ) * param->dilation [1 ];
115- }
116- // dilation
117- Array<IndexExpr> oshape ({dshape_nchw[0 ], channels, 0 , 0 });
118-
119- oshape.Set (2 , (dshape_nchw[2 ] + param->padding [0 ] * 2 - dilated_ksize_y) / param->strides [0 ] + 1 );
120- oshape.Set (3 , (dshape_nchw[3 ] + param->padding [1 ] * 2 - dilated_ksize_x) / param->strides [1 ] + 1 );
121- DataType out_dtype = param->out_dtype ;
122- if (out_dtype.bits () == 0 ) {
123- out_dtype = data->dtype ;
124- }
125- oshape = trans_out_layout.BackwardShape (oshape);
126- // assign output type
127- reporter->Assign (types[2 ], TensorTypeNode::make (oshape, out_dtype));
128- return true ;
129- }
130-
13139template <typename T>
13240Array<Array<Layout> > Conv2DInferCorrectLayout (
13341 const Attrs& attrs,
@@ -195,7 +103,7 @@ with the layer input to produce a tensor of outputs.
195103.add_argument(" data" , " Tensor" , " The input tensor." )
196104.add_argument(" weight" , " Tensor" , " The weight tensor." )
197105.set_support_level(2 )
198- .add_type_rel(" Conv2D" , Conv2DRel)
106+ .add_type_rel(" Conv2D" , Conv2DRel<Conv2DAttrs> )
199107.set_attr<FInferCorrectLayout>(" FInferCorrectLayout" , Conv2DInferCorrectLayout<Conv2DAttrs>);
200108
201109
@@ -755,7 +663,7 @@ RELAY_REGISTER_OP("nn.contrib_depthwise_conv2d_NCHWc")
755663.add_argument(" data" , " Tensor" , " The input tensor." )
756664.add_argument(" weight" , " Tensor" , " The weight tensor." )
757665.set_support_level(10 )
758- .add_type_rel(" Conv2D" , Conv2DRel)
666+ .add_type_rel(" Conv2D" , Conv2DRel<Conv2DAttrs> )
759667.set_attr<FInferCorrectLayout>(" FInferCorrectLayout" ,
760668 Conv2DInferCorrectLayout<Conv2DAttrs>);
761669
0 commit comments