diff --git a/tests/model_executor/test_enabled_custom_ops.py b/tests/model_executor/test_enabled_custom_ops.py index 2d9cf1d48fd5..93453ddb657c 100644 --- a/tests/model_executor/test_enabled_custom_ops.py +++ b/tests/model_executor/test_enabled_custom_ops.py @@ -1,6 +1,7 @@ # SPDX-License-Identifier: Apache-2.0 import pytest +import torch from vllm.config import CompilationConfig, VllmConfig, set_current_vllm_config from vllm.model_executor.custom_op import CustomOp @@ -16,6 +17,8 @@ from vllm.model_executor.layers.layernorm import ( RMSNorm, dispatch_cuda_rmsnorm_func, fused_add_rms_norm, rms_norm, rocm_aiter_fused_add_rms_norm, rocm_aiter_rms_norm) +from vllm.model_executor.layers.quantization.utils.fp8_utils import ( + cutlass_scaled_mm, dispatch_w8a8_blockscale_func, w8a8_block_fp8_matmul) from vllm.platforms import current_platform @@ -98,6 +101,34 @@ def test_enabled_ops_invalid(env: str): RMSNorm(1024).enabled() +@pytest.mark.skipif( + not current_platform.is_rocm() or not current_platform.is_fp8_fnuz(), + reason="AITER is a feature exclusive for ROCm and FP8_FNUZ") +@pytest.mark.parametrize("use_cutlass", [True, False]) +@pytest.mark.parametrize("use_rocm_aiter", ["0", "1"]) +@pytest.mark.parametrize("use_rocm_aiter_gemm_w8a8_blockscale", ["0", "1"]) +def test_w8a8_blockscale_dispatch(use_cutlass: bool, use_rocm_aiter: str, + use_rocm_aiter_gemm_w8a8_blockscale: str, + monkeypatch): + + monkeypatch.setenv("VLLM_ROCM_USE_AITER", use_rocm_aiter) + monkeypatch.setenv("VLLM_ROCM_USE_AITER_LINEAR", + use_rocm_aiter_gemm_w8a8_blockscale) + + use_aiter_and_is_supported = (bool(int(use_rocm_aiter)) and bool( + int(use_rocm_aiter_gemm_w8a8_blockscale))) + block_scale_func = dispatch_w8a8_blockscale_func( + use_cutlass, use_aiter_and_is_supported=use_aiter_and_is_supported) + if use_cutlass: + assert block_scale_func == cutlass_scaled_mm + elif current_platform.is_rocm() and int(use_rocm_aiter) and int( + use_rocm_aiter_gemm_w8a8_blockscale): + assert block_scale_func == ( + torch.ops.vllm.rocm_aiter_gemm_w8a8_blockscale) + else: + assert block_scale_func == w8a8_block_fp8_matmul + + @pytest.mark.parametrize("use_rocm_aiter", ["0", "1"]) def test_topk_dispatch(use_rocm_aiter: str, monkeypatch): monkeypatch.setenv("VLLM_ROCM_USE_AITER", use_rocm_aiter) diff --git a/vllm/model_executor/layers/quantization/fp8.py b/vllm/model_executor/layers/quantization/fp8.py index ca3126354a1a..5b5f25909c33 100644 --- a/vllm/model_executor/layers/quantization/fp8.py +++ b/vllm/model_executor/layers/quantization/fp8.py @@ -182,6 +182,13 @@ def __init__(self, quant_config: Fp8Config): if current_platform.is_rocm(): self.use_marlin = False + # AITER is only supported on ROCm and only for FP8_FNUZ + # and at the moment are MI300 series + self.use_aiter_and_is_supported = (current_platform.is_rocm() + and envs.VLLM_ROCM_USE_AITER + and envs.VLLM_ROCM_USE_AITER_LINEAR + and current_platform.is_fp8_fnuz()) + self.block_quant = self.quant_config.weight_block_size is not None self.fp8_linear = Fp8LinearOp( # Default to using per_token quantization if cutlass is supported @@ -402,6 +409,7 @@ def apply(self, input_scale=layer.input_scale, bias=bias, cutlass_block_fp8_supported=self.cutlass_block_fp8_supported, + use_aiter_and_is_supported=self.use_aiter_and_is_supported, ) return self.fp8_linear.apply(input=x, diff --git a/vllm/model_executor/layers/quantization/utils/fp8_utils.py b/vllm/model_executor/layers/quantization/utils/fp8_utils.py index 723d2ffd4318..8f525ef1452a 100644 --- a/vllm/model_executor/layers/quantization/utils/fp8_utils.py +++ b/vllm/model_executor/layers/quantization/utils/fp8_utils.py @@ -4,7 +4,7 @@ import functools import json import os -from typing import Any, Optional, Union +from typing import Any, Callable, Optional, Union import torch @@ -27,6 +27,76 @@ def is_fp8(x: Union[torch.dtype, torch.Tensor]) -> bool: return x == torch.float8_e4m3fn or x == torch.float8_e4m3fnuz +def cutlass_scaled_mm( + A: torch.Tensor, + B: torch.Tensor, + As: torch.Tensor, + Bs: torch.Tensor, + block_size: list[int], + output_dtype: torch.dtype = torch.float16, +) -> torch.Tensor: + return ops.cutlass_scaled_mm(A, + B.T, + out_dtype=output_dtype, + scale_a=As, + scale_b=Bs.T) + + +def rocm_aiter_gemm_w8a8_blockscale_impl( + A: torch.Tensor, + B: torch.Tensor, + As: torch.Tensor, + Bs: torch.Tensor, + block_size: list[int], + output_dtype: torch.dtype = torch.float16, +) -> torch.Tensor: + import aiter as rocm_aiter + + return rocm_aiter.gemm_a8w8_blockscale_CK(A, B, As, Bs, dtype=output_dtype) + + +def rocm_aiter_gemm_w8a8_blockscale_fake( + A: torch.Tensor, + B: torch.Tensor, + As: torch.Tensor, + Bs: torch.Tensor, + block_size: list[int], + output_dtype: torch.dtype = torch.float16, +) -> torch.Tensor: + + m = A.shape[0] + n = B.shape[0] + Y = torch.empty(m, n, dtype=output_dtype, device=A.device) + return Y + + +if current_platform.is_rocm(): + direct_register_custom_op( + op_name="rocm_aiter_gemm_w8a8_blockscale", + op_func=rocm_aiter_gemm_w8a8_blockscale_impl, + mutates_args=[], + fake_impl=rocm_aiter_gemm_w8a8_blockscale_fake, + dispatch_key=current_platform.dispatch_key, + ) + + +def dispatch_w8a8_blockscale_func( + use_cutlass: bool, use_aiter_and_is_supported: bool +) -> Callable[[ + torch.Tensor, + torch.Tensor, + torch.Tensor, + torch.Tensor, + list[int], + torch.dtype, +], torch.Tensor]: + if use_cutlass: + return cutlass_scaled_mm + if (use_aiter_and_is_supported): + return torch.ops.vllm.rocm_aiter_gemm_w8a8_blockscale + return w8a8_block_fp8_matmul + + # TODO fix ROCm->Triton custom path: # https://github.com/vllm-project/vllm/issues/14397 def apply_w8a8_block_fp8_linear( @@ -37,26 +107,23 @@ def apply_w8a8_block_fp8_linear( input_scale: Optional[torch.Tensor] = None, bias: Optional[torch.Tensor] = None, cutlass_block_fp8_supported: bool = CUTLASS_BLOCK_FP8_SUPPORTED, + use_aiter_and_is_supported: bool = False, ) -> torch.Tensor: assert input_scale is None # View input as 2D matrix for fp8 methods input_2d = input.view(-1, input.shape[-1]) output_shape = [*input.shape[:-1], weight.shape[0]] - shape_supported_by_cutlass = (weight.shape[0] % 128 == 0 - and weight.shape[1] % 128 == 0) - if current_platform.is_rocm(): - # TODO this is never used, as cutlass_block_fp8_supported is False - scale_a_shape = ((input_2d.shape[-1] // block_size[1], ) + - input_2d.shape[:-1])[::-1] - scale_b_shape = (weight_scale.view(-1, 1) - if weight_scale.dim() <= 1 else weight_scale.T).shape - ar, ac = scale_a_shape - br, bc = scale_b_shape - if (ac > 1 or bc > 1 or ar not in (1, input_2d.shape[0]) - or br not in (1, weight.shape[0])): - shape_supported_by_cutlass = False - if cutlass_block_fp8_supported and shape_supported_by_cutlass: + if current_platform.is_cuda(): + use_cutlass = cutlass_block_fp8_supported and ( + weight.shape[0] % 128 == 0 and weight.shape[1] % 128 == 0) + else: + use_cutlass = False + + w8a8_blockscale_func = dispatch_w8a8_blockscale_func( + use_cutlass, use_aiter_and_is_supported) + + if use_cutlass: rows, cols = input_2d.shape # Blackwell GPUs (SM100) require row dimensions to be multiple of 4 for # optimal tensor core usage. Can be removed when targeting platforms @@ -67,26 +134,22 @@ def apply_w8a8_block_fp8_linear( input_2d = torch.nn.functional.pad(input_2d, (0, 0, 0, 4 - (rows % 4)), value=0).contiguous() - q_input, x_scale = per_token_group_quant_fp8(input_2d, - block_size[1], - column_major_scales=True) - output = ops.cutlass_scaled_mm(q_input, - weight.T, - out_dtype=input.dtype, - scale_a=x_scale, - scale_b=weight_scale.T) + + q_input, x_scale = per_token_group_quant_fp8( + input_2d, block_size[1], column_major_scales=use_cutlass) + + output = w8a8_blockscale_func(q_input, weight, x_scale, weight_scale, + block_size, input.dtype) if should_pad: output = output[:rows, :] + else: - q_input, x_scale = per_token_group_quant_fp8(input_2d, - block_size[1], - column_major_scales=False) - output = w8a8_block_fp8_matmul(q_input, - weight, - x_scale, - weight_scale, - block_size, - output_dtype=input.dtype) + q_input, x_scale = per_token_group_quant_fp8( + input_2d, block_size[1], column_major_scales=use_cutlass) + + output = w8a8_blockscale_func(q_input, weight, x_scale, weight_scale, + block_size, input.dtype) + if bias is not None: output = output + bias return output.to(dtype=input.dtype).view(*output_shape) @@ -98,6 +161,9 @@ def apply_w8a8_block_fp8_linear_fake( block_size: list[int], weight_scale: torch.Tensor, input_scale: Optional[torch.Tensor] = None, + bias: Optional[torch.Tensor] = None, + cutlass_block_fp8_supported: bool = CUTLASS_BLOCK_FP8_SUPPORTED, + use_aiter_and_is_supported: bool = False, ) -> torch.Tensor: output_shape = [*input.shape[:-1], weight.shape[0]] return torch.empty(output_shape, dtype=input.dtype, device=input.device)