From e70d8eb491fdfbe23988ce66deecae036f4e4370 Mon Sep 17 00:00:00 2001 From: Zihao Date: Sat, 11 Mar 2023 10:50:26 -0800 Subject: [PATCH 01/13] init --- python/tvm/script/ir_builder/tir/ir.py | 27 +++++++++++++++++++ .../unittest/test_tvmscript_roundtrip.py | 18 +++++++++++++ 2 files changed, 45 insertions(+) diff --git a/python/tvm/script/ir_builder/tir/ir.py b/python/tvm/script/ir_builder/tir/ir.py index d65f9adea86f..cb3cdbbf5054 100644 --- a/python/tvm/script/ir_builder/tir/ir.py +++ b/python/tvm/script/ir_builder/tir/ir.py @@ -1665,6 +1665,32 @@ def target(target_config: Union[Dict, str]) -> Target: ) return Target(target_config) +def Range(begin: PrimExpr, end: Optional[PrimExpr]) -> Range: + """Create a Range node. + + Parameters + ---------- + begin : PrimExpr + The begin value of the range. + + end : Optional[PrimExpr] + The end value of the range. + + Returns + ------- + res : Range + The Range node. + """ + if not isinstance(begin, PrimExpr): + raise ValueError( + f"T.Range expected a PrimExpr as begin value, but got {type(begin)} instead." + ) + if not isinstance(end, PrimExpr) and end is not None: + raise ValueError( + f"T.Range expected a Optional[PrimExpr] as end value, but got {type(end)} instead." + ) + return Range(begin, end) + class meta_var: # pylint: disable=invalid-name """A meta variable used in TVMScript metaprogramming. It means that the value of the variable @@ -2109,4 +2135,5 @@ def wrapped(*args, **kwargs): "Let", "IterVar", "CommReducer", + "Range" ] diff --git a/tests/python/unittest/test_tvmscript_roundtrip.py b/tests/python/unittest/test_tvmscript_roundtrip.py index c956f3bb02b9..350d06395b6c 100644 --- a/tests/python/unittest/test_tvmscript_roundtrip.py +++ b/tests/python/unittest/test_tvmscript_roundtrip.py @@ -3623,6 +3623,24 @@ def main(A: T.handle, B: T.handle): return main +def iter_var_range(): + T.prim_func + + def func(): + blockIdx_x = T.int32() + threadIdx_x = T.int32() + T.func_attr( + { + "tir.device_thread_axis": [ + T.iter_var(blockIdx_x, T.Range(0, 1), "ThreadIndex", "blockIdx.x"), + T.iter_var(threadIdx_x, T.Range(0, 32), "ThreadIndex", "threadIdx.x"), + ] + } + ) + + return func + + ir_generator = tvm.testing.parameter( launch_env_thread, opt_gemm_normalize, From b3b98fa461f75c8f7de21c2e1986671ee57daab4 Mon Sep 17 00:00:00 2001 From: Zihao Date: Sun, 12 Mar 2023 05:22:39 -0700 Subject: [PATCH 02/13] upd --- python/tvm/script/ir_builder/tir/ir.py | 8 ++ python/tvm/tir/op.py | 104 ++++++++++++++++++++++++- 2 files changed, 111 insertions(+), 1 deletion(-) diff --git a/python/tvm/script/ir_builder/tir/ir.py b/python/tvm/script/ir_builder/tir/ir.py index cb3cdbbf5054..ce9c9dce17a2 100644 --- a/python/tvm/script/ir_builder/tir/ir.py +++ b/python/tvm/script/ir_builder/tir/ir.py @@ -1808,6 +1808,11 @@ def wrapped(*args, **kwargs): tvm_bmma_sync = _op_wrapper(_tir_op.tvm_bmma_sync) tvm_fill_fragment = _op_wrapper(_tir_op.tvm_fill_fragment) tvm_store_matrix_sync = _op_wrapper(_tir_op.tvm_store_matrix_sync) +tvm_storage_sync = _op_wrapper(_tir_op.tvm_storage_sync) +tvm_warp_shuffle = _op_wrapper(_tir_op.tvm_warp_shuffle) +tvm_warp_shuffle_up = _op_wrapper(_tir_op.tvm_warp_shuffle_up) +tvm_warp_shuffle_down = _op_wrapper(_tir_op.tvm_warp_shuffle_down) +tvm_warp_activemask = _op_wrapper(_tir_op.tvm_warp_activemask) ptx_wait_group = _op_wrapper(_tir_op.ptx_wait_group) ptx_commit_group = _op_wrapper(_tir_op.ptx_commit_group) assume = _op_wrapper(_tir_op.assume) @@ -2068,6 +2073,9 @@ def wrapped(*args, **kwargs): "tvm_bmma_sync", "tvm_fill_fragment", "tvm_store_matrix_sync", + "tvm_storage_sync", + "tvm_warp_shuffle", + "tvm_warp_activemask", "ptx_mma", "ptx_mma_sp", "ptx_ldmatrix", diff --git a/python/tvm/tir/op.py b/python/tvm/tir/op.py index 0a9c4fdfaa52..498926045667 100644 --- a/python/tvm/tir/op.py +++ b/python/tvm/tir/op.py @@ -569,7 +569,8 @@ def lookup_param(param_name, span=None): def tvm_thread_allreduce(*freduce_args): - """ + """Perform allreduce inside threadblock. + Parameters ---------- freduce_args : Expr @@ -583,6 +584,107 @@ def tvm_thread_allreduce(*freduce_args): return call_intrin("handle", "tir.tvm_thread_allreduce", *freduce_args) +def tvm_storage_sync(storage_scope): + """Perform synchronization in specified scope. + + Parameters + ---------- + storage_scope : str + The storage scope to perform synchronization. + + Returns + ------- + call : PrimExpr + The call expression. + """ + return call_intrin("handle", "tir.tvm_storage_sync", storage_scope) + + +def tvm_warp_shuffle(mask, value, warp_id, width, warp_size): + """Exchange value between threads inside a warp. + + Parameters + ---------- + mask : PrimExpr + The warp mask indicates active threads inside warp. + value : PrimExpr + The value to exchange. + warp_id : PrimExpr + The source lane index to fetch value. + width : PrimExpr + The width of sub-sections to perform warp shuffle. + warp_size : PrimExpr + The warp size. + + Returns + ------- + call : PrimExpr + The call expression. + """ + return call_intrin("handle", "tir.tvm_warp_shuffle", mask, value, warp_id, width, warp_size) + + +def tvm_warp_shuffle_up(mask, value, offset, width, warp_size): + """Copy value from a lane with lower (by offset) index relative to caller. + + Parameters + ---------- + mask : PrimExpr + The warp mask indicates active threads inside warp. + value : PrimExpr + The value to exchange. + offset : PrimExpr + The difference between source lane index and destination lane index: + `offset = dst_lane_idx - src_lane_idx` + width : PrimExpr + The width of sub-sections to perform warp shuffle. + warp_size : PrimExpr + The warp size. + + Returns + ------- + call : PrimExpr + The call expression. + """ + return call_intrin("handle", "tir.tvm_warp_shuffle_up", mask, value, offset, width, warp_size) + + +def tvm_warp_shuffle_down(mask, value, offset, width, warp_size): + """Copy value from a lane with higher (by offset) index relative to caller. + + Parameters + ---------- + mask : PrimExpr + The warp mask indicates active threads inside warp. + value : PrimExpr + The value to exchange. + offset : PrimExpr + The difference between source lane index and destination lane index: + `offset = src_lane_idx - dst_lane_idx` + width : PrimExpr + The width of sub-sections to perform warp shuffle. + warp_size : PrimExpr + The warp size. + + Returns + ------- + call : PrimExpr + The call expression. + """ + return call_intrin("handle", "tir.tvm_warp_shuffle_down", mask, value, offset, width, warp_size) + + +def tvm_warp_activemask(): + """Return a 32-bit mask indicates currently active threads in a calling warp. + + Returns + ------- + call : PrimExpr + The call expression. + """ + return call_intrin("handle", "tir.tvm_warp_activemask") + + def type_annotation(dtype): """Create a type annotation expression From 409f7c5c3403fefae78dc9e49c3245f1cfa82d71 Mon Sep 17 00:00:00 2001 From: Zihao Date: Sun, 12 Mar 2023 05:59:42 -0700 Subject: [PATCH 03/13] upd --- python/tvm/script/ir_builder/tir/ir.py | 55 +++++++---------- python/tvm/tir/op.py | 8 +-- .../unittest/test_tvmscript_roundtrip.py | 61 ++++++++++++++----- 3 files changed, 74 insertions(+), 50 deletions(-) diff --git a/python/tvm/script/ir_builder/tir/ir.py b/python/tvm/script/ir_builder/tir/ir.py index ce9c9dce17a2..3abaadd8cf99 100644 --- a/python/tvm/script/ir_builder/tir/ir.py +++ b/python/tvm/script/ir_builder/tir/ir.py @@ -29,7 +29,8 @@ import numpy as np # type: ignore from tvm import tir -from tvm.ir import Range, Type +from tvm import ir +from tvm.ir import Type from tvm.ir.base import deprecated from tvm.runtime import String, convert, ndarray from tvm.target import Target @@ -496,7 +497,7 @@ def alloc_buffer( ) -def _as_range(dom: Union[Range, List[PrimExpr]]) -> Range: +def _as_range(dom: Union[ir.Range, List[PrimExpr]]) -> ir.Range: """The range constructor. Parameters @@ -509,13 +510,13 @@ def _as_range(dom: Union[Range, List[PrimExpr]]) -> Range: res : Range The Range. """ - if isinstance(dom, Range): + if isinstance(dom, ir.Range): return dom if isinstance(dom, (list, tuple)): - return Range(dom[0], dom[1]) + return ir.Range(dom[0], dom[1]) if hasattr(dom, "dtype"): - return Range(IntImm(dom.dtype, 0), dom) - return Range(0, dom) + return ir.Range(IntImm(dom.dtype, 0), dom) + return ir.Range(0, dom) class axis: # pylint: disable=invalid-name @@ -523,7 +524,7 @@ class axis: # pylint: disable=invalid-name @staticmethod def spatial( - dom: Union[Range, List[PrimExpr], Tuple[PrimExpr]], + dom: Union[ir.Range, List[PrimExpr], Tuple[PrimExpr]], binding: PrimExpr, dtype: str = "int32", ) -> Var: @@ -551,7 +552,7 @@ def spatial( @staticmethod def reduce( - dom: Union[Range, List[PrimExpr], Tuple[PrimExpr]], + dom: Union[ir.Range, List[PrimExpr], Tuple[PrimExpr]], binding: PrimExpr, dtype: str = "int32", ) -> Var: @@ -579,7 +580,7 @@ def reduce( @staticmethod def scan( - dom: Union[Range, List[PrimExpr], Tuple[PrimExpr]], + dom: Union[ir.Range, List[PrimExpr], Tuple[PrimExpr]], binding: PrimExpr, dtype: str = "int32", ) -> Var: @@ -607,7 +608,7 @@ def scan( @staticmethod def opaque( - dom: Union[Range, List[PrimExpr], Tuple[PrimExpr]], + dom: Union[ir.Range, List[PrimExpr], Tuple[PrimExpr]], binding: PrimExpr, dtype: str = "int32", ) -> Var: @@ -1288,7 +1289,7 @@ def buffer_store( def prefetch( buffer: Buffer, # pylint: disable=redefined-outer-name - bounds: List[Range], + bounds: List[ir.Range], ) -> None: """The prefetch hint for a buffer. @@ -1579,7 +1580,7 @@ def max(a: PrimExpr, b: PrimExpr) -> PrimExpr: # pylint: disable=redefined-buil return _ffi_api.max(a, b) # type: ignore[attr-defined] # pylint: disable=no-member -def iter_var(v: Union[Var, str], dom: Range, iter_type: str, thread_tag: str) -> IterVar: +def iter_var(v: Union[Var, str], dom: ir.Range, iter_type: str, thread_tag: str) -> IterVar: """The iteration variable. Parameters @@ -1665,32 +1666,20 @@ def target(target_config: Union[Dict, str]) -> Target: ) return Target(target_config) -def Range(begin: PrimExpr, end: Optional[PrimExpr]) -> Range: - """Create a Range node. - + +def Range(begin: PrimExpr, end: PrimExpr) -> ir.Range: + """ + Create a Range object. + Parameters ---------- - begin : PrimExpr + begin : PrimExpr The begin value of the range. end : Optional[PrimExpr] The end value of the range. - - Returns - ------- - res : Range - The Range node. """ - if not isinstance(begin, PrimExpr): - raise ValueError( - f"T.Range expected a PrimExpr as begin value, but got {type(begin)} instead." - ) - if not isinstance(end, PrimExpr) and end is not None: - raise ValueError( - f"T.Range expected a Optional[PrimExpr] as end value, but got {type(end)} instead." - ) - return Range(begin, end) - + return ir.Range(begin, end) class meta_var: # pylint: disable=invalid-name """A meta variable used in TVMScript metaprogramming. It means that the value of the variable @@ -2075,6 +2064,8 @@ def wrapped(*args, **kwargs): "tvm_store_matrix_sync", "tvm_storage_sync", "tvm_warp_shuffle", + "tvm_warp_shuffle_up", + "tvm_warp_shuffle_down", "tvm_warp_activemask", "ptx_mma", "ptx_mma_sp", @@ -2143,5 +2134,5 @@ def wrapped(*args, **kwargs): "Let", "IterVar", "CommReducer", - "Range" + "Range", ] diff --git a/python/tvm/tir/op.py b/python/tvm/tir/op.py index 498926045667..dd14d728128d 100644 --- a/python/tvm/tir/op.py +++ b/python/tvm/tir/op.py @@ -621,7 +621,7 @@ def tvm_warp_shuffle(mask, value, warp_id, width, warp_size): call : PrimExpr The call expression. """ - return call_intrin("handle", "tir.tvm_warp_shuffle", mask, value, warp_id, width, warp_size) + return call_intrin(value.dtype, "tir.tvm_warp_shuffle", mask, value, warp_id, width, warp_size) def tvm_warp_shuffle_up(mask, value, offset, width, warp_size): @@ -646,7 +646,7 @@ def tvm_warp_shuffle_up(mask, value, offset, width, warp_size): call : PrimExpr The call expression. """ - return call_intrin("handle", "tir.tvm_warp_shuffle_up", mask, value, offset, width, warp_size) + return call_intrin(value.dtype, "tir.tvm_warp_shuffle_up", mask, value, offset, width, warp_size) def tvm_warp_shuffle_down(mask, value, offset, width, warp_size): @@ -671,7 +671,7 @@ def tvm_warp_shuffle_down(mask, value, offset, width, warp_size): call : PrimExpr The call expression. """ - return call_intrin("handle", "tir.tvm_warp_shuffle_down", mask, value, offset, width, warp_size) + return call_intrin(value.dtype, "tir.tvm_warp_shuffle_down", mask, value, offset, width, warp_size) def tvm_warp_activemask(): @@ -682,7 +682,7 @@ def tvm_warp_activemask(): call : PrimExpr The call expression. """ - return call_intrin("handle", "tir.tvm_warp_activemask") + return call_intrin("int32", "tir.tvm_warp_activemask") def type_annotation(dtype): diff --git a/tests/python/unittest/test_tvmscript_roundtrip.py b/tests/python/unittest/test_tvmscript_roundtrip.py index 350d06395b6c..9f3c18b7ca6b 100644 --- a/tests/python/unittest/test_tvmscript_roundtrip.py +++ b/tests/python/unittest/test_tvmscript_roundtrip.py @@ -3623,20 +3623,52 @@ def main(A: T.handle, B: T.handle): return main -def iter_var_range(): - T.prim_func - - def func(): - blockIdx_x = T.int32() - threadIdx_x = T.int32() - T.func_attr( - { - "tir.device_thread_axis": [ - T.iter_var(blockIdx_x, T.Range(0, 1), "ThreadIndex", "blockIdx.x"), - T.iter_var(threadIdx_x, T.Range(0, 32), "ThreadIndex", "threadIdx.x"), - ] - } - ) +def tvm_shfl_builtins(): + @T.prim_func + def func( + A: T.handle("float32", "global"), + C: T.handle("float32", "global"), + B: T.handle("float32", "global"), + ): + blockIdx_x = T.launch_thread("blockIdx.x", 1) + threadIdx_x = T.launch_thread("threadIdx.x", 32) + A_warp = T.allocate([32], "float32", "warp") + B_warp = T.allocate([32], "float32", "warp") + red_buf0 = T.allocate([1], "float32", "local") + A_warp_1 = T.Buffer((32,), data=A_warp, scope="warp") + A_1 = T.Buffer((32,), data=A) + A_warp_1[threadIdx_x] = A_1[threadIdx_x] + B_warp_1 = T.Buffer((32,), data=B_warp, scope="warp") + T.tvm_storage_sync("warp") + B_warp_1[threadIdx_x] = A_warp_1[threadIdx_x % 4 * 8 + threadIdx_x // 4] + T.float32(1) + red_buf0_1 = T.Buffer((1,), data=red_buf0, scope="local") + with T.attr( + T.comm_reducer(lambda x0, y0: x0 + y0, [T.float32(0)]), + "reduce_scope", + T.reinterpret("handle", T.uint64(0)), + ): + mask = T.allocate([1], "uint32", "local") + t0 = T.allocate([1], "float32", "local") + red_buf0_1[0] = A_warp_1[threadIdx_x] + mask_1 = T.Buffer((1,), "uint32", data=mask, scope="local") + mask_1[0] = T.tvm_warp_activemask() + t0_1 = T.Buffer((1,), data=t0, scope="local") + t0_1[0] = T.tvm_warp_shuffle_down(mask_1[0], red_buf0_1[0], 16, 32, 32) + red_buf0_1[0] = red_buf0_1[0] + t0_1[0] + t0_1[0] = T.tvm_warp_shuffle_down(mask_1[0], red_buf0_1[0], 8, 32, 32) + red_buf0_1[0] = red_buf0_1[0] + t0_1[0] + t0_1[0] = T.tvm_warp_shuffle_down(mask_1[0], red_buf0_1[0], 4, 32, 32) + red_buf0_1[0] = red_buf0_1[0] + t0_1[0] + t0_1[0] = T.tvm_warp_shuffle_down(mask_1[0], red_buf0_1[0], 2, 32, 32) + red_buf0_1[0] = red_buf0_1[0] + t0_1[0] + t0_1[0] = T.tvm_warp_shuffle_down(mask_1[0], red_buf0_1[0], 1, 32, 32) + red_buf0_1[0] = red_buf0_1[0] + t0_1[0] + red_buf0_1[0] = T.tvm_warp_shuffle(mask_1[0], red_buf0_1[0], 0, 32, 32) + if threadIdx_x == 0: + C_1 = T.Buffer((1,), data=C) + C_1[0] = red_buf0_1[0] + B_1 = T.Buffer((32,), data=B) + B_1[threadIdx_x] = B_warp_1[threadIdx_x] return func @@ -3704,6 +3736,7 @@ def func(): let_stmt_value, string_stride, merge_shape_var_def, + tvm_shfl_builtins, ) From 9b9225269c3ac13e5dd572c85194fd1fb5c71873 Mon Sep 17 00:00:00 2001 From: Zihao Date: Sun, 12 Mar 2023 06:45:52 -0700 Subject: [PATCH 04/13] add tests --- python/tvm/script/ir_builder/tir/ir.py | 3 ++- .../unittest/test_tvmscript_roundtrip.py | 18 ++++++++++-------- 2 files changed, 12 insertions(+), 9 deletions(-) diff --git a/python/tvm/script/ir_builder/tir/ir.py b/python/tvm/script/ir_builder/tir/ir.py index 3abaadd8cf99..fce1fa5c6010 100644 --- a/python/tvm/script/ir_builder/tir/ir.py +++ b/python/tvm/script/ir_builder/tir/ir.py @@ -1675,12 +1675,13 @@ def Range(begin: PrimExpr, end: PrimExpr) -> ir.Range: ---------- begin : PrimExpr The begin value of the range. - + end : Optional[PrimExpr] The end value of the range. """ return ir.Range(begin, end) + class meta_var: # pylint: disable=invalid-name """A meta variable used in TVMScript metaprogramming. It means that the value of the variable does not appear in the final TIR, but only stays in the parser. diff --git a/tests/python/unittest/test_tvmscript_roundtrip.py b/tests/python/unittest/test_tvmscript_roundtrip.py index 9f3c18b7ca6b..dc2a48b43e7c 100644 --- a/tests/python/unittest/test_tvmscript_roundtrip.py +++ b/tests/python/unittest/test_tvmscript_roundtrip.py @@ -3632,15 +3632,17 @@ def func( ): blockIdx_x = T.launch_thread("blockIdx.x", 1) threadIdx_x = T.launch_thread("threadIdx.x", 32) - A_warp = T.allocate([32], "float32", "warp") - B_warp = T.allocate([32], "float32", "warp") + A_warp = T.allocate([1], "float32", "local") + B_warp = T.allocate([1], "float32", "local") red_buf0 = T.allocate([1], "float32", "local") - A_warp_1 = T.Buffer((32,), data=A_warp, scope="warp") + A_warp_1 = T.Buffer((32,), data=A_warp, scope="local") A_1 = T.Buffer((32,), data=A) - A_warp_1[threadIdx_x] = A_1[threadIdx_x] - B_warp_1 = T.Buffer((32,), data=B_warp, scope="warp") + A_warp_1[0] = A_1[threadIdx_x] + B_warp_1 = T.Buffer((32,), data=B_warp, scope="local") T.tvm_storage_sync("warp") - B_warp_1[threadIdx_x] = A_warp_1[threadIdx_x % 4 * 8 + threadIdx_x // 4] + T.float32(1) + B_warp_1[0] = T.tvm_warp_shuffle( + T.tvm_warp_activemask(), A_warp_1[0], threadIdx_x % 4 * 8 + threadIdx_x // 4, 32, 32 + ) + T.float32(1) red_buf0_1 = T.Buffer((1,), data=red_buf0, scope="local") with T.attr( T.comm_reducer(lambda x0, y0: x0 + y0, [T.float32(0)]), @@ -3649,7 +3651,7 @@ def func( ): mask = T.allocate([1], "uint32", "local") t0 = T.allocate([1], "float32", "local") - red_buf0_1[0] = A_warp_1[threadIdx_x] + red_buf0_1[0] = A_warp_1[0] mask_1 = T.Buffer((1,), "uint32", data=mask, scope="local") mask_1[0] = T.tvm_warp_activemask() t0_1 = T.Buffer((1,), data=t0, scope="local") @@ -3668,7 +3670,7 @@ def func( C_1 = T.Buffer((1,), data=C) C_1[0] = red_buf0_1[0] B_1 = T.Buffer((32,), data=B) - B_1[threadIdx_x] = B_warp_1[threadIdx_x] + B_1[threadIdx_x] = B_warp_1[0] return func From ec5346015e4abe1960ce7035f8ea4e2979cd3027 Mon Sep 17 00:00:00 2001 From: Zihao Date: Sun, 12 Mar 2023 06:51:57 -0700 Subject: [PATCH 05/13] upd --- tests/python/unittest/test_tvmscript_roundtrip.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/tests/python/unittest/test_tvmscript_roundtrip.py b/tests/python/unittest/test_tvmscript_roundtrip.py index dc2a48b43e7c..337e23b0209b 100644 --- a/tests/python/unittest/test_tvmscript_roundtrip.py +++ b/tests/python/unittest/test_tvmscript_roundtrip.py @@ -3666,6 +3666,8 @@ def func( t0_1[0] = T.tvm_warp_shuffle_down(mask_1[0], red_buf0_1[0], 1, 32, 32) red_buf0_1[0] = red_buf0_1[0] + t0_1[0] red_buf0_1[0] = T.tvm_warp_shuffle(mask_1[0], red_buf0_1[0], 0, 32, 32) + # NOTE(Zihao): test tvm_warp_shuffle_up + red_buf0_1[0] = T.tvm_warp_shuffle_up(mask_1[0], red_buf0_1[0], 0, 32, 32) if threadIdx_x == 0: C_1 = T.Buffer((1,), data=C) C_1[0] = red_buf0_1[0] From 4cc6681568d9f15170e0f37b5823813972b0b4d6 Mon Sep 17 00:00:00 2001 From: Zihao Date: Sun, 12 Mar 2023 07:25:55 -0700 Subject: [PATCH 06/13] remove _op_wrapper --- python/tvm/script/ir_builder/tir/ir.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/python/tvm/script/ir_builder/tir/ir.py b/python/tvm/script/ir_builder/tir/ir.py index fce1fa5c6010..57090c6f3f39 100644 --- a/python/tvm/script/ir_builder/tir/ir.py +++ b/python/tvm/script/ir_builder/tir/ir.py @@ -1798,11 +1798,11 @@ def wrapped(*args, **kwargs): tvm_bmma_sync = _op_wrapper(_tir_op.tvm_bmma_sync) tvm_fill_fragment = _op_wrapper(_tir_op.tvm_fill_fragment) tvm_store_matrix_sync = _op_wrapper(_tir_op.tvm_store_matrix_sync) -tvm_storage_sync = _op_wrapper(_tir_op.tvm_storage_sync) -tvm_warp_shuffle = _op_wrapper(_tir_op.tvm_warp_shuffle) -tvm_warp_shuffle_up = _op_wrapper(_tir_op.tvm_warp_shuffle_up) -tvm_warp_shuffle_down = _op_wrapper(_tir_op.tvm_warp_shuffle_down) -tvm_warp_activemask = _op_wrapper(_tir_op.tvm_warp_activemask) +tvm_storage_sync = _tir_op.tvm_storage_sync +tvm_warp_shuffle = _tir_op.tvm_warp_shuffle +tvm_warp_shuffle_up = _tir_op.tvm_warp_shuffle_up +tvm_warp_shuffle_down = _tir_op.tvm_warp_shuffle_down +tvm_warp_activemask = _tir_op.tvm_warp_activemask ptx_wait_group = _op_wrapper(_tir_op.ptx_wait_group) ptx_commit_group = _op_wrapper(_tir_op.ptx_commit_group) assume = _op_wrapper(_tir_op.assume) From 4a5396ab7ed9b0667e7ba577ba5ebe9aee2264e6 Mon Sep 17 00:00:00 2001 From: Zihao Date: Sun, 12 Mar 2023 07:49:14 -0700 Subject: [PATCH 07/13] fix --- python/tvm/tir/op.py | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/python/tvm/tir/op.py b/python/tvm/tir/op.py index dd14d728128d..0fe460c085d7 100644 --- a/python/tvm/tir/op.py +++ b/python/tvm/tir/op.py @@ -646,7 +646,9 @@ def tvm_warp_shuffle_up(mask, value, offset, width, warp_size): call : PrimExpr The call expression. """ - return call_intrin(value.dtype, "tir.tvm_warp_shuffle_up", mask, value, offset, width, warp_size) + return call_intrin( + value.dtype, "tir.tvm_warp_shuffle_up", mask, value, offset, width, warp_size + ) def tvm_warp_shuffle_down(mask, value, offset, width, warp_size): @@ -671,7 +673,9 @@ def tvm_warp_shuffle_down(mask, value, offset, width, warp_size): call : PrimExpr The call expression. """ - return call_intrin(value.dtype, "tir.tvm_warp_shuffle_down", mask, value, offset, width, warp_size) + return call_intrin( + value.dtype, "tir.tvm_warp_shuffle_down", mask, value, offset, width, warp_size + ) def tvm_warp_activemask(): @@ -682,7 +686,7 @@ def tvm_warp_activemask(): call : PrimExpr The call expression. """ - return call_intrin("int32", "tir.tvm_warp_activemask") + return call_intrin("uint32", "tir.tvm_warp_activemask") def type_annotation(dtype): From fdabd56ee4f369c2136ac681cfbea2f1252d4580 Mon Sep 17 00:00:00 2001 From: Zihao Date: Sun, 12 Mar 2023 08:08:53 -0700 Subject: [PATCH 08/13] flake --- tests/python/unittest/test_tvmscript_roundtrip.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/tests/python/unittest/test_tvmscript_roundtrip.py b/tests/python/unittest/test_tvmscript_roundtrip.py index 337e23b0209b..6f07b6a75aeb 100644 --- a/tests/python/unittest/test_tvmscript_roundtrip.py +++ b/tests/python/unittest/test_tvmscript_roundtrip.py @@ -3626,9 +3626,9 @@ def main(A: T.handle, B: T.handle): def tvm_shfl_builtins(): @T.prim_func def func( - A: T.handle("float32", "global"), - C: T.handle("float32", "global"), - B: T.handle("float32", "global"), + A: T.handle("float32"), + B: T.handle("float32"), + C: T.handle("float32"), ): blockIdx_x = T.launch_thread("blockIdx.x", 1) threadIdx_x = T.launch_thread("threadIdx.x", 32) From 5c4228a76eaa4793e39cd4b6b862f6d4ad79363f Mon Sep 17 00:00:00 2001 From: Zihao Date: Sun, 12 Mar 2023 08:34:00 -0700 Subject: [PATCH 09/13] pylint --- python/tvm/script/ir_builder/tir/ir.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/tvm/script/ir_builder/tir/ir.py b/python/tvm/script/ir_builder/tir/ir.py index 57090c6f3f39..45350c5a65c7 100644 --- a/python/tvm/script/ir_builder/tir/ir.py +++ b/python/tvm/script/ir_builder/tir/ir.py @@ -1667,7 +1667,7 @@ def target(target_config: Union[Dict, str]) -> Target: return Target(target_config) -def Range(begin: PrimExpr, end: PrimExpr) -> ir.Range: +def Range(begin: PrimExpr, end: PrimExpr) -> ir.Range: # pylint: disable=invalid-name """ Create a Range object. From c75c2335f26f7c97aabcb642c0ea9de0139b63d0 Mon Sep 17 00:00:00 2001 From: Zihao Date: Fri, 10 Mar 2023 08:16:48 -0800 Subject: [PATCH 10/13] try --- src/driver/driver_api.cc | 2 ++ src/tir/transforms/lower_warp_memory.cc | 12 ++++++---- .../test_tir_transform_lower_warp_memory.py | 24 ++++++++++++++++++- 3 files changed, 33 insertions(+), 5 deletions(-) diff --git a/src/driver/driver_api.cc b/src/driver/driver_api.cc index da1bbc296a49..184b98e381ac 100644 --- a/src/driver/driver_api.cc +++ b/src/driver/driver_api.cc @@ -659,7 +659,9 @@ transform::Sequential DeviceModulePassManager(IRModule mixed_mod, Target target) device_pass_list.push_back(tir::transform::BindTarget(target)); + device_pass_list.push_back(transform::PrintIR()); device_pass_list.push_back(tir::transform::LowerWarpMemory()); + device_pass_list.push_back(transform::PrintIR()); device_pass_list.push_back(tir::transform::Simplify()); device_pass_list.push_back(tir::transform::LowerCustomDatatypes()); device_pass_list.push_back(tir::transform::LowerDeviceStorageAccessInfo()); diff --git a/src/tir/transforms/lower_warp_memory.cc b/src/tir/transforms/lower_warp_memory.cc index 9d2ff88540fc..d587d173b802 100644 --- a/src/tir/transforms/lower_warp_memory.cc +++ b/src/tir/transforms/lower_warp_memory.cc @@ -130,10 +130,13 @@ class WarpStoreCoeffFinder : private StmtExprVisitor { } void VisitStmt_(const BufferStoreNode* op) final { + LOG(INFO) << "wut"; + LOG(INFO) << op->buffer << " " << GetRef(buffer_); if (op->buffer->data.get() != buffer_) { StmtVisitor::VisitStmt_(op); return; } + LOG(INFO) << "hey"; ICHECK_EQ(op->indices.size(), 1) << "Expected flat memory to use as warp memory. " << "Has StorageFlatten (TE-based schedule) or " @@ -160,7 +163,9 @@ class WarpStoreCoeffFinder : private StmtExprVisitor { << "` into the form ax + by + cz + ... Warp memory is approximated by storing values in " "thread local registers and shuffling values between these registers. Currently only " "linear equation indices are supported."; + LOG(INFO) << "index = " << index; PrimExpr mcoeff = analyzer_->canonical_simplify(m[0]); + LOG(INFO) << "mcoeff = " << mcoeff; const auto* mcoeff_as_int = mcoeff.as(); ICHECK(mcoeff_as_int && mcoeff_as_int->value > 0) << "LowerWarpMemory failed due to store index=" << index @@ -243,10 +248,13 @@ class WarpAccessRewriter : protected StmtExprMutator { ICHECK_GT(alloc_size, 0) << "warp memory only support constant alloc size"; alloc_size *= op->dtype.lanes(); std::tie(warp_index_, width_) = WarpIndexFinder(warp_size_).Find(op->body); + LOG(INFO) << "warp_index = " << warp_index_; + LOG(INFO) << "body = " << op->body; warp_coeff_ = WarpStoreCoeffFinder(buffer_, warp_index_, analyzer_).Find(op->body); // Align the local memory size. The number of elements may not // be a multiple of width_ * warp_coeff_; round it up. + LOG(INFO) << width_ << " " << warp_coeff_; int factor = width_ * warp_coeff_; ICHECK_NE(factor, 0) << "Divide by zero"; warp_group_ = (alloc_size + (factor - 1)) / factor; @@ -331,10 +339,6 @@ class WarpAccessRewriter : protected StmtExprMutator { << "FlattenBuffer (TIR-based schedules) been run?"; auto [local_index, group] = SplitIndexByGroup(op->indices[0]); - // invariance: local index must do not contain warp id - ICHECK(!UsesVar(local_index, [this](const VarNode* var) { return var == warp_index_.get(); })) - << "LowerWarpMemory failed to rewrite load to shuffle for index " << op->indices[0] - << " local_index=" << local_index; auto writer = load.CopyOnWrite(); writer->indices = {local_index}; diff --git a/tests/python/unittest/test_tir_transform_lower_warp_memory.py b/tests/python/unittest/test_tir_transform_lower_warp_memory.py index d4abc26bb204..635c0312d592 100644 --- a/tests/python/unittest/test_tir_transform_lower_warp_memory.py +++ b/tests/python/unittest/test_tir_transform_lower_warp_memory.py @@ -19,6 +19,7 @@ import tvm import tvm.testing from tvm import te +from tvm.script import tir as T from tvm.contrib.nvcc import have_fp16 @@ -346,6 +347,27 @@ def test_lower_warp_memory_divide_by_factor(): with pytest.raises(tvm.error.TVMError, match="Divide by zero") as cm: tvm.tir.transform.LowerWarpMemory()(mod)["f_kernel0"] +@T.prim_func +def func() -> None: + A = T.alloc_buffer([32], "float32", scope="warp") + B = T.alloc_buffer([32], "float32", scope="warp") + for i in range(32): + with T.block("warp_shuffle"): + B[i] = A[(i % 4) * 8 + i // 4] + 1 + + +@tvm.testing.requires_cuda +def test_warp_shuffle(): + mod = tvm.IRModule.from_expr(func) + sch = tvm.tir.Schedule(mod["main"]) + blk = sch.get_block("warp_shuffle") + i, = sch.get_loops(blk) + io, ii = sch.split(i, [1, 32]) + sch.bind(io, "blockIdx.x") + sch.bind(ii, "threadIdx.x") + f = tvm.build(sch.mod["main"], target="cuda") + if __name__ == "__main__": - tvm.testing.main() + # tvm.testing.main() + test_warp_shuffle() From 107b21e2b535b02601d097a943e035688dcd4717 Mon Sep 17 00:00:00 2001 From: Zihao Date: Sat, 11 Mar 2023 20:13:35 -0800 Subject: [PATCH 11/13] upd --- src/tir/transforms/lower_warp_memory.cc | 8 -- .../test_tir_transform_lower_warp_memory.py | 122 +++++++++++++++++- 2 files changed, 115 insertions(+), 15 deletions(-) diff --git a/src/tir/transforms/lower_warp_memory.cc b/src/tir/transforms/lower_warp_memory.cc index d587d173b802..f0ccef2c29a8 100644 --- a/src/tir/transforms/lower_warp_memory.cc +++ b/src/tir/transforms/lower_warp_memory.cc @@ -130,13 +130,10 @@ class WarpStoreCoeffFinder : private StmtExprVisitor { } void VisitStmt_(const BufferStoreNode* op) final { - LOG(INFO) << "wut"; - LOG(INFO) << op->buffer << " " << GetRef(buffer_); if (op->buffer->data.get() != buffer_) { StmtVisitor::VisitStmt_(op); return; } - LOG(INFO) << "hey"; ICHECK_EQ(op->indices.size(), 1) << "Expected flat memory to use as warp memory. " << "Has StorageFlatten (TE-based schedule) or " @@ -163,9 +160,7 @@ class WarpStoreCoeffFinder : private StmtExprVisitor { << "` into the form ax + by + cz + ... Warp memory is approximated by storing values in " "thread local registers and shuffling values between these registers. Currently only " "linear equation indices are supported."; - LOG(INFO) << "index = " << index; PrimExpr mcoeff = analyzer_->canonical_simplify(m[0]); - LOG(INFO) << "mcoeff = " << mcoeff; const auto* mcoeff_as_int = mcoeff.as(); ICHECK(mcoeff_as_int && mcoeff_as_int->value > 0) << "LowerWarpMemory failed due to store index=" << index @@ -248,13 +243,10 @@ class WarpAccessRewriter : protected StmtExprMutator { ICHECK_GT(alloc_size, 0) << "warp memory only support constant alloc size"; alloc_size *= op->dtype.lanes(); std::tie(warp_index_, width_) = WarpIndexFinder(warp_size_).Find(op->body); - LOG(INFO) << "warp_index = " << warp_index_; - LOG(INFO) << "body = " << op->body; warp_coeff_ = WarpStoreCoeffFinder(buffer_, warp_index_, analyzer_).Find(op->body); // Align the local memory size. The number of elements may not // be a multiple of width_ * warp_coeff_; round it up. - LOG(INFO) << width_ << " " << warp_coeff_; int factor = width_ * warp_coeff_; ICHECK_NE(factor, 0) << "Divide by zero"; warp_group_ = (alloc_size + (factor - 1)) / factor; diff --git a/tests/python/unittest/test_tir_transform_lower_warp_memory.py b/tests/python/unittest/test_tir_transform_lower_warp_memory.py index 635c0312d592..09bee477e796 100644 --- a/tests/python/unittest/test_tir_transform_lower_warp_memory.py +++ b/tests/python/unittest/test_tir_transform_lower_warp_memory.py @@ -347,27 +347,135 @@ def test_lower_warp_memory_divide_by_factor(): with pytest.raises(tvm.error.TVMError, match="Divide by zero") as cm: tvm.tir.transform.LowerWarpMemory()(mod)["f_kernel0"] + @T.prim_func -def func() -> None: - A = T.alloc_buffer([32], "float32", scope="warp") - B = T.alloc_buffer([32], "float32", scope="warp") +def func(a: T.handle, b: T.handle) -> None: + A = T.match_buffer(a, [32], "float32") + B = T.match_buffer(b, [32], "float32") for i in range(32): with T.block("warp_shuffle"): - B[i] = A[(i % 4) * 8 + i // 4] + 1 + vi = T.axis.spatial(32, i) + B[vi] = A[(vi % 4) * 8 + vi // 4] + T.float32(1) + + +def test_warp_shuffle_transform(): + @tvm.script.ir_module + class Before: + @T.prim_func + def main(A: T.handle("float32", "global"), B: T.handle("float32", "global")): + # blockIdx_x = T.int32() + blockIdx_x = T.env_thread("blockIdx.x") + threadIdx_x = T.env_thread("threadIdx.x") + T.func_attr( + { + "calling_conv": 2, + "global_symbol": "default_function_kernel0", + "target": T.target( + { + "host": {"keys": ["cpu"], "kind": "llvm", "tag": ""}, + "keys": ["cuda", "gpu"], + "kind": "cuda", + "max_num_threads": 1024, + "tag": "", + "thread_warp_size": 32, + } + ), + "tir.device_thread_axis": [ + T.iter_var(blockIdx_x, [0, 1], "ThreadIndex", "blockIdx.x"), + T.iter_var(threadIdx_x, [0, 32], "ThreadIndex", "threadIdx.x"), + ], + "tir.is_global_func": 1, + "tir.noalias": 1, + } + ) + T.launch_thread(blockIdx_x, 1) + A_warp = T.allocate([32], "float32", "warp") + B_warp = T.allocate([32], "float32", "warp") + T.launch_thread(threadIdx_x, 32) + A_warp_1 = T.Buffer((32,), data=A_warp, scope="warp") + A_1 = T.Buffer((32,), data=A) + A_warp_1[threadIdx_x] = A_1[threadIdx_x] + B_warp_1 = T.Buffer((32,), data=B_warp, scope="warp") + # T.tvm_storage_sync("warp") + B_warp_1[threadIdx_x] = A_warp_1[threadIdx_x % 4 * 8 + threadIdx_x // 4] + T.float32(1) + B_1 = T.Buffer((32,), data=B) + B_1[threadIdx_x] = B_warp_1[threadIdx_x] + + @tvm.script.ir_module + class Expected: + @T.prim_func + def main(A: T.handle("float32", "global"), B: T.handle("float32", "global")): + blockIdx_x = T.env_thread("blockIdx.x") + threadIdx_x = T.env_thread("threadIdx.x") + T.func_attr( + { + "calling_conv": 2, + "global_symbol": "default_function_kernel0", + "target": T.target( + { + "host": {"keys": ["cpu"], "kind": "llvm", "tag": ""}, + "keys": ["cuda", "gpu"], + "kind": "cuda", + "max_num_threads": 1024, + "tag": "", + "thread_warp_size": 32, + } + ), + "tir.device_thread_axis": [ + T.iter_var(blockIdx_x, [0, 1], "ThreadIndex", "blockIdx.x"), + T.iter_var(threadIdx_x, [0, 32], "ThreadIndex", "threadIdx.x"), + ], + "tir.is_global_func": 1, + "tir.noalias": 1, + } + ) + T.launch_thread(blockIdx_x, 1) + A_warp = T.allocate([1], "float32", "local") + B_warp = T.allocate([1], "float32", "local") + T.launch_thread(threadIdx_x, 32) + A_warp_1 = T.Buffer((32,), data=A_warp, scope="local") + A_1 = T.Buffer((32,), data=A) + A_warp_1[0] = A_1[threadIdx_x] + B_warp_1 = T.Buffer((32,), data=B_warp, scope="local") + # T.tvm_storage_sync("warp") + B_warp_1[0] = T.tvm_warp_shuffle( + T.tvm_warp_activemask(), A_warp_1[0], threadIdx_x % 4 * 8 + threadIdx_x // 4, 32, 32 + ) + T.float32(1) + B_1 = T.Buffer((32,), data=B) + B_1[threadIdx_x] = B_warp_1[0] + + after = tvm.tir.transform.LowerWarpMemory()(Before) + + tvm.ir.assert_structural_equal(after, Expected) + +@T.prim_func +def warp_shuffle(A: T.Buffer([32], "float32"), B: T.Buffer([32], "float32")) -> None: + for i in range(32): + with T.block("warp_shuffle"): + vi = T.axis.spatial(32, i) + B[vi] = A[vi % 4 * 8 + vi // 4] + T.float32(1) @tvm.testing.requires_cuda def test_warp_shuffle(): - mod = tvm.IRModule.from_expr(func) + mod = tvm.IRModule.from_expr(warp_shuffle) sch = tvm.tir.Schedule(mod["main"]) blk = sch.get_block("warp_shuffle") i, = sch.get_loops(blk) io, ii = sch.split(i, [1, 32]) - sch.bind(io, "blockIdx.x") sch.bind(ii, "threadIdx.x") + A_warp = sch.cache_read(blk, 0, "warp") + sch.compute_at(A_warp, io) + sch.bind(sch.get_loops(A_warp)[-1], "threadIdx.x") + B_warp = sch.cache_write(blk, 0, "warp") + sch.reverse_compute_at(B_warp, io) + sch.bind(sch.get_loops(B_warp)[-1], "threadIdx.x") + sch.bind(io, "blockIdx.x") + print(sch.mod["main"].script()) f = tvm.build(sch.mod["main"], target="cuda") if __name__ == "__main__": # tvm.testing.main() - test_warp_shuffle() + test_warp_shuffle_transform() + # test_warp_shuffle() From 8ed39537557cb7a464171d1b741e21093d6d7efb Mon Sep 17 00:00:00 2001 From: Zihao Date: Sun, 12 Mar 2023 07:46:09 -0700 Subject: [PATCH 12/13] fix --- tests/python/unittest/test_tir_transform_lower_warp_memory.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/python/unittest/test_tir_transform_lower_warp_memory.py b/tests/python/unittest/test_tir_transform_lower_warp_memory.py index 09bee477e796..3315954a566d 100644 --- a/tests/python/unittest/test_tir_transform_lower_warp_memory.py +++ b/tests/python/unittest/test_tir_transform_lower_warp_memory.py @@ -396,7 +396,7 @@ def main(A: T.handle("float32", "global"), B: T.handle("float32", "global")): A_1 = T.Buffer((32,), data=A) A_warp_1[threadIdx_x] = A_1[threadIdx_x] B_warp_1 = T.Buffer((32,), data=B_warp, scope="warp") - # T.tvm_storage_sync("warp") + T.tvm_storage_sync("warp") B_warp_1[threadIdx_x] = A_warp_1[threadIdx_x % 4 * 8 + threadIdx_x // 4] + T.float32(1) B_1 = T.Buffer((32,), data=B) B_1[threadIdx_x] = B_warp_1[threadIdx_x] @@ -437,7 +437,7 @@ def main(A: T.handle("float32", "global"), B: T.handle("float32", "global")): A_1 = T.Buffer((32,), data=A) A_warp_1[0] = A_1[threadIdx_x] B_warp_1 = T.Buffer((32,), data=B_warp, scope="local") - # T.tvm_storage_sync("warp") + T.tvm_storage_sync("warp") B_warp_1[0] = T.tvm_warp_shuffle( T.tvm_warp_activemask(), A_warp_1[0], threadIdx_x % 4 * 8 + threadIdx_x // 4, 32, 32 ) + T.float32(1) From 7d8eff7fc3e6f8e676af6716810499a52ee09e0d Mon Sep 17 00:00:00 2001 From: Zihao Date: Sun, 12 Mar 2023 07:54:17 -0700 Subject: [PATCH 13/13] upd --- src/driver/driver_api.cc | 2 -- .../test_tir_transform_lower_warp_memory.py | 35 ++----------------- 2 files changed, 3 insertions(+), 34 deletions(-) diff --git a/src/driver/driver_api.cc b/src/driver/driver_api.cc index 184b98e381ac..da1bbc296a49 100644 --- a/src/driver/driver_api.cc +++ b/src/driver/driver_api.cc @@ -659,9 +659,7 @@ transform::Sequential DeviceModulePassManager(IRModule mixed_mod, Target target) device_pass_list.push_back(tir::transform::BindTarget(target)); - device_pass_list.push_back(transform::PrintIR()); device_pass_list.push_back(tir::transform::LowerWarpMemory()); - device_pass_list.push_back(transform::PrintIR()); device_pass_list.push_back(tir::transform::Simplify()); device_pass_list.push_back(tir::transform::LowerCustomDatatypes()); device_pass_list.push_back(tir::transform::LowerDeviceStorageAccessInfo()); diff --git a/tests/python/unittest/test_tir_transform_lower_warp_memory.py b/tests/python/unittest/test_tir_transform_lower_warp_memory.py index 3315954a566d..68c1164e262f 100644 --- a/tests/python/unittest/test_tir_transform_lower_warp_memory.py +++ b/tests/python/unittest/test_tir_transform_lower_warp_memory.py @@ -363,13 +363,12 @@ def test_warp_shuffle_transform(): class Before: @T.prim_func def main(A: T.handle("float32", "global"), B: T.handle("float32", "global")): - # blockIdx_x = T.int32() blockIdx_x = T.env_thread("blockIdx.x") threadIdx_x = T.env_thread("threadIdx.x") T.func_attr( { "calling_conv": 2, - "global_symbol": "default_function_kernel0", + "global_symbol": "main", "target": T.target( { "host": {"keys": ["cpu"], "kind": "llvm", "tag": ""}, @@ -410,7 +409,7 @@ def main(A: T.handle("float32", "global"), B: T.handle("float32", "global")): T.func_attr( { "calling_conv": 2, - "global_symbol": "default_function_kernel0", + "global_symbol": "main", "target": T.target( { "host": {"keys": ["cpu"], "kind": "llvm", "tag": ""}, @@ -448,34 +447,6 @@ def main(A: T.handle("float32", "global"), B: T.handle("float32", "global")): tvm.ir.assert_structural_equal(after, Expected) -@T.prim_func -def warp_shuffle(A: T.Buffer([32], "float32"), B: T.Buffer([32], "float32")) -> None: - for i in range(32): - with T.block("warp_shuffle"): - vi = T.axis.spatial(32, i) - B[vi] = A[vi % 4 * 8 + vi // 4] + T.float32(1) - - -@tvm.testing.requires_cuda -def test_warp_shuffle(): - mod = tvm.IRModule.from_expr(warp_shuffle) - sch = tvm.tir.Schedule(mod["main"]) - blk = sch.get_block("warp_shuffle") - i, = sch.get_loops(blk) - io, ii = sch.split(i, [1, 32]) - sch.bind(ii, "threadIdx.x") - A_warp = sch.cache_read(blk, 0, "warp") - sch.compute_at(A_warp, io) - sch.bind(sch.get_loops(A_warp)[-1], "threadIdx.x") - B_warp = sch.cache_write(blk, 0, "warp") - sch.reverse_compute_at(B_warp, io) - sch.bind(sch.get_loops(B_warp)[-1], "threadIdx.x") - sch.bind(io, "blockIdx.x") - print(sch.mod["main"].script()) - f = tvm.build(sch.mod["main"], target="cuda") - if __name__ == "__main__": - # tvm.testing.main() - test_warp_shuffle_transform() - # test_warp_shuffle() + tvm.testing.main()