@@ -391,6 +391,58 @@ def append_eagle3(tokens: torch.Tensor, model_outputs):
391391            d2t  =  model_outputs ["d2t" ][tokens ]
392392            tokens  +=  d2t 
393393
394+     @staticmethod  
395+     def  _apply_embedding_bias (
396+             logits : torch .Tensor ,
397+             requests : list [LlmRequest ],
398+             steps_per_request : list [int ] =  None ) ->  torch .Tensor :
399+         """Apply embedding bias (aka logit bias) to logits. 
400+         If steps_per_request is None, assumes 1 step per request (non-batched path). 
401+         """ 
402+         # Collect biases and their associated data 
403+         bias_list  =  []
404+         bias_data  =  []  # Either indices (fast path) or steps (batched path) 
405+ 
406+         for  i , req  in  enumerate (requests ):
407+             bias  =  req ._py_embedding_bias_1d 
408+             if  bias  is  not None :
409+                 bias_list .append (bias )
410+                 bias_data .append (i  if  steps_per_request  is 
411+                                  None  else  steps_per_request [i ])
412+ 
413+         if  not  bias_list :
414+             return  logits 
415+ 
416+         bias_tensor  =  torch .stack (bias_list ).to (logits .device ,
417+                                                 non_blocking = True )
418+         logits  =  logits .clone ()
419+ 
420+         if  steps_per_request  is  None :
421+             # Fast path: direct indexing 
422+             indices  =  torch .tensor (bias_data , device = logits .device )
423+             logits [indices ] +=  bias_tensor 
424+         else :
425+             # Batched path: expand biases and use boolean mask 
426+             expanded_biases  =  torch .repeat_interleave (bias_tensor ,
427+                                                       torch .tensor (
428+                                                           bias_data ,
429+                                                           device = logits .device ),
430+                                                       dim = 0 )
431+ 
432+             mask  =  torch .zeros (sum (steps_per_request ),
433+                                dtype = torch .bool ,
434+                                device = logits .device )
435+             offset  =  0 
436+             for  i , req  in  enumerate (requests ):
437+                 steps  =  steps_per_request [i ]
438+                 if  req ._py_embedding_bias_1d  is  not None :
439+                     mask [offset :offset  +  steps ] =  True 
440+                 offset  +=  steps 
441+ 
442+             logits [mask ] +=  expanded_biases 
443+ 
444+         return  logits 
445+ 
394446    def  _process_requests (self ,
395447                          requests : list [LlmRequest ],
396448                          model_outputs : dict [str , torch .Tensor ],
@@ -411,6 +463,7 @@ def _process_requests(self,
411463
412464        if  fast_path :
413465            logits  =  raw_logits [:len (requests )]
466+             logits  =  self ._apply_embedding_bias (logits , requests )
414467            next_tokens  =  torch .argmax (logits , dim = - 1 )
415468            self .append_eagle3 (next_tokens , model_outputs )
416469            int_next_tokens  =  next_tokens .to (torch .int , non_blocking = True )
@@ -430,17 +483,29 @@ def _process_requests(self,
430483
431484        if  batched_strategy  is  not None :
432485            logits  =  raw_logits [:sum_steps ]
486+             # Collect steps per request for batched strategy 
487+             steps_per_request  =  [
488+                 1  +  len (req .py_draft_tokens ) for  req  in  requests 
489+             ]
490+             logits  =  self ._apply_embedding_bias (logits , requests ,
491+                                                 steps_per_request )
433492            batched_next_tokens , batched_softmax  =  sample (
434493                batched_strategy , logits )
435494            self .append_eagle3 (batched_next_tokens , model_outputs )
436495
437496        offset  =  0 
438-         for  strategy , slot , steps  in  zip (strategies , seq_slots , num_steps ):
497+         for  i , (strategy , slot ,
498+                 steps ) in  enumerate (zip (strategies , seq_slots , num_steps )):
439499            input_slice  =  slice (offset , offset  +  steps )
440500            logits  =  raw_logits [input_slice ]
501+ 
502+             req  =  requests [i ]
503+ 
441504            if  batched_next_tokens  is  None :
505+                 logits  =  self ._apply_embedding_bias (logits , [req ])
442506                next_tokens , softmax  =  sample (strategy , logits )
443507            else :
508+                 # Batched processing already applied bias, just use the results 
444509                next_tokens  =  batched_next_tokens [input_slice ]
445510                softmax  =  batched_softmax [input_slice ]
446511            current_slice  =  slice (0 , steps ), slot , beam 
0 commit comments