Skip to content

Commit 78a1f80

Browse files
authored
[CODEGEN] Vector-Codegen support for llvm-pure-intrin (#16985)
* Vector-Codegen support for llvm-pure-intrin
1 parent f5d3fc2 commit 78a1f80

File tree

4 files changed

+103
-2
lines changed

4 files changed

+103
-2
lines changed

src/tir/op/builtin.cc

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -139,7 +139,8 @@ TIR_DEFINE_BUILTIN_FUNC(call_llvm_intrin)
139139
TIR_DEFINE_BUILTIN_FUNC(call_llvm_pure_intrin)
140140
.set_attr<TCallEffectKind>("TCallEffectKind", Integer(CallEffectKind::kPure))
141141
.set_attr<TScriptDtypePrintLocation>("TScriptDtypePrintLocation",
142-
Integer(ScriptDtypePrintLocation::kFirst));
142+
Integer(ScriptDtypePrintLocation::kFirst))
143+
.set_attr<TVectorizable>("TVectorizable", true);
143144

144145
TIR_DEFINE_BUILTIN_FUNC(call_spirv_pure_glsl450)
145146
.set_attr<TCallEffectKind>("TCallEffectKind", Integer(CallEffectKind::kPure));

src/tir/transforms/vectorize_loop.cc

Lines changed: 22 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -550,7 +550,28 @@ class Vectorizer : public StmtMutator, public ExprFunctor<PrimExpr(const PrimExp
550550
}
551551
} else {
552552
int lane = 0;
553-
Array<PrimExpr> new_args = MutateArray(op->args, &lane);
553+
Array<PrimExpr> new_args;
554+
if (op->op.same_as(builtin::call_llvm_pure_intrin())) {
555+
// op->args[1], will give us total number of arguments to intrinsic
556+
int num_signature = Downcast<IntImm>(op->args[1])->value;
557+
Array<PrimExpr> op_expr_args;
558+
for (int i = 0; i < num_signature; i++) {
559+
// Collect all intrinsic arguments
560+
op_expr_args.push_back(op->args[i + 2]);
561+
}
562+
// Generate RAMP nodes for intrinsic arguments
563+
Array<PrimExpr> updated_args = MutateArray(op_expr_args, &lane);
564+
// Collect Intrinsic ID and no. of argument
565+
for (int i = 0; i < 2; i++) {
566+
new_args.push_back(op->args[i]);
567+
}
568+
// Collect updated intrinsic arguments
569+
for (int i = 0; i < num_signature; i++) {
570+
new_args.push_back(updated_args[i]);
571+
}
572+
} else {
573+
new_args = MutateArray(op->args, &lane);
574+
}
554575
// normal code path.
555576
if (op->args.same_as(new_args)) {
556577
return GetRef<PrimExpr>(op);

tests/python/tir-transform/test_tir_transform_vectorize.py

Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -790,5 +790,63 @@ def expected(a: T.handle, b: T.handle):
790790
tvm.ir.assert_structural_equal(after, expected)
791791

792792

793+
@pytest.mark.parametrize(
794+
"extent, vec_str, target",
795+
[(4, "float32x4", simple_target)],
796+
)
797+
def test_vectorize_llvm_pure_intrin(extent, vec_str, target):
798+
@I.ir_module
799+
class Before:
800+
@T.prim_func
801+
def main(A: T.Buffer((25,), "float32"), B: T.Buffer((25,), "float32")):
802+
for j in T.vectorized(extent):
803+
A[j] = T.call_llvm_pure_intrin(
804+
"float32", "llvm.sqrt", tvm.tir.const(1, "uint"), B[j]
805+
)
806+
807+
@I.ir_module
808+
class After:
809+
@T.prim_func
810+
def main(A: T.Buffer((25,), "float32"), B: T.Buffer((25,), "float32")):
811+
A[T.Ramp(0, 1, extent)] = T.call_llvm_pure_intrin(
812+
vec_str, "llvm.sqrt", tvm.tir.const(1, "uint"), B[T.Ramp(0, 1, extent)]
813+
)
814+
815+
with tvm.target.Target(target):
816+
mod = tvm.tir.transform.VectorizeLoop()(Before)
817+
tvm.ir.assert_structural_equal(mod, After)
818+
mod = tvm.build(mod, target)
819+
820+
821+
@pytest.mark.parametrize(
822+
"extent, vec_str, target",
823+
[(4, "int32x4", simple_target)],
824+
)
825+
def test_vectorize_llvm_pure_intrin_fail(extent, vec_str, target):
826+
@I.ir_module
827+
class Before:
828+
@T.prim_func
829+
def main(A: T.Buffer((25,), "int32"), B: T.Buffer((25,), "float32")):
830+
for j in T.vectorized(extent):
831+
A[j] = T.call_llvm_pure_intrin(
832+
"int32", "llvm.lround", tvm.tir.const(1, "uint"), B[j]
833+
)
834+
835+
@I.ir_module
836+
class After:
837+
@T.prim_func
838+
def main(A: T.Buffer((25,), "int32"), B: T.Buffer((25,), "float32")):
839+
A[T.Ramp(0, 1, extent)] = T.call_llvm_pure_intrin(
840+
vec_str, "llvm.lround", tvm.tir.const(1, "uint"), B[T.Ramp(0, 1, extent)]
841+
)
842+
843+
with pytest.raises(Exception) as e_info:
844+
with tvm.target.Target(target):
845+
mod = tvm.tir.transform.VectorizeLoop()(Before)
846+
ex = tvm.build(mod, target)
847+
tvm.ir.assert_structural_equal(mod, After)
848+
assert "Intrinsic does not support vectors" in e_info.value.args[0]
849+
850+
793851
if __name__ == "__main__":
794852
tvm.testing.main()

tests/python/tvmscript/test_tvmscript_printer_tir.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1045,5 +1045,26 @@ def main(A: T.Buffer((128,), "float32"), B: T.Buffer((128,), "float32")):
10451045
_assert_print(main, expected_output)
10461046

10471047

1048+
def test_vectorize_llvm_pure_intrin():
1049+
from tvm.script import tir as T
1050+
1051+
@T.prim_func
1052+
def main(a: T.handle, b: T.handle):
1053+
A = T.match_buffer(a, (4,), "float32")
1054+
B = T.match_buffer(b, (4,), "float32")
1055+
A[T.Ramp(0, 1, 4)] = T.call_llvm_pure_intrin(
1056+
"float32x4", "llvm.sqrt", 1, B[T.Ramp(0, 1, 4)]
1057+
)
1058+
1059+
expected_output = """
1060+
# from tvm.script import tir as T
1061+
1062+
@T.prim_func
1063+
def main(A: T.Buffer((4,), "float32"), B: T.Buffer((4,), "float32")):
1064+
A[0:4] = T.call_llvm_pure_intrin("float32x4", "llvm.sqrt", 1, B[0:4])
1065+
"""
1066+
_assert_print(main, expected_output)
1067+
1068+
10481069
if __name__ == "__main__":
10491070
tvm.testing.main()

0 commit comments

Comments
 (0)