Skip to content

Commit 14daf78

Browse files
author
Umang Yadav
committed
Fix extent one for the post_stmt in loop partition
1 parent 5f9c5e4 commit 14daf78

File tree

2 files changed

+26
-29
lines changed

2 files changed

+26
-29
lines changed

src/pass/loop_partition.cc

Lines changed: 25 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -509,46 +509,42 @@ Stmt LoopPartitioner::TryPartition(const Node* node,
509509
bool pre_stmt_recurse = true;
510510
if (middle_interval_i->HasLowerBound()) {
511511
body_begin = ir::Simplify(middle_interval.min());
512-
if (!analyzer_.CanProve(body_begin == min)) {
513-
Expr cond = (body_begin - min >= 0);
514-
if (!analyzer_.CanProve(cond)) {
515-
LOG(WARNING) << "Cannot prove: " << cond
516-
<< ", when generating the pre doubt loop";
517-
body_begin = Max::make(body_begin, min);
518-
// stop recursing on this interval if we can't prove it has non-negative length
519-
pre_stmt_recurse = false;
520-
}
521-
if (!partition_thread_scope) {
522-
Stmt pre_body = Substitute(body, {{Var{var}, var + min}});
523-
pre_stmt = MakeFor(node, body_begin - min, pre_body);
524-
}
512+
Expr cond = (body_begin - min >= 0);
513+
if (!analyzer_.CanProve(cond)) {
514+
LOG(WARNING) << "Cannot prove: " << cond
515+
<< ", when generating the pre doubt loop";
516+
body_begin = Max::make(body_begin, min);
517+
// stop recursing on this interval if we can't prove it has non-negative length
518+
pre_stmt_recurse = false;
519+
}
520+
if (!partition_thread_scope) {
521+
Stmt pre_body = Substitute(body, {{Var{var}, var + min}});
522+
pre_stmt = MakeFor(node, body_begin - min, pre_body);
525523
}
526524
} else {
527525
body_begin = min;
528526
}
529527

530528
// Calculating post-subrange and generating code for it.
531-
// post-subrange = [post_doubt_begin, max]
529+
// post-subrange = [post_doubt_begin, max+1)
532530
Expr post_doubt_begin;
533531
Stmt post_stmt;
534532
bool post_stmt_recurse = true;
535533
if (middle_interval_i->HasUpperBound()) {
536534
post_doubt_begin = ir::Simplify(middle_interval.max() + 1);
537-
if (!analyzer_.CanProve(middle_interval.max() == max)) {
538-
// require the extent to be non-negative
539-
Expr cond = (max - post_doubt_begin + 1 >= 0);
540-
if (!analyzer_.CanProve(cond)) {
541-
LOG(WARNING) << "Cannot prove: " << cond
542-
<< ", when generating the post doubt loop";
543-
post_doubt_begin = Min::make(post_doubt_begin, max);
544-
// stop recursing on this interval if we can't prove it has non-negative length
545-
post_stmt_recurse = false;
546-
}
547-
if (!partition_thread_scope) {
548-
Stmt post_body =
549-
Substitute(body, {{Var{var}, var + post_doubt_begin}});
550-
post_stmt = MakeFor(node, max - post_doubt_begin + 1, post_body);
551-
}
535+
// require the extent to be non-negative
536+
Expr cond = (max - post_doubt_begin + 1 >= 0);
537+
if (!analyzer_.CanProve(cond)) {
538+
LOG(WARNING) << "Cannot prove: " << cond
539+
<< ", when generating the post doubt loop";
540+
post_doubt_begin = Min::make(post_doubt_begin, max+1);
541+
// stop recursing on this interval if we can't prove it has non-negative length
542+
post_stmt_recurse = false;
543+
}
544+
if (!partition_thread_scope) {
545+
Stmt post_body =
546+
Substitute(body, {{Var{var}, var + post_doubt_begin}});
547+
post_stmt = MakeFor(node, max - post_doubt_begin + 1, post_body);
552548
}
553549
} else {
554550
post_doubt_begin = max + 1;

tests/python/unittest/test_pass_bound_checkers.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@ def lower(sch, args):
3737
bounds = tvm.schedule.InferBound(sch)
3838
stmt = tvm.schedule.ScheduleOps(sch, bounds)
3939
stmt = tvm.ir_pass.LoopPartition(stmt, True)
40+
stmt = tvm.ir_pass.RemoveNoOp(stmt)
4041
stmt = tvm.ir_pass.StorageFlatten(stmt, binds, 64, True)
4142
stmt = tvm.ir_pass.CanonicalSimplify(stmt)
4243
stmt = tvm.ir_pass.VectorizeLoop(stmt)

0 commit comments

Comments
 (0)