2525
2626
2727@autotvm .task .register_topi_compute (nn .conv2d_transpose_nchw , ['cuda' , 'gpu' ], "direct" )
28- def conv2d_transpose_nchw_cuda (cfg , Input , Filter , strides , padding , out_dtype ):
28+ def conv2d_transpose_nchw_cuda (cfg , data , kernel , stride , padding , out_dtype ):
2929 """Transposed 2D convolution nchw forward operator.
3030
3131 Parameters
@@ -48,67 +48,59 @@ def conv2d_transpose_nchw_cuda(cfg, Input, Filter, strides, padding, out_dtype):
4848 Output : tvm.Tensor
4949 4-D with shape [batch, out_channel, out_height, out_width]
5050 """
51- batch , in_c , in_h , in_w = get_const_tuple (Input .shape )
52- _ , out_c , filter_h , filter_w = get_const_tuple (Filter .shape )
53- stride_h , stride_w = strides
54-
55- # attach stride info to config, this is used in schedule space definition
56- cfg .stride = strides
57-
58- # padding stage
59- fpad_top , fpad_left , fpad_bottom , fpad_right = nn .get_pad_tuple (padding , (filter_h , filter_w ))
60- bpad_top = filter_h - 1 - fpad_top
61- bpad_bottom = filter_h - 1 - fpad_bottom
62- bpad_left = filter_w - 1 - fpad_left
63- bpad_right = filter_w - 1 - fpad_right
64-
65- # padding stage
66- FirstPad = nn .pad (Input ,
67- [0 , 0 , (bpad_top + stride_h - 1 ) // stride_h ,
68- (bpad_left + stride_w - 1 ) // stride_w ],
69- [0 , 0 , (bpad_bottom + stride_h - 1 ) // stride_h ,
70- (bpad_right + stride_w - 1 ) // stride_w ], name = 'FirstPad' )
71-
72- idxdiv = tvm .indexdiv
73- idxmod = tvm .indexmod
74- # remove extra padding introduced by dilatation
75- border_h = idxmod (stride_h - idxmod (bpad_top , stride_h ), stride_h )
76- border_w = idxmod (stride_w - idxmod (bpad_left , stride_w ), stride_w )
77-
78- # dilation stage
79- data = FirstPad
80- strides = [1 , 1 , stride_h , stride_w ]
81- n = len (data .shape )
82-
83- def _dilate (* indices ):
84- not_zero = []
85- index_tuple = []
86- for i in range (n ):
87- if not equal_const_int (strides [i ], 1 ):
88- index_tuple .append (idxdiv (indices [i ], strides [i ]))
89- not_zero .append (idxmod (indices [i ], strides [i ]).equal (0 ))
90- else :
91- index_tuple .append (indices [i ])
92- if not_zero :
93- not_zero = tvm .all (* not_zero )
94- return tvm .if_then_else (not_zero , data (* index_tuple ), tvm .const (0.0 , data .dtype ))
95- return data (* index_tuple )
96-
97- # convolution stage
98- out_h = (in_h - 1 ) * stride_h - fpad_top - fpad_bottom + filter_h
99- out_w = (in_w - 1 ) * stride_w - fpad_left - fpad_right + filter_w
100- dc = tvm .reduce_axis ((0 , in_c ), name = 'dc' )
101- dh = tvm .reduce_axis ((0 , filter_h ), name = 'dh' )
102- dw = tvm .reduce_axis ((0 , filter_w ), name = 'dw' )
103-
104- Output = tvm .compute (
105- (batch , out_c , out_h , out_w ),
51+ batch , inp_channels , inp_height , inp_width = get_const_tuple (data .shape )
52+ _ , out_channels , kernel_height , kernel_width = get_const_tuple (kernel .shape )
53+ stride_height , stride_width = stride
54+ cfg .stride = stride
55+ pad_top , pad_left , pad_bottom , pad_right = nn .get_pad_tuple (
56+ padding , (kernel_height , kernel_width ))
57+
58+ out_width = (inp_width - 1 ) * stride_width + \
59+ kernel_width - pad_left - pad_right
60+ pad_left = kernel_width - 1 - pad_left
61+ pad_right = kernel_width - 1 - pad_right
62+ dilated_width = stride_width * (inp_width - 1 ) + 1
63+
64+ out_height = (inp_height - 1 ) * stride_height + \
65+ kernel_height - pad_top - pad_bottom
66+ pad_top = kernel_height - 1 - pad_top
67+ pad_bottom = kernel_height - 1 - pad_bottom
68+ dilated_height = stride_height * (inp_height - 1 ) + 1
69+
70+ # compute pad
71+ data = tvm .compute (
72+ (batch , inp_channels ,
73+ pad_top + dilated_height + pad_bottom ,
74+ pad_left + dilated_width + pad_right ),
75+ lambda n , c , y , x : tvm .if_then_else (
76+ tvm .all (x >= pad_left ,
77+ x < pad_left + dilated_width ,
78+ tvm .indexmod (x - pad_left , stride_width ).equal (0 ),
79+ y >= pad_top ,
80+ y < pad_top + dilated_height ,
81+ tvm .indexmod (y - pad_top , stride_height ).equal (0 )
82+ ),
83+ data [n , c ,
84+ tvm .indexdiv (y - pad_top , stride_height ),
85+ tvm .indexdiv (x - pad_left , stride_width )],
86+ tvm .const (0. , "float32" )),
87+ name = 'data_pad' )
88+
89+ # compute transposed conv
90+ dc = tvm .reduce_axis ((0 , inp_channels ), name = 'dc' )
91+ dh = tvm .reduce_axis ((0 , kernel_height ), name = 'dh' )
92+ dw = tvm .reduce_axis ((0 , kernel_width ), name = 'dw' )
93+ data_out = tvm .compute (
94+ (batch , out_channels , out_height , out_width ),
10695 lambda b , c , h , w : tvm .sum (
107- _dilate (b , dc , h + dh + border_h , w + dw + border_w ).astype (out_dtype ) *
108- Filter [dc , c , filter_h - 1 - dh , filter_w - 1 - dw ].astype (out_dtype ),
96+ data [b , dc , h + dh , w + dw ].astype (out_dtype ) *
97+ kernel [dc ,
98+ c ,
99+ kernel_height - 1 - dh ,
100+ kernel_width - 1 - dw ].astype (out_dtype ),
109101 axis = [dc , dh , dw ]), tag = "conv2d_transpose_nchw" )
110102
111- return Output
103+ return data_out
112104
113105@autotvm .task .register_topi_schedule (generic .schedule_conv2d_transpose_nchw ,
114106 ['cuda' , 'gpu' ], 'direct' )
@@ -140,7 +132,8 @@ def _fallback_schedule(N, F, Y, X):
140132 else :
141133 cfg ["tile_n" ] = SplitEntity ([1 , 1 , 1 , 1 ])
142134 # split F (output channel dimension)
143- cfg ["tile_f" ] = SplitEntity ([- 1 , 1 , 64 , 1 ])
135+ if F > 1 :
136+ cfg ["tile_f" ] = SplitEntity ([- 1 , 1 , 64 , 1 ])
144137 # split Y (height dimension)
145138 y_split_factor = 1
146139 for candidate in range (5 , 17 ):
@@ -185,26 +178,8 @@ def _callback(op):
185178 cfg .define_knob ("unroll_explicit" , [0 , 1 ])
186179
187180 if cfg .is_fallback :
188- ko = int (kernel .shape [1 ])
189- kh = int (kernel .shape [2 ])
190- kw = int (kernel .shape [3 ])
191- stride_h , stride_w = cfg .stride
192- # Workaround to make CUDA compilation work. Issue #4470
193- # TODO make _fallback_schedule work for all kernel/strides combinations
194- # after issue #4470 is resolved
195- do_fallback = True
196- if ko == 1 :
197- do_fallback = False
198- elif (kh , kw ) == (1 , 1 ):
199- do_fallback = True
200- elif (stride_h , stride_w ) == (2 , 2 ):
201- do_fallback = False
202- elif (kh , kw ) == (stride_h , stride_w ):
203- do_fallback = False
204-
205- if do_fallback :
206- N , F , Y , X = get_const_tuple (conv .shape )
207- _fallback_schedule (N , F , Y , X )
181+ N , F , Y , X = get_const_tuple (conv .shape )
182+ _fallback_schedule (N , F , Y , X )
208183
209184 ##### space definition end #####
210185
0 commit comments