-
-
Notifications
You must be signed in to change notification settings - Fork 10.8k
[V1][CUDA] Full cudagraph support for FlashInfer #21367
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
Signed-off-by: fhl2000 <[email protected]>
Signed-off-by: fhl <[email protected]>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code Review
This pull request introduces full CUDA graph support for the FlashInfer attention backend, which is a great performance enhancement for pure decode scenarios. The use of a new AttentionCGSupport enum to manage different levels of CUDA graph support across backends is a solid design choice that improves code clarity and maintainability.
The PR also includes an important bug fix to prevent graph capture for unsupported batch sizes, which is crucial for stability. I've identified one critical issue where a data structure for padding batch sizes is not updated after filtering the capture sizes, which could lead to runtime errors. I've provided a suggestion to fix this. Overall, this is a valuable contribution.
vllm/v1/worker/gpu_model_runner.py
Outdated
| self.cudagraph_batch_sizes = [ | ||
| size for size in self.cudagraph_batch_sizes | ||
| if size <= max_num_seqs] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This change correctly filters self.cudagraph_batch_sizes to prevent capturing graphs for sizes larger than max_num_seqs for PURE_DECODE_ONLY backends. However, the pad_for_cudagraph method, which is used at runtime to determine the padded graph size, relies on a mapping (bs_to_padded_graph_size) that was initialized with the original, unfiltered cudagraph_batch_sizes.
This discrepancy can lead to a KeyError at runtime. For example, if a batch with num_decodes is processed, pad_for_cudagraph might return a padded size that was filtered out and for which no CUDA graph was captured. This will cause a lookup failure in _decode_wrappers_cudagraph.
To fix this, you should re-initialize the padding map after filtering self.cudagraph_batch_sizes.
| self.cudagraph_batch_sizes = [ | |
| size for size in self.cudagraph_batch_sizes | |
| if size <= max_num_seqs] | |
| self.cudagraph_batch_sizes = [ | |
| size for size in self.cudagraph_batch_sizes | |
| if size <= max_num_seqs | |
| ] | |
| self.vllm_config.compilation_config.init_with_cudagraph_sizes( | |
| self.cudagraph_batch_sizes) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good catch for pad_for_cudagraph , though I think it would not affect the final correctness.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I find that the following hangs:
VLLM_ATTENTION_BACKEND=FLASHINFER vllm serve models/Llama-3.1-8B-Instruct --no-enable-prefix-caching --compilation-config='{"full_cuda_graph": true}' --max-num-seqs 2
at this point:
Capturing CUDA graph shapes: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 2/2 [00:00<00:00, 2.42it/s]
INFO 07-22 10:32:54 [gpu_model_runner.py:2404] Graph capturing finished in 1 secs, took 0.42 GiB
It just hangs here - could this be related?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
if I remove the --max-num-seqs then it works fine, so I think it is indeed related.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I find overriding gpu model runner's cudagraph_batch_sizes would be enough. And vllm_config.compilation_config.init_with_cudagraph_sizes method does not actually override cudagraph_batch_sizes of compilation config after its first call.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hmm unfortunately it doesn't happen on main + FlashAttention. The hang is 100% reproducible using the code from this PR.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I will try my best to figure it out.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I have tested --max-num-seqs being one of [2,4,8,16,24,32, 40] leads to hangs, while [1,48,56,...] work normally. The stuck occurs in a final dummy_run after all capturing in gpu_worker.py around lines 285~292, which runs into cudagraph replay (nums_tokens = max_num_seqs) without creating attn_metadata. I guess something weird happened in Flashinfer.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@tdoublep Take the new fix! It should be fine now. Could you please also test if it works for you?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, seems to work now. Thanks!
|
👋 Hi! Thank you for contributing to the vLLM project. 💬 Join our developer Slack at https://slack.vllm.ai to discuss your PR in #pr-reviews, coordinate on features in #feat- channels, or join special interest groups in #sig- channels. Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging. To run CI, PR reviewers can either: Add 🚀 |
|
This PR will be super-helpful for enabling full (decode-only) CUDA graphs for hybrid (mamba/attention) models in V1, where right now we need to use FlashInfer. I am testing these changes with my branch now. |
Signed-off-by: fhl2000 <[email protected]>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
A few questions but otherwise looks good. I'm keen to see this PR merged because it should push us close to the point when we can deprecate V0 for hybrid models (which require FlashInfer currently).
vllm/v1/worker/gpu_worker.py
Outdated
| # Always activate creating attn_cudagraphs for dummy run to avoid | ||
| # illegal memory access for full cudagraph. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could you explain a bit more why this change is needed?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sure! The current dummy_run is not aware of whether it is in the status of warming up or needs to trigger cudagraph. It doesn't set up skip_cuda_graph to forward_ctx, so it blindly expects the model executions to go through cudagraph. However, even after the full cudagraphs are captured (buffer address is solid), if attn_metadata is not built correctly in dummy_run (nothing has been going through Flashinfer's plan function), it may access incorrect values in these buffers and may potentially fall into an infinite loop. I think always activating this part is also not bad for piecewise cudagraph.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hey, just found that always activating it here would cause CI failures for FlexAttention. So I have to enable this only when full_cuda_graph and not enforce eager.
Signed-off-by: fhl2000 <[email protected]>
Signed-off-by: fhl <[email protected]>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I generally think the structure looks reasonable. I would like to have a better understanding for why this patch hangs when warming up shapes that used to be in the capture list but have been removed (I may not be describing this correctly, I'm just going off of your conversation with @tdoublep).
| self.scheduler_config.max_num_batched_tokens) | ||
| # activate building attn_metadata for this dummy run to avoid | ||
| # potential illegal memory access for full cudagraph relay. | ||
| attn_cudagraph = self.compilation_config.full_cuda_graph and\ |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm not sure I understand why you need this. AIUI, this code is specifically warming up shapes that are not in the cudagraph capture list? Is this required because you modified the list in the GPUModelRunner?
I see there's some discussion about a hang when you don't pass an attention metadata into the dummy_run?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hey! @SageMoore, Thank you for the questions!
Is this required because you modified the list in the GPUModelRunner?
I think they are not related.
I'd like to try explaining more here. This line of code is actually located after capturing all shapes of cudagraphs for the modified list in gpu_model_runner. This dummy_run with num_tokens= max_num_reqs is actually <= the max captured size of that modified list. And recall that dummy_run for attention_cg_support=PURE_DECODE_ONLY would only try to run pure decode batches. So here it would only run into cudagraph replay of decode only if it hits the size of list, otherwise no cudagraph. However, when it hits the replay, FlashInfer may be trapped in an infinite loop if the content in the persistent buffers is incorrect.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
OK please let me know if I'm understanding correctly. You are saying that, if max_num_reqs is a shape that has already been full cudagraph captured, we need to make sure that the _dummy_run goes through the process of creating an AttentionMetadata because, even though the persistent buffers are guaranteed to exist, they can contain incorrect data which can cause the graph replay to hang when running with the flash infer backend?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ah ok makes sense to me I think; basically for all dummy runs after capture we need build the metadata since it will result in a graph replay
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You are saying that, if
max_num_reqsis a shape that has already been full cudagraph captured, we need to make sure that the_dummy_rungoes through the process of creating anAttentionMetadatabecause, even though the persistent buffers are guaranteed to exist, they can contain incorrect data which can cause the graph replay to hang when running with the flash infer backend
Exactly.
| "FlashInfer only supports decode-only full CUDAGraph capture. " \ | ||
| "Make sure all cudagraph capture sizes <= max_num_seq." | ||
|
|
||
| m.max_query_len = 1 # decode-only |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nit: You shouldn't need to set this. You can add it to your decode_only assert on the previous line.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think this is a common practice now (also see this part for FlashMLA). As the attn_metadate passed from dummy run have max_query_len=num_tokens currently.
|
This pull request has merge conflicts that must be resolved before it can be |
|
I am working on progress now after merging the #21137. Plan to make the fast version of decode wrapper plan function in this PR! |
| )) | ||
| return output_padded | ||
|
|
||
| # TODO: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@fhl2000 Is this going to be done in this PR?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, I am doing this now. Because after #21137 merged, there are many conflicts, and I realized it would be less painful to resolve these conflicts while also handling the fast plan override at the same time.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, I am doing this now. Because after #21137 merged, there are many conflicts, and I realized it would be less painful to resolve these conflicts while also handling the fast plan override at the same time.
Hey, this is now complete! Feel free to share any feedback!
Signed-off-by: fhl2000 <[email protected]>
|
@fhl2000 Repro server command: Model: https://huggingface.co/nvidia/Llama-3.3-70B-Instruct-FP4 Error message: |
Thanks for the report! I'll try taking a look, but no promise to find a fix as I am not such familiar with quant. |
|
@fhl2000 If you think it's not directly caused by your change, please let us know so that we can debug it. Thanks! |
Yes, I think I haven't changed anything related to dtype/quant. Happy to hear that you can debug it. I am not sure if this can work in sglang. If yes, please let me know. |
|
This pull request has merge conflicts that must be resolved before it can be |
Signed-off-by: Jinzhen Lin <[email protected]>
Signed-off-by: Noam Gat <[email protected]>
Signed-off-by: Paul Pak <[email protected]>
Signed-off-by: Diego-Castan <[email protected]>
Essential Elements of an Effective PR Description Checklist
supported_models.mdandexamplesfor a new model.Purpose
This PR is split from origin #20059 to support full cudagraph for FlashInfer (pure decode only), which runs pure decode batches at full cudagraph, and falls back to no cudagraph at mix prefill-decode batches. Hope to land this first before #20059.
Details include:
This PR also fixes a potential bug originally from #18581, where an assertion error will be raised when capturing if max_capture_size is greater than max_num_reqs. To resolve this, a new enum type AttentionCGSupport (adapted from #20059) is introduced to distinguish how the backend supports cudagraph, so that we can overwrite the cudagraph_batch_sizes to be not greater than max_num_seqs.
NOTE: Currently, manually setting max capture size seems impossible after the introduction of Pydantic, which blocks config
--cuda_graph_sizesto beinttypeLimitation:
(updated) After #21137 is merged
To resolve the conflicts with the early version of this PR and further reduce potential overhead, the new commits include adding both host-side and device-side persistent buffers and overriding the decode plan function of Flashinfer for cudagraph execution. Now, for a full cudagraph common decode, we reduced any copy from temporary tensors to persistent buffers by directly writing results to persistent buffers; we further avoid
device-to-devicecopy of paged_kv_indices, and only retaining two host-to-device copies ofpaged_kv_indptrandpaged_kv_last_page_lenby overriding that plan function.Test Plan
lm_eval, benchmark_serving
Test Result
See comments below.
(Optional) Documentation Update