Skip to content

Commit 98a3eb3

Browse files
ekaldalhutton1Neil Hickey
committed
[SVE] Add vscale builtin
Add a vscale builtin and lowering to `llvm.vscale`. This will be used in subsequent patches for expressing scalable vectors in TIR. Co-authored-by: Luke Hutton <[email protected]> Co-authored-by: Neil Hickey <[email protected]>
1 parent 90320b2 commit 98a3eb3

File tree

7 files changed

+46
-3
lines changed

7 files changed

+46
-3
lines changed

include/tvm/tir/builtin.h

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -909,6 +909,11 @@ TVM_DLL const Op& anylist_setitem_call_packed();
909909
*/
910910
TVM_DLL const Op& anylist_setitem_call_cpacked();
911911

912+
/*!
913+
* \brief Get the target's vscale value
914+
*/
915+
TVM_DLL const Op& vscale();
916+
912917
/*! \brief The kind of structure field info used in intrinsic */
913918
enum TVMStructFieldKind : int {
914919
// array head address

python/tvm/script/ir_builder/tir/ir.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1862,6 +1862,7 @@ def wrapped(*args, **kwargs):
18621862
anylist_resetitem = _op_wrapper(_tir_op.anylist_resetitem)
18631863
anylist_setitem_call_packed = _op_wrapper(_tir_op.anylist_setitem_call_packed)
18641864
anylist_setitem_call_cpacked = _op_wrapper(_tir_op.anylist_setitem_call_cpacked)
1865+
vscale = _op_wrapper(_tir_op.vscale)
18651866

18661867

18671868
def _dtype_forward(func):
@@ -2199,4 +2200,5 @@ def wrapped(*args, **kwargs):
21992200
"IterVar",
22002201
"CommReducer",
22012202
"Range",
2203+
"vscale",
22022204
]

python/tvm/tir/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -88,6 +88,7 @@
8888
from .op import q_multiply_shift, q_multiply_shift_per_axis, shift_left, shift_right
8989
from .op import TVMBackendAllocWorkspace, TVMBackendFreeWorkspace
9090
from .op import start_profile_intrinsic, end_profile_intrinsic
91+
from .op import vscale
9192
from .generic import add, subtract, multiply
9293

9394
from .schedule import StmtSRef, BlockScope, ScheduleState, Schedule, ScheduleError

python/tvm/tir/op.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3338,6 +3338,16 @@ def anylist_setitem_call_cpacked(list_handle, index, func_name, *args):
33383338
)
33393339

33403340

3341+
def vscale():
3342+
"""Get the target's vscale value
3343+
Returns
3344+
-------
3345+
call : PrimExpr
3346+
Call to the vscale intrinsic
3347+
"""
3348+
return call_intrin("int32", "tir.vscale")
3349+
3350+
33413351
# pylint: disable=unnecessary-lambda
33423352
sum = comm_reducer(lambda x, y: x + y, lambda t: const(0, dtype=t), name="sum")
33433353
min = comm_reducer(lambda x, y: _ffi_api._OpMin(x, y, None), max_value, name="min") # type: ignore

src/target/llvm/codegen_llvm.cc

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1478,6 +1478,12 @@ llvm::Value* CodeGenLLVM::CreateIntrinsic(const CallNode* op) {
14781478
return builder_->CreateAssumption(cond);
14791479
} else if (op->op.same_as(builtin::tvm_thread_invariant())) {
14801480
return MakeValue(op->args[0]);
1481+
#if TVM_LLVM_VERSION >= 110
1482+
} else if (op->op.same_as(builtin::vscale())) {
1483+
llvm::Intrinsic::ID id = llvm::Intrinsic::vscale;
1484+
llvm::Function* f = GetIntrinsicDecl(id, builder_->getInt32Ty(), {});
1485+
return builder_->CreateCall(f);
1486+
#endif
14811487
} else {
14821488
LOG(FATAL) << "unknown intrinsic " << op->op;
14831489
}

src/tir/op/builtin.cc

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -394,6 +394,9 @@ TIR_DEFINE_BUILTIN_FUNC(anylist_setitem_call_packed)
394394

395395
TIR_DEFINE_BUILTIN_FUNC(anylist_setitem_call_cpacked)
396396
.set_attr<TCallEffectKind>("TCallEffectKind", Integer(CallEffectKind::kOpaque));
397+
398+
TIR_DEFINE_BUILTIN_FUNC(vscale).set_attr<TCallEffectKind>("TCallEffectKind",
399+
Integer(CallEffectKind::kOpaque));
397400
} // namespace builtin
398401
} // namespace tir
399402
} // namespace tvm

tests/python/codegen/test_target_codegen_aarch64.py

Lines changed: 19 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -16,10 +16,8 @@
1616
# under the License.
1717
import tvm
1818
from tvm import te
19-
from tvm.script import tir as TIR
19+
from tvm.script import tir as T
2020
import re
21-
import os
22-
import ctypes
2321
import pytest
2422

2523
from tvm.target.codegen import llvm_version_major
@@ -476,5 +474,23 @@ def check_correct_assembly(type):
476474
check_correct_assembly(type=dtype)
477475

478476

477+
@pytest.mark.skipif(
478+
llvm_version_major() < 10, reason="Vscale is not supported in earlier versions of LLVM"
479+
)
480+
def test_codegen_vscale():
481+
target = "llvm -mtriple=aarch64-linux-gnu -mattr=+sve"
482+
vscale = tvm.tir.vscale()
483+
484+
@T.prim_func
485+
def main(A: T.Buffer((5,), "int32")):
486+
for i in range(5):
487+
A[i] = 2 * vscale
488+
489+
build_mod = tvm.build(main, target=target)
490+
llvm = build_mod.get_source()
491+
492+
assert re.findall(r"llvm.vscale.i32", llvm), "No vscale in generated LLVM."
493+
494+
479495
if __name__ == "__main__":
480496
tvm.testing.main()

0 commit comments

Comments
 (0)