@@ -110,31 +110,64 @@ func.func @transfer_read_dims_mismatch_non_zero_indices(
110110
111111func.func @transfer_read_dims_mismatch_non_contiguous_non_zero_indices (
112112 %arg : memref <1 x3 x3 x2 xf32 , strided <[40 , 10 , 2 , 1 ], offset : ?>>,
113- %idx0 : index ,
114- %idx1 : index ) -> vector <2 x2 xf32 > {
113+ %idx_1 : index ,
114+ %idx_2 : index ) -> vector <2 x2 xf32 > {
115115
116116 %c0 = arith.constant 0 : index
117117 %cst_1 = arith.constant 0.000000e+00 : f32
118- %8 = vector.transfer_read %arg [%c0 , %idx0 , %idx1 , %c0 ], %cst_1 {in_bounds = [true , true ]} :
118+ %8 = vector.transfer_read %arg [%c0 , %idx_1 , %idx_2 , %c0 ], %cst_1 {in_bounds = [true , true ]} :
119119 memref <1 x3 x3 x2 xf32 , strided <[40 , 10 , 2 , 1 ], offset : ?>>, vector <2 x2 xf32 >
120120 return %8 : vector <2 x2 xf32 >
121121}
122122
123123// CHECK: #[[$MAP:.+]] = affine_map<()[s0] -> (s0 * 2)>
124124
125125// CHECK-LABEL: func.func @transfer_read_dims_mismatch_non_contiguous_non_zero_indices(
126- // CHECK: %[[COLLAPSE:.+]] = memref.collapse_shape %{{.*}} {{\[}}[0], [1], [2, 3]] : memref<1x3x3x2xf32, strided<[40, 10, 2, 1], offset: ?>> into memref<1x3x6xf32, strided<[40, 10, 1], offset: ?>>
126+ // CHECK: %[[COLLAPSE:.+]] = memref.collapse_shape %{{.*}} {{\[}}[0], [1], [2, 3]]
127+ // CHECK-SAME: : memref<1x3x3x2xf32, strided<[40, 10, 2, 1], offset: ?>> into memref<1x3x6xf32, strided<[40, 10, 1], offset: ?>>
127128// CHECK: %[[APPLY:.*]] = affine.apply #[[$MAP]]()
128129
129130// CHECK-128B-LABEL: func @transfer_read_dims_mismatch_non_contiguous_non_zero_indices(
130131// CHECK-128B: memref.collapse_shape
131132
132133// -----
133134
134- // The input memref has a dynamic trailing shape and hence is not flattened.
135- // TODO: This case could be supported via memref.dim
135+ // The leading dynamic shapes don't affect whether this example is flattenable
136+ // or not. Indeed, those dynamic shapes are not candidates for flattening anyway.
136137
137- func.func @transfer_read_dims_mismatch_non_zero_indices_dynamic_shapes (
138+ func.func @transfer_read_leading_dynamic_dims (
139+ %arg : memref <?x?x8 x4 xi8 , strided <[?, 32 , 4 , 1 ], offset : ?>>,
140+ %idx_1 : index ,
141+ %idx_2 : index ) -> vector <8 x4 xi8 > {
142+
143+ %c0_i8 = arith.constant 0 : i8
144+ %c0 = arith.constant 0 : index
145+ %result = vector.transfer_read %arg [%idx_1 , %idx_2 , %c0 , %c0 ], %c0_i8 {in_bounds = [true , true ]} :
146+ memref <?x?x8 x4 xi8 , strided <[?, 32 , 4 , 1 ], offset : ?>>, vector <8 x4 xi8 >
147+ return %result : vector <8 x4 xi8 >
148+ }
149+
150+ // CHECK-LABEL: func @transfer_read_leading_dynamic_dims
151+ // CHECK-SAME: %[[ARG0:.+]]: memref<?x?x8x4xi8, {{.+}}>, %[[ARG1:.+]]: index, %[[ARG2:.+]]: index
152+ // CHECK: %[[C0_I8:.+]] = arith.constant 0 : i8
153+ // CHECK: %[[C0:.+]] = arith.constant 0 : index
154+ // CHECK: %[[COLLAPSED:.+]] = memref.collapse_shape %[[ARG0]] {{\[}}[0], [1], [2, 3]{{\]}}
155+ // CHECK-SAME: : memref<?x?x8x4xi8, {{.+}}> into memref<?x?x32xi8, {{.+}}>
156+ // CHECK: %[[VEC1D:.+]] = vector.transfer_read %[[COLLAPSED]]
157+ // CHECK-SAME: [%[[ARG1]], %[[ARG2]], %[[C0]]], %[[C0_I8]]
158+ // CHECK-SAME: {in_bounds = [true]}
159+ // CHECK-SAME: : memref<?x?x32xi8, {{.+}}>, vector<32xi8>
160+ // CHECK: %[[VEC2D:.+]] = vector.shape_cast %[[VEC1D]] : vector<32xi8> to vector<8x4xi8>
161+ // CHECK: return %[[VEC2D]] : vector<8x4xi8>
162+
163+ // CHECK-128B-LABEL: func @transfer_read_leading_dynamic_dims
164+ // CHECK-128B: memref.collapse_shape
165+
166+ // -----
167+
168+ // One of the dims to be flattened is dynamic - not supported ATM.
169+
170+ func.func @negative_transfer_read_dynamic_dim_to_flatten (
138171 %idx_1: index ,
139172 %idx_2: index ,
140173 %m_in: memref <1 x?x4 x6 xi32 >) -> vector <1 x2 x6 xi32 > {
@@ -146,11 +179,11 @@ func.func @transfer_read_dims_mismatch_non_zero_indices_dynamic_shapes(
146179 return %v : vector <1 x2 x6 xi32 >
147180}
148181
149- // CHECK-LABEL: func.func @transfer_read_dims_mismatch_non_zero_indices_dynamic_shapes(
182+ // CHECK-LABEL: func.func @negative_transfer_read_dynamic_dim_to_flatten
150183// CHECK-NOT: memref.collapse_shape
151184// CHECK-NOT: vector.shape_cast
152185
153- // CHECK-128B-LABEL: func @transfer_read_dims_mismatch_non_zero_indices_dynamic_shapes(
186+ // CHECK-128B-LABEL: func @negative_transfer_read_dynamic_dim_to_flatten
154187// CHECK-128B-NOT: memref.collapse_shape
155188
156189// -----
@@ -326,11 +359,11 @@ func.func @transfer_write_dims_mismatch_non_zero_indices(
326359func.func @transfer_write_dims_mismatch_non_contiguous_non_zero_indices (
327360 %value : vector <2 x2 xf32 >,
328361 %subview : memref <1 x3 x3 x2 xf32 , strided <[40 , 10 , 2 , 1 ], offset : ?>>,
329- %idx0 : index ,
330- %idx1 : index ) {
362+ %idx_1 : index ,
363+ %idx_2 : index ) {
331364
332365 %c0 = arith.constant 0 : index
333- vector.transfer_write %value , %subview [%c0 , %idx0 , %idx1 , %c0 ] {in_bounds = [true , true ]} : vector <2 x2 xf32 >, memref <1 x3 x3 x2 xf32 , strided <[40 , 10 , 2 , 1 ], offset : ?>>
366+ vector.transfer_write %value , %subview [%c0 , %idx_1 , %idx_2 , %c0 ] {in_bounds = [true , true ]} : vector <2 x2 xf32 >, memref <1 x3 x3 x2 xf32 , strided <[40 , 10 , 2 , 1 ], offset : ?>>
334367 return
335368}
336369
@@ -345,10 +378,40 @@ func.func @transfer_write_dims_mismatch_non_contiguous_non_zero_indices(
345378
346379// -----
347380
348- // The input memref has a dynamic trailing shape and hence is not flattened.
349- // TODO: This case could be supported via memref.dim
381+ // The leading dynamic shapes don't affect whether this example is flattenable
382+ // or not. Indeed, those dynamic shapes are not candidates for flattening anyway.
383+
384+ func.func @transfer_write_leading_dynamic_dims (
385+ %vec : vector <8 x4 xi8 >,
386+ %arg : memref <?x?x8 x4 xi8 , strided <[?, 32 , 4 , 1 ], offset : ?>>,
387+ %idx_1 : index ,
388+ %idx_2 : index ) {
389+
390+ %c0 = arith.constant 0 : index
391+ vector.transfer_write %vec , %arg [%idx_1 , %idx_2 , %c0 , %c0 ] {in_bounds = [true , true ]} :
392+ vector <8 x4 xi8 >, memref <?x?x8 x4 xi8 , strided <[?, 32 , 4 , 1 ], offset : ?>>
393+ return
394+ }
395+
396+ // CHECK-LABEL: func @transfer_write_leading_dynamic_dims
397+ // CHECK-SAME: %[[ARG0:.+]]: vector<8x4xi8>, %[[ARG1:.+]]: memref<?x?x8x4xi8, {{.+}}>, %[[ARG2:.+]]: index, %[[ARG3:.+]]: index
398+ // CHECK: %[[C0:.+]] = arith.constant 0 : index
399+ // CHECK: %[[COLLAPSED:.+]] = memref.collapse_shape %[[ARG1]] {{\[}}[0], [1], [2, 3]{{\]}}
400+ // CHECK-SAME: : memref<?x?x8x4xi8, {{.+}}> into memref<?x?x32xi8, {{.+}}>
401+ // CHECK: %[[VEC1D:.+]] = vector.shape_cast %[[ARG0]] : vector<8x4xi8> to vector<32xi8>
402+ // CHECK: vector.transfer_write %[[VEC1D]], %[[COLLAPSED]]
403+ // CHECK-SAME: [%[[ARG2]], %[[ARG3]], %[[C0]]]
404+ // CHECK-SAME: {in_bounds = [true]}
405+ // CHECK-SAME: : vector<32xi8>, memref<?x?x32xi8, {{.+}}>
406+
407+ // CHECK-128B-LABEL: func @transfer_write_leading_dynamic_dims
408+ // CHECK-128B: memref.collapse_shape
350409
351- func.func @transfer_write_dims_mismatch_non_zero_indices_dynamic_shapes (
410+ // -----
411+
412+ // One of the dims to be flattened is dynamic - not supported ATM.
413+
414+ func.func @negative_transfer_write_dynamic_to_flatten (
352415 %idx_1: index ,
353416 %idx_2: index ,
354417 %vec : vector <1 x2 x6 xi32 >,
@@ -361,11 +424,11 @@ func.func @transfer_write_dims_mismatch_non_zero_indices_dynamic_shapes(
361424 return
362425}
363426
364- // CHECK-LABEL: func.func @transfer_write_dims_mismatch_non_zero_indices_dynamic_shapes(
427+ // CHECK-LABEL: func.func @negative_transfer_write_dynamic_to_flatten
365428// CHECK-NOT: memref.collapse_shape
366429// CHECK-NOT: vector.shape_cast
367430
368- // CHECK-128B-LABEL: func @transfer_write_dims_mismatch_non_zero_indices_dynamic_shapes(
431+ // CHECK-128B-LABEL: func @negative_transfer_write_dynamic_to_flatten
369432// CHECK-128B-NOT: memref.collapse_shape
370433
371434// -----
@@ -434,56 +497,10 @@ func.func @transfer_write_non_contiguous_src(
434497// -----
435498
436499///----------------------------------------------------------------------------------------
437- /// TODO: Categorize + re-format
500+ /// [Pattern: DropUnitDimFromElementwiseOps]
501+ /// TODO: Move to a dedicated file - there's no "flattening" in the following tests
438502///----------------------------------------------------------------------------------------
439503
440- func.func @transfer_read_flattenable_with_dynamic_dims_and_indices (%arg0 : memref <?x?x8 x4 xi8 , strided <[?, 32 , 4 , 1 ], offset : ?>>, %arg1 : index , %arg2 : index ) -> vector <8 x4 xi8 > {
441- %c0_i8 = arith.constant 0 : i8
442- %c0 = arith.constant 0 : index
443- %result = vector.transfer_read %arg0 [%arg1 , %arg2 , %c0 , %c0 ], %c0_i8 {in_bounds = [true , true ]} : memref <?x?x8 x4 xi8 , strided <[?, 32 , 4 , 1 ], offset : ?>>, vector <8 x4 xi8 >
444- return %result : vector <8 x4 xi8 >
445- }
446-
447- // CHECK-LABEL: func @transfer_read_flattenable_with_dynamic_dims_and_indices
448- // CHECK-SAME: %[[ARG0:.+]]: memref<?x?x8x4xi8, {{.+}}>, %[[ARG1:.+]]: index, %[[ARG2:.+]]: index
449- // CHECK: %[[C0_I8:.+]] = arith.constant 0 : i8
450- // CHECK: %[[C0:.+]] = arith.constant 0 : index
451- // CHECK: %[[COLLAPSED:.+]] = memref.collapse_shape %[[ARG0]] {{\[}}[0], [1], [2, 3]{{\]}}
452- // CHECK-SAME: : memref<?x?x8x4xi8, {{.+}}> into memref<?x?x32xi8, {{.+}}>
453- // CHECK: %[[VEC1D:.+]] = vector.transfer_read %[[COLLAPSED]]
454- // CHECK-SAME: [%[[ARG1]], %[[ARG2]], %[[C0]]], %[[C0_I8]]
455- // CHECK-SAME: {in_bounds = [true]}
456- // CHECK-SAME: : memref<?x?x32xi8, {{.+}}>, vector<32xi8>
457- // CHECK: %[[VEC2D:.+]] = vector.shape_cast %[[VEC1D]] : vector<32xi8> to vector<8x4xi8>
458- // CHECK: return %[[VEC2D]] : vector<8x4xi8>
459-
460- // CHECK-128B-LABEL: func @transfer_read_flattenable_with_dynamic_dims_and_indices(
461- // CHECK-128B: memref.collapse_shape
462-
463- // -----
464-
465- func.func @transfer_write_flattenable_with_dynamic_dims_and_indices (%vec : vector <8 x4 xi8 >, %dst : memref <?x?x8 x4 xi8 , strided <[?, 32 , 4 , 1 ], offset : ?>>, %arg1 : index , %arg2 : index ) {
466- %c0 = arith.constant 0 : index
467- vector.transfer_write %vec , %dst [%arg1 , %arg2 , %c0 , %c0 ] {in_bounds = [true , true ]} : vector <8 x4 xi8 >, memref <?x?x8 x4 xi8 , strided <[?, 32 , 4 , 1 ], offset : ?>>
468- return
469- }
470-
471- // CHECK-LABEL: func @transfer_write_flattenable_with_dynamic_dims_and_indices
472- // CHECK-SAME: %[[ARG0:.+]]: vector<8x4xi8>, %[[ARG1:.+]]: memref<?x?x8x4xi8, {{.+}}>, %[[ARG2:.+]]: index, %[[ARG3:.+]]: index
473- // CHECK: %[[C0:.+]] = arith.constant 0 : index
474- // CHECK: %[[COLLAPSED:.+]] = memref.collapse_shape %[[ARG1]] {{\[}}[0], [1], [2, 3]{{\]}}
475- // CHECK-SAME: : memref<?x?x8x4xi8, {{.+}}> into memref<?x?x32xi8, {{.+}}>
476- // CHECK: %[[VEC1D:.+]] = vector.shape_cast %[[ARG0]] : vector<8x4xi8> to vector<32xi8>
477- // CHECK: vector.transfer_write %[[VEC1D]], %[[COLLAPSED]]
478- // CHECK-SAME: [%[[ARG2]], %[[ARG3]], %[[C0]]]
479- // CHECK-SAME: {in_bounds = [true]}
480- // CHECK-SAME: : vector<32xi8>, memref<?x?x32xi8, {{.+}}>
481-
482- // CHECK-128B-LABEL: func @transfer_write_flattenable_with_dynamic_dims_and_indices(
483- // CHECK-128B: memref.collapse_shape
484-
485- // -----
486-
487504func.func @fold_unit_dim_add_basic (%arg0 : vector <1 x8 xi32 >) -> vector <1 x8 xi32 > {
488505 %add = arith.addi %arg0 , %arg0 : vector <1 x8 xi32 >
489506 return %add : vector <1 x8 xi32 >
0 commit comments