Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
51 changes: 39 additions & 12 deletions python/tvm/script/ir_builder/tir/ir.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,8 @@
import numpy as np # type: ignore

from tvm import tir
from tvm.ir import Range, Type
from tvm import ir
from tvm.ir import Type
from tvm.ir.base import deprecated
from tvm.runtime import String, convert, ndarray
from tvm.target import Target
Expand Down Expand Up @@ -496,7 +497,7 @@ def alloc_buffer(
)


def _as_range(dom: Union[Range, List[PrimExpr]]) -> Range:
def _as_range(dom: Union[ir.Range, List[PrimExpr]]) -> ir.Range:
"""The range constructor.

Parameters
Expand All @@ -509,21 +510,21 @@ def _as_range(dom: Union[Range, List[PrimExpr]]) -> Range:
res : Range
The Range.
"""
if isinstance(dom, Range):
if isinstance(dom, ir.Range):
return dom
if isinstance(dom, (list, tuple)):
return Range(dom[0], dom[1])
return ir.Range(dom[0], dom[1])
if hasattr(dom, "dtype"):
return Range(IntImm(dom.dtype, 0), dom)
return Range(0, dom)
return ir.Range(IntImm(dom.dtype, 0), dom)
return ir.Range(0, dom)


class axis: # pylint: disable=invalid-name
"""The axis class"""

@staticmethod
def spatial(
dom: Union[Range, List[PrimExpr], Tuple[PrimExpr]],
dom: Union[ir.Range, List[PrimExpr], Tuple[PrimExpr]],
binding: PrimExpr,
dtype: str = "int32",
) -> Var:
Expand Down Expand Up @@ -551,7 +552,7 @@ def spatial(

@staticmethod
def reduce(
dom: Union[Range, List[PrimExpr], Tuple[PrimExpr]],
dom: Union[ir.Range, List[PrimExpr], Tuple[PrimExpr]],
binding: PrimExpr,
dtype: str = "int32",
) -> Var:
Expand Down Expand Up @@ -579,7 +580,7 @@ def reduce(

@staticmethod
def scan(
dom: Union[Range, List[PrimExpr], Tuple[PrimExpr]],
dom: Union[ir.Range, List[PrimExpr], Tuple[PrimExpr]],
binding: PrimExpr,
dtype: str = "int32",
) -> Var:
Expand Down Expand Up @@ -607,7 +608,7 @@ def scan(

@staticmethod
def opaque(
dom: Union[Range, List[PrimExpr], Tuple[PrimExpr]],
dom: Union[ir.Range, List[PrimExpr], Tuple[PrimExpr]],
binding: PrimExpr,
dtype: str = "int32",
) -> Var:
Expand Down Expand Up @@ -1288,7 +1289,7 @@ def buffer_store(

def prefetch(
buffer: Buffer, # pylint: disable=redefined-outer-name
bounds: List[Range],
bounds: List[ir.Range],
) -> None:
"""The prefetch hint for a buffer.

Expand Down Expand Up @@ -1579,7 +1580,7 @@ def max(a: PrimExpr, b: PrimExpr) -> PrimExpr: # pylint: disable=redefined-buil
return _ffi_api.max(a, b) # type: ignore[attr-defined] # pylint: disable=no-member


def iter_var(v: Union[Var, str], dom: Range, iter_type: str, thread_tag: str) -> IterVar:
def iter_var(v: Union[Var, str], dom: ir.Range, iter_type: str, thread_tag: str) -> IterVar:
"""The iteration variable.

Parameters
Expand Down Expand Up @@ -1666,6 +1667,21 @@ def target(target_config: Union[Dict, str]) -> Target:
return Target(target_config)


def Range(begin: PrimExpr, end: PrimExpr) -> ir.Range: # pylint: disable=invalid-name
"""
Create a Range object.

Parameters
----------
begin : PrimExpr
The begin value of the range.

end : Optional[PrimExpr]
The end value of the range.
"""
return ir.Range(begin, end)


class meta_var: # pylint: disable=invalid-name
"""A meta variable used in TVMScript metaprogramming. It means that the value of the variable
does not appear in the final TIR, but only stays in the parser.
Expand Down Expand Up @@ -1782,6 +1798,11 @@ def wrapped(*args, **kwargs):
tvm_bmma_sync = _op_wrapper(_tir_op.tvm_bmma_sync)
tvm_fill_fragment = _op_wrapper(_tir_op.tvm_fill_fragment)
tvm_store_matrix_sync = _op_wrapper(_tir_op.tvm_store_matrix_sync)
tvm_storage_sync = _tir_op.tvm_storage_sync
tvm_warp_shuffle = _tir_op.tvm_warp_shuffle
tvm_warp_shuffle_up = _tir_op.tvm_warp_shuffle_up
tvm_warp_shuffle_down = _tir_op.tvm_warp_shuffle_down
tvm_warp_activemask = _tir_op.tvm_warp_activemask
ptx_wait_group = _op_wrapper(_tir_op.ptx_wait_group)
ptx_commit_group = _op_wrapper(_tir_op.ptx_commit_group)
assume = _op_wrapper(_tir_op.assume)
Expand Down Expand Up @@ -2042,6 +2063,11 @@ def wrapped(*args, **kwargs):
"tvm_bmma_sync",
"tvm_fill_fragment",
"tvm_store_matrix_sync",
"tvm_storage_sync",
"tvm_warp_shuffle",
"tvm_warp_shuffle_up",
"tvm_warp_shuffle_down",
"tvm_warp_activemask",
"ptx_mma",
"ptx_mma_sp",
"ptx_ldmatrix",
Expand Down Expand Up @@ -2109,4 +2135,5 @@ def wrapped(*args, **kwargs):
"Let",
"IterVar",
"CommReducer",
"Range",
]
108 changes: 107 additions & 1 deletion python/tvm/tir/op.py
Original file line number Diff line number Diff line change
Expand Up @@ -569,7 +569,8 @@ def lookup_param(param_name, span=None):


def tvm_thread_allreduce(*freduce_args):
"""
"""Perform allreduce inside threadblock.

Parameters
----------
freduce_args : Expr
Expand All @@ -583,6 +584,111 @@ def tvm_thread_allreduce(*freduce_args):
return call_intrin("handle", "tir.tvm_thread_allreduce", *freduce_args)


def tvm_storage_sync(storage_scope):
"""Perform synchronization in specified scope.

Parameters
----------
storage_scope : str
The storage scope to perform synchronization.

Returns
-------
call : PrimExpr
The call expression.
"""
return call_intrin("handle", "tir.tvm_storage_sync", storage_scope)


def tvm_warp_shuffle(mask, value, warp_id, width, warp_size):
"""Exchange value between threads inside a warp.

Parameters
----------
mask : PrimExpr
The warp mask indicates active threads inside warp.
value : PrimExpr
The value to exchange.
warp_id : PrimExpr
The source lane index to fetch value.
width : PrimExpr
The width of sub-sections to perform warp shuffle.
warp_size : PrimExpr
The warp size.

Returns
-------
call : PrimExpr
The call expression.
"""
return call_intrin(value.dtype, "tir.tvm_warp_shuffle", mask, value, warp_id, width, warp_size)


def tvm_warp_shuffle_up(mask, value, offset, width, warp_size):
"""Copy value from a lane with lower (by offset) index relative to caller.

Parameters
----------
mask : PrimExpr
The warp mask indicates active threads inside warp.
value : PrimExpr
The value to exchange.
offset : PrimExpr
The difference between source lane index and destination lane index:
`offset = dst_lane_idx - src_lane_idx`
width : PrimExpr
The width of sub-sections to perform warp shuffle.
warp_size : PrimExpr
The warp size.

Returns
-------
call : PrimExpr
The call expression.
"""
return call_intrin(
value.dtype, "tir.tvm_warp_shuffle_up", mask, value, offset, width, warp_size
)


def tvm_warp_shuffle_down(mask, value, offset, width, warp_size):
"""Copy value from a lane with higher (by offset) index relative to caller.

Parameters
----------
mask : PrimExpr
The warp mask indicates active threads inside warp.
value : PrimExpr
The value to exchange.
offset : PrimExpr
The difference between source lane index and destination lane index:
`offset = src_lane_idx - dst_lane_idx`
width : PrimExpr
The width of sub-sections to perform warp shuffle.
warp_size : PrimExpr
The warp size.

Returns
-------
call : PrimExpr
The call expression.
"""
return call_intrin(
value.dtype, "tir.tvm_warp_shuffle_down", mask, value, offset, width, warp_size
)


def tvm_warp_activemask():
"""Return a 32-bit mask indicates currently active threads in a calling warp.

Returns
-------
call : PrimExpr
The call expression.
"""
return call_intrin("uint32", "tir.tvm_warp_activemask")


def type_annotation(dtype):
"""Create a type annotation expression

Expand Down
4 changes: 0 additions & 4 deletions src/tir/transforms/lower_warp_memory.cc
Original file line number Diff line number Diff line change
Expand Up @@ -331,10 +331,6 @@ class WarpAccessRewriter : protected StmtExprMutator {
<< "FlattenBuffer (TIR-based schedules) been run?";

auto [local_index, group] = SplitIndexByGroup(op->indices[0]);
// invariance: local index must do not contain warp id
ICHECK(!UsesVar(local_index, [this](const VarNode* var) { return var == warp_index_.get(); }))
<< "LowerWarpMemory failed to rewrite load to shuffle for index " << op->indices[0]
<< " local_index=" << local_index;

auto writer = load.CopyOnWrite();
writer->indices = {local_index};
Expand Down
101 changes: 101 additions & 0 deletions tests/python/unittest/test_tir_transform_lower_warp_memory.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
import tvm
import tvm.testing
from tvm import te
from tvm.script import tir as T
from tvm.contrib.nvcc import have_fp16


Expand Down Expand Up @@ -347,5 +348,105 @@ def test_lower_warp_memory_divide_by_factor():
tvm.tir.transform.LowerWarpMemory()(mod)["f_kernel0"]


@T.prim_func
def func(a: T.handle, b: T.handle) -> None:
A = T.match_buffer(a, [32], "float32")
B = T.match_buffer(b, [32], "float32")
for i in range(32):
with T.block("warp_shuffle"):
vi = T.axis.spatial(32, i)
B[vi] = A[(vi % 4) * 8 + vi // 4] + T.float32(1)


def test_warp_shuffle_transform():
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The test looks reasonable as-is, though there's also a tvm.testing.CompareBeforeAfter that you could use to further reduce the boilerplate.

class TestWarpShuffleTransform(tvm.testing.CompareBeforeAfter):
    transform = tvm.tir.transform.LowerWarpMemory()

    def before(A: T.handle("float32", "global"), B: T.handle("float32", "global")):
        ...

    def expected(A: T.handle("float32", "global"), B: T.handle("float32", "global")):
        ...

@tvm.script.ir_module
class Before:
@T.prim_func
def main(A: T.handle("float32", "global"), B: T.handle("float32", "global")):
blockIdx_x = T.env_thread("blockIdx.x")
threadIdx_x = T.env_thread("threadIdx.x")
T.func_attr(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It looks like the test case only requires the "target" attribute, and only requires "kind" and "thread_warp_size" within that. Can we remove the extra attributes from the unit test?

{
"calling_conv": 2,
"global_symbol": "main",
"target": T.target(
{
"host": {"keys": ["cpu"], "kind": "llvm", "tag": ""},
"keys": ["cuda", "gpu"],
"kind": "cuda",
"max_num_threads": 1024,
"tag": "",
"thread_warp_size": 32,
}
),
"tir.device_thread_axis": [
T.iter_var(blockIdx_x, [0, 1], "ThreadIndex", "blockIdx.x"),
T.iter_var(threadIdx_x, [0, 32], "ThreadIndex", "threadIdx.x"),
],
"tir.is_global_func": 1,
"tir.noalias": 1,
}
)
T.launch_thread(blockIdx_x, 1)
A_warp = T.allocate([32], "float32", "warp")
B_warp = T.allocate([32], "float32", "warp")
T.launch_thread(threadIdx_x, 32)
A_warp_1 = T.Buffer((32,), data=A_warp, scope="warp")
A_1 = T.Buffer((32,), data=A)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Instead of having a separate A: T.handle and A_1: T.Buffer, the buffer could be declared as a parameter A_1: T.Buffer(32). It does result in slightly different TIR, as it follows the style from before MakePackedAPI is applied, but for a unit test would help to emphasize the change being tested.

A_warp_1[threadIdx_x] = A_1[threadIdx_x]
B_warp_1 = T.Buffer((32,), data=B_warp, scope="warp")
T.tvm_storage_sync("warp")
B_warp_1[threadIdx_x] = A_warp_1[threadIdx_x % 4 * 8 + threadIdx_x // 4] + T.float32(1)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could we add a comment here, indicating that this line is the one that should be updated correctly?

B_1 = T.Buffer((32,), data=B)
B_1[threadIdx_x] = B_warp_1[threadIdx_x]

@tvm.script.ir_module
class Expected:
@T.prim_func
def main(A: T.handle("float32", "global"), B: T.handle("float32", "global")):
blockIdx_x = T.env_thread("blockIdx.x")
threadIdx_x = T.env_thread("threadIdx.x")
T.func_attr(
{
"calling_conv": 2,
"global_symbol": "main",
"target": T.target(
{
"host": {"keys": ["cpu"], "kind": "llvm", "tag": ""},
"keys": ["cuda", "gpu"],
"kind": "cuda",
"max_num_threads": 1024,
"tag": "",
"thread_warp_size": 32,
}
),
"tir.device_thread_axis": [
T.iter_var(blockIdx_x, [0, 1], "ThreadIndex", "blockIdx.x"),
T.iter_var(threadIdx_x, [0, 32], "ThreadIndex", "threadIdx.x"),
],
"tir.is_global_func": 1,
"tir.noalias": 1,
}
)
T.launch_thread(blockIdx_x, 1)
A_warp = T.allocate([1], "float32", "local")
B_warp = T.allocate([1], "float32", "local")
T.launch_thread(threadIdx_x, 32)
A_warp_1 = T.Buffer((32,), data=A_warp, scope="local")
A_1 = T.Buffer((32,), data=A)
A_warp_1[0] = A_1[threadIdx_x]
B_warp_1 = T.Buffer((32,), data=B_warp, scope="local")
T.tvm_storage_sync("warp")
B_warp_1[0] = T.tvm_warp_shuffle(
T.tvm_warp_activemask(), A_warp_1[0], threadIdx_x % 4 * 8 + threadIdx_x // 4, 32, 32
) + T.float32(1)
B_1 = T.Buffer((32,), data=B)
B_1[threadIdx_x] = B_warp_1[0]

after = tvm.tir.transform.LowerWarpMemory()(Before)

tvm.ir.assert_structural_equal(after, Expected)


if __name__ == "__main__":
tvm.testing.main()
Loading