@@ -1066,74 +1066,5 @@ def main(x: R.Tensor((2, 3, 4, 5), dtype="float32")) -> R.Tensor((3, 4), dtype="
10661066 tvm .ir .assert_structural_equal (mod , Expected )
10671067
10681068
1069- def test_cumsum ():
1070- # fmt: off
1071- @I .ir_module
1072- class Cumsum :
1073- @R .function
1074- def main (x : R .Tensor ((3 , 2 , 3 ), "float32" )):
1075- gv = R .cumsum (x , axis = 1 , dtype = "int32" )
1076- return gv
1077-
1078- @I .ir_module
1079- class Expected :
1080- @T .prim_func (private = True )
1081- def cumsum (var_rxplaceholder : T .handle , out_buf : T .Buffer ((T .int64 (3 ), T .int64 (2 ), T .int64 (3 )), "int32" )):
1082- T .func_attr ({"tir.noalias" : True })
1083- rxplaceholder = T .match_buffer (var_rxplaceholder , (T .int64 (3 ), T .int64 (2 ), T .int64 (3 )), offset_factor = 1 )
1084- with T .block ("cumsum_generic" ):
1085- for fused in T .parallel (T .int64 (9 )):
1086- out_buf [(fused // T .int64 (3 ) * T .int64 (2 ) * T .int64 (3 ) + fused % T .int64 (3 )) // T .int64 (3 ) // T .int64 (2 ), (fused // T .int64 (3 ) * T .int64 (2 ) * T .int64 (3 ) + fused % T .int64 (3 )) // T .int64 (3 ) % T .int64 (2 ), (fused // T .int64 (3 ) * T .int64 (2 ) * T .int64 (3 ) + fused % T .int64 (3 )) % T .int64 (3 )] = T .Cast ("int32" , rxplaceholder [(fused // T .int64 (3 ) * T .int64 (2 ) * T .int64 (3 ) + fused % T .int64 (3 )) // T .int64 (3 ) // T .int64 (2 ), (fused // T .int64 (3 ) * T .int64 (2 ) * T .int64 (3 ) + fused % T .int64 (3 )) // T .int64 (3 ) % T .int64 (2 ), (fused // T .int64 (3 ) * T .int64 (2 ) * T .int64 (3 ) + fused % T .int64 (3 )) % T .int64 (3 )])
1087- for _k in range (T .int64 (1 )):
1088- out_buf [(fused // T .int64 (3 ) * T .int64 (2 ) * T .int64 (3 ) + fused % T .int64 (3 ) + (_k + T .int64 (1 )) * T .int64 (3 )) // T .int64 (3 ) // T .int64 (2 ), (fused // T .int64 (3 ) * T .int64 (2 ) * T .int64 (3 ) + fused % T .int64 (3 ) + (_k + T .int64 (1 )) * T .int64 (3 )) // T .int64 (3 ) % T .int64 (2 ), (fused // T .int64 (3 ) * T .int64 (2 ) * T .int64 (3 ) + fused % T .int64 (3 ) + (_k + T .int64 (1 )) * T .int64 (3 )) % T .int64 (3 )] = out_buf [(fused // T .int64 (3 ) * T .int64 (2 ) * T .int64 (3 ) + fused % T .int64 (3 ) + (_k + T .int64 (1 ) - T .int64 (1 )) * T .int64 (3 )) // T .int64 (3 ) // T .int64 (2 ), (fused // T .int64 (3 ) * T .int64 (2 ) * T .int64 (3 ) + fused % T .int64 (3 ) + (_k + T .int64 (1 ) - T .int64 (1 )) * T .int64 (3 )) // T .int64 (3 ) % T .int64 (2 ), (fused // T .int64 (3 ) * T .int64 (2 ) * T .int64 (3 ) + fused % T .int64 (3 ) + (_k + T .int64 (1 ) - T .int64 (1 )) * T .int64 (3 )) % T .int64 (3 )] + T .Cast ("int32" , rxplaceholder [(fused // T .int64 (3 ) * T .int64 (2 ) * T .int64 (3 ) + fused % T .int64 (3 ) + (_k + T .int64 (1 )) * T .int64 (3 )) // T .int64 (3 ) // T .int64 (2 ), (fused // T .int64 (3 ) * T .int64 (2 ) * T .int64 (3 ) + fused % T .int64 (3 ) + (_k + T .int64 (1 )) * T .int64 (3 )) // T .int64 (3 ) % T .int64 (2 ), (fused // T .int64 (3 ) * T .int64 (2 ) * T .int64 (3 ) + fused % T .int64 (3 ) + (_k + T .int64 (1 )) * T .int64 (3 )) % T .int64 (3 )])
1089-
1090- @R .function
1091- def main (x : R .Tensor ((3 , 2 , 3 ), dtype = "float32" )) -> R .Tensor ((3 , 2 , 3 ), dtype = "int32" ):
1092- cls = Expected
1093- gv = R .call_tir (cls .cumsum , (x ,), out_sinfo = R .Tensor ((3 , 2 , 3 ), dtype = "int32" ))
1094- return gv
1095- # fmt: on
1096-
1097- mod = LegalizeOps ()(Cumsum )
1098- tvm .ir .assert_structural_equal (mod , Expected )
1099-
1100-
1101- def test_cumsum_symbolic ():
1102- # fmt: off
1103- @I .ir_module
1104- class Cumsum :
1105- @R .function
1106- def main (x : R .Tensor (("a" , "b" , "c" ), "float32" )):
1107- gv = R .cumsum (x , axis = 1 , dtype = "int32" )
1108- return gv
1109-
1110- @I .ir_module
1111- class Expected :
1112- @T .prim_func (private = True )
1113- def cumsum (var_rxplaceholder : T .handle , var_cumsum_generic : T .handle ):
1114- T .func_attr ({"tir.noalias" : True })
1115- a , b , c = T .int64 (), T .int64 (), T .int64 ()
1116- rxplaceholder = T .match_buffer (var_rxplaceholder , (a , b , c ), offset_factor = 1 )
1117- out_buf = T .match_buffer (var_cumsum_generic , (a , b , c ), "int32" )
1118- with T .block ("cumsum_generic" ):
1119- for fused in T .parallel (a * c ):
1120- out_buf [(fused // c * b * c + fused % c ) // c // b , (fused // c * b * c + fused % c ) // c % b , (fused // c * b * c + fused % c ) % c ] = T .Cast ("int32" , rxplaceholder [(fused // c * b * c + fused % c ) // c // b , (fused // c * b * c + fused % c ) // c % b , (fused // c * b * c + fused % c ) % c ])
1121- for _k in range (b - T .int64 (1 )):
1122- out_buf [(fused // c * b * c + fused % c + (_k + T .int64 (1 )) * c ) // c // b , (fused // c * b * c + fused % c + (_k + T .int64 (1 )) * c ) // c % b , (fused // c * b * c + fused % c + (_k + T .int64 (1 )) * c ) % c ] = out_buf [(fused // c * b * c + fused % c + (_k + T .int64 (1 ) - T .int64 (1 )) * c ) // c // b , (fused // c * b * c + fused % c + (_k + T .int64 (1 ) - T .int64 (1 )) * c ) // c % b , (fused // c * b * c + fused % c + (_k + T .int64 (1 ) - T .int64 (1 )) * c ) % c ] + T .Cast ("int32" , rxplaceholder [(fused // c * b * c + fused % c + (_k + T .int64 (1 )) * c ) // c // b , (fused // c * b * c + fused % c + (_k + T .int64 (1 )) * c ) // c % b , (fused // c * b * c + fused % c + (_k + T .int64 (1 )) * c ) % c ])
1123-
1124- @R .function
1125- def main (x : R .Tensor (("a" , "b" , "c" ), dtype = "float32" )) -> R .Tensor (("a" , "b" , "c" ), dtype = "int32" ):
1126- a = T .int64 ()
1127- b = T .int64 ()
1128- c = T .int64 ()
1129- cls = Expected
1130- gv = R .call_tir (cls .cumsum , (x ,), out_sinfo = R .Tensor ((a , b , c ), dtype = "int32" ))
1131- return gv
1132- # fmt: on
1133-
1134- mod = LegalizeOps ()(Cumsum )
1135- tvm .ir .assert_structural_equal (mod , Expected )
1136-
1137-
11381069if __name__ == "__main__" :
11391070 tvm .testing .main ()
0 commit comments