Skip to content

Commit f1aecb5

Browse files
committed
[TIR, Schedule] Update block flags and simplify predicate in Reverse-Compute-Inline
1 parent b139736 commit f1aecb5

File tree

2 files changed

+45
-0
lines changed

2 files changed

+45
-0
lines changed

src/tir/schedule/primitive/compute_inline.cc

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -651,6 +651,7 @@ class ReverseComputeInliner : public BaseInliner {
651651
// Substitute the producer block iters with the its bindings since the predicate in BlockRealize
652652
// should not contain the block iters
653653
predicate = Substitute(predicate, subst_map);
654+
predicate = analyzer_.Simplify(predicate);
654655
return predicate;
655656
}
656657

@@ -865,6 +866,13 @@ void ReverseComputeInlineImpl(ScheduleState self, const StmtSRef& consumer_block
865866
return;
866867
}
867868
self->Replace(scope_root_sref, tgt_stmt, inliner.block_reuse);
869+
// Step 8. Update the cached flags
870+
arith::Analyzer analyzer;
871+
BlockInfo& block_info = self->block_info[producer_block_sref];
872+
block_info.affine_binding = IsAffineBinding(
873+
/*realize=*/GetBlockRealize(self, producer_block_sref),
874+
/*loop_var_ranges=*/LoopDomainOfSRefTreePath(GetRef<StmtSRef>(producer_block_sref->parent)),
875+
/*analyzer=*/&analyzer);
868876
}
869877

870878
bool CanReverseComputeInline(const ScheduleState& self, const StmtSRef& block_sref) {

tests/python/unittest/test_tir_schedule_compute_inline.py

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -611,6 +611,34 @@ def elementwise_overcomputed_producer_reverse_inlined(
611611
C[vi, vj] = A[vi, vj] * 2.0 + 1.0
612612

613613

614+
@T.prim_func
615+
def elementwise_overcomputed_producer_simplify_predicate(
616+
A: T.Buffer((128, 128), "float32"), C: T.Buffer((127, 127), "float32")
617+
) -> None:
618+
B = T.alloc_buffer((128, 128))
619+
for i in T.grid(16384):
620+
with T.block("B"):
621+
vi = T.axis.spatial(128, i // 128)
622+
vj = T.axis.spatial(128, i % 128)
623+
B[vi, vj] = A[vi, vj] * 2.0
624+
for i, j in T.grid(127, 127):
625+
with T.block("C"):
626+
cvi, cvj = T.axis.remap("SS", [i, j])
627+
C[cvi, cvj] = B[cvi, cvj] + 1.0
628+
629+
630+
@T.prim_func
631+
def elementwise_overcomputed_producer_simplify_predicate_reverse_inlined(
632+
A: T.Buffer((128, 128), "float32"), C: T.Buffer((127, 127), "float32")
633+
) -> None:
634+
for i in T.grid(16384):
635+
with T.block("B"):
636+
vi = T.axis.spatial(128, i // 128)
637+
vj = T.axis.spatial(128, i % 128)
638+
T.where(i < 16255 and i % 128 < 127)
639+
C[vi, vj] = A[vi, vj] * 2.0 + 1.0
640+
641+
614642
@T.prim_func
615643
def elementwise_producer_not_cover_consumer(
616644
A: T.Buffer((128, 128), "float32"), D: T.Buffer((256, 128), "float32")
@@ -1025,6 +1053,15 @@ def test_reverse_compute_inline_overcomputed_producer(use_block_name):
10251053
)
10261054

10271055

1056+
def test_reverse_compute_inline_overcomputed_producer_simplify_predicate(use_block_name):
1057+
"""Test reverse compute inline overcomputed producer where the predicate should be simplified"""
1058+
sch = tir.Schedule(elementwise_overcomputed_producer_simplify_predicate, debug_mask="all")
1059+
compute = "C" if use_block_name else sch.get_block("C")
1060+
sch.reverse_compute_inline(compute)
1061+
tvm.ir.assert_structural_equal(
1062+
elementwise_overcomputed_producer_simplify_predicate_reverse_inlined, sch.mod["main"]
1063+
)
1064+
10281065
def test_reverse_compute_inline_error_producer_not_cover_consumer(use_block_name):
10291066
"""Test reverse compute inline failure when the inlined block iter domains are not covered by
10301067
its producer

0 commit comments

Comments
 (0)