Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 20 additions & 13 deletions tensorrt_llm/_torch/pyexecutor/model_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -695,19 +695,25 @@ def release_batch(result: ScheduledRequests | None):
cuda_graph_batch_sizes = sorted(self._cuda_graph_batch_sizes,
reverse=True)
# Create CUDA graphs for different draft lengths
draft_lengths = [self.max_draft_len]
# For non-draft model, we also capture the CUDA graph instance for draft length 0,
# so that when we disable spec decode at runtime, we can still run the captured graph.
# Note that for one engine mode, we are not able to turn off spec decode at runtime.
if (not self.is_draft_model and self.max_draft_len > 0
and not self.spec_config.spec_dec_mode.use_one_engine()
# Assume that speculation is always on if the user didn't give us a max_concurrency
# value. This will save on memory.
and self.spec_config.max_concurrency is not None):
draft_lengths.append(0)
if self.is_spec_decode and self.is_draft_model and spec_resource_manager is not None and isinstance(
spec_resource_manager, Eagle3ResourceManager):
draft_lengths.append(self.original_max_draft_len)
draft_lengths = []
if self.is_draft_model:
if self.model_is_wrapped and self.is_spec_decode and spec_resource_manager is not None and isinstance(
spec_resource_manager, Eagle3ResourceManager):
# The CDL path uses draft_len > 0 for the number of iterations in the drafting loop.
draft_lengths.append(self.original_max_draft_len)
else:
draft_lengths.append(self.max_draft_len)
else:
# For non-draft model, we also capture the CUDA graph instance for draft length 0,
# so that when we disable spec decode at runtime, we can still run the captured graph.
# Note that for one engine mode, we are not able to turn off spec decode at runtime.
if (self.max_draft_len > 0
and not self.spec_config.spec_dec_mode.use_one_engine()
# Assume that speculation is always on if the user didn't give us a max_concurrency
# value. This will save on memory.
and self.spec_config.max_concurrency is not None):
draft_lengths.append(0)
draft_lengths = [self.max_draft_len]

for bs in cuda_graph_batch_sizes:
if bs > self.batch_size:
Expand All @@ -723,6 +729,7 @@ def release_batch(result: ScheduledRequests | None):
logger.info(
f"Run generation only CUDA graph warmup for batch size={bs}, draft_len={draft_len}"
)

self.enable_spec_decode = draft_len > 0 or self.is_draft_model

def _update_draft_inference_state(is_first_draft: bool,
Expand Down
Loading