Skip to content

Commit f1ba5ed

Browse files
authored
[BugFix][TIR] Schedule support reverse-inline with reduction blocks (#17838)
This PR fixes a bug in reverse-compute-inline of tir Schedule, which generates incorrect TIR after inlining a transpose block into a reduction block.
1 parent 6bd55f0 commit f1ba5ed

File tree

2 files changed

+76
-1
lines changed

2 files changed

+76
-1
lines changed

src/tir/schedule/primitive/compute_inline.cc

Lines changed: 27 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -586,6 +586,30 @@ class ReverseComputeInliner : public BaseInliner {
586586
ReverseComputeInliner* self_;
587587
};
588588

589+
class RecursionResolver : public StmtExprMutator {
590+
public:
591+
explicit RecursionResolver(ReverseComputeInliner* self) : self_(self) {}
592+
593+
private:
594+
PrimExpr VisitExpr_(const VarNode* var) final {
595+
auto it = self_->idx_sub_.find(var);
596+
if (it == self_->idx_sub_.end()) {
597+
return GetRef<Var>(var);
598+
}
599+
return (*it).second;
600+
}
601+
602+
PrimExpr VisitExpr_(const BufferLoadNode* _load) final {
603+
BufferLoad load = Downcast<BufferLoad>(StmtExprMutator::VisitExpr_(_load));
604+
return load->buffer.same_as(self_->inlined_buffer_)
605+
? StmtExprMutator::VisitExpr(
606+
BufferLoad(self_->inlined_store_->buffer, self_->inlined_store_->indices))
607+
: load;
608+
}
609+
610+
ReverseComputeInliner* self_;
611+
};
612+
589613
public:
590614
explicit ReverseComputeInliner(const Buffer& inlined_buffer, const BlockNode* producer_block,
591615
const BlockRealize& consumer_block_realize,
@@ -784,7 +808,9 @@ class ReverseComputeInliner : public BaseInliner {
784808
}
785809

786810
Stmt ReplaceInlinedBuffer(BufferStore producer) {
787-
producer_rhs_ = producer->value;
811+
// "producer->value" may contain the buffer that is inlined in cases of reduction,
812+
// so we need to resolve the recursion first
813+
producer_rhs_ = RecursionResolver(this)(producer->value);
788814
return Substituter(this)(GetRef<BufferStore>(inlined_store_));
789815
}
790816

tests/python/tir-schedule/test_tir_schedule_compute_inline.py

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1529,5 +1529,54 @@ def after(
15291529
assert_structural_equal_ignore_global_symbol(after, sch.mod["main"])
15301530

15311531

1532+
def test_inline_with_reduction():
1533+
@T.prim_func
1534+
def before(
1535+
T_softmax_norm: T.Buffer((T.int64(6), T.int64(1), T.int64(1)), "float32"),
1536+
T_reshape_2: T.Buffer((T.int64(6), T.int64(1), T.int64(64)), "float32"),
1537+
T_transpose: T.Buffer((T.int64(1), T.int64(1), T.int64(6), T.int64(64)), "float32"),
1538+
):
1539+
T_batch_matmul_NN = T.alloc_buffer((T.int64(6), T.int64(1), T.int64(64)))
1540+
for ax0, ax1 in T.grid(T.int64(6), T.int64(64)):
1541+
with T.block("bmm"):
1542+
v0, v1 = T.axis.remap("SS", [ax0, ax1])
1543+
T.reads(T_softmax_norm[v0, T.int64(0), T.int64(0)], T_reshape_2[v0, T.int64(0), v1])
1544+
T.writes(T_batch_matmul_NN[v0, T.int64(0), v1])
1545+
with T.init():
1546+
T_batch_matmul_NN[v0, T.int64(0), v1] = T.float32(0.0)
1547+
T_batch_matmul_NN[v0, T.int64(0), v1] = (
1548+
T_batch_matmul_NN[v0, T.int64(0), v1]
1549+
+ T_softmax_norm[v0, T.int64(0), T.int64(0)] * T_reshape_2[v0, T.int64(0), v1]
1550+
)
1551+
for ax0, ax1 in T.grid(T.int64(6), T.int64(64)):
1552+
with T.block("transpose"):
1553+
v0, v1 = T.axis.remap("SS", [ax0, ax1])
1554+
T.reads(T_batch_matmul_NN[v0, T.int64(0), v1])
1555+
T.writes(T_transpose[T.int64(0), T.int64(0), v0, v1])
1556+
T_transpose[T.int64(0), T.int64(0), v0, v1] = T_batch_matmul_NN[v0, T.int64(0), v1]
1557+
1558+
@T.prim_func
1559+
def after(
1560+
T_softmax_norm: T.Buffer((T.int64(6), T.int64(1), T.int64(1)), "float32"),
1561+
T_reshape_2: T.Buffer((T.int64(6), T.int64(1), T.int64(64)), "float32"),
1562+
T_transpose: T.Buffer((T.int64(1), T.int64(1), T.int64(6), T.int64(64)), "float32"),
1563+
):
1564+
for ax0, ax1 in T.grid(T.int64(6), T.int64(64)):
1565+
with T.block("bmm"):
1566+
v0, v1 = T.axis.remap("SS", [ax0, ax1])
1567+
T.reads(T_softmax_norm[v0, T.int64(0), T.int64(0)], T_reshape_2[v0, T.int64(0), v1])
1568+
T.writes(T_transpose[T.int64(0), T.int64(0), v0, v1])
1569+
with T.init():
1570+
T_transpose[T.int64(0), T.int64(0), v0, v1] = T.float32(0.0)
1571+
T_transpose[T.int64(0), T.int64(0), v0, v1] = (
1572+
T_transpose[T.int64(0), T.int64(0), v0, v1]
1573+
+ T_softmax_norm[v0, T.int64(0), T.int64(0)] * T_reshape_2[v0, T.int64(0), v1]
1574+
)
1575+
1576+
sch = tir.Schedule(before)
1577+
sch.reverse_compute_inline(sch.get_block("transpose"))
1578+
assert_structural_equal_ignore_global_symbol(after, sch.mod["main"])
1579+
1580+
15321581
if __name__ == "__main__":
15331582
tvm.testing.main()

0 commit comments

Comments
 (0)