Skip to content

Commit da386ad

Browse files
xq25478huwen.hu@antgroup.com
authored andcommitted
imp(torchsampler):support openai stop in text level
Signed-off-by: xq25478 <[email protected]>
1 parent 69e9f6d commit da386ad

File tree

9 files changed

+120
-31
lines changed

9 files changed

+120
-31
lines changed

tensorrt_llm/_torch/pyexecutor/_util.py

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,8 @@
3434
SimpleScheduler)
3535
from .seq_slot_manager import SeqSlotManager
3636

37+
from transformers import PreTrainedTokenizerBase
38+
3739
GB = 1 << 30
3840

3941

@@ -542,7 +544,8 @@ def create_py_executor_instance(
542544

543545

544546
def create_torch_sampler_args(executor_config: ExecutorConfig, mapping: Mapping,
545-
*, max_seq_len: int, enable_mixed_sampler: bool):
547+
*, max_seq_len: int, enable_mixed_sampler: bool,
548+
tokenizer: PreTrainedTokenizerBase):
546549
max_num_sequences = executor_config.max_batch_size * mapping.pp_size
547550
max_draft_len = (0 if executor_config.speculative_config is None else
548551
executor_config.speculative_config.max_draft_len)
@@ -552,18 +555,22 @@ def create_torch_sampler_args(executor_config: ExecutorConfig, mapping: Mapping,
552555
max_num_sequences=max_num_sequences,
553556
max_beam_width=executor_config.max_beam_width,
554557
enable_mixed_sampler=enable_mixed_sampler,
558+
tokenizer=tokenizer
555559
)
556560

557561

558562
def instantiate_sampler(engine: PyTorchModelEngine,
559563
executor_config: ExecutorConfig,
560564
pytorch_backend_config: PyTorchConfig,
561-
mapping: Mapping):
565+
mapping: Mapping,
566+
tokenizer: Optional[PreTrainedTokenizerBase]):
562567
sampler_args = create_torch_sampler_args(
563568
executor_config,
564569
mapping,
565570
max_seq_len=engine.max_seq_len,
566-
enable_mixed_sampler=pytorch_backend_config.enable_mixed_sampler)
571+
enable_mixed_sampler=pytorch_backend_config.enable_mixed_sampler,
572+
tokenizer=tokenizer
573+
)
567574
if mapping.cp_config.get('cp_type') == 'star_attention':
568575
assert pytorch_backend_config.attn_backend == "FLASHINFER_STAR_ATTENTION", "attention backend of star attention should be 'FLASHINFER_STAR_ATTENTION'"
569576
return TorchSampler(sampler_args)
@@ -574,7 +581,8 @@ def instantiate_sampler(engine: PyTorchModelEngine,
574581
decoding_mode = get_decoding_mode(executor_config)
575582
return TRTLLMSampler(executor_config, engine.model, engine.dtype,
576583
mapping, decoding_mode,
577-
pytorch_backend_config.disable_overlap_scheduler)
584+
pytorch_backend_config.disable_overlap_scheduler,
585+
tokenizer)
578586
if not engine.model.model_config.is_generation:
579587
# NOTE: choose sampler based on model type
580588
return EarlyStopSampler()

tensorrt_llm/_torch/pyexecutor/py_executor_creator.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@
2828
from .model_engine import PyTorchModelEngine
2929
from .py_executor import PyExecutor
3030

31+
from transformers import PreTrainedTokenizerBase
3132

3233
class _ExecutorCreationStage(enum.Enum):
3334
SAMPLER = "Sampler"
@@ -185,7 +186,8 @@ def create_py_executor(
185186
executor_config: ExecutorConfig,
186187
checkpoint_dir: str = None,
187188
lora_config: Optional[LoraConfig] = None,
188-
garbage_collection_gen0_threshold: Optional[int] = None) -> PyExecutor:
189+
garbage_collection_gen0_threshold: Optional[int] = None,
190+
tokenizer:PreTrainedTokenizerBase = None) -> PyExecutor:
189191
_mangle_executor_config(executor_config)
190192
pytorch_backend_config = executor_config.pytorch_backend_config
191193

@@ -327,7 +329,7 @@ def create_py_executor(
327329

328330
with mem_monitor.observe_creation_stage(_ExecutorCreationStage.SAMPLER):
329331
sampler = instantiate_sampler(model_engine, executor_config,
330-
pytorch_backend_config, mapping)
332+
pytorch_backend_config, mapping,tokenizer)
331333

332334
guided_decoder: Optional[GuidedDecoder] = None
333335
if executor_config.guided_decoding_config is not None:

tensorrt_llm/_torch/pyexecutor/sampler.py

Lines changed: 27 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
from abc import ABC, abstractmethod
22
from collections.abc import Iterable
33
from dataclasses import dataclass
4-
from typing import Literal
4+
from typing import Literal,Union,List
55

66
import torch
77

@@ -26,6 +26,7 @@
2626
from .llm_request import LlmRequest, LlmRequestState
2727
from .scheduler import ScheduledRequests
2828

29+
from transformers import PreTrainedTokenizerBase
2930

3031
@dataclass(kw_only=True)
3132
class SampleStateTensors:
@@ -224,13 +225,15 @@ class Args:
224225
max_num_sequences: int
225226
max_beam_width: int
226227
enable_mixed_sampler: bool
228+
tokenizer: PreTrainedTokenizerBase
227229

228230
def __init__(self, args: Args):
229231
self.max_seq_len = args.max_seq_len
230232
self.enable_mixed_sampler = args.enable_mixed_sampler
231233
self.max_tokens = args.max_draft_len + 1
232234
assert args.max_beam_width == self.MAX_BEAM_WIDTH, "TorchSampler only supports beam_width = 1"
233235
self.num_seq_slots = args.max_num_sequences
236+
self.tokenizer = args.tokenizer
234237

235238
self.NEW_TOKENS_SHAPE = (self.max_tokens, self.num_seq_slots,
236239
self.MAX_BEAM_WIDTH)
@@ -247,22 +250,39 @@ def _meet_max_token_stop_criteria(self, request: LlmRequest):
247250
>= self.max_seq_len)
248251

249252
@staticmethod
250-
def _meet_stop_token_criteria(request: LlmRequest):
253+
def _meet_stop_token_criteria(
254+
request: LlmRequest,
255+
tokenizer: PreTrainedTokenizerBase,
256+
new_token: Union[int, List[int], torch.Tensor]
257+
):
251258
if request.py_stop_words_list:
252259
assert isinstance(
253260
request.py_stop_words_list,
254261
list), "request.py_stop_words_list should be a list"
262+
255263
stop_words_list, prefix_sum = request.py_stop_words_list
256264
tokens = request.get_tokens(0)
265+
try:
266+
new_words = tokenizer.decode(new_token,skip_special_tokens=False,clean_up_tokenization_spaces=False)
267+
except Exception:
268+
# If decode fails, fall back to token-based matching only
269+
new_words = ""
257270
offset = 0
258271
for i, offset_end in enumerate(prefix_sum):
259272
if i > 0:
260273
offset = prefix_sum[i - 1]
261274
stop_word = stop_words_list[offset:offset_end]
275+
try:
276+
stop_text = tokenizer.decode(stop_word, skip_special_tokens=False, clean_up_tokenization_spaces=False)
277+
except Exception:
278+
continue
262279
if len(stop_word) > len(tokens):
263280
continue
264281
if tokens[-len(stop_word):] == stop_word:
265282
return True
283+
if stop_text in new_words:
284+
return True
285+
266286
return False
267287

268288
def _handle_stop_criteria(self, request: LlmRequest,
@@ -277,7 +297,7 @@ def _handle_stop_criteria(self, request: LlmRequest,
277297
request.finish_by(FinishReason.LENGTH, self.BEAM)
278298
return True
279299

280-
if self._meet_stop_token_criteria(request):
300+
if self._meet_stop_token_criteria(request, self.tokenizer, new_token):
281301
request.finish_by(FinishReason.STOP_WORDS, self.BEAM)
282302
return True
283303

@@ -365,6 +385,7 @@ def gen_logits_host(self, requests: Iterable[LlmRequest], vocab_size: int):
365385

366386
def sample_async(self, scheduled_requests: ScheduledRequests,
367387
model_outputs: dict[str, torch.Tensor]) -> SampleState:
388+
368389
requests = scheduled_requests.all_requests()
369390
new_tokens = self.store.new_tokens
370391
vocab_size = model_outputs["logits"].shape[-1]
@@ -492,6 +513,7 @@ def __init__(
492513
mapping: Mapping,
493514
decoding_mode: DecodingMode,
494515
disable_overlap_scheduler: bool,
516+
tokenizer: PreTrainedTokenizerBase
495517
):
496518

497519
vocab_size = model.config.vocab_size
@@ -520,6 +542,8 @@ def __init__(
520542
num_hidden_layers, 0, num_heads,
521543
hidden_size, self.model_datatype)
522544

545+
self.tokenizer = tokenizer
546+
523547
self._initialize_store()
524548
self._instantiate_algorithms()
525549

@@ -625,7 +649,6 @@ def _update_cache_indirection_buffer(self,
625649
@nvtx_range("sample_async")
626650
def sample_async(self, scheduled_requests: ScheduledRequests,
627651
model_outputs) -> SampleStateTRTLLM:
628-
629652
batch_size = scheduled_requests.batch_size
630653
beam_width = self.beam_width(scheduled_requests.all_requests())
631654
if (batch_size > 1 and beam_width > 1

tensorrt_llm/executor/executor.py

Lines changed: 11 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,8 @@
3535
from .result import GenerationResult, IterationResult
3636
from .utils import IntraProcessQueue, ProcessPoolExecutorSession, RequestError
3737

38+
from transformers import PreTrainedTokenizerBase
39+
3840
if TYPE_CHECKING:
3941
from .proxy import GenerationExecutorProxy
4042
from .worker import GenerationExecutorWorker
@@ -352,6 +354,7 @@ def create(
352354
is_llm_executor: Optional[bool] = None,
353355
lora_config: Optional[LoraConfig] = None,
354356
garbage_collection_gen0_threshold: Optional[int] = None,
357+
tokenizer: Optional[PreTrainedTokenizerBase] = None
355358
) -> Union["GenerationExecutorProxy", "GenerationExecutorWorker"]:
356359
# local imports to avoid cyclic importing
357360
from .proxy import GenerationExecutorProxy
@@ -396,8 +399,8 @@ def create(
396399
mpi_session=mpi_session,
397400
postproc_worker_config=postproc_worker_config,
398401
is_llm_executor=is_llm_executor,
399-
garbage_collection_gen0_threshold=
400-
garbage_collection_gen0_threshold)
402+
garbage_collection_gen0_threshold=garbage_collection_gen0_threshold,
403+
tokenizer=tokenizer)
401404

402405
# WAR: For the performance of gathering logits, we use single process worker
403406
# for TP1 to avoid the large overhead of IPC.
@@ -409,8 +412,8 @@ def create(
409412
)
410413
return GenerationExecutorWorker(**worker_kwargs,
411414
is_llm_executor=is_llm_executor,
412-
garbage_collection_gen0_threshold=
413-
garbage_collection_gen0_threshold)
415+
garbage_collection_gen0_threshold=garbage_collection_gen0_threshold,
416+
tokenizer=tokenizer)
414417

415418
# For single-gpu case:
416419
# Partition the workload to multiple process for streaming performance.
@@ -423,8 +426,8 @@ def create(
423426
mpi_session=None, # use mpi4py
424427
postproc_worker_config=postproc_worker_config,
425428
is_llm_executor=is_llm_executor,
426-
garbage_collection_gen0_threshold=
427-
garbage_collection_gen0_threshold)
429+
garbage_collection_gen0_threshold=garbage_collection_gen0_threshold,
430+
tokenizer=tokenizer)
428431
else:
429432
ctx = multiprocessing.get_context("spawn")
430433
# The ProcessPoolExecutorSession is used to support Windows, as mpi4py cannot.
@@ -436,8 +439,8 @@ def create(
436439
mpi_session=mpi_session,
437440
postproc_worker_config=postproc_worker_config,
438441
is_llm_executor=is_llm_executor,
439-
garbage_collection_gen0_threshold=
440-
garbage_collection_gen0_threshold)
442+
garbage_collection_gen0_threshold=garbage_collection_gen0_threshold,
443+
tokenizer=tokenizer)
441444

442445
def wait_first_completed(
443446
self, futures: List[GenerationResult]

tensorrt_llm/executor/proxy.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,8 @@
2828
is_llm_response, print_alive_threads)
2929
from .worker import GenerationExecutorWorker, worker_main
3030

31+
from transformers import PreTrainedTokenizerBase
32+
3133
__all__ = [
3234
"GenerationExecutorProxy",
3335
]
@@ -46,6 +48,7 @@ def __init__(
4648
postproc_worker_config: Optional[PostprocWorkerConfig] = None,
4749
is_llm_executor: Optional[bool] = None,
4850
garbage_collection_gen0_threshold: Optional[int] = None,
51+
tokenizer: Optional[PreTrainedTokenizerBase] = None,
4952
) -> None:
5053
postproc_worker_config = postproc_worker_config or PostprocWorkerConfig(
5154
)
@@ -59,6 +62,7 @@ def __init__(
5962

6063
self.workers_started = False
6164
self.worker_cls = worker_cls
65+
self.tokenizer = tokenizer
6266

6367
mpi_process_pre_spawned: bool = get_spawn_proxy_process_env()
6468

@@ -94,7 +98,8 @@ def __init__(
9498
postproc_worker_config=postproc_worker_config,
9599
is_llm_executor=False,
96100
garbage_collection_gen0_threshold=self.
97-
garbage_collection_gen0_threshold)
101+
garbage_collection_gen0_threshold,
102+
tokenizer=tokenizer)
98103

99104
if "log_level" not in worker_kwargs:
100105
worker_kwargs["log_level"] = logger.level
@@ -410,7 +415,9 @@ def submit(self, request: GenerationRequest) -> GenerationResult:
410415
background_error_handler=self._handle_background_error,
411416
executor=self,
412417
disaggregated_params=request.disaggregated_params,
413-
logprob_params=logprob_params)
418+
logprob_params=logprob_params,
419+
tokenizer = self.tokenizer
420+
)
414421
self._results[request.id] = result
415422

416423
with nvtx_range_debug("request_queue.put"):

tensorrt_llm/executor/result.py

Lines changed: 39 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,8 @@
1818
from ..sampling_params import LogprobParams, SamplingParams
1919
from .utils import ErrorResponse, has_event_loop, is_llm_response
2020

21+
from transformers import PreTrainedTokenizerBase
22+
2123
if TYPE_CHECKING:
2224
from .executor import GenerationExecutor
2325
from .postproc_worker import PostprocParams, PostprocWorker
@@ -139,13 +141,16 @@ def __init__(self,
139141
id: int,
140142
sampling_params: SamplingParams,
141143
background_error_handler: Optional[Callable] = None,
142-
postproc_params: "Optional[PostprocParams]" = None):
144+
postproc_params: "Optional[PostprocParams]" = None,
145+
tokenizer: Optional[PreTrainedTokenizerBase] = None):
143146
self.id = id
144147
self.sampling_params = sampling_params
145148
self.postproc_params = postproc_params
146149
self.disaggregated_params = None
147150
self.decoding_iter = 0
148151
self._done = False
152+
self.tokenizer = tokenizer
153+
149154

150155
if has_event_loop():
151156
self.aqueue = AsyncQueue()
@@ -197,6 +202,28 @@ def outputs(self) -> List[CompletionOutput]:
197202
def context_logits(self) -> Optional[torch.Tensor]:
198203
return self._context_logits
199204

205+
def _check_text_stop_criteria(self, output, stop_reason: str, stop_ids: list) -> bool:
206+
"""Check if the stop text is found in newly generated tokens."""
207+
now_token_ids_len = len(output.token_ids)
208+
new_generated_token_ids = output.token_ids[output._last_token_ids_len:now_token_ids_len]
209+
210+
for idx in range(len(new_generated_token_ids)):
211+
if self.tokenizer is None:
212+
continue
213+
new_generated_text = self.tokenizer.decode(
214+
new_generated_token_ids[idx],
215+
skip_special_tokens=False,
216+
clean_up_tokenization_spaces=False
217+
)
218+
if stop_reason in new_generated_text:
219+
output.stop_reason = stop_reason
220+
if not self.sampling_params.include_stop_str_in_output:
221+
output.token_ids = output.token_ids[:output._last_token_ids_len + idx]
222+
else:
223+
output.token_ids = output.token_ids[:output._last_token_ids_len + idx] + stop_ids
224+
return True
225+
return False
226+
200227
def _handle_sequence(self,
201228
finish_reasons,
202229
response_tensors,
@@ -249,11 +276,15 @@ def _handle_sequence(self,
249276
output.finish_reason = 'stop'
250277
for stop_reason, stop_ids in self.sampling_params._get_stop_reasons_and_words(
251278
):
252-
if output.token_ids[-len(stop_ids):] == stop_ids:
253-
output.stop_reason = stop_reason
254-
if not self.sampling_params.include_stop_str_in_output:
255-
output.token_ids = output.token_ids[:-len(stop_ids)]
256-
break
279+
if isinstance(stop_reason, str):
280+
if self._check_text_stop_criteria(output, stop_reason, stop_ids):
281+
break
282+
else:
283+
if output.token_ids[-len(stop_ids):] == stop_ids:
284+
output.stop_reason = stop_reason
285+
if not self.sampling_params.include_stop_str_in_output:
286+
output.token_ids = output.token_ids[:-len(stop_ids)]
287+
break
257288
elif finish_reasons[src_idx] == tllm.FinishReason.LENGTH:
258289
output.finish_reason = 'length'
259290
elif finish_reasons[src_idx] == tllm.FinishReason.TIMED_OUT:
@@ -412,12 +443,14 @@ def __init__(
412443
executor: Optional["GenerationExecutor"] = None,
413444
disaggregated_params: Optional[DisaggregatedParams] = None,
414445
logprob_params: Optional[LogprobParams] = None,
446+
tokenizer: Optional[PreTrainedTokenizerBase] = None
415447
) -> None:
416448
super().__init__(
417449
generation_request.id,
418450
generation_request.sampling_params,
419451
background_error_handler,
420452
postproc_params=generation_request.postproc_params,
453+
tokenizer=tokenizer
421454
)
422455
self._generation_request = generation_request
423456
self._streaming = generation_request.streaming

0 commit comments

Comments
 (0)