@@ -35,6 +35,21 @@ def _get_thrust_func_name(tvmop):
3535 return tvmop_to_thrust_func_name [tvmop ]
3636
3737
38+ def _can_use_scan_thrust (binop ):
39+ """
40+ Check if scan_thrust can be utilized based on the current target and binary op.
41+ """
42+ target = tvm .target .Target .current ()
43+ if target is None :
44+ return False
45+ return binop == tvm .tir .generic .add and any (
46+ [
47+ can_use_thrust (target , "tvm.contrib.thrust.sum_scan" ),
48+ can_use_rocthrust (target , "tvm.contrib.thrust.sum_scan" ),
49+ ]
50+ )
51+
52+
3853def exclusive_scan_ir (data , output , reduction = None , binop = tvm .tir .generic .add , identity_value = 0 ):
3954 """Low level IR to do exclusive sum scan along rows of 2D input.
4055
@@ -363,17 +378,9 @@ def exclusive_scan(
363378 """
364379
365380 def do_scan (data , output_dtype ):
366- target = tvm .target .Target .current ()
367381
368382 # TODO: add support for a prod_scan
369- if (
370- target
371- and binop == tvm .tir .generic .add
372- and (
373- can_use_thrust (target , "tvm.contrib.thrust.sum_scan" )
374- or can_use_rocthrust (target , "tvm.contrib.thrust.sum_scan" )
375- )
376- ):
383+ if _can_use_scan_thrust (binop ):
377384 return scan_thrust (
378385 data , output_dtype , exclusive = True , return_reduction = return_reduction , binop = binop
379386 )
@@ -479,6 +486,23 @@ def inclusive_scan(data, axis=-1, output_dtype=None, binop=tvm.tir.generic.add,
479486 output : tvm.te.Tensor
480487 A N-D tensor of the same rank N as the input data.
481488 """
489+
490+ if _can_use_scan_thrust (binop ):
491+ if output_dtype is None or output_dtype == "" :
492+ output_dtype = data .dtype
493+ ndim = len (data .shape )
494+ if axis < 0 :
495+ axis += ndim
496+
497+ if axis != ndim - 1 :
498+ axes = swap (list (range (ndim )), axis )
499+ data = transpose (data , axes )
500+ output = scan_thrust (data , output_dtype , exclusive = False , binop = binop )
501+ if axis != ndim - 1 :
502+ axes = swap (list (range (ndim )), axis )
503+ output = transpose (output , axes )
504+ return output
505+
482506 ex_scan = exclusive_scan (
483507 data , axis , output_dtype = output_dtype , binop = binop , identity_value = identity_value
484508 )
0 commit comments