Skip to content

Commit 32e500b

Browse files
[LLVM] Add support to generate llvm.assume (#14294)
We're adding support to generate `llvm.assume` from `tir.assume` as we currently see an error when lowering code with tir.assume in certain cases.
1 parent f4520c4 commit 32e500b

File tree

2 files changed

+27
-0
lines changed

2 files changed

+27
-0
lines changed

src/target/llvm/codegen_llvm.cc

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1387,6 +1387,9 @@ llvm::Value* CodeGenLLVM::CreateIntrinsic(const CallNode* op) {
13871387
op->op.same_as(builtin::end_profile_intrinsic())) {
13881388
LOG(INFO) << "Ignoring profile_intrinsic ... " << op->op;
13891389
return nullptr;
1390+
} else if (op->op.same_as(builtin::assume())) {
1391+
llvm::Value* cond = MakeValue(op->args[0]);
1392+
return builder_->CreateAssumption(cond);
13901393
} else {
13911394
LOG(FATAL) << "unknown intrinsic " << op->op;
13921395
}

tests/python/unittest/test_target_codegen_llvm.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -978,5 +978,29 @@ def test_llvm_target_attributes():
978978
assert n in functions_with_target
979979

980980

981+
@tvm.testing.requires_llvm
982+
def test_llvm_assume():
983+
"""
984+
Check that LLVM does not error out when generating code with tir.assume.
985+
Verifying for llvm.assume being generated is not easy as the intrinsic and its
986+
related instructions get removed during optimizations
987+
"""
988+
989+
@T.prim_func
990+
def tir_assume_func(A: T.Buffer((4, 4), "int32"), B: T.Buffer((14,), "int32")):
991+
T.func_attr({"global_symbol": "main", "tir.noalias": True})
992+
A_1 = T.Buffer((16,), "int32", data=A.data)
993+
for axis0, axis1 in T.grid(4, 4):
994+
T.assume(axis0 < 3 or axis1 < 2 or A_1[axis0 * 4 + axis1] == 0)
995+
for i in range(14):
996+
B_1 = T.Buffer((14,), "int32", data=B.data)
997+
B_1[i] = A_1[i] * 2
998+
999+
mod = tvm.IRModule.from_expr(tir_assume_func)
1000+
inp = te.placeholder((4, 4), name="A", dtype="int32")
1001+
out = te.placeholder((14,), name="B", dtype="int32")
1002+
m = tvm.build(mod, [inp, out], target="llvm")
1003+
1004+
9811005
if __name__ == "__main__":
9821006
tvm.testing.main()

0 commit comments

Comments
 (0)