Skip to content

Commit ac2209e

Browse files
author
Alex Gladkov
committed
Improve CUDA conv2d_transpose_nchw
- combine pad and dilate; - fix for the issue https://discuss.tvm.ai/t/compile-error-for-cuda-target/4164 - fix for the issue apache#4472
1 parent e4d817d commit ac2209e

File tree

2 files changed

+74
-90
lines changed

2 files changed

+74
-90
lines changed

topi/python/topi/cuda/conv2d_transpose_nchw.py

Lines changed: 55 additions & 80 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@
2525

2626

2727
@autotvm.task.register_topi_compute(nn.conv2d_transpose_nchw, ['cuda', 'gpu'], "direct")
28-
def conv2d_transpose_nchw_cuda(cfg, Input, Filter, strides, padding, out_dtype):
28+
def conv2d_transpose_nchw_cuda(cfg, data, kernel, stride, padding, out_dtype):
2929
"""Transposed 2D convolution nchw forward operator.
3030
3131
Parameters
@@ -48,67 +48,59 @@ def conv2d_transpose_nchw_cuda(cfg, Input, Filter, strides, padding, out_dtype):
4848
Output : tvm.Tensor
4949
4-D with shape [batch, out_channel, out_height, out_width]
5050
"""
51-
batch, in_c, in_h, in_w = get_const_tuple(Input.shape)
52-
_, out_c, filter_h, filter_w = get_const_tuple(Filter.shape)
53-
stride_h, stride_w = strides
54-
55-
# attach stride info to config, this is used in schedule space definition
56-
cfg.stride = strides
57-
58-
# padding stage
59-
fpad_top, fpad_left, fpad_bottom, fpad_right = nn.get_pad_tuple(padding, (filter_h, filter_w))
60-
bpad_top = filter_h - 1 - fpad_top
61-
bpad_bottom = filter_h - 1 - fpad_bottom
62-
bpad_left = filter_w - 1 - fpad_left
63-
bpad_right = filter_w - 1 - fpad_right
64-
65-
# padding stage
66-
FirstPad = nn.pad(Input,
67-
[0, 0, (bpad_top + stride_h - 1) // stride_h,
68-
(bpad_left + stride_w - 1) // stride_w],
69-
[0, 0, (bpad_bottom + stride_h - 1) // stride_h,
70-
(bpad_right + stride_w - 1) // stride_w], name='FirstPad')
71-
72-
idxdiv = tvm.indexdiv
73-
idxmod = tvm.indexmod
74-
# remove extra padding introduced by dilatation
75-
border_h = idxmod(stride_h - idxmod(bpad_top, stride_h), stride_h)
76-
border_w = idxmod(stride_w - idxmod(bpad_left, stride_w), stride_w)
77-
78-
# dilation stage
79-
data = FirstPad
80-
strides = [1, 1, stride_h, stride_w]
81-
n = len(data.shape)
82-
83-
def _dilate(*indices):
84-
not_zero = []
85-
index_tuple = []
86-
for i in range(n):
87-
if not equal_const_int(strides[i], 1):
88-
index_tuple.append(idxdiv(indices[i], strides[i]))
89-
not_zero.append(idxmod(indices[i], strides[i]).equal(0))
90-
else:
91-
index_tuple.append(indices[i])
92-
if not_zero:
93-
not_zero = tvm.all(*not_zero)
94-
return tvm.if_then_else(not_zero, data(*index_tuple), tvm.const(0.0, data.dtype))
95-
return data(*index_tuple)
96-
97-
# convolution stage
98-
out_h = (in_h - 1) * stride_h - fpad_top - fpad_bottom + filter_h
99-
out_w = (in_w - 1) * stride_w - fpad_left - fpad_right + filter_w
100-
dc = tvm.reduce_axis((0, in_c), name='dc')
101-
dh = tvm.reduce_axis((0, filter_h), name='dh')
102-
dw = tvm.reduce_axis((0, filter_w), name='dw')
103-
104-
Output = tvm.compute(
105-
(batch, out_c, out_h, out_w),
51+
batch, inp_channels, inp_height, inp_width = get_const_tuple(data.shape)
52+
_, out_channels, kernel_height, kernel_width = get_const_tuple(kernel.shape)
53+
stride_height, stride_width = stride
54+
cfg.stride = stride
55+
pad_top, pad_left, pad_bottom, pad_right = nn.get_pad_tuple(
56+
padding, (kernel_height, kernel_width))
57+
58+
out_width = (inp_width - 1) * stride_width + \
59+
kernel_width - pad_left - pad_right
60+
pad_left = kernel_width - 1 - pad_left
61+
pad_right = kernel_width - 1 - pad_right
62+
dilated_width = stride_width * (inp_width - 1) + 1
63+
64+
out_height = (inp_height - 1) * stride_height + \
65+
kernel_height - pad_top - pad_bottom
66+
pad_top = kernel_height - 1 - pad_top
67+
pad_bottom = kernel_height - 1 - pad_bottom
68+
dilated_height = stride_height * (inp_height - 1) + 1
69+
70+
# compute pad
71+
data = tvm.compute(
72+
(batch, inp_channels,
73+
pad_top + dilated_height + pad_bottom,
74+
pad_left + dilated_width + pad_right),
75+
lambda n, c, y, x: tvm.if_then_else(
76+
tvm.all(x >= pad_left,
77+
x < pad_left + dilated_width,
78+
tvm.indexmod(x - pad_left, stride_width).equal(0),
79+
y >= pad_top,
80+
y < pad_top + dilated_height,
81+
tvm.indexmod(y - pad_top, stride_height).equal(0)
82+
),
83+
data[n, c,
84+
tvm.indexdiv(y - pad_top, stride_height),
85+
tvm.indexdiv(x - pad_left, stride_width)],
86+
tvm.const(0., "float32")),
87+
name='data_pad')
88+
89+
# compute transposed conv
90+
dc = tvm.reduce_axis((0, inp_channels), name='dc')
91+
dh = tvm.reduce_axis((0, kernel_height), name='dh')
92+
dw = tvm.reduce_axis((0, kernel_width), name='dw')
93+
data_out = tvm.compute(
94+
(batch, out_channels, out_height, out_width),
10695
lambda b, c, h, w: tvm.sum(
107-
_dilate(b, dc, h + dh + border_h, w + dw + border_w).astype(out_dtype) *
108-
Filter[dc, c, filter_h - 1 - dh, filter_w - 1 - dw].astype(out_dtype),
96+
data[b, dc, h + dh, w + dw].astype(out_dtype) *
97+
kernel[dc,
98+
c,
99+
kernel_height - 1 - dh,
100+
kernel_width - 1 - dw].astype(out_dtype),
109101
axis=[dc, dh, dw]), tag="conv2d_transpose_nchw")
110102

111-
return Output
103+
return data_out
112104

113105
@autotvm.task.register_topi_schedule(generic.schedule_conv2d_transpose_nchw,
114106
['cuda', 'gpu'], 'direct')
@@ -140,7 +132,8 @@ def _fallback_schedule(N, F, Y, X):
140132
else:
141133
cfg["tile_n"] = SplitEntity([1, 1, 1, 1])
142134
# split F (output channel dimension)
143-
cfg["tile_f"] = SplitEntity([-1, 1, 64, 1])
135+
if F > 1:
136+
cfg["tile_f"] = SplitEntity([-1, 1, 64, 1])
144137
# split Y (height dimension)
145138
y_split_factor = 1
146139
for candidate in range(5, 17):
@@ -185,26 +178,8 @@ def _callback(op):
185178
cfg.define_knob("unroll_explicit", [0, 1])
186179

187180
if cfg.is_fallback:
188-
ko = int(kernel.shape[1])
189-
kh = int(kernel.shape[2])
190-
kw = int(kernel.shape[3])
191-
stride_h, stride_w = cfg.stride
192-
# Workaround to make CUDA compilation work. Issue #4470
193-
# TODO make _fallback_schedule work for all kernel/strides combinations
194-
# after issue #4470 is resolved
195-
do_fallback = True
196-
if ko == 1:
197-
do_fallback = False
198-
elif (kh, kw) == (1, 1):
199-
do_fallback = True
200-
elif (stride_h, stride_w) == (2, 2):
201-
do_fallback = False
202-
elif (kh, kw) == (stride_h, stride_w):
203-
do_fallback = False
204-
205-
if do_fallback:
206-
N, F, Y, X = get_const_tuple(conv.shape)
207-
_fallback_schedule(N, F, Y, X)
181+
N, F, Y, X = get_const_tuple(conv.shape)
182+
_fallback_schedule(N, F, Y, X)
208183

209184
##### space definition end #####
210185

topi/tests/python/test_topi_conv2d_transpose_nchw.py

Lines changed: 19 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -25,10 +25,13 @@
2525
from common import get_all_backend
2626

2727
def verify_conv2d_transpose_nchw(batch, in_channel, in_size, num_filter, kernel, stride, padding):
28-
in_height = in_width = in_size
28+
in_height, in_width = in_size
29+
kernel_height, kernel_width = kernel
30+
stride_height, stride_width = stride
31+
pad_top, pad_left, pad_bottom, pad_right = padding
2932

3033
A = tvm.placeholder((batch, in_channel, in_height, in_width), name='A')
31-
W = tvm.placeholder((in_channel, num_filter, kernel, kernel), name='W')
34+
W = tvm.placeholder((in_channel, num_filter, kernel_height, kernel_width), name='W')
3235

3336
a_shape = get_const_tuple(A.shape)
3437
w_shape = get_const_tuple(W.shape)
@@ -51,7 +54,10 @@ def check_device(device):
5154
return
5255
print("Running on target: %s" % device)
5356
with tvm.target.create(device):
54-
B = topi.nn.conv2d_transpose_nchw(A, W, [stride, stride], [padding, padding], A.dtype)
57+
B = topi.nn.conv2d_transpose_nchw(A, W,
58+
[stride_height, stride_width],
59+
[pad_top, pad_left, pad_bottom, pad_right],
60+
A.dtype)
5561
C = topi.nn.relu(B)
5662
s1 = topi.generic.schedule_conv2d_transpose_nchw([B])
5763
s2 = topi.generic.schedule_conv2d_transpose_nchw([C])
@@ -66,18 +72,21 @@ def check_device(device):
6672
func2(a, w, c)
6773
tvm.testing.assert_allclose(b.asnumpy(), b_np, rtol=1e-5)
6874
tvm.testing.assert_allclose(c.asnumpy(), c_np, rtol=1e-5)
69-
7075
for device in get_all_backend():
7176
check_device(device)
7277

7378

7479
def test_conv2d_transpose_nchw():
75-
verify_conv2d_transpose_nchw(1, 3, 224, 32, 3, 1, 0)
76-
verify_conv2d_transpose_nchw(1, 3, 224, 32, 3, 2, 1)
77-
verify_conv2d_transpose_nchw(1, 3, 224, 32, 2, 2, 0)
78-
verify_conv2d_transpose_nchw(1, 32, 32, 128, 5, 1, 0)
79-
verify_conv2d_transpose_nchw(1, 32, 32, 128, 5, 2, 1)
80-
80+
verify_conv2d_transpose_nchw(1, 3, (224, 224), 1, (1, 1), (1, 1), (0, 0, 0, 0))
81+
verify_conv2d_transpose_nchw(1, 3, (224, 224), 32, (3, 3), (1, 1), (0, 0, 0, 0))
82+
verify_conv2d_transpose_nchw(1, 3, (224, 224), 32, (3, 3), (3, 3), (0, 0, 0, 0))
83+
verify_conv2d_transpose_nchw(1, 3, (224, 224), 32, (3, 3), (1, 1), (0, 0, 0, 0))
84+
verify_conv2d_transpose_nchw(1, 3, (224, 224), 32, (3, 3), (2, 2), (1, 1, 1, 1))
85+
verify_conv2d_transpose_nchw(1, 3, (224, 224), 32, (2, 2), (2, 2), (0, 0, 0, 0))
86+
verify_conv2d_transpose_nchw(1, 32, (32, 32), 128, (5, 5), (1, 1), (0, 0, 0, 0))
87+
verify_conv2d_transpose_nchw(1, 32, (32, 32), 128, (5, 5), (2, 2), (1, 1, 1, 1))
88+
verify_conv2d_transpose_nchw(16, 32, (8192, 1), 8, (31, 1), (2, 1), (14, 0, 15, 0))
89+
verify_conv2d_transpose_nchw(16, 512, (8, 1), 128, (31, 1), (2, 1), (14, 0, 15, 0))
8190

8291
if __name__ == "__main__":
8392
test_conv2d_transpose_nchw()

0 commit comments

Comments
 (0)