@@ -451,6 +451,162 @@ Array<StmtSRef> Split(ScheduleState self, const StmtSRef& loop_sref, const Array
451451 return result_srefs;
452452}
453453
454+ class LoopReconstructor : private StmtMutator {
455+ public:
456+ explicit LoopReconstructor (Block scope_root, const std::vector<std::vector<For>>& loops)
457+ : scope_root_(scope_root), loops_(loops) {}
458+
459+ using StmtMutator::operator ();
460+
461+ /* !
462+ * \brief Create the new nest loops induced by the given loops
463+ */
464+ void MakeNewLoop () {
465+ Array<Var> new_loop_vars;
466+ Array<PrimExpr> new_loop_extents;
467+ Array<Stmt> new_stmts;
468+ for (size_t i = 0 ; i < loops_.size (); i++) {
469+ Map<Var, PrimExpr> var_map;
470+ for (size_t j = 0 ; j < loops_[i].size (); j++) {
471+ if (i == 0 ) {
472+ Var merged_var = loops_[i][j]->loop_var .copy_with_suffix (" _m" );
473+ new_loop_vars.push_back (merged_var);
474+ new_loop_extents.push_back (loops_[i][j]->extent );
475+ }
476+ var_map.Set (loops_[i][j]->loop_var , new_loop_vars[j]);
477+ }
478+ auto new_stmt = Substitute (loops_[i][0 ]->body , var_map);
479+ new_stmts.push_back (new_stmt);
480+ this ->need_remove_loop_ .push_back (loops_[i].back ());
481+ }
482+ auto new_loop = For (new_loop_vars[0 ], Integer (0 ), new_loop_extents[0 ], ForKind::kSerial ,
483+ SeqStmt (std::move (new_stmts)));
484+ this ->new_inner_loop_ = new_loop;
485+ for (size_t i = 1 ; i < new_loop_vars.size (); ++i) {
486+ const Var& loop_var = new_loop_vars[i];
487+ const PrimExpr& loop_extent = new_loop_extents[i];
488+ new_loop = For (loop_var, Integer (0 ), loop_extent, ForKind::kSerial , new_loop);
489+ }
490+ this ->new_outer_loop_ = new_loop;
491+ }
492+
493+ private:
494+ Stmt VisitStmt_ (const BlockNode* block) final {
495+ if (block != scope_root_.get ()) {
496+ return GetRef<Block>(block);
497+ }
498+ return StmtMutator::VisitStmt_ (block);
499+ }
500+
501+ Stmt VisitStmt_ (const ForNode* loop) final {
502+ if (GetRef<For>(loop) == need_remove_loop_.back ()) {
503+ return new_outer_loop_;
504+ } else if (std::count (need_remove_loop_.begin (), need_remove_loop_.end (), GetRef<For>(loop))) {
505+ return Evaluate (0 );
506+ }
507+ return StmtMutator::VisitStmt_ (loop);
508+ }
509+
510+ Stmt VisitStmt_ (const SeqStmtNode* seq_stmt) final {
511+ auto ret = Downcast<SeqStmt>(StmtMutator::VisitSeqStmt_ (seq_stmt, true ));
512+ Array<Stmt> filtered;
513+ for (Stmt stmt : ret->seq ) {
514+ if (!is_no_op (stmt)) {
515+ filtered.push_back (std::move (stmt));
516+ }
517+ }
518+ ret = SeqStmt (filtered);
519+ if (ret->size () == 0 ) {
520+ return Evaluate (0 );
521+ } else if (ret->size () == 1 ) {
522+ return ret->seq [0 ];
523+ } else {
524+ return std::move (ret);
525+ }
526+ }
527+
528+ public:
529+ /* ! \brief The root block of the block scope */
530+ Block scope_root_;
531+ /* ! \brief The given loops to be merge */
532+ const std::vector<std::vector<For>>& loops_;
533+ /* ! \brief The outermost new loop to replace the original loop */
534+ For new_outer_loop_{nullptr };
535+ /* ! \brief The innermost new loop to replace the original loop */
536+ For new_inner_loop_{nullptr };
537+ /* ! \brief The loops to be removed */
538+ std::vector<For> need_remove_loop_;
539+ };
540+
541+ StmtSRef Merge (ScheduleState self, const Array<StmtSRef>& loop_srefs) {
542+ // Invariance
543+ // - The total repeat number has not changed for each direct child block.
544+ // - The execution order has not changed. (The block executes with the same
545+ // args and the same order with before.)
546+ arith::Analyzer analyzer;
547+ StmtSRef scope_root_sref;
548+ StmtSRef lca = GetSRefLowestCommonAncestor (loop_srefs);
549+ std::vector<std::vector<For>> lca_nest_loops;
550+ // Step 1. check correctness
551+ std::vector<For> nest_loop_loops;
552+ std::vector<PrimExpr> nest_loop_extents;
553+ for (size_t i = 0 ; i < loop_srefs.size (); i++) {
554+ const StmtSRef& sref = loop_srefs[i];
555+ auto scope_root_sref_ = GetScopeRoot (self, sref, /* require_stage_pipeline=*/ false );
556+ std::vector<PrimExpr> nest_loop_i_extents;
557+ std::vector<For> nest_loop_i_loops;
558+ for (auto p = sref.get (); p != lca.get (); p = p->parent ) {
559+ if (auto loop = p->StmtAs <ForNode>()) {
560+ if (!loop->annotations .empty () || loop->thread_binding .defined ()) {
561+ throw HasAnnotationOrThreadBindingError (self->mod , GetRef<For>(loop));
562+ }
563+ CheckLoopStartsWithZero (self, GetRef<StmtSRef>(p), &analyzer);
564+ nest_loop_i_loops.push_back (GetRef<For>(loop));
565+ nest_loop_i_extents.push_back (loop->extent );
566+ }
567+ }
568+ lca_nest_loops.push_back (nest_loop_i_loops);
569+ const ForNode* outer_loop = nullptr ;
570+ for (auto iter = nest_loop_i_loops.rbegin (); iter != nest_loop_i_loops.rend (); ++iter) {
571+ if (outer_loop && !outer_loop->body .same_as (*iter)) {
572+ throw NotOnlyChildError (self->mod , GetRef<For>(outer_loop), *iter);
573+ }
574+ outer_loop = (*iter).get ();
575+ }
576+ if (i == 0 ) {
577+ scope_root_sref = scope_root_sref_;
578+ nest_loop_loops = nest_loop_i_loops;
579+ nest_loop_extents = nest_loop_i_extents;
580+ } else {
581+ if (scope_root_sref_.get () != scope_root_sref.get ()) {
582+ LOG (FATAL) << " ScheduleError: Expected the loops to be under the same block scope." ;
583+ throw ;
584+ }
585+ if (nest_loop_i_extents.size () != nest_loop_extents.size ()) {
586+ LOG (FATAL) << " ScheduleError: Merge loop's nesting depth must be same, but not." ;
587+ throw ;
588+ } else {
589+ for (size_t j = 0 ; j < nest_loop_i_extents.size (); j++) {
590+ if (!analyzer.CanProveEqual (nest_loop_i_extents[j], nest_loop_extents[j])) {
591+ LOG (FATAL) << " ScheduleError: Merge loop's `extent` must be same, but not."
592+ << " extent=[" << j << " ," << nest_loop_extents[j] << " ,"
593+ << nest_loop_i_extents[j] << " ]" ;
594+ throw ;
595+ }
596+ }
597+ }
598+ }
599+ }
600+ // Step 2. Create merged loops and replace the original loops
601+ Block scope_root = GetRef<Block>(scope_root_sref->StmtAs <BlockNode>());
602+ LoopReconstructor reconstructor (scope_root, lca_nest_loops);
603+ reconstructor.MakeNewLoop ();
604+ Block new_scope_root = Downcast<Block>(reconstructor (scope_root));
605+ // Step 3. Do the actual replacement
606+ self->Replace (scope_root_sref, new_scope_root, {{scope_root, new_scope_root}});
607+ return self->stmt2ref .at (reconstructor.new_inner_loop_ .get ());
608+ }
609+
454610StmtSRef Fuse (ScheduleState self, const Array<StmtSRef>& loop_srefs, bool preserve_unit_iters) {
455611 // Invariance
456612 // - The total repeat number has not changed for each direct child block.
@@ -795,6 +951,38 @@ struct SplitTraits : public UnpackedInstTraits<SplitTraits> {
795951 friend struct ::tvm::tir::UnpackedInstTraits;
796952};
797953
954+ struct MergeTraits : public UnpackedInstTraits <MergeTraits> {
955+ static constexpr const char * kName = " Merge" ;
956+ static constexpr bool kIsPure = false ;
957+
958+ private:
959+ static constexpr size_t kNumInputs = 1 ;
960+ static constexpr size_t kNumAttrs = 0 ;
961+ static constexpr size_t kNumDecisions = 0 ;
962+
963+ template <size_t delta>
964+ static TVM_ALWAYS_INLINE void _SetInputs (const runtime::TVMArgsSetter& setter,
965+ const Array<ObjectRef>& inputs) {
966+ setter (delta, inputs);
967+ }
968+
969+ static LoopRV UnpackedApplyToSchedule (Schedule sch, Array<LoopRV> loop_rvs) {
970+ return sch->Merge (loop_rvs);
971+ }
972+
973+ static String UnpackedAsPython (Array<String> outputs, Array<String> loop_rvs) {
974+ PythonAPICall py (" merge" );
975+ for (const String& loop_rv : loop_rvs) {
976+ py.Input (" " , loop_rv);
977+ }
978+ py.SingleOutput (outputs);
979+ return py.Str ();
980+ }
981+
982+ template <typename >
983+ friend struct ::tvm::tir::UnpackedInstTraits;
984+ };
985+
798986struct FuseTraits : public UnpackedInstTraits <FuseTraits> {
799987 static constexpr const char * kName = " Fuse" ;
800988 static constexpr bool kIsPure = false ;
@@ -893,6 +1081,7 @@ struct AddUnitLoopTraits : public UnpackedInstTraits<AddUnitLoopTraits> {
8931081};
8941082
8951083TVM_REGISTER_INST_KIND_TRAITS (SplitTraits);
1084+ TVM_REGISTER_INST_KIND_TRAITS (MergeTraits);
8961085TVM_REGISTER_INST_KIND_TRAITS (FuseTraits);
8971086TVM_REGISTER_INST_KIND_TRAITS (ReorderTraits);
8981087TVM_REGISTER_INST_KIND_TRAITS (AddUnitLoopTraits);
0 commit comments