@@ -1139,7 +1139,7 @@ func.func @fold_collapse_of_expand(%arg0 : tensor<12x4xf32>) -> tensor<12x4xf32>
11391139 return %1 : tensor <12 x4 xf32 >
11401140}
11411141// CHECK-LABEL: @fold_collapse_of_expand
1142- // CHECK-NOT: linalg .{{.*}}shape
1142+ // CHECK-NOT: tensor .{{.*}}_shape
11431143
11441144// -----
11451145
@@ -1152,7 +1152,75 @@ func.func @fold_collapse_of_expand_dynamic(%arg0 : tensor<?x?xf32>, %arg1: index
11521152 return %1 : tensor <?x?xf32 >
11531153}
11541154// CHECK-LABEL: @fold_collapse_of_expand_dynamic
1155- // CHECK-NOT: linalg.{{.*}}_shape
1155+ // CHECK-NOT: tensor.{{.*}}_shape
1156+
1157+ // -----
1158+
1159+ func.func @fold_collapse_of_expand_fully_dynamic (%arg0 : tensor <?x?xf32 >, %arg1: index , %arg2: index , %arg3: index )
1160+ -> tensor <?x?xf32 > {
1161+ %0 = tensor.expand_shape %arg0 [[0 , 1 ], [2 ]] output_shape [%arg1 , %arg2 , %arg3 ]
1162+ : tensor <?x?xf32 > into tensor <?x?x?xf32 >
1163+ %1 = tensor.collapse_shape %0 [[0 , 1 ], [2 ]]
1164+ : tensor <?x?x?xf32 > into tensor <?x?xf32 >
1165+ return %1 : tensor <?x?xf32 >
1166+ }
1167+ // CHECK-LABEL: @fold_collapse_of_expand_fully_dynamic
1168+ // CHECK-NOT: tensor.{{.*}}_shape
1169+
1170+ // -----
1171+
1172+ func.func @no_fold_parallel_collapse_of_expand_dynamic (%arg0 : tensor <?x?x?xf32 >, %arg1: index , %arg2: index , %arg3: index , %arg4: index )
1173+ -> tensor <?x?x?xf32 > {
1174+ %0 = tensor.expand_shape %arg0 [[0 , 1 ], [2 ], [3 ]] output_shape [%arg1 , %arg2 , %arg3 , %arg4 ]
1175+ : tensor <?x?x?xf32 > into tensor <?x?x?x?xf32 >
1176+ %1 = tensor.collapse_shape %0 [[0 ], [1 ], [2 , 3 ]]
1177+ : tensor <?x?x?x?xf32 > into tensor <?x?x?xf32 >
1178+ return %1 : tensor <?x?x?xf32 >
1179+ }
1180+ // CHECK-LABEL: @no_fold_parallel_collapse_of_expand_dynamic
1181+ // CHECK: tensor.expand_shape
1182+ // CHECK: %[[COLLAPSE:.+]] = tensor.collapse_shape
1183+ // CHECK: return %[[COLLAPSE]]
1184+
1185+ // -----
1186+
1187+ func.func @fold_expand_of_collapse (%arg0 : tensor <3 x4 x4 xf32 >) -> tensor <3 x4 x4 xf32 > {
1188+ %0 = tensor.collapse_shape %arg0 [[0 , 1 ], [2 ]]
1189+ : tensor <3 x4 x4 xf32 > into tensor <12 x4 xf32 >
1190+ %1 = tensor.expand_shape %0 [[0 , 1 ], [2 ]] output_shape [3 , 4 , 4 ]
1191+ : tensor <12 x4 xf32 > into tensor <3 x4 x4 xf32 >
1192+ return %1 : tensor <3 x4 x4 xf32 >
1193+ }
1194+ // CHECK-LABEL: @fold_expand_of_collapse
1195+ // CHECK-NOT: tensor.{{.*}}_shape
1196+
1197+ // -----
1198+
1199+ func.func @fold_expand_of_collapse_dynamic (%arg0 : tensor <?x4 x?xf32 >, %arg1: index , %arg2: index )
1200+ -> tensor <?x4 x?xf32 > {
1201+ %0 = tensor.collapse_shape %arg0 [[0 , 1 ], [2 ]]
1202+ : tensor <?x4 x?xf32 > into tensor <?x?xf32 >
1203+ %1 = tensor.expand_shape %0 [[0 , 1 ], [2 ]] output_shape [%arg1 , 4 , %arg2 ]
1204+ : tensor <?x?xf32 > into tensor <?x4 x?xf32 >
1205+ return %1 : tensor <?x4 x?xf32 >
1206+ }
1207+ // CHECK-LABEL: @fold_expand_of_collapse_dynamic
1208+ // CHECK-NOT: tensor.{{.*}}_shape
1209+
1210+ // -----
1211+
1212+ func.func @no_fold_expand_of_collapse_dynamic (%arg0 : tensor <?x?x?xf32 >, %arg1: index , %arg2: index , %arg3: index )
1213+ -> tensor <?x?x?xf32 > {
1214+ %0 = tensor.collapse_shape %arg0 [[0 , 1 ], [2 ]]
1215+ : tensor <?x?x?xf32 > into tensor <?x?xf32 >
1216+ %1 = tensor.expand_shape %0 [[0 , 1 ], [2 ]] output_shape [%arg1 , %arg2 , %arg3 ]
1217+ : tensor <?x?xf32 > into tensor <?x?x?xf32 >
1218+ return %1 : tensor <?x?x?xf32 >
1219+ }
1220+ // CHECK-LABEL: @no_fold_expand_of_collapse_dynamic
1221+ // CHECK: tensor.collapse_shape
1222+ // CHECK: %[[EXPAND:.+]] = tensor.expand_shape
1223+ // CHECK: return %[[EXPAND]]
11561224
11571225// -----
11581226
0 commit comments