@@ -978,5 +978,29 @@ def test_llvm_target_attributes():
978978 assert n in functions_with_target
979979
980980
981+ @tvm .testing .requires_llvm
982+ def test_llvm_assume ():
983+ """
984+ Check that LLVM does not error out when generating code with tir.assume.
985+ Verifying for llvm.assume being generated is not easy as the intrinsic and its
986+ related instructions get removed during optimizations
987+ """
988+
989+ @T .prim_func
990+ def tir_assume_func (A : T .Buffer ((4 , 4 ), "int32" ), B : T .Buffer ((14 ,), "int32" )):
991+ T .func_attr ({"global_symbol" : "main" , "tir.noalias" : True })
992+ A_1 = T .Buffer ((16 ,), "int32" , data = A .data )
993+ for axis0 , axis1 in T .grid (4 , 4 ):
994+ T .assume (axis0 < 3 or axis1 < 2 or A_1 [axis0 * 4 + axis1 ] == 0 )
995+ for i in range (14 ):
996+ B_1 = T .Buffer ((14 ,), "int32" , data = B .data )
997+ B_1 [i ] = A_1 [i ] * 2
998+
999+ mod = tvm .IRModule .from_expr (tir_assume_func )
1000+ inp = te .placeholder ((4 , 4 ), name = "A" , dtype = "int32" )
1001+ out = te .placeholder ((14 ,), name = "B" , dtype = "int32" )
1002+ m = tvm .build (mod , [inp , out ], target = "llvm" )
1003+
1004+
9811005if __name__ == "__main__" :
9821006 tvm .testing .main ()
0 commit comments