Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
103 changes: 80 additions & 23 deletions python/tvm/topi/adreno/reduction.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,45 +25,102 @@


def _schedule_reduce_adreno(op, sch, is_idx_reduce=False):
if is_idx_reduce:
real_output = op.output(0)
sch_output = sch.outputs[0].output(0)
use_rfactor = False
if not is_idx_reduce:
rdomain = 1
whole_rop_output = op.output(0)
for axis in sch[whole_rop_output].op.reduce_axis:
rdomain = rdomain * axis.dom.extent
if rdomain > 50:
use_rfactor = True
# shared goves better perf, but works only for rfactor flow
scope = "shared"
else:
# in case of direct scheduling, shared is failed to be compiled
scope = "local"
if op in sch.outputs:
whole_rop_output = sch.cache_write(sch_output, scope)
else:
# no change for whole_rop_output def, but need to set proper scope
sch[whole_rop_output].set_scope(scope)
else:
temp_idx_input = op.input_tensors[0].op.output(0)
temp_val_input = op.input_tensors[0].op.output(1)
else:
real_output = op.output(0)
shape = get_const_tuple(real_output.shape)
sch[temp_idx_input].set_scope("local")
sch[temp_val_input].set_scope("local")

shape = get_const_tuple(sch_output.shape)
latest4 = shape[-1] == 4
div4 = numpy.prod(shape) % 4 == 0

# Fuse and split the axis
if latest4:
fused_outer = sch[real_output].fuse(
*[sch[real_output].op.axis[i] for i in range(len(sch[real_output].op.axis) - 1)]
fused_outer = sch[sch_output].fuse(
*[sch[sch_output].op.axis[i] for i in range(len(sch[sch_output].op.axis) - 1)]
)
else:
fused_outer = sch[real_output].fuse(
*[sch[real_output].op.axis[i] for i in range(len(sch[real_output].op.axis))]
fused_outer = sch[sch_output].fuse(
*[sch[sch_output].op.axis[i] for i in range(len(sch[sch_output].op.axis))]
)

ftc = numpy.prod(shape)
a = fused_outer
if latest4:
sch[real_output].vectorize(sch[real_output].op.axis[-1])
elif div4 and not is_idx_reduce:
a, b = sch[real_output].split(fused_outer, factor=4)
sch[real_output].vectorize(b)
ftc = ftc / 4

num_thread = get_div(ftc, 128)
if not is_idx_reduce:
if use_rfactor:
# below values were selected empirically assuming that we should have some work in each
# thread (currently from 25-49) and number of threads not exceeding some threshold that
# was selected as 256 from performance point of view after experiments on Adreno 660
max_threads = rdomain.value // 25 if rdomain > 25 else 1
max_threads = 256 if max_threads > 256 else max_threads
num_thread = get_div(rdomain, max_threads)

bx, outer_in = sch[real_output].split(a, factor=num_thread)
fused_reduce = sch[whole_rop_output].fuse(*sch[whole_rop_output].op.reduce_axis)
thread_y = te.thread_axis((0, num_thread), "threadIdx.y")
_, ki = sch[whole_rop_output].split(fused_reduce, factor=num_thread)
data_out_rf = sch.rfactor(whole_rop_output, ki)
sch[data_out_rf].compute_at(
sch[whole_rop_output], sch[whole_rop_output].op.reduce_axis[0]
)
sch[whole_rop_output].bind(sch[whole_rop_output].op.reduce_axis[0], thread_y)

sch[real_output].bind(bx, te.thread_axis("blockIdx.x"))
sch[real_output].bind(outer_in, te.thread_axis("threadIdx.y"))
if is_idx_reduce:
sch[temp_idx_input].compute_at(sch[real_output], outer_in)
sch[temp_val_input].compute_at(sch[real_output], outer_in)
if div4:
if latest4:
b = sch[sch_output].op.axis[-1]
else:
a, b = sch[sch_output].split(fused_outer, factor=4)
sch[sch_output].vectorize(b)
if not use_rfactor:
if is_idx_reduce:
sch[temp_idx_input].compute_at(sch[sch_output], b)
sch[temp_val_input].compute_at(sch[sch_output], b)
else:
sch[whole_rop_output].compute_at(sch[sch_output], b)

if not use_rfactor:
num_thread = get_div(ftc, 128)
bx, outer_in = sch[sch_output].split(a, factor=num_thread)
sch[sch_output].bind(bx, te.thread_axis("blockIdx.x"))
sch[sch_output].bind(outer_in, te.thread_axis("threadIdx.x"))

if not div4:
if is_idx_reduce:
sch[temp_idx_input].compute_at(sch[sch_output], outer_in)
sch[temp_val_input].compute_at(sch[sch_output], outer_in)
else:
sch[whole_rop_output].compute_at(sch[sch_output], outer_in)
else:
sch[sch_output].bind(a, te.thread_axis("blockIdx.x"))
if not div4 or use_rfactor:
if is_idx_reduce:
sch[temp_idx_input].compute_at(sch[sch_output], a)
sch[temp_val_input].compute_at(sch[sch_output], a)
else:
sch[whole_rop_output].compute_at(sch[sch_output], a)


def schedule_reduce(outs):
return schedule_reduce_impl(outs, _schedule_reduce_adreno, schedule_injective_from_existing)
return schedule_reduce_impl(
outs, _schedule_reduce_adreno, schedule_injective_from_existing, True
)
6 changes: 4 additions & 2 deletions python/tvm/topi/cuda/reduction.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,9 @@ def is_scheduled(stage):
return True


def schedule_reduce_impl(outs, schedule_reduce_stage, schedule_injective_stage):
def schedule_reduce_impl(
outs, schedule_reduce_stage, schedule_injective_stage, inline_postops=False
):
"""Schedule for inject->reduce->bcast ops.
Traverse over the stages in the schedule and schedule separate stages depending
on the position of the stage. Injecteve post-ops of reduction will be scheduled using
Expand Down Expand Up @@ -160,7 +162,7 @@ def traverse_before_reduce(operator):
def traverse_after_reduce(operator):
"""Internal traverse function"""
if tag.is_broadcast(operator.tag):
if operator not in scheduled_ops:
if operator not in scheduled_ops and not inline_postops:
schedule_injective_stage(sch, operator.output(0))
for tensor in operator.input_tensors:
if tensor.op not in scheduled_ops:
Expand Down
2 changes: 1 addition & 1 deletion src/relay/transforms/annotate_texture_storage.cc
Original file line number Diff line number Diff line change
Expand Up @@ -404,7 +404,7 @@ class StorageInfo : private transform::DeviceAwareExprVisitor {
} else if (const OpNode* opnode = call->op.as<OpNode>()) {
auto fpattern = Op::GetAttrMap<TOpPattern>("TOpPattern");
auto pattern = fpattern[GetRef<Op>(opnode)];
if (pattern <= kInjective) {
if (pattern <= kCommReduce) {
if (const auto* ttype = call->checked_type().as<TensorTypeNode>()) {
if (ttype->shape.size() == 5) {
supports_texture_storage = true;
Expand Down
126 changes: 126 additions & 0 deletions tests/python/relay/opencl_texture/test_reduction_texture.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,5 +51,131 @@ def test_argmax(remote, target, dtype):
build_run_compare(remote, mod, {}, {"data": input_shape}, {"data": dtype}, target)


@tvm.testing.requires_opencl
@tvm.testing.parametrize_targets("opencl -device=adreno")
def test_reduction_max(remote, target, dtype):
# NCHW
input_shape = (1, 3, 720, 1280)
A = relay.var("data", shape=input_shape, dtype=dtype)
argmax = relay.op.max(A, axis=[1])
mod = relay.Function([A], argmax)

build_run_compare(remote, mod, {}, {"data": input_shape}, {"data": dtype}, target)


@tvm.testing.requires_opencl
@tvm.testing.parametrize_targets("opencl -device=adreno")
def test_mean_nd4(remote, target, dtype):
# NCHW
input_shape = (1, 3, 729, 729)
A = relay.var("data", shape=input_shape, dtype=dtype)
mean = relay.mean(A, axis=1, keepdims=True)
mod = relay.Function([A], mean)

build_run_compare(remote, mod, {}, {"data": input_shape}, {"data": dtype}, target)


@tvm.testing.requires_opencl
@tvm.testing.parametrize_targets("opencl -device=adreno")
def test_argmax_nd4(remote, target, dtype):
# NCHW
input_shape = (1, 3, 729, 729)
A = relay.var("data", shape=input_shape, dtype=dtype)
argmax = relay.op.argmax(A, axis=[1])
mod = relay.Function([A], argmax)

build_run_compare(remote, mod, {}, {"data": input_shape}, {"data": dtype}, target)


@tvm.testing.requires_opencl
@tvm.testing.parametrize_targets("opencl -device=adreno")
def test_reduction_max_nd4(remote, target, dtype):
# NCHW
input_shape = (1, 3, 729, 729)
A = relay.var("data", shape=input_shape, dtype=dtype)
argmax = relay.op.max(A, axis=[1])
mod = relay.Function([A], argmax)

build_run_compare(remote, mod, {}, {"data": input_shape}, {"data": dtype}, target)


@tvm.testing.requires_opencl
@tvm.testing.parametrize_targets("opencl -device=adreno")
def test_mean_b4(remote, target, dtype):
# NCHW
input_shape = (1, 3, 720, 320, 4)
A = relay.var("data", shape=input_shape, dtype=dtype)
mean = relay.mean(A, axis=1, keepdims=True)
mod = relay.Function([A], mean)

build_run_compare(remote, mod, {}, {"data": input_shape}, {"data": dtype}, target)


@tvm.testing.requires_opencl
@tvm.testing.parametrize_targets("opencl -device=adreno")
def test_argmax_b4(remote, target, dtype):
# NCHW
input_shape = (1, 3, 720, 320, 4)
A = relay.var("data", shape=input_shape, dtype=dtype)
argmax = relay.op.argmax(A, axis=[1])
mod = relay.Function([A], argmax)

build_run_compare(remote, mod, {}, {"data": input_shape}, {"data": dtype}, target)


@tvm.testing.requires_opencl
@tvm.testing.parametrize_targets("opencl -device=adreno")
def test_reduction_max_b4(remote, target, dtype):
# NCHW
input_shape = (1, 3, 720, 320, 4)
A = relay.var("data", shape=input_shape, dtype=dtype)
argmax = relay.op.max(A, axis=[1])
mod = relay.Function([A], argmax)

build_run_compare(remote, mod, {}, {"data": input_shape}, {"data": dtype}, target)


@tvm.testing.requires_opencl
@tvm.testing.parametrize_targets("opencl -device=adreno")
def test_mean_global_pooling(remote, target, dtype):
"""
Use case of blocked NCHW4c global pooling with big spatial valies
"""
input_shape = (1, 160, 160, 32)
A = relay.var("data", shape=input_shape, dtype=dtype)
mean = relay.mean(A, axis=[1, 2], keepdims=True)
mod = relay.Function([A], mean)

build_run_compare(remote, mod, {}, {"data": input_shape}, {"data": dtype}, target)


@tvm.testing.requires_opencl
@tvm.testing.parametrize_targets("opencl -device=adreno")
def test_mean_global_pooling_block4(remote, target, dtype):
"""
Use case of blocked NCHW4c global pooling with big spatial valies
"""
input_shape = (1, 160, 160, 8, 4)
A = relay.var("data", shape=input_shape, dtype=dtype)
mean = relay.mean(A, axis=[1, 2], keepdims=True)
mod = relay.Function([A], mean)

build_run_compare(remote, mod, {}, {"data": input_shape}, {"data": dtype}, target)


@tvm.testing.requires_opencl
@tvm.testing.parametrize_targets("opencl -device=adreno")
def test_max_global_pooling_block4(remote, target, dtype):
"""
Use case of blocked NCHW4c global pooling with big spatial valies
"""
input_shape = (1, 160, 160, 8, 4)
A = relay.var("data", shape=input_shape, dtype=dtype)
mean = relay.max(A, axis=[1, 2], keepdims=True)
mod = relay.Function([A], mean)

build_run_compare(remote, mod, {}, {"data": input_shape}, {"data": dtype}, target)


if __name__ == "__main__":
tvm.testing.main()