Skip to content

Conversation

yewentao256
Copy link
Member

@yewentao256 yewentao256 commented Jul 28, 2025

Purpose

Similar to #19452, we temporally fix by enforcing input contiguous. Will work on #19630 later to see we can support some non-contiguous input case.

lm_eval --model vllm --model_args "pretrained=nm-testing/DeepSeek-Coder-V2-Lite-Instruct-FP8,max_model_len=32768,enable_expert_parallel=True,enforce_eager=True" --trust_remote_code --tasks gsm8k --num_fewshot 5 --batch_size auto

[rank0]:   File "/data/vllm-community-homes/vllm-user-6/vllm/vllm/model_executor/layers/quantization/input_quant_fp8.py", line 60, in forward_cuda
[rank0]:     return ops.scaled_fp8_quant(
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/data/vllm-community-homes/vllm-user-6/vllm/vllm/_custom_ops.py", line 1288, in scaled_fp8_quant
[rank0]:     torch.ops._C.static_scaled_fp8_quant(output, input, scale)
[rank0]:   File "/data/vllm-community-homes/vllm-user-6/.venv/lib/python3.12/site-packages/torch/_ops.py", line 1158, in __call__
[rank0]:     return self._op(*args, **(kwargs or {}))
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]: RuntimeError: Expected input.is_contiguous() to be true, but got false.  (Could this error message be improved?  If so, please report an enhancement request to PyTorch.)

Test

llm (pretrained=nm-testing/DeepSeek-Coder-V2-Lite-Instruct-FP8,max_model_len=32768,enable_expert_parallel=True,enforce_eager=True,trust_remote_code=True), gen_kwargs: (None), limit: None, num_fewshot: 5, batch_size: auto
|Tasks|Version|     Filter     |n-shot|  Metric   |   |Value |   |Stderr|
|-----|------:|----------------|-----:|-----------|---|-----:|---|-----:|
|gsm8k|      3|flexible-extract|     5|exact_match||0.7415|±  |0.0121|
|     |       |strict-match    |     5|exact_match||0.7104|±  |0.0125|

Signed-off-by: yewentao256 <[email protected]>
Copy link

👋 Hi! Thank you for contributing to the vLLM project.

💬 Join our developer Slack at https://slack.vllm.ai to discuss your PR in #pr-reviews, coordinate on features in #feat- channels, or join special interest groups in #sig- channels.

Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run fastcheck CI which starts running only a small and essential subset of CI tests to quickly catch errors. You can run other CI tests on top of those by going to your fastcheck build on Buildkite UI (linked in the PR checks section) and unblock them. If you do not have permission to unblock, ping simon-mo or khluu to add you in our Buildkite org.

Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging.

To run CI, PR reviewers can either: Add ready label to the PR or enable auto-merge.

🚀

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request addresses a RuntimeError that occurs when non-contiguous tensors are passed to the dynamic_scaled_fp8_quant and static_scaled_fp8_quant custom operators. The fix correctly enforces input contiguity by calling .contiguous() on the input tensor before it is passed to the underlying C++ kernels.

The change is a direct and effective solution to the bug described, and it aligns with the existing pattern for other custom operator calls within the same function. The pull request description clearly communicates that this is an intentional, temporary measure to ensure stability, with a more performant, long-term solution tracked in a separate issue. The code is clean, localized, and I found no issues of high or critical severity. The pull request is ready to be merged.

@mgoin mgoin added bug Something isn't working ready ONLY add when PR is ready to merge/full CI is needed labels Jul 28, 2025
@mgoin mgoin enabled auto-merge (squash) July 28, 2025 17:51
Copy link
Member

@mgoin mgoin left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM. I'm curious what the strides actually are. I think the kernels should be able to support non-contiguous inputs as-is if input.stride(-1) == 1

@yewentao256
Copy link
Member Author

lm_eval --model vllm --model_args "pretrained=nm-testing/DeepSeek-Coder-V2-Lite-Instruct-FP8,max_model_len=32768,enable_expert_parallel=True,enforce_eager=True" --trust_remote_code --tasks gsm8k --num_fewshot 5 --batch_size auto

  • shape: torch.Size([975, 512])
  • stride: (576, 1),

Yeah, I think I will look at #19630 later to see if we can support this case using scalar fallback

@mgoin mgoin merged commit e0e58f9 into vllm-project:main Jul 28, 2025
74 checks passed
@yewentao256 yewentao256 deleted the wye-enforce-contiguous-for-dynamic/static-fp8-quant branch July 29, 2025 14:34
liuyumoye pushed a commit to liuyumoye/vllm that referenced this pull request Jul 31, 2025
x22x22 pushed a commit to x22x22/vllm that referenced this pull request Aug 5, 2025
…atic_scaled_fp8_quant` (vllm-project#21773)

Signed-off-by: yewentao256 <[email protected]>
Signed-off-by: x22x22 <[email protected]>
Pradyun92 pushed a commit to Pradyun92/vllm that referenced this pull request Aug 6, 2025
npanpaliya pushed a commit to odh-on-pz/vllm-upstream that referenced this pull request Aug 6, 2025
jinzhen-lin pushed a commit to jinzhen-lin/vllm that referenced this pull request Aug 9, 2025
…atic_scaled_fp8_quant` (vllm-project#21773)

Signed-off-by: yewentao256 <[email protected]>
Signed-off-by: Jinzhen Lin <[email protected]>
paulpak58 pushed a commit to paulpak58/vllm that referenced this pull request Aug 13, 2025
…atic_scaled_fp8_quant` (vllm-project#21773)

Signed-off-by: yewentao256 <[email protected]>
Signed-off-by: Paul Pak <[email protected]>
diegocastanibm pushed a commit to diegocastanibm/vllm that referenced this pull request Aug 15, 2025
…atic_scaled_fp8_quant` (vllm-project#21773)

Signed-off-by: yewentao256 <[email protected]>
Signed-off-by: Diego-Castan <[email protected]>
epwalsh pushed a commit to epwalsh/vllm that referenced this pull request Aug 28, 2025
zhewenl pushed a commit to zhewenl/vllm that referenced this pull request Aug 28, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

bug Something isn't working ready ONLY add when PR is ready to merge/full CI is needed

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants