@@ -75,7 +75,7 @@ IterSplitExpr::IterSplitExpr(IterMark source, PrimExpr scale) {
7575 n->dtype = source->source ->dtype ;
7676 n->source = std::move (source);
7777 n->extent = n->source ->extent ;
78- if (is_zero (n->source ->min )) {
78+ if (! is_zero (n->source ->min )) {
7979 n->extent = n->extent + n->source ->min ;
8080 }
8181 n->lower_factor = one;
@@ -233,13 +233,20 @@ class IterMapRewriter : public ExprMutator {
233233 collector.Collect (bindings);
234234 for (const IterMark& mark : collector.visited_ ) {
235235 if (TryNormalizeSplits (mark, collector.mark2splits_ [mark], require_bijective).empty ()) {
236+ diag_ctx_.Emit (Diagnostic::Error (mark->source ->span )
237+ << " Fail to normalize iter mark splits: " << mark);
236238 return false ;
237239 }
238240 }
239241 if (require_bijective) {
240242 // all input marks must be visited
241243 for (const IterMark& mark : input_marks_) {
242- if (collector.visited_ .count (mark) == 0 ) return false ;
244+ if (collector.visited_ .count (mark) == 0 ) {
245+ diag_ctx_.Emit (Diagnostic::Error (mark->source ->span )
246+ << " The mapping is not bijective because input iter mark " << mark
247+ << " is not covered, " );
248+ return false ;
249+ }
243250 }
244251 }
245252 return true ;
@@ -425,26 +432,48 @@ class IterMapRewriter : public ExprMutator {
425432 }
426433 if (j == splits.size ()) {
427434 // we do not allow incomplete split if the bindings should be bijective
428- if (require_bijective) return Array<IterSplitExpr>();
435+ if (require_bijective) {
436+ diag_ctx_.Emit (
437+ Diagnostic::Error (mark->source ->span )
438+ << " Do not allow incomplete split in bijective checking, expected_lower_factor="
439+ << expected_lower_factor);
440+ return Array<IterSplitExpr>();
441+ }
429442 // look for the next split skipping this lower factor
430443 // For example, y \in [0, 24) has 3 splits [y / 6, (y / 2) % 6, y % 2]
431444 // It is valid to only have [y / 6, y % 2] if bijective is not required
432445 // We can skip (y / 2) % 6
433446 j = SearchSkipLowerFactor (splits, used, expected_lower_factor);
434447 // split not found
435- if (j == splits.size ()) return Array<IterSplitExpr>();
448+ if (j == splits.size ()) {
449+ diag_ctx_.Emit (Diagnostic::Error (mark->source ->span )
450+ << " Fail to find split skipping the lower factor in bijective-free "
451+ " checking, expected_lower_factor="
452+ << expected_lower_factor);
453+ return Array<IterSplitExpr>();
454+ }
436455 }
437456 used[j] = true ;
438457 iters.push_back (splits[j]);
439458 expected_lower_factor = splits[j]->lower_factor * splits[j]->extent ;
440459 }
460+
441461 // Case 1. bijective is required.
442462 // We check the extent we calculate is consistent with the extent of the mark
443463 // Case 2. bijective is not required.
444- // We check the extent we calculate is a factor of the extent of the mark
464+ // We check either
465+ // (1) the extent we calculate is a factor of the extent of the mark
445466 // For example, y \in [0, 24) [(y / 2) % 6, y % 2] is valid, but y \in [0, 25) is not.
446- if ((require_bijective && !analyzer_->CanProveEqual (expected_lower_factor, mark->extent )) ||
447- (!require_bijective && !CanProveDivisible (mark->extent , expected_lower_factor))) {
467+ // (2) the extent we calculate is larger than the max of the mark
468+ // For example, y \in [1, 8] [y / 18, y % 18] is valid.
469+ if ((require_bijective &&
470+ !(analyzer_->CanProveEqual (expected_lower_factor, mark->extent ) && is_zero (mark->min ))) ||
471+ (!require_bijective &&
472+ !(CanProveDivisible (mark->extent , expected_lower_factor) ||
473+ analyzer_->CanProve (mark->min + mark->extent <= expected_lower_factor)))) {
474+ diag_ctx_.Emit (Diagnostic::Error (mark->source ->span )
475+ << " Mark extent of " << mark
476+ << " is not compatible with expected_lower_factor=" << expected_lower_factor);
448477 return Array<IterSplitExpr>();
449478 }
450479 return Array<IterSplitExpr>(iters.rbegin (), iters.rend ());
0 commit comments