@@ -451,6 +451,165 @@ Array<StmtSRef> Split(ScheduleState self, const StmtSRef& loop_sref, const Array
451451 return result_srefs;
452452}
453453
454+ class LoopReconstructor : private StmtMutator {
455+ public:
456+ explicit LoopReconstructor (Block scope_root, std::vector<std::vector<const ForNode*>> loops)
457+ : scope_root_(scope_root), loops_(loops) {}
458+
459+ using StmtMutator::operator ();
460+
461+ /* !
462+ * \brief Create the new nest loops induced by the given loops
463+ */
464+ void MakeNewLoop () {
465+ Array<Var> new_loop_vars;
466+ Array<PrimExpr> new_loop_extents;
467+ Array<Stmt> new_stmts;
468+ for (size_t i = 0 ; i < loops_.size (); i++) {
469+ Map<Var, PrimExpr> var_map;
470+ for (size_t j = 0 ; j < loops_[i].size (); j++) {
471+ if (i == 0 ) {
472+ std::string suffix = loops_[i][j]->loop_var ->name_hint ;
473+ suffix += " _m" ;
474+ int bits = loops_[i][j]->loop_var .dtype ().bits ();
475+ Var merged_var (suffix, DataType::Int (bits));
476+ new_loop_vars.push_back (merged_var);
477+ new_loop_extents.push_back (loops_[i][j]->extent );
478+ }
479+ var_map.Set (loops_[i][j]->loop_var , new_loop_vars[j]);
480+ }
481+ auto new_stmt = Substitute (loops_[i][0 ]->body , var_map);
482+ new_stmts.push_back (new_stmt);
483+ this ->need_remove_loop_ .push_back (loops_[i].back ());
484+ }
485+ auto new_loop = For (new_loop_vars[0 ], Integer (0 ), new_loop_extents[0 ], ForKind::kSerial ,
486+ SeqStmt (std::move (new_stmts)));
487+ this ->new_inner_loop_ = new_loop;
488+ for (size_t i = 1 ; i < new_loop_vars.size (); ++i) {
489+ const Var& loop_var = new_loop_vars[i];
490+ const PrimExpr& loop_extent = new_loop_extents[i];
491+ new_loop = For (loop_var, Integer (0 ), loop_extent, ForKind::kSerial , new_loop);
492+ }
493+ this ->new_outer_loop_ = new_loop;
494+ }
495+
496+ private:
497+ Stmt VisitStmt_ (const BlockNode* block) final {
498+ if (block != scope_root_.get ()) {
499+ return GetRef<Block>(block);
500+ }
501+ return StmtMutator::VisitStmt_ (block);
502+ }
503+
504+ Stmt VisitStmt_ (const ForNode* loop) final {
505+ if (loop == need_remove_loop_.back ()) {
506+ return new_outer_loop_;
507+ } else if (std::count (need_remove_loop_.begin (), need_remove_loop_.end (), loop)) {
508+ return Evaluate (0 );
509+ }
510+ return StmtMutator::VisitStmt_ (loop);
511+ }
512+
513+ Stmt VisitStmt_ (const SeqStmtNode* seq_stmt) final {
514+ auto ret = Downcast<SeqStmt>(StmtMutator::VisitSeqStmt_ (seq_stmt, true ));
515+ Array<Stmt> filtered;
516+ for (Stmt stmt : ret->seq ) {
517+ if (!is_no_op (stmt)) {
518+ filtered.push_back (std::move (stmt));
519+ }
520+ }
521+ ret = SeqStmt (filtered);
522+ if (ret->size () == 0 ) {
523+ return Evaluate (0 );
524+ } else if (ret->size () == 1 ) {
525+ return ret->seq [0 ];
526+ } else {
527+ return std::move (ret);
528+ }
529+ }
530+
531+ public:
532+ /* ! \brief The root block of the block scope */
533+ Block scope_root_;
534+ /* ! \brief The given loops to be merge */
535+ std::vector<std::vector<const ForNode*>> loops_;
536+ /* ! \brief The outermost new loop to replace the original loop */
537+ For new_outer_loop_{nullptr };
538+ /* ! \brief The innermost new loop to replace the original loop */
539+ For new_inner_loop_{nullptr };
540+ /* ! \brief The loops to be removed */
541+ std::vector<const ForNode*> need_remove_loop_;
542+ };
543+
544+ StmtSRef Merge (ScheduleState self, const Array<StmtSRef>& loop_srefs) {
545+ // Invariance
546+ // - The total repeat number has not changed for each direct child block.
547+ // - The execution order has not changed. (The block executes with the same
548+ // args and the same order with before.)
549+ arith::Analyzer analyzer;
550+ StmtSRef scope_root_sref;
551+ StmtSRef lca = GetSRefLowestCommonAncestor (loop_srefs);
552+ std::vector<std::vector<const ForNode*>> lca_nest_loops;
553+ // Step 1. check correctness
554+ std::vector<const ForNode*> nest_loop_loops;
555+ std::vector<PrimExpr> nest_loop_extents;
556+ for (size_t i = 0 ; i < loop_srefs.size (); i++) {
557+ const StmtSRef& sref = loop_srefs[i];
558+ auto scope_root_sref_ = GetScopeRoot (self, sref, /* require_stage_pipeline=*/ false );
559+ std::vector<PrimExpr> nest_loop_i_extents;
560+ std::vector<const ForNode*> nest_loop_i_loops;
561+ for (auto p = sref.get (); p != lca.get (); p = p->parent ) {
562+ if (auto loop = p->StmtAs <ForNode>()) {
563+ if (!loop->annotations .empty () || loop->thread_binding .defined ()) {
564+ throw HasAnnotationOrThreadBindingError (self->mod , GetRef<For>(loop));
565+ }
566+ CheckLoopStartsWithZero (self, GetRef<StmtSRef>(p), &analyzer);
567+ nest_loop_i_loops.push_back (loop);
568+ nest_loop_i_extents.push_back (loop->extent );
569+ }
570+ }
571+ lca_nest_loops.push_back (nest_loop_i_loops);
572+ const ForNode* outer_loop = nullptr ;
573+ for (auto iter = nest_loop_i_loops.rbegin (); iter != nest_loop_i_loops.rend (); ++iter) {
574+ if (outer_loop && !outer_loop->body .same_as (GetRef<For>(*iter))) {
575+ throw NotOnlyChildError (self->mod , GetRef<For>(outer_loop), GetRef<For>(*iter));
576+ }
577+ outer_loop = *iter;
578+ }
579+ if (i == 0 ) {
580+ scope_root_sref = scope_root_sref_;
581+ nest_loop_loops = nest_loop_i_loops;
582+ nest_loop_extents = nest_loop_i_extents;
583+ } else {
584+ if (scope_root_sref_.get () != scope_root_sref.get ()) {
585+ LOG (FATAL) << " ScheduleError: Expected the loops to be under the same block scope" ;
586+ throw ;
587+ }
588+ if (nest_loop_i_extents.size () != nest_loop_extents.size ()) {
589+ LOG (FATAL) << " ScheduleError: Merge loop's nesting depth must be same, but not" ;
590+ throw ;
591+ } else {
592+ for (size_t j = 0 ; j < nest_loop_i_extents.size (); j++) {
593+ if (!analyzer.CanProveEqual (nest_loop_i_extents[j], nest_loop_extents[j])) {
594+ LOG (FATAL) << " ScheduleError: Merge loop's `extent` must be same, but not."
595+ << " extent=[" << j << " ," << nest_loop_extents[j] << " ,"
596+ << nest_loop_i_extents[j] << " ]" ;
597+ throw ;
598+ }
599+ }
600+ }
601+ }
602+ }
603+ // Step 2. Create merged loops and replace the original loops
604+ Block scope_root = GetRef<Block>(scope_root_sref->StmtAs <BlockNode>());
605+ LoopReconstructor reconstructor (scope_root, lca_nest_loops);
606+ reconstructor.MakeNewLoop ();
607+ Block new_scope_root = Downcast<Block>(reconstructor (scope_root));
608+ // Step 3. Do the actual replacement
609+ self->Replace (scope_root_sref, new_scope_root, {{scope_root, new_scope_root}});
610+ return self->stmt2ref .at (reconstructor.new_inner_loop_ .get ());
611+ }
612+
454613StmtSRef Fuse (ScheduleState self, const Array<StmtSRef>& loop_srefs, bool preserve_unit_iters) {
455614 // Invariance
456615 // - The total repeat number has not changed for each direct child block.
@@ -795,6 +954,38 @@ struct SplitTraits : public UnpackedInstTraits<SplitTraits> {
795954 friend struct ::tvm::tir::UnpackedInstTraits;
796955};
797956
957+ struct MergeTraits : public UnpackedInstTraits <MergeTraits> {
958+ static constexpr const char * kName = " Merge" ;
959+ static constexpr bool kIsPure = false ;
960+
961+ private:
962+ static constexpr size_t kNumInputs = 1 ;
963+ static constexpr size_t kNumAttrs = 0 ;
964+ static constexpr size_t kNumDecisions = 0 ;
965+
966+ template <size_t delta>
967+ static TVM_ALWAYS_INLINE void _SetInputs (const runtime::TVMArgsSetter& setter,
968+ const Array<ObjectRef>& inputs) {
969+ setter (delta, inputs);
970+ }
971+
972+ static LoopRV UnpackedApplyToSchedule (Schedule sch, Array<LoopRV> loop_rvs) {
973+ return sch->Merge (loop_rvs);
974+ }
975+
976+ static String UnpackedAsPython (Array<String> outputs, Array<String> loop_rvs) {
977+ PythonAPICall py (" merge" );
978+ for (const String& loop_rv : loop_rvs) {
979+ py.Input (" " , loop_rv);
980+ }
981+ py.SingleOutput (outputs);
982+ return py.Str ();
983+ }
984+
985+ template <typename >
986+ friend struct ::tvm::tir::UnpackedInstTraits;
987+ };
988+
798989struct FuseTraits : public UnpackedInstTraits <FuseTraits> {
799990 static constexpr const char * kName = " Fuse" ;
800991 static constexpr bool kIsPure = false ;
@@ -893,6 +1084,7 @@ struct AddUnitLoopTraits : public UnpackedInstTraits<AddUnitLoopTraits> {
8931084};
8941085
8951086TVM_REGISTER_INST_KIND_TRAITS (SplitTraits);
1087+ TVM_REGISTER_INST_KIND_TRAITS (MergeTraits);
8961088TVM_REGISTER_INST_KIND_TRAITS (FuseTraits);
8971089TVM_REGISTER_INST_KIND_TRAITS (ReorderTraits);
8981090TVM_REGISTER_INST_KIND_TRAITS (AddUnitLoopTraits);
0 commit comments