Skip to content

Commit 044e863

Browse files
author
Siyuan Feng
committed
[VULKAN] Fix CLZ support for Vulkan
CLZ (counting leading zeros) is used for improving ceil_log2 performance on vulkan. however, the current implantation is incorrect during dtype converting. This PR contains: 1. Simplify clz for index calculation (happens in vulkan sort) 2. Fix clz for data type conversion
1 parent a309b6b commit 044e863

File tree

5 files changed

+61
-3
lines changed

5 files changed

+61
-3
lines changed

python/tvm/target/detect_target.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -67,8 +67,9 @@ def _detect_vulkan(dev: Device) -> Target:
6767
"max_shared_memory_per_block": dev.max_shared_memory_per_block,
6868
"thread_warp_size": dev.warp_size,
6969
"supports_float16": f_get_target_property(dev, "supports_float16"),
70-
"supports_int16": f_get_target_property(dev, "supports_int16"),
7170
"supports_int8": f_get_target_property(dev, "supports_int8"),
71+
"supports_int16": f_get_target_property(dev, "supports_int16"),
72+
"supports_int64": f_get_target_property(dev, "supports_int64"),
7273
"supports_16bit_buffer": f_get_target_property(dev, "supports_16bit_buffer"),
7374
}
7475
)

src/arith/rewrite_simplify.cc

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2250,6 +2250,17 @@ PrimExpr RewriteSimplifier::Impl::VisitExpr_(const CallNode* op) {
22502250
}
22512251
}
22522252
}
2253+
} else if (op->op.same_as(Op::Get("tir.clz"))) {
2254+
if (const auto* arg_int = op->args[0].as<IntImmNode>()) {
2255+
int bits = arg_int->dtype.bits();
2256+
if (arg_int->value == 0) return make_const(op->dtype, bits);
2257+
for (int i = bits - 1; i >= 0; --i) {
2258+
if ((int64_t(1) << i) & arg_int->value) {
2259+
return IntImm(op->dtype, bits - i - 1);
2260+
}
2261+
}
2262+
LOG(FATAL) << "Should not reach here";
2263+
}
22532264
}
22542265

22552266
if (op->op.same_as(tir::builtin::likely())) {

src/tir/ir/data_type_rewriter.cc

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -215,6 +215,7 @@ TVM_DEFINE_BIOP_EXPR_MUTATE_WITH_TYPE_MATCH(GENode, operator>=);
215215
#undef TVM_DEFINE_BIOP_EXPR_MUTATE_WITH_TYPE_MATCH
216216

217217
PrimExpr DataTypeLegalizer::VisitExpr_(const CallNode* op) {
218+
Call before = GetRef<Call>(op);
218219
PrimExpr e = StmtExprMutator::VisitExpr_(op);
219220
op = e.as<CallNode>();
220221
static const Op& builtin_pow_ = Op::Get("tir.pow");
@@ -234,6 +235,16 @@ PrimExpr DataTypeLegalizer::VisitExpr_(const CallNode* op) {
234235
return pow(op->args[0], op->args[1]);
235236
} else if (op->op.same_as(builtin::if_then_else())) {
236237
return if_then_else(op->args[0], op->args[1], op->args[2]);
238+
} else if (op->op.same_as(Op::Get("tir.clz"))) {
239+
DataType before_dtype = before->args[0]->dtype;
240+
DataType after_dtype = op->args[0]->dtype;
241+
CHECK(before_dtype.is_int() && (before_dtype.bits() == 32 || before_dtype.bits() == 64))
242+
<< "clz only supports 32 or 64 bit integer types, but get type before legalizing: "
243+
<< before_dtype;
244+
CHECK(after_dtype.is_int() && (after_dtype.bits() == 32 || after_dtype.bits() == 64))
245+
<< "clz only supports 32 or 64 bit integer types, but get type after legalizing: "
246+
<< after_dtype;
247+
return e - after_dtype.bits() + before_dtype.bits();
237248
}
238249
return e;
239250
}

tests/python/arith/test_arith_rewrite_simplify.py

Lines changed: 18 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,9 +20,12 @@
2020
import pytest
2121

2222
import tvm
23+
import tvm.testing
2324
from tvm import te, tir
24-
25-
from tvm.tir import truncdiv as tdiv, truncmod as tmod, floordiv as fld, floormod as flm
25+
from tvm.tir import floordiv as fld
26+
from tvm.tir import floormod as flm
27+
from tvm.tir import truncdiv as tdiv
28+
from tvm.tir import truncmod as tmod
2629

2730

2831
class TestCase:
@@ -1150,5 +1153,18 @@ class TestIfThenElse(BaseCompare):
11501153
)
11511154

11521155

1156+
class TestCLZ(BaseCompare):
1157+
test_case = tvm.testing.parameter(
1158+
TestCase(tvm.tir.call_intrin("int32", "tir.clz", 0), 32),
1159+
TestCase(tvm.tir.call_intrin("int32", "tir.clz", 1), 31),
1160+
TestCase(tvm.tir.call_intrin("int32", "tir.clz", 2), 30),
1161+
TestCase(tvm.tir.call_intrin("int32", "tir.clz", 128), 24),
1162+
TestCase(tvm.tir.call_intrin("int32", "tir.clz", tvm.tir.IntImm("int64", 0)), 64),
1163+
TestCase(tvm.tir.call_intrin("int32", "tir.clz", tvm.tir.IntImm("int64", 1)), 63),
1164+
TestCase(tvm.tir.call_intrin("int32", "tir.clz", tvm.tir.IntImm("int64", 2)), 62),
1165+
TestCase(tvm.tir.call_intrin("int32", "tir.clz", tvm.tir.IntImm("int64", 128)), 56),
1166+
)
1167+
1168+
11531169
if __name__ == "__main__":
11541170
tvm.testing.main()

tests/python/tir-transform/test_tir_transform_force_narrow_index_to_i32.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -259,5 +259,24 @@ def main(A: T.Buffer((4,), "float32"), B: T.Buffer((4,), "float32"), n: T.int32)
259259
tvm.ir.assert_structural_equal(Expected, after)
260260

261261

262+
def test_clz():
263+
@tvm.script.ir_module
264+
class Before:
265+
@T.prim_func
266+
def main(B: T.Buffer((T.int64(4),), "int32")):
267+
for i in T.serial(T.int64(4)):
268+
B[i] = T.clz(i)
269+
270+
@tvm.script.ir_module
271+
class Expected:
272+
@T.prim_func
273+
def main(B: T.Buffer((4,), "int32")):
274+
for i in range(4):
275+
B[i] = T.clz(i) - 32 + 64
276+
277+
after = tvm.tir.transform.ForceNarrowIndexToInt32()(Before)
278+
tvm.ir.assert_structural_equal(Expected, after)
279+
280+
262281
if __name__ == "__main__":
263282
tvm.testing.main()

0 commit comments

Comments
 (0)