Skip to content

Commit efd4ffa

Browse files
authored
[https://nvbugs/5534705][fix] Skip unnecessary CUDA graph capture (#8050)
Signed-off-by: ziyixiong-nv <[email protected]>
1 parent 84d2f12 commit efd4ffa

File tree

1 file changed

+20
-13
lines changed

1 file changed

+20
-13
lines changed

tensorrt_llm/_torch/pyexecutor/model_engine.py

Lines changed: 20 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -695,19 +695,25 @@ def release_batch(result: ScheduledRequests | None):
695695
cuda_graph_batch_sizes = sorted(self._cuda_graph_batch_sizes,
696696
reverse=True)
697697
# Create CUDA graphs for different draft lengths
698-
draft_lengths = [self.max_draft_len]
699-
# For non-draft model, we also capture the CUDA graph instance for draft length 0,
700-
# so that when we disable spec decode at runtime, we can still run the captured graph.
701-
# Note that for one engine mode, we are not able to turn off spec decode at runtime.
702-
if (not self.is_draft_model and self.max_draft_len > 0
703-
and not self.spec_config.spec_dec_mode.use_one_engine()
704-
# Assume that speculation is always on if the user didn't give us a max_concurrency
705-
# value. This will save on memory.
706-
and self.spec_config.max_concurrency is not None):
707-
draft_lengths.append(0)
708-
if self.is_spec_decode and self.is_draft_model and spec_resource_manager is not None and isinstance(
709-
spec_resource_manager, Eagle3ResourceManager):
710-
draft_lengths.append(self.original_max_draft_len)
698+
draft_lengths = []
699+
if self.is_draft_model:
700+
if self.model_is_wrapped and self.is_spec_decode and spec_resource_manager is not None and isinstance(
701+
spec_resource_manager, Eagle3ResourceManager):
702+
# The CDL path uses draft_len > 0 for the number of iterations in the drafting loop.
703+
draft_lengths.append(self.original_max_draft_len)
704+
else:
705+
draft_lengths.append(self.max_draft_len)
706+
else:
707+
# For non-draft model, we also capture the CUDA graph instance for draft length 0,
708+
# so that when we disable spec decode at runtime, we can still run the captured graph.
709+
# Note that for one engine mode, we are not able to turn off spec decode at runtime.
710+
if (self.max_draft_len > 0
711+
and not self.spec_config.spec_dec_mode.use_one_engine()
712+
# Assume that speculation is always on if the user didn't give us a max_concurrency
713+
# value. This will save on memory.
714+
and self.spec_config.max_concurrency is not None):
715+
draft_lengths.append(0)
716+
draft_lengths = [self.max_draft_len]
711717

712718
for bs in cuda_graph_batch_sizes:
713719
if bs > self.batch_size:
@@ -723,6 +729,7 @@ def release_batch(result: ScheduledRequests | None):
723729
logger.info(
724730
f"Run generation only CUDA graph warmup for batch size={bs}, draft_len={draft_len}"
725731
)
732+
726733
self.enable_spec_decode = draft_len > 0 or self.is_draft_model
727734

728735
def _update_draft_inference_state(is_first_draft: bool,

0 commit comments

Comments
 (0)