Skip to content

Commit 4b431c6

Browse files
Umang Yadavtqchen
authored andcommitted
1) Add EQ op to the deduce_bound and add unittests for the same (#3775)
2) Add EQ support in the loop partition and add test for the same 3) Change typo truc to trunc
1 parent 2536465 commit 4b431c6

File tree

4 files changed

+113
-19
lines changed

4 files changed

+113
-19
lines changed

src/arithmetic/bound_deducer.cc

Lines changed: 55 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -69,6 +69,8 @@ std::vector<const Node*> GetPath(Expr target, Expr expr) {
6969
return v.path_;
7070
}
7171

72+
enum CompareOp {kGreater, kLess, kEqual};
73+
7274
// a visitor to deduce the bound of a variable from a expression
7375
class BoundDeducer: public IRVisitor {
7476
public:
@@ -120,7 +122,7 @@ class BoundDeducer: public IRVisitor {
120122
} else {
121123
result_ -= op->a;
122124
result_ = - result_;
123-
is_greater_ = !is_greater_;
125+
comp_op = ReverseOp(comp_op);
124126
}
125127
Visit(left ? op->a : op->b);
126128
}
@@ -138,7 +140,7 @@ class BoundDeducer: public IRVisitor {
138140
}
139141

140142
if (sign_operand == SignType::kNegative) {
141-
is_greater_ = !is_greater_;
143+
comp_op = ReverseOp(comp_op);
142144
} else if (sign_operand == SignType::kUnknown) {
143145
// unable to get the sign of operand
144146
success_ = false;
@@ -151,11 +153,15 @@ class BoundDeducer: public IRVisitor {
151153

152154
if (!divided) {
153155
// Handle non-divisible case
154-
// NOTE: this accounts for truc div behavior.
156+
// NOTE: this accounts for trunc div behavior.
155157
bool target_is_non_neg = expr_map_[target_var].can_prove_non_negative();
156158

157-
if (is_greater_) {
159+
if (comp_op == kGreater) {
158160
result_ += 1;
161+
} else if (comp_op == kEqual) {
162+
// condition unsatisfiable as with trunc div, it will change the expression
163+
success_ = false;
164+
return;
159165
} else {
160166
// NOTE: this is a bit sutble hack.
161167
//
@@ -185,14 +191,14 @@ class BoundDeducer: public IRVisitor {
185191
}
186192

187193
Expr result_;
188-
bool is_greater_{true};
194+
CompareOp comp_op{kGreater};
189195
bool success_{true};
190196

191197
private:
192198
void Init();
193199
void Transform();
194200
void Relax();
195-
201+
CompareOp ReverseOp(CompareOp comp_op);
196202
Expr target_;
197203
Expr expr_;
198204
const std::unordered_map<const Variable*, IntSet>& hint_map_;
@@ -228,51 +234,72 @@ void BoundDeducer::Init() {
228234
Transform();
229235
}
230236

237+
CompareOp BoundDeducer::ReverseOp(CompareOp comp_op) {
238+
switch (comp_op) {
239+
case kEqual: return kEqual; // IntSet can not represent range for `NE
240+
case kGreater: return kLess;
241+
case kLess: return kGreater;
242+
default:
243+
LOG(FATAL) << "Not a valid compare op";
244+
return kGreater; // return some default value
245+
}
246+
}
247+
231248
void BoundDeducer::Transform() {
232249
// We will ensure to set expr_ such that it contains target_
233250
if (const LT* op = expr_.as<LT>()) {
234251
if (GetPath(target_, op->a).empty()) {
235252
// a < b -> b >= a + 1
236-
is_greater_ = true;
253+
comp_op = kGreater;
237254
expr_ = op->b;
238255
result_ = op->a + 1;
239256
} else {
240257
// a < b -> a <= b - 1
241-
is_greater_ = false;
258+
comp_op = kLess;
242259
expr_ = op->a;
243260
result_ = op->b - 1;
244261
}
245262
} else if (const LE* op = expr_.as<LE>()) {
246263
if (GetPath(target_, op->a).empty()) {
247264
// a <= b -> b >= a
248-
is_greater_ = true;
265+
comp_op = kGreater;
249266
expr_ = op->b;
250267
result_ = op->a;
251268
} else {
252-
is_greater_ = false;
269+
comp_op = kLess;
253270
expr_ = op->a;
254271
result_ = op->b;
255272
}
256273
} else if (const GT* op = expr_.as<GT>()) {
257274
if (GetPath(target_, op->a).empty()) {
258275
// a > b -> b <= a - 1
259-
is_greater_ = false;
276+
comp_op = kLess;
260277
expr_ = op->b;
261278
result_ = op->a - 1;
262279
} else {
263280
// a > b -> a >= b + 1
264-
is_greater_ = true;
281+
comp_op = kGreater;
265282
expr_ = op->a;
266283
result_ = op->b + 1;
267284
}
268285
} else if (const GE* op = expr_.as<GE>()) {
269286
if (GetPath(target_, op->a).empty()) {
270287
// a >= b -> b <= a
271-
is_greater_ = false;
288+
comp_op = kLess;
289+
expr_ = op->b;
290+
result_ = op->a;
291+
} else {
292+
comp_op = kGreater;
293+
expr_ = op->a;
294+
result_ = op->b;
295+
}
296+
} else if (const EQ* op = expr_.as<EQ>()) {
297+
comp_op = kEqual;
298+
if (GetPath(target_, op->a).empty()) {
299+
// if the b == a -> a == b
272300
expr_ = op->b;
273301
result_ = op->a;
274302
} else {
275-
is_greater_ = true;
276303
expr_ = op->a;
277304
result_ = op->b;
278305
}
@@ -304,8 +331,16 @@ void BoundDeducer::Relax() {
304331
success_ = false;
305332
return;
306333
}
307-
expr_ = is_greater_ ? a.min() : a.max();
308-
result_ = is_greater_ ? b.max() : b.min();
334+
// Both LHS and RHS of the EQ should behave as constants e.g. i == j,
335+
// can not be resolved when either `i` or `j` or both are variables with
336+
// some Range OR `i` and `j` both should be a single point in IntSet
337+
if (comp_op == kEqual && (!analyzer_.CanProve(b.min() == b.max())
338+
|| !analyzer_.CanProve(a.min() == a.max()))) {
339+
success_ = false;
340+
return;
341+
}
342+
expr_ = (comp_op == kGreater) ? a.min() : a.max();
343+
result_ = (comp_op == kGreater) ? b.max() : b.min();
309344
}
310345

311346
IntSet DeduceBound(Expr v, Expr e,
@@ -315,7 +350,10 @@ IntSet DeduceBound(Expr v, Expr e,
315350
d.Deduce();
316351
if (!d.success_) return IntSet::nothing();
317352
Expr min = neg_inf(), max = pos_inf();
318-
if (d.is_greater_) {
353+
if (d.comp_op == kEqual) {
354+
min = d.result_;
355+
max = d.result_;
356+
} else if (d.comp_op == kGreater) {
319357
min = d.result_;
320358
} else {
321359
max = d.result_;

src/pass/loop_partition.cc

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -226,8 +226,6 @@ class PartitionFinder : public IRVisitor {
226226

227227
private:
228228
Expr InverseCond(const Expr& cond) {
229-
// We expect most condition not to be of EQ or NE form.
230-
// Currently we do not handle inversing EQ or NE.
231229
Expr inverse_cond;
232230
if (const LT* op = cond.as<LT>()) {
233231
// a < b -> a >= b
@@ -241,6 +239,12 @@ class PartitionFinder : public IRVisitor {
241239
} else if (const GE* op = cond.as<GE>()) {
242240
// a >= b -> a < b
243241
inverse_cond = LT::make(op->a, op->b);
242+
} else if (const EQ* op = cond.as<EQ>()) {
243+
// a == b -> a != b
244+
inverse_cond = NE::make(op->a, op->b);
245+
// a != b -> a == b
246+
} else if (const NE* op = cond.as<NE>()) {
247+
inverse_cond = EQ::make(op->a, op->b);
244248
}
245249
return inverse_cond;
246250
}

tests/python/unittest/test_arith_deduce_bound.py

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -85,6 +85,44 @@ def test_deduce():
8585
res3 = tvm.arith.DeduceBound(a, zero <= e3, {b: b_s, c: c_s, d: d_s}, {b: b_s, d: d_s})
8686
assert str(tvm.ir_pass.Simplify(res3.min_value)) == str(ans3)
8787

88+
# tests for `EQ` op
89+
res4 = tvm.arith.DeduceBound(a, a == b, {}, {})
90+
assert_expr_equal(res4.max_value, b)
91+
assert_expr_equal(res4.min_value, b)
92+
93+
# Unsatisfiable `EQ`, variable as one of the Operand
94+
res5 = tvm.arith.DeduceBound(a, (a == b), {b: b_s}, {b: b_s})
95+
assert str(res5.max_value) == "neg_inf"
96+
assert str(res5.min_value) == "pos_inf"
97+
98+
# variable `a` on the RHS side
99+
res6 = tvm.arith.DeduceBound(a, 10 == a, {}, {})
100+
assert_expr_equal(res6.max_value, 10)
101+
assert_expr_equal(res6.min_value, 10)
102+
103+
# Add, Sub in `EQ`
104+
e4 = ((a - c) == (b + d))
105+
ans4 = (b + d + c)
106+
res7 = tvm.arith.DeduceBound(a, e4, {b: b_s, c: c_s, d: d_s}, {})
107+
assert_expr_equal(res7.max_value, ans4)
108+
assert_expr_equal(res7.min_value, ans4)
109+
110+
# Satisfiable Mul in `EQ` with negative sign
111+
res8 = tvm.arith.DeduceBound(a, (5 * a == -10), {}, {})
112+
assert_expr_equal(res8.max_value, -2)
113+
assert_expr_equal(res8.min_value, -2)
114+
115+
# Unsatisfiable Mul in `EQ`
116+
e5 = (4 * a == b)
117+
res9 = tvm.arith.DeduceBound(a, e5, {b: b_s}, {})
118+
assert str(res9.max_value) == "neg_inf"
119+
assert str(res9.min_value) == "pos_inf"
120+
121+
# Unsatisfiable Mul in `EQ`
122+
res10 = tvm.arith.DeduceBound(a, (b * a == b), {b: b_s}, {}) # simplifier is not able to prove that (b % b == 0)
123+
assert str(res10.max_value) == "neg_inf"
124+
assert str(res10.min_value) == "pos_inf"
125+
88126

89127
def test_check():
90128
a = tvm.var('a')
@@ -175,5 +213,6 @@ def test_complex(a1, a2, coff):
175213

176214
if __name__ == "__main__":
177215
test_check()
216+
test_deduce()
178217
test_deduce_basic()
179218
test_deduce_complex()

tests/python/unittest/test_pass_loop_partition.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -171,6 +171,18 @@ def test_condition():
171171
stmt = tvm.ir_pass.Simplify(stmt)
172172
assert(not any(collect_visit(stmt.first, lambda x: isinstance(x, tvm.expr.Select))))
173173

174+
def test_condition_EQ():
175+
ib = tvm.ir_builder.create()
176+
m = tvm.var('m')
177+
n = tvm.var('n')
178+
with ib.for_range(0, 10, 'i') as i:
179+
ib.emit(tvm.make.Evaluate(
180+
tvm.make.Select(ib.likely(tvm.expr.EQ(i, 5)), m, n)))
181+
stmt = ib.get()
182+
stmt = tvm.ir_pass.LoopPartition(stmt, True)
183+
stmt = tvm.ir_pass.Simplify(stmt)
184+
assert(not any(collect_visit(stmt.first, lambda x: isinstance(x, tvm.expr.Select))))
185+
174186
def test_thread_axis2():
175187
n = tvm.convert(4096)
176188
m = tvm.var('m')
@@ -420,6 +432,7 @@ def test_simple_rfactor():
420432
test_thread_axis()
421433
test_vectorize()
422434
test_condition()
435+
test_condition_EQ()
423436
test_thread_axis2()
424437
test_everything_during_deduction()
425438
test_single_likely()

0 commit comments

Comments
 (0)