Skip to content

Commit 2c1bb3d

Browse files
[TIR] Add merge primitive for TIR schedule
1 parent 0d0d2f0 commit 2c1bb3d

File tree

10 files changed

+514
-0
lines changed

10 files changed

+514
-0
lines changed

include/tvm/tir/schedule/schedule.h

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -292,6 +292,16 @@ class ScheduleNode : public runtime::Object {
292292
*/
293293
virtual Array<BlockRV> GetConsumers(const BlockRV& block_rv) = 0;
294294
/******** Schedule: Transform loops ********/
295+
/*!
296+
* \brief Merge a list of loops into one. The loops under their LCA requires:
297+
* 1) Under the same scope.
298+
* 2) Can't have annotations or thread bindings
299+
* 3) Start with 0 and have same domain.
300+
* 4) The inner loop must be the only child of the outer loop.
301+
* \param loop_rvs The loops to the loops to be merged
302+
* \return The new loop after merge
303+
*/
304+
virtual LoopRV Merge(const Array<LoopRV>& loop_rvs) = 0;
295305
/*!
296306
* \brief Fuse a list of consecutive loops into one. It requires:
297307
* 1) The loops can't have annotations or thread bindings.

python/tvm/tir/schedule/schedule.py

Lines changed: 78 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -541,6 +541,84 @@ def get_consumers(self, block: Union[BlockRV, str]) -> List[BlockRV]:
541541
return list(_ffi_api.ScheduleGetConsumers(self, block)) # type: ignore # pylint: disable=no-member
542542

543543
########## Schedule: Transform loops ##########
544+
@type_checked
545+
def merge(
546+
self,
547+
*loops: List[LoopRV],
548+
) -> LoopRV:
549+
"""Merge a list of loops into one. The loops under their LCA requires:
550+
1) Under the same scope
551+
2) Can't have annotations or thread bindings.
552+
3) Start with 0 and have same domain
553+
4) The inner loop must be the only child of the outer loop.
554+
555+
Parameters
556+
----------
557+
*loops : List[LoopRV]
558+
The loops to be merged
559+
560+
Returns
561+
-------
562+
fused_loop : LoopRV
563+
The new loop after merge
564+
565+
Examples
566+
--------
567+
568+
Before applying merge, in TensorIR, the IR is:
569+
570+
.. code-block:: python
571+
572+
@T.prim_func
573+
def before_merge(a: T.handle, b: T.handle, c: T.handle) -> None:
574+
A = T.match_buffer(a, (128, 128))
575+
B = T.match_buffer(b, (128, 128))
576+
C = T.match_buffer(c, (128, 128))
577+
for i, j in T.grid(128, 128):
578+
with T.block("B"):
579+
vi, vj = T.axis.remap("SS", [i, j])
580+
B[vi, vj] = A[vi, vj] * 2.0
581+
for i, j in T.grid(128, 128):
582+
with T.block("C"):
583+
vi, vj = T.axis.remap("SS", [i, j])
584+
C[vi, vj] = A[vi, vj] * 2.0
585+
586+
Create the schedule and do fuse:
587+
588+
.. code-block:: python
589+
590+
sch = tir.Schedule(before_fuse)
591+
i1, _ = sch.get_loops(sch.get_block("B"))
592+
i2, _ = sch.get_loops(sch.get_block("C"))
593+
sch.merge(i1, i2)
594+
print(sch.mod["main"].script())
595+
596+
After applying fuse, the IR becomes:
597+
598+
.. code-block:: python
599+
600+
@T.prim_func
601+
def after_fuse(a: T.handle, b: T.handle, c: T.handle) -> None:
602+
A = T.match_buffer(a, (128, 128))
603+
B = T.match_buffer(b, (128, 128))
604+
C = T.match_buffer(c, (128, 128))
605+
# the 2 loops are merged into 1
606+
for i_m in range(128):
607+
for j in range(128):
608+
with T.block("B"):
609+
vi, vj = T.axis.remap("SS", [i_m, j])
610+
T.reads(A[vi, vj])
611+
T.writes(B[vi, vj])
612+
B[vi, vj] = A[vi, vj] * T.float32(2)
613+
for j in range(128):
614+
with T.block("C"):
615+
vi, vj = T.axis.remap("SS", [i_m, j])
616+
T.reads(A[vi, vj])
617+
T.writes(C[vi, vj])
618+
C[vi, vj] = A[vi, vj] * T.float32(2)
619+
"""
620+
return _ffi_api.ScheduleMerge(self, loops) # type: ignore # pylint: disable=no-member
621+
544622
@type_checked
545623
def fuse(
546624
self,

src/tir/schedule/concrete_schedule.cc

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -356,6 +356,17 @@ Array<BlockRV> ConcreteScheduleNode::GetConsumers(const BlockRV& block_rv) {
356356

357357
/******** Schedule: Transform loops ********/
358358

359+
LoopRV ConcreteScheduleNode::Merge(const Array<LoopRV>& loop_rvs) {
360+
CHECK(!loop_rvs.empty()) << "ValueError: 'merge' requires at least 1 loop(s)";
361+
Array<StmtSRef> loop_srefs = this->GetSRefs(loop_rvs);
362+
StmtSRef result{nullptr};
363+
TVM_TIR_SCHEDULE_BEGIN();
364+
result = tir::Merge(state_, loop_srefs);
365+
TVM_TIR_SCHEDULE_END("merge", this->error_render_level_);
366+
this->state_->DebugVerify();
367+
return CreateRV<LoopRV>(result);
368+
}
369+
359370
LoopRV ConcreteScheduleNode::Fuse(const Array<LoopRV>& loop_rvs, bool preserve_unit_iters) {
360371
CHECK(!loop_rvs.empty()) << "ValueError: 'fuse' requires at least 1 loop(s)";
361372
Array<StmtSRef> loop_srefs = this->GetSRefs(loop_rvs);

src/tir/schedule/concrete_schedule.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -101,6 +101,7 @@ class ConcreteScheduleNode : public ScheduleNode {
101101
Array<BlockRV> GetConsumers(const BlockRV& block_rv) override;
102102
/******** Schedule: Transform loops ********/
103103
LoopRV Fuse(const Array<LoopRV>& loop_rvs, bool preserve_unit_iters) override;
104+
LoopRV Merge(const Array<LoopRV>& loop_rvs) override;
104105
Array<LoopRV> Split(const LoopRV& loop_rv, const Array<Optional<ExprRV>>& factors,
105106
bool preserve_unit_iters) override;
106107
void Reorder(const Array<LoopRV>& ordered_loop_rvs) override;

src/tir/schedule/primitive.h

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -161,6 +161,19 @@ Array<StmtSRef> GetConsumers(const ScheduleState& self, const StmtSRef& block_sr
161161
*/
162162
TVM_DLL Array<StmtSRef> Split(ScheduleState self, const StmtSRef& loop_sref,
163163
const Array<PrimExpr>& factors, bool preserve_unit_iters);
164+
165+
/*!
166+
* \brief Merge a list of loops into one. The loops under their LCA requires:
167+
* 1) Under the same scope.
168+
* 2) Can't have annotations or thread bindings
169+
* 3) Start with 0 and have same domain.
170+
* 4) The inner loop must be the only child of the outer loop.
171+
* \param self The state of the schedule
172+
* \param loop_srefs An array of srefs to the loops to be merged
173+
* \return The new loop after merge
174+
*/
175+
TVM_DLL StmtSRef Merge(ScheduleState self, const Array<StmtSRef>& loop_srefs);
176+
164177
/*!
165178
* \brief Fuse a list of consecutive loops into one. It requires:
166179
* 1) The loops can't have annotations or thread bindings.

src/tir/schedule/primitive/loop_transformation.cc

Lines changed: 192 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -451,6 +451,165 @@ Array<StmtSRef> Split(ScheduleState self, const StmtSRef& loop_sref, const Array
451451
return result_srefs;
452452
}
453453

454+
class LoopReconstructor : private StmtMutator {
455+
public:
456+
explicit LoopReconstructor(Block scope_root, std::vector<std::vector<const ForNode*>> loops)
457+
: scope_root_(scope_root), loops_(loops) {}
458+
459+
using StmtMutator::operator();
460+
461+
/*!
462+
* \brief Create the new nest loops induced by the given loops
463+
*/
464+
void MakeNewLoop() {
465+
Array<Var> new_loop_vars;
466+
Array<PrimExpr> new_loop_extents;
467+
Array<Stmt> new_stmts;
468+
for (size_t i = 0; i < loops_.size(); i++) {
469+
Map<Var, PrimExpr> var_map;
470+
for (size_t j = 0; j < loops_[i].size(); j++) {
471+
if (i == 0) {
472+
std::string suffix = loops_[i][j]->loop_var->name_hint;
473+
suffix += "_m";
474+
int bits = loops_[i][j]->loop_var.dtype().bits();
475+
Var merged_var(suffix, DataType::Int(bits));
476+
new_loop_vars.push_back(merged_var);
477+
new_loop_extents.push_back(loops_[i][j]->extent);
478+
}
479+
var_map.Set(loops_[i][j]->loop_var, new_loop_vars[j]);
480+
}
481+
auto new_stmt = Substitute(loops_[i][0]->body, var_map);
482+
new_stmts.push_back(new_stmt);
483+
this->need_remove_loop_.push_back(loops_[i].back());
484+
}
485+
auto new_loop = For(new_loop_vars[0], Integer(0), new_loop_extents[0], ForKind::kSerial,
486+
SeqStmt(std::move(new_stmts)));
487+
this->new_inner_loop_ = new_loop;
488+
for (size_t i = 1; i < new_loop_vars.size(); ++i) {
489+
const Var& loop_var = new_loop_vars[i];
490+
const PrimExpr& loop_extent = new_loop_extents[i];
491+
new_loop = For(loop_var, Integer(0), loop_extent, ForKind::kSerial, new_loop);
492+
}
493+
this->new_outer_loop_ = new_loop;
494+
}
495+
496+
private:
497+
Stmt VisitStmt_(const BlockNode* block) final {
498+
if (block != scope_root_.get()) {
499+
return GetRef<Block>(block);
500+
}
501+
return StmtMutator::VisitStmt_(block);
502+
}
503+
504+
Stmt VisitStmt_(const ForNode* loop) final {
505+
if (loop == need_remove_loop_.back()) {
506+
return new_outer_loop_;
507+
} else if (std::count(need_remove_loop_.begin(), need_remove_loop_.end(), loop)) {
508+
return Evaluate(0);
509+
}
510+
return StmtMutator::VisitStmt_(loop);
511+
}
512+
513+
Stmt VisitStmt_(const SeqStmtNode* seq_stmt) final {
514+
auto ret = Downcast<SeqStmt>(StmtMutator::VisitSeqStmt_(seq_stmt, true));
515+
Array<Stmt> filtered;
516+
for (Stmt stmt : ret->seq) {
517+
if (!is_no_op(stmt)) {
518+
filtered.push_back(std::move(stmt));
519+
}
520+
}
521+
ret = SeqStmt(filtered);
522+
if (ret->size() == 0) {
523+
return Evaluate(0);
524+
} else if (ret->size() == 1) {
525+
return ret->seq[0];
526+
} else {
527+
return std::move(ret);
528+
}
529+
}
530+
531+
public:
532+
/*! \brief The root block of the block scope */
533+
Block scope_root_;
534+
/*! \brief The given loops to be merge */
535+
std::vector<std::vector<const ForNode*>> loops_;
536+
/*! \brief The outermost new loop to replace the original loop */
537+
For new_outer_loop_{nullptr};
538+
/*! \brief The innermost new loop to replace the original loop */
539+
For new_inner_loop_{nullptr};
540+
/*! \brief The loops to be removed */
541+
std::vector<const ForNode*> need_remove_loop_;
542+
};
543+
544+
StmtSRef Merge(ScheduleState self, const Array<StmtSRef>& loop_srefs) {
545+
// Invariance
546+
// - The total repeat number has not changed for each direct child block.
547+
// - The execution order has not changed. (The block executes with the same
548+
// args and the same order with before.)
549+
arith::Analyzer analyzer;
550+
StmtSRef scope_root_sref;
551+
StmtSRef lca = GetSRefLowestCommonAncestor(loop_srefs);
552+
std::vector<std::vector<const ForNode*>> lca_nest_loops;
553+
// Step 1. check correctness
554+
std::vector<const ForNode*> nest_loop_loops;
555+
std::vector<PrimExpr> nest_loop_extents;
556+
for (size_t i = 0; i < loop_srefs.size(); i++) {
557+
const StmtSRef& sref = loop_srefs[i];
558+
auto scope_root_sref_ = GetScopeRoot(self, sref, /*require_stage_pipeline=*/false);
559+
std::vector<PrimExpr> nest_loop_i_extents;
560+
std::vector<const ForNode*> nest_loop_i_loops;
561+
for (auto p = sref.get(); p != lca.get(); p = p->parent) {
562+
if (auto loop = p->StmtAs<ForNode>()) {
563+
if (!loop->annotations.empty() || loop->thread_binding.defined()) {
564+
throw HasAnnotationOrThreadBindingError(self->mod, GetRef<For>(loop));
565+
}
566+
CheckLoopStartsWithZero(self, GetRef<StmtSRef>(p), &analyzer);
567+
nest_loop_i_loops.push_back(loop);
568+
nest_loop_i_extents.push_back(loop->extent);
569+
}
570+
}
571+
lca_nest_loops.push_back(nest_loop_i_loops);
572+
const ForNode* outer_loop = nullptr;
573+
for (auto iter = nest_loop_i_loops.rbegin(); iter != nest_loop_i_loops.rend(); ++iter) {
574+
if (outer_loop && !outer_loop->body.same_as(GetRef<For>(*iter))) {
575+
throw NotOnlyChildError(self->mod, GetRef<For>(outer_loop), GetRef<For>(*iter));
576+
}
577+
outer_loop = *iter;
578+
}
579+
if (i == 0) {
580+
scope_root_sref = scope_root_sref_;
581+
nest_loop_loops = nest_loop_i_loops;
582+
nest_loop_extents = nest_loop_i_extents;
583+
} else {
584+
if (scope_root_sref_.get() != scope_root_sref.get()) {
585+
LOG(FATAL) << "ScheduleError: Expected the loops to be under the same block scope";
586+
throw;
587+
}
588+
if (nest_loop_i_extents.size() != nest_loop_extents.size()) {
589+
LOG(FATAL) << "ScheduleError: Merge loop's nesting depth must be same, but not";
590+
throw;
591+
} else {
592+
for (size_t j = 0; j < nest_loop_i_extents.size(); j++) {
593+
if (!analyzer.CanProveEqual(nest_loop_i_extents[j], nest_loop_extents[j])) {
594+
LOG(FATAL) << "ScheduleError: Merge loop's `extent` must be same, but not."
595+
<< "extent=[" << j << "," << nest_loop_extents[j] << ","
596+
<< nest_loop_i_extents[j] << "]";
597+
throw;
598+
}
599+
}
600+
}
601+
}
602+
}
603+
// Step 2. Create merged loops and replace the original loops
604+
Block scope_root = GetRef<Block>(scope_root_sref->StmtAs<BlockNode>());
605+
LoopReconstructor reconstructor(scope_root, lca_nest_loops);
606+
reconstructor.MakeNewLoop();
607+
Block new_scope_root = Downcast<Block>(reconstructor(scope_root));
608+
// Step 3. Do the actual replacement
609+
self->Replace(scope_root_sref, new_scope_root, {{scope_root, new_scope_root}});
610+
return self->stmt2ref.at(reconstructor.new_inner_loop_.get());
611+
}
612+
454613
StmtSRef Fuse(ScheduleState self, const Array<StmtSRef>& loop_srefs, bool preserve_unit_iters) {
455614
// Invariance
456615
// - The total repeat number has not changed for each direct child block.
@@ -795,6 +954,38 @@ struct SplitTraits : public UnpackedInstTraits<SplitTraits> {
795954
friend struct ::tvm::tir::UnpackedInstTraits;
796955
};
797956

957+
struct MergeTraits : public UnpackedInstTraits<MergeTraits> {
958+
static constexpr const char* kName = "Merge";
959+
static constexpr bool kIsPure = false;
960+
961+
private:
962+
static constexpr size_t kNumInputs = 1;
963+
static constexpr size_t kNumAttrs = 0;
964+
static constexpr size_t kNumDecisions = 0;
965+
966+
template <size_t delta>
967+
static TVM_ALWAYS_INLINE void _SetInputs(const runtime::TVMArgsSetter& setter,
968+
const Array<ObjectRef>& inputs) {
969+
setter(delta, inputs);
970+
}
971+
972+
static LoopRV UnpackedApplyToSchedule(Schedule sch, Array<LoopRV> loop_rvs) {
973+
return sch->Merge(loop_rvs);
974+
}
975+
976+
static String UnpackedAsPython(Array<String> outputs, Array<String> loop_rvs) {
977+
PythonAPICall py("merge");
978+
for (const String& loop_rv : loop_rvs) {
979+
py.Input("", loop_rv);
980+
}
981+
py.SingleOutput(outputs);
982+
return py.Str();
983+
}
984+
985+
template <typename>
986+
friend struct ::tvm::tir::UnpackedInstTraits;
987+
};
988+
798989
struct FuseTraits : public UnpackedInstTraits<FuseTraits> {
799990
static constexpr const char* kName = "Fuse";
800991
static constexpr bool kIsPure = false;
@@ -893,6 +1084,7 @@ struct AddUnitLoopTraits : public UnpackedInstTraits<AddUnitLoopTraits> {
8931084
};
8941085

8951086
TVM_REGISTER_INST_KIND_TRAITS(SplitTraits);
1087+
TVM_REGISTER_INST_KIND_TRAITS(MergeTraits);
8961088
TVM_REGISTER_INST_KIND_TRAITS(FuseTraits);
8971089
TVM_REGISTER_INST_KIND_TRAITS(ReorderTraits);
8981090
TVM_REGISTER_INST_KIND_TRAITS(AddUnitLoopTraits);

src/tir/schedule/schedule.cc

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -153,6 +153,7 @@ TVM_REGISTER_GLOBAL("tir.schedule.ScheduleGetProducers")
153153
TVM_REGISTER_GLOBAL("tir.schedule.ScheduleGetConsumers")
154154
.set_body_method<Schedule>(&ScheduleNode::GetConsumers);
155155
/******** (FFI) Transform loops ********/
156+
TVM_REGISTER_GLOBAL("tir.schedule.ScheduleMerge").set_body_method<Schedule>(&ScheduleNode::Merge);
156157
TVM_REGISTER_GLOBAL("tir.schedule.ScheduleFuse").set_body_method<Schedule>(&ScheduleNode::Fuse);
157158
TVM_REGISTER_GLOBAL("tir.schedule.ScheduleSplit").set_body_method<Schedule>(&ScheduleNode::Split);
158159
TVM_REGISTER_GLOBAL("tir.schedule.ScheduleReorder")

src/tir/schedule/traced_schedule.cc

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -176,6 +176,16 @@ Array<BlockRV> TracedScheduleNode::GetConsumers(const BlockRV& block_rv) {
176176

177177
/******** Schedule: Transform loops ********/
178178

179+
LoopRV TracedScheduleNode::Merge(const Array<LoopRV>& loop_rvs) {
180+
LoopRV result = ConcreteScheduleNode::Merge(loop_rvs);
181+
static const InstructionKind& kind = InstructionKind::Get("Merge");
182+
trace_->Append(/*inst=*/Instruction(/*kind=*/kind,
183+
/*inputs=*/{loop_rvs.begin(), loop_rvs.end()},
184+
/*attrs=*/{},
185+
/*outputs=*/{result}));
186+
return result;
187+
}
188+
179189
LoopRV TracedScheduleNode::Fuse(const Array<LoopRV>& loop_rvs, bool preserve_unit_loops) {
180190
LoopRV result = ConcreteScheduleNode::Fuse(loop_rvs, preserve_unit_loops);
181191

src/tir/schedule/traced_schedule.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,7 @@ class TracedScheduleNode : public ConcreteScheduleNode {
6161
Array<BlockRV> GetConsumers(const BlockRV& block_rv) final;
6262
/******** Schedule: Transform loops ********/
6363
LoopRV Fuse(const Array<LoopRV>& loop_rvs, bool preserve_unit_iters) final;
64+
LoopRV Merge(const Array<LoopRV>& loop_rvs) final;
6465
Array<LoopRV> Split(const LoopRV& loop_rv, const Array<Optional<ExprRV>>& factor_rvs,
6566
bool preserve_unit_iters) final;
6667
void Reorder(const Array<LoopRV>& ordered_loop_rvs) final;

0 commit comments

Comments
 (0)