@@ -697,19 +697,24 @@ def release_batch(result: ScheduledRequests | None):
697
697
cuda_graph_batch_sizes = sorted (self ._cuda_graph_batch_sizes ,
698
698
reverse = True )
699
699
# Create CUDA graphs for different draft lengths
700
- draft_lengths = [self .max_draft_len ]
701
- # For non-draft model, we also capture the CUDA graph instance for draft length 0,
702
- # so that when we disable spec decode at runtime, we can still run the captured graph.
703
- # Note that for one engine mode, we are not able to turn off spec decode at runtime.
704
- if (not self .is_draft_model and self .max_draft_len > 0
705
- and not self .spec_config .spec_dec_mode .use_one_engine ()
706
- # Assume that speculation is always on if the user didn't give us a max_concurrency
707
- # value. This will save on memory.
708
- and self .spec_config .max_concurrency is not None ):
709
- draft_lengths .append (0 )
710
- if self .is_spec_decode and self .is_draft_model and spec_resource_manager is not None and isinstance (
711
- spec_resource_manager , Eagle3ResourceManager ):
712
- draft_lengths .append (self .original_max_draft_len )
700
+ draft_lengths = []
701
+ if self .is_draft_model :
702
+ if self .model_is_wrapped :
703
+ # The CDL path uses draft_len > 0 for the number of iterations in the drafting loop.
704
+ draft_lengths .append (self .original_max_draft_len )
705
+ else :
706
+ draft_lengths .append (self .max_draft_len )
707
+ else :
708
+ # For non-draft model, we also capture the CUDA graph instance for draft length 0,
709
+ # so that when we disable spec decode at runtime, we can still run the captured graph.
710
+ # Note that for one engine mode, we are not able to turn off spec decode at runtime.
711
+ if (self .max_draft_len > 0
712
+ and not self .spec_config .spec_dec_mode .use_one_engine ()
713
+ # Assume that speculation is always on if the user didn't give us a max_concurrency
714
+ # value. This will save on memory.
715
+ and self .spec_config .max_concurrency is not None ):
716
+ draft_lengths .append (0 )
717
+ draft_lengths = [self .max_draft_len ]
713
718
714
719
for bs in cuda_graph_batch_sizes :
715
720
if bs > self .batch_size :
@@ -725,6 +730,7 @@ def release_batch(result: ScheduledRequests | None):
725
730
logger .info (
726
731
f"Run generation only CUDA graph warmup for batch size={ bs } , draft_len={ draft_len } "
727
732
)
733
+
728
734
self .enable_spec_decode = draft_len > 0 or self .is_draft_model
729
735
730
736
def _update_draft_inference_state (is_first_draft : bool ,
0 commit comments