@@ -177,8 +177,12 @@ class IterMapRewriter : public ExprMutator {
177177 using Parent = ExprMutator;
178178
179179 explicit IterMapRewriter (Analyzer* analyzer, const Map<Var, Range>& input_iters,
180- bool simplify_trivial_iterators, Array<String>* errors)
181- : analyzer_(analyzer), errors_(*errors), padding_predicate_(const_false()) {
180+ IterMapLevel check_level, bool simplify_trivial_iterators,
181+ Array<String>* errors)
182+ : analyzer_(analyzer),
183+ check_level_(check_level),
184+ errors_(*errors),
185+ padding_predicate_(const_false()) {
182186 for (auto kv : input_iters) {
183187 const Var& var = kv.first ;
184188 const Range& vrng = kv.second ;
@@ -419,6 +423,8 @@ class IterMapRewriter : public ExprMutator {
419423
420424 // Internal analyzer
421425 Analyzer* analyzer_;
426+ // Iter map check level
427+ IterMapLevel check_level_;
422428 // Error messages for each unresolved expression.
423429 Array<String>& errors_;
424430 // The var map
@@ -651,7 +657,7 @@ class IterMapRewriter : public ExprMutator {
651657 if (predicate_induced_max.defined ())
652658 predicate_induced_max = predicate_induced_max.value () - base;
653659 }
654- Optional<IterSumExpr> opt = TryFuseIters (expr);
660+ Optional<IterSumExpr> opt = TryFuseIters (expr, check_level_ );
655661 ICHECK (!opt.defined () || opt.value ()->args .size () == 1 );
656662 // scale should be 1
657663 if (opt.defined () && is_one (opt.value ()->args [0 ]->scale )) {
@@ -702,7 +708,7 @@ class IterMapRewriter : public ExprMutator {
702708 IterSumExpr NormalizeToIterWithOffset (IterSumExpr expr) {
703709 // We are normalizing a regular iter
704710 if (expr->args .size () < 1 ) return expr;
705- Optional<IterSumExpr> opt = TryFuseIters (expr);
711+ Optional<IterSumExpr> opt = TryFuseIters (expr, check_level_ );
706712 if (opt.defined ()) {
707713 return opt.value ();
708714 } else {
@@ -735,9 +741,10 @@ class IterMapRewriter : public ExprMutator {
735741 * return a corresponding IterSumExpr with extra offset if needed.
736742 * Try to normalize IterSum into a fused IterMark
737743 * \param expr The input sum.
744+ * \param check_level The check level if iter mapping.
738745 * \return The sum with the fused IterMark and extra offset if succeed.
739746 */
740- Optional<IterSumExpr> TryFuseIters (IterSumExpr expr) {
747+ Optional<IterSumExpr> TryFuseIters (IterSumExpr expr, IterMapLevel check_level ) {
741748 // select the iterators in order
742749 std::vector<bool > visited (expr->args .size (), false );
743750 std::vector<IterSplitExpr> flattened_iters, grouped_iters;
@@ -758,14 +765,42 @@ class IterMapRewriter : public ExprMutator {
758765 }
759766 // check if it can be remapped into a fused pattern.
760767 PrimExpr expected_extra_base = 0 ;
768+ PrimExpr tail_extent = 0 ;
761769 PrimExpr expected_scale = base_scale.value ();
762770 for (size_t i = 0 ; i < expr->args .size ();) {
763- // find j such that expr->args[j] has expected scale
764- size_t j = i == 0 ? base_index : 0 ;
765- for (; j < expr->args .size (); ++j) {
766- if (!visited[j] && analyzer_->CanProveEqual (expr->args [j]->scale , expected_scale)) break ;
771+ // find position such that expr->args[j] match expected scale
772+ int j = i == 0 ? base_index : expr->args .size () - 1 ;
773+
774+ size_t matched_pos = expr->args .size ();
775+ PrimExpr matched_scale{nullptr };
776+ bool is_exact_match{false };
777+
778+ for (; j >= 0 ; --j) {
779+ if (visited[j]) {
780+ continue ;
781+ }
782+ const PrimExpr& cur_scale = expr->args [j]->scale ;
783+
784+ // for bijective mapping, the matched scale must equal to expected scale
785+ if (analyzer_->CanProveEqual (cur_scale, expected_scale)) {
786+ matched_pos = j;
787+ matched_scale = cur_scale;
788+ is_exact_match = true ;
789+ break ;
790+ }
791+ if (check_level != IterMapLevel::Bijective && base_scale.value ()->value == 1 ) {
792+ // find the closest scale which is less or equal to expected scale
793+ if (analyzer_->CanProveGreaterEqual (expected_scale - cur_scale, 0 ) &&
794+ analyzer_->CanProveGreaterEqual (cur_scale, 0 )) {
795+ if (matched_pos == expr->args .size () ||
796+ analyzer_->CanProveLess (matched_scale - cur_scale, 0 )) {
797+ matched_pos = j;
798+ matched_scale = cur_scale;
799+ }
800+ }
801+ }
767802 }
768- if (j == expr->args .size ()) {
803+ if (matched_pos == expr->args .size ()) {
769804 return NullOpt;
770805 }
771806 // look for the longest constrained iter started from expr->args[j]
@@ -775,8 +810,8 @@ class IterMapRewriter : public ExprMutator {
775810 // otherwise we expect the scale of i to be 2*5=10
776811 Optional<IterSumExpr> constraint_to_match;
777812 for (const IterSumExpr& iter : constrained_iters_flattened_) {
778- if (IterSplitEqual (expr->args [j ], iter->args .back (), false )) {
779- // find a predicate started from expr->args[j]
813+ if (IterSplitEqual (expr->args [matched_pos ], iter->args .back (), false )) {
814+ // find a predicate started from match position
780815 if (!constraint_to_match ||
781816 constraint_to_match.value ()->args .size () < iter->args .size ()) {
782817 constraint_to_match = iter;
@@ -793,7 +828,7 @@ class IterMapRewriter : public ExprMutator {
793828 size_t k = 0 ;
794829 for (; k < expr->args .size (); ++k) {
795830 if (!visited[k] && IterSplitEqual (expr->args [k], *it, false )) {
796- if (analyzer_->CanProveEqual ((*it)->scale * expected_scale , expr->args [k]->scale ))
831+ if (analyzer_->CanProveEqual ((*it)->scale * matched_scale , expr->args [k]->scale ))
797832 break ;
798833 }
799834 }
@@ -806,20 +841,25 @@ class IterMapRewriter : public ExprMutator {
806841 auto iter = sum_fuse_map_.find (constraint_to_match.value ());
807842 ICHECK (iter != sum_fuse_map_.end ());
808843 const IterMarkWithOffset& iter_matched = iter->second ;
809- grouped_iters.emplace_back (iter_matched.mark , expected_scale);
810- expected_extra_base += iter_matched.offset * expected_scale;
811- expected_scale *= iter_matched.mark ->extent ;
844+ grouped_iters.emplace_back (iter_matched.mark , div (matched_scale, base_scale.value ()));
845+ expected_extra_base += iter_matched.offset * matched_scale;
846+ if (!is_exact_match) {
847+ tail_extent += expected_scale - matched_scale;
848+ }
849+ expected_scale = matched_scale * iter_matched.mark ->extent ;
812850 // move forward
813851 i += constraint_to_match.value ()->args .size ();
814852 } else {
815853 // constraint_to_match not found, skip this iterator
816- visited[j] = true ;
817- IterSplitExpr arg = expr->args [j];
818- arg.CopyOnWrite ()->scale =
819- analyzer_->Simplify (div (expr->args [j]->scale , base_scale.value ()));
854+ visited[matched_pos] = true ;
855+ IterSplitExpr arg = expr->args [matched_pos];
856+ arg.CopyOnWrite ()->scale = analyzer_->Simplify (div (arg->scale , base_scale.value ()));
820857 flattened_iters.push_back (arg);
821858 grouped_iters.push_back (arg);
822- expected_scale *= expr->args [j]->extent ;
859+ if (!is_exact_match) {
860+ tail_extent += expected_scale - matched_scale;
861+ }
862+ expected_scale = matched_scale * expr->args [matched_pos]->extent ;
823863 ++i;
824864 }
825865 }
@@ -843,7 +883,8 @@ class IterMapRewriter : public ExprMutator {
843883 expr->base + expected_extra_base);
844884 } else {
845885 // new iter, form a new mark
846- IterMark mark = IterMark (structured_form, div (expected_scale, base_scale.value ()));
886+ IterMark mark =
887+ IterMark (structured_form, div (expected_scale, base_scale.value ()) + tail_extent);
847888 sum_fuse_map_[flattened_form] = IterMarkWithOffset (mark, 0 );
848889 flattened_map_[structured_form] = flattened_form;
849890 return IterSumExpr ({IterSplitExpr (mark, base_scale.value ())},
@@ -1086,8 +1127,8 @@ IterMapResult DetectIterMap(const Array<PrimExpr>& indices, const Map<Var, Range
10861127 constraints.begin (), constraints.end (),
10871128 [](const IterConstraint& a, const IterConstraint& b) { return a.expr_size < b.expr_size ; });
10881129
1089- IterMapRewriter rewriter (analyzer, constrained_input_iters, simplify_trivial_iterators ,
1090- &result->errors );
1130+ IterMapRewriter rewriter (analyzer, constrained_input_iters, check_level ,
1131+ simplify_trivial_iterators, &result->errors );
10911132 // Step0.0: rewrite constraints in the order from size-small ones to size-big ones
10921133 for (const IterConstraint& constraint : constraints) {
10931134 auto res = rewriter.RewriteIterConstraint (constraint.iter , constraint.lower_bound ,
@@ -1281,7 +1322,7 @@ IterSumExpr IterMapRewriter::PreprocessDividend(IterMapExpr dividend, PrimExpr o
12811322 } else if (sum->args .size () == 1 ) {
12821323 return sum;
12831324 }
1284- auto opt_fused = TryFuseIters (sum);
1325+ auto opt_fused = TryFuseIters (sum, check_level_ );
12851326 if (!opt_fused) {
12861327 ErrorLogger (this ) << " Dividend " << tvm::PrettyPrint (original_dividend)
12871328 << " , can't be written as a single fused IterSum" ;
0 commit comments