Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
136 changes: 55 additions & 81 deletions topi/python/topi/cuda/conv2d_transpose_nchw.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,11 +21,11 @@
from tvm import autotvm
from tvm.autotvm.task.space import SplitEntity, OtherOptionEntity
from .. import nn, generic
from ..util import equal_const_int, get_const_tuple, traverse_inline
from ..util import get_const_tuple, traverse_inline


@autotvm.task.register_topi_compute(nn.conv2d_transpose_nchw, ['cuda', 'gpu'], "direct")
def conv2d_transpose_nchw_cuda(cfg, Input, Filter, strides, padding, out_dtype):
def conv2d_transpose_nchw_cuda(cfg, data, kernel, stride, padding, out_dtype):
"""Transposed 2D convolution nchw forward operator.

Parameters
Expand All @@ -48,67 +48,58 @@ def conv2d_transpose_nchw_cuda(cfg, Input, Filter, strides, padding, out_dtype):
Output : tvm.Tensor
4-D with shape [batch, out_channel, out_height, out_width]
"""
batch, in_c, in_h, in_w = get_const_tuple(Input.shape)
_, out_c, filter_h, filter_w = get_const_tuple(Filter.shape)
stride_h, stride_w = strides

# attach stride info to config, this is used in schedule space definition
cfg.stride = strides

# padding stage
fpad_top, fpad_left, fpad_bottom, fpad_right = nn.get_pad_tuple(padding, (filter_h, filter_w))
bpad_top = filter_h - 1 - fpad_top
bpad_bottom = filter_h - 1 - fpad_bottom
bpad_left = filter_w - 1 - fpad_left
bpad_right = filter_w - 1 - fpad_right

# padding stage
FirstPad = nn.pad(Input,
[0, 0, (bpad_top + stride_h - 1) // stride_h,
(bpad_left + stride_w - 1) // stride_w],
[0, 0, (bpad_bottom + stride_h - 1) // stride_h,
(bpad_right + stride_w - 1) // stride_w], name='FirstPad')

idxdiv = tvm.indexdiv
idxmod = tvm.indexmod
# remove extra padding introduced by dilatation
border_h = idxmod(stride_h - idxmod(bpad_top, stride_h), stride_h)
border_w = idxmod(stride_w - idxmod(bpad_left, stride_w), stride_w)

# dilation stage
data = FirstPad
strides = [1, 1, stride_h, stride_w]
n = len(data.shape)

def _dilate(*indices):
not_zero = []
index_tuple = []
for i in range(n):
if not equal_const_int(strides[i], 1):
index_tuple.append(idxdiv(indices[i], strides[i]))
not_zero.append(idxmod(indices[i], strides[i]).equal(0))
else:
index_tuple.append(indices[i])
if not_zero:
not_zero = tvm.all(*not_zero)
return tvm.if_then_else(not_zero, data(*index_tuple), tvm.const(0.0, data.dtype))
return data(*index_tuple)

# convolution stage
out_h = (in_h - 1) * stride_h - fpad_top - fpad_bottom + filter_h
out_w = (in_w - 1) * stride_w - fpad_left - fpad_right + filter_w
dc = tvm.reduce_axis((0, in_c), name='dc')
dh = tvm.reduce_axis((0, filter_h), name='dh')
dw = tvm.reduce_axis((0, filter_w), name='dw')

Output = tvm.compute(
(batch, out_c, out_h, out_w),
batch, inp_channels, inp_height, inp_width = get_const_tuple(data.shape)
_, out_channels, kernel_height, kernel_width = get_const_tuple(kernel.shape)
stride_height, stride_width = stride
cfg.stride = stride
pad_top, pad_left, pad_bottom, pad_right = nn.get_pad_tuple(
padding, (kernel_height, kernel_width))

out_width = (inp_width - 1) * stride_width + \
kernel_width - pad_left - pad_right
pad_left = kernel_width - 1 - pad_left
pad_right = kernel_width - 1 - pad_right
dilated_width = stride_width * (inp_width - 1) + 1

out_height = (inp_height - 1) * stride_height + \
kernel_height - pad_top - pad_bottom
pad_top = kernel_height - 1 - pad_top
pad_bottom = kernel_height - 1 - pad_bottom
dilated_height = stride_height * (inp_height - 1) + 1

# compute pad
data = tvm.compute(
(batch, inp_channels,
pad_top + dilated_height + pad_bottom,
pad_left + dilated_width + pad_right),
lambda n, c, y, x: tvm.if_then_else(
tvm.all(x >= pad_left,
x < pad_left + dilated_width,
tvm.indexmod(x - pad_left, stride_width).equal(0),
y >= pad_top,
y < pad_top + dilated_height,
tvm.indexmod(y - pad_top, stride_height).equal(0)),
data[n, c,
tvm.indexdiv(y - pad_top, stride_height),
tvm.indexdiv(x - pad_left, stride_width)],
tvm.const(0., "float32")),
name='data_pad')

# compute transposed conv
dc = tvm.reduce_axis((0, inp_channels), name='dc')
dh = tvm.reduce_axis((0, kernel_height), name='dh')
dw = tvm.reduce_axis((0, kernel_width), name='dw')
data_out = tvm.compute(
(batch, out_channels, out_height, out_width),
lambda b, c, h, w: tvm.sum(
_dilate(b, dc, h + dh + border_h, w + dw + border_w).astype(out_dtype) *
Filter[dc, c, filter_h - 1 - dh, filter_w - 1 - dw].astype(out_dtype),
data[b, dc, h + dh, w + dw].astype(out_dtype) *
kernel[dc,
c,
kernel_height - 1 - dh,
kernel_width - 1 - dw].astype(out_dtype),
axis=[dc, dh, dw]), tag="conv2d_transpose_nchw")

return Output
return data_out

@autotvm.task.register_topi_schedule(generic.schedule_conv2d_transpose_nchw,
['cuda', 'gpu'], 'direct')
Expand Down Expand Up @@ -140,7 +131,8 @@ def _fallback_schedule(N, F, Y, X):
else:
cfg["tile_n"] = SplitEntity([1, 1, 1, 1])
# split F (output channel dimension)
cfg["tile_f"] = SplitEntity([-1, 1, 64, 1])
if F > 1:
cfg["tile_f"] = SplitEntity([-1, 1, 64, 1])
# split Y (height dimension)
y_split_factor = 1
for candidate in range(5, 17):
Expand Down Expand Up @@ -185,26 +177,8 @@ def _callback(op):
cfg.define_knob("unroll_explicit", [0, 1])

if cfg.is_fallback:
ko = int(kernel.shape[1])
kh = int(kernel.shape[2])
kw = int(kernel.shape[3])
stride_h, stride_w = cfg.stride
# Workaround to make CUDA compilation work. Issue #4470
# TODO make _fallback_schedule work for all kernel/strides combinations
# after issue #4470 is resolved
do_fallback = True
if ko == 1:
do_fallback = False
elif (kh, kw) == (1, 1):
do_fallback = True
elif (stride_h, stride_w) == (2, 2):
do_fallback = False
elif (kh, kw) == (stride_h, stride_w):
do_fallback = False

if do_fallback:
N, F, Y, X = get_const_tuple(conv.shape)
_fallback_schedule(N, F, Y, X)
N, F, Y, X = get_const_tuple(conv.shape)
_fallback_schedule(N, F, Y, X)

##### space definition end #####

Expand Down
29 changes: 19 additions & 10 deletions topi/tests/python/test_topi_conv2d_transpose_nchw.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,10 +25,13 @@
from common import get_all_backend

def verify_conv2d_transpose_nchw(batch, in_channel, in_size, num_filter, kernel, stride, padding):
in_height = in_width = in_size
in_height, in_width = in_size
kernel_height, kernel_width = kernel
stride_height, stride_width = stride
pad_top, pad_left, pad_bottom, pad_right = padding

A = tvm.placeholder((batch, in_channel, in_height, in_width), name='A')
W = tvm.placeholder((in_channel, num_filter, kernel, kernel), name='W')
W = tvm.placeholder((in_channel, num_filter, kernel_height, kernel_width), name='W')

a_shape = get_const_tuple(A.shape)
w_shape = get_const_tuple(W.shape)
Expand All @@ -51,7 +54,10 @@ def check_device(device):
return
print("Running on target: %s" % device)
with tvm.target.create(device):
B = topi.nn.conv2d_transpose_nchw(A, W, [stride, stride], [padding, padding], A.dtype)
B = topi.nn.conv2d_transpose_nchw(A, W,
[stride_height, stride_width],
[pad_top, pad_left, pad_bottom, pad_right],
A.dtype)
C = topi.nn.relu(B)
s1 = topi.generic.schedule_conv2d_transpose_nchw([B])
s2 = topi.generic.schedule_conv2d_transpose_nchw([C])
Expand All @@ -66,18 +72,21 @@ def check_device(device):
func2(a, w, c)
tvm.testing.assert_allclose(b.asnumpy(), b_np, rtol=1e-5)
tvm.testing.assert_allclose(c.asnumpy(), c_np, rtol=1e-5)

for device in get_all_backend():
check_device(device)


def test_conv2d_transpose_nchw():
verify_conv2d_transpose_nchw(1, 3, 224, 32, 3, 1, 0)
verify_conv2d_transpose_nchw(1, 3, 224, 32, 3, 2, 1)
verify_conv2d_transpose_nchw(1, 3, 224, 32, 2, 2, 0)
verify_conv2d_transpose_nchw(1, 32, 32, 128, 5, 1, 0)
verify_conv2d_transpose_nchw(1, 32, 32, 128, 5, 2, 1)

verify_conv2d_transpose_nchw(1, 3, (224, 224), 1, (1, 1), (1, 1), (0, 0, 0, 0))
verify_conv2d_transpose_nchw(1, 3, (224, 224), 32, (3, 3), (1, 1), (0, 0, 0, 0))
verify_conv2d_transpose_nchw(1, 3, (224, 224), 32, (3, 3), (3, 3), (0, 0, 0, 0))
verify_conv2d_transpose_nchw(1, 3, (224, 224), 32, (3, 3), (1, 1), (0, 0, 0, 0))
verify_conv2d_transpose_nchw(1, 3, (224, 224), 32, (3, 3), (2, 2), (1, 1, 1, 1))
verify_conv2d_transpose_nchw(1, 3, (224, 224), 32, (2, 2), (2, 2), (0, 0, 0, 0))
verify_conv2d_transpose_nchw(1, 32, (32, 32), 128, (5, 5), (1, 1), (0, 0, 0, 0))
verify_conv2d_transpose_nchw(1, 32, (32, 32), 128, (5, 5), (2, 2), (1, 1, 1, 1))
verify_conv2d_transpose_nchw(16, 32, (8192, 1), 8, (31, 1), (2, 1), (14, 0, 15, 0))
verify_conv2d_transpose_nchw(16, 512, (8, 1), 128, (31, 1), (2, 1), (14, 0, 15, 0))

if __name__ == "__main__":
test_conv2d_transpose_nchw()