Skip to content

Conversation

@fhl2000
Copy link
Contributor

@fhl2000 fhl2000 commented Jul 22, 2025

Essential Elements of an Effective PR Description Checklist

  • The purpose of the PR, such as "Fix some issue (link existing issues this PR will resolve)".
  • The test plan, such as providing test command.
  • The test results, such as pasting the results comparison before and after, or e2e results
  • (Optional) The necessary documentation update, such as updating supported_models.md and examples for a new model.

Purpose

This PR is split from origin #20059 to support full cudagraph for FlashInfer (pure decode only), which runs pure decode batches at full cudagraph, and falls back to no cudagraph at mix prefill-decode batches. Hope to land this first before #20059.

Details include:

  • Using the persistent buffer trick.
  • Create many decode_warpers, one for a cudagraph batch size, as this is required by the FlashInfer API.

This PR also fixes a potential bug originally from #18581, where an assertion error will be raised when capturing if max_capture_size is greater than max_num_reqs. To resolve this, a new enum type AttentionCGSupport (adapted from #20059) is introduced to distinguish how the backend supports cudagraph, so that we can overwrite the cudagraph_batch_sizes to be not greater than max_num_seqs.
NOTE: Currently, manually setting max capture size seems impossible after the introduction of Pydantic, which blocks config --cuda_graph_sizes to be int type

Limitation:

  1. FlashInfer backend currently does not support spec-decode when enabling full cudagraph (pure-decode only)
  2. trtllm_batch_decode_with_kv_cache of Flashinfer is not yet considered supported in this PR.

(updated) After #21137 is merged

To resolve the conflicts with the early version of this PR and further reduce potential overhead, the new commits include adding both host-side and device-side persistent buffers and overriding the decode plan function of Flashinfer for cudagraph execution. Now, for a full cudagraph common decode, we reduced any copy from temporary tensors to persistent buffers by directly writing results to persistent buffers; we further avoid device-to-device copy of paged_kv_indices, and only retaining two host-to-device copies of paged_kv_indptr and paged_kv_last_page_len by overriding that plan function.

Test Plan

lm_eval, benchmark_serving

Test Result

See comments below.

(Optional) Documentation Update

fhl2000 and others added 2 commits July 22, 2025 03:51
Signed-off-by: fhl <[email protected]>
Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces full CUDA graph support for the FlashInfer attention backend, which is a great performance enhancement for pure decode scenarios. The use of a new AttentionCGSupport enum to manage different levels of CUDA graph support across backends is a solid design choice that improves code clarity and maintainability.

The PR also includes an important bug fix to prevent graph capture for unsupported batch sizes, which is crucial for stability. I've identified one critical issue where a data structure for padding batch sizes is not updated after filtering the capture sizes, which could lead to runtime errors. I've provided a suggestion to fix this. Overall, this is a valuable contribution.

Comment on lines 2377 to 2379
self.cudagraph_batch_sizes = [
size for size in self.cudagraph_batch_sizes
if size <= max_num_seqs]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

This change correctly filters self.cudagraph_batch_sizes to prevent capturing graphs for sizes larger than max_num_seqs for PURE_DECODE_ONLY backends. However, the pad_for_cudagraph method, which is used at runtime to determine the padded graph size, relies on a mapping (bs_to_padded_graph_size) that was initialized with the original, unfiltered cudagraph_batch_sizes.

This discrepancy can lead to a KeyError at runtime. For example, if a batch with num_decodes is processed, pad_for_cudagraph might return a padded size that was filtered out and for which no CUDA graph was captured. This will cause a lookup failure in _decode_wrappers_cudagraph.

To fix this, you should re-initialize the padding map after filtering self.cudagraph_batch_sizes.

Suggested change
self.cudagraph_batch_sizes = [
size for size in self.cudagraph_batch_sizes
if size <= max_num_seqs]
self.cudagraph_batch_sizes = [
size for size in self.cudagraph_batch_sizes
if size <= max_num_seqs
]
self.vllm_config.compilation_config.init_with_cudagraph_sizes(
self.cudagraph_batch_sizes)

Copy link
Contributor Author

@fhl2000 fhl2000 Jul 22, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good catch for pad_for_cudagraph , though I think it would not affect the final correctness.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I find that the following hangs:

 VLLM_ATTENTION_BACKEND=FLASHINFER  vllm serve models/Llama-3.1-8B-Instruct --no-enable-prefix-caching --compilation-config='{"full_cuda_graph": true}' --max-num-seqs 2

at this point:

Capturing CUDA graph shapes: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 2/2 [00:00<00:00,  2.42it/s]
INFO 07-22 10:32:54 [gpu_model_runner.py:2404] Graph capturing finished in 1 secs, took 0.42 GiB

It just hangs here - could this be related?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if I remove the --max-num-seqs then it works fine, so I think it is indeed related.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I find overriding gpu model runner's cudagraph_batch_sizes would be enough. And vllm_config.compilation_config.init_with_cudagraph_sizes method does not actually override cudagraph_batch_sizes of compilation config after its first call.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm unfortunately it doesn't happen on main + FlashAttention. The hang is 100% reproducible using the code from this PR.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I will try my best to figure it out.

Copy link
Contributor Author

@fhl2000 fhl2000 Jul 22, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have tested --max-num-seqs being one of [2,4,8,16,24,32, 40] leads to hangs, while [1,48,56,...] work normally. The stuck occurs in a final dummy_run after all capturing in gpu_worker.py around lines 285~292, which runs into cudagraph replay (nums_tokens = max_num_seqs) without creating attn_metadata. I guess something weird happened in Flashinfer.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@tdoublep Take the new fix! It should be fine now. Could you please also test if it works for you?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, seems to work now. Thanks!

@fhl2000
Copy link
Contributor Author

fhl2000 commented Jul 22, 2025

@github-actions
Copy link

👋 Hi! Thank you for contributing to the vLLM project.

💬 Join our developer Slack at https://slack.vllm.ai to discuss your PR in #pr-reviews, coordinate on features in #feat- channels, or join special interest groups in #sig- channels.

Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run fastcheck CI which starts running only a small and essential subset of CI tests to quickly catch errors. You can run other CI tests on top of those by going to your fastcheck build on Buildkite UI (linked in the PR checks section) and unblock them. If you do not have permission to unblock, ping simon-mo or khluu to add you in our Buildkite org.

Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging.

To run CI, PR reviewers can either: Add ready label to the PR or enable auto-merge.

🚀

@tdoublep
Copy link
Member

This PR will be super-helpful for enabling full (decode-only) CUDA graphs for hybrid (mamba/attention) models in V1, where right now we need to use FlashInfer. I am testing these changes with my branch now.

@fhl2000 fhl2000 changed the title [V1][CUDA] Full cudagraph support for FlashInfer [Do not merge][V1][CUDA] Full cudagraph support for FlashInfer Jul 22, 2025
Signed-off-by: fhl2000 <[email protected]>
@fhl2000 fhl2000 changed the title [Do not merge][V1][CUDA] Full cudagraph support for FlashInfer [V1][CUDA] Full cudagraph support for FlashInfer Jul 22, 2025
Copy link
Member

@tdoublep tdoublep left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A few questions but otherwise looks good. I'm keen to see this PR merged because it should push us close to the point when we can deprecate V0 for hybrid models (which require FlashInfer currently).

Comment on lines 295 to 296
# Always activate creating attn_cudagraphs for dummy run to avoid
# illegal memory access for full cudagraph.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you explain a bit more why this change is needed?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure! The current dummy_run is not aware of whether it is in the status of warming up or needs to trigger cudagraph. It doesn't set up skip_cuda_graph to forward_ctx, so it blindly expects the model executions to go through cudagraph. However, even after the full cudagraphs are captured (buffer address is solid), if attn_metadata is not built correctly in dummy_run (nothing has been going through Flashinfer's plan function), it may access incorrect values in these buffers and may potentially fall into an infinite loop. I think always activating this part is also not bad for piecewise cudagraph.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hey, just found that always activating it here would cause CI failures for FlexAttention. So I have to enable this only when full_cuda_graph and not enforce eager.

fhl2000 and others added 2 commits July 23, 2025 03:09
Signed-off-by: fhl2000 <[email protected]>
Copy link
Contributor

@SageMoore SageMoore left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I generally think the structure looks reasonable. I would like to have a better understanding for why this patch hangs when warming up shapes that used to be in the capture list but have been removed (I may not be describing this correctly, I'm just going off of your conversation with @tdoublep).

self.scheduler_config.max_num_batched_tokens)
# activate building attn_metadata for this dummy run to avoid
# potential illegal memory access for full cudagraph relay.
attn_cudagraph = self.compilation_config.full_cuda_graph and\
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not sure I understand why you need this. AIUI, this code is specifically warming up shapes that are not in the cudagraph capture list? Is this required because you modified the list in the GPUModelRunner?

I see there's some discussion about a hang when you don't pass an attention metadata into the dummy_run?

Copy link
Contributor Author

@fhl2000 fhl2000 Jul 24, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hey! @SageMoore, Thank you for the questions!

Is this required because you modified the list in the GPUModelRunner?

I think they are not related.

I'd like to try explaining more here. This line of code is actually located after capturing all shapes of cudagraphs for the modified list in gpu_model_runner. This dummy_run with num_tokens= max_num_reqs is actually <= the max captured size of that modified list. And recall that dummy_run for attention_cg_support=PURE_DECODE_ONLY would only try to run pure decode batches. So here it would only run into cudagraph replay of decode only if it hits the size of list, otherwise no cudagraph. However, when it hits the replay, FlashInfer may be trapped in an infinite loop if the content in the persistent buffers is incorrect.

Copy link
Contributor

@SageMoore SageMoore Jul 29, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OK please let me know if I'm understanding correctly. You are saying that, if max_num_reqs is a shape that has already been full cudagraph captured, we need to make sure that the _dummy_run goes through the process of creating an AttentionMetadata because, even though the persistent buffers are guaranteed to exist, they can contain incorrect data which can cause the graph replay to hang when running with the flash infer backend?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ah ok makes sense to me I think; basically for all dummy runs after capture we need build the metadata since it will result in a graph replay

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You are saying that, if max_num_reqs is a shape that has already been full cudagraph captured, we need to make sure that the _dummy_run goes through the process of creating an AttentionMetadata because, even though the persistent buffers are guaranteed to exist, they can contain incorrect data which can cause the graph replay to hang when running with the flash infer backend

Exactly.

"FlashInfer only supports decode-only full CUDAGraph capture. " \
"Make sure all cudagraph capture sizes <= max_num_seq."

m.max_query_len = 1 # decode-only
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: You shouldn't need to set this. You can add it to your decode_only assert on the previous line.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this is a common practice now (also see this part for FlashMLA). As the attn_metadate passed from dummy run have max_query_len=num_tokens currently.

@mergify
Copy link

mergify bot commented Jul 24, 2025

This pull request has merge conflicts that must be resolved before it can be
merged. Please rebase the PR, @fhl2000.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

@mergify mergify bot added the needs-rebase label Jul 24, 2025
@mergify mergify bot removed the needs-rebase label Jul 24, 2025
@fhl2000
Copy link
Contributor Author

fhl2000 commented Jul 24, 2025

I am working on progress now after merging the #21137. Plan to make the fast version of decode wrapper plan function in this PR!

))
return output_padded

# TODO:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@fhl2000 Is this going to be done in this PR?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, I am doing this now. Because after #21137 merged, there are many conflicts, and I realized it would be less painful to resolve these conflicts while also handling the fast plan override at the same time.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, I am doing this now. Because after #21137 merged, there are many conflicts, and I realized it would be less painful to resolve these conflicts while also handling the fast plan override at the same time.

Hey, this is now complete! Feel free to share any feedback!

@shyeh25
Copy link

shyeh25 commented Aug 1, 2025

@fhl2000
Great work! It works well in llama3-70B FP8. But there is a functionality issue for llama3-70B FP4.
Could you take a look? Thanks

Repro server command:
python3 -m vllm.entrypoints.openai.api_server --host 0.0.0.0 --port 8087 --model nvidia/Llama-3.3-70B-Instruct-FP4 --tokenizer nvidia/Llama-3.3-70B-Instruct-FP4 --dtype auto --kv-cache-dtype fp8 --tensor-parallel-size 1 --pipeline-parallel-size 1 --swap-space 16 --max-num-seqs 512 --trust-remote-code --max-model-len 2048 --gpu-memory-utilization 0.9 --max-num-batched-tokens 8192 --quantization modelopt_fp4 --enable-chunked-prefill --no-enable-prefix-caching --async-scheduling --disable-log-requests --compilation-config '{"pass_config": {"enable_fi_allreduce_fusion": true}, "custom_ops": ["+rms_norm"], "level": 3,"full_cuda_graph":true}'

Model: https://huggingface.co/nvidia/Llama-3.3-70B-Instruct-FP4

Error message:
Capturing CUDA graph shapes: 0%| | 0/67 [00:12<?, ?it/s]
(VllmWorker rank=0 pid=453450) ERROR 07-31 23:33:56 [multiproc_executor.py:594] WorkerProc hit an exception.
(VllmWorker rank=0 pid=453450) ERROR 07-31 23:33:56 [multiproc_executor.py:594] Traceback (most recent call last):
(VllmWorker rank=0 pid=453450) ERROR 07-31 23:33:56 [multiproc_executor.py:594] File "/usr/local/lib/python3.12/dist-packages/vllm/v1/executor/multiproc_executor.py", line 589, in worker_busy_loop
(VllmWorker rank=0 pid=453450) ERROR 07-31 23:33:56 [multiproc_executor.py:594] output = func(*args, **kwargs)
(VllmWorker rank=0 pid=453450) ERROR 07-31 23:33:56 [multiproc_executor.py:594] ^^^^^^^^^^^^^^^^^^^^^
(VllmWorker rank=0 pid=453450) ERROR 07-31 23:33:56 [multiproc_executor.py:594] File "/usr/local/lib/python3.12/dist-packages/vllm/v1/worker/gpu_worker.py", line 314, in compile_or_warm_up_model
(VllmWorker rank=0 pid=453450) ERROR 07-31 23:33:56 [multiproc_executor.py:594] self.model_runner.capture_model()
(VllmWorker rank=0 pid=453450) ERROR 07-31 23:33:56 [multiproc_executor.py:594] File "/usr/local/lib/python3.12/dist-packages/vllm/v1/worker/gpu_model_runner.py", line 2563, in capture_model
(VllmWorker rank=0 pid=453450) ERROR 07-31 23:33:56 [multiproc_executor.py:594] self._dummy_run(num_tokens,
(VllmWorker rank=0 pid=453450) ERROR 07-31 23:33:56 [multiproc_executor.py:594] File "/usr/local/lib/python3.12/dist-packages/torch/utils/_contextlib.py", line 116, in decorate_context
(VllmWorker rank=0 pid=453450) ERROR 07-31 23:33:56 [multiproc_executor.py:594] return func(*args, **kwargs)
(VllmWorker rank=0 pid=453450) ERROR 07-31 23:33:56 [multiproc_executor.py:594] ^^^^^^^^^^^^^^^^^^^^^
(VllmWorker rank=0 pid=453450) ERROR 07-31 23:33:56 [multiproc_executor.py:594] File "/usr/local/lib/python3.12/dist-packages/vllm/v1/worker/gpu_model_runner.py", line 2217, in _dummy_run
(VllmWorker rank=0 pid=453450) ERROR 07-31 23:33:56 [multiproc_executor.py:594] kv_cache_group_id].build_for_cudagraph_capture(
(VllmWorker rank=0 pid=453450) ERROR 07-31 23:33:56 [multiproc_executor.py:594] ^^^^^^^^^^^^^^^^^^^^^^^^^^^^
(VllmWorker rank=0 pid=453450) ERROR 07-31 23:33:56 [multiproc_executor.py:594] File "/usr/local/lib/python3.12/dist-packages/vllm/v1/attention/backends/flashinfer.py", line 600, in build_for_cudagraph_capture
(VllmWorker rank=0 pid=453450) ERROR 07-31 23:33:56 [multiproc_executor.py:594] return self.build(0, m)
(VllmWorker rank=0 pid=453450) ERROR 07-31 23:33:56 [multiproc_executor.py:594] ^^^^^^^^^^^^^^^^
(VllmWorker rank=0 pid=453450) ERROR 07-31 23:33:56 [multiproc_executor.py:594] File "/usr/local/lib/python3.12/dist-packages/vllm/v1/attention/backends/flashinfer.py", line 582, in build
(VllmWorker rank=0 pid=453450) ERROR 07-31 23:33:56 [multiproc_executor.py:594] self._plan(num_prefills, num_decodes, attn_metadata)
(VllmWorker rank=0 pid=453450) ERROR 07-31 23:33:56 [multiproc_executor.py:594] File "/usr/local/lib/python3.12/dist-packages/vllm/v1/attention/backends/flashinfer.py", line 458, in _plan
(VllmWorker rank=0 pid=453450) ERROR 07-31 23:33:56 [multiproc_executor.py:594] fast_plan_decode(
(VllmWorker rank=0 pid=453450) ERROR 07-31 23:33:56 [multiproc_executor.py:594] File "/usr/local/lib/python3.12/dist-packages/vllm/v1/attention/backends/flashinfer.py", line 853, in fast_plan_decode
(VllmWorker rank=0 pid=453450) ERROR 07-31 23:33:56 [multiproc_executor.py:594] self.plan(
(VllmWorker rank=0 pid=453450) ERROR 07-31 23:33:56 [multiproc_executor.py:594] File "/usr/local/lib/python3.12/dist-packages/flashinfer/decode.py", line 1012, in plan
(VllmWorker rank=0 pid=453450) ERROR 07-31 23:33:56 [multiproc_executor.py:594] self._plan_info = self._cached_module.plan(
(VllmWorker rank=0 pid=453450) ERROR 07-31 23:33:56 [multiproc_executor.py:594] ^^^^^^^^^^^^^^^^^^^^^^^^^
(VllmWorker rank=0 pid=453450) ERROR 07-31 23:33:56 [multiproc_executor.py:594] File "/usr/local/lib/python3.12/dist-packages/torch/_ops.py", line 756, in call
(VllmWorker rank=0 pid=453450) ERROR 07-31 23:33:56 [multiproc_executor.py:594] return self._op(*args, **kwargs)
(VllmWorker rank=0 pid=453450) ERROR 07-31 23:33:56 [multiproc_executor.py:594] ^^^^^^^^^^^^^^^^^^^^^^^^^
(VllmWorker rank=0 pid=453450) ERROR 07-31 23:33:56 [multiproc_executor.py:594] RuntimeError: Error in function 'aligned_alloc' at /usr/local/lib/python3.12/dist-packages/flashinfer/data/include/flashinfer/allocator.h:48: Failed to allocate memory for batch_prefill_tmp_v with size 402128896 and alignment 16 in AlignedAllocator

@fhl2000
Copy link
Contributor Author

fhl2000 commented Aug 1, 2025

@fhl2000
Great work! It works well in llama3-70B FP8. But there is a functionality issue for llama3-70B FP4.
Could you take a look? Thanks

Thanks for the report! I'll try taking a look, but no promise to find a fix as I am not such familiar with quant.

@nvpohanh
Copy link
Contributor

nvpohanh commented Aug 1, 2025

@fhl2000 If you think it's not directly caused by your change, please let us know so that we can debug it. Thanks!

@fhl2000
Copy link
Contributor Author

fhl2000 commented Aug 1, 2025

@fhl2000 If you think it's not directly caused by your change, please let us know so that we can debug it. Thanks!

Yes, I think I haven't changed anything related to dtype/quant. Happy to hear that you can debug it. I am not sure if this can work in sglang. If yes, please let me know.

@mergify
Copy link

mergify bot commented Aug 1, 2025

This pull request has merge conflicts that must be resolved before it can be
merged. Please rebase the PR, @fhl2000.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

@mergify mergify bot added the needs-rebase label Aug 1, 2025
@mergify mergify bot removed the needs-rebase label Aug 1, 2025
@mgoin mgoin merged commit 2332243 into vllm-project:main Aug 2, 2025
41 checks passed
npanpaliya pushed a commit to odh-on-pz/vllm-upstream that referenced this pull request Aug 6, 2025
jinzhen-lin pushed a commit to jinzhen-lin/vllm that referenced this pull request Aug 9, 2025
noamgat pushed a commit to noamgat/vllm that referenced this pull request Aug 9, 2025
paulpak58 pushed a commit to paulpak58/vllm that referenced this pull request Aug 13, 2025
diegocastanibm pushed a commit to diegocastanibm/vllm that referenced this pull request Aug 15, 2025
epwalsh pushed a commit to epwalsh/vllm that referenced this pull request Aug 28, 2025
zhewenl pushed a commit to zhewenl/vllm that referenced this pull request Aug 28, 2025
@fhl2000 fhl2000 deleted the full_cg_flashinfer branch September 30, 2025 15:38
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ready ONLY add when PR is ready to merge/full CI is needed rocm Related to AMD ROCm v1

Projects

None yet

Development

Successfully merging this pull request may close these issues.

7 participants