@@ -38,32 +38,56 @@ def test_deduce():
3838 b_s = tvm .arith .intset_interval (2 , 3 )
3939 c_s = tvm .arith .intset_interval (10 , 15 )
4040 d_s = tvm .arith .intset_interval (- 3 , - 1 )
41+ zero = tvm .const (0 , "int32" )
4142
4243 e0 = (- b )* a + c - d
4344 res0 = tvm .arith .DeduceBound (a , e0 >= 0 , {b : b_s , c : c_s , d : d_s }, {})
4445 ans0 = ((d - c ) / (b * - 1 ))
4546 assert str (tvm .ir_pass .Simplify (res0 .max ())) == str (ans0 )
4647
48+ # expression containing variable a is on rhs
49+ res0 = tvm .arith .DeduceBound (a , zero <= e0 , {b : b_s , c : c_s , d : d_s }, {})
50+ assert str (tvm .ir_pass .Simplify (res0 .max ())) == str (ans0 )
51+
4752 e0 = d * a + c - d
4853 res0 = tvm .arith .DeduceBound (a , e0 >= 0 , {b : b_s , c : c_s , d : d_s }, {})
4954 ans0 = ((0 - c )/ d + 1 )
5055 assert str (tvm .ir_pass .Simplify (res0 .max ())) == str (ans0 )
5156
57+ # expression containing variable a is on rhs
58+ res0 = tvm .arith .DeduceBound (a , zero <= e0 , {b : b_s , c : c_s , d : d_s }, {})
59+ assert str (tvm .ir_pass .Simplify (res0 .max ())) == str (ans0 )
60+
5261 e1 = (a * 4 + b < c )
5362 res1 = tvm .arith .DeduceBound (a , e1 , {b : b_s , c : c_s , d : d_s }, {})
5463 ans1 = (((c - b ) + - 1 )/ 4 )
5564 assert str (tvm .ir_pass .Simplify (res1 .max ())) == str (ans1 )
5665
66+ # expression containing variable a is on rhs
67+ e1 = (c > a * 4 + b )
68+ res1 = tvm .arith .DeduceBound (a , e1 , {b : b_s , c : c_s , d : d_s }, {})
69+ assert str (tvm .ir_pass .Simplify (res1 .max ())) == str (ans1 )
70+
5771 e2 = (tvm .max (5 , a * 4 ) < 0 )
5872 res2 = tvm .arith .DeduceBound (a , e2 , {b : b_s , c : c_s , d : d_s }, {})
5973 assert str (res2 .max ()) == "neg_inf"
6074 assert str (res2 .min ()) == "pos_inf"
6175
76+ # expression containing variable a is on rhs
77+ e2 = (zero < tvm .max (5 , a * 4 ))
78+ res2 = tvm .arith .DeduceBound (a , e2 , {b : b_s , c : c_s , d : d_s }, {})
79+ assert str (res2 .max ()) == "neg_inf"
80+ assert str (res2 .min ()) == "pos_inf"
81+
82+
6283 e3 = (- b )+ a * c - d
6384 res3 = tvm .arith .DeduceBound (a , e3 >= 0 , {b : b_s , c : c_s , d : d_s }, {b : b_s , d : d_s })
6485 ans3 = 2 / c + 1
6586 assert str (tvm .ir_pass .Simplify (res3 .min ())) == str (ans3 )
6687
88+ res3 = tvm .arith .DeduceBound (a , zero <= e3 , {b : b_s , c : c_s , d : d_s }, {b : b_s , d : d_s })
89+ assert str (tvm .ir_pass .Simplify (res3 .min ())) == str (ans3 )
90+
6791def test_check ():
6892 a = tvm .var ('a' )
6993 b = tvm .var ('b' )
@@ -97,11 +121,13 @@ def test_basic(a1, a2, coff):
97121 [x , y ] = [res1 .max (), b_s .max ()] if coff > 0 else [res1 .min (), b_s .min ()]
98122 assert (tvm .ir_pass .Simplify ((x * coff + 3 + y ) < 17 )).value == 1
99123
100- res1 = tvm .arith .DeduceBound (a , e0 > 17 , {b : b_s }, {b : b_s })
124+ # expression containing variable a is on rhs
125+ res1 = tvm .arith .DeduceBound (a , tvm .const (17 , "int32" ) < e0 , {b : b_s }, {b : b_s })
101126 [x , y ] = [res1 .max (), b_s .max ()] if coff < 0 else [res1 .min (), b_s .min ()]
102127 assert (tvm .ir_pass .Simplify ((x * coff + 3 + y ) > 17 )).value == 1
103128
104- res1 = tvm .arith .DeduceBound (a , e0 <= 17 , {b : b_s }, {b : b_s })
129+ # expression containing variable a is on rhs
130+ res1 = tvm .arith .DeduceBound (a , tvm .const (17 , "int32" )>= e0 , {b : b_s }, {b : b_s })
105131 [x , y ] = [res1 .max (), b_s .max ()] if coff > 0 else [res1 .min (), b_s .min ()]
106132 assert (tvm .ir_pass .Simplify ((x * coff + 3 + y ) <= 17 )).value == 1
107133
@@ -127,15 +153,17 @@ def test_complex(a1, a2, coff):
127153 [t , x ] = [res1 .max (), b_s .max ()] if coff > 0 else [res1 .min (), b_s .min ()]
128154 assert (tvm .ir_pass .Simplify (((x * 3 + t * coff ) * 4 ) < 63 )).value == 1
129155
130- res1 = tvm .arith .DeduceBound (a , e0 <= 63 , {b : b_s }, {b : b_s })
156+ # expression containing variable a is on rhs
157+ res1 = tvm .arith .DeduceBound (a , tvm .const (63 , "int32" )>= e0 , {b : b_s }, {b : b_s })
131158 [t , x ] = [res1 .max (), b_s .max ()] if coff > 0 else [res1 .min (), b_s .min ()]
132159 assert (tvm .ir_pass .Simplify (((x * 3 + t * coff ) * 4 ) <= 63 )).value == 1
133160
134161 res1 = tvm .arith .DeduceBound (a , e0 > 63 , {b : b_s }, {b : b_s })
135162 [t , x ] = [res1 .max (), b_s .max ()] if coff < 0 else [res1 .min (), b_s .min ()]
136163 assert (tvm .ir_pass .Simplify (((x * 3 + t * coff ) * 4 ) > 63 )).value == 1
137164
138- res1 = tvm .arith .DeduceBound (a , e0 >= 63 , {b : b_s }, {b : b_s })
165+ # expression containing variable a is on rhs
166+ res1 = tvm .arith .DeduceBound (a , tvm .const (63 , "int32" ) <= e0 , {b : b_s }, {b : b_s })
139167 [t , x ] = [res1 .max (), b_s .max ()] if coff < 0 else [res1 .min (), b_s .min ()]
140168 assert (tvm .ir_pass .Simplify (((x * 3 + t * coff ) * 4 ) >= 63 )).value == 1
141169
0 commit comments