Skip to content

Commit 11de1dc

Browse files
committed
[https://nvbugs/5534705][fix] Skip draft_model forward when draft_len is 0
Signed-off-by: ziyixiong-nv <[email protected]>
1 parent 9b3d7cc commit 11de1dc

File tree

1 file changed

+19
-13
lines changed

1 file changed

+19
-13
lines changed

tensorrt_llm/_torch/pyexecutor/model_engine.py

Lines changed: 19 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -697,19 +697,24 @@ def release_batch(result: ScheduledRequests | None):
697697
cuda_graph_batch_sizes = sorted(self._cuda_graph_batch_sizes,
698698
reverse=True)
699699
# Create CUDA graphs for different draft lengths
700-
draft_lengths = [self.max_draft_len]
701-
# For non-draft model, we also capture the CUDA graph instance for draft length 0,
702-
# so that when we disable spec decode at runtime, we can still run the captured graph.
703-
# Note that for one engine mode, we are not able to turn off spec decode at runtime.
704-
if (not self.is_draft_model and self.max_draft_len > 0
705-
and not self.spec_config.spec_dec_mode.use_one_engine()
706-
# Assume that speculation is always on if the user didn't give us a max_concurrency
707-
# value. This will save on memory.
708-
and self.spec_config.max_concurrency is not None):
709-
draft_lengths.append(0)
710-
if self.is_spec_decode and self.is_draft_model and spec_resource_manager is not None and isinstance(
711-
spec_resource_manager, Eagle3ResourceManager):
712-
draft_lengths.append(self.original_max_draft_len)
700+
draft_lengths = []
701+
if self.is_draft_model:
702+
if self.model_is_wrapped:
703+
# The CDL path uses draft_len > 0 for the number of iterations in the drafting loop.
704+
draft_lengths.append(self.original_max_draft_len)
705+
else:
706+
draft_lengths.append(self.max_draft_len)
707+
else:
708+
# For non-draft model, we also capture the CUDA graph instance for draft length 0,
709+
# so that when we disable spec decode at runtime, we can still run the captured graph.
710+
# Note that for one engine mode, we are not able to turn off spec decode at runtime.
711+
if (self.max_draft_len > 0
712+
and not self.spec_config.spec_dec_mode.use_one_engine()
713+
# Assume that speculation is always on if the user didn't give us a max_concurrency
714+
# value. This will save on memory.
715+
and self.spec_config.max_concurrency is not None):
716+
draft_lengths.append(0)
717+
draft_lengths = [self.max_draft_len]
713718

714719
for bs in cuda_graph_batch_sizes:
715720
if bs > self.batch_size:
@@ -725,6 +730,7 @@ def release_batch(result: ScheduledRequests | None):
725730
logger.info(
726731
f"Run generation only CUDA graph warmup for batch size={bs}, draft_len={draft_len}"
727732
)
733+
728734
self.enable_spec_decode = draft_len > 0 or self.is_draft_model
729735

730736
def _update_draft_inference_state(is_first_draft: bool,

0 commit comments

Comments
 (0)