diff --git a/tensorrt_llm/_torch/pyexecutor/model_engine.py b/tensorrt_llm/_torch/pyexecutor/model_engine.py index db38219fe03..ef8bb0aea3b 100644 --- a/tensorrt_llm/_torch/pyexecutor/model_engine.py +++ b/tensorrt_llm/_torch/pyexecutor/model_engine.py @@ -695,19 +695,25 @@ def release_batch(result: ScheduledRequests | None): cuda_graph_batch_sizes = sorted(self._cuda_graph_batch_sizes, reverse=True) # Create CUDA graphs for different draft lengths - draft_lengths = [self.max_draft_len] - # For non-draft model, we also capture the CUDA graph instance for draft length 0, - # so that when we disable spec decode at runtime, we can still run the captured graph. - # Note that for one engine mode, we are not able to turn off spec decode at runtime. - if (not self.is_draft_model and self.max_draft_len > 0 - and not self.spec_config.spec_dec_mode.use_one_engine() - # Assume that speculation is always on if the user didn't give us a max_concurrency - # value. This will save on memory. - and self.spec_config.max_concurrency is not None): - draft_lengths.append(0) - if self.is_spec_decode and self.is_draft_model and spec_resource_manager is not None and isinstance( - spec_resource_manager, Eagle3ResourceManager): - draft_lengths.append(self.original_max_draft_len) + draft_lengths = [] + if self.is_draft_model: + if self.model_is_wrapped and self.is_spec_decode and spec_resource_manager is not None and isinstance( + spec_resource_manager, Eagle3ResourceManager): + # The CDL path uses draft_len > 0 for the number of iterations in the drafting loop. + draft_lengths.append(self.original_max_draft_len) + else: + draft_lengths.append(self.max_draft_len) + else: + # For non-draft model, we also capture the CUDA graph instance for draft length 0, + # so that when we disable spec decode at runtime, we can still run the captured graph. + # Note that for one engine mode, we are not able to turn off spec decode at runtime. + if (self.max_draft_len > 0 + and not self.spec_config.spec_dec_mode.use_one_engine() + # Assume that speculation is always on if the user didn't give us a max_concurrency + # value. This will save on memory. + and self.spec_config.max_concurrency is not None): + draft_lengths.append(0) + draft_lengths = [self.max_draft_len] for bs in cuda_graph_batch_sizes: if bs > self.batch_size: @@ -723,6 +729,7 @@ def release_batch(result: ScheduledRequests | None): logger.info( f"Run generation only CUDA graph warmup for batch size={bs}, draft_len={draft_len}" ) + self.enable_spec_decode = draft_len > 0 or self.is_draft_model def _update_draft_inference_state(is_first_draft: bool,