@@ -391,6 +391,58 @@ def append_eagle3(tokens: torch.Tensor, model_outputs):
391391 d2t = model_outputs ["d2t" ][tokens ]
392392 tokens += d2t
393393
394+ @staticmethod
395+ def _apply_embedding_bias (
396+ logits : torch .Tensor ,
397+ requests : list [LlmRequest ],
398+ steps_per_request : list [int ] = None ) -> torch .Tensor :
399+ """Apply embedding bias (aka logit bias) to logits.
400+ If steps_per_request is None, assumes 1 step per request (non-batched path).
401+ """
402+ # Collect biases and their associated data
403+ bias_list = []
404+ bias_data = [] # Either indices (fast path) or steps (batched path)
405+
406+ for i , req in enumerate (requests ):
407+ bias = req ._py_embedding_bias_1d
408+ if bias is not None :
409+ bias_list .append (bias )
410+ bias_data .append (i if steps_per_request is
411+ None else steps_per_request [i ])
412+
413+ if not bias_list :
414+ return logits
415+
416+ bias_tensor = torch .stack (bias_list ).to (logits .device ,
417+ non_blocking = True )
418+ logits = logits .clone ()
419+
420+ if steps_per_request is None :
421+ # Fast path: direct indexing
422+ indices = torch .tensor (bias_data , device = logits .device )
423+ logits [indices ] += bias_tensor
424+ else :
425+ # Batched path: expand biases and use boolean mask
426+ expanded_biases = torch .repeat_interleave (bias_tensor ,
427+ torch .tensor (
428+ bias_data ,
429+ device = logits .device ),
430+ dim = 0 )
431+
432+ mask = torch .zeros (sum (steps_per_request ),
433+ dtype = torch .bool ,
434+ device = logits .device )
435+ offset = 0
436+ for i , req in enumerate (requests ):
437+ steps = steps_per_request [i ]
438+ if req ._py_embedding_bias_1d is not None :
439+ mask [offset :offset + steps ] = True
440+ offset += steps
441+
442+ logits [mask ] += expanded_biases
443+
444+ return logits
445+
394446 def _process_requests (self ,
395447 requests : list [LlmRequest ],
396448 model_outputs : dict [str , torch .Tensor ],
@@ -411,6 +463,7 @@ def _process_requests(self,
411463
412464 if fast_path :
413465 logits = raw_logits [:len (requests )]
466+ logits = self ._apply_embedding_bias (logits , requests )
414467 next_tokens = torch .argmax (logits , dim = - 1 )
415468 self .append_eagle3 (next_tokens , model_outputs )
416469 int_next_tokens = next_tokens .to (torch .int , non_blocking = True )
@@ -430,17 +483,29 @@ def _process_requests(self,
430483
431484 if batched_strategy is not None :
432485 logits = raw_logits [:sum_steps ]
486+ # Collect steps per request for batched strategy
487+ steps_per_request = [
488+ 1 + len (req .py_draft_tokens ) for req in requests
489+ ]
490+ logits = self ._apply_embedding_bias (logits , requests ,
491+ steps_per_request )
433492 batched_next_tokens , batched_softmax = sample (
434493 batched_strategy , logits )
435494 self .append_eagle3 (batched_next_tokens , model_outputs )
436495
437496 offset = 0
438- for strategy , slot , steps in zip (strategies , seq_slots , num_steps ):
497+ for i , (strategy , slot ,
498+ steps ) in enumerate (zip (strategies , seq_slots , num_steps )):
439499 input_slice = slice (offset , offset + steps )
440500 logits = raw_logits [input_slice ]
501+
502+ req = requests [i ]
503+
441504 if batched_next_tokens is None :
505+ logits = self ._apply_embedding_bias (logits , [req ])
442506 next_tokens , softmax = sample (strategy , logits )
443507 else :
508+ # Batched processing already applied bias, just use the results
444509 next_tokens = batched_next_tokens [input_slice ]
445510 softmax = batched_softmax [input_slice ]
446511 current_slice = slice (0 , steps ), slot , beam
0 commit comments