Skip to content

Commit c9d87ef

Browse files
[Relax][Bugfix] Annotate ComputePrimValue output as host function (#17032)
The `ComputePrimValue` transform is used to compute the value of symbolic expressions that may appear within a Relax function. For example, to compute a boolean condition used for a `relax::If` node. These functions are used for small host-side computations, prior to launching a device kernel. This commit updates `ComputePrimValue` to annotate the generated `PrimFunc` with `tir::attr::kIsHostFunc`. This annotation is required for correct behavior in `tvm.dlight.ApplyDefaultSchedule`, to avoid erroneous scheduling of this function for the GPU, and for `tir::transform::BindTarget`, to ensure that the function is compiled for execution on the host. Co-authored-by: Chris Sullivan <[email protected]>
1 parent b2c6116 commit c9d87ef

File tree

2 files changed

+5
-1
lines changed

2 files changed

+5
-1
lines changed

src/relax/transform/compute_prim_value.cc

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,8 @@ class PrimValueComputeInjector : public ExprMutator {
4545
auto param_vars = tir::UndefinedVars(node->value);
4646
tir::Stmt body = tir::Evaluate(tir::Call(ret_dtype, tir::builtin::ret(), {node->value}));
4747

48-
tir::PrimFunc func(param_vars, body, PrimType(ret_dtype));
48+
tir::PrimFunc func(param_vars, body, PrimType(ret_dtype), {},
49+
DictAttrs({{tir::attr::kIsHostFunc, Bool(true)}}));
4950
func = tir::RenewDefs(func);
5051

5152
auto callee = builder_->AddFunction(func, "compute_symbolic_expr");

tests/python/relax/test_transform_compute_prim_value.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,7 @@ def main(A: R.Tensor(["N"])):
4444

4545
@T.prim_func(private=True)
4646
def compute_symbolic_expr(N: T.int64) -> T.bool:
47+
T.func_attr({"tir.is_host_func": True})
4748
T.ret(N % 16 == 0)
4849

4950

@@ -73,6 +74,7 @@ def main(A: R.Tensor(["N"])):
7374

7475
@T.prim_func(private=True)
7576
def compute_symbolic_expr(N: T.int64) -> T.bool:
77+
T.func_attr({"tir.is_host_func": True})
7678
T.ret(N % 16 == 0)
7779

7880

@@ -97,6 +99,7 @@ def main(_N: R.Prim(value="N"), _M: R.Prim(value="M")) -> R.Prim(value="N*M"):
9799

98100
@T.prim_func(private=True)
99101
def compute_symbolic_expr(N: T.int64, M: T.int64) -> T.int64:
102+
T.func_attr({"tir.is_host_func": True})
100103
T.ret(N * M)
101104

102105

0 commit comments

Comments
 (0)