Skip to content

Commit 9b148f1

Browse files
yzhliutqchen
authored andcommitted
[schedule] Improve ceil_divide in tile/split (#3842)
1 parent d9bbdbc commit 9b148f1

File tree

2 files changed

+32
-0
lines changed

2 files changed

+32
-0
lines changed

src/schedule/message_passing.cc

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,9 @@ void PassDownDomain(const Stage& stage,
5656
arith::Analyzer* actx,
5757
bool allow_missing) {
5858
auto ceil_div = [actx](Expr a, Expr b) {
59+
if (actx->CanProve(a % b == 0)) {
60+
return actx->Simplify(a / b);
61+
}
5962
return actx->Simplify((a + (b - 1)) / b);
6063
};
6164

tests/python/unittest/test_schedule_bound_inference.py

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -69,6 +69,33 @@ def test_bound3():
6969
assert(bounds[A1.op.axis[0]].extent.value==32)
7070
assert(bounds[A1.op.axis[1]].extent.value==16)
7171

72+
def test_bound_split_divisible():
73+
m = tvm.var('m')
74+
l = tvm.var('l')
75+
A = tvm.placeholder((8 * m, l), name='A')
76+
B = tvm.compute((8 * m, l), lambda i, j: A[i, j], name='B')
77+
s = tvm.create_schedule(B.op)
78+
xo, xi = s[B].split(B.op.axis[0], 8)
79+
bounds = tvm.schedule.InferBound(s)
80+
assert isinstance(bounds, tvm.container.Map)
81+
assert bounds[xo].extent == m
82+
assert bounds[xi].extent.value == 8
83+
84+
def test_bound_tile_divisible():
85+
m = tvm.var('m')
86+
l = tvm.var('l')
87+
shape = (8 * m, 32 * l)
88+
A = tvm.placeholder(shape, name='A')
89+
B = tvm.compute(shape, lambda i, j: A[i, j], name='B')
90+
s = tvm.create_schedule(B.op)
91+
xo, yo, xi, yi = s[B].tile(B.op.axis[0], B.op.axis[1], 8, 32)
92+
bounds = tvm.schedule.InferBound(s)
93+
assert isinstance(bounds, tvm.container.Map)
94+
assert bounds[xo].extent == m
95+
assert bounds[xi].extent.value == 8
96+
assert bounds[yo].extent == l
97+
assert bounds[yi].extent.value == 32
98+
7299
def test_bound_fusesplit1():
73100
m = tvm.var('m')
74101
l = tvm.var('l')
@@ -393,3 +420,5 @@ def _check(B, A=A):
393420
test_bound_simplification_failure()
394421
test_bound_fusesplit1()
395422
test_bound_fusesplit2()
423+
test_bound_split_divisible()
424+
test_bound_tile_divisible()

0 commit comments

Comments
 (0)