Skip to content

Commit b5b00ce

Browse files
committed
[https://nvbugs/5534705][fix] Skip draft_model forward when draft_len is 0
Signed-off-by: ziyixiong-nv <[email protected]>
1 parent e9e4632 commit b5b00ce

File tree

1 file changed

+19
-13
lines changed

1 file changed

+19
-13
lines changed

tensorrt_llm/_torch/pyexecutor/model_engine.py

Lines changed: 19 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -695,19 +695,24 @@ def release_batch(result: ScheduledRequests | None):
695695
cuda_graph_batch_sizes = sorted(self._cuda_graph_batch_sizes,
696696
reverse=True)
697697
# Create CUDA graphs for different draft lengths
698-
draft_lengths = [self.max_draft_len]
699-
# For non-draft model, we also capture the CUDA graph instance for draft length 0,
700-
# so that when we disable spec decode at runtime, we can still run the captured graph.
701-
# Note that for one engine mode, we are not able to turn off spec decode at runtime.
702-
if (not self.is_draft_model and self.max_draft_len > 0
703-
and not self.spec_config.spec_dec_mode.use_one_engine()
704-
# Assume that speculation is always on if the user didn't give us a max_concurrency
705-
# value. This will save on memory.
706-
and self.spec_config.max_concurrency is not None):
707-
draft_lengths.append(0)
708-
if self.is_spec_decode and self.is_draft_model and spec_resource_manager is not None and isinstance(
709-
spec_resource_manager, Eagle3ResourceManager):
710-
draft_lengths.append(self.original_max_draft_len)
698+
draft_lengths = []
699+
if self.is_draft_model:
700+
if self.model_is_wrapped:
701+
# The CDL path uses draft_len > 0 for the number of iterations in the drafting loop.
702+
draft_lengths.append(self.original_max_draft_len)
703+
else:
704+
draft_lengths.append(self.max_draft_len)
705+
else:
706+
# For non-draft model, we also capture the CUDA graph instance for draft length 0,
707+
# so that when we disable spec decode at runtime, we can still run the captured graph.
708+
# Note that for one engine mode, we are not able to turn off spec decode at runtime.
709+
if (self.max_draft_len > 0
710+
and not self.spec_config.spec_dec_mode.use_one_engine()
711+
# Assume that speculation is always on if the user didn't give us a max_concurrency
712+
# value. This will save on memory.
713+
and self.spec_config.max_concurrency is not None):
714+
draft_lengths.append(0)
715+
draft_lengths = [self.max_draft_len]
711716

712717
for bs in cuda_graph_batch_sizes:
713718
if bs > self.batch_size:
@@ -723,6 +728,7 @@ def release_batch(result: ScheduledRequests | None):
723728
logger.info(
724729
f"Run generation only CUDA graph warmup for batch size={bs}, draft_len={draft_len}"
725730
)
731+
726732
self.enable_spec_decode = draft_len > 0 or self.is_draft_model
727733

728734
def _update_draft_inference_state(is_first_draft: bool,

0 commit comments

Comments
 (0)