Skip to content

Commit cdbb49d

Browse files
[TIR] Add merge primitive for TIR schedule
1 parent 8dea77a commit cdbb49d

File tree

10 files changed

+587
-0
lines changed

10 files changed

+587
-0
lines changed

include/tvm/tir/schedule/schedule.h

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -292,6 +292,16 @@ class ScheduleNode : public runtime::Object {
292292
*/
293293
virtual Array<BlockRV> GetConsumers(const BlockRV& block_rv) = 0;
294294
/******** Schedule: Transform loops ********/
295+
/*!
296+
* \brief Merge a list of loops into one. The loops under their LCA requires:
297+
* 1) Under the same scope
298+
* 2) Can't have annotations or thread bindings
299+
* 3) Start with 0 and have same extent and same nesting depth
300+
* 4) From target loop to their LCA, the inner loop must be the only child of the outer loop
301+
* \param loop_rvs The loops to be merged
302+
* \return The new loop after merge
303+
*/
304+
virtual LoopRV Merge(const Array<LoopRV>& loop_rvs) = 0;
295305
/*!
296306
* \brief Fuse a list of consecutive loops into one. It requires:
297307
* 1) The loops can't have annotations or thread bindings.

python/tvm/tir/schedule/schedule.py

Lines changed: 78 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -541,6 +541,84 @@ def get_consumers(self, block: Union[BlockRV, str]) -> List[BlockRV]:
541541
return list(_ffi_api.ScheduleGetConsumers(self, block)) # type: ignore # pylint: disable=no-member
542542

543543
########## Schedule: Transform loops ##########
544+
@type_checked
545+
def merge(
546+
self,
547+
*loops: List[LoopRV],
548+
) -> LoopRV:
549+
"""Merge a list of loops into one. The loops under their LCA requires:
550+
1) Under the same scope.
551+
2) Can't have annotations or thread bindings.
552+
3) Start with 0 and have same extent and same nesting depth.
553+
4) From target loop to their LCA, The inner loop must be the only child of the outer loop.
554+
555+
Parameters
556+
----------
557+
*loops : List[LoopRV]
558+
The loops to be merged
559+
560+
Returns
561+
-------
562+
fused_loop : LoopRV
563+
The new loop after merge
564+
565+
Examples
566+
--------
567+
568+
Before applying merge, in TensorIR, the IR is:
569+
570+
.. code-block:: python
571+
572+
@T.prim_func
573+
def before_merge(a: T.handle, b: T.handle, c: T.handle) -> None:
574+
A = T.match_buffer(a, (128, 128))
575+
B = T.match_buffer(b, (128, 128))
576+
C = T.match_buffer(c, (128, 128))
577+
for i, j in T.grid(128, 128):
578+
with T.block("B"):
579+
vi, vj = T.axis.remap("SS", [i, j])
580+
B[vi, vj] = A[vi, vj] * 2.0
581+
for i, j in T.grid(128, 128):
582+
with T.block("C"):
583+
vi, vj = T.axis.remap("SS", [i, j])
584+
C[vi, vj] = A[vi, vj] * 2.0
585+
586+
Create the schedule and do fuse:
587+
588+
.. code-block:: python
589+
590+
sch = tir.Schedule(before_fuse)
591+
i1, _ = sch.get_loops(sch.get_block("B"))
592+
i2, _ = sch.get_loops(sch.get_block("C"))
593+
sch.merge(i1, i2)
594+
print(sch.mod["main"].script())
595+
596+
After applying fuse, the IR becomes:
597+
598+
.. code-block:: python
599+
600+
@T.prim_func
601+
def after_fuse(a: T.handle, b: T.handle, c: T.handle) -> None:
602+
A = T.match_buffer(a, (128, 128))
603+
B = T.match_buffer(b, (128, 128))
604+
C = T.match_buffer(c, (128, 128))
605+
# the 2 loops are merged into 1
606+
for i_m in range(128):
607+
for j in range(128):
608+
with T.block("B"):
609+
vi, vj = T.axis.remap("SS", [i_m, j])
610+
T.reads(A[vi, vj])
611+
T.writes(B[vi, vj])
612+
B[vi, vj] = A[vi, vj] * T.float32(2)
613+
for j in range(128):
614+
with T.block("C"):
615+
vi, vj = T.axis.remap("SS", [i_m, j])
616+
T.reads(A[vi, vj])
617+
T.writes(C[vi, vj])
618+
C[vi, vj] = A[vi, vj] * T.float32(2)
619+
"""
620+
return _ffi_api.ScheduleMerge(self, loops) # type: ignore # pylint: disable=no-member
621+
544622
@type_checked
545623
def fuse(
546624
self,

src/tir/schedule/concrete_schedule.cc

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -356,6 +356,17 @@ Array<BlockRV> ConcreteScheduleNode::GetConsumers(const BlockRV& block_rv) {
356356

357357
/******** Schedule: Transform loops ********/
358358

359+
LoopRV ConcreteScheduleNode::Merge(const Array<LoopRV>& loop_rvs) {
360+
CHECK(loop_rvs.size() > 1) << "ValueError: 'merge' requires at least 2 loop(s)";
361+
Array<StmtSRef> loop_srefs = this->GetSRefs(loop_rvs);
362+
StmtSRef result{nullptr};
363+
TVM_TIR_SCHEDULE_BEGIN();
364+
result = tir::Merge(state_, loop_srefs);
365+
TVM_TIR_SCHEDULE_END("merge", this->error_render_level_);
366+
this->state_->DebugVerify();
367+
return CreateRV<LoopRV>(result);
368+
}
369+
359370
LoopRV ConcreteScheduleNode::Fuse(const Array<LoopRV>& loop_rvs, bool preserve_unit_iters) {
360371
CHECK(!loop_rvs.empty()) << "ValueError: 'fuse' requires at least 1 loop(s)";
361372
Array<StmtSRef> loop_srefs = this->GetSRefs(loop_rvs);

src/tir/schedule/concrete_schedule.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -101,6 +101,7 @@ class ConcreteScheduleNode : public ScheduleNode {
101101
Array<BlockRV> GetConsumers(const BlockRV& block_rv) override;
102102
/******** Schedule: Transform loops ********/
103103
LoopRV Fuse(const Array<LoopRV>& loop_rvs, bool preserve_unit_iters) override;
104+
LoopRV Merge(const Array<LoopRV>& loop_rvs) override;
104105
Array<LoopRV> Split(const LoopRV& loop_rv, const Array<Optional<ExprRV>>& factors,
105106
bool preserve_unit_iters) override;
106107
void Reorder(const Array<LoopRV>& ordered_loop_rvs) override;

src/tir/schedule/primitive.h

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -161,6 +161,19 @@ Array<StmtSRef> GetConsumers(const ScheduleState& self, const StmtSRef& block_sr
161161
*/
162162
TVM_DLL Array<StmtSRef> Split(ScheduleState self, const StmtSRef& loop_sref,
163163
const Array<PrimExpr>& factors, bool preserve_unit_iters);
164+
165+
/*!
166+
* \brief Merge a list of loops into one. The loops under their LCA requires:
167+
* 1) Under the same scope
168+
* 2) Can't have annotations or thread bindings
169+
* 3) Start with 0 and have same extent and same nesting depth
170+
* 4) From target loop to their LCA, the inner loop must be the only child of the outer loop
171+
* \param self The state of the schedule
172+
* \param loop_srefs An array of srefs to the loops to be merged
173+
* \return The new loop after merge
174+
*/
175+
TVM_DLL StmtSRef Merge(ScheduleState self, const Array<StmtSRef>& loop_srefs);
176+
164177
/*!
165178
* \brief Fuse a list of consecutive loops into one. It requires:
166179
* 1) The loops can't have annotations or thread bindings.

src/tir/schedule/primitive/loop_transformation.cc

Lines changed: 189 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -451,6 +451,162 @@ Array<StmtSRef> Split(ScheduleState self, const StmtSRef& loop_sref, const Array
451451
return result_srefs;
452452
}
453453

454+
class LoopReconstructor : private StmtMutator {
455+
public:
456+
explicit LoopReconstructor(Block scope_root, const std::vector<std::vector<For>>& loops)
457+
: scope_root_(scope_root), loops_(loops) {}
458+
459+
using StmtMutator::operator();
460+
461+
/*!
462+
* \brief Create the new nest loops induced by the given loops
463+
*/
464+
void MakeNewLoop() {
465+
Array<Var> new_loop_vars;
466+
Array<PrimExpr> new_loop_extents;
467+
Array<Stmt> new_stmts;
468+
for (size_t i = 0; i < loops_.size(); i++) {
469+
Map<Var, PrimExpr> var_map;
470+
for (size_t j = 0; j < loops_[i].size(); j++) {
471+
if (i == 0) {
472+
Var merged_var = loops_[i][j]->loop_var.copy_with_suffix("_m");
473+
new_loop_vars.push_back(merged_var);
474+
new_loop_extents.push_back(loops_[i][j]->extent);
475+
}
476+
var_map.Set(loops_[i][j]->loop_var, new_loop_vars[j]);
477+
}
478+
auto new_stmt = Substitute(loops_[i][0]->body, var_map);
479+
new_stmts.push_back(new_stmt);
480+
this->need_remove_loop_.push_back(loops_[i].back());
481+
}
482+
auto new_loop = For(new_loop_vars[0], Integer(0), new_loop_extents[0], ForKind::kSerial,
483+
SeqStmt(std::move(new_stmts)));
484+
this->new_inner_loop_ = new_loop;
485+
for (size_t i = 1; i < new_loop_vars.size(); ++i) {
486+
const Var& loop_var = new_loop_vars[i];
487+
const PrimExpr& loop_extent = new_loop_extents[i];
488+
new_loop = For(loop_var, Integer(0), loop_extent, ForKind::kSerial, new_loop);
489+
}
490+
this->new_outer_loop_ = new_loop;
491+
}
492+
493+
private:
494+
Stmt VisitStmt_(const BlockNode* block) final {
495+
if (block != scope_root_.get()) {
496+
return GetRef<Block>(block);
497+
}
498+
return StmtMutator::VisitStmt_(block);
499+
}
500+
501+
Stmt VisitStmt_(const ForNode* loop) final {
502+
if (GetRef<For>(loop) == need_remove_loop_.back()) {
503+
return new_outer_loop_;
504+
} else if (std::count(need_remove_loop_.begin(), need_remove_loop_.end(), GetRef<For>(loop))) {
505+
return Evaluate(0);
506+
}
507+
return StmtMutator::VisitStmt_(loop);
508+
}
509+
510+
Stmt VisitStmt_(const SeqStmtNode* seq_stmt) final {
511+
auto ret = Downcast<SeqStmt>(StmtMutator::VisitSeqStmt_(seq_stmt, true));
512+
Array<Stmt> filtered;
513+
for (Stmt stmt : ret->seq) {
514+
if (!is_no_op(stmt)) {
515+
filtered.push_back(std::move(stmt));
516+
}
517+
}
518+
ret = SeqStmt(filtered);
519+
if (ret->size() == 0) {
520+
return Evaluate(0);
521+
} else if (ret->size() == 1) {
522+
return ret->seq[0];
523+
} else {
524+
return std::move(ret);
525+
}
526+
}
527+
528+
public:
529+
/*! \brief The root block of the block scope */
530+
Block scope_root_;
531+
/*! \brief The given loops to be merge */
532+
const std::vector<std::vector<For>>& loops_;
533+
/*! \brief The outermost new loop to replace the original loop */
534+
For new_outer_loop_{nullptr};
535+
/*! \brief The innermost new loop to replace the original loop */
536+
For new_inner_loop_{nullptr};
537+
/*! \brief The loops to be removed */
538+
std::vector<For> need_remove_loop_;
539+
};
540+
541+
StmtSRef Merge(ScheduleState self, const Array<StmtSRef>& loop_srefs) {
542+
// Invariance
543+
// - The total repeat number has not changed for each direct child block.
544+
// - The execution order has not changed. (The block executes with the same
545+
// args and the same order with before.)
546+
arith::Analyzer analyzer;
547+
StmtSRef scope_root_sref;
548+
StmtSRef lca = GetSRefLowestCommonAncestor(loop_srefs);
549+
std::vector<std::vector<For>> lca_nest_loops;
550+
// Step 1. check correctness
551+
std::vector<For> nest_loop_loops;
552+
std::vector<PrimExpr> nest_loop_extents;
553+
for (size_t i = 0; i < loop_srefs.size(); i++) {
554+
const StmtSRef& sref = loop_srefs[i];
555+
auto scope_root_sref_ = GetScopeRoot(self, sref, /*require_stage_pipeline=*/false);
556+
std::vector<PrimExpr> nest_loop_i_extents;
557+
std::vector<For> nest_loop_i_loops;
558+
for (auto p = sref.get(); p != lca.get(); p = p->parent) {
559+
if (auto loop = p->StmtAs<ForNode>()) {
560+
if (!loop->annotations.empty() || loop->thread_binding.defined()) {
561+
throw HasAnnotationOrThreadBindingError(self->mod, GetRef<For>(loop));
562+
}
563+
CheckLoopStartsWithZero(self, GetRef<StmtSRef>(p), &analyzer);
564+
nest_loop_i_loops.push_back(GetRef<For>(loop));
565+
nest_loop_i_extents.push_back(loop->extent);
566+
}
567+
}
568+
lca_nest_loops.push_back(nest_loop_i_loops);
569+
const ForNode* outer_loop = nullptr;
570+
for (auto iter = nest_loop_i_loops.rbegin(); iter != nest_loop_i_loops.rend(); ++iter) {
571+
if (outer_loop && !outer_loop->body.same_as(*iter)) {
572+
throw NotOnlyChildError(self->mod, GetRef<For>(outer_loop), *iter);
573+
}
574+
outer_loop = (*iter).get();
575+
}
576+
if (i == 0) {
577+
scope_root_sref = scope_root_sref_;
578+
nest_loop_loops = nest_loop_i_loops;
579+
nest_loop_extents = nest_loop_i_extents;
580+
} else {
581+
if (scope_root_sref_.get() != scope_root_sref.get()) {
582+
LOG(FATAL) << "ScheduleError: Expected the loops to be under the same block scope.";
583+
throw;
584+
}
585+
if (nest_loop_i_extents.size() != nest_loop_extents.size()) {
586+
LOG(FATAL) << "ScheduleError: Merge loop's nesting depth must be same, but not.";
587+
throw;
588+
} else {
589+
for (size_t j = 0; j < nest_loop_i_extents.size(); j++) {
590+
if (!analyzer.CanProveEqual(nest_loop_i_extents[j], nest_loop_extents[j])) {
591+
LOG(FATAL) << "ScheduleError: Merge loop's `extent` must be same, but not."
592+
<< " extent=[" << j << "," << nest_loop_extents[j] << ","
593+
<< nest_loop_i_extents[j] << "]";
594+
throw;
595+
}
596+
}
597+
}
598+
}
599+
}
600+
// Step 2. Create merged loops and replace the original loops
601+
Block scope_root = GetRef<Block>(scope_root_sref->StmtAs<BlockNode>());
602+
LoopReconstructor reconstructor(scope_root, lca_nest_loops);
603+
reconstructor.MakeNewLoop();
604+
Block new_scope_root = Downcast<Block>(reconstructor(scope_root));
605+
// Step 3. Do the actual replacement
606+
self->Replace(scope_root_sref, new_scope_root, {{scope_root, new_scope_root}});
607+
return self->stmt2ref.at(reconstructor.new_inner_loop_.get());
608+
}
609+
454610
StmtSRef Fuse(ScheduleState self, const Array<StmtSRef>& loop_srefs, bool preserve_unit_iters) {
455611
// Invariance
456612
// - The total repeat number has not changed for each direct child block.
@@ -795,6 +951,38 @@ struct SplitTraits : public UnpackedInstTraits<SplitTraits> {
795951
friend struct ::tvm::tir::UnpackedInstTraits;
796952
};
797953

954+
struct MergeTraits : public UnpackedInstTraits<MergeTraits> {
955+
static constexpr const char* kName = "Merge";
956+
static constexpr bool kIsPure = false;
957+
958+
private:
959+
static constexpr size_t kNumInputs = 1;
960+
static constexpr size_t kNumAttrs = 0;
961+
static constexpr size_t kNumDecisions = 0;
962+
963+
template <size_t delta>
964+
static TVM_ALWAYS_INLINE void _SetInputs(const runtime::TVMArgsSetter& setter,
965+
const Array<ObjectRef>& inputs) {
966+
setter(delta, inputs);
967+
}
968+
969+
static LoopRV UnpackedApplyToSchedule(Schedule sch, Array<LoopRV> loop_rvs) {
970+
return sch->Merge(loop_rvs);
971+
}
972+
973+
static String UnpackedAsPython(Array<String> outputs, Array<String> loop_rvs) {
974+
PythonAPICall py("merge");
975+
for (const String& loop_rv : loop_rvs) {
976+
py.Input("", loop_rv);
977+
}
978+
py.SingleOutput(outputs);
979+
return py.Str();
980+
}
981+
982+
template <typename>
983+
friend struct ::tvm::tir::UnpackedInstTraits;
984+
};
985+
798986
struct FuseTraits : public UnpackedInstTraits<FuseTraits> {
799987
static constexpr const char* kName = "Fuse";
800988
static constexpr bool kIsPure = false;
@@ -893,6 +1081,7 @@ struct AddUnitLoopTraits : public UnpackedInstTraits<AddUnitLoopTraits> {
8931081
};
8941082

8951083
TVM_REGISTER_INST_KIND_TRAITS(SplitTraits);
1084+
TVM_REGISTER_INST_KIND_TRAITS(MergeTraits);
8961085
TVM_REGISTER_INST_KIND_TRAITS(FuseTraits);
8971086
TVM_REGISTER_INST_KIND_TRAITS(ReorderTraits);
8981087
TVM_REGISTER_INST_KIND_TRAITS(AddUnitLoopTraits);

src/tir/schedule/schedule.cc

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -153,6 +153,7 @@ TVM_REGISTER_GLOBAL("tir.schedule.ScheduleGetProducers")
153153
TVM_REGISTER_GLOBAL("tir.schedule.ScheduleGetConsumers")
154154
.set_body_method<Schedule>(&ScheduleNode::GetConsumers);
155155
/******** (FFI) Transform loops ********/
156+
TVM_REGISTER_GLOBAL("tir.schedule.ScheduleMerge").set_body_method<Schedule>(&ScheduleNode::Merge);
156157
TVM_REGISTER_GLOBAL("tir.schedule.ScheduleFuse").set_body_method<Schedule>(&ScheduleNode::Fuse);
157158
TVM_REGISTER_GLOBAL("tir.schedule.ScheduleSplit").set_body_method<Schedule>(&ScheduleNode::Split);
158159
TVM_REGISTER_GLOBAL("tir.schedule.ScheduleReorder")

src/tir/schedule/traced_schedule.cc

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -176,6 +176,16 @@ Array<BlockRV> TracedScheduleNode::GetConsumers(const BlockRV& block_rv) {
176176

177177
/******** Schedule: Transform loops ********/
178178

179+
LoopRV TracedScheduleNode::Merge(const Array<LoopRV>& loop_rvs) {
180+
LoopRV result = ConcreteScheduleNode::Merge(loop_rvs);
181+
static const InstructionKind& kind = InstructionKind::Get("Merge");
182+
trace_->Append(/*inst=*/Instruction(/*kind=*/kind,
183+
/*inputs=*/{loop_rvs.begin(), loop_rvs.end()},
184+
/*attrs=*/{},
185+
/*outputs=*/{result}));
186+
return result;
187+
}
188+
179189
LoopRV TracedScheduleNode::Fuse(const Array<LoopRV>& loop_rvs, bool preserve_unit_loops) {
180190
LoopRV result = ConcreteScheduleNode::Fuse(loop_rvs, preserve_unit_loops);
181191

src/tir/schedule/traced_schedule.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,7 @@ class TracedScheduleNode : public ConcreteScheduleNode {
6161
Array<BlockRV> GetConsumers(const BlockRV& block_rv) final;
6262
/******** Schedule: Transform loops ********/
6363
LoopRV Fuse(const Array<LoopRV>& loop_rvs, bool preserve_unit_iters) final;
64+
LoopRV Merge(const Array<LoopRV>& loop_rvs) final;
6465
Array<LoopRV> Split(const LoopRV& loop_rv, const Array<Optional<ExprRV>>& factor_rvs,
6566
bool preserve_unit_iters) final;
6667
void Reorder(const Array<LoopRV>& ordered_loop_rvs) final;

0 commit comments

Comments
 (0)