11from abc import ABC , abstractmethod
22from collections .abc import Iterable
33from dataclasses import dataclass
4- from typing import Literal
4+ from typing import Literal , Union , List
55
66import torch
77
2626from .llm_request import LlmRequest , LlmRequestState
2727from .scheduler import ScheduledRequests
2828
29+ from transformers import PreTrainedTokenizerBase
2930
3031@dataclass (kw_only = True )
3132class SampleStateTensors :
@@ -224,13 +225,15 @@ class Args:
224225 max_num_sequences : int
225226 max_beam_width : int
226227 enable_mixed_sampler : bool
228+ tokenizer : PreTrainedTokenizerBase
227229
228230 def __init__ (self , args : Args ):
229231 self .max_seq_len = args .max_seq_len
230232 self .enable_mixed_sampler = args .enable_mixed_sampler
231233 self .max_tokens = args .max_draft_len + 1
232234 assert args .max_beam_width == self .MAX_BEAM_WIDTH , "TorchSampler only supports beam_width = 1"
233235 self .num_seq_slots = args .max_num_sequences
236+ self .tokenizer = args .tokenizer
234237
235238 self .NEW_TOKENS_SHAPE = (self .max_tokens , self .num_seq_slots ,
236239 self .MAX_BEAM_WIDTH )
@@ -247,22 +250,39 @@ def _meet_max_token_stop_criteria(self, request: LlmRequest):
247250 >= self .max_seq_len )
248251
249252 @staticmethod
250- def _meet_stop_token_criteria (request : LlmRequest ):
253+ def _meet_stop_token_criteria (
254+ request : LlmRequest ,
255+ tokenizer : PreTrainedTokenizerBase ,
256+ new_token : Union [int , List [int ], torch .Tensor ]
257+ ):
251258 if request .py_stop_words_list :
252259 assert isinstance (
253260 request .py_stop_words_list ,
254261 list ), "request.py_stop_words_list should be a list"
262+
255263 stop_words_list , prefix_sum = request .py_stop_words_list
256264 tokens = request .get_tokens (0 )
265+ try :
266+ new_words = tokenizer .decode (new_token ,skip_special_tokens = False ,clean_up_tokenization_spaces = False )
267+ except Exception :
268+ # If decode fails, fall back to token-based matching only
269+ new_words = ""
257270 offset = 0
258271 for i , offset_end in enumerate (prefix_sum ):
259272 if i > 0 :
260273 offset = prefix_sum [i - 1 ]
261274 stop_word = stop_words_list [offset :offset_end ]
275+ try :
276+ stop_text = tokenizer .decode (stop_word , skip_special_tokens = False , clean_up_tokenization_spaces = False )
277+ except Exception :
278+ continue
262279 if len (stop_word ) > len (tokens ):
263280 continue
264281 if tokens [- len (stop_word ):] == stop_word :
265282 return True
283+ if stop_text in new_words :
284+ return True
285+
266286 return False
267287
268288 def _handle_stop_criteria (self , request : LlmRequest ,
@@ -277,7 +297,7 @@ def _handle_stop_criteria(self, request: LlmRequest,
277297 request .finish_by (FinishReason .LENGTH , self .BEAM )
278298 return True
279299
280- if self ._meet_stop_token_criteria (request ):
300+ if self ._meet_stop_token_criteria (request , self . tokenizer , new_token ):
281301 request .finish_by (FinishReason .STOP_WORDS , self .BEAM )
282302 return True
283303
@@ -365,6 +385,7 @@ def gen_logits_host(self, requests: Iterable[LlmRequest], vocab_size: int):
365385
366386 def sample_async (self , scheduled_requests : ScheduledRequests ,
367387 model_outputs : dict [str , torch .Tensor ]) -> SampleState :
388+
368389 requests = scheduled_requests .all_requests ()
369390 new_tokens = self .store .new_tokens
370391 vocab_size = model_outputs ["logits" ].shape [- 1 ]
@@ -492,6 +513,7 @@ def __init__(
492513 mapping : Mapping ,
493514 decoding_mode : DecodingMode ,
494515 disable_overlap_scheduler : bool ,
516+ tokenizer : PreTrainedTokenizerBase
495517 ):
496518
497519 vocab_size = model .config .vocab_size
@@ -520,6 +542,8 @@ def __init__(
520542 num_hidden_layers , 0 , num_heads ,
521543 hidden_size , self .model_datatype )
522544
545+ self .tokenizer = tokenizer
546+
523547 self ._initialize_store ()
524548 self ._instantiate_algorithms ()
525549
@@ -625,7 +649,6 @@ def _update_cache_indirection_buffer(self,
625649 @nvtx_range ("sample_async" )
626650 def sample_async (self , scheduled_requests : ScheduledRequests ,
627651 model_outputs ) -> SampleStateTRTLLM :
628-
629652 batch_size = scheduled_requests .batch_size
630653 beam_width = self .beam_width (scheduled_requests .all_requests ())
631654 if (batch_size > 1 and beam_width > 1
0 commit comments