Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion include/tvm/script/ir_builder/tir/ir.h
Original file line number Diff line number Diff line change
Expand Up @@ -401,9 +401,10 @@ LaunchThreadFrame LaunchThread(String thread_tag, PrimExpr extent);
/*!
* \brief Bind a var to thread env.
* \param thread_tag The thread type tag.
* \param dtype The data type of the variable.
* \return The result variable which gets bound to the thread env.
*/
Var EnvThread(String thread_tag);
Var EnvThread(String thread_tag, DataType dtype = DataType::Int(32));

/*!
* \brief Store data in a buffer.
Expand Down
7 changes: 5 additions & 2 deletions python/tvm/script/ir_builder/tir/ir.py
Original file line number Diff line number Diff line change
Expand Up @@ -1241,21 +1241,24 @@ def launch_thread(
return _ffi_api.LaunchThread(thread, extent) # type: ignore[attr-defined] # pylint: disable=no-member


def env_thread(thread_tag: str) -> IterVar:
def env_thread(thread_tag: str, dtype: str = "int32") -> IterVar:
"""Bind a var to thread env

Parameters
----------
thread_tag : str
The thread type tag.

dtype : str
The data type of the thread env.

Returns
-------
res : IterVar
The result iteration variable gets bound to the thread env.

"""
return _ffi_api.EnvThread(thread_tag) # type: ignore[attr-defined] # pylint: disable=no-member
return _ffi_api.EnvThread(thread_tag, dtype) # type: ignore[attr-defined] # pylint: disable=no-member


def buffer_store(
Expand Down
10 changes: 5 additions & 5 deletions src/script/ir_builder/tir/ir.cc
Original file line number Diff line number Diff line change
Expand Up @@ -432,7 +432,8 @@ LaunchThreadFrame LaunchThread(Var var, PrimExpr extent) {
}
ObjectPtr<LaunchThreadFrameNode> n = make_object<LaunchThreadFrameNode>();
if (!iter_var->dom.defined()) {
const_cast<tvm::tir::IterVarNode*>(iter_var.get())->dom = Range(0, extent);
const_cast<tvm::tir::IterVarNode*>(iter_var.get())->dom =
Range(tvm::tir::make_zero(extent.dtype()), extent);
} else if (!arith::Analyzer().CanProveEqual(iter_var->dom->extent, extent)) {
LOG(FATAL) << "ValueError: Inconsistent extents of environment thread. "
<< iter_var->dom->extent << " vs " << extent;
Expand All @@ -444,7 +445,7 @@ LaunchThreadFrame LaunchThread(Var var, PrimExpr extent) {
}

LaunchThreadFrame LaunchThread(String thread_tag, PrimExpr extent) {
return LaunchThread(EnvThread(thread_tag), extent);
return LaunchThread(EnvThread(thread_tag, extent.dtype()), extent);
}

RealizeFrame Realize(tvm::tir::BufferRegion buffer_slice, String storage_scope,
Expand Down Expand Up @@ -512,9 +513,8 @@ ElseFrame Else() {
return ElseFrame(n);
}

Var EnvThread(String thread_tag) {
IterVar iter_var(Range{nullptr}, Var("", DataType::Int(32)), tvm::tir::IterVarType::kThreadIndex,
thread_tag);
Var EnvThread(String thread_tag, DataType dtype) {
IterVar iter_var(Range{nullptr}, Var("", dtype), tvm::tir::IterVarType::kThreadIndex, thread_tag);
Var var = iter_var->var;
if (Optional<PrimFuncFrame> opt_frame = IRBuilder::Current()->FindFrame<PrimFuncFrame>()) {
opt_frame.value()->env_threads.Set(var, iter_var);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -969,9 +969,9 @@ def expected(A: T.Buffer((32, 128), "float16")):
T.ptx_cp_async(
"float16",
A_shared.data,
T.Cast("int64", tx) * T.int64(128) + cse_var_1 * T.int64(8),
tx * T.int64(128) + cse_var_1 * T.int64(8),
A.data,
T.Cast("int64", tx) * T.int64(128) + cse_var_1 * T.int64(8),
tx * T.int64(128) + cse_var_1 * T.int64(8),
16,
)
T.ptx_commit_group()
Expand Down
15 changes: 15 additions & 0 deletions tests/python/tvmscript/test_tvmscript_parser_tir.py
Original file line number Diff line number Diff line change
Expand Up @@ -471,5 +471,20 @@ def expected(A: T.Buffer((32,), "float32"), B: T.Buffer((32,), "float32")) -> No
tvm.ir.assert_structural_equal(func, expected)


def test_launch_thread_i64():
"""Test launching thread with int64"""

@T.prim_func
def func() -> None:
blockIdx_x = T.launch_thread("blockIdx.x", T.int64(1))
if blockIdx_x == T.int64(0):
T.evaluate(T.int64(0))
else:
T.evaluate(T.int64(1))

assert func.body.node.dom.min.dtype == "int64"
assert func.body.node.dom.extent.dtype == "int64"


if __name__ == "__main__":
tvm.testing.main()