-
-
Notifications
You must be signed in to change notification settings - Fork 10.8k
[V1][CUDA] Full cudagraph support for FlashInfer #21367
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
210be12
9db6e4d
1928556
29e0fe8
a571e00
d08bf08
9f3839a
177afe1
e634bd5
a5d260e
c4694d2
54533b1
ef5f0fa
d573d29
393a573
c984155
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Large diffs are not rendered by default.
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -321,11 +321,16 @@ def compile_or_warm_up_model(self) -> None: | |
| if get_pp_group().is_last_rank: | ||
| max_num_reqs = min(self.scheduler_config.max_num_seqs, | ||
| self.scheduler_config.max_num_batched_tokens) | ||
| # activate building attn_metadata for this dummy run to avoid | ||
| # potential illegal memory access for full cudagraph relay. | ||
| attn_cudagraph = self.compilation_config.full_cuda_graph and\ | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I'm not sure I understand why you need this. AIUI, this code is specifically warming up shapes that are not in the cudagraph capture list? Is this required because you modified the list in the I see there's some discussion about a hang when you don't pass an attention metadata into the dummy_run? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Hey! @SageMoore, Thank you for the questions!
I think they are not related. I'd like to try explaining more here. This line of code is actually located after capturing all shapes of cudagraphs for the modified list in gpu_model_runner. This dummy_run with num_tokens= max_num_reqs is actually <= the max captured size of that modified list. And recall that dummy_run for attention_cg_support=PURE_DECODE_ONLY would only try to run pure decode batches. So here it would only run into cudagraph replay of decode only if it hits the size of list, otherwise no cudagraph. However, when it hits the replay, FlashInfer may be trapped in an infinite loop if the content in the persistent buffers is incorrect. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. OK please let me know if I'm understanding correctly. You are saying that, if There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. ah ok makes sense to me I think; basically for all dummy runs after capture we need build the metadata since it will result in a graph replay There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Exactly. |
||
| not self.model_config.enforce_eager | ||
|
|
||
| # We skip EPLB here since we don't want to record dummy metrics | ||
| hidden_states, last_hidden_states = \ | ||
| self.model_runner._dummy_run( | ||
| num_tokens=max_num_reqs, | ||
| capture_attn_cudagraph=attn_cudagraph, | ||
LucasWilkinson marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| skip_eplb=True, | ||
| ) | ||
| if self.model_runner.is_pooling_model: | ||
|
|
||
Uh oh!
There was an error while loading. Please reload this page.