Skip to content

Conversation

vllmellm
Copy link
Contributor

@vllmellm vllmellm commented May 28, 2025

This PR introduces a new aiter kernel, grouped_topk, from the aiter package. This kernel is used for model architectures and configurations like Deepseek-V2. The previously integrated kernel, biased_grouped_topk, from aiter package in #17955 is used for model architectures and configurations like Deepseek-V3.

lm_eval results on Deepseek-V2-Lite-Chat

command:
VLLM_USE_V1=1 VLLM_ROCM_USE_AITER=1 VLLM_ROCM_USE_AITER_LINEAR=0 VLLM_ROCM_USE_AITER_RMSNORM=0 SAFETENSORS_FAST_GPU=1 \ lm_eval --model vllm --model_args pretrained=deepseek-ai/DeepSeek-V2-Lite-Chat,tensor_parallel_size=1,max_model_len=32768,block_size=1 --trust_remote_code --tasks gsm8k --num_fewshot 5 --batch_size auto

Tasks Version Filter n-shot Metric Value Stderr
gsm8k 3 flexible-extract 5 exact_match 0.6619 ± 0.0130
strict-match 5 exact_match 0.6528 ± 0.0131

lm_eval results on Deepseek-V3

command:
VLLM_USE_V1=1 VLLM_ROCM_USE_AITER=1 VLLM_ROCM_USE_AITER_LINEAR=0 VLLM_ROCM_USE_AITER_RMSNORM=0 SAFETENSORS_FAST_GPU=1 lm_eval --model vllm --model_args pretrained=deepseek-ai/DeepSeek-V3,tensor_parallel_size=8,max_model_len=32768,block_size=1 --trust_remote_code --tasks gsm8k --num_fewshot 5 --batch_size auto

Tasks Version Filter n-shot Metric Value Stderr
gsm8k 3 flexible-extract 5 exact_match 0.95 ± 0.006
strict-match 5 exact_match 0.95 ± 0.006

Performance Result on DeepSeek-V2-Lite-Chat

This only benchmarks the newly added grouped_topk kernel.

commands:

serve: `SAFETENSORS_FAST_GPU=1 VLLM_USE_V1=1 VLLM_ROCM_USE_AITER=1 VLLM_ROCM_USE_AITER_LINEAR=0 VLLM_ROCM_USE_AITER_RMSNORM=0 vllm serve deepseek-ai/DeepSeek-V2-Lite-Chat -tp 1 --max-model-len 32768 --block-size 1 --max_seq_len_to_capture 32768 --max-num-batched-tokens 32768 `

benchmark serving: `python benchmarks/benchmark_serving.py --model deepseek-ai/DeepSeek-V2-Lite-Chat --dataset-name random --random-input-len 1000 --random-input-len 1000 --random-range-ratio 0.9 --num-prompts 500`
Metric Triton grouped_topk AITER asm grouped_topk
Successful requests 500 500
Benchmark duration (s) 67.45 66.16
Total input tokens 493,277 493,277
Total generated tokens 475,127 476,229
Request throughput (req/s) 7.41 7.56
Request goodput (req/s) 0.92 0.36
Output token throughput (tok/s) 7,044.09 7,197.72
Total Token throughput (tok/s) 14,357.26 14,653.11
Time to First Token (TTFT)
Mean TTFT (ms) 5,878.18 7,194.89
Median TTFT (ms) 6,061.95 7,397.77
P99 TTFT (ms) 10,688.99 11,999.16
Time per Output Token (excl. 1st token)
Mean TPOT (ms) 51.87 52.41
Median TPOT (ms) 40.99 39.02
P99 TPOT (ms) 272.31 635.69
Inter-token Latency
Mean ITL (ms) 39.19 37.35
Median ITL (ms) 34.61 33.54
P99 ITL (ms) 94.38 62.50

The observed increase in TTFT could hinder the overall user experience.
Would it be better to enable/disable this kernel with a switch ?
Would appreciate any suggestions about this issue.

Copy link

👋 Hi! Thank you for contributing to the vLLM project.

💬 Join our developer Slack at https://slack.vllm.ai to discuss your PR in #pr-reviews, coordinate on features in #feat- channels, or join special interest groups in #sig- channels.

Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run fastcheck CI which starts running only a small and essential subset of CI tests to quickly catch errors. You can run other CI tests on top of those by going to your fastcheck build on Buildkite UI (linked in the PR checks section) and unblock them. If you do not have permission to unblock, ping simon-mo or khluu to add you in our Buildkite org.

Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging.

To run CI, PR reviewers can either: Add ready label to the PR or enable auto-merge.

🚀

@gshtras gshtras added the rocm Related to AMD ROCm label May 28, 2025
Copy link
Contributor

@SageMoore SageMoore left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just one question. Otherwise LGTM

)

if e_score_correction_bias is not None:
torch.ops.vllm.rocm_aiter_biased_grouped_topk(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do you need to assert that scoring_func is "sigmoid" here? I'm only asking because the previous assert was deleted.

Copy link
Contributor Author

@vllmellm vllmellm May 30, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@SageMoore the scoring_func is not an argument used in biased_grouped_topk for this version. so only used for grouped_topk which supports both sigmoid and softmax.

num_expert_group: int = 0,
topk_group: int = 0,
scoring_func: str = "sigmoid",
scoring_func: str = "softmax",
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this seems a bc breaking change, is it intentional?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@houseroad This is intentional. It is a modified as a new function rocm_aiter_grouped_topk that follows the same argument signature as grouped_topk from

.. Extra feature of this function is that it includes both AITER features: grouped_topk and biased_grouped_topk.

)
else:
assert (scoring_func == "softmax" or scoring_func == "sigmoid")
torch.ops.vllm.rocm_aiter_grouped_topk(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

where these ops are registered? in AITER repo or vLLM repo?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@houseroad All of the ops. rocm_aiter_grouped_topk and rocm_aiter_biased_grouped_topk are registered in this file: rocm_aiter_fused_moe.py

@DarkLight1337 DarkLight1337 added the ready ONLY add when PR is ready to merge/full CI is needed label May 30, 2025
Copy link
Member

@DarkLight1337 DarkLight1337 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The test failures are not related to this PR

@vllm-bot vllm-bot merged commit 0f5e0d5 into vllm-project:main May 31, 2025
77 of 80 checks passed
amitm02 pushed a commit to amitm02/vllm that referenced this pull request Jun 1, 2025
amitm02 pushed a commit to amitm02/vllm that referenced this pull request Jun 1, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ready ONLY add when PR is ready to merge/full CI is needed rocm Related to AMD ROCm

Projects

None yet

Development

Successfully merging this pull request may close these issues.

6 participants