Skip to content

Commit 320fda0

Browse files
derisaviwweic
authored andcommitted
[TVM][ARITH] Teach BoundDeduce to handle the case in which target var can appear in rhs of expression (apache#2795)
* target variable can now appear in either lhs or rhs of the expression to be analyzed * removed extra spaces
1 parent 3c2558e commit 320fda0

File tree

2 files changed

+75
-18
lines changed

2 files changed

+75
-18
lines changed

src/arithmetic/bound_deducer.cc

Lines changed: 43 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -207,24 +207,53 @@ void BoundDeducer::Init() {
207207
}
208208

209209
void BoundDeducer::Transform() {
210+
// We will ensure to set expr_ such that it contains target_
210211
if (const LT* op = expr_.as<LT>()) {
211-
is_greater = false;
212-
expr_ = op->a;
213-
// a < b -> a <= b - 1
214-
result = op->b - 1;
212+
if (GetPath(target_, op->a).empty()) {
213+
// a < b -> b >= a + 1
214+
is_greater = true;
215+
expr_ = op->b;
216+
result = op->a + 1;
217+
} else {
218+
// a < b -> a <= b - 1
219+
is_greater = false;
220+
expr_ = op->a;
221+
result = op->b - 1;
222+
}
215223
} else if (const LE* op = expr_.as<LE>()) {
216-
is_greater = false;
217-
expr_ = op->a;
218-
result = op->b;
224+
if (GetPath(target_, op->a).empty()) {
225+
// a <= b -> b >= a
226+
is_greater = true;
227+
expr_ = op->b;
228+
result = op->a;
229+
} else {
230+
is_greater = false;
231+
expr_ = op->a;
232+
result = op->b;
233+
}
219234
} else if (const GT* op = expr_.as<GT>()) {
220-
is_greater = true;
221-
expr_ = op->a;
222-
// a > b -> a >= b + 1
223-
result = op->b + 1;
235+
if (GetPath(target_, op->a).empty()) {
236+
// a > b -> b <= a - 1
237+
is_greater = false;
238+
expr_ = op->b;
239+
result = op->a - 1;
240+
} else {
241+
// a > b -> a >= b + 1
242+
is_greater = true;
243+
expr_ = op->a;
244+
result = op->b + 1;
245+
}
224246
} else if (const GE* op = expr_.as<GE>()) {
225-
is_greater = true;
226-
expr_ = op->a;
227-
result = op->b;
247+
if (GetPath(target_, op->a).empty()) {
248+
// a >= b -> b <= a
249+
is_greater = false;
250+
expr_ = op->b;
251+
result = op->a;
252+
} else {
253+
is_greater = true;
254+
expr_ = op->a;
255+
result = op->b;
256+
}
228257
} else {
229258
success = false;
230259
}

tests/python/unittest/test_arith_intset.py

Lines changed: 32 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -38,32 +38,56 @@ def test_deduce():
3838
b_s = tvm.arith.intset_interval(2, 3)
3939
c_s = tvm.arith.intset_interval(10, 15)
4040
d_s = tvm.arith.intset_interval(-3, -1)
41+
zero = tvm.const(0, "int32")
4142

4243
e0 = (-b)*a+c-d
4344
res0 = tvm.arith.DeduceBound(a, e0>=0, {b: b_s, c: c_s, d: d_s}, {})
4445
ans0 = ((d - c) /(b*-1))
4546
assert str(tvm.ir_pass.Simplify(res0.max())) == str(ans0)
4647

48+
# expression containing variable a is on rhs
49+
res0 = tvm.arith.DeduceBound(a, zero <= e0, {b: b_s, c: c_s, d: d_s}, {})
50+
assert str(tvm.ir_pass.Simplify(res0.max())) == str(ans0)
51+
4752
e0 = d*a+c-d
4853
res0 = tvm.arith.DeduceBound(a, e0>=0, {b: b_s, c: c_s, d: d_s}, {})
4954
ans0 = ((0-c)/d + 1)
5055
assert str(tvm.ir_pass.Simplify(res0.max())) == str(ans0)
5156

57+
# expression containing variable a is on rhs
58+
res0 = tvm.arith.DeduceBound(a, zero <= e0, {b: b_s, c: c_s, d: d_s}, {})
59+
assert str(tvm.ir_pass.Simplify(res0.max())) == str(ans0)
60+
5261
e1 = (a*4+b < c)
5362
res1 = tvm.arith.DeduceBound(a, e1, {b: b_s, c: c_s, d: d_s}, {})
5463
ans1 = (((c - b) + -1)/4)
5564
assert str(tvm.ir_pass.Simplify(res1.max())) == str(ans1)
5665

66+
# expression containing variable a is on rhs
67+
e1 = (c > a*4+b)
68+
res1 = tvm.arith.DeduceBound(a, e1, {b: b_s, c: c_s, d: d_s}, {})
69+
assert str(tvm.ir_pass.Simplify(res1.max())) == str(ans1)
70+
5771
e2 = (tvm.max(5, a * 4) < 0)
5872
res2 = tvm.arith.DeduceBound(a, e2, {b: b_s, c: c_s, d: d_s}, {})
5973
assert str(res2.max()) == "neg_inf"
6074
assert str(res2.min()) == "pos_inf"
6175

76+
# expression containing variable a is on rhs
77+
e2 = (zero < tvm.max(5, a * 4))
78+
res2 = tvm.arith.DeduceBound(a, e2, {b: b_s, c: c_s, d: d_s}, {})
79+
assert str(res2.max()) == "neg_inf"
80+
assert str(res2.min()) == "pos_inf"
81+
82+
6283
e3 = (-b)+a*c-d
6384
res3 = tvm.arith.DeduceBound(a, e3>=0, {b: b_s, c: c_s, d: d_s}, {b: b_s, d: d_s})
6485
ans3 = 2/c+1
6586
assert str(tvm.ir_pass.Simplify(res3.min())) == str(ans3)
6687

88+
res3 = tvm.arith.DeduceBound(a, zero <= e3, {b: b_s, c: c_s, d: d_s}, {b: b_s, d: d_s})
89+
assert str(tvm.ir_pass.Simplify(res3.min())) == str(ans3)
90+
6791
def test_check():
6892
a = tvm.var('a')
6993
b = tvm.var('b')
@@ -97,11 +121,13 @@ def test_basic(a1, a2, coff):
97121
[x, y] = [res1.max(), b_s.max()] if coff > 0 else [res1.min(), b_s.min()]
98122
assert (tvm.ir_pass.Simplify((x * coff + 3 + y) < 17)).value == 1
99123

100-
res1 = tvm.arith.DeduceBound(a, e0>17, {b: b_s}, {b: b_s})
124+
# expression containing variable a is on rhs
125+
res1 = tvm.arith.DeduceBound(a, tvm.const(17, "int32") < e0, {b: b_s}, {b: b_s})
101126
[x, y] = [res1.max(), b_s.max()] if coff < 0 else [res1.min(), b_s.min()]
102127
assert (tvm.ir_pass.Simplify((x * coff + 3 + y) > 17)).value == 1
103128

104-
res1 = tvm.arith.DeduceBound(a, e0<=17, {b: b_s}, {b: b_s})
129+
# expression containing variable a is on rhs
130+
res1 = tvm.arith.DeduceBound(a, tvm.const(17, "int32")>= e0, {b: b_s}, {b: b_s})
105131
[x, y] = [res1.max(), b_s.max()] if coff > 0 else [res1.min(), b_s.min()]
106132
assert (tvm.ir_pass.Simplify((x * coff + 3 + y) <= 17)).value == 1
107133

@@ -127,15 +153,17 @@ def test_complex(a1, a2, coff):
127153
[t, x] = [res1.max(), b_s.max()] if coff > 0 else [res1.min(), b_s.min()]
128154
assert (tvm.ir_pass.Simplify(((x*3 + t* coff) * 4) < 63)).value == 1
129155

130-
res1 = tvm.arith.DeduceBound(a, e0<=63, {b: b_s}, {b: b_s})
156+
# expression containing variable a is on rhs
157+
res1 = tvm.arith.DeduceBound(a, tvm.const(63, "int32")>= e0, {b: b_s}, {b: b_s})
131158
[t, x] = [res1.max(), b_s.max()] if coff > 0 else [res1.min(), b_s.min()]
132159
assert (tvm.ir_pass.Simplify(((x*3 + t* coff) * 4) <= 63)).value == 1
133160

134161
res1 = tvm.arith.DeduceBound(a, e0>63, {b: b_s}, {b: b_s})
135162
[t, x] = [res1.max(), b_s.max()] if coff < 0 else [res1.min(), b_s.min()]
136163
assert (tvm.ir_pass.Simplify(((x*3 + t* coff) * 4) > 63)).value == 1
137164

138-
res1 = tvm.arith.DeduceBound(a, e0>=63, {b: b_s}, {b: b_s})
165+
# expression containing variable a is on rhs
166+
res1 = tvm.arith.DeduceBound(a, tvm.const(63, "int32") <= e0, {b: b_s}, {b: b_s})
139167
[t, x] = [res1.max(), b_s.max()] if coff < 0 else [res1.min(), b_s.min()]
140168
assert (tvm.ir_pass.Simplify(((x*3 + t* coff) * 4) >= 63)).value == 1
141169

0 commit comments

Comments
 (0)