Skip to content

Commit d430fbb

Browse files
alexgl-githubvinx13
authored andcommitted
Implement 1d deconvolution (#4476)
1 parent 9384353 commit d430fbb

File tree

19 files changed

+853
-14
lines changed

19 files changed

+853
-14
lines changed

include/tvm/relay/attrs/nn.h

Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -315,6 +315,64 @@ struct Conv2DTransposeAttrs : public tvm::AttrsNode<Conv2DTransposeAttrs> {
315315
}
316316
};
317317

318+
/*! \brief Attributes used in 1D transposed convolution operator */
319+
struct Conv1DTransposeAttrs : public tvm::AttrsNode<Conv1DTransposeAttrs> {
320+
IndexExpr channels;
321+
Array<IndexExpr> kernel_size;
322+
Array<IndexExpr> strides;
323+
Array<IndexExpr> padding;
324+
Array<IndexExpr> output_padding;
325+
Array<IndexExpr> dilation;
326+
int groups;
327+
std::string data_layout;
328+
std::string kernel_layout;
329+
std::string out_layout;
330+
DataType out_dtype;
331+
332+
TVM_DECLARE_ATTRS(Conv1DTransposeAttrs, "relay.attrs.Conv1DTransposeAttrs") {
333+
TVM_ATTR_FIELD(channels)
334+
.set_default(NullValue<IndexExpr>())
335+
.describe("The dimensionality of the output space"
336+
"i.e. the number of output channels in the convolution.");
337+
TVM_ATTR_FIELD(kernel_size)
338+
.describe("The dimensions of the convolution window.")
339+
.set_default(NullValue<Array<IndexExpr> >());
340+
TVM_ATTR_FIELD(strides).set_default(Array<IndexExpr>({1}))
341+
.describe("The strides of the convolution.");
342+
TVM_ATTR_FIELD(output_padding).set_default(Array<IndexExpr>({0}))
343+
.describe("Zero-padding added to one side of the output.");
344+
TVM_ATTR_FIELD(padding).set_default(Array<IndexExpr>({0}))
345+
.describe("Symmetric or asymmetric padding."
346+
"Single value: the input is implicitly zero-padded on both sides."
347+
"Two values: padding[0] is used for left input padding, "
348+
"padding[1] is used for right input padding,");
349+
TVM_ATTR_FIELD(dilation).set_default(Array<IndexExpr>({1}))
350+
.describe("Specifies the dilation rate to use for dilated convolution.");
351+
TVM_ATTR_FIELD(groups).set_default(1)
352+
.describe("Controls the connections between inputs and outputs."
353+
"At groups=1, all inputs are convolved to all outputs."
354+
"At groups=2, the operation becomes equivalent to having two convolution"
355+
"layers side by side, each seeing half the input channels, and producing"
356+
"half the output channels, and both subsequently concatenated.");
357+
TVM_ATTR_FIELD(data_layout).set_default("NCW")
358+
.describe("Dimension ordering of data. Can be 'NCW', 'NWC', etc."
359+
"'N', 'C', 'W' stands for batch, channel, and width"
360+
"dimensions respectively. Convolution is applied on the"
361+
"'W' dimension.");
362+
TVM_ATTR_FIELD(kernel_layout).set_default("OIW")
363+
.describe("Dimension ordering of data and weight. Can be 'OIW', 'OIW16o16i', etc."
364+
"'O', 'I', 'W' stands for num_filter, input_channel, and width"
365+
"dimensions respectively.");
366+
TVM_ATTR_FIELD(out_layout).set_default("")
367+
.describe("Dimension ordering of output. Can be 'NCW', 'NWC', etc."
368+
"'N', 'C', 'W' stands for batch, channel, and width"
369+
"dimensions respectively. Default to be same as input layout.");
370+
TVM_ATTR_FIELD(out_dtype)
371+
.set_default(NullValue<DataType>())
372+
.describe("Output data type, set to explicit type under mixed precision setting");
373+
}
374+
};
375+
318376
/*! \brief Attributes for max pool operator */
319377
struct MaxPool2DAttrs : public tvm::AttrsNode<MaxPool2DAttrs> {
320378
Array<IndexExpr> pool_size;

python/tvm/autotvm/task/relay_integration.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -128,6 +128,7 @@ def extract_from_multiple_program(funcs, params, ops, target, target_host=None,
128128
tvm.relay.op.nn.dense: [topi.nn.dense],
129129
tvm.relay.op.nn.batch_matmul: [topi.nn.batch_matmul],
130130
tvm.relay.op.nn.deformable_conv2d: [topi.nn.deformable_conv2d_nchw],
131+
tvm.relay.op.nn.conv1d_transpose: [topi.nn.conv1d_transpose_ncw],
131132
}
132133

133134
topi_funcs = []

python/tvm/autotvm/task/topi_integration.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -92,6 +92,7 @@ def __init__(self, allow_duplicate=False):
9292
topi.nn.bitserial_conv2d_nhwc: "topi_nn_bitserial_conv2d_nhwc",
9393
topi.nn.bitserial_dense: "topi_nn_bitserial_dense",
9494
topi.nn.deformable_conv2d_nchw: "topi_nn_deformable_conv2d_nchw",
95+
topi.nn.conv1d_transpose_ncw: "topi_nn_conv1d_transpose_ncw",
9596
}
9697

9798
self.topi_to_schedule = {
@@ -109,6 +110,7 @@ def __init__(self, allow_duplicate=False):
109110
topi.nn.bitserial_conv2d_nhwc: [topi.generic.schedule_bitserial_conv2d_nhwc],
110111
topi.nn.bitserial_dense: [topi.generic.schedule_bitserial_dense],
111112
topi.nn.deformable_conv2d_nchw: [topi.generic.schedule_deformable_conv2d_nchw],
113+
topi.nn.conv1d_transpose_ncw: [topi.generic.schedule_conv1d_transpose_ncw],
112114
}
113115

114116
# function reflection for tracing
@@ -125,6 +127,7 @@ def __init__(self, allow_duplicate=False):
125127
topi.nn.bitserial_conv2d_nhwc: lambda x: setattr(topi.nn, 'bitserial_conv2d_nhwc', x),
126128
topi.nn.bitserial_dense: lambda x: setattr(topi.nn, 'bitserial_dense', x),
127129
topi.nn.deformable_conv2d_nchw: lambda x: setattr(topi.nn, 'deformable_conv2d_nchw', x),
130+
topi.nn.conv1d_transpose_ncw: lambda x: setattr(topi.nn, 'conv1d_transpose_ncw', x),
128131
}
129132

130133
self.allow_duplicate = allow_duplicate
@@ -214,6 +217,15 @@ def _topi_nn_conv2d_transpose_nchw(*args, **kwargs):
214217
s = topi.generic.schedule_conv2d_transpose_nchw([C])
215218
return s, [A, W, C]
216219

220+
@register("topi_nn_conv1d_transpose_ncw")
221+
def _topi_nn_conv1d_transpose_ncw(*args, **kwargs):
222+
assert not kwargs, "Do not support kwargs in template function call"
223+
args = deserialize_args(args)
224+
A, W = args[:2]
225+
C = topi.nn.conv1d_transpose_ncw(*args, **kwargs)
226+
s = topi.generic.schedule_conv1d_transpose_ncw([C])
227+
return s, [A, W, C]
228+
217229
@register("topi_nn_dense")
218230
def _topi_nn_dense(*args, **kwargs):
219231
assert not kwargs, "Do not support kwargs in template function call"

python/tvm/relay/_parser.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -141,6 +141,7 @@ def __call__(self, args, attrs, type_args):
141141
"nn.softmax": op.nn.softmax,
142142
"reshape": op.reshape,
143143
"nn.conv2d_transpose": op.nn.conv2d_transpose,
144+
"nn.conv1d_transpose": op.nn.conv1d_transpose,
144145
"concatenate": op.concatenate,
145146
"nn.dropout": op.nn.dropout_raw,
146147
"zeros": op.zeros,

python/tvm/relay/frontend/mxnet.py

Lines changed: 7 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -207,29 +207,23 @@ def _mx_conv1d_transpose(inputs, attrs):
207207
if data_layout != "NCW":
208208
raise tvm.error.OpAttributeInvalid(
209209
'Only "NCW" data layout is supported for 1D Convolution')
210-
data_layout = "NCHW"
211210
channel_axis = 1
212-
kernel_layout = "OIHW"
213-
211+
kernel_layout = "OIW"
214212
new_attrs = {}
215213
new_attrs["channels"] = attrs.get_int("num_filter")
216-
new_attrs["kernel_size"] = (1,) + attrs.get_int_tuple("kernel")
217-
new_attrs["strides"] = (1,) + attrs.get_int_tuple("stride", (1,))
218-
new_attrs["output_padding"] = (0,) + attrs.get_int_tuple("adj", (0,))
219-
new_attrs["padding"] = (0,) + attrs.get_int_tuple("pad", (0,))
220-
new_attrs["dilation"] = (1,) + attrs.get_int_tuple("dilate", (1,))
214+
new_attrs["kernel_size"] = attrs.get_int_tuple("kernel")
215+
new_attrs["strides"] = attrs.get_int_tuple("stride", (1,))
216+
new_attrs["output_padding"] = attrs.get_int_tuple("adj", (0,))
217+
new_attrs["padding"] = attrs.get_int_tuple("pad", (0,))
218+
new_attrs["dilation"] = attrs.get_int_tuple("dilate", (1,))
221219
new_attrs["groups"] = attrs.get_int("num_group", 1)
222220
new_attrs["data_layout"] = data_layout
223221
new_attrs["kernel_layout"] = kernel_layout
224222
use_bias = not attrs.get_bool("no_bias", True)
225-
data = _op.expand_dims(inputs[0], axis=2)
226-
kernel = _op.expand_dims(inputs[1], axis=2)
227-
res = _op.nn.conv2d_transpose(data, kernel, **new_attrs)
228-
223+
res = _op.nn.conv1d_transpose(inputs[0], inputs[1], **new_attrs)
229224
if use_bias:
230225
assert len(inputs) == 3
231226
res = _op.nn.bias_add(res, inputs[2], axis=channel_axis)
232-
res = _op.squeeze(res, axis=[2])
233227
return res
234228

235229

python/tvm/relay/op/nn/_nn.py

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -348,6 +348,37 @@ def legalize_conv2d_transpose(attrs, inputs, types):
348348

349349
reg.register_pattern("nn.conv2d_transpose", OpPattern.OUT_ELEMWISE_FUSABLE)
350350

351+
# conv1d_transpose
352+
@reg.register_compute("nn.conv1d_transpose")
353+
def compute_conv1d_transpose(attrs, inputs, out_dtype, target):
354+
"""Compute definition of conv1d_transpose"""
355+
padding = get_const_tuple(attrs.padding)
356+
strides = get_const_tuple(attrs.strides)
357+
dilation = get_const_tuple(attrs.dilation)
358+
groups = attrs.groups
359+
layout = attrs.data_layout
360+
out_dtype = attrs.out_dtype
361+
out_dtype = (inputs[0].dtype if out_dtype in ("same", "")
362+
else out_dtype)
363+
assert layout == "NCW", "conv1d_transpose ncw only supported"
364+
assert dilation == (1,), "conv1d_transpose dilation is not supported"
365+
assert groups == 1, "conv1d_transpose groups == 1 only supported"
366+
out = topi.nn.conv1d_transpose_ncw(
367+
inputs[0], inputs[1], strides, padding, out_dtype)
368+
output_padding = get_const_tuple(attrs.output_padding)
369+
out = topi.nn.pad(out,
370+
[0, 0, 0], [0, 0, output_padding[0]])
371+
return [out]
372+
373+
374+
@reg.register_schedule("nn.conv1d_transpose")
375+
def schedule_conv1d_transpose(attrs, outs, target):
376+
"""Schedule definition of conv1d_transpose"""
377+
with target:
378+
return topi.generic.schedule_conv1d_transpose_ncw(outs)
379+
380+
reg.register_pattern("nn.conv1d_transpose", OpPattern.OUT_ELEMWISE_FUSABLE)
381+
351382
# bias_add
352383
reg.register_schedule("nn.bias_add", schedule_injective)
353384
reg.register_pattern("nn.bias_add", OpPattern.BROADCAST)

python/tvm/relay/op/nn/nn.py

Lines changed: 66 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -257,6 +257,72 @@ def conv2d_transpose(data,
257257
kernel_layout, out_layout, output_padding, out_dtype)
258258

259259

260+
def conv1d_transpose(data,
261+
weight,
262+
strides=(1,),
263+
padding=(0,),
264+
dilation=(1,),
265+
groups=1,
266+
channels=None,
267+
kernel_size=None,
268+
data_layout="NCW",
269+
kernel_layout="OIW",
270+
out_layout="",
271+
output_padding=(0,),
272+
out_dtype=""):
273+
"""One dimensional transposed convolution operator.
274+
275+
Parameters
276+
----------
277+
data : tvm.relay.Expr
278+
The input data to the operator.
279+
280+
weight : tvm.relay.Expr
281+
The weight expressions.
282+
283+
strides : Tuple[int], optional
284+
The strides of convolution.
285+
286+
padding : Tuple[int], optional
287+
The padding of convolution on both sides of inputs.
288+
289+
dilation : Tuple[int], optional
290+
Specifies the dilation rate to be used for dilated convolution.
291+
292+
channels : int, optional
293+
Number of output channels of this convolution.
294+
295+
kernel_size : tuple of int, optional
296+
The spatial of the convolution kernel.
297+
298+
groups : int, optional
299+
Number of groups for grouped convolution.
300+
301+
data_layout : str, optional
302+
Layout of the input.
303+
304+
kernel_layout : str, optional
305+
Layout of the weight.
306+
307+
out_layout : Optional[str]
308+
Layout of the output, by default, out_layout is the same as data_layout
309+
310+
output_padding : Tuple[int], optional
311+
Additional zero-padding to be added to one side of the output.
312+
313+
out_dtype : str, optional
314+
Specifies the output data type for mixed precision conv2d.
315+
316+
Returns
317+
-------
318+
result : tvm.relay.Expr
319+
The computed result.
320+
"""
321+
return _make.conv1d_transpose(data, weight, strides, padding, dilation,
322+
groups, channels, kernel_size, data_layout,
323+
kernel_layout, out_layout, output_padding, out_dtype)
324+
325+
260326
def softmax(data, axis=-1):
261327
r"""Computes softmax.
262328

0 commit comments

Comments
 (0)