Skip to content

Conversation

yewentao256
Copy link
Member

@yewentao256 yewentao256 commented Jul 30, 2025

Purpose

Fix #19630

  • using common vectorization utils
  • support non-contiguous input (last stride == 1)
  • empty instead of zeros

Test

Acc

lm_eval --model vllm --model_args "pretrained=nm-testing/DeepSeek-Coder-V2-Lite-Instruct-FP8,max_model_len=32768,enable_expert_parallel=True,enforce_eager=True" --trust_remote_code --tasks gsm8k --num_fewshot 5 --batch_size auto

# Now
|Tasks|Version|     Filter     |n-shot|  Metric   |   |Value |   |Stderr|
|-----|------:|----------------|-----:|-----------|---|-----:|---|-----:|
|gsm8k|      3|flexible-extract|     5|exact_match||0.7574|±  |0.0118|
|     |       |strict-match    |     5|exact_match||0.7354|±  |0.0122|

# main
|Tasks|Version|     Filter     |n-shot|  Metric   |   |Value |   |Stderr|
|-----|------:|----------------|-----:|-----------|---|-----:|---|-----:|
|gsm8k|      3|flexible-extract|     5|exact_match||0.7574|±  |0.0118|
|     |       |strict-match    |     5|exact_match||0.7354|±  |0.0122|

Perf

vllm bench throughput --model nm-testing/DeepSeek-Coder-V2-Lite-Instruct-FP8 --input-len 1000 --output-len 100 --trust_remote_code --enforce_eager

# Now
Throughput: 49.45 requests/s, 54306.45 total tokens/s, 4944.88 output tokens/s
# main
Throughput: 48.92 requests/s, 53730.95 total tokens/s, 4892.48 output tokens/s

@github-actions
Copy link

👋 Hi! Thank you for contributing to the vLLM project.

💬 Join our developer Slack at https://slack.vllm.ai to discuss your PR in #pr-reviews, coordinate on features in #feat- channels, or join special interest groups in #sig- channels.

Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run fastcheck CI which starts running only a small and essential subset of CI tests to quickly catch errors. You can run other CI tests on top of those by going to your fastcheck build on Buildkite UI (linked in the PR checks section) and unblock them. If you do not have permission to unblock, ping simon-mo or khluu to add you in our Buildkite org.

Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging.

To run CI, PR reviewers can either: Add ready label to the PR or enable auto-merge.

🚀

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces support for non-contiguous tensors in FP8 quantization kernels. However, the implementation of segmented_max_reduction_strided has critical correctness and performance issues that need to be addressed before merging.

@mgoin mgoin self-assigned this Jul 30, 2025
@mgoin mgoin requested review from LucasWilkinson and mgoin July 30, 2025 19:43
Copy link
Member

@mgoin mgoin left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice work! Can you add at least one case to the unit test that expresses this type of non-contiguous input?

Signed-off-by: yewentao256 <[email protected]>
Signed-off-by: yewentao256 <[email protected]>
Signed-off-by: yewentao256 <[email protected]>
Signed-off-by: yewentao256 <[email protected]>
@mgoin
Copy link
Member

mgoin commented Aug 1, 2025

@yewentao256 can you add a simple unit test for this before merge?

@yewentao256
Copy link
Member Author

@yewentao256 can you add a simple unit test for this before merge?

Thanks for the comment! Done

(.wentao_env) wentao@gpu66:~/vllm-source/tests/quantization$ pytest test_fp8.py::test_scaled_fp8_quant
================================ test session starts ================================
platform linux -- Python 3.12.3, pytest-8.4.0, pluggy-1.6.0
rootdir: /home/wentao/vllm-source
configfile: pyproject.toml
plugins: asyncio-1.0.0, anyio-4.9.0
asyncio: mode=Mode.STRICT, asyncio_default_fixture_loop_scope=None, asyncio_default_test_loop_scope=function
collected 2 items                                                                   

test_fp8.py ..                                                                [100%]

================================= 2 passed in 2.12s =================================

@mgoin mgoin added performance Performance-related issues deepseek Related to DeepSeek models labels Aug 1, 2025
@mgoin mgoin enabled auto-merge (squash) August 1, 2025 21:19
@github-actions github-actions bot added the ready ONLY add when PR is ready to merge/full CI is needed label Aug 1, 2025
@vllm-bot vllm-bot merged commit 4771df7 into vllm-project:main Aug 5, 2025
72 of 74 checks passed
npanpaliya pushed a commit to odh-on-pz/vllm-upstream that referenced this pull request Aug 6, 2025
myselvess pushed a commit to myselvess/vllm that referenced this pull request Aug 7, 2025
jinzhen-lin pushed a commit to jinzhen-lin/vllm that referenced this pull request Aug 9, 2025
noamgat pushed a commit to noamgat/vllm that referenced this pull request Aug 9, 2025
paulpak58 pushed a commit to paulpak58/vllm that referenced this pull request Aug 13, 2025
diegocastanibm pushed a commit to diegocastanibm/vllm that referenced this pull request Aug 15, 2025
epwalsh pushed a commit to epwalsh/vllm that referenced this pull request Aug 28, 2025
xiao-llm pushed a commit to xiao-llm/vllm that referenced this pull request Aug 28, 2025
zhewenl pushed a commit to zhewenl/vllm that referenced this pull request Aug 28, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

deepseek Related to DeepSeek models performance Performance-related issues ready ONLY add when PR is ready to merge/full CI is needed

Projects

None yet

Development

Successfully merging this pull request may close these issues.

[Perf]: Support non-contiguous input for dynamic_scaled_int8_quant and dynamic_per_token_scaled_fp8_quant

4 participants