@@ -695,19 +695,25 @@ def release_batch(result: ScheduledRequests | None):
695695 cuda_graph_batch_sizes = sorted (self ._cuda_graph_batch_sizes ,
696696 reverse = True )
697697 # Create CUDA graphs for different draft lengths
698- draft_lengths = [self .max_draft_len ]
699- # For non-draft model, we also capture the CUDA graph instance for draft length 0,
700- # so that when we disable spec decode at runtime, we can still run the captured graph.
701- # Note that for one engine mode, we are not able to turn off spec decode at runtime.
702- if (not self .is_draft_model and self .max_draft_len > 0
703- and not self .spec_config .spec_dec_mode .use_one_engine ()
704- # Assume that speculation is always on if the user didn't give us a max_concurrency
705- # value. This will save on memory.
706- and self .spec_config .max_concurrency is not None ):
707- draft_lengths .append (0 )
708- if self .is_spec_decode and self .is_draft_model and spec_resource_manager is not None and isinstance (
709- spec_resource_manager , Eagle3ResourceManager ):
710- draft_lengths .append (self .original_max_draft_len )
698+ draft_lengths = []
699+ if self .is_draft_model :
700+ if self .model_is_wrapped and self .is_spec_decode and spec_resource_manager is not None and isinstance (
701+ spec_resource_manager , Eagle3ResourceManager ):
702+ # The CDL path uses draft_len > 0 for the number of iterations in the drafting loop.
703+ draft_lengths .append (self .original_max_draft_len )
704+ else :
705+ draft_lengths .append (self .max_draft_len )
706+ else :
707+ # For non-draft model, we also capture the CUDA graph instance for draft length 0,
708+ # so that when we disable spec decode at runtime, we can still run the captured graph.
709+ # Note that for one engine mode, we are not able to turn off spec decode at runtime.
710+ if (self .max_draft_len > 0
711+ and not self .spec_config .spec_dec_mode .use_one_engine ()
712+ # Assume that speculation is always on if the user didn't give us a max_concurrency
713+ # value. This will save on memory.
714+ and self .spec_config .max_concurrency is not None ):
715+ draft_lengths .append (0 )
716+ draft_lengths = [self .max_draft_len ]
711717
712718 for bs in cuda_graph_batch_sizes :
713719 if bs > self .batch_size :
@@ -723,6 +729,7 @@ def release_batch(result: ScheduledRequests | None):
723729 logger .info (
724730 f"Run generation only CUDA graph warmup for batch size={ bs } , draft_len={ draft_len } "
725731 )
732+
726733 self .enable_spec_decode = draft_len > 0 or self .is_draft_model
727734
728735 def _update_draft_inference_state (is_first_draft : bool ,
0 commit comments