Skip to content

Commit 1db4c6b

Browse files
committed
Fix cpu pack-unpack-mmt4d shape
Signed-off-by: hanhanW <[email protected]>
1 parent f6f89c0 commit 1db4c6b

File tree

1 file changed

+6
-6
lines changed

1 file changed

+6
-6
lines changed

mlir/test/Integration/Dialect/Linalg/CPU/pack-unpack-mmt4d.mlir

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -81,21 +81,21 @@ func.func private @matmul(%A: tensor<7x16xi32>, %B: tensor<16x13xi32>, %C: tenso
8181
func.func private @mmt4d(%A: tensor<7x16xi32>, %B: tensor<16x13xi32>, %C: tensor<7x13xi32>) -> tensor<7x13xi32> {
8282
%zero = arith.constant 0 : i32
8383

84-
%A_pack_empty = tensor.empty() : tensor<2x16x8x1xi32>
84+
%A_pack_empty = tensor.empty() : tensor<1x16x8x1xi32>
8585
%B_pack_empty = tensor.empty() : tensor<2x16x8x1xi32>
86-
%C_pack_empty = tensor.empty() : tensor<2x2x8x8xi32>
86+
%C_pack_empty = tensor.empty() : tensor<1x2x8x8xi32>
8787

8888
// Pack matrices
89-
%A_pack = linalg.pack %A padding_value(%zero : i32) inner_dims_pos = [0, 1] inner_tiles = [8, 1] into %A_pack_empty : tensor<7x16xi32> -> tensor<2x16x8x1xi32>
89+
%A_pack = linalg.pack %A padding_value(%zero : i32) inner_dims_pos = [0, 1] inner_tiles = [8, 1] into %A_pack_empty : tensor<7x16xi32> -> tensor<1x16x8x1xi32>
9090
%B_pack = linalg.pack %B padding_value(%zero : i32) outer_dims_perm = [1, 0] inner_dims_pos = [1, 0] inner_tiles = [8, 1] into %B_pack_empty : tensor<16x13xi32> -> tensor<2x16x8x1xi32>
91-
%C_pack = linalg.pack %C padding_value(%zero : i32) outer_dims_perm = [0, 1] inner_dims_pos = [0, 1] inner_tiles = [8, 8] into %C_pack_empty : tensor<7x13xi32> -> tensor<2x2x8x8xi32>
91+
%C_pack = linalg.pack %C padding_value(%zero : i32) outer_dims_perm = [0, 1] inner_dims_pos = [0, 1] inner_tiles = [8, 8] into %C_pack_empty : tensor<7x13xi32> -> tensor<1x2x8x8xi32>
9292

9393
// MMT4D
94-
%mmt4d = linalg.mmt4d ins(%A_pack, %B_pack : tensor<2x16x8x1xi32>, tensor<2x16x8x1xi32>) outs(%C_pack : tensor<2x2x8x8xi32>) -> tensor<2x2x8x8xi32>
94+
%mmt4d = linalg.mmt4d ins(%A_pack, %B_pack : tensor<1x16x8x1xi32>, tensor<2x16x8x1xi32>) outs(%C_pack : tensor<1x2x8x8xi32>) -> tensor<1x2x8x8xi32>
9595

9696
// Unpack output
9797
%C_out_empty = tensor.empty() : tensor<7x13xi32>
98-
%C_out_unpack = linalg.unpack %mmt4d outer_dims_perm = [0, 1] inner_dims_pos = [0, 1] inner_tiles = [8, 8] into %C_out_empty : tensor<2x2x8x8xi32> -> tensor<7x13xi32>
98+
%C_out_unpack = linalg.unpack %mmt4d outer_dims_perm = [0, 1] inner_dims_pos = [0, 1] inner_tiles = [8, 8] into %C_out_empty : tensor<1x2x8x8xi32> -> tensor<7x13xi32>
9999

100100
return %C_out_unpack : tensor<7x13xi32>
101101
}

0 commit comments

Comments
 (0)