Skip to content

Commit f5db8b7

Browse files
authored
[Bugfix] Conv1Dtranspose default kernel layout should be IOW (#14482)
* fix conv1Dtranspose kernel layout * fix conv1Dtranspose type checker * fix mxnet layout
1 parent f8f7bc8 commit f5db8b7

File tree

6 files changed

+26
-17
lines changed

6 files changed

+26
-17
lines changed

include/tvm/relay/attrs/nn.h

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -671,10 +671,10 @@ struct Conv1DTransposeAttrs : public tvm::AttrsNode<Conv1DTransposeAttrs> {
671671
"dimensions respectively. Convolution is applied on the"
672672
"'W' dimension.");
673673
TVM_ATTR_FIELD(kernel_layout)
674-
.set_default("OIW")
674+
.set_default("IOW")
675675
.describe(
676-
"Dimension ordering of data and weight. Can be 'OIW', 'OIW16o16i', etc."
677-
"'O', 'I', 'W' stands for num_filter, input_channel, and width"
676+
"Dimension ordering of data and weight. Can be 'IOW', 'IOW16o16i', etc."
677+
"'I', 'O', 'W' stands for input_channel, num_filter and width"
678678
"dimensions respectively.");
679679
TVM_ATTR_FIELD(out_layout)
680680
.set_default("")

python/tvm/relay/frontend/keras.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -282,6 +282,8 @@ def _convert_dense(
282282

283283

284284
def _convert_convolution1d(inexpr, keras_layer, etab, data_layout, input_shape=None):
285+
is_deconv = type(keras_layer).__name__ == "Conv1DTranspose"
286+
285287
if input_shape is None:
286288
input_shape = keras_layer.input_shape
287289
_check_data_format(keras_layer)
@@ -290,19 +292,21 @@ def _convert_convolution1d(inexpr, keras_layer, etab, data_layout, input_shape=N
290292

291293
if data_layout == "NWC":
292294
kernel_layout = "WIO"
295+
if is_deconv:
296+
kernel_layout = "WOI"
293297
else:
294298
kernel_layout = "OIW"
299+
if is_deconv:
300+
kernel_layout = "IOW"
295301
msg = (
296302
"Kernel layout with {} is not supported for operator Convolution1D "
297303
"in frontend Keras."
298304
)
299305
raise tvm.error.OpAttributeUnImplemented(msg.format(data_layout))
300306

301-
is_deconv = type(keras_layer).__name__ == "Conv1DTranspose"
302-
303307
if is_deconv:
304-
if kernel_layout == "OIW":
305-
weight = weight.transpose([2, 0, 1])
308+
if kernel_layout == "IOW":
309+
weight = weight.transpose([2, 1, 0])
306310
kernel_w, n_filters, _ = weight.shape
307311
else:
308312
kernel_w, _, n_filters = weight.shape

python/tvm/relay/frontend/mxnet.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -304,7 +304,7 @@ def _mx_conv1d_transpose(inputs, attrs):
304304
if data_layout != "NCW":
305305
raise tvm.error.OpAttributeInvalid('Only "NCW" data layout is supported for 1D Convolution')
306306
channel_axis = 1
307-
kernel_layout = "OIW"
307+
kernel_layout = "IOW"
308308
new_attrs = {}
309309
new_attrs["channels"] = attrs.get_int("num_filter")
310310
new_attrs["kernel_size"] = attrs.get_int_tuple("kernel")

python/tvm/relay/frontend/pytorch.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1263,6 +1263,9 @@ def convolution(self, inputs, input_types):
12631263
else:
12641264
data_layout = "NCW"
12651265
kernel_layout = "OIW"
1266+
if use_transpose:
1267+
# Transposed convolutions have IOW layout.
1268+
kernel_layout = "IOW"
12661269

12671270
# Conv1d does not currently support grouped convolution so we convert it to conv2d
12681271
is_grouped_conv1d = False

python/tvm/relay/op/nn/nn.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -604,7 +604,7 @@ def conv1d_transpose(
604604
channels=None,
605605
kernel_size=None,
606606
data_layout="NCW",
607-
kernel_layout="OIW",
607+
kernel_layout="IOW",
608608
out_layout="",
609609
output_padding=(0,),
610610
out_dtype="",

src/relay/op/nn/convolution.cc

Lines changed: 10 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -926,7 +926,7 @@ bool Conv1DTransposeRel(const Array<Type>& types, int num_inputs, const Attrs& a
926926
if (data == nullptr) return false;
927927

928928
static const Layout kNCW("NCW");
929-
static const Layout kOIW("OIW");
929+
static const Layout kIOW("IOW");
930930

931931
const Conv1DTransposeAttrs* param = attrs.as<Conv1DTransposeAttrs>();
932932
ICHECK(param != nullptr);
@@ -938,9 +938,9 @@ bool Conv1DTransposeRel(const Array<Type>& types, int num_inputs, const Attrs& a
938938
<< "Conv only support input layouts that are convertible from NCW."
939939
<< " But got " << in_layout;
940940

941-
const auto trans_kernel_layout = tir::BijectiveLayout(kernel_layout, kOIW);
941+
const auto trans_kernel_layout = tir::BijectiveLayout(kernel_layout, kIOW);
942942
ICHECK(trans_kernel_layout.defined())
943-
<< "Conv only support kernel layouts that are convertible from OIW."
943+
<< "Conv only support kernel layouts that are convertible from IOW."
944944
<< " But got " << kernel_layout;
945945

946946
Layout out_layout(param->out_layout == "" ? param->data_layout : param->out_layout);
@@ -979,16 +979,18 @@ bool Conv1DTransposeRel(const Array<Type>& types, int num_inputs, const Attrs& a
979979
ICHECK_EQ(param->kernel_size.size(), 1);
980980
// check the size
981981
ICHECK(reporter->AssertEQ(param->kernel_size[0], wshape[2]))
982-
<< "Conv1D: shape of weight is inconsistent with kernel_size, "
982+
<< "Conv1DTraspose: shape of weight is inconsistent with kernel_size, "
983983
<< " kernel_size=" << param->kernel_size << " wshape=" << Array<IndexExpr>(wshape);
984984
}
985985
if (param->channels.defined()) {
986-
ICHECK(reporter->AssertEQ(param->channels, wshape[1]))
987-
<< "Conv1D: shape of weight is inconsistent with channels, "
988-
<< " channels=" << param->channels << " wshape=" << Array<IndexExpr>(wshape);
986+
ICHECK(reporter->AssertEQ(indexdiv(param->channels, param->groups), wshape[1]))
987+
<< "Conv1DTraspose: shape of weight is inconsistent with channels, "
988+
<< " out_channels // groups != weight.shape[1] "
989+
<< " out_channels=" << param->channels << " groups=" << param->groups
990+
<< " wshape=" << Array<IndexExpr>(wshape);
989991
}
990992
if (!dshape_ncw[1].as<tir::AnyNode>() && !wshape[0].as<tir::AnyNode>()) {
991-
ICHECK(reporter->AssertEQ(indexdiv(dshape_ncw[1], param->groups), wshape[0]));
993+
ICHECK(reporter->AssertEQ(dshape_ncw[1], wshape[0]));
992994
}
993995
channels = wshape[1];
994996
dilated_ksize_x = 1 + (wshape[2] - 1) * param->dilation[0];

0 commit comments

Comments
 (0)