@@ -69,6 +69,8 @@ std::vector<const Node*> GetPath(Expr target, Expr expr) {
6969 return v.path_ ;
7070}
7171
72+ enum CompareOp {kGreater , kLess , kEqual };
73+
7274// a visitor to deduce the bound of a variable from a expression
7375class BoundDeducer : public IRVisitor {
7476 public:
@@ -120,7 +122,7 @@ class BoundDeducer: public IRVisitor {
120122 } else {
121123 result_ -= op->a ;
122124 result_ = - result_;
123- is_greater_ = !is_greater_ ;
125+ comp_op = ReverseOp (comp_op) ;
124126 }
125127 Visit (left ? op->a : op->b );
126128 }
@@ -138,7 +140,7 @@ class BoundDeducer: public IRVisitor {
138140 }
139141
140142 if (sign_operand == SignType::kNegative ) {
141- is_greater_ = !is_greater_ ;
143+ comp_op = ReverseOp (comp_op) ;
142144 } else if (sign_operand == SignType::kUnknown ) {
143145 // unable to get the sign of operand
144146 success_ = false ;
@@ -154,8 +156,12 @@ class BoundDeducer: public IRVisitor {
154156 // NOTE: this accounts for truc div behavior.
155157 bool target_is_non_neg = expr_map_[target_var].can_prove_non_negative ();
156158
157- if (is_greater_ ) {
159+ if (comp_op == kGreater ) {
158160 result_ += 1 ;
161+ } else if (comp_op == kEqual ) {
162+ // condition unsatisfiable as with truc div, it will change the expression
163+ success_ = false ;
164+ return ;
159165 } else {
160166 // NOTE: this is a bit sutble hack.
161167 //
@@ -185,14 +191,14 @@ class BoundDeducer: public IRVisitor {
185191 }
186192
187193 Expr result_;
188- bool is_greater_{ true };
194+ CompareOp comp_op{ kGreater };
189195 bool success_{true };
190196
191197 private:
192198 void Init ();
193199 void Transform ();
194200 void Relax ();
195-
201+ CompareOp ReverseOp (CompareOp comp_op);
196202 Expr target_;
197203 Expr expr_;
198204 const std::unordered_map<const Variable*, IntSet>& hint_map_;
@@ -228,51 +234,71 @@ void BoundDeducer::Init() {
228234 Transform ();
229235}
230236
237+ CompareOp BoundDeducer::ReverseOp (CompareOp comp_op) {
238+ switch (comp_op) {
239+ case kEqual : return kEqual ; // IntSet can not represent range for `NE
240+ case kGreater : return kLess ;
241+ case kLess : return kGreater ;
242+ default :
243+ return kGreater ;
244+ }
245+ }
246+
231247void BoundDeducer::Transform () {
232248 // We will ensure to set expr_ such that it contains target_
233249 if (const LT* op = expr_.as <LT>()) {
234250 if (GetPath (target_, op->a ).empty ()) {
235251 // a < b -> b >= a + 1
236- is_greater_ = true ;
252+ comp_op = kGreater ;
237253 expr_ = op->b ;
238254 result_ = op->a + 1 ;
239255 } else {
240256 // a < b -> a <= b - 1
241- is_greater_ = false ;
257+ comp_op = kLess ;
242258 expr_ = op->a ;
243259 result_ = op->b - 1 ;
244260 }
245261 } else if (const LE* op = expr_.as <LE>()) {
246262 if (GetPath (target_, op->a ).empty ()) {
247263 // a <= b -> b >= a
248- is_greater_ = true ;
264+ comp_op = kGreater ;
249265 expr_ = op->b ;
250266 result_ = op->a ;
251267 } else {
252- is_greater_ = false ;
268+ comp_op = kLess ;
253269 expr_ = op->a ;
254270 result_ = op->b ;
255271 }
256272 } else if (const GT* op = expr_.as <GT>()) {
257273 if (GetPath (target_, op->a ).empty ()) {
258274 // a > b -> b <= a - 1
259- is_greater_ = false ;
275+ comp_op = kLess ;
260276 expr_ = op->b ;
261277 result_ = op->a - 1 ;
262278 } else {
263279 // a > b -> a >= b + 1
264- is_greater_ = true ;
280+ comp_op = kGreater ;
265281 expr_ = op->a ;
266282 result_ = op->b + 1 ;
267283 }
268284 } else if (const GE* op = expr_.as <GE>()) {
269285 if (GetPath (target_, op->a ).empty ()) {
270286 // a >= b -> b <= a
271- is_greater_ = false ;
287+ comp_op = kLess ;
288+ expr_ = op->b ;
289+ result_ = op->a ;
290+ } else {
291+ comp_op = kGreater ;
292+ expr_ = op->a ;
293+ result_ = op->b ;
294+ }
295+ } else if (const EQ* op = expr_.as <EQ>()) {
296+ comp_op = kEqual ;
297+ if (GetPath (target_, op->a ).empty ()) {
298+ // if the b == a -> a == b
272299 expr_ = op->b ;
273300 result_ = op->a ;
274301 } else {
275- is_greater_ = true ;
276302 expr_ = op->a ;
277303 result_ = op->b ;
278304 }
@@ -304,8 +330,16 @@ void BoundDeducer::Relax() {
304330 success_ = false ;
305331 return ;
306332 }
307- expr_ = is_greater_ ? a.min () : a.max ();
308- result_ = is_greater_ ? b.max () : b.min ();
333+ // Both LHS and RHS of the EQ should behave as constants e.g. i == j,
334+ // can not be resolved when either `i` or `j` or both are variables with
335+ // some Range OR `i` and `j` both should be a single point in IntSet
336+ if (comp_op == kEqual && (!analyzer_.CanProve (b.min () == b.max ())
337+ || !analyzer_.CanProve (a.min () == a.max ()))) {
338+ success_ = false ;
339+ return ;
340+ }
341+ expr_ = (comp_op == kGreater ) ? a.min () : a.max ();
342+ result_ = (comp_op == kGreater ) ? b.max () : b.min ();
309343}
310344
311345IntSet DeduceBound (Expr v, Expr e,
@@ -315,7 +349,10 @@ IntSet DeduceBound(Expr v, Expr e,
315349 d.Deduce ();
316350 if (!d.success_ ) return IntSet::nothing ();
317351 Expr min = neg_inf (), max = pos_inf ();
318- if (d.is_greater_ ) {
352+ if (d.comp_op == kEqual ) {
353+ min = d.result_ ;
354+ max = d.result_ ;
355+ } else if (d.comp_op == kGreater ) {
319356 min = d.result_ ;
320357 } else {
321358 max = d.result_ ;
0 commit comments