Skip to content

Conversation

tjtanaa
Copy link
Contributor

@tjtanaa tjtanaa commented Mar 17, 2025

This PR integrates the RMS Norm layer functionality from AITER (AI Tensor Engine for ROCm) into vLLM.
This PR is to introduce AITER RMS Norm layer kernel so that any up-coming optimization in AITER kernel could be directly use and evaluated within vLLM framework.

RMS Norm Layer Implementation

The rmsnorm2d_fwd_with_add kernel from AITER has been integrated for the ROCm RMS norm forward pass in /vllm/model_executor/layers/layernorm.py. This feature:

  • Is enabled by default when the environment variable VLLM_ROCM_USE_AITER=1 is set
  • Can be specifically enabled or disabled using the dedicated environment variable VLLM_ROCM_USE_AITER_RMSNORM

Performance Improvements over Not using AITER kernel

Llama-3.1-8B-Instruct (with FP8 per-tensor dynamic quantization)

  • RMS norm only: -1.1~0.8% performance change

Llama-3.1-8B-Instruct-BF16

  • RMS norm only: 0.5~3.9% performance improvement

Llama-3.1-70B-Instruct (with FP8 per-tensor dynamic quantization)

  • RMS norm only: -0.02~2% performance change

Llama-3.1-70B-Instruct-BF16

  • RMS norm only: -0.12~1.2% performance change

Testing

The integration has been verified through:

  • High-level integration tests with various models
  • Kernel function dispatch testing to ensure correct operation selection
  • Quantization compatibility testing

This PR is part of a larger effort to integrate AITER kernels into vLLM for improved performance on ROCm platforms.

Unit tests Status

  • tests/model_executor/test_enabled_custom_ops.py [Passed]
  • tests/models/decoder_only/language/test_models.py [Passed*]
      • All passed except bigscience/bloom-560m has been failing in the main branch. (This branch unit tests matches vllm-project/vllm main test statuses)

Copy link

👋 Hi! Thank you for contributing to the vLLM project.

💬 Join our developer Slack at https://slack.vllm.ai to discuss your PR in #pr-reviews, coordinate on features in #feat- channels, or join special interest groups in #sig- channels.

Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run fastcheck CI which starts running only a small and essential subset of CI tests to quickly catch errors. You can run other CI tests on top of those by going to your fastcheck build on Buildkite UI (linked in the PR checks section) and unblock them. If you do not have permission to unblock, ping simon-mo or khluu to add you in our Buildkite org.

Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging.

To run CI, PR reviewers can either: Add ready label to the PR or enable auto-merge.

🚀

@mergify mergify bot added the ci/build label Mar 17, 2025
Comment on lines 82 to 83
if os.getenv("SKIP_ROCM_ATIER_MODEL_TEST_CASES") == "true":
pytest.skip("Skipping test suite for ROCM AITER")
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Where is this environment variable being used?

Copy link
Contributor Author

@tjtanaa tjtanaa Mar 17, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It is used in the .buildkite/run-amd-test.sh to skip the unit test in the CI environment. I add the changes to the .buildkite/run-amd-test.sh .

Copy link
Member

@DarkLight1337 DarkLight1337 Mar 17, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Instead of using environment variables, can you use pytest custom markers to select/exclude tests?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why do we want to skip the test in CI? I generally agree, though. A pytest maker would be nicer than an environment variable.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@SageMoore Previously we would not want to break the AMD CI, we have it temporarily disabled. Now since we have pinned down the AITER to a commit in Dockerfile.rocm_base and ensure all the unit tests in them are passing, we will enable the test in CI as well.

@SageMoore @DarkLight1337 About the pytest marker:
Since we enable the AITER kernel tests by default. In this case, we don't need to disable AITER. This also reduces the need to add pytest marker or any form of decorators.

So, is it ok to keep it as follows?

...
@pytest.mark.parametrize(
    "use_rocm_aiter", [True, False] if current_platform.is_rocm() else [False])
def test_models(hf_runner, vllm_runner, example_prompts, model: str,
                dtype: str, max_tokens: int, num_logprobs: int,
                use_rocm_aiter: bool, monkeypatch) -> None:

    if model in REQUIRES_V0 or current_platform.is_rocm():
        monkeypatch.setenv("VLLM_USE_V1", "0")

    if use_rocm_aiter:
        monkeypatch.setenv("VLLM_ROCM_USE_AITER", "1")

    with hf_runner(model, dtype=dtype) as hf_model:
        if model.startswith("THUDM/chatglm3"):
            hf_model.model.get_output_embeddings = lambda: \
                hf_model.model.transformer.output_layer

        hf_outputs = hf_model.generate_greedy_logprobs_limit(
            example_prompts, max_tokens, num_logprobs)

    with vllm_runner(model, dtype=dtype) as vllm_model:
        vllm_outputs = vllm_model.generate_greedy_logprobs(
            example_prompts, max_tokens, num_logprobs)

    check_logprobs_close(
        outputs_0_lst=hf_outputs,
        outputs_1_lst=vllm_outputs,
        name_0="hf",
        name_1="vllm",
    )

Similarly in #14967 (comment)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@SageMoore @DarkLight1337 I have also referenced your comments in other AITER PRs

  • We could enable AITER in AMD CI as we pinned AITER to a specific commit. (This avoids the need for pytest marker or decorator)
  • Enable the tests for the model that are actually using AITER kernels to avoid redundant models.

This is the final state of the test_models.py

# SPDX-License-Identifier: Apache-2.0
"""Compare the outputs of HF and vLLM when using greedy sampling.

Run `pytest tests/models/test_models.py`.
"""

import pytest
+ import torch

from vllm.platforms import current_platform

from ...utils import check_logprobs_close

# These have unsupported head_dim for FA. We do not
# not have a clean way to fall back, so we fail with
# a clear msg when it happens.
# https://github.com/vllm-project/vllm/issues/14524
REQUIRES_V0 = ["microsoft/phi-2", "stabilityai/stablelm-3b-4e1t"]

+ # This list contains the model that are using AITER kernel.
+ # Skip model that are not using AITER tests.
+ # When more AITER kernels are added, this list will not be
+ # needed as all the models will be calling AITER kernels
+ # in parts of the operators
+ AITER_MODEL_LIST = [
+    "meta-llama/Llama-3.2-1B-Instruct",
+    "openbmb/MiniCPM3-4B",
+    "Qwen/Qwen-7B",
+    "Qwen/Qwen2.5-0.5B-Instruct",
+    "ehristoforu/Falcon3-MoE-2x7B-Insruct",
+ ]


# @maybe_test_rocm_aiter
@pytest.mark.parametrize(
    "model",
    [
        pytest.param(
            "bigscience/bloom-560m",  # bloom - testing alibi slopes
            marks=[pytest.mark.core_model, pytest.mark.cpu_model],
        ),
        pytest.param(
            "openai-community/gpt2",  # gpt2
            marks=[pytest.mark.core_model, pytest.mark.cpu_model],
        ),
        pytest.param("Milos/slovak-gpt-j-405M"),  # gptj
        pytest.param("bigcode/tiny_starcoder_py"),  # gpt_bigcode
        pytest.param("EleutherAI/pythia-70m"),  # gpt_neox
        pytest.param(
            "google/gemma-1.1-2b-it",  # gemma
            marks=[pytest.mark.core_model, pytest.mark.cpu_model],
        ),
        pytest.param(
            "THUDM/chatglm3-6b",  # chatglm (text-only)
        ),
        pytest.param(
            "meta-llama/Llama-3.2-1B-Instruct",  # llama
            marks=[pytest.mark.core_model, pytest.mark.cpu_model],
        ),
        pytest.param(
            "openbmb/MiniCPM3-4B",
            # fused_moe not supported on CPU
            marks=[pytest.mark.core_model],
        ),
        pytest.param(
            "facebook/opt-125m",  # opt
            marks=[pytest.mark.core_model, pytest.mark.cpu_model],
        ),
        pytest.param(
            "microsoft/phi-2",  # phi
            marks=[pytest.mark.core_model],
        ),
        pytest.param(
            "Qwen/Qwen-7B",  # qwen (text-only)
        ),
        pytest.param(
            "Qwen/Qwen2.5-0.5B-Instruct",  # qwen2
            marks=[pytest.mark.core_model],
        ),
        pytest.param("stabilityai/stablelm-3b-4e1t"),  # stablelm
        pytest.param("bigcode/starcoder2-3b"),  # starcoder2
        pytest.param(
            "ehristoforu/Falcon3-MoE-2x7B-Insruct",  # mixtral
            marks=[pytest.mark.cpu_model],
        )
    ])
@pytest.mark.parametrize("dtype", ["half"])
@pytest.mark.parametrize("max_tokens", [32])
@pytest.mark.parametrize("num_logprobs", [5])
+ @pytest.mark.parametrize(
+    "use_rocm_aiter", [True, False] if current_platform.is_rocm() else [False])
def test_models(hf_runner, vllm_runner, example_prompts, model: str,
                dtype: str, max_tokens: int, num_logprobs: int,
+                use_rocm_aiter: bool, monkeypatch) -> None:

    if model in REQUIRES_V0:
        monkeypatch.setenv("VLLM_USE_V1", "0")

+    if use_rocm_aiter and (model in AITER_MODEL_LIST):
+        monkeypatch.setenv("VLLM_ROCM_USE_AITER", "1")
+    elif use_rocm_aiter and model not in AITER_MODEL_LIST:
+        # Skip model that are not using AITER tests.
+        # When more AITER kernels are added, this list will not be
+        # needed as all the models will be calling AITER kernels
+        # in parts of the operators
+        pytest.skip(f"Skipping '{model}' model test with AITER kernel.")

    with hf_runner(model, dtype=dtype) as hf_model:
        if model.startswith("THUDM/chatglm3"):
            hf_model.model.get_output_embeddings = lambda: \
                hf_model.model.transformer.output_layer

        hf_outputs = hf_model.generate_greedy_logprobs_limit(
            example_prompts, max_tokens, num_logprobs)

    with vllm_runner(model, dtype=dtype) as vllm_model:
        vllm_outputs = vllm_model.generate_greedy_logprobs(
            example_prompts, max_tokens, num_logprobs)

    check_logprobs_close(
        outputs_0_lst=hf_outputs,
        outputs_1_lst=vllm_outputs,
        name_0="hf",
        name_1="vllm",
    )
+    if use_rocm_aiter:
+        # this is to ensure that vllm engine
+        # has deallocated the memory before running the next
+        # unit tests. On ROCm, the memory might not be
+        # deallocated completely before running the
+        # next test case
+        torch.cuda.synchronize()

Copy link
Contributor

@SageMoore SageMoore left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I requested a few changes, but otherwise looks reasonable. Thanks for breaking it out of the mono-PR!

Comment on lines 82 to 83
if os.getenv("SKIP_ROCM_ATIER_MODEL_TEST_CASES") == "true":
pytest.skip("Skipping test suite for ROCM AITER")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why do we want to skip the test in CI? I generally agree, though. A pytest maker would be nicer than an environment variable.

def rocm_aiter_rmsnorm2d_fwd_with_add(
*, x: torch.Tensor, residual: torch.Tensor, weight: torch.Tensor,
variance_epsilon: float) -> Tuple[torch.Tensor, torch.Tensor]:
import aiter as rocm_aiter
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Given that AITER isn't published on pypy yet, meaning users will either have to use the docker container or build from source, I'd like to have a nicer error message when users try to enable aiter without it being installed. There are a number of ways we can do this. I like the following but am open to other solutions.

def dispatch_cuda_rmsnorm_func(
    add_residual: bool
) -> Callable[..., Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]]:
    if not add_residual:
        return rms_norm
    if current_platform.is_rocm_aiter_rmsnorm_enabled():
        try:
            import aiter as rocm_aiter
            return rocm_aiter_rmsnorm2d_fwd_with_add
        except ImportError:
            logger.warn_once("AITER RMS Norm kernel is enabled, but AITER is not installed. Falling back to the default RMS Norm kernel")
            return fused_add_rms_norm
    return fused_add_rms_norm

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@SageMoore

  1. import aiter is conflicting with built-in function in python and try catching import aiter does not show whether aiter is installed. Unless we try to import a kernel function from aiter that if that kernel is not prebuild it would start building kernel JIT.
  2. having a fallback makes it difficult to actually debug and ping pong performance differences. In addition, just a warning might be missed by users and complain about the performance, as user expect when AITER flag is set, AITER kernels are used.

So, we will avoid having a fallback here.

@tjtanaa tjtanaa requested a review from SageMoore March 20, 2025 17:26
Copy link
Contributor

@SageMoore SageMoore left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we're pretty close here. All of my comments are NITs

name_0="hf",
name_1="vllm",
)
if use_rocm_aiter:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this something we should generally be doing for ROCm or just when AITER is enabled?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this something we should generally be doing for ROCm or just when AITER is enabled?

Currently, it seems to be just when AITER enabled that this situation could occur.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@SageMoore
We have made the description clearer

        # this is to ensure that vllm engine
        # has deallocated the memory before running the next
+        # unit tests. On ROCm, when using AITER
+        # the memory might not be deallocated completely
+        # before running the next test case
        torch.cuda.synchronize()

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good to know. Thanks!

tjtanaa added 2 commits March 21, 2025 01:10
Signed-off-by: tjtanaa <[email protected]>
Signed-off-by: tjtanaa <[email protected]>
Copy link
Contributor

@SageMoore SageMoore left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for addressing all of my comments. This looks reasonable to me.

name_0="hf",
name_1="vllm",
)
if use_rocm_aiter:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good to know. Thanks!

Copy link
Member

@DarkLight1337 DarkLight1337 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Stamp

@DarkLight1337 DarkLight1337 enabled auto-merge (squash) March 21, 2025 14:54
@github-actions github-actions bot added the ready ONLY add when PR is ready to merge/full CI is needed label Mar 21, 2025
@vllm-bot vllm-bot merged commit ec870fb into vllm-project:main Mar 22, 2025
38 of 42 checks passed
@tjtanaa tjtanaa deleted the aiter-rmsnorm branch March 22, 2025 10:46
erictang000 pushed a commit to erictang000/vllm that referenced this pull request Mar 25, 2025
lulmer pushed a commit to lulmer/vllm that referenced this pull request Apr 7, 2025
lk-chen pushed a commit to lk-chen/vllm that referenced this pull request Apr 29, 2025
shreyankg pushed a commit to shreyankg/vllm that referenced this pull request May 3, 2025
RichardoMrMu pushed a commit to RichardoMrMu/vllm that referenced this pull request May 12, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ci/build ready ONLY add when PR is ready to merge/full CI is needed

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants