Skip to content

Commit ec03729

Browse files
committed
[TOPI][CUDA] schedule for group_conv2d
1 parent f1d4337 commit ec03729

File tree

2 files changed

+119
-6
lines changed

2 files changed

+119
-6
lines changed

topi/python/topi/cuda/group_conv2d_nchw.py

Lines changed: 118 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -321,8 +321,124 @@ def schedule_group_conv2d_NCHWc_int8(cfg, s, output):
321321
return s
322322

323323

324+
def schedule_group_conv2d_nchw_direct(cfg, s, conv):
325+
"""Schedule group conv2d NCHW direct template"""
326+
workload = conv.op.attrs["workload"]
327+
groups = get_const_int(workload[6])
328+
num_filters = get_const_int(conv.shape[1])
329+
330+
##### space definition begin #####
331+
n, f, y, x = s[conv].op.axis
332+
rc, ry, rx = s[conv].op.reduce_axis
333+
cfg.define_split("tile_n", n, num_outputs=4)
334+
cfg.define_split("tile_g", cfg.axis(groups), num_outputs=2)
335+
cfg.define_split("tile_f", cfg.axis(num_filters // groups), num_outputs=4)
336+
cfg.define_split("tile_y", y, num_outputs=4)
337+
cfg.define_split("tile_x", x, num_outputs=4)
338+
cfg.define_split("tile_rc", rc, num_outputs=2)
339+
cfg.define_split("tile_ry", ry, num_outputs=2)
340+
cfg.define_split("tile_rx", rx, num_outputs=2)
341+
cfg.define_knob("auto_unroll_max_step", [0, 512, 1500])
342+
343+
target = tvm.target.current_target()
344+
if target.target_name in ['nvptx', 'rocm']:
345+
cfg.define_knob("unroll_explicit", [1])
346+
else:
347+
cfg.define_knob("unroll_explicit", [0, 1])
348+
349+
pad_data, kernel = s[conv].op.input_tensors
350+
351+
s[pad_data].compute_inline()
352+
353+
if conv.op in s.outputs:
354+
output = conv
355+
OL = s.cache_write(conv, 'local')
356+
else:
357+
output = s.outputs[0].output(0)
358+
s[conv].set_scope('local')
359+
OL = conv
360+
361+
# create cache stage
362+
AA = s.cache_read(pad_data, 'shared', [OL])
363+
WW = s.cache_read(kernel, 'shared', [OL])
364+
365+
# tile and bind spatial axes
366+
n, f, y, x = s[output].op.axis
367+
kernel_scope, n = s[output].split(n, nparts=1)
368+
369+
g, f = s[output].split(f, nparts=groups)
370+
bn, vn, tn, ni = cfg["tile_n"].apply(s, output, n)
371+
bg, vg = cfg["tile_g"].apply(s, output, g)
372+
bf, vf, tf, fi = cfg["tile_f"].apply(s, output, f)
373+
by, vy, ty, yi = cfg["tile_y"].apply(s, output, y)
374+
bx, vx, tx, xi = cfg["tile_x"].apply(s, output, x)
375+
376+
s[output].reorder(bn, bg, bf, by, bx, vn, vg, vf, vy, vx, tn, tf, ty, tx, ni, fi, yi, xi)
377+
s[output].bind(bn, tvm.thread_axis("blockIdx.z"))
378+
s[output].bind(s[output].fuse(bg, bf), tvm.thread_axis("blockIdx.y"))
379+
s[output].bind(s[output].fuse(by, bx), tvm.thread_axis("blockIdx.x"))
380+
s[output].bind(vn, tvm.thread_axis("vthread"))
381+
s[output].bind(vg, tvm.thread_axis("vthread"))
382+
s[output].bind(vf, tvm.thread_axis("vthread"))
383+
s[output].bind(vy, tvm.thread_axis("vthread"))
384+
s[output].bind(vx, tvm.thread_axis("vthread"))
385+
386+
cfg.define_knob("fuse_yx", [0, 1]) # fuse ty,tx or tn,tf
387+
if cfg["fuse_yx"].val:
388+
s[output].bind(tn, tvm.thread_axis("threadIdx.z"))
389+
s[output].bind(tf, tvm.thread_axis("threadIdx.y"))
390+
tyx = s[output].fuse(ty, tx)
391+
s[output].bind(tyx, tvm.thread_axis("threadIdx.x"))
392+
s[OL].compute_at(s[output], tyx)
393+
394+
# number of threads
395+
n_tz = cfg["tile_n"].size[2]
396+
n_ty = cfg["tile_f"].size[2]
397+
n_tx = cfg["tile_y"].size[2] * cfg["tile_x"].size[2]
398+
else:
399+
s[output].bind(s[output].fuse(tn, tf), tvm.thread_axis("threadIdx.z"))
400+
s[output].bind(ty, tvm.thread_axis("threadIdx.y"))
401+
s[output].bind(tx, tvm.thread_axis("threadIdx.x"))
402+
s[OL].compute_at(s[output], tx)
403+
404+
# number of threads
405+
n_tz = cfg["tile_n"].size[2] * cfg["tile_f"].size[2]
406+
n_ty = cfg["tile_y"].size[2]
407+
n_tx = cfg["tile_x"].size[2]
408+
409+
# tile reduction axes
410+
n, f, y, x = s[OL].op.axis
411+
rc, ry, rx = s[OL].op.reduce_axis
412+
rco, rci = cfg['tile_rc'].apply(s, OL, rc)
413+
ryo, ryi = cfg['tile_rx'].apply(s, OL, ry)
414+
rxo, rxi = cfg['tile_ry'].apply(s, OL, rx)
415+
s[OL].reorder(rco, ryo, rxo, rci, ryi, rxi, n, f, y, x)
416+
417+
s[AA].compute_at(s[OL], rxo)
418+
s[WW].compute_at(s[OL], rxo)
419+
420+
# cooperative fetching
421+
for load in [AA, WW]:
422+
n, f, y, x = s[load].op.axis
423+
fused = s[load].fuse(n, f, y, x)
424+
fused, tx = s[load].split(fused, factor=n_tx)
425+
fused, ty = s[load].split(fused, factor=n_ty)
426+
fused, tz = s[load].split(fused, factor=n_tz)
427+
s[load].bind(tz, tvm.thread_axis("threadIdx.z"))
428+
s[load].bind(ty, tvm.thread_axis("threadIdx.y"))
429+
s[load].bind(tx, tvm.thread_axis("threadIdx.x"))
430+
431+
# unroll
432+
s[output].pragma(kernel_scope, 'auto_unroll_max_step', cfg['auto_unroll_max_step'].val)
433+
s[output].pragma(kernel_scope, 'unroll_explicit', cfg['unroll_explicit'].val)
434+
435+
N, CO, OH, OW = get_const_tuple(output.shape)
436+
_, KH, KW, CI = get_const_tuple(kernel.shape)
437+
cfg.add_flop(2 * N * OH * OW * CO * CI * KH * KW // groups)
438+
439+
324440
@autotvm.register_topi_schedule(generic.schedule_group_conv2d_nchw,
325-
["cuda", "gpu"], ["int8"])
441+
["cuda", "gpu"], ["int8", "direct"])
326442
def schedule_conv2d_nchw_cuda(cfg, outs):
327443
"""TOPI schedule callback of group conv2d for cuda gpu
328444
@@ -347,7 +463,7 @@ def _callback(op):
347463
if op.tag == "group_conv2d_NCHWc_int8":
348464
schedule_group_conv2d_NCHWc_int8(cfg, s, op.output(0))
349465
if op.tag == "group_conv2d_nchw":
350-
raise tvm.error.OpNotImplemented("group_conv2d_nchw not supported")
466+
schedule_group_conv2d_nchw_direct(cfg, s, op.output(0))
351467

352468
traverse_inline(s, outs[0].op, _callback)
353469
return s

topi/tests/python/test_topi_group_conv2d.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -67,9 +67,6 @@ def check_device(device):
6767
if not ctx.exist:
6868
print("Skip because %s is not enabled" % device)
6969
return
70-
if device == "cuda" and not tvm.contrib.nvcc.have_int8(ctx.compute_version):
71-
print("Skip because int8 intrinsics are not available")
72-
return
7370

7471
print("Running on target: %s" % device)
7572
with tvm.target.create(device):
@@ -94,7 +91,7 @@ def check_device(device):
9491
func(a, w, c)
9592
tvm.testing.assert_allclose(c.asnumpy(), c_np, rtol=1e-5)
9693

97-
for device in ["llvm"]:
94+
for device in ["llvm", "cuda"]:
9895
check_device(device)
9996

10097

0 commit comments

Comments
 (0)