Skip to content

Commit 1ef2e7d

Browse files
author
baoxinqi
committed
fix lint issues and compare bug
1 parent d1900bd commit 1ef2e7d

File tree

3 files changed

+45
-13
lines changed

3 files changed

+45
-13
lines changed

include/tvm/arith/iter_affine_map.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -283,6 +283,7 @@ class IterSumExpr : public IterMapExpr {
283283
* \param predicate The predicate constraints on the input iterators
284284
* \param require_bijective A boolean flag that indicates whether the mapping should be bijective.
285285
* \param analyzer Analyzer used to get context information.
286+
* \param diag_ctx Diagnostic context.
286287
*
287288
* \return The detected pattern if a match exists,
288289
* otherwise return an empty array.

src/arith/iter_affine_map.cc

Lines changed: 36 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -75,7 +75,7 @@ IterSplitExpr::IterSplitExpr(IterMark source, PrimExpr scale) {
7575
n->dtype = source->source->dtype;
7676
n->source = std::move(source);
7777
n->extent = n->source->extent;
78-
if (is_zero(n->source->min)) {
78+
if (!is_zero(n->source->min)) {
7979
n->extent = n->extent + n->source->min;
8080
}
8181
n->lower_factor = one;
@@ -233,13 +233,20 @@ class IterMapRewriter : public ExprMutator {
233233
collector.Collect(bindings);
234234
for (const IterMark& mark : collector.visited_) {
235235
if (TryNormalizeSplits(mark, collector.mark2splits_[mark], require_bijective).empty()) {
236+
diag_ctx_.Emit(Diagnostic::Error(mark->source->span)
237+
<< "Fail to normalize iter mark splits: " << mark);
236238
return false;
237239
}
238240
}
239241
if (require_bijective) {
240242
// all input marks must be visited
241243
for (const IterMark& mark : input_marks_) {
242-
if (collector.visited_.count(mark) == 0) return false;
244+
if (collector.visited_.count(mark) == 0) {
245+
diag_ctx_.Emit(Diagnostic::Error(mark->source->span)
246+
<< "The mapping is not bijective because input iter mark " << mark
247+
<< " is not covered, ");
248+
return false;
249+
}
243250
}
244251
}
245252
return true;
@@ -425,26 +432,48 @@ class IterMapRewriter : public ExprMutator {
425432
}
426433
if (j == splits.size()) {
427434
// we do not allow incomplete split if the bindings should be bijective
428-
if (require_bijective) return Array<IterSplitExpr>();
435+
if (require_bijective) {
436+
diag_ctx_.Emit(
437+
Diagnostic::Error(mark->source->span)
438+
<< "Do not allow incomplete split in bijective checking, expected_lower_factor="
439+
<< expected_lower_factor);
440+
return Array<IterSplitExpr>();
441+
}
429442
// look for the next split skipping this lower factor
430443
// For example, y \in [0, 24) has 3 splits [y / 6, (y / 2) % 6, y % 2]
431444
// It is valid to only have [y / 6, y % 2] if bijective is not required
432445
// We can skip (y / 2) % 6
433446
j = SearchSkipLowerFactor(splits, used, expected_lower_factor);
434447
// split not found
435-
if (j == splits.size()) return Array<IterSplitExpr>();
448+
if (j == splits.size()) {
449+
diag_ctx_.Emit(Diagnostic::Error(mark->source->span)
450+
<< "Fail to find split skipping the lower factor in bijective-free "
451+
"checking, expected_lower_factor="
452+
<< expected_lower_factor);
453+
return Array<IterSplitExpr>();
454+
}
436455
}
437456
used[j] = true;
438457
iters.push_back(splits[j]);
439458
expected_lower_factor = splits[j]->lower_factor * splits[j]->extent;
440459
}
460+
441461
// Case 1. bijective is required.
442462
// We check the extent we calculate is consistent with the extent of the mark
443463
// Case 2. bijective is not required.
444-
// We check the extent we calculate is a factor of the extent of the mark
464+
// We check either
465+
// (1) the extent we calculate is a factor of the extent of the mark
445466
// For example, y \in [0, 24) [(y / 2) % 6, y % 2] is valid, but y \in [0, 25) is not.
446-
if ((require_bijective && !analyzer_->CanProveEqual(expected_lower_factor, mark->extent)) ||
447-
(!require_bijective && !CanProveDivisible(mark->extent, expected_lower_factor))) {
467+
// (2) the extent we calculate is larger than the max of the mark
468+
// For example, y \in [1, 8] [y / 18, y % 18] is valid.
469+
if ((require_bijective &&
470+
!(analyzer_->CanProveEqual(expected_lower_factor, mark->extent) && is_zero(mark->min))) ||
471+
(!require_bijective &&
472+
!(CanProveDivisible(mark->extent, expected_lower_factor) ||
473+
analyzer_->CanProve(mark->min + mark->extent <= expected_lower_factor)))) {
474+
diag_ctx_.Emit(Diagnostic::Error(mark->source->span)
475+
<< "Mark extent of " << mark
476+
<< " is not compatible with expected_lower_factor=" << expected_lower_factor);
448477
return Array<IterSplitExpr>();
449478
}
450479
return Array<IterSplitExpr>(iters.rbegin(), iters.rend());

tests/python/unittest/test_arith_iter_affine_map.py

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,7 @@ def var_dom(iters):
4444
return {var: tvm.ir.Range(0, ext) for var, ext in iters}
4545

4646

47-
def assert_iter_sum_pattern(sum_expr, extent, base, scale=1, min=0):
47+
def assert_iter_sum_pattern(sum_expr, extent, base, scale=1, mark_min=0, mark_extent=None):
4848
"""Check the sum expr have the right pattern."""
4949
assert isinstance(sum_expr, tvm.arith.IterSumExpr)
5050
if extent == 1:
@@ -53,7 +53,9 @@ def assert_iter_sum_pattern(sum_expr, extent, base, scale=1, min=0):
5353
assert len(sum_expr.args) == 1
5454
tvm.testing.assert_prim_expr_equal(sum_expr.args[0].extent, extent)
5555
tvm.testing.assert_prim_expr_equal(sum_expr.args[0].scale, scale)
56-
tvm.testing.assert_prim_expr_equal(sum_expr.args[0].source.min, min)
56+
tvm.testing.assert_prim_expr_equal(sum_expr.args[0].source.min, mark_min)
57+
if mark_extent:
58+
tvm.testing.assert_prim_expr_equal(sum_expr.args[0].source.extent, mark_extent)
5759
tvm.testing.assert_prim_expr_equal(sum_expr.base, base)
5860

5961

@@ -212,10 +214,10 @@ def test_predicate():
212214
# lower bound only
213215
res = tvm.arith.detect_iter_map([x[0] * 10 + y[0]], var_dom([x, y]), x[0] * 10 + y[0] > 5)
214216
assert len(res) == 1
215-
assert_iter_sum_pattern(res[0], 124, 0, min=6)
217+
assert_iter_sum_pattern(res[0], 130, 0, mark_min=6, mark_extent=124)
216218
res = tvm.arith.detect_iter_map([x[0] * 10 + y[0]], var_dom([x, y]), x[0] * 10 + y[0] >= 6)
217219
assert len(res) == 1
218-
assert_iter_sum_pattern(res[0], 124, 0, min=6)
220+
assert_iter_sum_pattern(res[0], 130, 0, mark_min=6, mark_extent=124)
219221

220222
# lower bound + upper bound
221223
res = tvm.arith.detect_iter_map(
@@ -224,14 +226,14 @@ def test_predicate():
224226
tvm.tir.And(x[0] * 10 + y[0] > 5, x[0] * 10 + y[0] < 128),
225227
)
226228
assert len(res) == 1
227-
assert_iter_sum_pattern(res[0], 122, 0, min=6)
229+
assert_iter_sum_pattern(res[0], 128, 0, mark_min=6, mark_extent=122)
228230
res = tvm.arith.detect_iter_map(
229231
[x[0] * 10 + y[0]],
230232
var_dom([x, y]),
231233
tvm.tir.And(x[0] * 10 + y[0] >= 6, x[0] * 10 + y[0] <= 127),
232234
)
233235
assert len(res) == 1
234-
assert_iter_sum_pattern(res[0], 122, 0, min=6)
236+
assert_iter_sum_pattern(res[0], 128, 0, mark_min=6, mark_extent=122)
235237

236238
# non-standard form of predicate
237239
res = tvm.arith.detect_iter_map([x[0] * 10 + y[0]], var_dom([x, y]), x[0] * 10 < 128 - y[0])

0 commit comments

Comments
 (0)