Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
88 changes: 88 additions & 0 deletions tests/python/relay/test_op_grad_level2.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,88 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
import numpy as np
import tvm
import topi
import topi.testing
from tvm import relay
from tvm.relay.transform import gradient
from tvm.relay.testing import ctx_list
from tvm.relay.testing import run_infer_type


def verify_max_pool2d_grad(x_shape, pool_size, strides, padding, ceil_mode):
x = relay.var("x", relay.TensorType(x_shape, "float32"))
y = tvm.relay.nn.max_pool2d(x, pool_size=pool_size, strides=strides, padding=padding,
ceil_mode=ceil_mode)

fwd_func = relay.Function([x], y)
fwd_func = run_infer_type(fwd_func)
bwd_func = run_infer_type(gradient(fwd_func))

data = np.random.rand(*x_shape).astype("float32")
ph, pw = padding
y_shape = topi.util.get_const_tuple(fwd_func.ret_type.shape)
out_grad = np.ones(shape=y_shape)
ref_grad = topi.testing.pool_grad_nchw(data, out_grad, pool_size=pool_size, strides=strides,
padding=[ph, pw, ph, pw],
pool_type='max', ceil_mode=ceil_mode)

for target, ctx in ctx_list():
intrp = relay.create_executor(ctx=ctx, target=target)
op_res, (op_grad, ) = intrp.evaluate(bwd_func)(data)
np.testing.assert_allclose(op_grad.asnumpy(), ref_grad, rtol=0.01)


def test_max_pool2d_grad():
verify_max_pool2d_grad((1, 4, 16, 16), pool_size=(2, 2), strides=(2, 2), padding=(0, 0),
ceil_mode=False)
verify_max_pool2d_grad((1, 4, 16, 16), pool_size=(1, 1), strides=(1, 1), padding=(1, 1), ceil_mode=False)


def verify_avg_pool2d_grad(x_shape, pool_size, strides, padding, ceil_mode, count_include_pad):
x = relay.var("x", relay.TensorType(x_shape, "float32"))
y = tvm.relay.nn.avg_pool2d(x, pool_size=pool_size, strides=strides, padding=padding,
ceil_mode=ceil_mode, count_include_pad=count_include_pad)

fwd_func = relay.Function([x], y)
fwd_func = run_infer_type(fwd_func)
bwd_func = run_infer_type(gradient(fwd_func))

data = np.random.rand(*x_shape).astype("float32")
ph, pw = padding
y_shape = topi.util.get_const_tuple(fwd_func.ret_type.shape)
out_grad = np.ones(shape=y_shape)
ref_grad = topi.testing.pool_grad_nchw(data, out_grad, pool_size=pool_size, strides=strides,
padding=[ph, pw, ph, pw],
pool_type='avg', ceil_mode=ceil_mode)

for target, ctx in ctx_list():
intrp = relay.create_executor(ctx=ctx, target=target)
op_res, (op_grad, ) = intrp.evaluate(bwd_func)(data)
np.testing.assert_allclose(op_grad.asnumpy(), ref_grad, rtol=0.01)


def test_avg_pool2d_grad():
verify_avg_pool2d_grad((1, 4, 16, 16), pool_size=(2, 2), strides=(2, 2), padding=(0, 0),
ceil_mode=False, count_include_pad=True)
verify_avg_pool2d_grad((1, 4, 16, 16), pool_size=(1, 1), strides=(1, 1), padding=(1, 1),
ceil_mode=False, count_include_pad=False)


if __name__ == "__main__":
test_max_pool2d_grad()
test_avg_pool2d_grad()
50 changes: 50 additions & 0 deletions topi/python/topi/cuda/pooling.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
import tvm
from .. import tag
from .. import generic
from ..util import traverse_inline



Expand Down Expand Up @@ -150,3 +151,52 @@ def traverse(OP):

traverse(outs[0].op)
return s


@generic.schedule_pool_grad.register(['cuda', 'gpu'])
def schedule_pool_grad_cuda(outs):
"""Schedule for pool_grad on CUDA

Parameters
----------
outs: Array of Tensor
The computation graph description of pool_grad
in the format of an array of tensors.

Returns
-------
s: Schedule
The computation schedule for pool_grad.
"""
outs = [outs] if isinstance(outs, tvm.tensor.Tensor) else outs
s = tvm.create_schedule([x.op for x in outs])

def _schedule_pool_grad(op):
if op in s.outputs:
out = op
else:
out = outs[0].op.output(0)
fused = s[out].fuse(*s[out].op.axis)
num_thread = tvm.target.current_target(allow_none=False).max_num_threads
bx, tx = s[out].split(fused, factor=num_thread)
s[out].bind(bx, tvm.thread_axis("blockIdx.x"))
s[out].bind(tx, tvm.thread_axis("threadIdx.x"))

if tag.COMM_REDUCE_IDX in op.input_tensors[0].op.tag:
max_pool_index = op.input_tensors[0]
s[max_pool_index].compute_at(s[out], tx)

pool_input = max_pool_index.op.input_tensors[0]
if isinstance(pool_input.op, tvm.tensor.ComputeOp):
# handle padding
s[pool_input].compute_inline()
if op not in s.outputs:
s[op].compute_at(s[out], tx)

def _callback(op):
if op.tag.startswith('pool_grad'):
_schedule_pool_grad(op)

traverse_inline(s, outs[0].op, _callback)

return s
14 changes: 12 additions & 2 deletions topi/tests/python/test_topi_pooling.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,8 @@ def check_device(device):
for device in get_all_backend():
check_device(device)

def verify_pool_grad(n, ic, ih, kh, sh, padding, pool_type, ceil_mode, count_include_pad=True):
def verify_pool_grad(n, ic, ih, kh, sh, padding, pool_type, ceil_mode, count_include_pad=True,
add_relu=False):
iw = ih
kw = kh
sw = sh
Expand All @@ -110,13 +111,17 @@ def verify_pool_grad(n, ic, ih, kh, sh, padding, pool_type, ceil_mode, count_inc
PoolGrad = topi.nn.pool_grad(OutGrad, A, kernel=[kh, kw], stride=[sh, sw], padding=padding,
pool_type=pool_type, ceil_mode=ceil_mode,
layout="NCHW", count_include_pad=count_include_pad)
if add_relu:
PoolGrad = topi.nn.relu(PoolGrad)

a_np = np.random.uniform(low=0.001, size=(n, ic, ih, iw)).astype(dtype)
out_grad_np = np.random.uniform(low=0.001, size=bshape).astype(dtype)
pool_grad_np = topi.testing.pool_grad_nchw(a_np, out_grad_np, pool_size=(kh, kw),
strides=(sh, sw), padding=padding,
pool_type=pool_type, ceil_mode=ceil_mode,
count_include_pad=count_include_pad)
if add_relu:
pool_grad_np = np.maximum(pool_grad_np, 0.)

def check_device(device):
ctx = tvm.context(device, 0)
Expand All @@ -134,7 +139,7 @@ def check_device(device):
f(a, out_grad, pool_grad)
tvm.testing.assert_allclose(pool_grad.asnumpy(), pool_grad_np, rtol=1e-5)

for device in ['llvm']: # only support llvm
for device in get_all_backend():
check_device(device)

def test_pool():
Expand All @@ -152,6 +157,7 @@ def test_pool():
verify_pool(1, 256, 31, 3, 3, [1, 0, 3, 2], 'max', False)
verify_pool(1, 256, 31, 3, 3, [3, 2, 1, 0], 'max', True)

def test_pool_grad():
verify_pool_grad(1, 256, 32, 3, 2, [1, 1, 1, 1], 'avg', False, False)
verify_pool_grad(1, 256, 32, 2, 2, [0, 0, 0, 0], 'avg', False, True)
verify_pool_grad(1, 256, 31, 3, 3, [1, 2, 1, 2], 'avg', False, True)
Expand All @@ -169,6 +175,9 @@ def test_pool():
verify_pool_grad(1, 256, 32, 3, 2, [1, 1, 1, 1], 'max', False)
verify_pool_grad(1, 256, 32, 1, 2, [1, 1, 1, 1], 'avg', False, False)

verify_pool_grad(1, 256, 31, 4, 4, [0, 0, 0, 0], 'avg', False, False, add_relu=True)
verify_pool_grad(1, 256, 32, 2, 2, [0, 0, 0, 0], 'max', False, add_relu=True)


def verify_global_pool(n, c, h, w, pool_type):
A = tvm.placeholder((n, c, h, w), name='A')
Expand Down Expand Up @@ -258,5 +267,6 @@ def test_adaptive_pool():

if __name__ == "__main__":
test_pool()
test_pool_grad()
test_global_pool()
test_adaptive_pool()