Skip to content

Commit a8157e2

Browse files
Fix bug when decompose padding wrt the single child subtree
1 parent 209845f commit a8157e2

File tree

2 files changed

+74
-6
lines changed

2 files changed

+74
-6
lines changed

src/tir/schedule/primitive/decompose_padding.cc

Lines changed: 11 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -114,6 +114,10 @@ class PaddingInfoAnalyzer {
114114

115115
// Step 3. Analyze in-bound write region.
116116
PrimExpr in_bound_predicate = RewritePredicate(pad_predicate && realize->predicate);
117+
if (analyzer_->CanProveEqual(in_bound_predicate, 1)) {
118+
SetError("The in-bound predicate is trivial");
119+
return false;
120+
}
117121
Array<Range> in_bound_region = this->EstimateInBoundRegion(
118122
/*iter_values=*/realize->iter_values, /*dom_map=*/dom_map,
119123
/*in_bound_predicate=*/in_bound_predicate);
@@ -439,13 +443,14 @@ StmtSRef DecomposePaddingImpl(ScheduleState self, const StmtSRef& block_sref,
439443
analyzer.Bind(cur_loop->loop_var, range);
440444
loops.push_back(cur_loop);
441445

442-
if (!found_const_filling_pos) {
443-
if (cur_loop.same_as(const_filling_pos)) {
444-
found_const_filling_pos = true;
446+
if (cur_loop.same_as(const_filling_pos)) {
447+
ICHECK(!found_const_filling_pos);
448+
found_const_filling_pos = true;
449+
if (!found_in_bound_filling_pos) {
450+
found_in_bound_filling_pos = true;
451+
in_bound_filling_pos = cur_loop;
445452
}
446-
}
447-
448-
if (!found_in_bound_filling_pos) {
453+
} else if (!found_in_bound_filling_pos) {
449454
if (!cur_loop->body->IsInstance<ForNode>() &&
450455
!cur_loop->body->IsInstance<BlockRealizeNode>()) {
451456
found_in_bound_filling_pos = true;

tests/python/unittest/test_tir_schedule_decompose_padding.py

Lines changed: 63 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -309,5 +309,68 @@ def pooling_decompose_3(
309309
check_decompose_padding(sum_pool_2d, sch.mod["main"], pooling_decompose_3, check_run=True)
310310

311311

312+
def test_decompose_wrt_single_child_subtree():
313+
"""Test the case when the decompose position is under the single child subtree"""
314+
315+
@T.prim_func
316+
def pad_op(
317+
x: T.Buffer[(1, 16, 225, 225), "int8"], y: T.Buffer([1, 16, 231, 231], dtype="int8")
318+
):
319+
for i0, i1, i2, i3 in T.grid(1, 16, 231, 231):
320+
with T.block("pad_temp"):
321+
ax0, ax1, ax2, ax3 = T.axis.remap("SSSS", [i0, i1, i2, i3])
322+
y[ax0, ax1, ax2, ax3] = T.if_then_else(
323+
3 <= ax2 and ax2 < 228 and 3 <= ax3 and ax3 < 228,
324+
x[ax0, ax1, ax2 - 3, ax3 - 3],
325+
T.int8(0),
326+
dtype="int8",
327+
)
328+
329+
@T.prim_func
330+
def pad_op_after(
331+
x: T.Buffer[(1, 16, 225, 225), "int8"], y: T.Buffer[(1, 16, 231, 231), "int8"]
332+
):
333+
for i0, i1 in T.grid(1, 16):
334+
for i2, i3 in T.grid(231, 231):
335+
with T.block("pad_temp_pad_const"):
336+
ax0 = T.axis.spatial(1, 0)
337+
ax1, ax2, ax3 = T.axis.remap("SSS", [i1, i2, i3])
338+
y[ax0, ax1, ax2, ax3] = T.int8(0)
339+
for i2, i3 in T.grid(225, 225):
340+
with T.block("pad_temp"):
341+
ax0 = T.axis.spatial(1, 0)
342+
ax1, ax2, ax3 = T.axis.remap("SSS", [i1, i2, i3])
343+
y[ax0, ax1, ax2 + 3, ax3 + 3] = x[ax0, ax1, ax2, ax3]
344+
345+
sch = tir.Schedule(pad_op, debug_mask="all")
346+
pad = sch.get_block("pad_temp")
347+
_, _, h, _ = sch.get_loops(pad)
348+
sch.decompose_padding(pad, h)
349+
check_decompose_padding(pad_op, sch.mod["main"], pad_op_after, check_run=True)
350+
351+
352+
def test_not_to_decompose_trivial_predicate():
353+
"""Test the case when the padding condition is trivial"""
354+
355+
@T.prim_func
356+
def trivial_pad(
357+
x: T.Buffer[(1, 16, 225, 225), "int8"], y: T.Buffer([1, 16, 225, 225], dtype="int8")
358+
):
359+
for i0, i1, i2, i3 in T.grid(1, 16, 225, 225):
360+
with T.block("pad_temp"):
361+
ax0, ax1, ax2, ax3 = T.axis.remap("SSSS", [i0, i1, i2, i3])
362+
y[ax0, ax1, ax2, ax3] = T.if_then_else(
363+
0 <= ax2 and ax2 < 225 and 0 <= ax3 and ax3 < 225,
364+
x[ax0, ax1, ax2, ax3],
365+
T.int8(0),
366+
dtype="int8",
367+
)
368+
369+
sch = tir.Schedule(trivial_pad, debug_mask="all")
370+
pad = sch.get_block("pad_temp")
371+
_, _, h, _ = sch.get_loops(pad)
372+
assert not sch.can_decompose_padding(pad, h)
373+
374+
312375
if __name__ == "__main__":
313376
tvm.testing.main()

0 commit comments

Comments
 (0)