@@ -69,6 +69,33 @@ def test_bound3():
6969 assert (bounds [A1 .op .axis [0 ]].extent .value == 32 )
7070 assert (bounds [A1 .op .axis [1 ]].extent .value == 16 )
7171
72+ def test_bound_split_divisible ():
73+ m = tvm .var ('m' )
74+ l = tvm .var ('l' )
75+ A = tvm .placeholder ((8 * m , l ), name = 'A' )
76+ B = tvm .compute ((8 * m , l ), lambda i , j : A [i , j ], name = 'B' )
77+ s = tvm .create_schedule (B .op )
78+ xo , xi = s [B ].split (B .op .axis [0 ], 8 )
79+ bounds = tvm .schedule .InferBound (s )
80+ assert isinstance (bounds , tvm .container .Map )
81+ assert bounds [xo ].extent == m
82+ assert bounds [xi ].extent .value == 8
83+
84+ def test_bound_tile_divisible ():
85+ m = tvm .var ('m' )
86+ l = tvm .var ('l' )
87+ shape = (8 * m , 32 * l )
88+ A = tvm .placeholder (shape , name = 'A' )
89+ B = tvm .compute (shape , lambda i , j : A [i , j ], name = 'B' )
90+ s = tvm .create_schedule (B .op )
91+ xo , yo , xi , yi = s [B ].tile (B .op .axis [0 ], B .op .axis [1 ], 8 , 32 )
92+ bounds = tvm .schedule .InferBound (s )
93+ assert isinstance (bounds , tvm .container .Map )
94+ assert bounds [xo ].extent == m
95+ assert bounds [xi ].extent .value == 8
96+ assert bounds [yo ].extent == l
97+ assert bounds [yi ].extent .value == 32
98+
7299def test_bound_fusesplit1 ():
73100 m = tvm .var ('m' )
74101 l = tvm .var ('l' )
@@ -393,3 +420,5 @@ def _check(B, A=A):
393420 test_bound_simplification_failure ()
394421 test_bound_fusesplit1 ()
395422 test_bound_fusesplit2 ()
423+ test_bound_split_divisible ()
424+ test_bound_tile_divisible ()
0 commit comments