@@ -81,21 +81,21 @@ func.func private @matmul(%A: tensor<7x16xi32>, %B: tensor<16x13xi32>, %C: tenso
8181func.func private @mmt4d (%A: tensor <7 x16 xi32 >, %B: tensor <16 x13 xi32 >, %C: tensor <7 x13 xi32 >) -> tensor <7 x13 xi32 > {
8282 %zero = arith.constant 0 : i32
8383
84- %A_pack_empty = tensor.empty () : tensor <2 x 16 x 8 x 1 x i32 >
84+ %A_pack_empty = tensor.empty () : tensor <1 x 16 x 8 x 1 x i32 >
8585 %B_pack_empty = tensor.empty () : tensor <2 x16 x8 x1 xi32 >
86- %C_pack_empty = tensor.empty () : tensor <2 x 2 x 8 x 8 x i32 >
86+ %C_pack_empty = tensor.empty () : tensor <1 x 2 x 8 x 8 x i32 >
8787
8888 // Pack matrices
89- %A_pack = linalg.pack %A padding_value (%zero : i32 ) inner_dims_pos = [0 , 1 ] inner_tiles = [8 , 1 ] into %A_pack_empty : tensor <7 x16 xi32 > -> tensor <2 x 16 x 8 x 1 x i32 >
89+ %A_pack = linalg.pack %A padding_value (%zero : i32 ) inner_dims_pos = [0 , 1 ] inner_tiles = [8 , 1 ] into %A_pack_empty : tensor <7 x16 xi32 > -> tensor <1 x 16 x 8 x 1 x i32 >
9090 %B_pack = linalg.pack %B padding_value (%zero : i32 ) outer_dims_perm = [1 , 0 ] inner_dims_pos = [1 , 0 ] inner_tiles = [8 , 1 ] into %B_pack_empty : tensor <16 x13 xi32 > -> tensor <2 x16 x8 x1 xi32 >
91- %C_pack = linalg.pack %C padding_value (%zero : i32 ) outer_dims_perm = [0 , 1 ] inner_dims_pos = [0 , 1 ] inner_tiles = [8 , 8 ] into %C_pack_empty : tensor <7 x13 xi32 > -> tensor <2 x 2 x 8 x 8 x i32 >
91+ %C_pack = linalg.pack %C padding_value (%zero : i32 ) outer_dims_perm = [0 , 1 ] inner_dims_pos = [0 , 1 ] inner_tiles = [8 , 8 ] into %C_pack_empty : tensor <7 x13 xi32 > -> tensor <1 x 2 x 8 x 8 x i32 >
9292
9393 // MMT4D
94- %mmt4d = linalg.mmt4d ins (%A_pack , %B_pack : tensor <2 x 16 x 8 x 1 x i32 >, tensor <2 x16 x8 x1 xi32 >) outs (%C_pack : tensor <2 x 2 x 8 x 8 x i32 >) -> tensor <2 x 2 x 8 x 8 x i32 >
94+ %mmt4d = linalg.mmt4d ins (%A_pack , %B_pack : tensor <1 x 16 x 8 x 1 x i32 >, tensor <2 x16 x8 x1 xi32 >) outs (%C_pack : tensor <1 x 2 x 8 x 8 x i32 >) -> tensor <1 x 2 x 8 x 8 x i32 >
9595
9696 // Unpack output
9797 %C_out_empty = tensor.empty () : tensor <7 x13 xi32 >
98- %C_out_unpack = linalg.unpack %mmt4d outer_dims_perm = [0 , 1 ] inner_dims_pos = [0 , 1 ] inner_tiles = [8 , 8 ] into %C_out_empty : tensor <2 x 2 x 8 x 8 x i32 > -> tensor <7 x13 xi32 >
98+ %C_out_unpack = linalg.unpack %mmt4d outer_dims_perm = [0 , 1 ] inner_dims_pos = [0 , 1 ] inner_tiles = [8 , 8 ] into %C_out_empty : tensor <1 x 2 x 8 x 8 x i32 > -> tensor <7 x13 xi32 >
9999
100100 return %C_out_unpack : tensor <7 x13 xi32 >
101101}
0 commit comments