@@ -926,7 +926,7 @@ bool Conv1DTransposeRel(const Array<Type>& types, int num_inputs, const Attrs& a
926926 if (data == nullptr ) return false ;
927927
928928 static const Layout kNCW (" NCW" );
929- static const Layout kOIW ( " OIW " );
929+ static const Layout kIOW ( " IOW " );
930930
931931 const Conv1DTransposeAttrs* param = attrs.as <Conv1DTransposeAttrs>();
932932 ICHECK (param != nullptr );
@@ -938,9 +938,9 @@ bool Conv1DTransposeRel(const Array<Type>& types, int num_inputs, const Attrs& a
938938 << " Conv only support input layouts that are convertible from NCW."
939939 << " But got " << in_layout;
940940
941- const auto trans_kernel_layout = tir::BijectiveLayout (kernel_layout, kOIW );
941+ const auto trans_kernel_layout = tir::BijectiveLayout (kernel_layout, kIOW );
942942 ICHECK (trans_kernel_layout.defined ())
943- << " Conv only support kernel layouts that are convertible from OIW ."
943+ << " Conv only support kernel layouts that are convertible from IOW ."
944944 << " But got " << kernel_layout;
945945
946946 Layout out_layout (param->out_layout == " " ? param->data_layout : param->out_layout );
@@ -979,16 +979,18 @@ bool Conv1DTransposeRel(const Array<Type>& types, int num_inputs, const Attrs& a
979979 ICHECK_EQ (param->kernel_size .size (), 1 );
980980 // check the size
981981 ICHECK (reporter->AssertEQ (param->kernel_size [0 ], wshape[2 ]))
982- << " Conv1D : shape of weight is inconsistent with kernel_size, "
982+ << " Conv1DTraspose : shape of weight is inconsistent with kernel_size, "
983983 << " kernel_size=" << param->kernel_size << " wshape=" << Array<IndexExpr>(wshape);
984984 }
985985 if (param->channels .defined ()) {
986- ICHECK (reporter->AssertEQ (param->channels , wshape[1 ]))
987- << " Conv1D: shape of weight is inconsistent with channels, "
988- << " channels=" << param->channels << " wshape=" << Array<IndexExpr>(wshape);
986+ ICHECK (reporter->AssertEQ (indexdiv (param->channels , param->groups ), wshape[1 ]))
987+ << " Conv1DTraspose: shape of weight is inconsistent with channels, "
988+ << " out_channels // groups != weight.shape[1] "
989+ << " out_channels=" << param->channels << " groups=" << param->groups
990+ << " wshape=" << Array<IndexExpr>(wshape);
989991 }
990992 if (!dshape_ncw[1 ].as <tir::AnyNode>() && !wshape[0 ].as <tir::AnyNode>()) {
991- ICHECK (reporter->AssertEQ (indexdiv ( dshape_ncw[1 ], param-> groups ) , wshape[0 ]));
993+ ICHECK (reporter->AssertEQ (dshape_ncw[1 ], wshape[0 ]));
992994 }
993995 channels = wshape[1 ];
994996 dilated_ksize_x = 1 + (wshape[2 ] - 1 ) * param->dilation [0 ];
0 commit comments