Skip to content

Conversation

@benchislett
Copy link
Collaborator

@benchislett benchislett commented Oct 15, 2025

Purpose

TRTLLM-gen kernels support full cuda graphs, but are only used with FlashInfer on Blackwell under certain conditions.
It might not be safe to change FlashInfer's cudagraph_support to UNIFORM_BATCH always, but we can still set it when we know TRTLLM-gen backend will be used.

Also update the docs to reflect the FlashInfer and FlashInferMLA cuda graph compatibility

FIX #26856

Test Plan

Ran Llama 3.1 8B-Instruct with EAGLE3 and confirmed that lm_eval-gsm8k is unchanged compared to main, and when TRTLLM attention is force disabled. Confirmed via torch profile that full graphs are now issued for verification when TRTLLM attention is enabled

Test Result

TRTLLM on:

|Tasks|Version|     Filter     |n-shot|  Metric   |   |Value |   |Stderr|
|-----|------:|----------------|-----:|-----------|---|-----:|---|-----:|
|gsm8k|      3|flexible-extract|     5|exact_match|↑  |0.7726|±  |0.0115|
|     |       |strict-match    |     5|exact_match|↑  |0.7013|±  |0.0126|

TRTLLM off:

|Tasks|Version|     Filter     |n-shot|  Metric   |   |Value |   |Stderr|
|-----|------:|----------------|-----:|-----------|---|-----:|---|-----:|
|gsm8k|      3|flexible-extract|     5|exact_match|↑  |0.7726|±  |0.0115|
|     |       |strict-match    |     5|exact_match|↑  |0.7013|±  |0.0126|

Benchmarks

MT-Bench at concurrency 1 sees a minimal speedup (~2%)

vllm serve meta-llama/Llama-3.1-8B-Instruct --speculative-config '{"method": "eagle3", "model": "yuhuili/EAGLE3-LLaMA3.1-Instruct-8B", "num_speculative_tokens": 3}' &

vllm bench serve --dataset-name hf --dataset-path philschmid/mt-bench --num-prompts 80 --max-concurrency 1 --model meta-llama/Llama-3.1-8B-Instruct --base-url http://0.0.0.0:8049

Before:

============ Serving Benchmark Result ============
Successful requests:                     80        
Maximum request concurrency:             1         
Benchmark duration (s):                  42.58     
Total input tokens:                      8133      
Total generated tokens:                  16955     
Request throughput (req/s):              1.88      
Output token throughput (tok/s):         398.23    
Peak output token throughput (tok/s):    186.00    
Peak concurrent requests:                4.00      
Total Token throughput (tok/s):          589.25    
---------------Time to First Token----------------
Mean TTFT (ms):                          12.11     
Median TTFT (ms):                        11.88     
P99 TTFT (ms):                           14.29     
-----Time per Output Token (excl. 1st token)------
Mean TPOT (ms):                          2.45      
Median TPOT (ms):                        2.45      
P99 TPOT (ms):                           3.33      
---------------Inter-token Latency----------------
Mean ITL (ms):                           5.36      
Median ITL (ms):                         5.36      
P99 ITL (ms):                            5.61      
==================================================

After:

============ Serving Benchmark Result ============
Successful requests:                     80        
Maximum request concurrency:             1         
Benchmark duration (s):                  41.73     
Total input tokens:                      8133      
Total generated tokens:                  16795     
Request throughput (req/s):              1.92      
Output token throughput (tok/s):         402.47    
Peak output token throughput (tok/s):    190.00    
Peak concurrent requests:                4.00      
Total Token throughput (tok/s):          597.37    
---------------Time to First Token----------------
Mean TTFT (ms):                          11.86     
Median TTFT (ms):                        11.75     
P99 TTFT (ms):                           14.93     
-----Time per Output Token (excl. 1st token)------
Mean TPOT (ms):                          2.43      
Median TPOT (ms):                        2.37      
P99 TPOT (ms):                           3.39      
---------------Inter-token Latency----------------
Mean ITL (ms):                           5.26      
Median ITL (ms):                         5.25      
P99 ITL (ms):                            5.48      
==================================================

@benchislett benchislett requested a review from mgoin as a code owner October 15, 2025 19:12
@mergify
Copy link

mergify bot commented Oct 15, 2025

Documentation preview: https://vllm--26937.org.readthedocs.build/en/26937/

@mergify mergify bot added documentation Improvements or additions to documentation v1 labels Oct 15, 2025
Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request enables full CUDA graphs for speculative decoding with FlashInfer when TRT-LLM attention kernels are available, which is a valuable performance enhancement. The implementation correctly updates the cudagraph_support attribute in FlashInferMetadataBuilder at runtime based on whether TRT-LLM attention can be used. The change from a class variable to an instance variable for cudagraph_support is appropriate for this dynamic behavior. The documentation has also been updated to reflect these changes. The logic appears sound and the provided test results indicate that correctness is maintained while enabling this optimization.

Copy link

@chatgpt-codex-connector chatgpt-codex-connector bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

💡 Codex Review

Here are some automated review suggestions for this pull request.

ℹ️ About Codex in GitHub

Your team has set up Codex to review pull requests in this repo. Reviews are triggered when you

  • Open a pull request for review
  • Mark a draft as ready
  • Comment "@codex review".

If Codex has suggestions, it will comment; otherwise it will react with 👍.

@vadiklyutiy
Copy link
Collaborator

Regarding performance improvement.
I did try on Qwen3-next with 2 prediction tokens.
With batch=1 it improves from 92 toks/s -> 222 toks/s

@mgoin
Copy link
Member

mgoin commented Oct 15, 2025

cc @LucasWilkinson @ProExpertProg regarding updating AttentionCGSupport dynamically

@LucasWilkinson
Copy link
Collaborator

LucasWilkinson commented Oct 16, 2025

cc @LucasWilkinson @ProExpertProg regarding updating AttentionCGSupport dynamically

Dynamically updating it should be fine since we only call it here on instances here

if builder.cudagraph_support.value < min_cg_support.value:
. But if we are going to dynamically update it I think we should make it an instance property instead of a class variable just to avoid confusion and future bugs.

@mergify mergify bot added the rocm Related to AMD ROCm label Oct 21, 2025
Copy link
Collaborator

@LucasWilkinson LucasWilkinson left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM; there are some nits that should be addressed (specifically for the CPU backend I think we should still keep the reorder_batch_threshold = 1)

it is a bit harder to see where cudagraph_support is set now :/ I guess the alternative would be use a function; i.e. add a get_cudagraph_support() function in the base class (I think the current implementation is better but im also flip-flopping haha)


class TorchSDPAMetadataBuilderV1(AttentionMetadataBuilder[TorchSDPAMetadata]):
reorder_batch_threshold: int = 1
reorder_batch_threshold: int
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: I still think this needs to be set?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It is set in the constructor: _init_reorder_batch_threshold(1, False)

The type annotation is left to indicate that it will never be "None" on this class and its subclasses. This is a common pattern in the changes in this PR

AttentionMetadataBuilder[XFormersAttentionMetadata]
):
reorder_batch_threshold: int = 1
reorder_batch_threshold: int
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: is this still needed?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes, as type annotation, see previous comment

)

reorder_batch_threshold: int = 1
reorder_batch_threshold: int
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: is this still needed?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes, as type annotation, see previous comment

@fhl2000
Copy link
Contributor

fhl2000 commented Nov 2, 2025

Hi @benchislett, it is no longer safe to dynamically update cudagraph_support inside __init__() since #27427 merged, because now we resolve cudagraph mode (which requires cudagraph_support) before actually initializing the builder instance. So making it an instance property is not a good idea. Instead, I would make a class method get_cudagraph_support() for this.

@benchislett
Copy link
Collaborator Author

@fhl2000 that breaks this PR pretty firmly. The main idea to enable full-cuda-graphs for FlashInfer is to opt-in dynamically based on whether TRTLLM kernels can be used, which depends on a number of parameters, some which are specific to the actual model architecture. Do you see an easy way around this?

@fhl2000
Copy link
Contributor

fhl2000 commented Nov 11, 2025

The right logic should be "determine cudagraph_support of each backend(builder class)"-> "resolve cudagraph mode" -> "initial cudagraph-relative stuff of each backend" anyway. I think since the cudagraph_support of the specific backend is fixed after its initialization, can we extract how we determine the cudagraph_support out as a static method (I think passing what passed to the builder_class. __init__() is enough). Alternative may be delaying the "initial cudagraph-relative stuff of each backend" after the backend initialization but before the first build() call (may be triggered at build_for cudagraph_capturing?). so the flow becomes `initial each backend (also determine cudagraph support here)" -> resolve cudagraph mode -> "trigger cudagraph initialization at build_for_cudagraph_capturing". I think the first option is easier, but let's see if you and @LucasWilkinson have other concerns on it.

@mergify
Copy link

mergify bot commented Nov 11, 2025

This pull request has merge conflicts that must be resolved before it can be
merged. Please rebase the PR, @benchislett.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

@mergify mergify bot added the needs-rebase label Nov 11, 2025
@benchislett
Copy link
Collaborator Author

Closing the PR for now while I work on a refactor to fix up cudagraph_support.

@benchislett
Copy link
Collaborator Author

@fhl2000 @LucasWilkinson I took another stab at this in #28479, following @fhl2000's suggestion. I think this will work well.

I omitted the refactor changes to reorder_batch_threshold to simplify the diff; can add in a follow-up PR if still desired

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

documentation Improvements or additions to documentation needs-rebase nvidia rocm Related to AMD ROCm v1

Projects

Status: Done

Development

Successfully merging this pull request may close these issues.

[Performance]: FalshInfer attn backend. Use dynamic AttentionCGSupport

5 participants