Skip to content

Commit aa0b202

Browse files
committed
[TOPI] improve inclusive_scan for thrust
Fix comments
1 parent e56c5e1 commit aa0b202

File tree

1 file changed

+33
-9
lines changed

1 file changed

+33
-9
lines changed

python/tvm/topi/cuda/scan.py

Lines changed: 33 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,21 @@ def _get_thrust_func_name(tvmop):
3535
return tvmop_to_thrust_func_name[tvmop]
3636

3737

38+
def _can_use_scan_thrust(binop):
39+
"""
40+
Check if scan_thrust can be utilized based on the current target and binary op.
41+
"""
42+
target = tvm.target.Target.current()
43+
if target is None:
44+
return False
45+
return binop == tvm.tir.generic.add and any(
46+
[
47+
can_use_thrust(target, "tvm.contrib.thrust.sum_scan"),
48+
can_use_rocthrust(target, "tvm.contrib.thrust.sum_scan"),
49+
]
50+
)
51+
52+
3853
def exclusive_scan_ir(data, output, reduction=None, binop=tvm.tir.generic.add, identity_value=0):
3954
"""Low level IR to do exclusive sum scan along rows of 2D input.
4055
@@ -363,17 +378,9 @@ def exclusive_scan(
363378
"""
364379

365380
def do_scan(data, output_dtype):
366-
target = tvm.target.Target.current()
367381

368382
# TODO: add support for a prod_scan
369-
if (
370-
target
371-
and binop == tvm.tir.generic.add
372-
and (
373-
can_use_thrust(target, "tvm.contrib.thrust.sum_scan")
374-
or can_use_rocthrust(target, "tvm.contrib.thrust.sum_scan")
375-
)
376-
):
383+
if _can_use_scan_thrust(binop):
377384
return scan_thrust(
378385
data, output_dtype, exclusive=True, return_reduction=return_reduction, binop=binop
379386
)
@@ -479,6 +486,23 @@ def inclusive_scan(data, axis=-1, output_dtype=None, binop=tvm.tir.generic.add,
479486
output : tvm.te.Tensor
480487
A N-D tensor of the same rank N as the input data.
481488
"""
489+
490+
if _can_use_scan_thrust(binop):
491+
if output_dtype is None or output_dtype == "":
492+
output_dtype = data.dtype
493+
ndim = len(data.shape)
494+
if axis < 0:
495+
axis += ndim
496+
497+
if axis != ndim - 1:
498+
axes = swap(list(range(ndim)), axis)
499+
data = transpose(data, axes)
500+
output = scan_thrust(data, output_dtype, exclusive=False, binop=binop)
501+
if axis != ndim - 1:
502+
axes = swap(list(range(ndim)), axis)
503+
output = transpose(output, axes)
504+
return output
505+
482506
ex_scan = exclusive_scan(
483507
data, axis, output_dtype=output_dtype, binop=binop, identity_value=identity_value
484508
)

0 commit comments

Comments
 (0)