Skip to content

Commit 8e7382c

Browse files
[TIR][Schedule] enhance compute_at primitive to choose proper position
1 parent c477c76 commit 8e7382c

File tree

9 files changed

+173
-28
lines changed

9 files changed

+173
-28
lines changed

include/tvm/tir/schedule/schedule.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -431,8 +431,8 @@ class ScheduleNode : public runtime::Object {
431431
* \param loop_rv The loop where the block to be moved under
432432
* \param preserve_unit_loops Whether to keep the trivial loops whose extents are 1
433433
*/
434-
virtual void ComputeAt(const BlockRV& block_rv, const LoopRV& loop_rv,
435-
bool preserve_unit_loops) = 0;
434+
virtual void ComputeAt(const BlockRV& block_rv, const LoopRV& loop_rv, bool preserve_unit_loops,
435+
bool to_early_stage = false) = 0;
436436
/*!
437437
* \brief Move a consumer block under the specific loop, and regenerate the
438438
* loops induced by the block so that the buffer region consumed by the consumer block could

python/tvm/tir/schedule/schedule.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1261,6 +1261,7 @@ def compute_at(
12611261
block: Union[BlockRV, str],
12621262
loop: LoopRV,
12631263
preserve_unit_loops: bool = False,
1264+
to_early_stage: bool = False,
12641265
) -> None:
12651266
"""Compute-At. Move a producer block under the specific loop, and regenerate the
12661267
loops induced by the block so that the buffer region produced by the producer block could
@@ -1290,6 +1291,9 @@ def compute_at(
12901291
preserve_unit_loops: bool
12911292
Whether to keep the trivial loops whose extents are 1
12921293
1294+
to_early_stage: bool
1295+
Choose to closed to or away from it's consumer
1296+
12931297
Examples
12941298
--------
12951299
@@ -1347,6 +1351,7 @@ def after_compute_at(a: T.handle, c: T.handle) -> None:
13471351
block,
13481352
loop,
13491353
preserve_unit_loops,
1354+
to_early_stage,
13501355
)
13511356

13521357
@type_checked

src/tir/schedule/concrete_schedule.cc

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -567,7 +567,7 @@ BlockRV ConcreteScheduleNode::ReIndex(const BlockRV& block_rv, int buffer_index,
567567
/******** Schedule: Compute location ********/
568568

569569
void ConcreteScheduleNode::ComputeAt(const BlockRV& block_rv, const LoopRV& loop_rv,
570-
bool preserve_unit_loops) {
570+
bool preserve_unit_loops, bool to_early_stage) {
571571
static StmtSRef inline_mark = StmtSRef::InlineMark();
572572
static StmtSRef root_mark = StmtSRef::RootMark();
573573
StmtSRef loop_sref = this->GetSRef(loop_rv);
@@ -579,7 +579,7 @@ void ConcreteScheduleNode::ComputeAt(const BlockRV& block_rv, const LoopRV& loop
579579
TVM_TIR_SCHEDULE_END("compute-at", this->error_render_level_);
580580
} else {
581581
TVM_TIR_SCHEDULE_BEGIN();
582-
tir::ComputeAt(state_, this->GetSRef(block_rv), loop_sref, preserve_unit_loops);
582+
tir::ComputeAt(state_, this->GetSRef(block_rv), loop_sref, preserve_unit_loops, to_early_stage);
583583
TVM_TIR_SCHEDULE_END("compute-at", this->error_render_level_);
584584
}
585585
this->state_->DebugVerify();

src/tir/schedule/concrete_schedule.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -119,7 +119,8 @@ class ConcreteScheduleNode : public ScheduleNode {
119119
BlockRV ReIndex(const BlockRV& block_rv, int buffer_index,
120120
BufferIndexType buffer_index_type) override;
121121
/******** Schedule: Compute location ********/
122-
void ComputeAt(const BlockRV& block_rv, const LoopRV& loop_rv, bool preserve_unit_loops) override;
122+
void ComputeAt(const BlockRV& block_rv, const LoopRV& loop_rv, bool preserve_unit_loops,
123+
bool to_early_stage = false) override;
123124
void ReverseComputeAt(const BlockRV& block_rv, const LoopRV& loop_rv,
124125
bool preserve_unit_loops) override;
125126
void ComputeInline(const BlockRV& block) override;

src/tir/schedule/primitive.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -301,7 +301,7 @@ TVM_DLL StmtSRef ReIndex(ScheduleState self, const StmtSRef& block_sref, int buf
301301
* \param preserve_unit_loops Whether to keep the trivial loops whose extents are 1
302302
*/
303303
TVM_DLL void ComputeAt(ScheduleState self, const StmtSRef& block_sref, const StmtSRef& loop_sref,
304-
bool preserve_unit_loops);
304+
bool preserve_unit_loops, bool to_early_stage = false);
305305
/*!
306306
* \brief Move a consumer block under the specific loop, and regenerate the
307307
* loops induced by the block so that the buffer region consumed by the consumer block could

src/tir/schedule/primitive/compute_at.cc

Lines changed: 63 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -120,24 +120,38 @@ class NotInSameScopeError : public ScheduleError {
120120

121121
/******** Helper Functions/Classes ********/
122122

123+
Stmt GetBlock(Stmt stmt) {
124+
class Finder : public StmtVisitor {
125+
public:
126+
void VisitStmt_(const BlockRealizeNode* realize) final { blk = realize->block; }
127+
Stmt blk;
128+
};
129+
Finder finder;
130+
finder(stmt);
131+
return finder.blk;
132+
}
133+
123134
/*!
124135
* \brief Find a point where the block can be inserted under the loop
125136
* \tparam require_all_producers_visited Requires all producer blocks to be present under the loop
126137
* \tparam require_all_consumers_visited Requires all consumer blocks to be present under the loop
127138
* \param self The schedule state
139+
* \param scope The scope root block BlockScope
128140
* \param subtrees The subtrees under the loop, among which the insertion points are sought
129141
* \param producer_srefs The producer blocks
130142
* \param consumer_srefs The consumer blocks
131143
* \param block2realize A cache that maps a block to its realize
144+
* \param to_early_stage closed to or away from it's consumer
132145
* \return The last position the new block can be inserted onto, and the
133146
* producer-consumer-relationship is still satisfied.
134147
* \throws ScheduleError if there is no such insertion point found
135148
*/
136149
template <bool require_all_producers_visited, bool require_all_consumers_visited>
137-
int FindInsertionPoint(
138-
const ScheduleState& self, const Array<Stmt>& subtrees, const Array<StmtSRef>& producer_srefs,
139-
const Array<StmtSRef>& consumer_srefs,
140-
std::unordered_map<const BlockNode*, const BlockRealizeNode*>* block2realize) {
150+
int FindInsertionPoint(const ScheduleState& self, const BlockScope scope,
151+
const Array<Stmt>& subtrees, const Array<StmtSRef>& producer_srefs,
152+
const Array<StmtSRef>& consumer_srefs,
153+
std::unordered_map<const BlockNode*, const BlockRealizeNode*>* block2realize,
154+
bool to_early_stage) {
141155
ProducerConsumerSplit split =
142156
ProducerConsumerSplit::Find(self, subtrees, producer_srefs, consumer_srefs, block2realize);
143157
// Step 1. Check if all the producers are visited in the subtrees, if required to
@@ -160,7 +174,37 @@ int FindInsertionPoint(
160174
// The valid indices are: (last_producer_position, first_consumer_position]
161175
ICHECK(split.last_producer_position < split.first_consumer_position);
162176
// Step 4. Return the last valid insertion point
163-
return split.first_consumer_position;
177+
int insert_position = split.first_consumer_position;
178+
if (require_all_consumers_visited && to_early_stage) {
179+
class Finder : public StmtVisitor {
180+
public:
181+
void VisitStmt_(const BlockRealizeNode* realize) final {
182+
const BlockNode* block = realize->block.get();
183+
if (producer_blocks_.count(block)) {
184+
++this->n_producers_visited_;
185+
}
186+
}
187+
188+
std::unordered_set<const StmtNode*> producer_blocks_;
189+
int n_producers_visited_ = 0;
190+
};
191+
// adjust the inserted position by compute at order
192+
for (int i = split.first_consumer_position; i - 1 > split.last_producer_position; --i) {
193+
auto blk = GetBlock(subtrees[i]);
194+
if (!blk.defined()) break;
195+
auto block_sref = self->stmt2ref.at(blk.get());
196+
Array<StmtSRef> block_producer_srefs = GetProducers(block_sref, scope);
197+
Finder finder;
198+
finder.producer_blocks_.reserve(block_producer_srefs.size());
199+
for (const StmtSRef& block_sref_ : block_producer_srefs) {
200+
finder.producer_blocks_.insert(block_sref_->stmt);
201+
}
202+
finder(subtrees[i - 1]);
203+
if (finder.n_producers_visited_ == 0) break;
204+
insert_position = i - 1;
205+
}
206+
}
207+
return insert_position;
164208
}
165209

166210
/*!
@@ -556,7 +600,8 @@ void CalculateProvidedRequiredRegions(
556600
template <bool is_compute_at>
557601
void ComputeAtOrReverseComputeAtImpl(ScheduleState self, const StmtSRef& block_sref,
558602
const StmtSRef& loop_sref, bool preserve_unit_loops,
559-
arith::Analyzer* analyzer, bool check_only = false) {
603+
arith::Analyzer* analyzer, bool check_only = false,
604+
bool to_early_stage = false) {
560605
const BlockNode* block = TVM_SREF_TO_BLOCK(block, block_sref);
561606
const ForNode* loop = TVM_SREF_TO_FOR(loop, loop_sref);
562607
// Step 1. Bunch of checks
@@ -585,10 +630,11 @@ void ComputeAtOrReverseComputeAtImpl(ScheduleState self, const StmtSRef& block_s
585630
std::unordered_map<const BlockNode*, const BlockRealizeNode*> block2realize;
586631
block2realize.reserve(self->block_info.size());
587632
int insert_position = FindInsertionPoint<!is_compute_at, is_compute_at>(
588-
/*self=*/self,
633+
/*self=*/self, /*scope=*/scope,
589634
/*subtrees=*/AsArray(loop->body),
590635
/*producer_srefs=*/producer_srefs,
591-
/*consumer_srefs=*/consumer_srefs, /*block2realize=*/&block2realize);
636+
/*consumer_srefs=*/consumer_srefs, /*block2realize=*/&block2realize,
637+
/*to_early_stage*/ to_early_stage);
592638
// Step 4. Calculate the region provided by a single execution instance of `block`,
593639
// as well as the region required by dependent blocks under `loop`.
594640
// Here is the definition of `provide` and `require`:
@@ -626,10 +672,10 @@ void ComputeAtOrReverseComputeAtImpl(ScheduleState self, const StmtSRef& block_s
626672
}
627673

628674
void ComputeAt(ScheduleState self, const StmtSRef& block_sref, const StmtSRef& loop_sref,
629-
bool preserve_unit_loops) {
675+
bool preserve_unit_loops, bool to_early_stage) {
630676
arith::Analyzer analyzer;
631-
ComputeAtOrReverseComputeAtImpl<true>(self, block_sref, loop_sref, preserve_unit_loops,
632-
&analyzer);
677+
ComputeAtOrReverseComputeAtImpl<true>(self, block_sref, loop_sref, preserve_unit_loops, &analyzer,
678+
false, to_early_stage);
633679
}
634680

635681
void ReverseComputeAt(ScheduleState self, const StmtSRef& block_sref, const StmtSRef& loop_sref,
@@ -671,20 +717,22 @@ struct ComputeAtTraits : public UnpackedInstTraits<ComputeAtTraits> {
671717

672718
private:
673719
static constexpr size_t kNumInputs = 2;
674-
static constexpr size_t kNumAttrs = 1;
720+
static constexpr size_t kNumAttrs = 2;
675721
static constexpr size_t kNumDecisions = 0;
676722

677723
static void UnpackedApplyToSchedule(Schedule sch, BlockRV block_rv, LoopRV loop_rv,
678-
Bool preserve_unit_loops) {
679-
return sch->ComputeAt(block_rv, loop_rv, preserve_unit_loops.operator bool());
724+
Bool preserve_unit_loops, Bool to_early_stage) {
725+
return sch->ComputeAt(block_rv, loop_rv, preserve_unit_loops.operator bool(),
726+
to_early_stage.operator bool());
680727
}
681728

682729
static String UnpackedAsPython(Array<String> outputs, String block_rv, String loop_rv,
683-
Bool preserve_unit_loops) {
730+
Bool preserve_unit_loops, Bool to_early_stage) {
684731
PythonAPICall py("compute_at");
685732
py.Input("block", block_rv);
686733
py.Input("loop", loop_rv);
687734
py.Input("preserve_unit_loops", preserve_unit_loops.operator bool());
735+
py.Input("to_early_stage", to_early_stage.operator bool());
688736
return py.Str();
689737
}
690738

src/tir/schedule/traced_schedule.cc

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -320,14 +320,15 @@ BlockRV TracedScheduleNode::ReIndex(const BlockRV& block_rv, int buffer_index,
320320
/******** Schedule: Compute location ********/
321321

322322
void TracedScheduleNode::ComputeAt(const BlockRV& block_rv, const LoopRV& loop_rv,
323-
bool preserve_unit_loops) {
324-
ConcreteScheduleNode::ComputeAt(block_rv, loop_rv, preserve_unit_loops);
323+
bool preserve_unit_loops, bool to_early_stage) {
324+
ConcreteScheduleNode::ComputeAt(block_rv, loop_rv, preserve_unit_loops, to_early_stage);
325325

326326
static const InstructionKind& kind = InstructionKind::Get("ComputeAt");
327-
trace_->Append(/*inst=*/Instruction(/*kind=*/kind,
328-
/*inputs=*/{block_rv, loop_rv},
329-
/*attrs=*/{Integer(preserve_unit_loops)},
330-
/*outputs=*/{}));
327+
trace_->Append(
328+
/*inst=*/Instruction(/*kind=*/kind,
329+
/*inputs=*/{block_rv, loop_rv},
330+
/*attrs=*/{Integer(preserve_unit_loops), Integer(to_early_stage)},
331+
/*outputs=*/{}));
331332
}
332333

333334
void TracedScheduleNode::ReverseComputeAt(const BlockRV& block_rv, const LoopRV& loop_rv,

src/tir/schedule/traced_schedule.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -79,7 +79,8 @@ class TracedScheduleNode : public ConcreteScheduleNode {
7979
BlockRV ReIndex(const BlockRV& block_rv, int buffer_index,
8080
BufferIndexType buffer_index_type) final;
8181
/******** Schedule: Compute location ********/
82-
void ComputeAt(const BlockRV& block_rv, const LoopRV& loop_rv, bool preserve_unit_loops) final;
82+
void ComputeAt(const BlockRV& block_rv, const LoopRV& loop_rv, bool preserve_unit_loops,
83+
bool to_early_stage = false) final;
8384
void ReverseComputeAt(const BlockRV& block_rv, const LoopRV& loop_rv,
8485
bool preserve_unit_loops) final;
8586
void ComputeInline(const BlockRV& block_rv) final;

tests/python/unittest/test_tir_schedule_compute_at.py

Lines changed: 89 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1353,5 +1353,94 @@ def _create_prim_func():
13531353
verify_trace_roundtrip(sch=sch, mod=mod)
13541354

13551355

1356+
def test_compute_at_to_early_stage():
1357+
@T.prim_func
1358+
def multi_producers_conv(
1359+
data: T.Buffer[(1, 3, 224, 224), "int8"],
1360+
w: T.Buffer[(16, 3, 7, 7), "int8"],
1361+
conv: T.Buffer[(1, 16, 112, 112), "int32"],
1362+
) -> None:
1363+
pad = T.alloc_buffer([1, 3, 230, 230], dtype="int8")
1364+
wbuf = T.alloc_buffer([16, 3, 7, 7], dtype="int8")
1365+
for i0, i1, i2, i3 in T.grid(1, 3, 230, 230):
1366+
with T.block("pad"):
1367+
i0_1, i1_1, i2_1, i3_1 = T.axis.remap("SSSS", [i0, i1, i2, i3])
1368+
T.reads(data[i0_1, i1_1, i2_1 - 3, i3_1 - 3])
1369+
T.writes(pad[i0_1, i1_1, i2_1, i3_1])
1370+
pad[i0_1, i1_1, i2_1, i3_1] = T.if_then_else(
1371+
3 <= i2_1 and i2_1 < 227 and 3 <= i3_1 and i3_1 < 227,
1372+
data[i0_1, i1_1, i2_1 - 3, i3_1 - 3],
1373+
T.int8(0),
1374+
dtype="int8",
1375+
)
1376+
for i0 in T.serial(1):
1377+
for ax0, ax1, ax2, ax3 in T.grid(16, 3, 7, 7):
1378+
with T.block("wbuf"):
1379+
v0, v1, v2, v3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3])
1380+
T.reads(w[v0, v1, v2, v3])
1381+
T.writes(wbuf[v0, v1, v2, v3])
1382+
wbuf[v0, v1, v2, v3] = w[v0, v1, v2, v3]
1383+
for i1, i2, i3, i4, i5, i6 in T.grid(16, 112, 112, 3, 7, 7):
1384+
with T.block("conv"):
1385+
nn, ff, yy, xx, rc, ry, rx = T.axis.remap(
1386+
"SSSSRRR", [i0, i1, i2, i3, i4, i5, i6]
1387+
)
1388+
T.reads(pad[nn, rc, yy * 2 + ry, xx * 2 + rx], wbuf[ff, rc, ry, rx])
1389+
T.writes(conv[nn, ff, yy, xx])
1390+
with T.init():
1391+
conv[nn, ff, yy, xx] = 0
1392+
conv[nn, ff, yy, xx] = conv[nn, ff, yy, xx] + T.cast(
1393+
pad[nn, rc, yy * 2 + ry, xx * 2 + rx], "int32"
1394+
) * T.cast(wbuf[ff, rc, ry, rx], "int32")
1395+
1396+
@T.prim_func
1397+
def multi_producers_after_compute_at(
1398+
data: T.Buffer[(1, 3, 224, 224), "int8"],
1399+
w: T.Buffer[(16, 3, 7, 7), "int8"],
1400+
conv: T.Buffer[(1, 16, 112, 112), "int32"],
1401+
) -> None:
1402+
pad = T.alloc_buffer([1, 3, 230, 230], dtype="int8")
1403+
wbuf = T.alloc_buffer([16, 3, 7, 7], dtype="int8")
1404+
for i0 in T.serial(1):
1405+
for ax0, ax1, ax2 in T.grid(3, 229, 229):
1406+
with T.block("pad"):
1407+
i0_1 = T.axis.spatial(1, 0)
1408+
i1_1 = T.axis.spatial(3, ax0)
1409+
i2_1 = T.axis.spatial(230, ax1)
1410+
i3_1 = T.axis.spatial(230, ax2)
1411+
T.reads(data[i0_1, i1_1, i2_1 - 3, i3_1 - 3])
1412+
T.writes(pad[i0_1, i1_1, i2_1, i3_1])
1413+
pad[i0_1, i1_1, i2_1, i3_1] = T.if_then_else(
1414+
3 <= i2_1 and i2_1 < 227 and 3 <= i3_1 and i3_1 < 227,
1415+
data[i0_1, i1_1, i2_1 - 3, i3_1 - 3],
1416+
T.int8(0),
1417+
dtype="int8",
1418+
)
1419+
for ax0, ax1, ax2, ax3 in T.grid(16, 3, 7, 7):
1420+
with T.block("wbuf"):
1421+
v0, v1, v2, v3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3])
1422+
T.reads(w[v0, v1, v2, v3])
1423+
T.writes(wbuf[v0, v1, v2, v3])
1424+
wbuf[v0, v1, v2, v3] = w[v0, v1, v2, v3]
1425+
for i1, i2, i3, i4, i5, i6 in T.grid(16, 112, 112, 3, 7, 7):
1426+
with T.block("conv"):
1427+
nn, ff, yy, xx, rc, ry, rx = T.axis.remap(
1428+
"SSSSRRR", [i0, i1, i2, i3, i4, i5, i6]
1429+
)
1430+
T.reads(pad[nn, rc, yy * 2 + ry, xx * 2 + rx], wbuf[ff, rc, ry, rx])
1431+
T.writes(conv[nn, ff, yy, xx])
1432+
with T.init():
1433+
conv[nn, ff, yy, xx] = 0
1434+
conv[nn, ff, yy, xx] = conv[nn, ff, yy, xx] + T.cast(
1435+
pad[nn, rc, yy * 2 + ry, xx * 2 + rx], "int32"
1436+
) * T.cast(wbuf[ff, rc, ry, rx], "int32")
1437+
1438+
sch = tir.Schedule(multi_producers_conv, debug_mask="all")
1439+
block_c = sch.get_block("pad")
1440+
axis = sch.get_loops("conv")[0]
1441+
sch.compute_at(block_c, axis, to_early_stage=True)
1442+
tvm.ir.assert_structural_equal(multi_producers_after_compute_at, sch.mod["main"])
1443+
1444+
13561445
if __name__ == "__main__":
13571446
tvm.testing.main()

0 commit comments

Comments
 (0)