Skip to content

Commit 684a838

Browse files
authored
[TIR] Avoid unnecessary dtype escalation in loop splitting (#12035)
This PR introduces a type check to cast loop split decisions (sometimes given as `int64`) back to a smaller datatype when the loop variable's data type is smaller. This issue usually happens during reloading a trace from disk using JSON database and causes the failure of `CompactBufferAllocation` pass.
1 parent f769f4e commit 684a838

File tree

2 files changed

+12
-0
lines changed

2 files changed

+12
-0
lines changed

src/tir/schedule/concrete_schedule.cc

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -452,6 +452,9 @@ Array<LoopRV> ConcreteScheduleNode::Split(const LoopRV& loop_rv,
452452
if (is_const_int(factor) && !is_positive_const(factor)) {
453453
throw NonPositiveFactorError(state_->mod, factor.as<IntImmNode>()->value, i);
454454
}
455+
if (factor.dtype().bits() > loop->extent.dtype().bits()) {
456+
factor = cast(loop->extent.dtype(), factor);
457+
}
455458
factors.push_back(factor);
456459
tot_length *= factor;
457460
}

tests/python/unittest/test_tir_schedule_split_fuse.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
import tvm.testing
2121
from tvm import te, tir
2222
from tvm.script import tir as T
23+
from tvm.tir.expr import IntImm
2324
from tvm.tir.schedule.testing import verify_trace_roundtrip
2425

2526
# pylint: disable=no-member,invalid-name,unused-variable
@@ -637,5 +638,13 @@ def _create_prim_func():
637638
)
638639

639640

641+
def test_split_int64_factors():
642+
sch = tir.Schedule(elementwise_symbolic, debug_mask="all")
643+
block_b = sch.get_block("B")
644+
_, _, k = sch.get_loops(block_b)
645+
sch.split(k, factors=[IntImm(dtype="int64", value=10), None])
646+
tvm.ir.assert_structural_equal(elementwise_symbolic_split, sch.mod["main"])
647+
648+
640649
if __name__ == "__main__":
641650
tvm.testing.main()

0 commit comments

Comments
 (0)