From c019f9cc6e9d519b751c9fb0ff3be539bf2d52b6 Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Mon, 23 Oct 2017 08:16:45 +0900 Subject: [PATCH 1/3] update topi/cuda schedules to use target.max_num_threads --- python/tvm/target.py | 8 ++++++++ topi/python/topi/cuda/conv2d_nchw.py | 11 +++++++---- topi/python/topi/cuda/depthwise_conv2d.py | 8 ++++++-- topi/python/topi/cuda/injective.py | 4 +--- topi/python/topi/cuda/pooling.py | 2 +- topi/python/topi/cuda/reduction.py | 7 ++++++- topi/python/topi/generic/nn.py | 17 +++++++++++++++++ topi/tests/python/test_topi_depthwise_conv2d.py | 9 +++++---- topi/tests/python/test_topi_relu.py | 3 ++- 9 files changed, 53 insertions(+), 16 deletions(-) diff --git a/python/tvm/target.py b/python/tvm/target.py index 1bcd1de7d3d9..898f8b2d845d 100644 --- a/python/tvm/target.py +++ b/python/tvm/target.py @@ -311,3 +311,11 @@ def current_target(allow_none=True): "Requires a current target in generic function, but it is not set. " "Please set it using `with TargetObject:`") return Target.current + + +def get_max_num_threads(): + """Returns the maximum number of threads under current target. + """ + target = current_target() + target = target if target else cuda() + return target.max_num_threads diff --git a/topi/python/topi/cuda/conv2d_nchw.py b/topi/python/topi/cuda/conv2d_nchw.py index 34cc61b1d78c..391bf7d0aad8 100644 --- a/topi/python/topi/cuda/conv2d_nchw.py +++ b/topi/python/topi/cuda/conv2d_nchw.py @@ -36,7 +36,7 @@ def conv2d_224_3_64(s, temp, temp_R, temp_S, Filter_S, Out, Out_L): s[temp_S].compute_at(s[Out_L], ic) s[Filter_S].compute_at(s[Out_L], w) - num_thread1 = 512 + num_thread1 = tvm.target.get_max_num_threads() thread_xx = tvm.thread_axis((0, num_thread1), "threadIdx.x") block_xx = tvm.thread_axis("blockIdx.x") @@ -116,7 +116,7 @@ def conv2d_56_64_128(s, temp, temp_R, temp_S, Filter_S, Out, Out_L, flag): s[temp_S].compute_at(s[Out_L], oic) s[Filter_S].compute_at(s[Out_L], dw) - num_thread = 512 + num_thread = tvm.target.get_max_num_threads() thread_xx = tvm.thread_axis((0, num_thread), "threadIdx.x") block_xx = tvm.thread_axis("blockIdx.x") @@ -186,15 +186,18 @@ def conv2d_56_64_128(s, temp, temp_R, temp_S, Filter_S, Out, Out_L, flag): ic, dh, dw = s[Out_L].op.reduce_axis oic, iic = s[Out_L].split(ic, factor=ifactor) s[Out_L].reorder(oic, dh, dw, iic, h, w) + max_num_thread = tvm.target.get_max_num_threads() if util.get_const_int(Filter_S.shape[1]) == 128: oic = s[Out_L].fuse(dh, oic) s[temp_S].compute_at(s[Out_L], oic) s[Filter_S].compute_at(s[Out_L], oic) - num_thread = 512 + num_thread = max_num_thread else: s[temp_S].compute_at(s[Out_L], oic) s[Filter_S].compute_at(s[Out_L], dw) num_thread = 456 + if max_num_thread < num_thread: + num_thread = max_num_thread thread_xx = tvm.thread_axis((0, num_thread), "threadIdx.x") block_xx = tvm.thread_axis("blockIdx.x") @@ -300,7 +303,7 @@ def conv2d_14_256_256(s, temp, temp_R, temp_S, Filter, Filter_S, Out, Out_L): s[temp_S].compute_at(s[Out_L], oic) s[Filter_S].compute_at(s[Out_L], oic) - num_thread = 512 + num_thread = tvm.target.get_max_num_threads() thread_xx = tvm.thread_axis((0, num_thread), "threadIdx.x") block_xx = tvm.thread_axis("blockIdx.x") diff --git a/topi/python/topi/cuda/depthwise_conv2d.py b/topi/python/topi/cuda/depthwise_conv2d.py index 935d8a98b811..a83838d003ec 100644 --- a/topi/python/topi/cuda/depthwise_conv2d.py +++ b/topi/python/topi/cuda/depthwise_conv2d.py @@ -119,6 +119,7 @@ def traverse(OP): traverse(outs[0].op) return s +@generic.schedule_depthwise_conv2d_nhwc.register(["cuda", "gpu"]) def schedule_depthwise_conv2d_nhwc(outs): """Schedule for depthwise_conv2d nhwc forward. @@ -151,8 +152,11 @@ def _schedule(temp, Filter, DepthwiseConv2d): b, h, w, c = s[Output].op.axis - ic_val = tvm.ir_pass.Simplify(temp.shape[3]).value - xoc, xic = s[Output].split(c, factor=ic_val) + num_thread = tvm.ir_pass.Simplify(temp.shape[3]).value + max_num_thread = tvm.target.get_max_num_threads() + if max_num_thread < num_thread: + num_thread = max_num_thread + xoc, xic = s[Output].split(c, factor=num_thread) s[Output].reorder(xoc, b, h, w, xic) xo, yo, _, _ = s[Output].tile(h, w, x_factor=2, y_factor=2) fused = s[Output].fuse(yo, xo) diff --git a/topi/python/topi/cuda/injective.py b/topi/python/topi/cuda/injective.py index 940f91e5938c..039744045fca 100644 --- a/topi/python/topi/cuda/injective.py +++ b/topi/python/topi/cuda/injective.py @@ -6,9 +6,7 @@ def _schedule_injective(op, sch): x = op.output(0) fused = sch[x].fuse(*sch[x].op.axis) - target = tvm.target.current_target() - target = target if target else tvm.target.cuda() - num_thread = target.max_num_threads + num_thread = tvm.target.get_max_num_threads() bx, tx = sch[x].split(fused, factor=num_thread) sch[x].bind(bx, tvm.thread_axis("blockIdx.x")) sch[x].bind(tx, tvm.thread_axis("threadIdx.x")) diff --git a/topi/python/topi/cuda/pooling.py b/topi/python/topi/cuda/pooling.py index fc1850d31221..9b4d7face10a 100644 --- a/topi/python/topi/cuda/pooling.py +++ b/topi/python/topi/cuda/pooling.py @@ -84,7 +84,7 @@ def schedule_pool(outs): s = tvm.create_schedule([x.op for x in outs]) def _schedule(PaddedInput, Pool): s[PaddedInput].compute_inline() - num_thread = 512 + num_thread = tvm.target.get_max_num_threads() if Pool.op in s.outputs: Out = Pool OL = s.cache_write(Pool, "local") diff --git a/topi/python/topi/cuda/reduction.py b/topi/python/topi/cuda/reduction.py index 02250fadb90d..8bc7b4837fd4 100644 --- a/topi/python/topi/cuda/reduction.py +++ b/topi/python/topi/cuda/reduction.py @@ -16,12 +16,17 @@ def _schedule_reduce(op, sch, is_idx_reduce=False): if len(sch[data_out].op.axis) > 0: all_reduce = False num_thread = 32 + target = tvm.target.current_target() + if target and target.target_name == "opencl": + # without it, CL_INVALID_WORK_GROUP_SIZE occured when running test_topi_reduce.py + # don't know why + num_thread = 16 block_x = tvm.thread_axis("blockIdx.x") thread_x = tvm.thread_axis((0, num_thread), "threadIdx.x") thread_y = tvm.thread_axis((0, num_thread), "threadIdx.y") else: all_reduce = True - num_thread = 512 + num_thread = tvm.target.get_max_num_threads() thread_x = tvm.thread_axis((0, num_thread), "threadIdx.x") # Fuse and refactor the reduce axis diff --git a/topi/python/topi/generic/nn.py b/topi/python/topi/generic/nn.py index 41f8117077ca..3a335790d7ed 100644 --- a/topi/python/topi/generic/nn.py +++ b/topi/python/topi/generic/nn.py @@ -53,6 +53,23 @@ def schedule_depthwise_conv2d_nchw(outs): return _default_schedule(outs, False) +@tvm.target.generic_func +def schedule_depthwise_conv2d_nhwc(outs): + """Schedule for depthwise nhcw conv2 + Parameters + ---------- + outs: Array of Tensor + The computation graph description of reduce in the format + of an array of tensors. + + Returns + ------- + sch: Schedule + The computation schedule for the op. + """ + return _default_schedule(outs, False) + + @tvm.target.generic_func def schedule_reduce(outs): """Schedule for reduction diff --git a/topi/tests/python/test_topi_depthwise_conv2d.py b/topi/tests/python/test_topi_depthwise_conv2d.py index dc507d1bfc3d..adaad2221fcd 100644 --- a/topi/tests/python/test_topi_depthwise_conv2d.py +++ b/topi/tests/python/test_topi_depthwise_conv2d.py @@ -32,7 +32,6 @@ def check_device(device): s1 = topi.generic.schedule_depthwise_conv2d_nchw(DepthwiseConv2d) s2 = topi.generic.schedule_depthwise_conv2d_nchw(ScaleShift) s3 = topi.generic.schedule_depthwise_conv2d_nchw(Relu) - ctx = tvm.context(device, 0) # build the kernels f1 = tvm.build(s1, [Input, Filter, DepthwiseConv2d], device) @@ -107,14 +106,16 @@ def depthwise_conv2d_with_workload_nhwc(batch, in_channel, in_height, channel_mu ScaleShift = topi.nn.scale_shift_nhwc(DepthwiseConv2d, Scale, Shift) Relu = topi.nn.relu(ScaleShift) # schedule - s1 = schedule_depthwise_conv2d_nhwc(DepthwiseConv2d) - s2 = schedule_depthwise_conv2d_nhwc(ScaleShift) - s3 = schedule_depthwise_conv2d_nhwc(Relu) def check_device(device): if not tvm.module.enabled(device): print("Skip because %s is not enabled" % device) return + + with tvm.target.create(device): + s1 = topi.generic.schedule_depthwise_conv2d_nhwc(DepthwiseConv2d) + s2 = topi.generic.schedule_depthwise_conv2d_nhwc(ScaleShift) + s3 = topi.generic.schedule_depthwise_conv2d_nhwc(Relu) ctx = tvm.context(device, 0) # build the kernels f1 = tvm.build(s1, [Input, Filter, DepthwiseConv2d], device) diff --git a/topi/tests/python/test_topi_relu.py b/topi/tests/python/test_topi_relu.py index 3307100043d1..04ee3d87b0dc 100644 --- a/topi/tests/python/test_topi_relu.py +++ b/topi/tests/python/test_topi_relu.py @@ -8,7 +8,6 @@ def verify_relu(m, n): A = tvm.placeholder((m, n), name='A') B = topi.nn.relu(A) - s = topi.cuda.schedule_elemwise(B) a_np = np.random.uniform(size=get_const_tuple(A.shape)).astype(A.dtype) b_np = a_np * (a_np > 0) @@ -17,6 +16,8 @@ def check_device(device): if not tvm.module.enabled(device): print("Skip because %s is not enabled" % device) return + with tvm.target.create(device): + s = topi.generic.schedule_elemwise(B) ctx = tvm.context(device, 0) a = tvm.nd.array(a_np, ctx) b = tvm.nd.array(np.zeros(get_const_tuple(B.shape), dtype=B.dtype), ctx) From 7f553f499be0d99eac51e77d4a9298a5bc221915 Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Mon, 23 Oct 2017 08:27:23 +0900 Subject: [PATCH 2/3] allow num_thread to be larger than cuda.max_num_threads --- topi/python/topi/cuda/depthwise_conv2d.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/topi/python/topi/cuda/depthwise_conv2d.py b/topi/python/topi/cuda/depthwise_conv2d.py index a83838d003ec..851a00db0a48 100644 --- a/topi/python/topi/cuda/depthwise_conv2d.py +++ b/topi/python/topi/cuda/depthwise_conv2d.py @@ -152,10 +152,11 @@ def _schedule(temp, Filter, DepthwiseConv2d): b, h, w, c = s[Output].op.axis + # num_thread here could be 728, it is larger than cuda.max_num_threads num_thread = tvm.ir_pass.Simplify(temp.shape[3]).value - max_num_thread = tvm.target.get_max_num_threads() - if max_num_thread < num_thread: - num_thread = max_num_thread + target = tvm.target.current_target() + if target and target.target_name != "cuda": + num_thread = target.max_num_threads xoc, xic = s[Output].split(c, factor=num_thread) s[Output].reorder(xoc, b, h, w, xic) xo, yo, _, _ = s[Output].tile(h, w, x_factor=2, y_factor=2) From 92b896292de113aedea5259ad1123af3900d7d72 Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Mon, 23 Oct 2017 09:02:07 +0900 Subject: [PATCH 3/3] remove get_max_num_threads and make it inline --- python/tvm/target.py | 8 -------- topi/python/topi/cuda/conv2d_nchw.py | 8 ++++---- topi/python/topi/cuda/injective.py | 2 +- topi/python/topi/cuda/pooling.py | 2 +- topi/python/topi/cuda/reduction.py | 2 +- 5 files changed, 7 insertions(+), 15 deletions(-) diff --git a/python/tvm/target.py b/python/tvm/target.py index 898f8b2d845d..1bcd1de7d3d9 100644 --- a/python/tvm/target.py +++ b/python/tvm/target.py @@ -311,11 +311,3 @@ def current_target(allow_none=True): "Requires a current target in generic function, but it is not set. " "Please set it using `with TargetObject:`") return Target.current - - -def get_max_num_threads(): - """Returns the maximum number of threads under current target. - """ - target = current_target() - target = target if target else cuda() - return target.max_num_threads diff --git a/topi/python/topi/cuda/conv2d_nchw.py b/topi/python/topi/cuda/conv2d_nchw.py index 391bf7d0aad8..ccc6eb254b09 100644 --- a/topi/python/topi/cuda/conv2d_nchw.py +++ b/topi/python/topi/cuda/conv2d_nchw.py @@ -36,7 +36,7 @@ def conv2d_224_3_64(s, temp, temp_R, temp_S, Filter_S, Out, Out_L): s[temp_S].compute_at(s[Out_L], ic) s[Filter_S].compute_at(s[Out_L], w) - num_thread1 = tvm.target.get_max_num_threads() + num_thread1 = tvm.target.current_target(allow_none=False).max_num_threads thread_xx = tvm.thread_axis((0, num_thread1), "threadIdx.x") block_xx = tvm.thread_axis("blockIdx.x") @@ -116,7 +116,7 @@ def conv2d_56_64_128(s, temp, temp_R, temp_S, Filter_S, Out, Out_L, flag): s[temp_S].compute_at(s[Out_L], oic) s[Filter_S].compute_at(s[Out_L], dw) - num_thread = tvm.target.get_max_num_threads() + num_thread = tvm.target.current_target(allow_none=False).max_num_threads thread_xx = tvm.thread_axis((0, num_thread), "threadIdx.x") block_xx = tvm.thread_axis("blockIdx.x") @@ -186,7 +186,7 @@ def conv2d_56_64_128(s, temp, temp_R, temp_S, Filter_S, Out, Out_L, flag): ic, dh, dw = s[Out_L].op.reduce_axis oic, iic = s[Out_L].split(ic, factor=ifactor) s[Out_L].reorder(oic, dh, dw, iic, h, w) - max_num_thread = tvm.target.get_max_num_threads() + max_num_thread = tvm.target.current_target(allow_none=False).max_num_threads if util.get_const_int(Filter_S.shape[1]) == 128: oic = s[Out_L].fuse(dh, oic) s[temp_S].compute_at(s[Out_L], oic) @@ -303,7 +303,7 @@ def conv2d_14_256_256(s, temp, temp_R, temp_S, Filter, Filter_S, Out, Out_L): s[temp_S].compute_at(s[Out_L], oic) s[Filter_S].compute_at(s[Out_L], oic) - num_thread = tvm.target.get_max_num_threads() + num_thread = tvm.target.current_target(allow_none=False).max_num_threads thread_xx = tvm.thread_axis((0, num_thread), "threadIdx.x") block_xx = tvm.thread_axis("blockIdx.x") diff --git a/topi/python/topi/cuda/injective.py b/topi/python/topi/cuda/injective.py index 039744045fca..0143aec36a7b 100644 --- a/topi/python/topi/cuda/injective.py +++ b/topi/python/topi/cuda/injective.py @@ -6,7 +6,7 @@ def _schedule_injective(op, sch): x = op.output(0) fused = sch[x].fuse(*sch[x].op.axis) - num_thread = tvm.target.get_max_num_threads() + num_thread = tvm.target.current_target(allow_none=False).max_num_threads bx, tx = sch[x].split(fused, factor=num_thread) sch[x].bind(bx, tvm.thread_axis("blockIdx.x")) sch[x].bind(tx, tvm.thread_axis("threadIdx.x")) diff --git a/topi/python/topi/cuda/pooling.py b/topi/python/topi/cuda/pooling.py index 9b4d7face10a..4ed5ae66c19b 100644 --- a/topi/python/topi/cuda/pooling.py +++ b/topi/python/topi/cuda/pooling.py @@ -84,7 +84,7 @@ def schedule_pool(outs): s = tvm.create_schedule([x.op for x in outs]) def _schedule(PaddedInput, Pool): s[PaddedInput].compute_inline() - num_thread = tvm.target.get_max_num_threads() + num_thread = tvm.target.current_target(allow_none=False).max_num_threads if Pool.op in s.outputs: Out = Pool OL = s.cache_write(Pool, "local") diff --git a/topi/python/topi/cuda/reduction.py b/topi/python/topi/cuda/reduction.py index 8bc7b4837fd4..932f2aae3098 100644 --- a/topi/python/topi/cuda/reduction.py +++ b/topi/python/topi/cuda/reduction.py @@ -26,7 +26,7 @@ def _schedule_reduce(op, sch, is_idx_reduce=False): thread_y = tvm.thread_axis((0, num_thread), "threadIdx.y") else: all_reduce = True - num_thread = tvm.target.get_max_num_threads() + num_thread = tvm.target.current_target(allow_none=False).max_num_threads thread_x = tvm.thread_axis((0, num_thread), "threadIdx.x") # Fuse and refactor the reduce axis