Skip to content

Commit 6252fa5

Browse files
author
Siyuan Feng
authored
[TIR] Enhance CLZ intrinsic support (#16952)
1 parent bc8742b commit 6252fa5

File tree

7 files changed

+94
-5
lines changed

7 files changed

+94
-5
lines changed

.github/workflows/main.yml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -77,6 +77,8 @@ jobs:
7777
- name: Minimal Metal Compile-and-Run
7878
shell: bash -l {0}
7979
run: >-
80+
python -m pytest -v -s 'tests/python/codegen/test_target_codegen_metal.py'
81+
python -m pytest -v -s 'tests/python/codegen/test_target_codegen_gpu_common.py'
8082
python -m pytest -v -s 'tests/python/codegen/test_gpu_codegen_allreduce.py::test_allreduce_sum[dims0-metal]'
8183
# - name: Test iOS RPC
8284
# shell: bash -l {0}

src/target/intrin_rule.h

Lines changed: 15 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -53,8 +53,13 @@ struct Direct {
5353
std::string operator()(DataType t, std::string name) const { return name; }
5454
};
5555

56-
// Call pure extern function.
57-
template <typename T>
56+
/*!
57+
* \brief Dispatch pure extern function.
58+
* \param e The call expression.
59+
* \tparam T The function to dispatch.
60+
* \tparam dtype_from_arg Whether the dtype is from the first argument or the call node
61+
*/
62+
template <typename T, bool dtype_from_arg = false>
5863
inline PrimExpr DispatchPureExtern(const PrimExpr& e) {
5964
const CallNode* call = e.as<CallNode>();
6065
ICHECK(call != nullptr);
@@ -64,7 +69,14 @@ inline PrimExpr DispatchPureExtern(const PrimExpr& e) {
6469
ICHECK(op != nullptr);
6570
std::string name = op->name;
6671
ICHECK_EQ(name.substr(0, 4), "tir.");
67-
name = T()(call->dtype, name.substr(4));
72+
DataType dtype;
73+
if (dtype_from_arg) {
74+
ICHECK_EQ(call->args.size(), 1U);
75+
dtype = call->args[0].dtype();
76+
} else {
77+
dtype = call->dtype;
78+
}
79+
name = T()(dtype, name.substr(4));
6880

6981
if (name.length() != 0) {
7082
Array<PrimExpr> new_args = {StringImm(name)};

src/target/source/intrin_rule_cuda.cc

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,15 @@ struct CUDAMath {
5454
}
5555
} else if (t.is_bfloat16()) {
5656
return 'h' + name;
57+
} else if (t.is_int() || t.is_uint()) {
58+
switch (t.bits()) {
59+
case 32:
60+
return "__" + name;
61+
case 64:
62+
return "__" + name + "ll";
63+
default:
64+
return "";
65+
}
5766
}
5867
return "";
5968
}
@@ -133,6 +142,9 @@ static PrimExpr DispatchCUDAShuffle(const PrimExpr& e) {
133142
return Call(call->dtype, T()(call->dtype, Downcast<Op>(call->op)), cuda_args);
134143
}
135144

145+
TVM_REGISTER_OP("tir.clz").set_attr<FLowerIntrinsic>(
146+
"cuda.FLowerIntrinsic", DispatchPureExtern<CUDAMath, /*dtype_from_arg=*/true>);
147+
136148
TVM_REGISTER_OP("tir.floor")
137149
.set_attr<FLowerIntrinsic>("cuda.FLowerIntrinsic", DispatchPureExtern<CUDAMath>);
138150

src/target/source/intrin_rule_metal.cc

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,9 @@ static PrimExpr DispatchMetalShuffle(const PrimExpr& e) {
5252
return Call(call->dtype, T()(call->dtype, Downcast<Op>(call->op)), metal_args);
5353
}
5454

55+
TVM_REGISTER_OP("tir.clz").set_attr<FLowerIntrinsic>("metal.FLowerIntrinsic",
56+
DispatchPureExtern<Direct>);
57+
5558
TVM_REGISTER_OP("tir.floor")
5659
.set_attr<FLowerIntrinsic>("metal.FLowerIntrinsic", DispatchPureExtern<Direct>);
5760

src/target/source/intrin_rule_opencl.cc

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,9 @@ namespace codegen {
3131
namespace intrin {
3232
using tir::FLowerIntrinsic;
3333

34+
TVM_REGISTER_OP("tir.clz").set_attr<FLowerIntrinsic>("opencl.FLowerIntrinsic",
35+
DispatchPureExtern<Direct>);
36+
3437
TVM_REGISTER_OP("tir.floor")
3538
.set_attr<FLowerIntrinsic>("opencl.FLowerIntrinsic", DispatchPureExtern<Direct>);
3639

src/tir/ir/data_type_rewriter.cc

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -238,10 +238,12 @@ PrimExpr DataTypeLegalizer::VisitExpr_(const CallNode* op) {
238238
} else if (op->op.same_as(Op::Get("tir.clz"))) {
239239
DataType before_dtype = before->args[0]->dtype;
240240
DataType after_dtype = op->args[0]->dtype;
241-
CHECK(before_dtype.is_int() && (before_dtype.bits() == 32 || before_dtype.bits() == 64))
241+
CHECK((before_dtype.is_int() || before_dtype.is_uint()) &&
242+
(before_dtype.bits() == 32 || before_dtype.bits() == 64))
242243
<< "clz only supports 32 or 64 bit integer types, but get type before legalizing: "
243244
<< before_dtype;
244-
CHECK(after_dtype.is_int() && (after_dtype.bits() == 32 || after_dtype.bits() == 64))
245+
CHECK((after_dtype.is_int() || after_dtype.is_uint()) &&
246+
(after_dtype.bits() == 32 || after_dtype.bits() == 64))
245247
<< "clz only supports 32 or 64 bit integer types, but get type after legalizing: "
246248
<< after_dtype;
247249
return e - after_dtype.bits() + before_dtype.bits();
Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,55 @@
1+
# Licensed to the Apache Software Foundation (ASF) under one
2+
# or more contributor license agreements. See the NOTICE file
3+
# distributed with this work for additional information
4+
# regarding copyright ownership. The ASF licenses this file
5+
# to you under the Apache License, Version 2.0 (the
6+
# "License"); you may not use this file except in compliance
7+
# with the License. You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing,
12+
# software distributed under the License is distributed on an
13+
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14+
# KIND, either express or implied. See the License for the
15+
# specific language governing permissions and limitations
16+
# under the License.
17+
from functools import partial
18+
19+
import numpy as np
20+
import pytest
21+
22+
import tvm
23+
import tvm.testing
24+
from tvm import te
25+
26+
27+
@tvm.testing.requires_gpu
28+
@tvm.testing.parametrize_targets("cuda", "metal", "vulkan -supports_int64=1", "opencl")
29+
@pytest.mark.parametrize("dtype", ["int32", "uint32", "int64", "uint64"])
30+
def test_int_intrin(target, dev, dtype):
31+
test_funcs = [
32+
(tvm.tir.clz, lambda x, dtype: int(dtype[-2:]) - (len(bin(x)) - 2)),
33+
]
34+
35+
def run_test(tvm_intrin, np_func, dtype):
36+
n = 128
37+
A = te.placeholder((n,), name="A", dtype=dtype)
38+
B = te.compute(A.shape, lambda *i: tvm_intrin(A(*i)), name="B")
39+
func = te.create_prim_func([A, B])
40+
sch = tvm.tir.Schedule(func)
41+
(x,) = sch.get_loops(sch.get_block("B"))
42+
sch.bind(x, "threadIdx.x")
43+
f = tvm.build(sch.mod, target=target)
44+
a = tvm.nd.array(np.random.randint(0, 100000, size=n).astype(A.dtype), dev)
45+
b = tvm.nd.array(np.zeros(shape=(n,)).astype(B.dtype), dev)
46+
f(a, b)
47+
ref = np.vectorize(partial(np_func, dtype=dtype))(a.numpy())
48+
tvm.testing.assert_allclose(b.numpy(), ref)
49+
50+
for func in test_funcs:
51+
run_test(*func, dtype)
52+
53+
54+
if __name__ == "__main__":
55+
tvm.testing.main()

0 commit comments

Comments
 (0)