@@ -69,6 +69,8 @@ std::vector<const Node*> GetPath(Expr target, Expr expr) {
6969 return v.path_ ;
7070}
7171
72+ enum CompareOp {kGreater , kLess , kEqual };
73+
7274// a visitor to deduce the bound of a variable from a expression
7375class BoundDeducer : public IRVisitor {
7476 public:
@@ -120,7 +122,7 @@ class BoundDeducer: public IRVisitor {
120122 } else {
121123 result_ -= op->a ;
122124 result_ = - result_;
123- is_greater_ = !is_greater_ ;
125+ comp_op = ReverseOp (comp_op) ;
124126 }
125127 Visit (left ? op->a : op->b );
126128 }
@@ -138,7 +140,7 @@ class BoundDeducer: public IRVisitor {
138140 }
139141
140142 if (sign_operand == SignType::kNegative ) {
141- is_greater_ = !is_greater_ ;
143+ comp_op = ReverseOp (comp_op) ;
142144 } else if (sign_operand == SignType::kUnknown ) {
143145 // unable to get the sign of operand
144146 success_ = false ;
@@ -151,11 +153,15 @@ class BoundDeducer: public IRVisitor {
151153
152154 if (!divided) {
153155 // Handle non-divisible case
154- // NOTE: this accounts for truc div behavior.
156+ // NOTE: this accounts for trunc div behavior.
155157 bool target_is_non_neg = expr_map_[target_var].can_prove_non_negative ();
156158
157- if (is_greater_ ) {
159+ if (comp_op == kGreater ) {
158160 result_ += 1 ;
161+ } else if (comp_op == kEqual ) {
162+ // condition unsatisfiable as with trunc div, it will change the expression
163+ success_ = false ;
164+ return ;
159165 } else {
160166 // NOTE: this is a bit sutble hack.
161167 //
@@ -185,14 +191,14 @@ class BoundDeducer: public IRVisitor {
185191 }
186192
187193 Expr result_;
188- bool is_greater_{ true };
194+ CompareOp comp_op{ kGreater };
189195 bool success_{true };
190196
191197 private:
192198 void Init ();
193199 void Transform ();
194200 void Relax ();
195-
201+ CompareOp ReverseOp (CompareOp comp_op);
196202 Expr target_;
197203 Expr expr_;
198204 const std::unordered_map<const Variable*, IntSet>& hint_map_;
@@ -228,51 +234,72 @@ void BoundDeducer::Init() {
228234 Transform ();
229235}
230236
237+ CompareOp BoundDeducer::ReverseOp (CompareOp comp_op) {
238+ switch (comp_op) {
239+ case kEqual : return kEqual ; // IntSet can not represent range for `NE
240+ case kGreater : return kLess ;
241+ case kLess : return kGreater ;
242+ default :
243+ LOG (FATAL) << " Not a valid compare op" ;
244+ return kGreater ; // return some default value
245+ }
246+ }
247+
231248void BoundDeducer::Transform () {
232249 // We will ensure to set expr_ such that it contains target_
233250 if (const LT* op = expr_.as <LT>()) {
234251 if (GetPath (target_, op->a ).empty ()) {
235252 // a < b -> b >= a + 1
236- is_greater_ = true ;
253+ comp_op = kGreater ;
237254 expr_ = op->b ;
238255 result_ = op->a + 1 ;
239256 } else {
240257 // a < b -> a <= b - 1
241- is_greater_ = false ;
258+ comp_op = kLess ;
242259 expr_ = op->a ;
243260 result_ = op->b - 1 ;
244261 }
245262 } else if (const LE* op = expr_.as <LE>()) {
246263 if (GetPath (target_, op->a ).empty ()) {
247264 // a <= b -> b >= a
248- is_greater_ = true ;
265+ comp_op = kGreater ;
249266 expr_ = op->b ;
250267 result_ = op->a ;
251268 } else {
252- is_greater_ = false ;
269+ comp_op = kLess ;
253270 expr_ = op->a ;
254271 result_ = op->b ;
255272 }
256273 } else if (const GT* op = expr_.as <GT>()) {
257274 if (GetPath (target_, op->a ).empty ()) {
258275 // a > b -> b <= a - 1
259- is_greater_ = false ;
276+ comp_op = kLess ;
260277 expr_ = op->b ;
261278 result_ = op->a - 1 ;
262279 } else {
263280 // a > b -> a >= b + 1
264- is_greater_ = true ;
281+ comp_op = kGreater ;
265282 expr_ = op->a ;
266283 result_ = op->b + 1 ;
267284 }
268285 } else if (const GE* op = expr_.as <GE>()) {
269286 if (GetPath (target_, op->a ).empty ()) {
270287 // a >= b -> b <= a
271- is_greater_ = false ;
288+ comp_op = kLess ;
289+ expr_ = op->b ;
290+ result_ = op->a ;
291+ } else {
292+ comp_op = kGreater ;
293+ expr_ = op->a ;
294+ result_ = op->b ;
295+ }
296+ } else if (const EQ* op = expr_.as <EQ>()) {
297+ comp_op = kEqual ;
298+ if (GetPath (target_, op->a ).empty ()) {
299+ // if the b == a -> a == b
272300 expr_ = op->b ;
273301 result_ = op->a ;
274302 } else {
275- is_greater_ = true ;
276303 expr_ = op->a ;
277304 result_ = op->b ;
278305 }
@@ -304,8 +331,16 @@ void BoundDeducer::Relax() {
304331 success_ = false ;
305332 return ;
306333 }
307- expr_ = is_greater_ ? a.min () : a.max ();
308- result_ = is_greater_ ? b.max () : b.min ();
334+ // Both LHS and RHS of the EQ should behave as constants e.g. i == j,
335+ // can not be resolved when either `i` or `j` or both are variables with
336+ // some Range OR `i` and `j` both should be a single point in IntSet
337+ if (comp_op == kEqual && (!analyzer_.CanProve (b.min () == b.max ())
338+ || !analyzer_.CanProve (a.min () == a.max ()))) {
339+ success_ = false ;
340+ return ;
341+ }
342+ expr_ = (comp_op == kGreater ) ? a.min () : a.max ();
343+ result_ = (comp_op == kGreater ) ? b.max () : b.min ();
309344}
310345
311346IntSet DeduceBound (Expr v, Expr e,
@@ -315,7 +350,10 @@ IntSet DeduceBound(Expr v, Expr e,
315350 d.Deduce ();
316351 if (!d.success_ ) return IntSet::nothing ();
317352 Expr min = neg_inf (), max = pos_inf ();
318- if (d.is_greater_ ) {
353+ if (d.comp_op == kEqual ) {
354+ min = d.result_ ;
355+ max = d.result_ ;
356+ } else if (d.comp_op == kGreater ) {
319357 min = d.result_ ;
320358 } else {
321359 max = d.result_ ;
0 commit comments