From 961de9ee5a551c2f594e00b649088c18e4f56dcc Mon Sep 17 00:00:00 2001 From: vllmellm Date: Mon, 17 Mar 2025 13:03:57 +0000 Subject: [PATCH 01/18] add aiter block-scaled gemm Signed-off-by: vllmellm --- .../model_executor/test_enabled_custom_ops.py | 25 +++++ vllm/envs.py | 16 +++ .../layers/quantization/utils/fp8_utils.py | 99 ++++++++++++------- vllm/platforms/interface.py | 3 + vllm/platforms/rocm.py | 3 + 5 files changed, 113 insertions(+), 33 deletions(-) diff --git a/tests/model_executor/test_enabled_custom_ops.py b/tests/model_executor/test_enabled_custom_ops.py index 4a6a766b8ca0..fe62a3165978 100644 --- a/tests/model_executor/test_enabled_custom_ops.py +++ b/tests/model_executor/test_enabled_custom_ops.py @@ -8,6 +8,10 @@ ReLUSquaredActivation, SiluAndMul) from vllm.model_executor.layers.layernorm import RMSNorm +from vllm.model_executor.layers.quantization.utils.fp8_utils import ( + cutlass_scaled_mm, dispatch_w8a8_blockscale_func, + rocm_aiter_gemm_w8a8_blockscale, w8a8_block_fp8_matmul) +from vllm.platforms import current_platform # Registered subclass for test @@ -87,3 +91,24 @@ def test_enabled_ops_invalid(env: str): custom_ops=env.split(","))) with set_current_vllm_config(vllm_config): RMSNorm(1024).enabled() + + +@pytest.mark.parametrize("use_cutlass", [True, False]) +@pytest.mark.parametrize("use_rocm_aiter", ["0", "1"]) +@pytest.mark.parametrize("use_rocm_aiter_gemm_w8a8_blockscale", ["0", "1"]) +def test_w8a8_blockscale_dispatch(use_cutlass: bool, use_rocm_aiter: str, + use_rocm_aiter_gemm_w8a8_blockscale: str, + monkeypatch): + + monkeypatch.setenv("VLLM_ROCM_USE_AITER", use_rocm_aiter) + monkeypatch.setenv("VLLM_ROCM_USE_AITER_GEMM_W8A8_BLOCKSCALE", + use_rocm_aiter_gemm_w8a8_blockscale) + block_scale_func = dispatch_w8a8_blockscale_func(use_cutlass) + + if use_cutlass: + assert block_scale_func == cutlass_scaled_mm + elif current_platform.is_rocm() and int(use_rocm_aiter) and int( + use_rocm_aiter_gemm_w8a8_blockscale): + assert block_scale_func == rocm_aiter_gemm_w8a8_blockscale + else: + assert block_scale_func == w8a8_block_fp8_matmul diff --git a/vllm/envs.py b/vllm/envs.py index bf214f314c45..0bd3b7949e92 100644 --- a/vllm/envs.py +++ b/vllm/envs.py @@ -75,6 +75,8 @@ VLLM_SKIP_P2P_CHECK: bool = False VLLM_DISABLED_KERNELS: list[str] = [] VLLM_USE_V1: bool = True + VLLM_ROCM_USE_AITER: bool = False + VLLM_ROCM_USE_AITER_GEMM_W8A8_BLOCKSCALE: bool = False VLLM_ROCM_FP8_PADDING: bool = True VLLM_ENABLE_V1_MULTIPROCESSING: bool = True VLLM_LOG_BATCHSIZE_INTERVAL: float = -1 @@ -522,6 +524,20 @@ def maybe_convert_int(value: Optional[str]) -> Optional[int]: "VLLM_USE_V1": lambda: bool(int(os.getenv("VLLM_USE_V1", "1"))), + # Use aiter ops unless explicitly disabled. + # Acts as a parent switch to enable the rest of the operations. + "VLLM_ROCM_USE_AITER": + lambda: (os.getenv("VLLM_ROCM_USE_AITER", "False").lower() in + ("true", "1")), + + # Use aiter w8a8 block gemm kernel if aiter ops are enabled. + # This is disabled by default. + "VLLM_ROCM_USE_AITER_GEMM_W8A8_BLOCKSCALE": + lambda: + (os.getenv("VLLM_ROCM_USE_AITER", "False").lower() in + ("true", "1") and os.getenv("VLLM_ROCM_USE_AITER_GEMM_W8A8_BLOCKSCALE", + "False").lower() in ("true", "1")), + # Pad the fp8 weights to 256 bytes for ROCm "VLLM_ROCM_FP8_PADDING": lambda: bool(int(os.getenv("VLLM_ROCM_FP8_PADDING", "1"))), diff --git a/vllm/model_executor/layers/quantization/utils/fp8_utils.py b/vllm/model_executor/layers/quantization/utils/fp8_utils.py index ecb7996e1e8c..6a741fd46b7e 100644 --- a/vllm/model_executor/layers/quantization/utils/fp8_utils.py +++ b/vllm/model_executor/layers/quantization/utils/fp8_utils.py @@ -4,7 +4,7 @@ import functools import json import os -from typing import Any, Dict, List, Optional, Tuple, Union +from typing import Any, Callable, Dict, List, Optional, Tuple, Union import torch import triton @@ -28,6 +28,55 @@ def is_fp8(x: Union[torch.dtype, torch.Tensor]) -> bool: return x == torch.float8_e4m3fn or x == torch.float8_e4m3fnuz +def is_shape_supported_by_cutlass(weight: torch.Tensor, block_size: List[int], + weight_scale: torch.Tensor, + input_2d: torch.Tensor) -> bool: + if current_platform.is_rocm(): + scale_a_shape = ((input_2d.shape[-1] // block_size[1], ) + + input_2d.shape[:-1])[::-1] + scale_b_shape = (weight_scale.view(-1, 1) + if weight_scale.dim() <= 1 else weight_scale.T).shape + ar, ac = scale_a_shape + br, bc = scale_b_shape + return ac > 1 or bc > 1 or ar not in (1, input_2d.shape[0]) \ + or br not in (1, weight.shape[0]) + + return weight.shape[0] % 128 == 0 and weight.shape[1] % 128 == 0 + + +def cutlass_scaled_mm(A: torch.Tensor, B: torch.Tensor, As: torch.Tensor, + Bs: torch.Tensor, output_dtype: torch.dtype, + **kwargs) -> torch.Tensor: + return ops.cutlass_scaled_mm(A, + B.T, + out_dtype=output_dtype, + scale_a=As, + scale_b=Bs.T) + + +def rocm_aiter_gemm_w8a8_blockscale(A: torch.Tensor, + B: torch.Tensor, + As: torch.Tensor, + Bs: torch.Tensor, + output_dtype: torch.dtype = torch.float16, + **kwargs) -> torch.Tensor: + import aiter as rocm_aiter + + output = torch.zeros([A.shape[0], B.shape[0]], + dtype=output_dtype, + device=A.device) + return rocm_aiter.gemm_a8w8_blockscale(A, B, As, Bs, output) + + +def dispatch_w8a8_blockscale_func( + use_cutlass: bool) -> Callable[..., torch.Tensor]: + if use_cutlass: + return cutlass_scaled_mm + if current_platform.is_rocm_aiter_gemm_w8a8_blockscale_enabled(): + return rocm_aiter_gemm_w8a8_blockscale + return w8a8_block_fp8_matmul + + # TODO fix ROCm->Triton custom path: # https://github.com/vllm-project/vllm/issues/14397 def apply_w8a8_block_fp8_linear( @@ -44,38 +93,22 @@ def apply_w8a8_block_fp8_linear( input_2d = input.view(-1, input.shape[-1]) output_shape = [*input.shape[:-1], weight.shape[0]] - shape_supported_by_cutlass = (weight.shape[0] % 128 == 0 - and weight.shape[1] % 128 == 0) - if current_platform.is_rocm(): - # TODO this is never used, as cutlass_block_fp8_supported is False - scale_a_shape = ((input_2d.shape[-1] // block_size[1], ) + - input_2d.shape[:-1])[::-1] - scale_b_shape = (weight_scale.view(-1, 1) - if weight_scale.dim() <= 1 else weight_scale.T).shape - ar, ac = scale_a_shape - br, bc = scale_b_shape - if (ac > 1 or bc > 1 or ar not in (1, input_2d.shape[0]) - or br not in (1, weight.shape[0])): - shape_supported_by_cutlass = False - if cutlass_block_fp8_supported and shape_supported_by_cutlass: - q_input, x_scale = per_token_group_quant_fp8(input_2d, - block_size[1], - column_major_scales=True) - output = ops.cutlass_scaled_mm(q_input, - weight.T, - out_dtype=input.dtype, - scale_a=x_scale, - scale_b=weight_scale.T) - else: - q_input, x_scale = per_token_group_quant_fp8(input_2d, - block_size[1], - column_major_scales=False) - output = w8a8_block_fp8_matmul(q_input, - weight, - x_scale, - weight_scale, - block_size, - output_dtype=input.dtype) + # TODO is_shape_supported_by_cutlass is never used, + # as cutlass_block_fp8_supported is False + use_cutlass = cutlass_block_fp8_supported and is_shape_supported_by_cutlass( + weight, block_size, weight_scale, input_2d) + + q_input, x_scale = per_token_group_quant_fp8( + input_2d, block_size[1], column_major_scales=use_cutlass) + + output = dispatch_w8a8_blockscale_func(use_cutlass)( + A=q_input, + B=weight, + As=x_scale, + Bs=weight_scale, + block_size=block_size, + output_dtype=input.dtype) + if bias is not None: output = output + bias return output.to(dtype=input.dtype).view(*output_shape) diff --git a/vllm/platforms/interface.py b/vllm/platforms/interface.py index 7415b5d5f060..3c3371888b22 100644 --- a/vllm/platforms/interface.py +++ b/vllm/platforms/interface.py @@ -143,6 +143,9 @@ def is_cuda_alike(self) -> bool: """Stateless version of :func:`torch.cuda.is_available`.""" return self._enum in (PlatformEnum.CUDA, PlatformEnum.ROCM) + def is_rocm_aiter_w8a8_block_gemm_enabled(self) -> bool: + return False + @classmethod def get_attn_backend_cls(cls, selected_backend: _Backend, head_size: int, dtype: torch.dtype, kv_cache_dtype: Optional[str], diff --git a/vllm/platforms/rocm.py b/vllm/platforms/rocm.py index 75f287b568ac..877c78b23f43 100644 --- a/vllm/platforms/rocm.py +++ b/vllm/platforms/rocm.py @@ -110,6 +110,9 @@ class RocmPlatform(Platform): "fbgemm_fp8", "gguf", "quark", "ptpc_fp8" ] + def is_rocm_aiter_gemm_w8a8_blockscale_enabled(self) -> bool: + return envs.VLLM_ROCM_USE_AITER_GEMM_W8A8_BLOCKSCALE + @classmethod def get_attn_backend_cls(cls, selected_backend, head_size, dtype, kv_cache_dtype, block_size, use_v1, From 45efca97433ebd0690cc3e49f494529d7d35ab55 Mon Sep 17 00:00:00 2001 From: vllmellm Date: Mon, 17 Mar 2025 12:24:33 +0000 Subject: [PATCH 02/18] include AITER enable for rocm platforms in model end to end tests Signed-off-by: vllmellm --- .buildkite/run-amd-test.sh | 4 + .../decoder_only/language/test_mistral.py | 80 +++++++++++-------- .../decoder_only/language/test_models.py | 10 +++ .../decoder_only/language/test_phimoe.py | 19 ++--- tests/quantization/test_fp8.py | 22 ++++- 5 files changed, 91 insertions(+), 44 deletions(-) diff --git a/.buildkite/run-amd-test.sh b/.buildkite/run-amd-test.sh index 0680bae13ddb..2e15533ffcf8 100755 --- a/.buildkite/run-amd-test.sh +++ b/.buildkite/run-amd-test.sh @@ -72,6 +72,10 @@ HF_CACHE="$(realpath ~)/huggingface" mkdir -p "${HF_CACHE}" HF_MOUNT="/root/.cache/huggingface" +# environment variables +SKIP_ROCM_ATIER_MODEL_TEST_CASES="True" +echo $SKIP_ROCM_ATIER_MODEL_TEST_CASES + commands=$@ echo "Commands:$commands" #ignore certain kernels tests diff --git a/tests/models/decoder_only/language/test_mistral.py b/tests/models/decoder_only/language/test_mistral.py index 4c2055361d44..2809b0c98012 100644 --- a/tests/models/decoder_only/language/test_mistral.py +++ b/tests/models/decoder_only/language/test_mistral.py @@ -12,6 +12,7 @@ from vllm.entrypoints.openai.tool_parsers.mistral_tool_parser import ( # noqa MistralToolParser) +from vllm.platforms import current_platform from vllm.sampling_params import GuidedDecodingParams, SamplingParams from ...utils import check_logprobs_close @@ -174,15 +175,16 @@ @pytest.mark.parametrize("dtype", ["bfloat16"]) @pytest.mark.parametrize("max_tokens", [64]) @pytest.mark.parametrize("num_logprobs", [5]) -def test_models( - hf_runner, - vllm_runner, - example_prompts, - model: str, - dtype: str, - max_tokens: int, - num_logprobs: int, -) -> None: +@pytest.mark.parametrize( + "use_rocm_aiter", [True, False] if current_platform.is_rocm() else [False]) +def test_models(hf_runner, vllm_runner, example_prompts, model: str, + dtype: str, max_tokens: int, num_logprobs: int, + use_rocm_aiter: bool, monkeypatch) -> None: + if use_rocm_aiter: + if monkeypatch.getenv("SKIP_ROCM_ATIER_MODEL_TEST_CASES") == "true": + pytest.skip("Skipping test suite for ROCM AITER") + monkeypatch.setenv("VLLM_ROCM_USE_AITER", "1") + # TODO(sang): Sliding window should be tested separately. with hf_runner(model, dtype=dtype) as hf_model: hf_outputs = hf_model.generate_greedy_logprobs_limit( @@ -206,14 +208,16 @@ def test_models( @pytest.mark.parametrize("dtype", ["bfloat16"]) @pytest.mark.parametrize("max_tokens", [64]) @pytest.mark.parametrize("num_logprobs", [5]) -def test_mistral_format( - vllm_runner, - example_prompts, - model: str, - dtype: str, - max_tokens: int, - num_logprobs: int, -) -> None: +@pytest.mark.parametrize( + "use_rocm_aiter", [True, False] if current_platform.is_rocm() else [False]) +def test_mistral_format(vllm_runner, example_prompts, model: str, dtype: str, + max_tokens: int, num_logprobs: int, + use_rocm_aiter: bool, monkeypatch) -> None: + if use_rocm_aiter: + if monkeypatch.getenv("SKIP_ROCM_ATIER_MODEL_TEST_CASES") == "true": + pytest.skip("Skipping test suite for ROCM AITER") + monkeypatch.setenv("VLLM_ROCM_USE_AITER", "1") + with vllm_runner( model, dtype=dtype, @@ -244,11 +248,15 @@ def test_mistral_format( @pytest.mark.parametrize("model", MISTRAL_FORMAT_MODELS) @pytest.mark.parametrize("dtype", ["bfloat16"]) -def test_mistral_symbolic_languages( - vllm_runner, - model: str, - dtype: str, -) -> None: +@pytest.mark.parametrize( + "use_rocm_aiter", [True, False] if current_platform.is_rocm() else [False]) +def test_mistral_symbolic_languages(vllm_runner, model: str, dtype: str, + use_rocm_aiter: bool, monkeypatch) -> None: + if use_rocm_aiter: + if monkeypatch.getenv("SKIP_ROCM_ATIER_MODEL_TEST_CASES") == "true": + pytest.skip("Skipping test suite for ROCM AITER") + monkeypatch.setenv("VLLM_ROCM_USE_AITER", "1") + with vllm_runner(model, dtype=dtype, max_model_len=8192, @@ -266,11 +274,15 @@ def test_mistral_symbolic_languages( @pytest.mark.parametrize("dtype", ["bfloat16"]) @pytest.mark.parametrize("model", MISTRAL_FORMAT_MODELS) # v1 can't do func calling -def test_mistral_function_calling( - vllm_runner, - model: str, - dtype: str, -) -> None: +@pytest.mark.parametrize( + "use_rocm_aiter", [True, False] if current_platform.is_rocm() else [False]) +def test_mistral_function_calling(vllm_runner, model: str, dtype: str, + use_rocm_aiter: bool, monkeypatch) -> None: + if use_rocm_aiter: + if monkeypatch.getenv("SKIP_ROCM_ATIER_MODEL_TEST_CASES") == "true": + pytest.skip("Skipping test suite for ROCM AITER") + monkeypatch.setenv("VLLM_ROCM_USE_AITER", "1") + with vllm_runner(model, dtype=dtype, tokenizer_mode="mistral", @@ -301,11 +313,15 @@ def test_mistral_function_calling( @pytest.mark.parametrize("model", MODELS) @pytest.mark.parametrize("guided_backend", ["outlines", "lm-format-enforcer", "xgrammar"]) -def test_mistral_guided_decoding( - vllm_runner, - model: str, - guided_backend: str, -) -> None: +@pytest.mark.parametrize( + "use_rocm_aiter", [True, False] if current_platform.is_rocm() else [False]) +def test_mistral_guided_decoding(vllm_runner, model: str, guided_backend: str, + use_rocm_aiter: bool, monkeypatch) -> None: + if use_rocm_aiter: + if monkeypatch.getenv("SKIP_ROCM_ATIER_MODEL_TEST_CASES") == "true": + pytest.skip("Skipping test suite for ROCM AITER") + monkeypatch.setenv("VLLM_ROCM_USE_AITER", "1") + with vllm_runner(model, dtype='bfloat16', tokenizer_mode="mistral") as vllm_model: diff --git a/tests/models/decoder_only/language/test_models.py b/tests/models/decoder_only/language/test_models.py index a49926ea220e..7a25d652195d 100644 --- a/tests/models/decoder_only/language/test_models.py +++ b/tests/models/decoder_only/language/test_models.py @@ -5,6 +5,8 @@ """ import pytest +from vllm.platforms import current_platform + from ...utils import check_logprobs_close # These have unsupported head_dim for FA. We do not @@ -69,6 +71,8 @@ @pytest.mark.parametrize("dtype", ["half"]) @pytest.mark.parametrize("max_tokens", [32]) @pytest.mark.parametrize("num_logprobs", [5]) +@pytest.mark.parametrize( + "use_rocm_aiter", [True, False] if current_platform.is_rocm() else [False]) def test_models( hf_runner, vllm_runner, @@ -77,11 +81,17 @@ def test_models( dtype: str, max_tokens: int, num_logprobs: int, + use_rocm_aiter: bool, monkeypatch, ) -> None: if model in REQUIRES_V0: monkeypatch.setenv("VLLM_USE_V1", "0") + if use_rocm_aiter: + if monkeypatch.getenv("SKIP_ROCM_ATIER_MODEL_TEST_CASES") == "true": + pytest.skip("Skipping test suite for ROCM AITER") + monkeypatch.setenv("VLLM_ROCM_USE_AITER", "1") + with hf_runner(model, dtype=dtype) as hf_model: if model.startswith("THUDM/chatglm3"): hf_model.model.get_output_embeddings = lambda: \ diff --git a/tests/models/decoder_only/language/test_phimoe.py b/tests/models/decoder_only/language/test_phimoe.py index f9757d6ac295..2badcaf104bd 100644 --- a/tests/models/decoder_only/language/test_phimoe.py +++ b/tests/models/decoder_only/language/test_phimoe.py @@ -79,15 +79,16 @@ def test_phimoe_routing_function(): @pytest.mark.parametrize("dtype", ["bfloat16"]) @pytest.mark.parametrize("max_tokens", [64]) @pytest.mark.parametrize("num_logprobs", [5]) -def test_models( - hf_runner, - vllm_runner, - example_prompts, - model: str, - dtype: str, - max_tokens: int, - num_logprobs: int, -) -> None: +@pytest.mark.parametrize( + "use_rocm_aiter", [True, False] if current_platform.is_rocm() else [False]) +def test_models(hf_runner, vllm_runner, example_prompts, model: str, + dtype: str, max_tokens: int, num_logprobs: int, + use_rocm_aiter: bool, monkeypatch) -> None: + if use_rocm_aiter: + if monkeypatch.getenv("SKIP_ROCM_ATIER_MODEL_TEST_CASES") == "true": + pytest.skip("Skipping test suite for ROCM AITER") + monkeypatch.setenv("VLLM_ROCM_USE_AITER", "1") + with hf_runner(model, dtype=dtype) as hf_model: hf_outputs = hf_model.generate_greedy_logprobs_limit( example_prompts, max_tokens, num_logprobs) diff --git a/tests/quantization/test_fp8.py b/tests/quantization/test_fp8.py index 19cf29d3e659..5cadc8d5dd49 100644 --- a/tests/quantization/test_fp8.py +++ b/tests/quantization/test_fp8.py @@ -23,11 +23,16 @@ reason="FP8 is not supported on this GPU type.") @pytest.mark.parametrize("model_id", MODELS) @pytest.mark.parametrize("force_marlin", [False, True]) +@pytest.mark.parametrize( + "use_rocm_aiter", [True, False] if current_platform.is_rocm() else [False]) def test_model_load_and_run(vllm_runner, model_id: str, force_marlin: bool, - monkeypatch) -> None: + use_rocm_aiter: bool, monkeypatch) -> None: if force_marlin: monkeypatch.setenv("VLLM_TEST_FORCE_FP8_MARLIN", "1") + if use_rocm_aiter: + monkeypatch.setenv("VLLM_ROCM_USE_AITER", "1") + with vllm_runner(model_id) as llm: # note: this does not test accuracy, just that we can run through # see lm-eval tests for accuracy @@ -47,7 +52,13 @@ def test_model_load_and_run(vllm_runner, model_id: str, force_marlin: bool, @pytest.mark.skipif(not is_quant_method_supported("fp8"), reason="FP8 is not supported on this GPU type.") @pytest.mark.parametrize("model_id", KV_CACHE_MODELS) -def test_kv_cache_model_load_and_run(vllm_runner, model_id: str, monkeypatch): +@pytest.mark.parametrize( + "use_rocm_aiter", [True, False] if current_platform.is_rocm() else [False]) +def test_kv_cache_model_load_and_run(vllm_runner, model_id: str, + use_rocm_aiter: bool, monkeypatch): + if use_rocm_aiter: + monkeypatch.setenv("VLLM_ROCM_USE_AITER", "1") + # vllm_runner.apply_model() relies on V0 internals. monkeypatch.setenv("VLLM_USE_V1", "0") with vllm_runner(model_id, kv_cache_dtype="fp8") as llm: @@ -86,8 +97,13 @@ def check_model(model): reason="FP8 is not supported on this GPU type.") @pytest.mark.parametrize("kv_cache_dtype", ["auto", "fp8"]) @pytest.mark.parametrize("force_marlin", [False, True]) +@pytest.mark.parametrize( + "use_rocm_aiter", [True, False] if current_platform.is_rocm() else [False]) def test_load_fp16_model(vllm_runner, kv_cache_dtype: str, force_marlin: bool, - monkeypatch) -> None: + use_rocm_aiter: bool, monkeypatch) -> None: + if use_rocm_aiter: + monkeypatch.setenv("VLLM_ROCM_USE_AITER", "1") + # vllm_runner.apply_model() relies on V0 internals. monkeypatch.setenv("VLLM_USE_V1", "0") From 68a0479e8c273a03d52fd4fd58cd16f5ac966aff Mon Sep 17 00:00:00 2001 From: vllmellm Date: Mon, 17 Mar 2025 15:09:21 +0000 Subject: [PATCH 03/18] modify rocm docker Signed-off-by: vllmellm --- Dockerfile.rocm_base | 25 +++++++++++++++++++------ 1 file changed, 19 insertions(+), 6 deletions(-) diff --git a/Dockerfile.rocm_base b/Dockerfile.rocm_base index e33e73b30309..962555ceae4f 100644 --- a/Dockerfile.rocm_base +++ b/Dockerfile.rocm_base @@ -6,12 +6,14 @@ ARG RCCL_BRANCH="648a58d" ARG RCCL_REPO="https://github.com/ROCm/rccl" ARG TRITON_BRANCH="e5be006" ARG TRITON_REPO="https://github.com/triton-lang/triton.git" -ARG PYTORCH_BRANCH="3a585126" -ARG PYTORCH_VISION_BRANCH="v0.19.1" +ARG PYTORCH_BRANCH="6c0e7463" +ARG PYTORCH_VISION_BRANCH="v0.21.0" ARG PYTORCH_REPO="https://github.com/pytorch/pytorch.git" ARG PYTORCH_VISION_REPO="https://github.com/pytorch/vision.git" -ARG FA_BRANCH="b7d29fb" -ARG FA_REPO="https://github.com/ROCm/flash-attention.git" +ARG FA_BRANCH="1a7f4dfa" +ARG FA_REPO="https://github.com/Dao-AILab/flash-attention.git" +ARG AITER_BRANCH="e1ec015" +ARG AITER_REPO="https://github.com/ROCm/aiter.git" FROM ${BASE_IMAGE} AS base @@ -108,7 +110,7 @@ RUN git clone ${FA_REPO} RUN cd flash-attention \ && git checkout ${FA_BRANCH} \ && git submodule update --init \ - && MAX_JOBS=64 GPU_ARCHS=${PYTORCH_ROCM_ARCH} python3 setup.py bdist_wheel --dist-dir=dist + && GPU_ARCHS=${PYTORCH_ROCM_ARCH} python3 setup.py bdist_wheel --dist-dir=dist RUN mkdir -p /app/install && cp /app/pytorch/dist/*.whl /app/install \ && cp /app/vision/dist/*.whl /app/install \ && cp /app/flash-attention/dist/*.whl /app/install @@ -129,7 +131,17 @@ RUN --mount=type=bind,from=build_amdsmi,src=/app/install/,target=/install \ RUN --mount=type=bind,from=build_pytorch,src=/app/install/,target=/install \ pip install /install/*.whl +ARG AITER_REPO +ARG AITER_BRANCH +RUN git clone --recursive ${AITER_REPO} +RUN cd aiter \ + && git checkout ${AITER_BRANCH} \ + && git submodule update --init --recursive \ + && pip install -r requirements.txt \ + && PREBUILD_KERNELS=1 GPU_ARCHS=gfx942 python3 setup.py develop && pip show aiter + ARG BASE_IMAGE +ARG HIPBLAS_COMMON_BRANCH ARG HIPBLASLT_BRANCH ARG LEGACY_HIPBLASLT_OPTION ARG RCCL_BRANCH @@ -155,4 +167,5 @@ RUN echo "BASE_IMAGE: ${BASE_IMAGE}" > /app/versions.txt \ && echo "PYTORCH_REPO: ${PYTORCH_REPO}" >> /app/versions.txt \ && echo "PYTORCH_VISION_REPO: ${PYTORCH_VISION_REPO}" >> /app/versions.txt \ && echo "FA_BRANCH: ${FA_BRANCH}" >> /app/versions.txt \ - && echo "FA_REPO: ${FA_REPO}" >> /app/versions.txt + && echo "AITER_BRANCH: ${AITER_BRANCH}" >> /app/versions.txt \ + && echo "AITER_REPO: ${AITER_REPO}" >> /app/versions.txt \ No newline at end of file From 952e55ec97fe6a7358e3f1648aef5f593350a92b Mon Sep 17 00:00:00 2001 From: vllmellm Date: Tue, 18 Mar 2025 05:21:02 +0000 Subject: [PATCH 04/18] remove aiter blockscale flag check from current platform Signed-off-by: vllmellm --- vllm/model_executor/layers/quantization/utils/fp8_utils.py | 5 ++++- vllm/platforms/interface.py | 3 --- vllm/platforms/rocm.py | 3 --- 3 files changed, 4 insertions(+), 7 deletions(-) diff --git a/vllm/model_executor/layers/quantization/utils/fp8_utils.py b/vllm/model_executor/layers/quantization/utils/fp8_utils.py index 6a741fd46b7e..4591951be20a 100644 --- a/vllm/model_executor/layers/quantization/utils/fp8_utils.py +++ b/vllm/model_executor/layers/quantization/utils/fp8_utils.py @@ -10,6 +10,7 @@ import triton import triton.language as tl +import vllm.envs as envs from vllm import _custom_ops as ops from vllm.logger import init_logger from vllm.model_executor.layers.quantization.utils.quant_utils import ( @@ -70,9 +71,11 @@ def rocm_aiter_gemm_w8a8_blockscale(A: torch.Tensor, def dispatch_w8a8_blockscale_func( use_cutlass: bool) -> Callable[..., torch.Tensor]: + use_aiter_gemm_w8a8_blockscale = (current_platform.is_rocm() and \ + envs.VLLM_ROCM_USE_AITER_GEMM_W8A8_BLOCKSCALE) if use_cutlass: return cutlass_scaled_mm - if current_platform.is_rocm_aiter_gemm_w8a8_blockscale_enabled(): + if use_aiter_gemm_w8a8_blockscale: return rocm_aiter_gemm_w8a8_blockscale return w8a8_block_fp8_matmul diff --git a/vllm/platforms/interface.py b/vllm/platforms/interface.py index 3c3371888b22..7415b5d5f060 100644 --- a/vllm/platforms/interface.py +++ b/vllm/platforms/interface.py @@ -143,9 +143,6 @@ def is_cuda_alike(self) -> bool: """Stateless version of :func:`torch.cuda.is_available`.""" return self._enum in (PlatformEnum.CUDA, PlatformEnum.ROCM) - def is_rocm_aiter_w8a8_block_gemm_enabled(self) -> bool: - return False - @classmethod def get_attn_backend_cls(cls, selected_backend: _Backend, head_size: int, dtype: torch.dtype, kv_cache_dtype: Optional[str], diff --git a/vllm/platforms/rocm.py b/vllm/platforms/rocm.py index 877c78b23f43..75f287b568ac 100644 --- a/vllm/platforms/rocm.py +++ b/vllm/platforms/rocm.py @@ -110,9 +110,6 @@ class RocmPlatform(Platform): "fbgemm_fp8", "gguf", "quark", "ptpc_fp8" ] - def is_rocm_aiter_gemm_w8a8_blockscale_enabled(self) -> bool: - return envs.VLLM_ROCM_USE_AITER_GEMM_W8A8_BLOCKSCALE - @classmethod def get_attn_backend_cls(cls, selected_backend, head_size, dtype, kv_cache_dtype, block_size, use_v1, From 70ae1310d4e66ebbf90926648ecad9792ba0a255 Mon Sep 17 00:00:00 2001 From: vllmellm Date: Tue, 18 Mar 2025 06:48:57 +0000 Subject: [PATCH 05/18] keep packages versions in rocm docker_base file same as main and only add AITER package Signed-off-by: vllmellm --- Dockerfile.rocm_base | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/Dockerfile.rocm_base b/Dockerfile.rocm_base index 962555ceae4f..50d23cfc9ad5 100644 --- a/Dockerfile.rocm_base +++ b/Dockerfile.rocm_base @@ -6,12 +6,12 @@ ARG RCCL_BRANCH="648a58d" ARG RCCL_REPO="https://github.com/ROCm/rccl" ARG TRITON_BRANCH="e5be006" ARG TRITON_REPO="https://github.com/triton-lang/triton.git" -ARG PYTORCH_BRANCH="6c0e7463" -ARG PYTORCH_VISION_BRANCH="v0.21.0" +ARG PYTORCH_BRANCH="3a585126" +ARG PYTORCH_VISION_BRANCH="v0.19.1" ARG PYTORCH_REPO="https://github.com/pytorch/pytorch.git" ARG PYTORCH_VISION_REPO="https://github.com/pytorch/vision.git" -ARG FA_BRANCH="1a7f4dfa" -ARG FA_REPO="https://github.com/Dao-AILab/flash-attention.git" +ARG FA_BRANCH="b7d29fb" +ARG FA_REPO="https://github.com/ROCm/flash-attention.git" ARG AITER_BRANCH="e1ec015" ARG AITER_REPO="https://github.com/ROCm/aiter.git" @@ -110,7 +110,7 @@ RUN git clone ${FA_REPO} RUN cd flash-attention \ && git checkout ${FA_BRANCH} \ && git submodule update --init \ - && GPU_ARCHS=${PYTORCH_ROCM_ARCH} python3 setup.py bdist_wheel --dist-dir=dist + && MAX_JOBS=64 GPU_ARCHS=${PYTORCH_ROCM_ARCH} python3 setup.py bdist_wheel --dist-dir=dist RUN mkdir -p /app/install && cp /app/pytorch/dist/*.whl /app/install \ && cp /app/vision/dist/*.whl /app/install \ && cp /app/flash-attention/dist/*.whl /app/install @@ -139,9 +139,8 @@ RUN cd aiter \ && git submodule update --init --recursive \ && pip install -r requirements.txt \ && PREBUILD_KERNELS=1 GPU_ARCHS=gfx942 python3 setup.py develop && pip show aiter - + ARG BASE_IMAGE -ARG HIPBLAS_COMMON_BRANCH ARG HIPBLASLT_BRANCH ARG LEGACY_HIPBLASLT_OPTION ARG RCCL_BRANCH @@ -167,5 +166,6 @@ RUN echo "BASE_IMAGE: ${BASE_IMAGE}" > /app/versions.txt \ && echo "PYTORCH_REPO: ${PYTORCH_REPO}" >> /app/versions.txt \ && echo "PYTORCH_VISION_REPO: ${PYTORCH_VISION_REPO}" >> /app/versions.txt \ && echo "FA_BRANCH: ${FA_BRANCH}" >> /app/versions.txt \ + && echo "FA_REPO: ${FA_REPO}" >> /app/versions.txt && echo "AITER_BRANCH: ${AITER_BRANCH}" >> /app/versions.txt \ && echo "AITER_REPO: ${AITER_REPO}" >> /app/versions.txt \ No newline at end of file From a5c2e895e675a9233bf0b19ba3b916fbf1c9ed19 Mon Sep 17 00:00:00 2001 From: vllmellm Date: Tue, 18 Mar 2025 07:42:32 +0000 Subject: [PATCH 06/18] fix get envs variables in unit tests Signed-off-by: vllmellm --- tests/models/decoder_only/language/test_mistral.py | 11 ++++++----- tests/models/decoder_only/language/test_models.py | 4 +++- tests/models/decoder_only/language/test_phimoe.py | 4 +++- 3 files changed, 12 insertions(+), 7 deletions(-) diff --git a/tests/models/decoder_only/language/test_mistral.py b/tests/models/decoder_only/language/test_mistral.py index 2809b0c98012..8c6353d5b3eb 100644 --- a/tests/models/decoder_only/language/test_mistral.py +++ b/tests/models/decoder_only/language/test_mistral.py @@ -5,6 +5,7 @@ """ import copy import json +import os import jsonschema import jsonschema.exceptions @@ -181,7 +182,7 @@ def test_models(hf_runner, vllm_runner, example_prompts, model: str, dtype: str, max_tokens: int, num_logprobs: int, use_rocm_aiter: bool, monkeypatch) -> None: if use_rocm_aiter: - if monkeypatch.getenv("SKIP_ROCM_ATIER_MODEL_TEST_CASES") == "true": + if os.getenv("SKIP_ROCM_ATIER_MODEL_TEST_CASES") == "true": pytest.skip("Skipping test suite for ROCM AITER") monkeypatch.setenv("VLLM_ROCM_USE_AITER", "1") @@ -214,7 +215,7 @@ def test_mistral_format(vllm_runner, example_prompts, model: str, dtype: str, max_tokens: int, num_logprobs: int, use_rocm_aiter: bool, monkeypatch) -> None: if use_rocm_aiter: - if monkeypatch.getenv("SKIP_ROCM_ATIER_MODEL_TEST_CASES") == "true": + if os.getenv("SKIP_ROCM_ATIER_MODEL_TEST_CASES") == "true": pytest.skip("Skipping test suite for ROCM AITER") monkeypatch.setenv("VLLM_ROCM_USE_AITER", "1") @@ -253,7 +254,7 @@ def test_mistral_format(vllm_runner, example_prompts, model: str, dtype: str, def test_mistral_symbolic_languages(vllm_runner, model: str, dtype: str, use_rocm_aiter: bool, monkeypatch) -> None: if use_rocm_aiter: - if monkeypatch.getenv("SKIP_ROCM_ATIER_MODEL_TEST_CASES") == "true": + if os.getenv("SKIP_ROCM_ATIER_MODEL_TEST_CASES") == "true": pytest.skip("Skipping test suite for ROCM AITER") monkeypatch.setenv("VLLM_ROCM_USE_AITER", "1") @@ -279,7 +280,7 @@ def test_mistral_symbolic_languages(vllm_runner, model: str, dtype: str, def test_mistral_function_calling(vllm_runner, model: str, dtype: str, use_rocm_aiter: bool, monkeypatch) -> None: if use_rocm_aiter: - if monkeypatch.getenv("SKIP_ROCM_ATIER_MODEL_TEST_CASES") == "true": + if os.getenv("SKIP_ROCM_ATIER_MODEL_TEST_CASES") == "true": pytest.skip("Skipping test suite for ROCM AITER") monkeypatch.setenv("VLLM_ROCM_USE_AITER", "1") @@ -318,7 +319,7 @@ def test_mistral_function_calling(vllm_runner, model: str, dtype: str, def test_mistral_guided_decoding(vllm_runner, model: str, guided_backend: str, use_rocm_aiter: bool, monkeypatch) -> None: if use_rocm_aiter: - if monkeypatch.getenv("SKIP_ROCM_ATIER_MODEL_TEST_CASES") == "true": + if os.getenv("SKIP_ROCM_ATIER_MODEL_TEST_CASES") == "true": pytest.skip("Skipping test suite for ROCM AITER") monkeypatch.setenv("VLLM_ROCM_USE_AITER", "1") diff --git a/tests/models/decoder_only/language/test_models.py b/tests/models/decoder_only/language/test_models.py index 7a25d652195d..593fc7af2fb4 100644 --- a/tests/models/decoder_only/language/test_models.py +++ b/tests/models/decoder_only/language/test_models.py @@ -3,6 +3,8 @@ Run `pytest tests/models/test_models.py`. """ +import os + import pytest from vllm.platforms import current_platform @@ -88,7 +90,7 @@ def test_models( monkeypatch.setenv("VLLM_USE_V1", "0") if use_rocm_aiter: - if monkeypatch.getenv("SKIP_ROCM_ATIER_MODEL_TEST_CASES") == "true": + if os.getenv("SKIP_ROCM_ATIER_MODEL_TEST_CASES") == "true": pytest.skip("Skipping test suite for ROCM AITER") monkeypatch.setenv("VLLM_ROCM_USE_AITER", "1") diff --git a/tests/models/decoder_only/language/test_phimoe.py b/tests/models/decoder_only/language/test_phimoe.py index 2badcaf104bd..d9cfac1d3b38 100644 --- a/tests/models/decoder_only/language/test_phimoe.py +++ b/tests/models/decoder_only/language/test_phimoe.py @@ -3,6 +3,8 @@ Run `pytest tests/models/test_phimoe.py`. """ +import os + import pytest import torch @@ -85,7 +87,7 @@ def test_models(hf_runner, vllm_runner, example_prompts, model: str, dtype: str, max_tokens: int, num_logprobs: int, use_rocm_aiter: bool, monkeypatch) -> None: if use_rocm_aiter: - if monkeypatch.getenv("SKIP_ROCM_ATIER_MODEL_TEST_CASES") == "true": + if os.getenv("SKIP_ROCM_ATIER_MODEL_TEST_CASES") == "true": pytest.skip("Skipping test suite for ROCM AITER") monkeypatch.setenv("VLLM_ROCM_USE_AITER", "1") From 9507aac7b1a6cdc735f0fa5e4947a6c3c4de38eb Mon Sep 17 00:00:00 2001 From: vllmellm Date: Mon, 24 Mar 2025 15:55:46 +0000 Subject: [PATCH 07/18] remove cascading logic in envs and use CK gemm w8a8 block from AITER Signed-off-by: vllmellm --- vllm/envs.py | 5 ++--- .../layers/quantization/utils/fp8_utils.py | 15 ++++++++------- 2 files changed, 10 insertions(+), 10 deletions(-) diff --git a/vllm/envs.py b/vllm/envs.py index 0bd3b7949e92..84935862a358 100644 --- a/vllm/envs.py +++ b/vllm/envs.py @@ -534,9 +534,8 @@ def maybe_convert_int(value: Optional[str]) -> Optional[int]: # This is disabled by default. "VLLM_ROCM_USE_AITER_GEMM_W8A8_BLOCKSCALE": lambda: - (os.getenv("VLLM_ROCM_USE_AITER", "False").lower() in - ("true", "1") and os.getenv("VLLM_ROCM_USE_AITER_GEMM_W8A8_BLOCKSCALE", - "False").lower() in ("true", "1")), + (os.getenv("VLLM_ROCM_USE_AITER_GEMM_W8A8_BLOCKSCALE", "False").lower() in + ("true", "1")), # Pad the fp8 weights to 256 bytes for ROCm "VLLM_ROCM_FP8_PADDING": diff --git a/vllm/model_executor/layers/quantization/utils/fp8_utils.py b/vllm/model_executor/layers/quantization/utils/fp8_utils.py index 4591951be20a..0fa69aebe69e 100644 --- a/vllm/model_executor/layers/quantization/utils/fp8_utils.py +++ b/vllm/model_executor/layers/quantization/utils/fp8_utils.py @@ -63,19 +63,20 @@ def rocm_aiter_gemm_w8a8_blockscale(A: torch.Tensor, **kwargs) -> torch.Tensor: import aiter as rocm_aiter - output = torch.zeros([A.shape[0], B.shape[0]], - dtype=output_dtype, - device=A.device) - return rocm_aiter.gemm_a8w8_blockscale(A, B, As, Bs, output) + return rocm_aiter.gemm_a8w8_blockscale_CK(A, B, As, Bs, dtype=output_dtype) + + +def is_rocm_aiter_gemm_w8a8_blockscale_enabled() -> bool: + return current_platform.is_rocm() \ + and envs.VLLM_ROCM_USE_AITER \ + and envs.VLLM_ROCM_USE_AITER_GEMM_W8A8_BLOCKSCALE def dispatch_w8a8_blockscale_func( use_cutlass: bool) -> Callable[..., torch.Tensor]: - use_aiter_gemm_w8a8_blockscale = (current_platform.is_rocm() and \ - envs.VLLM_ROCM_USE_AITER_GEMM_W8A8_BLOCKSCALE) if use_cutlass: return cutlass_scaled_mm - if use_aiter_gemm_w8a8_blockscale: + if is_rocm_aiter_gemm_w8a8_blockscale_enabled(): return rocm_aiter_gemm_w8a8_blockscale return w8a8_block_fp8_matmul From 7f17da768e4696dac49d87414e099f9a338e2b58 Mon Sep 17 00:00:00 2001 From: vllmellm Date: Mon, 24 Mar 2025 15:57:09 +0000 Subject: [PATCH 08/18] revert back the unittest to their original format Signed-off-by: vllmellm --- .buildkite/run-amd-test.sh | 4 -- .../decoder_only/language/test_mistral.py | 54 +++---------------- .../decoder_only/language/test_models.py | 12 ----- .../decoder_only/language/test_phimoe.py | 11 +--- tests/quantization/test_fp8.py | 22 ++------ 5 files changed, 11 insertions(+), 92 deletions(-) diff --git a/.buildkite/run-amd-test.sh b/.buildkite/run-amd-test.sh index 2e15533ffcf8..0680bae13ddb 100755 --- a/.buildkite/run-amd-test.sh +++ b/.buildkite/run-amd-test.sh @@ -72,10 +72,6 @@ HF_CACHE="$(realpath ~)/huggingface" mkdir -p "${HF_CACHE}" HF_MOUNT="/root/.cache/huggingface" -# environment variables -SKIP_ROCM_ATIER_MODEL_TEST_CASES="True" -echo $SKIP_ROCM_ATIER_MODEL_TEST_CASES - commands=$@ echo "Commands:$commands" #ignore certain kernels tests diff --git a/tests/models/decoder_only/language/test_mistral.py b/tests/models/decoder_only/language/test_mistral.py index 8c6353d5b3eb..ec885386dd94 100644 --- a/tests/models/decoder_only/language/test_mistral.py +++ b/tests/models/decoder_only/language/test_mistral.py @@ -5,7 +5,6 @@ """ import copy import json -import os import jsonschema import jsonschema.exceptions @@ -13,7 +12,6 @@ from vllm.entrypoints.openai.tool_parsers.mistral_tool_parser import ( # noqa MistralToolParser) -from vllm.platforms import current_platform from vllm.sampling_params import GuidedDecodingParams, SamplingParams from ...utils import check_logprobs_close @@ -176,16 +174,8 @@ @pytest.mark.parametrize("dtype", ["bfloat16"]) @pytest.mark.parametrize("max_tokens", [64]) @pytest.mark.parametrize("num_logprobs", [5]) -@pytest.mark.parametrize( - "use_rocm_aiter", [True, False] if current_platform.is_rocm() else [False]) def test_models(hf_runner, vllm_runner, example_prompts, model: str, - dtype: str, max_tokens: int, num_logprobs: int, - use_rocm_aiter: bool, monkeypatch) -> None: - if use_rocm_aiter: - if os.getenv("SKIP_ROCM_ATIER_MODEL_TEST_CASES") == "true": - pytest.skip("Skipping test suite for ROCM AITER") - monkeypatch.setenv("VLLM_ROCM_USE_AITER", "1") - + dtype: str, max_tokens: int, num_logprobs: int) -> None: # TODO(sang): Sliding window should be tested separately. with hf_runner(model, dtype=dtype) as hf_model: hf_outputs = hf_model.generate_greedy_logprobs_limit( @@ -209,16 +199,8 @@ def test_models(hf_runner, vllm_runner, example_prompts, model: str, @pytest.mark.parametrize("dtype", ["bfloat16"]) @pytest.mark.parametrize("max_tokens", [64]) @pytest.mark.parametrize("num_logprobs", [5]) -@pytest.mark.parametrize( - "use_rocm_aiter", [True, False] if current_platform.is_rocm() else [False]) def test_mistral_format(vllm_runner, example_prompts, model: str, dtype: str, - max_tokens: int, num_logprobs: int, - use_rocm_aiter: bool, monkeypatch) -> None: - if use_rocm_aiter: - if os.getenv("SKIP_ROCM_ATIER_MODEL_TEST_CASES") == "true": - pytest.skip("Skipping test suite for ROCM AITER") - monkeypatch.setenv("VLLM_ROCM_USE_AITER", "1") - + max_tokens: int, num_logprobs: int) -> None: with vllm_runner( model, dtype=dtype, @@ -249,15 +231,8 @@ def test_mistral_format(vllm_runner, example_prompts, model: str, dtype: str, @pytest.mark.parametrize("model", MISTRAL_FORMAT_MODELS) @pytest.mark.parametrize("dtype", ["bfloat16"]) -@pytest.mark.parametrize( - "use_rocm_aiter", [True, False] if current_platform.is_rocm() else [False]) -def test_mistral_symbolic_languages(vllm_runner, model: str, dtype: str, - use_rocm_aiter: bool, monkeypatch) -> None: - if use_rocm_aiter: - if os.getenv("SKIP_ROCM_ATIER_MODEL_TEST_CASES") == "true": - pytest.skip("Skipping test suite for ROCM AITER") - monkeypatch.setenv("VLLM_ROCM_USE_AITER", "1") - +def test_mistral_symbolic_languages(vllm_runner, model: str, + dtype: str) -> None: with vllm_runner(model, dtype=dtype, max_model_len=8192, @@ -275,15 +250,7 @@ def test_mistral_symbolic_languages(vllm_runner, model: str, dtype: str, @pytest.mark.parametrize("dtype", ["bfloat16"]) @pytest.mark.parametrize("model", MISTRAL_FORMAT_MODELS) # v1 can't do func calling -@pytest.mark.parametrize( - "use_rocm_aiter", [True, False] if current_platform.is_rocm() else [False]) -def test_mistral_function_calling(vllm_runner, model: str, dtype: str, - use_rocm_aiter: bool, monkeypatch) -> None: - if use_rocm_aiter: - if os.getenv("SKIP_ROCM_ATIER_MODEL_TEST_CASES") == "true": - pytest.skip("Skipping test suite for ROCM AITER") - monkeypatch.setenv("VLLM_ROCM_USE_AITER", "1") - +def test_mistral_function_calling(vllm_runner, model: str, dtype: str) -> None: with vllm_runner(model, dtype=dtype, tokenizer_mode="mistral", @@ -314,15 +281,8 @@ def test_mistral_function_calling(vllm_runner, model: str, dtype: str, @pytest.mark.parametrize("model", MODELS) @pytest.mark.parametrize("guided_backend", ["outlines", "lm-format-enforcer", "xgrammar"]) -@pytest.mark.parametrize( - "use_rocm_aiter", [True, False] if current_platform.is_rocm() else [False]) -def test_mistral_guided_decoding(vllm_runner, model: str, guided_backend: str, - use_rocm_aiter: bool, monkeypatch) -> None: - if use_rocm_aiter: - if os.getenv("SKIP_ROCM_ATIER_MODEL_TEST_CASES") == "true": - pytest.skip("Skipping test suite for ROCM AITER") - monkeypatch.setenv("VLLM_ROCM_USE_AITER", "1") - +def test_mistral_guided_decoding(vllm_runner, model: str, + guided_backend: str) -> None: with vllm_runner(model, dtype='bfloat16', tokenizer_mode="mistral") as vllm_model: diff --git a/tests/models/decoder_only/language/test_models.py b/tests/models/decoder_only/language/test_models.py index 593fc7af2fb4..a49926ea220e 100644 --- a/tests/models/decoder_only/language/test_models.py +++ b/tests/models/decoder_only/language/test_models.py @@ -3,12 +3,8 @@ Run `pytest tests/models/test_models.py`. """ -import os - import pytest -from vllm.platforms import current_platform - from ...utils import check_logprobs_close # These have unsupported head_dim for FA. We do not @@ -73,8 +69,6 @@ @pytest.mark.parametrize("dtype", ["half"]) @pytest.mark.parametrize("max_tokens", [32]) @pytest.mark.parametrize("num_logprobs", [5]) -@pytest.mark.parametrize( - "use_rocm_aiter", [True, False] if current_platform.is_rocm() else [False]) def test_models( hf_runner, vllm_runner, @@ -83,17 +77,11 @@ def test_models( dtype: str, max_tokens: int, num_logprobs: int, - use_rocm_aiter: bool, monkeypatch, ) -> None: if model in REQUIRES_V0: monkeypatch.setenv("VLLM_USE_V1", "0") - if use_rocm_aiter: - if os.getenv("SKIP_ROCM_ATIER_MODEL_TEST_CASES") == "true": - pytest.skip("Skipping test suite for ROCM AITER") - monkeypatch.setenv("VLLM_ROCM_USE_AITER", "1") - with hf_runner(model, dtype=dtype) as hf_model: if model.startswith("THUDM/chatglm3"): hf_model.model.get_output_embeddings = lambda: \ diff --git a/tests/models/decoder_only/language/test_phimoe.py b/tests/models/decoder_only/language/test_phimoe.py index d9cfac1d3b38..7194efaa7bca 100644 --- a/tests/models/decoder_only/language/test_phimoe.py +++ b/tests/models/decoder_only/language/test_phimoe.py @@ -3,7 +3,6 @@ Run `pytest tests/models/test_phimoe.py`. """ -import os import pytest import torch @@ -81,16 +80,8 @@ def test_phimoe_routing_function(): @pytest.mark.parametrize("dtype", ["bfloat16"]) @pytest.mark.parametrize("max_tokens", [64]) @pytest.mark.parametrize("num_logprobs", [5]) -@pytest.mark.parametrize( - "use_rocm_aiter", [True, False] if current_platform.is_rocm() else [False]) def test_models(hf_runner, vllm_runner, example_prompts, model: str, - dtype: str, max_tokens: int, num_logprobs: int, - use_rocm_aiter: bool, monkeypatch) -> None: - if use_rocm_aiter: - if os.getenv("SKIP_ROCM_ATIER_MODEL_TEST_CASES") == "true": - pytest.skip("Skipping test suite for ROCM AITER") - monkeypatch.setenv("VLLM_ROCM_USE_AITER", "1") - + dtype: str, max_tokens: int, num_logprobs: int) -> None: with hf_runner(model, dtype=dtype) as hf_model: hf_outputs = hf_model.generate_greedy_logprobs_limit( example_prompts, max_tokens, num_logprobs) diff --git a/tests/quantization/test_fp8.py b/tests/quantization/test_fp8.py index 5cadc8d5dd49..19cf29d3e659 100644 --- a/tests/quantization/test_fp8.py +++ b/tests/quantization/test_fp8.py @@ -23,16 +23,11 @@ reason="FP8 is not supported on this GPU type.") @pytest.mark.parametrize("model_id", MODELS) @pytest.mark.parametrize("force_marlin", [False, True]) -@pytest.mark.parametrize( - "use_rocm_aiter", [True, False] if current_platform.is_rocm() else [False]) def test_model_load_and_run(vllm_runner, model_id: str, force_marlin: bool, - use_rocm_aiter: bool, monkeypatch) -> None: + monkeypatch) -> None: if force_marlin: monkeypatch.setenv("VLLM_TEST_FORCE_FP8_MARLIN", "1") - if use_rocm_aiter: - monkeypatch.setenv("VLLM_ROCM_USE_AITER", "1") - with vllm_runner(model_id) as llm: # note: this does not test accuracy, just that we can run through # see lm-eval tests for accuracy @@ -52,13 +47,7 @@ def test_model_load_and_run(vllm_runner, model_id: str, force_marlin: bool, @pytest.mark.skipif(not is_quant_method_supported("fp8"), reason="FP8 is not supported on this GPU type.") @pytest.mark.parametrize("model_id", KV_CACHE_MODELS) -@pytest.mark.parametrize( - "use_rocm_aiter", [True, False] if current_platform.is_rocm() else [False]) -def test_kv_cache_model_load_and_run(vllm_runner, model_id: str, - use_rocm_aiter: bool, monkeypatch): - if use_rocm_aiter: - monkeypatch.setenv("VLLM_ROCM_USE_AITER", "1") - +def test_kv_cache_model_load_and_run(vllm_runner, model_id: str, monkeypatch): # vllm_runner.apply_model() relies on V0 internals. monkeypatch.setenv("VLLM_USE_V1", "0") with vllm_runner(model_id, kv_cache_dtype="fp8") as llm: @@ -97,13 +86,8 @@ def check_model(model): reason="FP8 is not supported on this GPU type.") @pytest.mark.parametrize("kv_cache_dtype", ["auto", "fp8"]) @pytest.mark.parametrize("force_marlin", [False, True]) -@pytest.mark.parametrize( - "use_rocm_aiter", [True, False] if current_platform.is_rocm() else [False]) def test_load_fp16_model(vllm_runner, kv_cache_dtype: str, force_marlin: bool, - use_rocm_aiter: bool, monkeypatch) -> None: - if use_rocm_aiter: - monkeypatch.setenv("VLLM_ROCM_USE_AITER", "1") - + monkeypatch) -> None: # vllm_runner.apply_model() relies on V0 internals. monkeypatch.setenv("VLLM_USE_V1", "0") From 31fefa3eee7d5075f82a776bc5a6b71bc5fb6ca9 Mon Sep 17 00:00:00 2001 From: vllmellm Date: Mon, 24 Mar 2025 16:05:26 +0000 Subject: [PATCH 09/18] revert back unneccessary changes Signed-off-by: vllmellm --- tests/models/decoder_only/language/test_mistral.py | 11 +++++++++-- tests/models/decoder_only/language/test_phimoe.py | 1 - 2 files changed, 9 insertions(+), 3 deletions(-) diff --git a/tests/models/decoder_only/language/test_mistral.py b/tests/models/decoder_only/language/test_mistral.py index ec885386dd94..717dca0b2244 100644 --- a/tests/models/decoder_only/language/test_mistral.py +++ b/tests/models/decoder_only/language/test_mistral.py @@ -174,8 +174,15 @@ @pytest.mark.parametrize("dtype", ["bfloat16"]) @pytest.mark.parametrize("max_tokens", [64]) @pytest.mark.parametrize("num_logprobs", [5]) -def test_models(hf_runner, vllm_runner, example_prompts, model: str, - dtype: str, max_tokens: int, num_logprobs: int) -> None: +def test_models( + hf_runner, + vllm_runner, + example_prompts, + model: str, + dtype: str, + max_tokens: int, + num_logprobs: int, +) -> None: # TODO(sang): Sliding window should be tested separately. with hf_runner(model, dtype=dtype) as hf_model: hf_outputs = hf_model.generate_greedy_logprobs_limit( diff --git a/tests/models/decoder_only/language/test_phimoe.py b/tests/models/decoder_only/language/test_phimoe.py index 7194efaa7bca..5e43f20bd2b1 100644 --- a/tests/models/decoder_only/language/test_phimoe.py +++ b/tests/models/decoder_only/language/test_phimoe.py @@ -3,7 +3,6 @@ Run `pytest tests/models/test_phimoe.py`. """ - import pytest import torch From 6bed514473178526dcaf24cd28b03c03ab57f29e Mon Sep 17 00:00:00 2001 From: vllmellm Date: Mon, 24 Mar 2025 16:09:19 +0000 Subject: [PATCH 10/18] revert back unneccassary changes Signed-off-by: vllmellm --- .../decoder_only/language/test_mistral.py | 30 ++++++++++++++----- .../decoder_only/language/test_phimoe.py | 11 +++++-- 2 files changed, 32 insertions(+), 9 deletions(-) diff --git a/tests/models/decoder_only/language/test_mistral.py b/tests/models/decoder_only/language/test_mistral.py index 717dca0b2244..4c2055361d44 100644 --- a/tests/models/decoder_only/language/test_mistral.py +++ b/tests/models/decoder_only/language/test_mistral.py @@ -206,8 +206,14 @@ def test_models( @pytest.mark.parametrize("dtype", ["bfloat16"]) @pytest.mark.parametrize("max_tokens", [64]) @pytest.mark.parametrize("num_logprobs", [5]) -def test_mistral_format(vllm_runner, example_prompts, model: str, dtype: str, - max_tokens: int, num_logprobs: int) -> None: +def test_mistral_format( + vllm_runner, + example_prompts, + model: str, + dtype: str, + max_tokens: int, + num_logprobs: int, +) -> None: with vllm_runner( model, dtype=dtype, @@ -238,8 +244,11 @@ def test_mistral_format(vllm_runner, example_prompts, model: str, dtype: str, @pytest.mark.parametrize("model", MISTRAL_FORMAT_MODELS) @pytest.mark.parametrize("dtype", ["bfloat16"]) -def test_mistral_symbolic_languages(vllm_runner, model: str, - dtype: str) -> None: +def test_mistral_symbolic_languages( + vllm_runner, + model: str, + dtype: str, +) -> None: with vllm_runner(model, dtype=dtype, max_model_len=8192, @@ -257,7 +266,11 @@ def test_mistral_symbolic_languages(vllm_runner, model: str, @pytest.mark.parametrize("dtype", ["bfloat16"]) @pytest.mark.parametrize("model", MISTRAL_FORMAT_MODELS) # v1 can't do func calling -def test_mistral_function_calling(vllm_runner, model: str, dtype: str) -> None: +def test_mistral_function_calling( + vllm_runner, + model: str, + dtype: str, +) -> None: with vllm_runner(model, dtype=dtype, tokenizer_mode="mistral", @@ -288,8 +301,11 @@ def test_mistral_function_calling(vllm_runner, model: str, dtype: str) -> None: @pytest.mark.parametrize("model", MODELS) @pytest.mark.parametrize("guided_backend", ["outlines", "lm-format-enforcer", "xgrammar"]) -def test_mistral_guided_decoding(vllm_runner, model: str, - guided_backend: str) -> None: +def test_mistral_guided_decoding( + vllm_runner, + model: str, + guided_backend: str, +) -> None: with vllm_runner(model, dtype='bfloat16', tokenizer_mode="mistral") as vllm_model: diff --git a/tests/models/decoder_only/language/test_phimoe.py b/tests/models/decoder_only/language/test_phimoe.py index 5e43f20bd2b1..f9757d6ac295 100644 --- a/tests/models/decoder_only/language/test_phimoe.py +++ b/tests/models/decoder_only/language/test_phimoe.py @@ -79,8 +79,15 @@ def test_phimoe_routing_function(): @pytest.mark.parametrize("dtype", ["bfloat16"]) @pytest.mark.parametrize("max_tokens", [64]) @pytest.mark.parametrize("num_logprobs", [5]) -def test_models(hf_runner, vllm_runner, example_prompts, model: str, - dtype: str, max_tokens: int, num_logprobs: int) -> None: +def test_models( + hf_runner, + vllm_runner, + example_prompts, + model: str, + dtype: str, + max_tokens: int, + num_logprobs: int, +) -> None: with hf_runner(model, dtype=dtype) as hf_model: hf_outputs = hf_model.generate_greedy_logprobs_limit( example_prompts, max_tokens, num_logprobs) From d6cce8f50612376d24839f977285cedb02c75ce1 Mon Sep 17 00:00:00 2001 From: tjtanaa Date: Wed, 16 Apr 2025 18:16:54 +0000 Subject: [PATCH 11/18] wrap aiter kernel into direct register custom op Signed-off-by: tjtanaa --- .../model_executor/test_enabled_custom_ops.py | 7 ++-- .../layers/quantization/utils/fp8_utils.py | 41 +++++++++++++++---- 2 files changed, 38 insertions(+), 10 deletions(-) diff --git a/tests/model_executor/test_enabled_custom_ops.py b/tests/model_executor/test_enabled_custom_ops.py index 08431e8ee893..bd4068b443dd 100644 --- a/tests/model_executor/test_enabled_custom_ops.py +++ b/tests/model_executor/test_enabled_custom_ops.py @@ -1,6 +1,7 @@ # SPDX-License-Identifier: Apache-2.0 import pytest +import torch from vllm.config import CompilationConfig, VllmConfig, set_current_vllm_config from vllm.model_executor.custom_op import CustomOp @@ -15,8 +16,7 @@ RMSNorm, dispatch_cuda_rmsnorm_func, fused_add_rms_norm, rms_norm, rocm_aiter_fused_add_rms_norm, rocm_aiter_rms_norm) from vllm.model_executor.layers.quantization.utils.fp8_utils import ( - cutlass_scaled_mm, dispatch_w8a8_blockscale_func, - rocm_aiter_gemm_w8a8_blockscale, w8a8_block_fp8_matmul) + cutlass_scaled_mm, dispatch_w8a8_blockscale_func, w8a8_block_fp8_matmul) from vllm.platforms import current_platform @@ -115,7 +115,8 @@ def test_w8a8_blockscale_dispatch(use_cutlass: bool, use_rocm_aiter: str, assert block_scale_func == cutlass_scaled_mm elif current_platform.is_rocm() and int(use_rocm_aiter) and int( use_rocm_aiter_gemm_w8a8_blockscale): - assert block_scale_func == rocm_aiter_gemm_w8a8_blockscale + assert block_scale_func == ( + torch.ops.vllm.rocm_aiter_gemm_w8a8_blockscale) else: assert block_scale_func == w8a8_block_fp8_matmul diff --git a/vllm/model_executor/layers/quantization/utils/fp8_utils.py b/vllm/model_executor/layers/quantization/utils/fp8_utils.py index 0fa69aebe69e..5220c97b363b 100644 --- a/vllm/model_executor/layers/quantization/utils/fp8_utils.py +++ b/vllm/model_executor/layers/quantization/utils/fp8_utils.py @@ -55,17 +55,44 @@ def cutlass_scaled_mm(A: torch.Tensor, B: torch.Tensor, As: torch.Tensor, scale_b=Bs.T) -def rocm_aiter_gemm_w8a8_blockscale(A: torch.Tensor, - B: torch.Tensor, - As: torch.Tensor, - Bs: torch.Tensor, - output_dtype: torch.dtype = torch.float16, - **kwargs) -> torch.Tensor: +def rocm_aiter_gemm_w8a8_blockscale_impl( + A: torch.Tensor, + B: torch.Tensor, + As: torch.Tensor, + Bs: torch.Tensor, + output_dtype: torch.dtype = torch.float16, + block_size: Optional[List[int]] = None, +) -> torch.Tensor: import aiter as rocm_aiter return rocm_aiter.gemm_a8w8_blockscale_CK(A, B, As, Bs, dtype=output_dtype) +def rocm_aiter_gemm_w8a8_blockscale_fake( + A: torch.Tensor, + B: torch.Tensor, + As: torch.Tensor, + Bs: torch.Tensor, + output_dtype: torch.dtype = torch.float16, + block_size: Optional[List[int]] = None, +) -> torch.Tensor: + + m = A.shape[0] + n = B.shape[0] + Y = torch.empty(m, n, dtype=output_dtype, device=A.device) + return Y + + +if current_platform.is_rocm(): + direct_register_custom_op( + op_name="rocm_aiter_gemm_w8a8_blockscale", + op_func=rocm_aiter_gemm_w8a8_blockscale_impl, + mutates_args=[], + fake_impl=rocm_aiter_gemm_w8a8_blockscale_fake, + dispatch_key=current_platform.dispatch_key, + ) + + def is_rocm_aiter_gemm_w8a8_blockscale_enabled() -> bool: return current_platform.is_rocm() \ and envs.VLLM_ROCM_USE_AITER \ @@ -77,7 +104,7 @@ def dispatch_w8a8_blockscale_func( if use_cutlass: return cutlass_scaled_mm if is_rocm_aiter_gemm_w8a8_blockscale_enabled(): - return rocm_aiter_gemm_w8a8_blockscale + return torch.ops.vllm.rocm_aiter_gemm_w8a8_blockscale return w8a8_block_fp8_matmul From 55260a2e46ac571a71e9aa6b4e1d09c5a2875251 Mon Sep 17 00:00:00 2001 From: tjtanaa Date: Tue, 22 Apr 2025 10:03:01 +0000 Subject: [PATCH 12/18] refactor the conditional checking logic to reduce checking overhead Signed-off-by: tjtanaa --- .../model_executor/layers/quantization/fp8.py | 13 +++- .../layers/quantization/utils/fp8_utils.py | 60 +++++++++---------- 2 files changed, 39 insertions(+), 34 deletions(-) diff --git a/vllm/model_executor/layers/quantization/fp8.py b/vllm/model_executor/layers/quantization/fp8.py index be76785baccc..bea51e5c31d1 100644 --- a/vllm/model_executor/layers/quantization/fp8.py +++ b/vllm/model_executor/layers/quantization/fp8.py @@ -19,6 +19,8 @@ from vllm.model_executor.layers.quantization.base_config import ( QuantizationConfig, QuantizeMethodBase) from vllm.model_executor.layers.quantization.kv_cache import BaseKVCacheMethod +from vllm.model_executor.layers.quantization.utils.fp8_utils import ( + apply_w8a8_block_fp8_linear) from vllm.model_executor.layers.quantization.utils.marlin_utils_fp8 import ( apply_fp8_marlin_linear, prepare_fp8_layer_for_marlin) from vllm.model_executor.layers.quantization.utils.quant_utils import ( @@ -174,6 +176,13 @@ def __init__(self, quant_config: Fp8Config): if current_platform.is_rocm(): self.use_marlin = False + # AITER is only supported on ROCm and only for FP8_FNUZ + # and at the moment are MI300 series + self.use_aiter_and_is_supported = (current_platform.is_rocm() + and envs.VLLM_ROCM_USE_AITER + and envs.VLLM_ROCM_USE_AITER_LINEAR + and current_platform.is_fp8_fnuz()) + self.block_quant = self.quant_config.weight_block_size is not None if self.block_quant: # Marlin doesn't support block-wise fp8 @@ -403,7 +412,8 @@ def apply(self, if self.block_quant: assert self.quant_config.weight_block_size is not None - return torch.ops.vllm.apply_w8a8_block_fp8_linear( + # return torch.ops.vllm.apply_w8a8_block_fp8_linear( + return apply_w8a8_block_fp8_linear( input=x, weight=layer.weight, block_size=self.quant_config.weight_block_size, @@ -411,6 +421,7 @@ def apply(self, input_scale=layer.input_scale, bias=bias, cutlass_block_fp8_supported=self.cutlass_block_fp8_supported, + use_aiter_and_is_supported=self.use_aiter_and_is_supported, ) return self.fp8_linear.apply(input=x, diff --git a/vllm/model_executor/layers/quantization/utils/fp8_utils.py b/vllm/model_executor/layers/quantization/utils/fp8_utils.py index 5220c97b363b..6051bb87bcdf 100644 --- a/vllm/model_executor/layers/quantization/utils/fp8_utils.py +++ b/vllm/model_executor/layers/quantization/utils/fp8_utils.py @@ -10,7 +10,6 @@ import triton import triton.language as tl -import vllm.envs as envs from vllm import _custom_ops as ops from vllm.logger import init_logger from vllm.model_executor.layers.quantization.utils.quant_utils import ( @@ -93,17 +92,12 @@ def rocm_aiter_gemm_w8a8_blockscale_fake( ) -def is_rocm_aiter_gemm_w8a8_blockscale_enabled() -> bool: - return current_platform.is_rocm() \ - and envs.VLLM_ROCM_USE_AITER \ - and envs.VLLM_ROCM_USE_AITER_GEMM_W8A8_BLOCKSCALE - - def dispatch_w8a8_blockscale_func( - use_cutlass: bool) -> Callable[..., torch.Tensor]: + use_cutlass: bool, + use_aiter_and_is_supported: bool) -> Callable[..., torch.Tensor]: if use_cutlass: return cutlass_scaled_mm - if is_rocm_aiter_gemm_w8a8_blockscale_enabled(): + if (use_aiter_and_is_supported): return torch.ops.vllm.rocm_aiter_gemm_w8a8_blockscale return w8a8_block_fp8_matmul @@ -118,6 +112,7 @@ def apply_w8a8_block_fp8_linear( input_scale: Optional[torch.Tensor] = None, bias: Optional[torch.Tensor] = None, cutlass_block_fp8_supported: bool = CUTLASS_BLOCK_FP8_SUPPORTED, + use_aiter_and_is_supported: bool = False, ) -> torch.Tensor: assert input_scale is None # View input as 2D matrix for fp8 methods @@ -132,36 +127,35 @@ def apply_w8a8_block_fp8_linear( q_input, x_scale = per_token_group_quant_fp8( input_2d, block_size[1], column_major_scales=use_cutlass) - output = dispatch_w8a8_blockscale_func(use_cutlass)( - A=q_input, - B=weight, - As=x_scale, - Bs=weight_scale, - block_size=block_size, - output_dtype=input.dtype) + output = dispatch_w8a8_blockscale_func( + use_cutlass, use_aiter_and_is_supported)(A=q_input, + B=weight, + As=x_scale, + Bs=weight_scale, + block_size=block_size, + output_dtype=input.dtype) if bias is not None: output = output + bias return output.to(dtype=input.dtype).view(*output_shape) -def apply_w8a8_block_fp8_linear_fake( - input: torch.Tensor, - weight: torch.Tensor, - block_size: List[int], - weight_scale: torch.Tensor, - input_scale: Optional[torch.Tensor] = None, -) -> torch.Tensor: - output_shape = [*input.shape[:-1], weight.shape[0]] - return torch.empty(output_shape, dtype=input.dtype, device=input.device) - - -direct_register_custom_op( - op_name="apply_w8a8_block_fp8_linear", - op_func=apply_w8a8_block_fp8_linear, - mutates_args=[], - fake_impl=apply_w8a8_block_fp8_linear_fake, -) +# def apply_w8a8_block_fp8_linear_fake( +# input: torch.Tensor, +# weight: torch.Tensor, +# block_size: List[int], +# weight_scale: torch.Tensor, +# input_scale: Optional[torch.Tensor] = None, +# ) -> torch.Tensor: +# output_shape = [*input.shape[:-1], weight.shape[0]] +# return torch.empty(output_shape, dtype=input.dtype, device=input.device) + +# direct_register_custom_op( +# op_name="apply_w8a8_block_fp8_linear", +# op_func=apply_w8a8_block_fp8_linear, +# mutates_args=[], +# fake_impl=apply_w8a8_block_fp8_linear_fake, +# ) def input_to_float8( From e06cc3b48b9f37df1ece329c7cfde50c3e478f31 Mon Sep 17 00:00:00 2001 From: tjtanaa Date: Tue, 22 Apr 2025 13:13:17 +0000 Subject: [PATCH 13/18] restore dispatcher abstraction andupdate unittest Signed-off-by: tjtanaa --- .../model_executor/test_enabled_custom_ops.py | 10 ++++-- .../model_executor/layers/quantization/fp8.py | 5 +-- .../layers/quantization/utils/fp8_utils.py | 36 ++++++++++--------- 3 files changed, 29 insertions(+), 22 deletions(-) diff --git a/tests/model_executor/test_enabled_custom_ops.py b/tests/model_executor/test_enabled_custom_ops.py index bd4068b443dd..4a6791498e61 100644 --- a/tests/model_executor/test_enabled_custom_ops.py +++ b/tests/model_executor/test_enabled_custom_ops.py @@ -99,6 +99,9 @@ def test_enabled_ops_invalid(env: str): RMSNorm(1024).enabled() +@pytest.mark.skipif( + not current_platform.is_rocm() or not current_platform.is_fp8_fnuz(), + reason="AITER is a feature exclusive for ROCm and FP8_FNUZ") @pytest.mark.parametrize("use_cutlass", [True, False]) @pytest.mark.parametrize("use_rocm_aiter", ["0", "1"]) @pytest.mark.parametrize("use_rocm_aiter_gemm_w8a8_blockscale", ["0", "1"]) @@ -107,10 +110,13 @@ def test_w8a8_blockscale_dispatch(use_cutlass: bool, use_rocm_aiter: str, monkeypatch): monkeypatch.setenv("VLLM_ROCM_USE_AITER", use_rocm_aiter) - monkeypatch.setenv("VLLM_ROCM_USE_AITER_GEMM_W8A8_BLOCKSCALE", + monkeypatch.setenv("VLLM_ROCM_USE_AITER_LINEAR", use_rocm_aiter_gemm_w8a8_blockscale) - block_scale_func = dispatch_w8a8_blockscale_func(use_cutlass) + use_aiter_and_is_supported = (bool(int(use_rocm_aiter)) and bool( + int(use_rocm_aiter_gemm_w8a8_blockscale))) + block_scale_func = dispatch_w8a8_blockscale_func( + use_cutlass, use_aiter_and_is_supported=use_aiter_and_is_supported) if use_cutlass: assert block_scale_func == cutlass_scaled_mm elif current_platform.is_rocm() and int(use_rocm_aiter) and int( diff --git a/vllm/model_executor/layers/quantization/fp8.py b/vllm/model_executor/layers/quantization/fp8.py index bea51e5c31d1..6054bd924fa7 100644 --- a/vllm/model_executor/layers/quantization/fp8.py +++ b/vllm/model_executor/layers/quantization/fp8.py @@ -19,8 +19,6 @@ from vllm.model_executor.layers.quantization.base_config import ( QuantizationConfig, QuantizeMethodBase) from vllm.model_executor.layers.quantization.kv_cache import BaseKVCacheMethod -from vllm.model_executor.layers.quantization.utils.fp8_utils import ( - apply_w8a8_block_fp8_linear) from vllm.model_executor.layers.quantization.utils.marlin_utils_fp8 import ( apply_fp8_marlin_linear, prepare_fp8_layer_for_marlin) from vllm.model_executor.layers.quantization.utils.quant_utils import ( @@ -412,8 +410,7 @@ def apply(self, if self.block_quant: assert self.quant_config.weight_block_size is not None - # return torch.ops.vllm.apply_w8a8_block_fp8_linear( - return apply_w8a8_block_fp8_linear( + return torch.ops.vllm.apply_w8a8_block_fp8_linear( input=x, weight=layer.weight, block_size=self.quant_config.weight_block_size, diff --git a/vllm/model_executor/layers/quantization/utils/fp8_utils.py b/vllm/model_executor/layers/quantization/utils/fp8_utils.py index 6051bb87bcdf..1ab18c75ca80 100644 --- a/vllm/model_executor/layers/quantization/utils/fp8_utils.py +++ b/vllm/model_executor/layers/quantization/utils/fp8_utils.py @@ -140,22 +140,26 @@ def apply_w8a8_block_fp8_linear( return output.to(dtype=input.dtype).view(*output_shape) -# def apply_w8a8_block_fp8_linear_fake( -# input: torch.Tensor, -# weight: torch.Tensor, -# block_size: List[int], -# weight_scale: torch.Tensor, -# input_scale: Optional[torch.Tensor] = None, -# ) -> torch.Tensor: -# output_shape = [*input.shape[:-1], weight.shape[0]] -# return torch.empty(output_shape, dtype=input.dtype, device=input.device) - -# direct_register_custom_op( -# op_name="apply_w8a8_block_fp8_linear", -# op_func=apply_w8a8_block_fp8_linear, -# mutates_args=[], -# fake_impl=apply_w8a8_block_fp8_linear_fake, -# ) +def apply_w8a8_block_fp8_linear_fake( + input: torch.Tensor, + weight: torch.Tensor, + block_size: List[int], + weight_scale: torch.Tensor, + input_scale: Optional[torch.Tensor] = None, + bias: Optional[torch.Tensor] = None, + cutlass_block_fp8_supported: bool = CUTLASS_BLOCK_FP8_SUPPORTED, + use_aiter_and_is_supported: bool = False, +) -> torch.Tensor: + output_shape = [*input.shape[:-1], weight.shape[0]] + return torch.empty(output_shape, dtype=input.dtype, device=input.device) + + +direct_register_custom_op( + op_name="apply_w8a8_block_fp8_linear", + op_func=apply_w8a8_block_fp8_linear, + mutates_args=[], + fake_impl=apply_w8a8_block_fp8_linear_fake, +) def input_to_float8( From 41bf618e3f3b30a6abe68bef1e394a8ebd4d4400 Mon Sep 17 00:00:00 2001 From: tjtanaa Date: Wed, 23 Apr 2025 01:56:40 +0000 Subject: [PATCH 14/18] clean up TODO; it is resolved in main Signed-off-by: tjtanaa --- vllm/model_executor/layers/quantization/utils/fp8_utils.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/vllm/model_executor/layers/quantization/utils/fp8_utils.py b/vllm/model_executor/layers/quantization/utils/fp8_utils.py index 1ab18c75ca80..05ded072683c 100644 --- a/vllm/model_executor/layers/quantization/utils/fp8_utils.py +++ b/vllm/model_executor/layers/quantization/utils/fp8_utils.py @@ -119,8 +119,6 @@ def apply_w8a8_block_fp8_linear( input_2d = input.view(-1, input.shape[-1]) output_shape = [*input.shape[:-1], weight.shape[0]] - # TODO is_shape_supported_by_cutlass is never used, - # as cutlass_block_fp8_supported is False use_cutlass = cutlass_block_fp8_supported and is_shape_supported_by_cutlass( weight, block_size, weight_scale, input_2d) From fc97e94f2ff96f85bdedc64a8b0b3890da94cead Mon Sep 17 00:00:00 2001 From: vllmellm Date: Sun, 11 May 2025 14:41:13 +0000 Subject: [PATCH 15/18] fix pre-commit Signed-off-by: vllmellm --- .../layers/quantization/utils/fp8_utils.py | 36 ++++++++++--------- 1 file changed, 19 insertions(+), 17 deletions(-) diff --git a/vllm/model_executor/layers/quantization/utils/fp8_utils.py b/vllm/model_executor/layers/quantization/utils/fp8_utils.py index e1724cfac8ba..16f0b7a33763 100644 --- a/vllm/model_executor/layers/quantization/utils/fp8_utils.py +++ b/vllm/model_executor/layers/quantization/utils/fp8_utils.py @@ -44,9 +44,13 @@ def is_shape_supported_by_cutlass(weight: torch.Tensor, block_size: List[int], return weight.shape[0] % 128 == 0 and weight.shape[1] % 128 == 0 -def cutlass_scaled_mm(A: torch.Tensor, B: torch.Tensor, As: torch.Tensor, - Bs: torch.Tensor, output_dtype: torch.dtype, - **kwargs) -> torch.Tensor: +def cutlass_scaled_mm( + A: torch.Tensor, + B: torch.Tensor, + As: torch.Tensor, + Bs: torch.Tensor, + output_dtype: torch.dtype, +) -> torch.Tensor: return ops.cutlass_scaled_mm(A, B.T, out_dtype=output_dtype, @@ -143,13 +147,12 @@ def apply_w8a8_block_fp8_linear( q_input, x_scale = per_token_group_quant_fp8( input_2d, block_size[1], column_major_scales=use_cutlass) - output = w8a8_blockscale_func(use_cutlass, use_aiter_and_is_supported)( - A=q_input, - B=weight, - As=x_scale, - Bs=weight_scale, - block_size=block_size, - output_dtype=input.dtype) + output = w8a8_blockscale_func(A=q_input, + B=weight, + As=x_scale, + Bs=weight_scale, + block_size=block_size, + output_dtype=input.dtype) if should_pad: output = output[:rows, :] @@ -157,13 +160,12 @@ def apply_w8a8_block_fp8_linear( q_input, x_scale = per_token_group_quant_fp8( input_2d, block_size[1], column_major_scales=use_cutlass) - output = w8a8_blockscale_func(use_cutlass, use_aiter_and_is_supported)( - A=q_input, - B=weight, - As=x_scale, - Bs=weight_scale, - block_size=block_size, - output_dtype=input.dtype) + output = w8a8_blockscale_func(A=q_input, + B=weight, + As=x_scale, + Bs=weight_scale, + block_size=block_size, + output_dtype=input.dtype) if bias is not None: output = output + bias From 7c5491fef2038772618bc667731699d984b89d52 Mon Sep 17 00:00:00 2001 From: vllmellm Date: Sun, 11 May 2025 14:52:29 +0000 Subject: [PATCH 16/18] fix pre-commit Signed-off-by: vllmellm --- .../layers/quantization/utils/fp8_utils.py | 15 ++++++++++----- 1 file changed, 10 insertions(+), 5 deletions(-) diff --git a/vllm/model_executor/layers/quantization/utils/fp8_utils.py b/vllm/model_executor/layers/quantization/utils/fp8_utils.py index 16f0b7a33763..2a0a4c30fbde 100644 --- a/vllm/model_executor/layers/quantization/utils/fp8_utils.py +++ b/vllm/model_executor/layers/quantization/utils/fp8_utils.py @@ -49,7 +49,8 @@ def cutlass_scaled_mm( B: torch.Tensor, As: torch.Tensor, Bs: torch.Tensor, - output_dtype: torch.dtype, + block_size: List[int], + output_dtype: torch.dtype = torch.float16, ) -> torch.Tensor: return ops.cutlass_scaled_mm(A, B.T, @@ -63,8 +64,8 @@ def rocm_aiter_gemm_w8a8_blockscale_impl( B: torch.Tensor, As: torch.Tensor, Bs: torch.Tensor, + block_size: List[int], output_dtype: torch.dtype = torch.float16, - block_size: Optional[List[int]] = None, ) -> torch.Tensor: import aiter as rocm_aiter @@ -76,8 +77,8 @@ def rocm_aiter_gemm_w8a8_blockscale_fake( B: torch.Tensor, As: torch.Tensor, Bs: torch.Tensor, + block_size: List[int], output_dtype: torch.dtype = torch.float16, - block_size: Optional[List[int]] = None, ) -> torch.Tensor: m = A.shape[0] @@ -99,8 +100,12 @@ def rocm_aiter_gemm_w8a8_blockscale_fake( def dispatch_w8a8_blockscale_func( use_cutlass: bool, use_aiter_and_is_supported: bool ) -> Callable[[ - torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch. - dtype, Optional[List[int]] + torch.Tensor, + torch.Tensor, + torch.Tensor, + torch.Tensor, + List[int], + torch.dtype, ], torch.Tensor]: if use_cutlass: return cutlass_scaled_mm From a81cd66ec28f1477614a8eefed89a56c06037fbe Mon Sep 17 00:00:00 2001 From: vllmellm Date: Sun, 11 May 2025 14:59:56 +0000 Subject: [PATCH 17/18] fix pre-commit Signed-off-by: vllmellm --- .../layers/quantization/utils/fp8_utils.py | 16 ++++------------ 1 file changed, 4 insertions(+), 12 deletions(-) diff --git a/vllm/model_executor/layers/quantization/utils/fp8_utils.py b/vllm/model_executor/layers/quantization/utils/fp8_utils.py index 2a0a4c30fbde..306bae384560 100644 --- a/vllm/model_executor/layers/quantization/utils/fp8_utils.py +++ b/vllm/model_executor/layers/quantization/utils/fp8_utils.py @@ -152,12 +152,8 @@ def apply_w8a8_block_fp8_linear( q_input, x_scale = per_token_group_quant_fp8( input_2d, block_size[1], column_major_scales=use_cutlass) - output = w8a8_blockscale_func(A=q_input, - B=weight, - As=x_scale, - Bs=weight_scale, - block_size=block_size, - output_dtype=input.dtype) + output = w8a8_blockscale_func(q_input, weight, x_scale, weight_scale, + block_size, input.dtype) if should_pad: output = output[:rows, :] @@ -165,12 +161,8 @@ def apply_w8a8_block_fp8_linear( q_input, x_scale = per_token_group_quant_fp8( input_2d, block_size[1], column_major_scales=use_cutlass) - output = w8a8_blockscale_func(A=q_input, - B=weight, - As=x_scale, - Bs=weight_scale, - block_size=block_size, - output_dtype=input.dtype) + output = w8a8_blockscale_func(q_input, weight, x_scale, weight_scale, + block_size, input.dtype) if bias is not None: output = output + bias From 59437291a0f3384d959ded0e602840b2c112ef5e Mon Sep 17 00:00:00 2001 From: vllmellm Date: Tue, 13 May 2025 02:25:27 +0000 Subject: [PATCH 18/18] clean up use_cutlass logic Signed-off-by: vllmellm --- .../layers/quantization/utils/fp8_utils.py | 24 ++++--------------- 1 file changed, 5 insertions(+), 19 deletions(-) diff --git a/vllm/model_executor/layers/quantization/utils/fp8_utils.py b/vllm/model_executor/layers/quantization/utils/fp8_utils.py index 306bae384560..48b2a8d64d36 100644 --- a/vllm/model_executor/layers/quantization/utils/fp8_utils.py +++ b/vllm/model_executor/layers/quantization/utils/fp8_utils.py @@ -27,23 +27,6 @@ def is_fp8(x: Union[torch.dtype, torch.Tensor]) -> bool: return x == torch.float8_e4m3fn or x == torch.float8_e4m3fnuz -def is_shape_supported_by_cutlass(weight: torch.Tensor, block_size: List[int], - weight_scale: torch.Tensor, - input_2d: torch.Tensor) -> bool: - if current_platform.is_rocm(): - # TODO this is never used, as cutlass_block_fp8_supported is False - scale_a_shape = ((input_2d.shape[-1] // block_size[1], ) + - input_2d.shape[:-1])[::-1] - scale_b_shape = (weight_scale.view(-1, 1) - if weight_scale.dim() <= 1 else weight_scale.T).shape - ar, ac = scale_a_shape - br, bc = scale_b_shape - return ac > 1 or bc > 1 or ar not in (1, input_2d.shape[0]) \ - or br not in (1, weight.shape[0]) - - return weight.shape[0] % 128 == 0 and weight.shape[1] % 128 == 0 - - def cutlass_scaled_mm( A: torch.Tensor, B: torch.Tensor, @@ -131,8 +114,11 @@ def apply_w8a8_block_fp8_linear( input_2d = input.view(-1, input.shape[-1]) output_shape = [*input.shape[:-1], weight.shape[0]] - use_cutlass = cutlass_block_fp8_supported and is_shape_supported_by_cutlass( - weight, block_size, weight_scale, input_2d) + if current_platform.is_cuda(): + use_cutlass = cutlass_block_fp8_supported and ( + weight.shape[0] % 128 == 0 and weight.shape[1] % 128 == 0) + else: + use_cutlass = False w8a8_blockscale_func = dispatch_w8a8_blockscale_func( use_cutlass, use_aiter_and_is_supported)