Skip to content

Commit ff1b88e

Browse files
author
Siyuan Feng
committed
[TVMScript] Support T.launch_thread with i64 dtype
This PR fixes the bug of mismatched dtype in `T.launch_thread` when the dtype is `i64`.
1 parent 57316da commit ff1b88e

File tree

4 files changed

+27
-8
lines changed

4 files changed

+27
-8
lines changed

include/tvm/script/ir_builder/tir/ir.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -401,9 +401,10 @@ LaunchThreadFrame LaunchThread(String thread_tag, PrimExpr extent);
401401
/*!
402402
* \brief Bind a var to thread env.
403403
* \param thread_tag The thread type tag.
404+
* \param dtype The data type of the variable.
404405
* \return The result variable which gets bound to the thread env.
405406
*/
406-
Var EnvThread(String thread_tag);
407+
Var EnvThread(String thread_tag, DataType dtype = DataType::Int(32));
407408

408409
/*!
409410
* \brief Store data in a buffer.

python/tvm/script/ir_builder/tir/ir.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1241,21 +1241,24 @@ def launch_thread(
12411241
return _ffi_api.LaunchThread(thread, extent) # type: ignore[attr-defined] # pylint: disable=no-member
12421242

12431243

1244-
def env_thread(thread_tag: str) -> IterVar:
1244+
def env_thread(thread_tag: str, dtype: str = "int32") -> IterVar:
12451245
"""Bind a var to thread env
12461246
12471247
Parameters
12481248
----------
12491249
thread_tag : str
12501250
The thread type tag.
12511251
1252+
dtype : str
1253+
The data type of the thread env.
1254+
12521255
Returns
12531256
-------
12541257
res : IterVar
12551258
The result iteration variable gets bound to the thread env.
12561259
12571260
"""
1258-
return _ffi_api.EnvThread(thread_tag) # type: ignore[attr-defined] # pylint: disable=no-member
1261+
return _ffi_api.EnvThread(thread_tag, dtype) # type: ignore[attr-defined] # pylint: disable=no-member
12591262

12601263

12611264
def buffer_store(

src/script/ir_builder/tir/ir.cc

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -432,7 +432,8 @@ LaunchThreadFrame LaunchThread(Var var, PrimExpr extent) {
432432
}
433433
ObjectPtr<LaunchThreadFrameNode> n = make_object<LaunchThreadFrameNode>();
434434
if (!iter_var->dom.defined()) {
435-
const_cast<tvm::tir::IterVarNode*>(iter_var.get())->dom = Range(0, extent);
435+
const_cast<tvm::tir::IterVarNode*>(iter_var.get())->dom =
436+
Range(tvm::tir::make_zero(extent.dtype()), extent);
436437
} else if (!arith::Analyzer().CanProveEqual(iter_var->dom->extent, extent)) {
437438
LOG(FATAL) << "ValueError: Inconsistent extents of environment thread. "
438439
<< iter_var->dom->extent << " vs " << extent;
@@ -444,7 +445,7 @@ LaunchThreadFrame LaunchThread(Var var, PrimExpr extent) {
444445
}
445446

446447
LaunchThreadFrame LaunchThread(String thread_tag, PrimExpr extent) {
447-
return LaunchThread(EnvThread(thread_tag), extent);
448+
return LaunchThread(EnvThread(thread_tag, extent.dtype()), extent);
448449
}
449450

450451
RealizeFrame Realize(tvm::tir::BufferRegion buffer_slice, String storage_scope,
@@ -512,9 +513,8 @@ ElseFrame Else() {
512513
return ElseFrame(n);
513514
}
514515

515-
Var EnvThread(String thread_tag) {
516-
IterVar iter_var(Range{nullptr}, Var("", DataType::Int(32)), tvm::tir::IterVarType::kThreadIndex,
517-
thread_tag);
516+
Var EnvThread(String thread_tag, DataType dtype) {
517+
IterVar iter_var(Range{nullptr}, Var("", dtype), tvm::tir::IterVarType::kThreadIndex, thread_tag);
518518
Var var = iter_var->var;
519519
if (Optional<PrimFuncFrame> opt_frame = IRBuilder::Current()->FindFrame<PrimFuncFrame>()) {
520520
opt_frame.value()->env_threads.Set(var, iter_var);

tests/python/tvmscript/test_tvmscript_parser_tir.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -471,5 +471,20 @@ def expected(A: T.Buffer((32,), "float32"), B: T.Buffer((32,), "float32")) -> No
471471
tvm.ir.assert_structural_equal(func, expected)
472472

473473

474+
def test_launch_thread_i64():
475+
"""Test launching thread with int64"""
476+
477+
@T.prim_func
478+
def func() -> None:
479+
blockIdx_x = T.launch_thread("blockIdx.x", T.int64(1))
480+
if blockIdx_x == T.int64(0):
481+
T.evaluate(T.int64(0))
482+
else:
483+
T.evaluate(T.int64(1))
484+
485+
assert func.body.node.dom.min.dtype == "int64"
486+
assert func.body.node.dom.extent.dtype == "int64"
487+
488+
474489
if __name__ == "__main__":
475490
tvm.testing.main()

0 commit comments

Comments
 (0)