Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 33 additions & 9 deletions python/tvm/topi/cuda/scan.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,21 @@ def _get_thrust_func_name(tvmop):
return tvmop_to_thrust_func_name[tvmop]


def _can_use_scan_thrust(binop):
"""
Check if scan_thrust can be utilized based on the current target and binary op.
"""
target = tvm.target.Target.current()
if target is None:
return False
return binop == tvm.tir.generic.add and any(
[
can_use_thrust(target, "tvm.contrib.thrust.sum_scan"),
can_use_rocthrust(target, "tvm.contrib.thrust.sum_scan"),
]
)


def exclusive_scan_ir(data, output, reduction=None, binop=tvm.tir.generic.add, identity_value=0):
"""Low level IR to do exclusive sum scan along rows of 2D input.

Expand Down Expand Up @@ -363,17 +378,9 @@ def exclusive_scan(
"""

def do_scan(data, output_dtype):
target = tvm.target.Target.current()

# TODO: add support for a prod_scan
if (
target
and binop == tvm.tir.generic.add
and (
can_use_thrust(target, "tvm.contrib.thrust.sum_scan")
or can_use_rocthrust(target, "tvm.contrib.thrust.sum_scan")
)
):
if _can_use_scan_thrust(binop):
return scan_thrust(
data, output_dtype, exclusive=True, return_reduction=return_reduction, binop=binop
)
Expand Down Expand Up @@ -479,6 +486,23 @@ def inclusive_scan(data, axis=-1, output_dtype=None, binop=tvm.tir.generic.add,
output : tvm.te.Tensor
A N-D tensor of the same rank N as the input data.
"""

if _can_use_scan_thrust(binop):
if output_dtype is None or output_dtype == "":
output_dtype = data.dtype
ndim = len(data.shape)
if axis < 0:
axis += ndim

if axis != ndim - 1:
axes = swap(list(range(ndim)), axis)
data = transpose(data, axes)
output = scan_thrust(data, output_dtype, exclusive=False, binop=binop)
if axis != ndim - 1:
axes = swap(list(range(ndim)), axis)
output = transpose(output, axes)
return output

ex_scan = exclusive_scan(
data, axis, output_dtype=output_dtype, binop=binop, identity_value=identity_value
)
Expand Down