diff --git a/topi/python/topi/cuda/conv2d_nchw.py b/topi/python/topi/cuda/conv2d_nchw.py index 63a91703bab3..4987f8d6fef2 100644 --- a/topi/python/topi/cuda/conv2d_nchw.py +++ b/topi/python/topi/cuda/conv2d_nchw.py @@ -10,19 +10,18 @@ def conv2d_224_3_64(s, temp, temp_R, temp_S, Filter_S, Out, Out_L): ofactor = 16 hfactor = 2 ow_size = util.get_const_int(Out.shape[3]) - num_thread = ow_size*hfactor - vthread = hfactor + num_thread = ow_size * hfactor + vthread = ofactor block_x = tvm.thread_axis("blockIdx.x") thread_x = tvm.thread_axis((0, num_thread), "threadIdx.x") thread_xz = tvm.thread_axis((0, vthread), "vthread", name="vx") i, oc, h, w = s[Out].op.axis - ooc, ioc = s[Out].split(oc, factor=ofactor) + ooc, ioc = s[Out].split(oc, factor=vthread) oh, ih = s[Out].split(h, factor=hfactor) s[Out].reorder(ooc, oh, ioc, ih, w) oc = s[Out].fuse(ooc, oh) w = s[Out].fuse(w, ih) - s[Out].bind(w, thread_x) s[Out].bind(ioc, thread_xz) s[Out].bind(oc, block_x) @@ -261,9 +260,9 @@ def conv2d_14_256_256(s, temp, temp_R, temp_S, Filter, Filter_S, Out, Out_L): else: # scheduler params - vthread_x = util.get_const_int(Out.shape[2]) + vthread_x = min(8, util.get_const_int(Out.shape[2])) num_thread_x = 16 - num_thread_y = util.get_const_int(Out.shape[3]) + num_thread_y = min(8, util.get_const_int(Out.shape[3])) ofactor = 8 block_x = tvm.thread_axis("blockIdx.x") thread_x = tvm.thread_axis((0, num_thread_x), "threadIdx.x") @@ -272,10 +271,12 @@ def conv2d_14_256_256(s, temp, temp_R, temp_S, Filter, Filter_S, Out, Out_L): i, oc, h, w = s[Out].op.axis ooc, ioc = s[Out].split(oc, factor=num_thread_x) - s[Out].reorder(i, ooc, h, w, ioc) + oh, ih = s[Out].split(h, factor=vthread_x) + ow, iw = s[Out].split(w, factor=num_thread_y) + s[Out].reorder(i, ooc, oh, ih, ow, iw, ioc) s[Out].bind(ioc, thread_x) - s[Out].bind(w, thread_y) - s[Out].bind(h, thread_xz) + s[Out].bind(iw, thread_y) + s[Out].bind(ih, thread_xz) s[Out].bind(ooc, block_x) s[Out_L].compute_at(s[Out], ioc) @@ -289,21 +290,19 @@ def conv2d_14_256_256(s, temp, temp_R, temp_S, Filter, Filter_S, Out, Out_L): s[temp_S].compute_at(s[Out_L], oic) s[Filter_S].compute_at(s[Out_L], oic) - rfactor = util.get_const_int(Filter.shape[1]) - thread_xx = tvm.thread_axis((0, rfactor), "threadIdx.x") + num_thread = 512 + thread_xx = tvm.thread_axis((0, num_thread), "threadIdx.x") block_xx = tvm.thread_axis("blockIdx.x") - i, ic, h, w = s[temp].op.axis - ic = s[temp].fuse(ic, h, w) - oic, iic = s[temp].split(ic, factor=rfactor) - s[temp].bind(iic, thread_xx) - s[temp].bind(oic, block_xx) - - i, h, w, oic, iic = s[temp_R].op.axis - ic = s[temp_R].fuse(oic, iic) - s[temp_R].bind(ic, thread_xx) - h = s[temp_R].fuse(h, w) - s[temp_R].bind(h, block_xx) + i = s[temp].fuse(*s[temp].op.axis) + bx, tx = s[temp].split(i, factor=num_thread) + s[temp].bind(tx, thread_xx) + s[temp].bind(bx, block_xx) + + i = s[temp_R].fuse(*s[temp_R].op.axis) + bx, tx = s[temp_R].split(i, factor=num_thread) + s[temp_R].bind(tx, thread_xx) + s[temp_R].bind(bx, block_xx) #schedule temp_S shared mem load i, h, w, oc, ic = s[temp_S].op.axis