@@ -94,31 +94,31 @@ def test_buffer_index_merge_mult_mod():
9494 def assert_simplified_equal (index_simplified , index_direct ):
9595 assert tvm .ir_pass .Equal (index_simplified , index_direct ),\
9696 "index_simplified=%s, index_direct=%s" % (index_simplified , index_direct )
97- idxdiv = tvm .indexdiv
98- idxmod = tvm .indexmod
97+ idxd = tvm .indexdiv
98+ idxm = tvm .indexmod
9999 # Test Case1
100100 index_simplified = A_stride .vload (
101- (idxdiv ( idxmod (k0 , k1 ), s ), idxmod ( idxmod (k0 , k1 ), s ) + idxdiv (k0 , k1 ) * k1 ))
101+ (idxd ( idxm (k0 , k1 ), s ), idxm ( idxm (k0 , k1 ), s ) + idxd (k0 , k1 ) * k1 ))
102102 index_direct = A_stride .vload ((0 , k0 ))
103103 assert_simplified_equal (index_simplified , index_direct )
104104
105105 # Test Case2
106- index_simplified = A .vload ((idxdiv ( idxmod (k0 , idxdiv (k1 , s )), n ),
107- idxmod ( idxmod (k0 , idxdiv (k1 , s )), n ) + idxmod (k0 , k1 )))
108- index_direct = A .vload ((0 , idxmod (k0 , k1 ) + idxmod (k0 , idxdiv (k1 , s ))))
106+ index_simplified = A .vload ((idxd ( idxm (k0 , idxd (k1 , s )), n ),
107+ idxm ( idxm (k0 , idxd (k1 , s )), n ) + idxm (k0 , k1 )))
108+ index_direct = A .vload ((0 , idxm (k0 , k1 ) + idxm (k0 , idxd (k1 , s ))))
109109 assert_simplified_equal (index_simplified , index_direct )
110110 # Test Case3
111- index_simplified = A .vload ((idxdiv (( idxdiv (k0 , idxdiv (k1 , s )) * idxdiv (k1 , s )), n ) +
112- idxdiv ( idxmod (k0 , idxdiv (k1 , s )), n ),
113- idxmod (( idxdiv (k0 , idxdiv (k1 , s )) * idxdiv (k1 , s )), n ) +
114- idxmod ( idxmod (k0 , idxdiv (k1 , s )), n )))
111+ index_simplified = A .vload ((idxd (( idxd (k0 , idxd (k1 , s )) * idxd (k1 , s )), n ) +
112+ idxd ( idxm (k0 , idxd (k1 , s )), n ),
113+ idxm (( idxd (k0 , idxd (k1 , s )) * idxd (k1 , s )), n ) +
114+ idxm ( idxm (k0 , idxd (k1 , s )), n )))
115115 index_direct = A .vload ((0 , k0 ))
116116 assert_simplified_equal (index_simplified , index_direct )
117117 # Test Case4 (not able to simplify)
118- index_simplified = A .vload ((idxdiv ( idxmod (k0 , idxdiv (k1 , s )), n ),
119- idxmod ( idxmod (k0 , idxdiv (k1 , n )), n ) + idxmod (k0 , k1 )))
120- index_direct = A .vload ((0 , idxdiv ( idxmod (k0 , idxdiv (k1 , s )), n ) * n +
121- (idxmod ( idxmod (k0 , idxdiv (k1 , n )), n ) + idxmod (k0 , k1 ))))
118+ index_simplified = A .vload ((idxd ( idxm (k0 , idxd (k1 , s )), n ),
119+ idxm ( idxm (k0 , idxd (k1 , n )), n ) + idxm (k0 , k1 )))
120+ index_direct = A .vload ((0 , idxd ( idxm (k0 , idxd (k1 , s )), n ) * n +
121+ (idxm ( idxm (k0 , idxd (k1 , n )), n ) + idxm (k0 , k1 ))))
122122 assert_simplified_equal (index_simplified , index_direct )
123123
124124
0 commit comments