Skip to content

Commit caf6b03

Browse files
authored
[TVMScript][Parser] Add more warp-level builtins and Range (#14279)
# Motivation Several builtins "tvm_storage_sync", "tvm_warp_shuffle", "tvm_warp_shuffle_up", "tvm_warp_shuffle_down", "tvm_warp_activemask" and `Range` will appear in TVMScript printer but are missing in TVMScript parser. This PR fix the behavior.
1 parent e3c8f2b commit caf6b03

File tree

3 files changed

+201
-13
lines changed

3 files changed

+201
-13
lines changed

python/tvm/script/ir_builder/tir/ir.py

Lines changed: 39 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,8 @@
2929
import numpy as np # type: ignore
3030

3131
from tvm import tir
32-
from tvm.ir import Range, Type
32+
from tvm import ir
33+
from tvm.ir import Type
3334
from tvm.ir.base import deprecated
3435
from tvm.runtime import String, convert, ndarray
3536
from tvm.target import Target
@@ -496,7 +497,7 @@ def alloc_buffer(
496497
)
497498

498499

499-
def _as_range(dom: Union[Range, List[PrimExpr]]) -> Range:
500+
def _as_range(dom: Union[ir.Range, List[PrimExpr]]) -> ir.Range:
500501
"""The range constructor.
501502
502503
Parameters
@@ -509,21 +510,21 @@ def _as_range(dom: Union[Range, List[PrimExpr]]) -> Range:
509510
res : Range
510511
The Range.
511512
"""
512-
if isinstance(dom, Range):
513+
if isinstance(dom, ir.Range):
513514
return dom
514515
if isinstance(dom, (list, tuple)):
515-
return Range(dom[0], dom[1])
516+
return ir.Range(dom[0], dom[1])
516517
if hasattr(dom, "dtype"):
517-
return Range(IntImm(dom.dtype, 0), dom)
518-
return Range(0, dom)
518+
return ir.Range(IntImm(dom.dtype, 0), dom)
519+
return ir.Range(0, dom)
519520

520521

521522
class axis: # pylint: disable=invalid-name
522523
"""The axis class"""
523524

524525
@staticmethod
525526
def spatial(
526-
dom: Union[Range, List[PrimExpr], Tuple[PrimExpr]],
527+
dom: Union[ir.Range, List[PrimExpr], Tuple[PrimExpr]],
527528
binding: PrimExpr,
528529
dtype: str = "int32",
529530
) -> Var:
@@ -551,7 +552,7 @@ def spatial(
551552

552553
@staticmethod
553554
def reduce(
554-
dom: Union[Range, List[PrimExpr], Tuple[PrimExpr]],
555+
dom: Union[ir.Range, List[PrimExpr], Tuple[PrimExpr]],
555556
binding: PrimExpr,
556557
dtype: str = "int32",
557558
) -> Var:
@@ -579,7 +580,7 @@ def reduce(
579580

580581
@staticmethod
581582
def scan(
582-
dom: Union[Range, List[PrimExpr], Tuple[PrimExpr]],
583+
dom: Union[ir.Range, List[PrimExpr], Tuple[PrimExpr]],
583584
binding: PrimExpr,
584585
dtype: str = "int32",
585586
) -> Var:
@@ -607,7 +608,7 @@ def scan(
607608

608609
@staticmethod
609610
def opaque(
610-
dom: Union[Range, List[PrimExpr], Tuple[PrimExpr]],
611+
dom: Union[ir.Range, List[PrimExpr], Tuple[PrimExpr]],
611612
binding: PrimExpr,
612613
dtype: str = "int32",
613614
) -> Var:
@@ -1288,7 +1289,7 @@ def buffer_store(
12881289

12891290
def prefetch(
12901291
buffer: Buffer, # pylint: disable=redefined-outer-name
1291-
bounds: List[Range],
1292+
bounds: List[ir.Range],
12921293
) -> None:
12931294
"""The prefetch hint for a buffer.
12941295
@@ -1579,7 +1580,7 @@ def max(a: PrimExpr, b: PrimExpr) -> PrimExpr: # pylint: disable=redefined-buil
15791580
return _ffi_api.max(a, b) # type: ignore[attr-defined] # pylint: disable=no-member
15801581

15811582

1582-
def iter_var(v: Union[Var, str], dom: Range, iter_type: str, thread_tag: str) -> IterVar:
1583+
def iter_var(v: Union[Var, str], dom: ir.Range, iter_type: str, thread_tag: str) -> IterVar:
15831584
"""The iteration variable.
15841585
15851586
Parameters
@@ -1666,6 +1667,21 @@ def target(target_config: Union[Dict, str]) -> Target:
16661667
return Target(target_config)
16671668

16681669

1670+
def Range(begin: PrimExpr, end: PrimExpr) -> ir.Range: # pylint: disable=invalid-name
1671+
"""
1672+
Create a Range object.
1673+
1674+
Parameters
1675+
----------
1676+
begin : PrimExpr
1677+
The begin value of the range.
1678+
1679+
end : Optional[PrimExpr]
1680+
The end value of the range.
1681+
"""
1682+
return ir.Range(begin, end)
1683+
1684+
16691685
class meta_var: # pylint: disable=invalid-name
16701686
"""A meta variable used in TVMScript metaprogramming. It means that the value of the variable
16711687
does not appear in the final TIR, but only stays in the parser.
@@ -1782,6 +1798,11 @@ def wrapped(*args, **kwargs):
17821798
tvm_bmma_sync = _op_wrapper(_tir_op.tvm_bmma_sync)
17831799
tvm_fill_fragment = _op_wrapper(_tir_op.tvm_fill_fragment)
17841800
tvm_store_matrix_sync = _op_wrapper(_tir_op.tvm_store_matrix_sync)
1801+
tvm_storage_sync = _tir_op.tvm_storage_sync
1802+
tvm_warp_shuffle = _tir_op.tvm_warp_shuffle
1803+
tvm_warp_shuffle_up = _tir_op.tvm_warp_shuffle_up
1804+
tvm_warp_shuffle_down = _tir_op.tvm_warp_shuffle_down
1805+
tvm_warp_activemask = _tir_op.tvm_warp_activemask
17851806
ptx_wait_group = _op_wrapper(_tir_op.ptx_wait_group)
17861807
ptx_commit_group = _op_wrapper(_tir_op.ptx_commit_group)
17871808
assume = _op_wrapper(_tir_op.assume)
@@ -2042,6 +2063,11 @@ def wrapped(*args, **kwargs):
20422063
"tvm_bmma_sync",
20432064
"tvm_fill_fragment",
20442065
"tvm_store_matrix_sync",
2066+
"tvm_storage_sync",
2067+
"tvm_warp_shuffle",
2068+
"tvm_warp_shuffle_up",
2069+
"tvm_warp_shuffle_down",
2070+
"tvm_warp_activemask",
20452071
"ptx_mma",
20462072
"ptx_mma_sp",
20472073
"ptx_ldmatrix",
@@ -2109,4 +2135,5 @@ def wrapped(*args, **kwargs):
21092135
"Let",
21102136
"IterVar",
21112137
"CommReducer",
2138+
"Range",
21122139
]

python/tvm/tir/op.py

Lines changed: 107 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -569,7 +569,8 @@ def lookup_param(param_name, span=None):
569569

570570

571571
def tvm_thread_allreduce(*freduce_args):
572-
"""
572+
"""Perform allreduce inside threadblock.
573+
573574
Parameters
574575
----------
575576
freduce_args : Expr
@@ -583,6 +584,111 @@ def tvm_thread_allreduce(*freduce_args):
583584
return call_intrin("handle", "tir.tvm_thread_allreduce", *freduce_args)
584585

585586

587+
def tvm_storage_sync(storage_scope):
588+
"""Perform synchronization in specified scope.
589+
590+
Parameters
591+
----------
592+
storage_scope : str
593+
The storage scope to perform synchronization.
594+
595+
Returns
596+
-------
597+
call : PrimExpr
598+
The call expression.
599+
"""
600+
return call_intrin("handle", "tir.tvm_storage_sync", storage_scope)
601+
602+
603+
def tvm_warp_shuffle(mask, value, warp_id, width, warp_size):
604+
"""Exchange value between threads inside a warp.
605+
606+
Parameters
607+
----------
608+
mask : PrimExpr
609+
The warp mask indicates active threads inside warp.
610+
value : PrimExpr
611+
The value to exchange.
612+
warp_id : PrimExpr
613+
The source lane index to fetch value.
614+
width : PrimExpr
615+
The width of sub-sections to perform warp shuffle.
616+
warp_size : PrimExpr
617+
The warp size.
618+
619+
Returns
620+
-------
621+
call : PrimExpr
622+
The call expression.
623+
"""
624+
return call_intrin(value.dtype, "tir.tvm_warp_shuffle", mask, value, warp_id, width, warp_size)
625+
626+
627+
def tvm_warp_shuffle_up(mask, value, offset, width, warp_size):
628+
"""Copy value from a lane with lower (by offset) index relative to caller.
629+
630+
Parameters
631+
----------
632+
mask : PrimExpr
633+
The warp mask indicates active threads inside warp.
634+
value : PrimExpr
635+
The value to exchange.
636+
offset : PrimExpr
637+
The difference between source lane index and destination lane index:
638+
`offset = dst_lane_idx - src_lane_idx`
639+
width : PrimExpr
640+
The width of sub-sections to perform warp shuffle.
641+
warp_size : PrimExpr
642+
The warp size.
643+
644+
Returns
645+
-------
646+
call : PrimExpr
647+
The call expression.
648+
"""
649+
return call_intrin(
650+
value.dtype, "tir.tvm_warp_shuffle_up", mask, value, offset, width, warp_size
651+
)
652+
653+
654+
def tvm_warp_shuffle_down(mask, value, offset, width, warp_size):
655+
"""Copy value from a lane with higher (by offset) index relative to caller.
656+
657+
Parameters
658+
----------
659+
mask : PrimExpr
660+
The warp mask indicates active threads inside warp.
661+
value : PrimExpr
662+
The value to exchange.
663+
offset : PrimExpr
664+
The difference between source lane index and destination lane index:
665+
`offset = src_lane_idx - dst_lane_idx`
666+
width : PrimExpr
667+
The width of sub-sections to perform warp shuffle.
668+
warp_size : PrimExpr
669+
The warp size.
670+
671+
Returns
672+
-------
673+
call : PrimExpr
674+
The call expression.
675+
"""
676+
return call_intrin(
677+
value.dtype, "tir.tvm_warp_shuffle_down", mask, value, offset, width, warp_size
678+
)
679+
680+
681+
def tvm_warp_activemask():
682+
"""Return a 32-bit mask indicates currently active threads in a calling warp.
683+
684+
Returns
685+
-------
686+
call : PrimExpr
687+
The call expression.
688+
"""
689+
return call_intrin("uint32", "tir.tvm_warp_activemask")
690+
691+
586692
def type_annotation(dtype):
587693
"""Create a type annotation expression
588694

tests/python/unittest/test_tvmscript_roundtrip.py

Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3623,6 +3623,60 @@ def main(A: T.handle, B: T.handle):
36233623
return main
36243624

36253625

3626+
def tvm_shfl_builtins():
3627+
@T.prim_func
3628+
def func(
3629+
A: T.handle("float32"),
3630+
B: T.handle("float32"),
3631+
C: T.handle("float32"),
3632+
):
3633+
blockIdx_x = T.launch_thread("blockIdx.x", 1)
3634+
threadIdx_x = T.launch_thread("threadIdx.x", 32)
3635+
A_warp = T.allocate([1], "float32", "local")
3636+
B_warp = T.allocate([1], "float32", "local")
3637+
red_buf0 = T.allocate([1], "float32", "local")
3638+
A_warp_1 = T.Buffer((32,), data=A_warp, scope="local")
3639+
A_1 = T.Buffer((32,), data=A)
3640+
A_warp_1[0] = A_1[threadIdx_x]
3641+
B_warp_1 = T.Buffer((32,), data=B_warp, scope="local")
3642+
T.tvm_storage_sync("warp")
3643+
B_warp_1[0] = T.tvm_warp_shuffle(
3644+
T.tvm_warp_activemask(), A_warp_1[0], threadIdx_x % 4 * 8 + threadIdx_x // 4, 32, 32
3645+
) + T.float32(1)
3646+
red_buf0_1 = T.Buffer((1,), data=red_buf0, scope="local")
3647+
with T.attr(
3648+
T.comm_reducer(lambda x0, y0: x0 + y0, [T.float32(0)]),
3649+
"reduce_scope",
3650+
T.reinterpret("handle", T.uint64(0)),
3651+
):
3652+
mask = T.allocate([1], "uint32", "local")
3653+
t0 = T.allocate([1], "float32", "local")
3654+
red_buf0_1[0] = A_warp_1[0]
3655+
mask_1 = T.Buffer((1,), "uint32", data=mask, scope="local")
3656+
mask_1[0] = T.tvm_warp_activemask()
3657+
t0_1 = T.Buffer((1,), data=t0, scope="local")
3658+
t0_1[0] = T.tvm_warp_shuffle_down(mask_1[0], red_buf0_1[0], 16, 32, 32)
3659+
red_buf0_1[0] = red_buf0_1[0] + t0_1[0]
3660+
t0_1[0] = T.tvm_warp_shuffle_down(mask_1[0], red_buf0_1[0], 8, 32, 32)
3661+
red_buf0_1[0] = red_buf0_1[0] + t0_1[0]
3662+
t0_1[0] = T.tvm_warp_shuffle_down(mask_1[0], red_buf0_1[0], 4, 32, 32)
3663+
red_buf0_1[0] = red_buf0_1[0] + t0_1[0]
3664+
t0_1[0] = T.tvm_warp_shuffle_down(mask_1[0], red_buf0_1[0], 2, 32, 32)
3665+
red_buf0_1[0] = red_buf0_1[0] + t0_1[0]
3666+
t0_1[0] = T.tvm_warp_shuffle_down(mask_1[0], red_buf0_1[0], 1, 32, 32)
3667+
red_buf0_1[0] = red_buf0_1[0] + t0_1[0]
3668+
red_buf0_1[0] = T.tvm_warp_shuffle(mask_1[0], red_buf0_1[0], 0, 32, 32)
3669+
# NOTE(Zihao): test tvm_warp_shuffle_up
3670+
red_buf0_1[0] = T.tvm_warp_shuffle_up(mask_1[0], red_buf0_1[0], 0, 32, 32)
3671+
if threadIdx_x == 0:
3672+
C_1 = T.Buffer((1,), data=C)
3673+
C_1[0] = red_buf0_1[0]
3674+
B_1 = T.Buffer((32,), data=B)
3675+
B_1[threadIdx_x] = B_warp_1[0]
3676+
3677+
return func
3678+
3679+
36263680
ir_generator = tvm.testing.parameter(
36273681
launch_env_thread,
36283682
opt_gemm_normalize,
@@ -3686,6 +3740,7 @@ def main(A: T.handle, B: T.handle):
36863740
let_stmt_value,
36873741
string_stride,
36883742
merge_shape_var_def,
3743+
tvm_shfl_builtins,
36893744
)
36903745

36913746

0 commit comments

Comments
 (0)