Skip to content

Commit 1b9660f

Browse files
[TIR] Add merge primitive for TIR schedule
1 parent b3a5e18 commit 1b9660f

File tree

10 files changed

+512
-0
lines changed

10 files changed

+512
-0
lines changed

include/tvm/tir/schedule/schedule.h

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -292,6 +292,16 @@ class ScheduleNode : public runtime::Object {
292292
*/
293293
virtual Array<BlockRV> GetConsumers(const BlockRV& block_rv) = 0;
294294
/******** Schedule: Transform loops ********/
295+
/*!
296+
* \brief Merge a list of loops into one. The loops under their LCA requires:
297+
* 1) Under the same scope.
298+
* 2) Can't have annotations or thread bindings
299+
* 3) Start with 0 and have same domain.
300+
* 4) The inner loop must be the only child of the outer loop.
301+
* \param loop_rvs The loops to the loops to be merged
302+
* \return The new loop after merge
303+
*/
304+
virtual LoopRV Merge(const Array<LoopRV>& loop_rvs) = 0;
295305
/*!
296306
* \brief Fuse a list of consecutive loops into one. It requires:
297307
* 1) The loops can't have annotations or thread bindings.

python/tvm/tir/schedule/schedule.py

Lines changed: 78 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -541,6 +541,84 @@ def get_consumers(self, block: Union[BlockRV, str]) -> List[BlockRV]:
541541
return list(_ffi_api.ScheduleGetConsumers(self, block)) # type: ignore # pylint: disable=no-member
542542

543543
########## Schedule: Transform loops ##########
544+
@type_checked
545+
def merge(
546+
self,
547+
*loops: List[LoopRV],
548+
) -> LoopRV:
549+
"""Merge a list of loops into one. The loops under their LCA requires:
550+
1) Under the same scope
551+
2) Can't have annotations or thread bindings.
552+
3) Start with 0 and have same domain
553+
4) The inner loop must be the only child of the outer loop.
554+
555+
Parameters
556+
----------
557+
*loops : List[LoopRV]
558+
The loops to be merged
559+
560+
Returns
561+
-------
562+
fused_loop : LoopRV
563+
The new loop after merge
564+
565+
Examples
566+
--------
567+
568+
Before applying merge, in TensorIR, the IR is:
569+
570+
.. code-block:: python
571+
572+
@T.prim_func
573+
def before_merge(a: T.handle, b: T.handle, c: T.handle) -> None:
574+
A = T.match_buffer(a, (128, 128))
575+
B = T.match_buffer(b, (128, 128))
576+
C = T.match_buffer(c, (128, 128))
577+
for i, j in T.grid(128, 128):
578+
with T.block("B"):
579+
vi, vj = T.axis.remap("SS", [i, j])
580+
B[vi, vj] = A[vi, vj] * 2.0
581+
for i, j in T.grid(128, 128):
582+
with T.block("C"):
583+
vi, vj = T.axis.remap("SS", [i, j])
584+
C[vi, vj] = A[vi, vj] * 2.0
585+
586+
Create the schedule and do fuse:
587+
588+
.. code-block:: python
589+
590+
sch = tir.Schedule(before_fuse)
591+
i1, _ = sch.get_loops(sch.get_block("B"))
592+
i2, _ = sch.get_loops(sch.get_block("C"))
593+
sch.merge(i1, i2)
594+
print(sch.mod["main"].script())
595+
596+
After applying fuse, the IR becomes:
597+
598+
.. code-block:: python
599+
600+
@T.prim_func
601+
def after_fuse(a: T.handle, b: T.handle, c: T.handle) -> None:
602+
A = T.match_buffer(a, (128, 128))
603+
B = T.match_buffer(b, (128, 128))
604+
C = T.match_buffer(c, (128, 128))
605+
# the 2 loops are merged into 1
606+
for i_m in range(128):
607+
for j in range(128):
608+
with T.block("B"):
609+
vi, vj = T.axis.remap("SS", [i_m, j])
610+
T.reads(A[vi, vj])
611+
T.writes(B[vi, vj])
612+
B[vi, vj] = A[vi, vj] * T.float32(2)
613+
for j in range(128):
614+
with T.block("C"):
615+
vi, vj = T.axis.remap("SS", [i_m, j])
616+
T.reads(A[vi, vj])
617+
T.writes(C[vi, vj])
618+
C[vi, vj] = A[vi, vj] * T.float32(2)
619+
"""
620+
return _ffi_api.ScheduleMerge(self, loops) # type: ignore # pylint: disable=no-member
621+
544622
@type_checked
545623
def fuse(
546624
self,

src/tir/schedule/concrete_schedule.cc

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -356,6 +356,17 @@ Array<BlockRV> ConcreteScheduleNode::GetConsumers(const BlockRV& block_rv) {
356356

357357
/******** Schedule: Transform loops ********/
358358

359+
LoopRV ConcreteScheduleNode::Merge(const Array<LoopRV>& loop_rvs) {
360+
CHECK(!loop_rvs.empty()) << "ValueError: 'merge' requires at least 1 loop(s)";
361+
Array<StmtSRef> loop_srefs = this->GetSRefs(loop_rvs);
362+
StmtSRef result{nullptr};
363+
TVM_TIR_SCHEDULE_BEGIN();
364+
result = tir::Merge(state_, loop_srefs);
365+
TVM_TIR_SCHEDULE_END("merge", this->error_render_level_);
366+
this->state_->DebugVerify();
367+
return CreateRV<LoopRV>(result);
368+
}
369+
359370
LoopRV ConcreteScheduleNode::Fuse(const Array<LoopRV>& loop_rvs, bool preserve_unit_iters) {
360371
CHECK(!loop_rvs.empty()) << "ValueError: 'fuse' requires at least 1 loop(s)";
361372
Array<StmtSRef> loop_srefs = this->GetSRefs(loop_rvs);

src/tir/schedule/concrete_schedule.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -101,6 +101,7 @@ class ConcreteScheduleNode : public ScheduleNode {
101101
Array<BlockRV> GetConsumers(const BlockRV& block_rv) override;
102102
/******** Schedule: Transform loops ********/
103103
LoopRV Fuse(const Array<LoopRV>& loop_rvs, bool preserve_unit_iters) override;
104+
LoopRV Merge(const Array<LoopRV>& loop_rvs) override;
104105
Array<LoopRV> Split(const LoopRV& loop_rv, const Array<Optional<ExprRV>>& factors,
105106
bool preserve_unit_iters) override;
106107
void Reorder(const Array<LoopRV>& ordered_loop_rvs) override;

src/tir/schedule/primitive.h

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -161,6 +161,19 @@ Array<StmtSRef> GetConsumers(const ScheduleState& self, const StmtSRef& block_sr
161161
*/
162162
TVM_DLL Array<StmtSRef> Split(ScheduleState self, const StmtSRef& loop_sref,
163163
const Array<PrimExpr>& factors, bool preserve_unit_iters);
164+
165+
/*!
166+
* \brief Merge a list of loops into one. The loops under their LCA requires:
167+
* 1) Under the same scope.
168+
* 2) Can't have annotations or thread bindings
169+
* 3) Start with 0 and have same domain.
170+
* 4) The inner loop must be the only child of the outer loop.
171+
* \param self The state of the schedule
172+
* \param loop_srefs An array of srefs to the loops to be merged
173+
* \return The new loop after merge
174+
*/
175+
TVM_DLL StmtSRef Merge(ScheduleState self, const Array<StmtSRef>& loop_srefs);
176+
164177
/*!
165178
* \brief Fuse a list of consecutive loops into one. It requires:
166179
* 1) The loops can't have annotations or thread bindings.

src/tir/schedule/primitive/loop_transformation.cc

Lines changed: 190 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -451,6 +451,163 @@ Array<StmtSRef> Split(ScheduleState self, const StmtSRef& loop_sref, const Array
451451
return result_srefs;
452452
}
453453

454+
class LoopReconstructor : private StmtMutator {
455+
public:
456+
explicit LoopReconstructor(Block scope_root,
457+
const std::vector<std::vector<const ForNode*>>& loops)
458+
: scope_root_(scope_root), loops_(loops) {}
459+
460+
using StmtMutator::operator();
461+
462+
/*!
463+
* \brief Create the new nest loops induced by the given loops
464+
*/
465+
void MakeNewLoop() {
466+
Array<Var> new_loop_vars;
467+
Array<PrimExpr> new_loop_extents;
468+
Array<Stmt> new_stmts;
469+
for (size_t i = 0; i < loops_.size(); i++) {
470+
Map<Var, PrimExpr> var_map;
471+
for (size_t j = 0; j < loops_[i].size(); j++) {
472+
if (i == 0) {
473+
Var merged_var = loops_[i][j]->loop_var.copy_with_suffix("_m");
474+
new_loop_vars.push_back(merged_var);
475+
new_loop_extents.push_back(loops_[i][j]->extent);
476+
}
477+
var_map.Set(loops_[i][j]->loop_var, new_loop_vars[j]);
478+
}
479+
auto new_stmt = Substitute(loops_[i][0]->body, var_map);
480+
new_stmts.push_back(new_stmt);
481+
this->need_remove_loop_.push_back(loops_[i].back());
482+
}
483+
auto new_loop = For(new_loop_vars[0], Integer(0), new_loop_extents[0], ForKind::kSerial,
484+
SeqStmt(std::move(new_stmts)));
485+
this->new_inner_loop_ = new_loop;
486+
for (size_t i = 1; i < new_loop_vars.size(); ++i) {
487+
const Var& loop_var = new_loop_vars[i];
488+
const PrimExpr& loop_extent = new_loop_extents[i];
489+
new_loop = For(loop_var, Integer(0), loop_extent, ForKind::kSerial, new_loop);
490+
}
491+
this->new_outer_loop_ = new_loop;
492+
}
493+
494+
private:
495+
Stmt VisitStmt_(const BlockNode* block) final {
496+
if (block != scope_root_.get()) {
497+
return GetRef<Block>(block);
498+
}
499+
return StmtMutator::VisitStmt_(block);
500+
}
501+
502+
Stmt VisitStmt_(const ForNode* loop) final {
503+
if (loop == need_remove_loop_.back()) {
504+
return new_outer_loop_;
505+
} else if (std::count(need_remove_loop_.begin(), need_remove_loop_.end(), loop)) {
506+
return Evaluate(0);
507+
}
508+
return StmtMutator::VisitStmt_(loop);
509+
}
510+
511+
Stmt VisitStmt_(const SeqStmtNode* seq_stmt) final {
512+
auto ret = Downcast<SeqStmt>(StmtMutator::VisitSeqStmt_(seq_stmt, true));
513+
Array<Stmt> filtered;
514+
for (Stmt stmt : ret->seq) {
515+
if (!is_no_op(stmt)) {
516+
filtered.push_back(std::move(stmt));
517+
}
518+
}
519+
ret = SeqStmt(filtered);
520+
if (ret->size() == 0) {
521+
return Evaluate(0);
522+
} else if (ret->size() == 1) {
523+
return ret->seq[0];
524+
} else {
525+
return std::move(ret);
526+
}
527+
}
528+
529+
public:
530+
/*! \brief The root block of the block scope */
531+
Block scope_root_;
532+
/*! \brief The given loops to be merge */
533+
const std::vector<std::vector<const ForNode*>>& loops_;
534+
/*! \brief The outermost new loop to replace the original loop */
535+
For new_outer_loop_{nullptr};
536+
/*! \brief The innermost new loop to replace the original loop */
537+
For new_inner_loop_{nullptr};
538+
/*! \brief The loops to be removed */
539+
std::vector<const ForNode*> need_remove_loop_;
540+
};
541+
542+
StmtSRef Merge(ScheduleState self, const Array<StmtSRef>& loop_srefs) {
543+
// Invariance
544+
// - The total repeat number has not changed for each direct child block.
545+
// - The execution order has not changed. (The block executes with the same
546+
// args and the same order with before.)
547+
arith::Analyzer analyzer;
548+
StmtSRef scope_root_sref;
549+
StmtSRef lca = GetSRefLowestCommonAncestor(loop_srefs);
550+
std::vector<std::vector<const ForNode*>> lca_nest_loops;
551+
// Step 1. check correctness
552+
std::vector<const ForNode*> nest_loop_loops;
553+
std::vector<PrimExpr> nest_loop_extents;
554+
for (size_t i = 0; i < loop_srefs.size(); i++) {
555+
const StmtSRef& sref = loop_srefs[i];
556+
auto scope_root_sref_ = GetScopeRoot(self, sref, /*require_stage_pipeline=*/false);
557+
std::vector<PrimExpr> nest_loop_i_extents;
558+
std::vector<const ForNode*> nest_loop_i_loops;
559+
for (auto p = sref.get(); p != lca.get(); p = p->parent) {
560+
if (auto loop = p->StmtAs<ForNode>()) {
561+
if (!loop->annotations.empty() || loop->thread_binding.defined()) {
562+
throw HasAnnotationOrThreadBindingError(self->mod, GetRef<For>(loop));
563+
}
564+
CheckLoopStartsWithZero(self, GetRef<StmtSRef>(p), &analyzer);
565+
nest_loop_i_loops.push_back(loop);
566+
nest_loop_i_extents.push_back(loop->extent);
567+
}
568+
}
569+
lca_nest_loops.push_back(nest_loop_i_loops);
570+
const ForNode* outer_loop = nullptr;
571+
for (auto iter = nest_loop_i_loops.rbegin(); iter != nest_loop_i_loops.rend(); ++iter) {
572+
if (outer_loop && !outer_loop->body.same_as(GetRef<For>(*iter))) {
573+
throw NotOnlyChildError(self->mod, GetRef<For>(outer_loop), GetRef<For>(*iter));
574+
}
575+
outer_loop = *iter;
576+
}
577+
if (i == 0) {
578+
scope_root_sref = scope_root_sref_;
579+
nest_loop_loops = nest_loop_i_loops;
580+
nest_loop_extents = nest_loop_i_extents;
581+
} else {
582+
if (scope_root_sref_.get() != scope_root_sref.get()) {
583+
LOG(FATAL) << "ScheduleError: Expected the loops to be under the same block scope";
584+
throw;
585+
}
586+
if (nest_loop_i_extents.size() != nest_loop_extents.size()) {
587+
LOG(FATAL) << "ScheduleError: Merge loop's nesting depth must be same, but not";
588+
throw;
589+
} else {
590+
for (size_t j = 0; j < nest_loop_i_extents.size(); j++) {
591+
if (!analyzer.CanProveEqual(nest_loop_i_extents[j], nest_loop_extents[j])) {
592+
LOG(FATAL) << "ScheduleError: Merge loop's `extent` must be same, but not."
593+
<< "extent=[" << j << "," << nest_loop_extents[j] << ","
594+
<< nest_loop_i_extents[j] << "]";
595+
throw;
596+
}
597+
}
598+
}
599+
}
600+
}
601+
// Step 2. Create merged loops and replace the original loops
602+
Block scope_root = GetRef<Block>(scope_root_sref->StmtAs<BlockNode>());
603+
LoopReconstructor reconstructor(scope_root, lca_nest_loops);
604+
reconstructor.MakeNewLoop();
605+
Block new_scope_root = Downcast<Block>(reconstructor(scope_root));
606+
// Step 3. Do the actual replacement
607+
self->Replace(scope_root_sref, new_scope_root, {{scope_root, new_scope_root}});
608+
return self->stmt2ref.at(reconstructor.new_inner_loop_.get());
609+
}
610+
454611
StmtSRef Fuse(ScheduleState self, const Array<StmtSRef>& loop_srefs, bool preserve_unit_iters) {
455612
// Invariance
456613
// - The total repeat number has not changed for each direct child block.
@@ -795,6 +952,38 @@ struct SplitTraits : public UnpackedInstTraits<SplitTraits> {
795952
friend struct ::tvm::tir::UnpackedInstTraits;
796953
};
797954

955+
struct MergeTraits : public UnpackedInstTraits<MergeTraits> {
956+
static constexpr const char* kName = "Merge";
957+
static constexpr bool kIsPure = false;
958+
959+
private:
960+
static constexpr size_t kNumInputs = 1;
961+
static constexpr size_t kNumAttrs = 0;
962+
static constexpr size_t kNumDecisions = 0;
963+
964+
template <size_t delta>
965+
static TVM_ALWAYS_INLINE void _SetInputs(const runtime::TVMArgsSetter& setter,
966+
const Array<ObjectRef>& inputs) {
967+
setter(delta, inputs);
968+
}
969+
970+
static LoopRV UnpackedApplyToSchedule(Schedule sch, Array<LoopRV> loop_rvs) {
971+
return sch->Merge(loop_rvs);
972+
}
973+
974+
static String UnpackedAsPython(Array<String> outputs, Array<String> loop_rvs) {
975+
PythonAPICall py("merge");
976+
for (const String& loop_rv : loop_rvs) {
977+
py.Input("", loop_rv);
978+
}
979+
py.SingleOutput(outputs);
980+
return py.Str();
981+
}
982+
983+
template <typename>
984+
friend struct ::tvm::tir::UnpackedInstTraits;
985+
};
986+
798987
struct FuseTraits : public UnpackedInstTraits<FuseTraits> {
799988
static constexpr const char* kName = "Fuse";
800989
static constexpr bool kIsPure = false;
@@ -893,6 +1082,7 @@ struct AddUnitLoopTraits : public UnpackedInstTraits<AddUnitLoopTraits> {
8931082
};
8941083

8951084
TVM_REGISTER_INST_KIND_TRAITS(SplitTraits);
1085+
TVM_REGISTER_INST_KIND_TRAITS(MergeTraits);
8961086
TVM_REGISTER_INST_KIND_TRAITS(FuseTraits);
8971087
TVM_REGISTER_INST_KIND_TRAITS(ReorderTraits);
8981088
TVM_REGISTER_INST_KIND_TRAITS(AddUnitLoopTraits);

src/tir/schedule/schedule.cc

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -153,6 +153,7 @@ TVM_REGISTER_GLOBAL("tir.schedule.ScheduleGetProducers")
153153
TVM_REGISTER_GLOBAL("tir.schedule.ScheduleGetConsumers")
154154
.set_body_method<Schedule>(&ScheduleNode::GetConsumers);
155155
/******** (FFI) Transform loops ********/
156+
TVM_REGISTER_GLOBAL("tir.schedule.ScheduleMerge").set_body_method<Schedule>(&ScheduleNode::Merge);
156157
TVM_REGISTER_GLOBAL("tir.schedule.ScheduleFuse").set_body_method<Schedule>(&ScheduleNode::Fuse);
157158
TVM_REGISTER_GLOBAL("tir.schedule.ScheduleSplit").set_body_method<Schedule>(&ScheduleNode::Split);
158159
TVM_REGISTER_GLOBAL("tir.schedule.ScheduleReorder")

src/tir/schedule/traced_schedule.cc

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -176,6 +176,16 @@ Array<BlockRV> TracedScheduleNode::GetConsumers(const BlockRV& block_rv) {
176176

177177
/******** Schedule: Transform loops ********/
178178

179+
LoopRV TracedScheduleNode::Merge(const Array<LoopRV>& loop_rvs) {
180+
LoopRV result = ConcreteScheduleNode::Merge(loop_rvs);
181+
static const InstructionKind& kind = InstructionKind::Get("Merge");
182+
trace_->Append(/*inst=*/Instruction(/*kind=*/kind,
183+
/*inputs=*/{loop_rvs.begin(), loop_rvs.end()},
184+
/*attrs=*/{},
185+
/*outputs=*/{result}));
186+
return result;
187+
}
188+
179189
LoopRV TracedScheduleNode::Fuse(const Array<LoopRV>& loop_rvs, bool preserve_unit_loops) {
180190
LoopRV result = ConcreteScheduleNode::Fuse(loop_rvs, preserve_unit_loops);
181191

src/tir/schedule/traced_schedule.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,7 @@ class TracedScheduleNode : public ConcreteScheduleNode {
6161
Array<BlockRV> GetConsumers(const BlockRV& block_rv) final;
6262
/******** Schedule: Transform loops ********/
6363
LoopRV Fuse(const Array<LoopRV>& loop_rvs, bool preserve_unit_iters) final;
64+
LoopRV Merge(const Array<LoopRV>& loop_rvs) final;
6465
Array<LoopRV> Split(const LoopRV& loop_rv, const Array<Optional<ExprRV>>& factor_rvs,
6566
bool preserve_unit_iters) final;
6667
void Reorder(const Array<LoopRV>& ordered_loop_rvs) final;

0 commit comments

Comments
 (0)