@@ -140,6 +140,13 @@ class MTPSpecMetadata(SpecMetadata):
140140 slot_ids : Optional [torch .Tensor ] = None
141141 # The index of the batche inputs
142142 batch_indices_cuda : Optional [torch .Tensor ] = None
143+ # The number of sequences for speculative model/layer of different rank
144+ _all_rank_num_seqs : Optional [List [int ]] = None
145+ # This is used for attention dp in the MTP Eagle worker. The numbers of input
146+ # tokens varies between the 1st draft forward and subsequent ones. To support
147+ # CUDA graph, we use this tensor to store the number of input tokens for the
148+ # subsequence draft forward.
149+ subseq_all_rank_num_tokens : Optional [List [int ]] = None
143150
144151 def __post_init__ (self ) -> None :
145152 if self .mtp_hidden_states_manager is not None :
@@ -166,6 +173,16 @@ def __post_init__(self) -> None:
166173 device = 'cuda' ,
167174 )
168175
176+ @property
177+ def all_rank_num_seqs (self ):
178+ return self ._all_rank_num_seqs
179+
180+ @all_rank_num_seqs .setter
181+ def all_rank_num_seqs (self , value : List [int ]):
182+ self ._all_rank_num_seqs = value
183+ if self .spec_dec_mode .is_mtp_eagle ():
184+ self .subseq_all_rank_num_tokens = value
185+
169186 def prepare (self ):
170187 assert self .request_ids is not None
171188 num_seqs = len (self .request_ids )
@@ -176,10 +193,11 @@ def prepare(self):
176193 pin_memory = True )
177194 self .batch_indices_cuda [:num_seqs ].copy_ (batch_indices ,
178195 non_blocking = True )
179- # MTP module need different number of input tokens in generation phase
180- if self .spec_dec_mode .is_mtp_eagle ():
181- self .num_tokens -= (self .num_generations ) * self .mtp_num_modules
182- else :
196+ # MTP vanilla worker uses total max_draft_tokens input tokens in generation phase,
197+ # while MTP Eagle worker uses (max_draft_tokens + 1) input tokens in the 1st draft
198+ # forward and only one input token in the following draft forward.
199+ # This num_tokens is used to set the all_rank_num_tokens for attention dp.
200+ if not self .spec_dec_mode .is_mtp_eagle ():
183201 self .num_tokens -= self .num_generations
184202
185203 if self .mtp_hidden_states_manager is not None : # MTP vanilla or use relaxed acceptance
@@ -375,9 +393,9 @@ def forward(
375393 num_accepted_tokens = num_accepted_tokens ,
376394 spec_metadata = spec_metadata ,
377395 attn_metadata = attn_metadata )
378- hidden_states , logits = mtp_layer (lm_head = lm_head ,
379- embed_tokens = embed_tokens ,
380- ** draft_inputs )
396+ hidden_states = mtp_layer (embed_tokens = embed_tokens , ** draft_inputs )
397+ logits = mtp_layer . shared_head ( hidden_states , lm_head ,
398+ attn_metadata ). float ( )
381399 previous_layer_draft_tokens = self .draft_sampler (logits )
382400 next_draft_tokens .append (previous_layer_draft_tokens )
383401
@@ -727,12 +745,13 @@ def sample_and_accept_draft_tokens(
727745 logits = logits .unsqueeze (0 )
728746
729747 # The return buffer
730- accepted_tokens = torch .empty ((batch_size , (mtp_num_modules + 1 )),
731- dtype = torch .int ,
732- device = logits .device )
733- num_accepted_tokens = torch .ones (batch_size ,
748+ if self .spec_config .use_relaxed_acceptance_for_thinking or not self .is_thop :
749+ accepted_tokens = torch .ones ((batch_size , (mtp_num_modules + 1 )),
734750 dtype = torch .int ,
735751 device = logits .device )
752+ num_accepted_tokens = torch .ones (batch_size ,
753+ dtype = torch .int ,
754+ device = logits .device )
736755 if self .spec_config .use_relaxed_acceptance_for_thinking :
737756 mtp_relaxed_delta_pool = spec_metadata .mtp_hidden_states_manager .mtp_relaxed_delta_pool
738757
@@ -1021,7 +1040,6 @@ def prepare_drafter_inputs(
10211040 "position_ids" : position_ids ,
10221041 "hidden_states" : return_hidden_states ,
10231042 "attn_metadata" : attn_metadata ,
1024- "spec_metadata" : spec_metadata ,
10251043 }
10261044
10271045 def draft_sampler (
@@ -1066,6 +1084,7 @@ def forward(
10661084 ):
10671085 batch_size = attn_metadata .num_seqs
10681086 num_contexts = attn_metadata .num_contexts
1087+ num_gens = batch_size - num_contexts
10691088
10701089 # Sample and verify draft tokens
10711090 raw_logits = logits
@@ -1079,58 +1098,79 @@ def forward(
10791098
10801099 # Prepare inputs for the 1st MTP layer
10811100 position_ids = position_ids .squeeze (0 )
1082- inputs = self .prepare_drafter_inputs (
1083- input_ids = input_ids ,
1084- position_ids = position_ids ,
1085- hidden_states = hidden_states ,
1086- accepted_tokens = accepted_tokens ,
1087- num_accepted_tokens = num_accepted_tokens ,
1088- attn_metadata = attn_metadata ,
1089- spec_metadata = spec_metadata )
1101+ last_tokens_idx = torch .cumsum (
1102+ attn_metadata .seq_lens_cuda , dim = 0 , dtype = torch .long ) - 1
1103+ inputs = self .prepare_drafter_inputs (input_ids = input_ids ,
1104+ position_ids = position_ids ,
1105+ last_tokens_idx = last_tokens_idx ,
1106+ hidden_states = hidden_states ,
1107+ accepted_tokens = accepted_tokens ,
1108+ attn_metadata = attn_metadata ,
1109+ spec_metadata = spec_metadata )
10901110
10911111 # Predict draft tokens
10921112 next_draft_tokens = []
10931113 for i in range (self .mtp_num_modules ):
1094- hidden_states , logits = mtp_layers [0 ](lm_head = lm_head ,
1095- embed_tokens = embed_tokens ,
1096- ** inputs )
1114+ if i == 0 :
1115+ hidden_states = mtp_layers [0 ](
1116+ embed_tokens = embed_tokens ,
1117+ all_rank_num_tokens = spec_metadata .all_rank_num_tokens ,
1118+ ** inputs )
1119+ start_ids_gen = (spec_metadata .batch_indices_cuda [:num_gens ] *
1120+ (self .mtp_num_modules + 1 )).long ()
1121+ gather_ids_gen = (start_ids_gen +
1122+ num_accepted_tokens [num_contexts :] - 1 +
1123+ attn_metadata .num_ctx_tokens )
1124+ gather_ids = torch .concat (
1125+ [last_tokens_idx [:num_contexts ], gather_ids_gen ], dim = 0 )
1126+ else :
1127+ hidden_states = mtp_layers [0 ](embed_tokens = embed_tokens ,
1128+ all_rank_num_tokens = spec_metadata .
1129+ subseq_all_rank_num_tokens ,
1130+ ** inputs )
1131+ # All of the seq_len are 1, use batch_indices_cuda as gather_ids
1132+ gather_ids = spec_metadata .batch_indices_cuda [:batch_size ]
1133+ logits = mtp_layers [0 ].shared_head (hidden_states [gather_ids ],
1134+ lm_head , attn_metadata , True )
10971135 new_draft_token = self .draft_sampler (logits )
10981136 next_draft_tokens .append (new_draft_token )
10991137 # update inputs
1100- last_tokens = torch .cumsum (
1101- attn_metadata .seq_lens_cuda ,
1102- dim = 0 ,
1103- dtype = torch .long ,
1104- ) - 1
1105- position_ids = inputs ["position_ids" ][last_tokens ] + 1
1106- hidden_states = hidden_states [last_tokens ]
1107- attn_metadata ._seq_lens [:attn_metadata .num_contexts ].fill_ (1 )
1108- attn_metadata ._seq_lens_cuda [:attn_metadata .num_contexts ].fill_ (1 )
1109- attn_metadata .on_update ()
1110- # cannot run generation if their is no kv cache
1111- if inputs ["attn_metadata" ].kv_cache_manager is not None :
1112- attn_metadata .host_request_types [:attn_metadata .
1113- num_contexts ].fill_ (1 )
1114- attn_metadata .num_contexts = 0
1115- if i == 0 and num_contexts > 0 and attn_metadata .enable_flash_mla :
1138+ hidden_states = hidden_states [gather_ids ]
1139+ position_ids = inputs ["position_ids" ][gather_ids ] + 1
1140+ # update attn_metadata
1141+ if i == 0 :
1142+ attn_metadata ._seq_lens [:batch_size ].fill_ (1 )
1143+ attn_metadata ._seq_lens_cuda [:batch_size ].fill_ (1 )
1144+ attn_metadata .on_update ()
1145+ # cannot run generation if their is no kv cache
1146+ has_kv_cache = inputs [
1147+ "attn_metadata" ].kv_cache_manager is not None
1148+ if has_kv_cache :
1149+ attn_metadata .host_request_types [:attn_metadata .
1150+ num_contexts ].fill_ (1 )
1151+ attn_metadata .num_contexts = 0
1152+ # update kv_lens_cuda
1153+ if hasattr (attn_metadata , 'kv_lens_cuda' ):
1154+ attn_metadata .kv_lens_cuda [num_contexts :batch_size ] -= (
1155+ self .mtp_num_modules -
1156+ num_accepted_tokens [num_contexts :])
1157+ attn_metadata .kv_lens_cuda [:num_contexts ] += 1
1158+ # update metadata for flash mla
1159+ if has_kv_cache and num_contexts > 0 and attn_metadata .enable_flash_mla :
11161160 reorder_block_ids_per_seq = torch .cat ([
11171161 attn_metadata .
11181162 kv_block_ids_per_seq [num_contexts :batch_size ],
11191163 attn_metadata .kv_block_ids_per_seq [:num_contexts ]
11201164 ])
11211165 attn_metadata .block_ids_per_seq [:batch_size , :].copy_ (
11221166 reorder_block_ids_per_seq , non_blocking = True )
1123- if hasattr (attn_metadata , 'kv_lens_cuda' ):
1167+ elif hasattr (attn_metadata , 'kv_lens_cuda' ):
11241168 attn_metadata .kv_lens_cuda [:batch_size ] += 1
1125- # support attention dp
1126- if spec_metadata .all_rank_num_tokens is not None :
1127- spec_metadata .all_rank_num_tokens = spec_metadata .all_rank_num_seqs
11281169 inputs = {
11291170 "input_ids" : new_draft_token ,
11301171 "position_ids" : position_ids ,
11311172 "hidden_states" : hidden_states ,
11321173 "attn_metadata" : attn_metadata ,
1133- "spec_metadata" : spec_metadata ,
11341174 }
11351175 next_draft_tokens = torch .stack (next_draft_tokens , dim = 1 )
11361176
@@ -1159,66 +1199,32 @@ def prepare_drafter_inputs(
11591199 self ,
11601200 input_ids : torch .IntTensor ,
11611201 position_ids : torch .IntTensor ,
1202+ last_tokens_idx : torch .LongTensor ,
11621203 hidden_states : torch .Tensor ,
11631204 accepted_tokens : torch .Tensor ,
1164- num_accepted_tokens : torch .Tensor ,
11651205 attn_metadata : AttentionMetadata ,
11661206 spec_metadata : MTPSpecMetadata ,
11671207 ):
1168- batch_size = attn_metadata .num_seqs
11691208 num_contexts = attn_metadata .num_contexts
1170- num_gens = batch_size - num_contexts
1171- num_ctx_tokens = attn_metadata .num_ctx_tokens
1172- hidden_size = hidden_states .shape [1 ]
1173- last_tokens_idx = torch .cumsum (
1174- attn_metadata .seq_lens_cuda , dim = 0 , dtype = torch .long ) - 1
11751209
11761210 # context
1177- hidden_states_ctx = hidden_states [:attn_metadata .num_ctx_tokens , :]
11781211 input_ctx_ids = input_ids [:attn_metadata .num_ctx_tokens ]
11791212 input_ids_ctx = torch .empty_like (input_ctx_ids ,
11801213 dtype = torch .int32 ,
11811214 device = "cuda" )
11821215 input_ids_ctx [:- 1 ].copy_ (input_ctx_ids [1 :])
11831216 input_ids_ctx [
11841217 last_tokens_idx [:num_contexts ]] = accepted_tokens [:num_contexts , 0 ]
1185- position_ids_ctx = position_ids [:num_ctx_tokens ]
11861218
11871219 # generation
1188- gen_batch_idx = spec_metadata .batch_indices_cuda [:num_gens ]
1189- gen_token_idx = num_accepted_tokens [num_contexts :] - 1
1190- hidden_states_gen = hidden_states [attn_metadata .num_ctx_tokens :, :]
1191- hidden_states_gen = hidden_states_gen .reshape (num_gens ,
1192- self .mtp_num_modules + 1 ,
1193- hidden_size )
1194- hidden_states_gen = hidden_states_gen [gen_batch_idx , gen_token_idx , :]
1195- accepted_tokens_gen = accepted_tokens [num_contexts :, :]
1196- input_ids_gen = accepted_tokens_gen [gen_batch_idx , gen_token_idx ]
1197- position_ids_gen = position_ids [num_ctx_tokens :].reshape (
1198- num_gens , self .mtp_num_modules + 1 )
1199- position_ids_gen = position_ids_gen [gen_batch_idx , gen_token_idx ]
1220+ input_ids_gen = accepted_tokens [num_contexts :, :].flatten ()
12001221
12011222 # get draft inputs
12021223 input_ids = torch .concat ([input_ids_ctx , input_ids_gen ], dim = 0 )
1203- hidden_states = torch .concat ([hidden_states_ctx , hidden_states_gen ],
1204- dim = 0 )
1205- position_ids = torch .concat ([position_ids_ctx , position_ids_gen ], dim = 0 )
1206-
1207- # change attn_metadata
1208- attn_metadata ._seq_lens [num_contexts :batch_size ].fill_ (1 )
1209- attn_metadata ._seq_lens_cuda [num_contexts :batch_size ].fill_ (1 )
1210- attn_metadata .on_update ()
1211- if hasattr (attn_metadata , 'kv_lens_cuda' ):
1212- # Note that it's important to not free the seq_lens_cuda
1213- # buffer once the graph has been captured also - this will invalidate
1214- # the graph and force an expensive recapture.
1215- attn_metadata .kv_lens_cuda [num_contexts :batch_size ] -= (
1216- self .mtp_num_modules + 1 - num_accepted_tokens [num_contexts :])
12171224
12181225 return {
12191226 "input_ids" : input_ids ,
12201227 "position_ids" : position_ids ,
12211228 "hidden_states" : hidden_states ,
12221229 "attn_metadata" : attn_metadata ,
1223- "spec_metadata" : spec_metadata ,
12241230 }
0 commit comments