1010from vllm .logger import init_logger
1111from vllm .model_executor .model_loader import get_model
1212from vllm .model_executor .models .llama_eagle3 import Eagle3LlamaForCausalLM
13- from vllm .triton_utils import tl , triton
14- from vllm . v1 . attention . backends . flash_attn import FlashAttentionMetadata
13+ from vllm .v1 . attention . backends . flash_attn import ( CommonAttentionMetadata ,
14+ FlashAttentionMetadata )
1515from vllm .v1 .sample .metadata import SamplingMetadata
16+ from vllm .v1 .spec_decode .utils import prepare_eagle_input_kernel
1617
1718logger = init_logger (__name__ )
1819
@@ -25,12 +26,15 @@ def __init__(
2526 self ,
2627 vllm_config : VllmConfig ,
2728 device : torch .device ,
29+ runner = None ,
2830 ):
2931 self .vllm_config = vllm_config
3032 self .speculative_config = vllm_config .speculative_config
3133 self .draft_model_config = self .speculative_config .draft_model_config
3234 self .method = self .speculative_config .method
3335
36+ self .runner = runner
37+
3438 self .dtype = vllm_config .model_config .dtype
3539 self .max_model_len = vllm_config .model_config .max_model_len
3640 self .block_size = vllm_config .cache_config .block_size
@@ -106,24 +110,46 @@ def propose(
106110 # FA requires seq_len to have dtype int32.
107111 seq_lens = (target_positions [last_token_indices ] + 1 ).int ()
108112
109- # FIXME(woosuk): The below two ops cause synchronization. Optimize.
110- max_seq_len = seq_lens .max ().item ()
111- max_num_tokens = (cu_num_tokens [1 :] - cu_num_tokens [:- 1 ]).max ().item ()
112- attn_metadata = FlashAttentionMetadata (
113- num_actual_tokens = num_tokens ,
114- max_query_len = max_num_tokens ,
115- query_start_loc = cu_num_tokens ,
116- max_seq_len = max_seq_len ,
117- seq_lens = seq_lens ,
118- block_table = block_table ,
119- slot_mapping = target_slot_mapping ,
120- # TODO(woosuk): Support cascade attention.
121- use_cascade = False ,
122- common_prefix_len = 0 ,
123- cu_prefix_query_lens = None ,
124- prefix_kv_lens = None ,
125- suffix_kv_lens = None ,
126- )
113+ if self .method in ["eagle" , "eagle3" ]:
114+ # FIXME(woosuk): The below two ops cause synchronization. Optimize.
115+ max_seq_len = seq_lens .max ().item ()
116+ max_num_tokens = (cu_num_tokens [1 :] -
117+ cu_num_tokens [:- 1 ]).max ().item ()
118+ attn_metadata = FlashAttentionMetadata (
119+ num_actual_tokens = num_tokens ,
120+ max_query_len = max_num_tokens ,
121+ query_start_loc = cu_num_tokens ,
122+ max_seq_len = max_seq_len ,
123+ seq_lens = seq_lens ,
124+ block_table = block_table ,
125+ slot_mapping = target_slot_mapping ,
126+ # TODO(woosuk): Support cascade attention.
127+ use_cascade = False ,
128+ common_prefix_len = 0 ,
129+ cu_prefix_query_lens = None ,
130+ prefix_kv_lens = None ,
131+ suffix_kv_lens = None ,
132+ )
133+ elif self .method == "deepseek_mtp" :
134+ query_lens = cu_num_tokens [1 :] - cu_num_tokens [:- 1 ]
135+ max_query_len = query_lens .max ().item ()
136+
137+ common_attn_metadata = CommonAttentionMetadata (
138+ query_start_loc = cu_num_tokens , seq_lens = seq_lens )
139+
140+ assert self .runner is not None
141+
142+ # FIXME: need to consider multiple kv_cache_groups
143+ attn_metadata = self .runner .attn_metadata_builder .build (
144+ num_reqs = batch_size ,
145+ num_actual_tokens = num_tokens ,
146+ max_query_len = max_query_len ,
147+ common_prefix_len = 0 ,
148+ common_attn_metadata = common_attn_metadata ,
149+ )
150+ else :
151+ raise ValueError (f"Unsupported method: { self .method } " )
152+
127153 if self .use_cuda_graph and \
128154 num_tokens <= self .cudagraph_batch_sizes [- 1 ]:
129155 num_input_tokens = self .vllm_config .pad_for_cudagraph (num_tokens )
@@ -136,11 +162,15 @@ def propose(
136162 with set_forward_context (attn_metadata ,
137163 self .vllm_config ,
138164 num_tokens = num_input_tokens ):
139- last_hidden_states , hidden_states = self .model (
140- input_ids = self .input_ids [:num_input_tokens ],
141- positions = self .positions [:num_input_tokens ],
142- hidden_states = self .hidden_states [:num_input_tokens ],
165+ ret_hidden_states = self .model (
166+ self .input_ids [:num_input_tokens ],
167+ self .positions [:num_input_tokens ],
168+ self .hidden_states [:num_input_tokens ],
143169 )
170+ if self .method == "deepseek_mtp" :
171+ last_hidden_states = ret_hidden_states
172+ else :
173+ last_hidden_states , hidden_states = ret_hidden_states
144174 sample_hidden_states = last_hidden_states [last_token_indices ]
145175 logits = self .model .compute_logits (sample_hidden_states , None )
146176 draft_token_ids = logits .argmax (dim = - 1 )
@@ -150,6 +180,10 @@ def propose(
150180 # [batch_size, 1]
151181 return draft_token_ids .view (- 1 , 1 )
152182
183+ # TODO: Currently, MTP module released by deepseek only has
184+ # one layer. Adapt this code to support multiple layers once
185+ # there's a multi-layer MTP module.
186+
153187 # Generate the remaining draft tokens.
154188 draft_token_ids_list = [draft_token_ids ]
155189
@@ -215,9 +249,9 @@ def propose(
215249 self .vllm_config ,
216250 num_tokens = input_batch_size ):
217251 last_hidden_states , hidden_states = self .model (
218- input_ids = self .input_ids [:input_batch_size ],
219- positions = self .positions [:input_batch_size ],
220- hidden_states = self .hidden_states [:input_batch_size ],
252+ self .input_ids [:input_batch_size ],
253+ self .positions [:input_batch_size ],
254+ self .hidden_states [:input_batch_size ],
221255 )
222256 hidden_states = hidden_states [:batch_size ]
223257 logits = self .model .compute_logits (last_hidden_states [:batch_size ],
@@ -268,7 +302,7 @@ def prepare_inputs(
268302
269303 batch_size = num_rejected_tokens .shape [0 ]
270304 BLOCK_SIZE = 1024
271- prepare_input_kernel [(batch_size , )](
305+ prepare_eagle_input_kernel [(batch_size , )](
272306 token_indices ,
273307 cu_target_query_lens ,
274308 cu_num_tokens ,
@@ -320,9 +354,9 @@ def dummy_run(
320354 with set_forward_context (None , self .vllm_config ,
321355 num_tokens = num_tokens ):
322356 self .model (
323- input_ids = self .input_ids [:num_tokens ],
324- positions = self .positions [:num_tokens ],
325- hidden_states = self .hidden_states [:num_tokens ],
357+ self .input_ids [:num_tokens ],
358+ self .positions [:num_tokens ],
359+ self .hidden_states [:num_tokens ],
326360 )
327361
328362
@@ -367,29 +401,3 @@ def compute_probs_and_sample_next_token(
367401 next_token_ids ,
368402 )
369403 return next_token_ids , probs
370-
371-
372- @triton .jit
373- def prepare_input_kernel (
374- out_ptr ,
375- cu_query_lens_ptr ,
376- cu_num_tokens_ptr ,
377- BLOCK_SIZE : tl .constexpr ,
378- ):
379- pid = tl .program_id (0 )
380-
381- # [start_pos, end_pos)
382- start_pos = tl .load (cu_num_tokens_ptr + pid )
383- end_pos = tl .load (cu_num_tokens_ptr + pid + 1 )
384- num_tokens = end_pos - start_pos
385-
386- index_start = tl .load (cu_query_lens_ptr + pid )
387-
388- num_blocks = tl .cdiv (num_tokens , BLOCK_SIZE )
389- for i in tl .range (num_blocks ):
390- offset = i * BLOCK_SIZE + tl .arange (0 , BLOCK_SIZE )
391- tl .store (
392- out_ptr + start_pos + offset ,
393- index_start + offset ,
394- mask = offset < num_tokens ,
395- )
0 commit comments