@@ -618,7 +618,9 @@ def from_plain(
618618
619619 # Linear layers are (in_features, out_features) but the int_data that is reaching this point
620620 # is (out_features, in_features). We need to transpose it to match the expected shape in the marlin code.
621+ # NOTE(reviewers): Please check if this is what I should do.
621622 q_w_24 = int_data .t ()
623+ scale = scale .reshape (- 1 , q_w_24 .shape [1 ])
622624
623625 if q_w_24 .dtype != torch .int32 :
624626 raise ValueError ("Only `torch.int32` weights are supported." )
@@ -631,15 +633,14 @@ def from_plain(
631633
632634 # NOTE: The current marlin 2:4 kernel supports both 4 and 8 bits quantization but fp8
633635 # will require a bit more work to get our current quantization flow to work with it.
634- # Check the below link for a reference:
635- # https://github.com/neuralmagic/nm-vllm/tree/main
636+ # Check the link for a reference: https://github.com/neuralmagic/nm-vllm/tree/main
636637 num_bits = 4 if torch .max (q_w_24 ) < 16 else - 1
637638 if num_bits not in [4 ]:
638639 raise ValueError (
639- f"Only { const . SUPPORTED_NUM_BITS } bits are supported, got { num_bits } ."
640+ f"Only { [ 4 ] } bits are supported, got { num_bits } ."
640641 )
641642
642- group_size = in_features // scale .shape [- 1 ]
643+ group_size = in_features // scale .shape [0 ]
643644 if group_size == 0 :
644645 group_size = in_features
645646 assert group_size <= in_features , "Group size must be less than or equal to in_features."
@@ -1043,27 +1044,44 @@ def _linear_fp_act_int4_weight_sparse_marlin_check(input_tensor, weight_tensor,
10431044 isinstance (weight_tensor .layout_type , MarlinSparseLayoutType )
10441045 )
10451046
1046-
10471047def _linear_fp_act_int4_weight_sparse_marlin_impl (input_tensor , weight_tensor , bias ):
1048- from torchao .sparsity .marlin import marlin_24_workspace
1048+ from torchao .sparsity .marlin import marlin_24_workspace , const
10491049
10501050 sparse_w_int4 = weight_tensor .layout_tensor .int_data
10511051 scale = weight_tensor .layout_tensor .scale
10521052 meta = weight_tensor .layout_tensor .meta
10531053 original_shape = weight_tensor .layout_tensor .original_shape
10541054 num_bits = weight_tensor .layout_tensor .num_bits
10551055
1056+ # Saves batch size for reshaping back to original shape after the matmul
1057+ # Reshapes tensor to (m, k) where m is in_features * batch and k is out_features
1058+ # NOTE(reviewers): Please check if I am handling the batch size correctly
1059+ batch_size = - 1
1060+ if input_tensor .dim () == 3 :
1061+ batch_size = input_tensor .size (0 )
1062+ input_tensor = input_tensor .reshape (- 1 , input_tensor .shape [- 1 ]).contiguous ()
1063+
10561064 size_m = input_tensor .shape [0 ]
1057- size_n = original_shape [0 ]
1065+ size_n = original_shape [1 ]
10581066 size_k = input_tensor .shape [1 ]
10591067 workspace_24 = marlin_24_workspace (original_shape [1 ])
10601068
1069+ # Pad input_tensor dim 1 to a multiple of the marlin tile size (16)
1070+ if size_k % const .TILE != 0 :
1071+ pad_size = find_multiple (size_k , const .TILE )
1072+ input_tensor = torch .nn .functional .pad (input_tensor , (0 , pad_size - size_k ))
1073+ size_k = pad_size
1074+
10611075 out = torchao .ops .marlin_24_gemm (
10621076 input_tensor , sparse_w_int4 , meta , scale ,
10631077 workspace_24 , num_bits , size_m , size_n , size_k
10641078 )
10651079 torch .cuda .synchronize ()
10661080
1081+ # Reshape back to original shape
1082+ if batch_size != - 1 :
1083+ out = out .reshape (batch_size , - 1 , out .shape [- 1 ])
1084+
10671085 if bias is not None :
10681086 out += bias .to (out .dtype )
10691087 return out
0 commit comments