55# LICENSE file in the root directory of this source tree.
66
77import torch
8- import torch .nn .functional as F
98
109Tensor = torch .Tensor
1110
@@ -31,14 +30,23 @@ def to_blocked(input_matrix) -> Tensor:
3130 n_row_blocks = ceil_div (rows , 128 )
3231 n_col_blocks = ceil_div (cols , 4 )
3332
34- # Pad out and view as tiles of (128, 4)
35- padded = F .pad (input_matrix , (0 , - cols % 4 , 0 , - rows % 128 ))
36- blocks = padded .view (n_row_blocks , 128 , n_col_blocks , 4 ).permute (0 , 2 , 1 , 3 )
33+ # Calculate the padded shape
34+ padded_rows = n_row_blocks * 128
35+ padded_cols = n_col_blocks * 4
36+
37+ padded = input_matrix
38+ if (rows , cols ) != (padded_rows , padded_cols ):
39+ padded = torch .zeros (
40+ (padded_rows , padded_cols ),
41+ device = input_matrix .device ,
42+ dtype = input_matrix .dtype ,
43+ )
44+ padded [:rows , :cols ] = input_matrix
3745
38- # rearrange all tiles
46+ # Rearrange the blocks
47+ blocks = padded .view (n_row_blocks , 128 , n_col_blocks , 4 ).permute (0 , 2 , 1 , 3 )
3948 rearranged = blocks .reshape (- 1 , 4 , 32 , 4 ).transpose (1 , 2 ).reshape (- 1 , 32 , 16 )
4049
41- # Layout rearranged tiles according to second pic
4250 return rearranged .flatten ()
4351
4452
0 commit comments