Skip to content

Conversation

divakar-amd
Copy link
Contributor

@divakar-amd divakar-amd commented Aug 20, 2025

Purpose

Replace torch.bmm with aiter batched_gemm_a8w8_a_per_token_group_prequant_w_per_batched_tensor_quant kernel for MLA

Test Plan

  • Correctness check
  • Perf check

Test Result

Correctness Result


Generated Outputs:
------------------------------------------------------------
Prompt:    'Hello, my name is'
Output:    ' Christian Munoz and\nthis is my blog where I cover different\nIT topics'
------------------------------------------------------------
Prompt:    'The president of the United States is'
Output:    ' in charge of which branch of government? A. judicial B. legislative C.'
------------------------------------------------------------
Prompt:    'The capital of France is'
Output:    ' Paris. Paris is located along the Seine River in the north-central part of the'
------------------------------------------------------------
Prompt:    'The future of AI is'
Output:    " a fascinating and rapidly evolving field. Here's a glimpse into some key areas shaping"
------------------------------------------------------------

Performance test on DeepSeek-R1 with full-cudagraph capture mode

REQUEST_RATES=(1 5 7 9)
TOTAL_SECONDS=20 
TP=8
OUTPUT_LEN=128

DATASET_PATH="ShareGPT_Vicuna_unfiltered/ShareGPT_V3_unfiltered_cleaned_split.json"
PYTHON_BENCH_SCRIPT="benchmarks/benchmark_serving.py"
   VLLM_USE_V1=1 vllm serve $MODEL \
        --port 8004 \
        --tensor-parallel-size $TP \
        --max-num-seqs 256 \
        --no-enable-prefix-caching \
        --swap-space 16 \
        --disable-log-requests \
        --disable-uvicorn-access-log \
        --block-size 1 \
        -O '{"full_cuda_graph":true}' 
        python3 $PYTHON_BENCH_SCRIPT \
            --model $MODEL \
            --percentile-metrics ttft,tpot,itl,e2el \
            --dataset-path $DATASET_PATH \
            --request-rate $REQUEST_RATE \
            --num-prompts $(($TOTAL_SECONDS * $REQUEST_RATE)) \
            --ignore-eos \
            --port 8004 \
            --sharegpt-output-len $OUTPUT_LEN

Performance with and without this kernel
resultPlot_aiter_withoutBMM_run2-vs-aiter_BMM_precompiled

Pre-compilation for the kernel

If the kernel is not pre-compiled, the graph below shows the difference in performance when the kernel is run for the first time (aiter_BMM_run1) -vs- when the subsequent run (aiter_BMM_run2). Adding a pre-compilation step during weight loading resolves this issue.
resultPlot_aiter_BMM_run1-vs-aiter_BMM_run2-vs-aiter_BMM_precompiled

@github-actions
Copy link

👋 Hi! Thank you for contributing to the vLLM project.

💬 Join our developer Slack at https://slack.vllm.ai to discuss your PR in #pr-reviews, coordinate on features in #feat- channels, or join special interest groups in #sig- channels.

Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run fastcheck CI which starts running only a small and essential subset of CI tests to quickly catch errors. You can run other CI tests on top of those by going to your fastcheck build on Buildkite UI (linked in the PR checks section) and unblock them. If you do not have permission to unblock, ping simon-mo or khluu to add you in our Buildkite org.

Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging.

To run CI, PR reviewers can either: Add ready label to the PR or enable auto-merge.

🚀

@mergify mergify bot added rocm Related to AMD ROCm v1 labels Aug 20, 2025
Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces a new Triton kernel for FP8 batched matrix multiplication (bmm) within the MLA backend on ROCm, aimed at improving performance. The changes include adding a new environment variable VLLM_ROCM_USE_AITER_FP8BMM to control this feature, and conditionally using the new kernel in place of torch.bmm for specific operations in the MLA implementation.

My review has identified one critical issue: the new environment variable is not included in the computation graph hash, which can lead to incorrect caching behavior. This must be addressed to ensure correctness.

# triton kernel to avoid runtime compilation for unseen batch sizes
# Pre-compile for batch sizes 1 to 1024 to cover most use-cases.
# On DS-R1, this step adds roughly 50s to the model loading time.
max_batch_size = 1024 # [ToDo] Find the optimal upper limit
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not crazy about adding this much overhead to the model loading time. CC @mgoin @LucasWilkinson I don't know how much pre-compilation we consider "acceptable". @divakar-amd how much does this improve performance?

Copy link
Contributor Author

@divakar-amd divakar-amd Aug 25, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Without pre-compilation, the performance is worse than torch.bmm for the first run. I added a plot above showing the performance difference if pre-compilation is not used. @SageMoore
Also, pre-compilation will add to the model loading time only if AITER is enabled.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this is generally fine. As you pointed out offline we do precompilation for other AITER kernels already and that takes less time than torch.compile, which seems like a reasonable upper bound. I do agree that the TTFT improvements are nice.

@ProExpertProg ProExpertProg added the ready ONLY add when PR is ready to merge/full CI is needed label Aug 27, 2025
@gshtras gshtras enabled auto-merge (squash) August 28, 2025 16:17
@gshtras gshtras merged commit 04d1dd7 into vllm-project:main Aug 28, 2025
38 of 39 checks passed
eicherseiji pushed a commit to eicherseiji/vllm that referenced this pull request Sep 9, 2025
FeiDaLI pushed a commit to FeiDaLI/vllm that referenced this pull request Sep 25, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ready ONLY add when PR is ready to merge/full CI is needed rocm Related to AMD ROCm v1

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants