@@ -451,6 +451,163 @@ Array<StmtSRef> Split(ScheduleState self, const StmtSRef& loop_sref, const Array
451451 return result_srefs;
452452}
453453
454+ class LoopReconstructor : private StmtMutator {
455+ public:
456+ explicit LoopReconstructor (Block scope_root,
457+ const std::vector<std::vector<const ForNode*>>& loops)
458+ : scope_root_(scope_root), loops_(loops) {}
459+
460+ using StmtMutator::operator ();
461+
462+ /* !
463+ * \brief Create the new nest loops induced by the given loops
464+ */
465+ void MakeNewLoop () {
466+ Array<Var> new_loop_vars;
467+ Array<PrimExpr> new_loop_extents;
468+ Array<Stmt> new_stmts;
469+ for (size_t i = 0 ; i < loops_.size (); i++) {
470+ Map<Var, PrimExpr> var_map;
471+ for (size_t j = 0 ; j < loops_[i].size (); j++) {
472+ if (i == 0 ) {
473+ Var merged_var = loops_[i][j]->loop_var .copy_with_suffix (" _m" );
474+ new_loop_vars.push_back (merged_var);
475+ new_loop_extents.push_back (loops_[i][j]->extent );
476+ }
477+ var_map.Set (loops_[i][j]->loop_var , new_loop_vars[j]);
478+ }
479+ auto new_stmt = Substitute (loops_[i][0 ]->body , var_map);
480+ new_stmts.push_back (new_stmt);
481+ this ->need_remove_loop_ .push_back (loops_[i].back ());
482+ }
483+ auto new_loop = For (new_loop_vars[0 ], Integer (0 ), new_loop_extents[0 ], ForKind::kSerial ,
484+ SeqStmt (std::move (new_stmts)));
485+ this ->new_inner_loop_ = new_loop;
486+ for (size_t i = 1 ; i < new_loop_vars.size (); ++i) {
487+ const Var& loop_var = new_loop_vars[i];
488+ const PrimExpr& loop_extent = new_loop_extents[i];
489+ new_loop = For (loop_var, Integer (0 ), loop_extent, ForKind::kSerial , new_loop);
490+ }
491+ this ->new_outer_loop_ = new_loop;
492+ }
493+
494+ private:
495+ Stmt VisitStmt_ (const BlockNode* block) final {
496+ if (block != scope_root_.get ()) {
497+ return GetRef<Block>(block);
498+ }
499+ return StmtMutator::VisitStmt_ (block);
500+ }
501+
502+ Stmt VisitStmt_ (const ForNode* loop) final {
503+ if (loop == need_remove_loop_.back ()) {
504+ return new_outer_loop_;
505+ } else if (std::count (need_remove_loop_.begin (), need_remove_loop_.end (), loop)) {
506+ return Evaluate (0 );
507+ }
508+ return StmtMutator::VisitStmt_ (loop);
509+ }
510+
511+ Stmt VisitStmt_ (const SeqStmtNode* seq_stmt) final {
512+ auto ret = Downcast<SeqStmt>(StmtMutator::VisitSeqStmt_ (seq_stmt, true ));
513+ Array<Stmt> filtered;
514+ for (Stmt stmt : ret->seq ) {
515+ if (!is_no_op (stmt)) {
516+ filtered.push_back (std::move (stmt));
517+ }
518+ }
519+ ret = SeqStmt (filtered);
520+ if (ret->size () == 0 ) {
521+ return Evaluate (0 );
522+ } else if (ret->size () == 1 ) {
523+ return ret->seq [0 ];
524+ } else {
525+ return std::move (ret);
526+ }
527+ }
528+
529+ public:
530+ /* ! \brief The root block of the block scope */
531+ Block scope_root_;
532+ /* ! \brief The given loops to be merge */
533+ const std::vector<std::vector<const ForNode*>>& loops_;
534+ /* ! \brief The outermost new loop to replace the original loop */
535+ For new_outer_loop_{nullptr };
536+ /* ! \brief The innermost new loop to replace the original loop */
537+ For new_inner_loop_{nullptr };
538+ /* ! \brief The loops to be removed */
539+ std::vector<const ForNode*> need_remove_loop_;
540+ };
541+
542+ StmtSRef Merge (ScheduleState self, const Array<StmtSRef>& loop_srefs) {
543+ // Invariance
544+ // - The total repeat number has not changed for each direct child block.
545+ // - The execution order has not changed. (The block executes with the same
546+ // args and the same order with before.)
547+ arith::Analyzer analyzer;
548+ StmtSRef scope_root_sref;
549+ StmtSRef lca = GetSRefLowestCommonAncestor (loop_srefs);
550+ std::vector<std::vector<const ForNode*>> lca_nest_loops;
551+ // Step 1. check correctness
552+ std::vector<const ForNode*> nest_loop_loops;
553+ std::vector<PrimExpr> nest_loop_extents;
554+ for (size_t i = 0 ; i < loop_srefs.size (); i++) {
555+ const StmtSRef& sref = loop_srefs[i];
556+ auto scope_root_sref_ = GetScopeRoot (self, sref, /* require_stage_pipeline=*/ false );
557+ std::vector<PrimExpr> nest_loop_i_extents;
558+ std::vector<const ForNode*> nest_loop_i_loops;
559+ for (auto p = sref.get (); p != lca.get (); p = p->parent ) {
560+ if (auto loop = p->StmtAs <ForNode>()) {
561+ if (!loop->annotations .empty () || loop->thread_binding .defined ()) {
562+ throw HasAnnotationOrThreadBindingError (self->mod , GetRef<For>(loop));
563+ }
564+ CheckLoopStartsWithZero (self, GetRef<StmtSRef>(p), &analyzer);
565+ nest_loop_i_loops.push_back (loop);
566+ nest_loop_i_extents.push_back (loop->extent );
567+ }
568+ }
569+ lca_nest_loops.push_back (nest_loop_i_loops);
570+ const ForNode* outer_loop = nullptr ;
571+ for (auto iter = nest_loop_i_loops.rbegin (); iter != nest_loop_i_loops.rend (); ++iter) {
572+ if (outer_loop && !outer_loop->body .same_as (GetRef<For>(*iter))) {
573+ throw NotOnlyChildError (self->mod , GetRef<For>(outer_loop), GetRef<For>(*iter));
574+ }
575+ outer_loop = *iter;
576+ }
577+ if (i == 0 ) {
578+ scope_root_sref = scope_root_sref_;
579+ nest_loop_loops = nest_loop_i_loops;
580+ nest_loop_extents = nest_loop_i_extents;
581+ } else {
582+ if (scope_root_sref_.get () != scope_root_sref.get ()) {
583+ LOG (FATAL) << " ScheduleError: Expected the loops to be under the same block scope" ;
584+ throw ;
585+ }
586+ if (nest_loop_i_extents.size () != nest_loop_extents.size ()) {
587+ LOG (FATAL) << " ScheduleError: Merge loop's nesting depth must be same, but not" ;
588+ throw ;
589+ } else {
590+ for (size_t j = 0 ; j < nest_loop_i_extents.size (); j++) {
591+ if (!analyzer.CanProveEqual (nest_loop_i_extents[j], nest_loop_extents[j])) {
592+ LOG (FATAL) << " ScheduleError: Merge loop's `extent` must be same, but not."
593+ << " extent=[" << j << " ," << nest_loop_extents[j] << " ,"
594+ << nest_loop_i_extents[j] << " ]" ;
595+ throw ;
596+ }
597+ }
598+ }
599+ }
600+ }
601+ // Step 2. Create merged loops and replace the original loops
602+ Block scope_root = GetRef<Block>(scope_root_sref->StmtAs <BlockNode>());
603+ LoopReconstructor reconstructor (scope_root, lca_nest_loops);
604+ reconstructor.MakeNewLoop ();
605+ Block new_scope_root = Downcast<Block>(reconstructor (scope_root));
606+ // Step 3. Do the actual replacement
607+ self->Replace (scope_root_sref, new_scope_root, {{scope_root, new_scope_root}});
608+ return self->stmt2ref .at (reconstructor.new_inner_loop_ .get ());
609+ }
610+
454611StmtSRef Fuse (ScheduleState self, const Array<StmtSRef>& loop_srefs, bool preserve_unit_iters) {
455612 // Invariance
456613 // - The total repeat number has not changed for each direct child block.
@@ -795,6 +952,38 @@ struct SplitTraits : public UnpackedInstTraits<SplitTraits> {
795952 friend struct ::tvm::tir::UnpackedInstTraits;
796953};
797954
955+ struct MergeTraits : public UnpackedInstTraits <MergeTraits> {
956+ static constexpr const char * kName = " Merge" ;
957+ static constexpr bool kIsPure = false ;
958+
959+ private:
960+ static constexpr size_t kNumInputs = 1 ;
961+ static constexpr size_t kNumAttrs = 0 ;
962+ static constexpr size_t kNumDecisions = 0 ;
963+
964+ template <size_t delta>
965+ static TVM_ALWAYS_INLINE void _SetInputs (const runtime::TVMArgsSetter& setter,
966+ const Array<ObjectRef>& inputs) {
967+ setter (delta, inputs);
968+ }
969+
970+ static LoopRV UnpackedApplyToSchedule (Schedule sch, Array<LoopRV> loop_rvs) {
971+ return sch->Merge (loop_rvs);
972+ }
973+
974+ static String UnpackedAsPython (Array<String> outputs, Array<String> loop_rvs) {
975+ PythonAPICall py (" merge" );
976+ for (const String& loop_rv : loop_rvs) {
977+ py.Input (" " , loop_rv);
978+ }
979+ py.SingleOutput (outputs);
980+ return py.Str ();
981+ }
982+
983+ template <typename >
984+ friend struct ::tvm::tir::UnpackedInstTraits;
985+ };
986+
798987struct FuseTraits : public UnpackedInstTraits <FuseTraits> {
799988 static constexpr const char * kName = " Fuse" ;
800989 static constexpr bool kIsPure = false ;
@@ -893,6 +1082,7 @@ struct AddUnitLoopTraits : public UnpackedInstTraits<AddUnitLoopTraits> {
8931082};
8941083
8951084TVM_REGISTER_INST_KIND_TRAITS (SplitTraits);
1085+ TVM_REGISTER_INST_KIND_TRAITS (MergeTraits);
8961086TVM_REGISTER_INST_KIND_TRAITS (FuseTraits);
8971087TVM_REGISTER_INST_KIND_TRAITS (ReorderTraits);
8981088TVM_REGISTER_INST_KIND_TRAITS (AddUnitLoopTraits);
0 commit comments