From 1032cf38894c32401556a602faf6f69995cba7fe Mon Sep 17 00:00:00 2001 From: Michael Goin Date: Tue, 29 Oct 2024 15:11:03 -0400 Subject: [PATCH 1/2] Make apply_fp8_linear work with >2D input --- .../layers/quantization/utils/w8a8_utils.py | 22 ++++++++++++------- 1 file changed, 14 insertions(+), 8 deletions(-) diff --git a/vllm/model_executor/layers/quantization/utils/w8a8_utils.py b/vllm/model_executor/layers/quantization/utils/w8a8_utils.py index 1879d2855d93..a041b4716bec 100644 --- a/vllm/model_executor/layers/quantization/utils/w8a8_utils.py +++ b/vllm/model_executor/layers/quantization/utils/w8a8_utils.py @@ -96,21 +96,26 @@ def apply_fp8_linear( # If dynamic, layer.input_scale is None and x_scale computed from x. # If static, layer.input_scale is scalar and x_scale is input_scale. + # View input as 2D matrix for fp8 methods + input_2d = input.view(-1, input.shape[-1]) + output_shape = [*input.shape[:-1], weight.shape[1]] + # cutlass_scaled_mm supports per tensor/channel W and per tensor/token A if cutlass_fp8_supported: qinput, x_scale = ops.scaled_fp8_quant( - input, + input_2d, input_scale, scale_ub=input_scale_ub, use_per_token_if_dynamic=use_per_token_if_dynamic) # Fused GEMM_DQ - return ops.cutlass_scaled_mm(qinput, + output = ops.cutlass_scaled_mm(qinput, weight, out_dtype=input.dtype, scale_a=x_scale, scale_b=weight_scale, bias=bias) + return output.view(*output_shape) # torch.scaled_mm supports per tensor weights + activations only # so fallback to naive if per channel or per token @@ -119,7 +124,7 @@ def apply_fp8_linear( # for matrices with batch dimension > 16. # This could change in the future. qinput, x_scale = ops.scaled_fp8_quant( - input, + input_2d, input_scale, num_token_padding=17, use_per_token_if_dynamic=use_per_token_if_dynamic) @@ -138,8 +143,9 @@ def apply_fp8_linear( # A fix for discrepancy in scaled_mm which returns tuple # for torch < 2.5 and a single value in torch >= 2.5 if type(output) is tuple and len(output) == 2: - return torch.narrow(output[0], 0, 0, input.shape[0]) - return torch.narrow(output, 0, 0, input.shape[0]) + output = output[0] + + return torch.narrow(output, 0, 0, input_2d.shape[0]).view(*output_shape) else: # Fallback for channelwise case, where we use unfused DQ @@ -176,15 +182,15 @@ def apply_fp8_linear( if type(output) is tuple and len(output) == 2: output = output[0] # Unpad (undo num_token_padding) - output = torch.narrow(output, 0, 0, input.shape[0]) - x_scale = torch.narrow(x_scale, 0, 0, input.shape[0]) + output = torch.narrow(output, 0, 0, input_2d.shape[0]) + x_scale = torch.narrow(x_scale, 0, 0, input_2d.shape[0]) # DQ # C = sw * sx * (X * W) + bias output = output * x_scale * weight_scale.t() if bias is not None: output = output + bias - return output.to(dtype=input.dtype) + return output.to(dtype=input.dtype).view(*output_shape) def apply_int8_linear( From ebd38af92c3d2f2ec0d9740f8a861458a0ecf68b Mon Sep 17 00:00:00 2001 From: Michael Goin Date: Tue, 29 Oct 2024 16:49:20 -0400 Subject: [PATCH 2/2] Update w8a8_utils.py --- .../layers/quantization/utils/w8a8_utils.py | 13 +++++++------ 1 file changed, 7 insertions(+), 6 deletions(-) diff --git a/vllm/model_executor/layers/quantization/utils/w8a8_utils.py b/vllm/model_executor/layers/quantization/utils/w8a8_utils.py index a041b4716bec..445117ac99a3 100644 --- a/vllm/model_executor/layers/quantization/utils/w8a8_utils.py +++ b/vllm/model_executor/layers/quantization/utils/w8a8_utils.py @@ -110,11 +110,11 @@ def apply_fp8_linear( # Fused GEMM_DQ output = ops.cutlass_scaled_mm(qinput, - weight, - out_dtype=input.dtype, - scale_a=x_scale, - scale_b=weight_scale, - bias=bias) + weight, + out_dtype=input.dtype, + scale_a=x_scale, + scale_b=weight_scale, + bias=bias) return output.view(*output_shape) # torch.scaled_mm supports per tensor weights + activations only @@ -145,7 +145,8 @@ def apply_fp8_linear( if type(output) is tuple and len(output) == 2: output = output[0] - return torch.narrow(output, 0, 0, input_2d.shape[0]).view(*output_shape) + return torch.narrow(output, 0, 0, + input_2d.shape[0]).view(*output_shape) else: # Fallback for channelwise case, where we use unfused DQ