Skip to content

Commit 00c830e

Browse files
relax reorder primitive's affineness check (#10887)
1 parent c7c76d1 commit 00c830e

File tree

4 files changed

+156
-15
lines changed

4 files changed

+156
-15
lines changed

src/tir/schedule/analysis.h

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -231,6 +231,17 @@ bool IsAffineBinding(const BlockRealize& realize, const Map<Var, Range>& loop_va
231231
*/
232232
void CheckAffineBinding(const ScheduleState& self, Block block);
233233

234+
/*!
235+
* \brief Check whether a block has an affine binding under the high exclusive sref node,
236+
* throw an exception if the block does not have an affine binding.
237+
* \param self The schedule state
238+
* \param block The block to be checked
239+
* \param high_exclusive The highest sref node
240+
* \throw ScheduleError If the input block does not have an affine binding
241+
*/
242+
void CheckPartialAffineBinding(const ScheduleState& self, Block block,
243+
const Optional<StmtSRef>& high_exclusive);
244+
234245
/*!
235246
* \brief Extracts the ranges of loop variables in a path of the sref tree
236247
* \param low_inclusive The lowest node in the path

src/tir/schedule/analysis/analysis.cc

Lines changed: 43 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -544,26 +544,62 @@ bool IsAffineBinding(const BlockRealize& realize, const Map<Var, Range>& loop_va
544544
return true;
545545
}
546546

547-
void CheckAffineBinding(const ScheduleState& self, Block block) {
547+
void CheckPartialAffineBinding(const ScheduleState& self, Block block,
548+
const Optional<StmtSRef>& high_exclusive) {
548549
class NotAffineBindingError : public ScheduleError {
549550
public:
550-
explicit NotAffineBindingError(IRModule mod, Block block)
551-
: mod_(std::move(mod)), block_(std::move(block)) {}
551+
explicit NotAffineBindingError(IRModule mod, Block block, Optional<StmtSRef> high_exclusive)
552+
: mod_(std::move(mod)), block_(std::move(block)) {
553+
if (high_exclusive.defined()) {
554+
high_exclusive_loop_ = high_exclusive.value()->StmtAs<ForNode>();
555+
}
556+
}
552557
String FastErrorString() const final {
553-
return "ScheduleError: The block is required to have an affine binding";
558+
std::ostringstream ss;
559+
if (high_exclusive_loop_) {
560+
ss << "ScheduleError: The block is required to have an partial affine binding under "
561+
<< high_exclusive_loop_->loop_var;
562+
} else {
563+
ss << "ScheduleError: The block is required to have an affine binding";
564+
}
565+
return ss.str();
554566
}
555567
String DetailRenderTemplate() const final {
556-
return "The block {0} is required to have an affine binding";
568+
std::ostringstream ss;
569+
if (high_exclusive_loop_) {
570+
ss << "The block {0} is required to have an partial affine binding under "
571+
<< high_exclusive_loop_->loop_var;
572+
} else {
573+
ss << "The block {0} is required to have an affine binding";
574+
}
575+
return ss.str();
557576
}
558577
IRModule mod() const final { return mod_; }
559578
Array<ObjectRef> LocationsOfInterest() const final { return {block_}; }
560579
IRModule mod_;
561580
Block block_;
581+
const ForNode* high_exclusive_loop_{nullptr};
562582
};
563583

564-
if (!self->IsAffineBlockBinding(self->stmt2ref.at(block.get()))) {
565-
throw NotAffineBindingError(self->mod, std::move(block));
584+
StmtSRef block_sref = self->stmt2ref.at(block.get());
585+
if (self->IsAffineBlockBinding(block_sref)) {
586+
// check block cached state for global affineness
587+
return;
588+
}
589+
if (block_sref->parent && high_exclusive.defined()) {
590+
// if it is not of global affine binding, check affineness under high_exclusive,
591+
arith::Analyzer analyzer;
592+
Map<Var, Range> dom_map =
593+
LoopDomainOfSRefTreePath(GetRef<StmtSRef>(block_sref->parent), high_exclusive);
594+
if (IsAffineBinding(GetBlockRealize(self, block_sref), dom_map, &analyzer)) {
595+
return;
596+
}
566597
}
598+
throw NotAffineBindingError(self->mod, std::move(block), high_exclusive);
599+
}
600+
601+
void CheckAffineBinding(const ScheduleState& self, Block block) {
602+
CheckPartialAffineBinding(self, std::move(block), NullOpt);
567603
}
568604

569605
Map<Var, Range> LoopDomainOfSRefTreePath(const StmtSRef& low_inclusive,

src/tir/schedule/primitive/loop_transformation.cc

Lines changed: 13 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -134,30 +134,35 @@ class IterMapSimplifyBlockBinding : public StmtExprMutator {
134134
class BlockPropertyError : public ScheduleError {
135135
public:
136136
/*!
137-
* \brief Check that all the blocks under the specific stmt have affine bindings and only have
138-
* data-parallel or reduction block iters
137+
* \brief Check that all the blocks under the specific stmt have affine bindings
138+
* wrt top loop sref and only have data-parallel or reduction block iters
139139
* \param self The state of the schedule
140140
* \param sref The sref to the specific stmt
141141
*/
142-
static void CheckBlockIterTypeAndAffineBinding(const ScheduleState& self,
142+
static void CheckBlockIterTypeAndAffineBinding(const ScheduleState& self, const StmtSRefNode* top,
143143
const StmtSRefNode* sref) {
144144
class BlockIterTypeAndAffineBindingChecker : public StmtVisitor {
145145
public:
146-
explicit BlockIterTypeAndAffineBindingChecker(const ScheduleState& state) : state_(state) {}
146+
explicit BlockIterTypeAndAffineBindingChecker(const ScheduleState& state,
147+
const StmtSRefNode* top)
148+
: state_(state), top_(top) {}
147149

148150
private:
149151
void VisitStmt_(const BlockNode* op) final {
150152
for (const IterVar& iter_var : op->iter_vars) {
151153
if (iter_var->iter_type != kDataPar && iter_var->iter_type != kCommReduce) {
152154
throw BlockPropertyError(state_->mod, GetRef<Block>(op));
153155
}
154-
CheckAffineBinding(state_, GetRef<Block>(op));
156+
Optional<StmtSRef> high_exclusive =
157+
top_->parent ? GetRef<StmtSRef>(top_->parent) : Optional<StmtSRef>(NullOpt);
158+
CheckPartialAffineBinding(state_, GetRef<Block>(op), high_exclusive);
155159
}
156160
}
157161
const ScheduleState& state_;
162+
const StmtSRefNode* top_;
158163
};
159164

160-
BlockIterTypeAndAffineBindingChecker checker(self);
165+
BlockIterTypeAndAffineBindingChecker checker(self, top);
161166
checker(GetRef<Stmt>(sref->stmt));
162167
}
163168

@@ -708,8 +713,8 @@ void Reorder(ScheduleState self, const Array<StmtSRef>& ordered_loop_srefs) {
708713
// Step 3. Collect all loops in the chain and check the loops are single-branch
709714
std::vector<const StmtSRefNode*> chain = GetLoopsInReorderRange(self, top, bottom);
710715
// Step 4. Check the block below has all its block_var to be data-parallel or reduction,
711-
// and the block has an affine binding.
712-
BlockPropertyError::CheckBlockIterTypeAndAffineBinding(self, bottom);
716+
// and the block has an affine binding wrt top of the loop range.
717+
BlockPropertyError::CheckBlockIterTypeAndAffineBinding(self, top, bottom);
713718
// Step 5. Replace the original loops with the reordered loops and check that outer loop is
714719
// not dependent on inner loop
715720
For new_loop = ConstructNewLoopChain(self, std::move(chain), ordered_loop_srefs, loop_srefs);

tests/python/unittest/test_tir_schedule_reorder.py

Lines changed: 89 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -213,6 +213,95 @@ def test_reorder_with_opaque_access():
213213
verify_trace_roundtrip(sch=sch, mod=opaque_access)
214214

215215

216+
def test_reorder_with_partial_affineness():
217+
@T.prim_func
218+
def non_affine_func(A: T.Buffer[(14, 4), "float32"], B: T.Buffer[(14, 4), "float32"]):
219+
# example to write first axis multiple times
220+
for v0, v1, v2 in T.grid(6, 4, 4):
221+
with T.block("block"):
222+
i = T.axis.spatial(14, v0 * 2 + v1)
223+
j = T.axis.spatial(4, v2)
224+
B[i, j] = A[i, j] + 1.0
225+
226+
@T.prim_func
227+
def non_affine_func_reorder(A: T.Buffer[(14, 4), "float32"], B: T.Buffer[(14, 4), "float32"]):
228+
# example to write first axis multiple times
229+
for v0, v2, v1 in T.grid(6, 4, 4):
230+
with T.block("block"):
231+
i = T.axis.spatial(14, v0 * 2 + v1)
232+
j = T.axis.spatial(4, v2)
233+
B[i, j] = A[i, j] + 1.0
234+
235+
sch = tir.Schedule(non_affine_func, debug_mask="all")
236+
v0, v1, v2 = sch.get_loops(sch.get_block("block"))
237+
with pytest.raises(tvm.tir.ScheduleError):
238+
sch.reorder(v0, v2, v1)
239+
240+
sch.reorder(v2, v1)
241+
tvm.ir.assert_structural_equal(non_affine_func_reorder, sch.mod["main"])
242+
verify_trace_roundtrip(sch=sch, mod=non_affine_func)
243+
244+
245+
def test_reorder_with_cascade_tiled_ops():
246+
@T.prim_func
247+
def cascade_pool_ops(
248+
x: T.Buffer[(1, 16, 112, 112), "float32"], y2: T.Buffer[(1, 16, 108, 108), "float32"]
249+
) -> None:
250+
y1 = T.alloc_buffer([1, 16, 110, 110], dtype="float32")
251+
for n, c, h, w, kh, kw in T.grid(1, 16, 110, 110, 3, 3):
252+
with T.block("pool_0"):
253+
ax0, ax1, ax2, ax3, rv0, rv1 = T.axis.remap("SSSSRR", [n, c, h, w, kh, kw])
254+
with T.init():
255+
y1[ax0, ax1, ax2, ax3] = 0.0
256+
y1[ax0, ax1, ax2, ax3] = y1[ax0, ax1, ax2, ax3] + x[ax0, ax1, ax2 + rv0, ax3 + rv1]
257+
for n, c, h, w, kh, kw in T.grid(1, 16, 108, 108, 3, 3):
258+
with T.block("pool_1"):
259+
ax0, ax1, ax2, ax3, rv0, rv1 = T.axis.remap("SSSSRR", [n, c, h, w, kh, kw])
260+
with T.init():
261+
y2[ax0, ax1, ax2, ax3] = 0.0
262+
y2[ax0, ax1, ax2, ax3] = y2[ax0, ax1, ax2, ax3] + y1[ax0, ax1, ax2 + rv0, ax3 + rv1]
263+
264+
@T.prim_func
265+
def cascade_pool_ops_tile_reordered(
266+
x: T.Buffer[(1, 16, 112, 112), "float32"], y2: T.Buffer[(1, 16, 108, 108), "float32"]
267+
) -> None:
268+
y1 = T.alloc_buffer([1, 16, 110, 110], dtype="float32")
269+
for n, c, h_o in T.grid(1, 16, 27):
270+
for w, h_i, kh, kw in T.grid(110, 6, 3, 3):
271+
with T.block("pool_0"):
272+
ax0 = T.axis.spatial(1, 0)
273+
ax1 = T.axis.spatial(16, c)
274+
ax2 = T.axis.spatial(110, h_o * 4 + h_i)
275+
ax3, rv0, rv1 = T.axis.remap("SRR", [w, kh, kw])
276+
with T.init():
277+
y1[ax0, ax1, ax2, ax3] = 0.0
278+
y1[ax0, ax1, ax2, ax3] = (
279+
y1[ax0, ax1, ax2, ax3] + x[ax0, ax1, ax2 + rv0, ax3 + rv1]
280+
)
281+
for h_i, w, kh, kw in T.grid(4, 108, 3, 3):
282+
with T.block("pool_1"):
283+
ax0 = T.axis.spatial(1, 0)
284+
ax1 = T.axis.spatial(16, c)
285+
ax2 = T.axis.spatial(108, h_o * 4 + h_i)
286+
ax3, rv0, rv1 = T.axis.remap("SRR", [w, kh, kw])
287+
with T.init():
288+
y2[ax0, ax1, ax2, ax3] = 0.0
289+
y2[ax0, ax1, ax2, ax3] = (
290+
y2[ax0, ax1, ax2, ax3] + y1[ax0, ax1, ax2 + rv0, ax3 + rv1]
291+
)
292+
293+
sch = tvm.tir.schedule.Schedule(cascade_pool_ops)
294+
pool_0 = sch.get_block("pool_0")
295+
pool_1 = sch.get_block("pool_1")
296+
_, _, h, w, _, _ = sch.get_loops(pool_1)
297+
ho, _ = sch.split(h, factors=[None, 4])
298+
sch.compute_at(pool_0, ho)
299+
_, _, _, h_i, w, _, _ = sch.get_loops(pool_0)
300+
sch.reorder(w, h_i)
301+
tvm.ir.assert_structural_equal(cascade_pool_ops_tile_reordered, sch.mod["main"], True)
302+
verify_trace_roundtrip(sch=sch, mod=cascade_pool_ops)
303+
304+
216305
def test_reorder_with_predicate():
217306
sch = tir.Schedule(elementwise_predicate, debug_mask="all")
218307
block_b = sch.get_block("B")

0 commit comments

Comments
 (0)