@@ -695,19 +695,24 @@ def release_batch(result: ScheduledRequests | None):
695
695
cuda_graph_batch_sizes = sorted (self ._cuda_graph_batch_sizes ,
696
696
reverse = True )
697
697
# Create CUDA graphs for different draft lengths
698
- draft_lengths = [self .max_draft_len ]
699
- # For non-draft model, we also capture the CUDA graph instance for draft length 0,
700
- # so that when we disable spec decode at runtime, we can still run the captured graph.
701
- # Note that for one engine mode, we are not able to turn off spec decode at runtime.
702
- if (not self .is_draft_model and self .max_draft_len > 0
703
- and not self .spec_config .spec_dec_mode .use_one_engine ()
704
- # Assume that speculation is always on if the user didn't give us a max_concurrency
705
- # value. This will save on memory.
706
- and self .spec_config .max_concurrency is not None ):
707
- draft_lengths .append (0 )
708
- if self .is_spec_decode and self .is_draft_model and spec_resource_manager is not None and isinstance (
709
- spec_resource_manager , Eagle3ResourceManager ):
710
- draft_lengths .append (self .original_max_draft_len )
698
+ draft_lengths = []
699
+ if self .is_draft_model :
700
+ if self .model_is_wrapped :
701
+ # The CDL path uses draft_len > 0 for the number of iterations in the drafting loop.
702
+ draft_lengths .append (self .original_max_draft_len )
703
+ else :
704
+ draft_lengths .append (self .max_draft_len )
705
+ else :
706
+ # For non-draft model, we also capture the CUDA graph instance for draft length 0,
707
+ # so that when we disable spec decode at runtime, we can still run the captured graph.
708
+ # Note that for one engine mode, we are not able to turn off spec decode at runtime.
709
+ if (self .max_draft_len > 0
710
+ and not self .spec_config .spec_dec_mode .use_one_engine ()
711
+ # Assume that speculation is always on if the user didn't give us a max_concurrency
712
+ # value. This will save on memory.
713
+ and self .spec_config .max_concurrency is not None ):
714
+ draft_lengths .append (0 )
715
+ draft_lengths = [self .max_draft_len ]
711
716
712
717
for bs in cuda_graph_batch_sizes :
713
718
if bs > self .batch_size :
@@ -723,6 +728,7 @@ def release_batch(result: ScheduledRequests | None):
723
728
logger .info (
724
729
f"Run generation only CUDA graph warmup for batch size={ bs } , draft_len={ draft_len } "
725
730
)
731
+
726
732
self .enable_spec_decode = draft_len > 0 or self .is_draft_model
727
733
728
734
def _update_draft_inference_state (is_first_draft : bool ,
0 commit comments