Skip to content

Commit 876cba0

Browse files
author
Umang Yadav
committed
1) Add EQ op to the deduce_bound and add unittests for the same
2) Add EQ support in the loop partition and add test for the same
1 parent aee16d8 commit 876cba0

File tree

4 files changed

+111
-18
lines changed

4 files changed

+111
-18
lines changed

src/arithmetic/bound_deducer.cc

Lines changed: 53 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -69,6 +69,8 @@ std::vector<const Node*> GetPath(Expr target, Expr expr) {
6969
return v.path_;
7070
}
7171

72+
enum CompareOp {kGreater, kLess, kEqual};
73+
7274
// a visitor to deduce the bound of a variable from a expression
7375
class BoundDeducer: public IRVisitor {
7476
public:
@@ -120,7 +122,7 @@ class BoundDeducer: public IRVisitor {
120122
} else {
121123
result_ -= op->a;
122124
result_ = - result_;
123-
is_greater_ = !is_greater_;
125+
comp_op = ReverseOp(comp_op);
124126
}
125127
Visit(left ? op->a : op->b);
126128
}
@@ -138,7 +140,7 @@ class BoundDeducer: public IRVisitor {
138140
}
139141

140142
if (sign_operand == SignType::kNegative) {
141-
is_greater_ = !is_greater_;
143+
comp_op = ReverseOp(comp_op);
142144
} else if (sign_operand == SignType::kUnknown) {
143145
// unable to get the sign of operand
144146
success_ = false;
@@ -154,8 +156,12 @@ class BoundDeducer: public IRVisitor {
154156
// NOTE: this accounts for truc div behavior.
155157
bool target_is_non_neg = expr_map_[target_var].can_prove_non_negative();
156158

157-
if (is_greater_) {
159+
if (comp_op == kGreater) {
158160
result_ += 1;
161+
} else if (comp_op == kEqual) {
162+
// condition unsatisfiable as with truc div, it will change the expression
163+
success_ = false;
164+
return;
159165
} else {
160166
// NOTE: this is a bit sutble hack.
161167
//
@@ -185,14 +191,14 @@ class BoundDeducer: public IRVisitor {
185191
}
186192

187193
Expr result_;
188-
bool is_greater_{true};
194+
CompareOp comp_op{kGreater};
189195
bool success_{true};
190196

191197
private:
192198
void Init();
193199
void Transform();
194200
void Relax();
195-
201+
CompareOp ReverseOp(CompareOp comp_op);
196202
Expr target_;
197203
Expr expr_;
198204
const std::unordered_map<const Variable*, IntSet>& hint_map_;
@@ -228,51 +234,71 @@ void BoundDeducer::Init() {
228234
Transform();
229235
}
230236

237+
CompareOp BoundDeducer::ReverseOp(CompareOp comp_op) {
238+
switch (comp_op) {
239+
case kEqual: return kEqual; // IntSet can not represent range for `NE
240+
case kGreater: return kLess;
241+
case kLess: return kGreater;
242+
default:
243+
return kGreater;
244+
}
245+
}
246+
231247
void BoundDeducer::Transform() {
232248
// We will ensure to set expr_ such that it contains target_
233249
if (const LT* op = expr_.as<LT>()) {
234250
if (GetPath(target_, op->a).empty()) {
235251
// a < b -> b >= a + 1
236-
is_greater_ = true;
252+
comp_op = kGreater;
237253
expr_ = op->b;
238254
result_ = op->a + 1;
239255
} else {
240256
// a < b -> a <= b - 1
241-
is_greater_ = false;
257+
comp_op = kLess;
242258
expr_ = op->a;
243259
result_ = op->b - 1;
244260
}
245261
} else if (const LE* op = expr_.as<LE>()) {
246262
if (GetPath(target_, op->a).empty()) {
247263
// a <= b -> b >= a
248-
is_greater_ = true;
264+
comp_op = kGreater;
249265
expr_ = op->b;
250266
result_ = op->a;
251267
} else {
252-
is_greater_ = false;
268+
comp_op = kLess;
253269
expr_ = op->a;
254270
result_ = op->b;
255271
}
256272
} else if (const GT* op = expr_.as<GT>()) {
257273
if (GetPath(target_, op->a).empty()) {
258274
// a > b -> b <= a - 1
259-
is_greater_ = false;
275+
comp_op = kLess;
260276
expr_ = op->b;
261277
result_ = op->a - 1;
262278
} else {
263279
// a > b -> a >= b + 1
264-
is_greater_ = true;
280+
comp_op = kGreater;
265281
expr_ = op->a;
266282
result_ = op->b + 1;
267283
}
268284
} else if (const GE* op = expr_.as<GE>()) {
269285
if (GetPath(target_, op->a).empty()) {
270286
// a >= b -> b <= a
271-
is_greater_ = false;
287+
comp_op = kLess;
288+
expr_ = op->b;
289+
result_ = op->a;
290+
} else {
291+
comp_op = kGreater;
292+
expr_ = op->a;
293+
result_ = op->b;
294+
}
295+
} else if (const EQ* op = expr_.as<EQ>()) {
296+
comp_op = kEqual;
297+
if (GetPath(target_, op->a).empty()) {
298+
// if the b == a -> a == b
272299
expr_ = op->b;
273300
result_ = op->a;
274301
} else {
275-
is_greater_ = true;
276302
expr_ = op->a;
277303
result_ = op->b;
278304
}
@@ -304,8 +330,16 @@ void BoundDeducer::Relax() {
304330
success_ = false;
305331
return;
306332
}
307-
expr_ = is_greater_ ? a.min() : a.max();
308-
result_ = is_greater_ ? b.max() : b.min();
333+
// Both LHS and RHS of the EQ should behave as constants e.g. i == j,
334+
// can not be resolved when either `i` or `j` or both are variables with
335+
// some Range OR `i` and `j` both should be a single point in IntSet
336+
if (comp_op == kEqual && (!analyzer_.CanProve(b.min() == b.max())
337+
|| !analyzer_.CanProve(a.min() == a.max()))) {
338+
success_ = false;
339+
return;
340+
}
341+
expr_ = (comp_op == kGreater) ? a.min() : a.max();
342+
result_ = (comp_op == kGreater) ? b.max() : b.min();
309343
}
310344

311345
IntSet DeduceBound(Expr v, Expr e,
@@ -315,7 +349,10 @@ IntSet DeduceBound(Expr v, Expr e,
315349
d.Deduce();
316350
if (!d.success_) return IntSet::nothing();
317351
Expr min = neg_inf(), max = pos_inf();
318-
if (d.is_greater_) {
352+
if (d.comp_op == kEqual) {
353+
min = d.result_;
354+
max = d.result_;
355+
} else if (d.comp_op == kGreater) {
319356
min = d.result_;
320357
} else {
321358
max = d.result_;

src/pass/loop_partition.cc

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -226,8 +226,6 @@ class PartitionFinder : public IRVisitor {
226226

227227
private:
228228
Expr InverseCond(const Expr& cond) {
229-
// We expect most condition not to be of EQ or NE form.
230-
// Currently we do not handle inversing EQ or NE.
231229
Expr inverse_cond;
232230
if (const LT* op = cond.as<LT>()) {
233231
// a < b -> a >= b
@@ -241,6 +239,12 @@ class PartitionFinder : public IRVisitor {
241239
} else if (const GE* op = cond.as<GE>()) {
242240
// a >= b -> a < b
243241
inverse_cond = LT::make(op->a, op->b);
242+
} else if (const EQ* op = cond.as<EQ>()) {
243+
// a == b -> a != b
244+
inverse_cond = NE::make(op->a, op->b);
245+
// a != b -> a == b
246+
} else if (const NE* op = cond.as<NE>()) {
247+
inverse_cond = EQ::make(op->a, op->b);
244248
}
245249
return inverse_cond;
246250
}

tests/python/unittest/test_arith_deduce_bound.py

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -85,6 +85,44 @@ def test_deduce():
8585
res3 = tvm.arith.DeduceBound(a, zero <= e3, {b: b_s, c: c_s, d: d_s}, {b: b_s, d: d_s})
8686
assert str(tvm.ir_pass.Simplify(res3.min_value)) == str(ans3)
8787

88+
# tests for `EQ` op
89+
res4 = tvm.arith.DeduceBound(a, a == b, {}, {})
90+
assert_expr_equal(res4.max_value, b)
91+
assert_expr_equal(res4.min_value, b)
92+
93+
# Unsatisfiable `EQ`, variable as one of the Operand
94+
res5 = tvm.arith.DeduceBound(a, (a == b), {b: b_s}, {b: b_s})
95+
assert str(res5.max_value) == "neg_inf"
96+
assert str(res5.min_value) == "pos_inf"
97+
98+
# variable `a` on the RHS side
99+
res6 = tvm.arith.DeduceBound(a, 10 == a, {}, {})
100+
assert_expr_equal(res6.max_value, 10)
101+
assert_expr_equal(res6.min_value, 10)
102+
103+
# Add, Sub in `EQ`
104+
e4 = ((a - c) == (b + d))
105+
ans4 = (b + d + c)
106+
res7 = tvm.arith.DeduceBound(a, e4, {b: b_s, c: c_s, d: d_s}, {})
107+
assert_expr_equal(res7.max_value, ans4)
108+
assert_expr_equal(res7.min_value, ans4)
109+
110+
# Satisfiable Mul in `EQ` with negative sign
111+
res8 = tvm.arith.DeduceBound(a, (5 * a == -10), {}, {})
112+
assert_expr_equal(res8.max_value, -2)
113+
assert_expr_equal(res8.min_value, -2)
114+
115+
# Unsatisfiable Mul in `EQ`
116+
e5 = (4 * a == b)
117+
res9 = tvm.arith.DeduceBound(a, e5, {b: b_s}, {})
118+
assert str(res9.max_value) == "neg_inf"
119+
assert str(res9.min_value) == "pos_inf"
120+
121+
# Unsatisfiable Mul in `EQ`
122+
res10 = tvm.arith.DeduceBound(a, (b * a == b), {b: b_s}, {}) # simplifier is not able to prove that (b % b == 0)
123+
assert str(res10.max_value) == "neg_inf"
124+
assert str(res10.min_value) == "pos_inf"
125+
88126

89127
def test_check():
90128
a = tvm.var('a')
@@ -175,5 +213,6 @@ def test_complex(a1, a2, coff):
175213

176214
if __name__ == "__main__":
177215
test_check()
216+
test_deduce()
178217
test_deduce_basic()
179218
test_deduce_complex()

tests/python/unittest/test_pass_loop_partition.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -171,6 +171,18 @@ def test_condition():
171171
stmt = tvm.ir_pass.Simplify(stmt)
172172
assert(not any(collect_visit(stmt.first, lambda x: isinstance(x, tvm.expr.Select))))
173173

174+
def test_condition_EQ():
175+
ib = tvm.ir_builder.create()
176+
m = tvm.var('m')
177+
n = tvm.var('n')
178+
with ib.for_range(0, 10, 'i') as i:
179+
ib.emit(tvm.make.Evaluate(
180+
tvm.make.Select(ib.likely(tvm.expr.EQ(i, 5)), m, n)))
181+
stmt = ib.get()
182+
stmt = tvm.ir_pass.LoopPartition(stmt, True)
183+
stmt = tvm.ir_pass.Simplify(stmt)
184+
assert(not any(collect_visit(stmt.first, lambda x: isinstance(x, tvm.expr.Select))))
185+
174186
def test_thread_axis2():
175187
n = tvm.convert(4096)
176188
m = tvm.var('m')
@@ -420,6 +432,7 @@ def test_simple_rfactor():
420432
test_thread_axis()
421433
test_vectorize()
422434
test_condition()
435+
test_condition_EQ()
423436
test_thread_axis2()
424437
test_everything_during_deduction()
425438
test_single_likely()

0 commit comments

Comments
 (0)