Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 12 additions & 4 deletions tensorrt_llm/_torch/pyexecutor/_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,8 @@
SimpleScheduler)
from .seq_slot_manager import SeqSlotManager

from transformers import PreTrainedTokenizerBase

GB = 1 << 30


Expand Down Expand Up @@ -542,7 +544,8 @@ def create_py_executor_instance(


def create_torch_sampler_args(executor_config: ExecutorConfig, mapping: Mapping,
*, max_seq_len: int, enable_mixed_sampler: bool):
*, max_seq_len: int, enable_mixed_sampler: bool,
tokenizer: PreTrainedTokenizerBase):
max_num_sequences = executor_config.max_batch_size * mapping.pp_size
max_draft_len = (0 if executor_config.speculative_config is None else
executor_config.speculative_config.max_draft_len)
Expand All @@ -552,18 +555,22 @@ def create_torch_sampler_args(executor_config: ExecutorConfig, mapping: Mapping,
max_num_sequences=max_num_sequences,
max_beam_width=executor_config.max_beam_width,
enable_mixed_sampler=enable_mixed_sampler,
tokenizer=tokenizer
)


def instantiate_sampler(engine: PyTorchModelEngine,
executor_config: ExecutorConfig,
pytorch_backend_config: PyTorchConfig,
mapping: Mapping):
mapping: Mapping,
tokenizer: Optional[PreTrainedTokenizerBase]):
sampler_args = create_torch_sampler_args(
executor_config,
mapping,
max_seq_len=engine.max_seq_len,
enable_mixed_sampler=pytorch_backend_config.enable_mixed_sampler)
enable_mixed_sampler=pytorch_backend_config.enable_mixed_sampler,
tokenizer=tokenizer
)
if mapping.cp_config.get('cp_type') == 'star_attention':
assert pytorch_backend_config.attn_backend == "FLASHINFER_STAR_ATTENTION", "attention backend of star attention should be 'FLASHINFER_STAR_ATTENTION'"
return TorchSampler(sampler_args)
Expand All @@ -574,7 +581,8 @@ def instantiate_sampler(engine: PyTorchModelEngine,
decoding_mode = get_decoding_mode(executor_config)
return TRTLLMSampler(executor_config, engine.model, engine.dtype,
mapping, decoding_mode,
pytorch_backend_config.disable_overlap_scheduler)
pytorch_backend_config.disable_overlap_scheduler,
tokenizer)
if not engine.model.model_config.is_generation:
# NOTE: choose sampler based on model type
return EarlyStopSampler()
Expand Down
6 changes: 4 additions & 2 deletions tensorrt_llm/_torch/pyexecutor/py_executor_creator.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
from .model_engine import PyTorchModelEngine
from .py_executor import PyExecutor

from transformers import PreTrainedTokenizerBase

class _ExecutorCreationStage(enum.Enum):
SAMPLER = "Sampler"
Expand Down Expand Up @@ -185,7 +186,8 @@ def create_py_executor(
executor_config: ExecutorConfig,
checkpoint_dir: str = None,
lora_config: Optional[LoraConfig] = None,
garbage_collection_gen0_threshold: Optional[int] = None) -> PyExecutor:
garbage_collection_gen0_threshold: Optional[int] = None,
tokenizer:PreTrainedTokenizerBase = None) -> PyExecutor:
_mangle_executor_config(executor_config)
pytorch_backend_config = executor_config.pytorch_backend_config

Expand Down Expand Up @@ -327,7 +329,7 @@ def create_py_executor(

with mem_monitor.observe_creation_stage(_ExecutorCreationStage.SAMPLER):
sampler = instantiate_sampler(model_engine, executor_config,
pytorch_backend_config, mapping)
pytorch_backend_config, mapping, tokenizer)

guided_decoder: Optional[GuidedDecoder] = None
if executor_config.guided_decoding_config is not None:
Expand Down
101 changes: 65 additions & 36 deletions tensorrt_llm/_torch/pyexecutor/sampler.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from abc import ABC, abstractmethod
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

💡 Verification agent

🧩 Analysis chain

Verify NVIDIA copyright header presence.

According to the coding guidelines, all TensorRT-LLM source files should contain an NVIDIA copyright header with the current year. Please ensure this file includes the required copyright header.


🏁 Script executed:

#!/bin/bash
# Check if the file contains NVIDIA copyright header
head -10 tensorrt_llm/_torch/pyexecutor/sampler.py | grep -i "nvidia\|copyright"

Length of output: 82


Let’s search the entire file for any NVIDIA/copyright lines:


🏁 Script executed:

#!/bin/bash
# Search for any NVIDIA or Copyright mentions in sampler.py
grep -n -E "NVIDIA|Copyright" tensorrt_llm/_torch/pyexecutor/sampler.py || true

Length of output: 268


Add NVIDIA copyright header to sampler.py

This file is missing the required NVIDIA copyright header. Per CODING_GUIDELINES.md, all Python source files must begin with a header that includes the current year and NVIDIA’s copyright.

• File to update:

  • tensorrt_llm/_torch/pyexecutor/sampler.py (at the very top)

Please insert a header similar to:

# SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES
# SPDX-License-Identifier: Apache-2.0

Ensure it matches the exact formatting used in other TensorRT-LLM files.

🤖 Prompt for AI Agents
In tensorrt_llm/_torch/pyexecutor/sampler.py at line 1, add the NVIDIA copyright
header at the very top of the file before any imports. Insert the following two
lines exactly as shown: "# SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA
CORPORATION & AFFILIATES" and "# SPDX-License-Identifier: Apache-2.0". This will
ensure compliance with the CODING_GUIDELINES.md and match the formatting used in
other TensorRT-LLM files.

from collections.abc import Iterable
from dataclasses import dataclass
from typing import Literal
from typing import Literal, Union, List

import torch

Expand All @@ -26,6 +26,7 @@
from .llm_request import LlmRequest, LlmRequestState
from .scheduler import ScheduledRequests

from transformers import PreTrainedTokenizerBase

@dataclass(kw_only=True)
class SampleStateTensors:
Expand Down Expand Up @@ -205,6 +206,49 @@ def int_tensor(shape: tuple[int, ...], device: str = 'cuda') -> torch.Tensor:
return torch.empty(shape, dtype=torch.int, device=device)


def meet_stop_token_criteria(
request: LlmRequest,
tokenizer: PreTrainedTokenizerBase,
new_token: Union[int, List[int], torch.Tensor]
):
if request.py_stop_words_list:
assert isinstance(
request.py_stop_words_list,
list), "request.py_stop_words_list should be a list"

stop_words_list, prefix_sum = request.py_stop_words_list
tokens = request.get_tokens(0)
try:
new_words = tokenizer.decode(new_token,skip_special_tokens=False,clean_up_tokenization_spaces=False)
except Exception:
# If decode fails, fall back to token-based matching only
new_words = ""
offset = 0
for i, offset_end in enumerate(prefix_sum):
if i > 0:
offset = prefix_sum[i - 1]
stop_word = stop_words_list[offset:offset_end]
try:
stop_text = tokenizer.decode(stop_word, skip_special_tokens=False, clean_up_tokenization_spaces=False)
except Exception:
continue
if len(stop_word) > len(tokens):
continue
if tokens[-len(stop_word):] == stop_word:
return True
if stop_text in new_words:
return True

return False
Comment on lines +209 to +242
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue

Address remaining issues from previous reviews.

The function logic is sound for text-level stop detection, but several previously flagged issues still need to be fixed:

  1. Missing type annotations: Add proper type hints for the return value
  2. Broad exception handling: Replace Exception with specific exceptions
  3. Empty stop text validation: Add check before substring matching

Apply this diff to fix these issues:

 def meet_stop_token_criteria(
     request: LlmRequest,
     tokenizer: PreTrainedTokenizerBase, 
     new_token: Union[int, List[int], torch.Tensor]
-    ):
+    ) -> bool:
     if request.py_stop_words_list:
         assert isinstance(
             request.py_stop_words_list,
             list), "request.py_stop_words_list should be a list"
 
         stop_words_list, prefix_sum = request.py_stop_words_list
         tokens = request.get_tokens(0)
         try: 
-            new_words = tokenizer.decode(new_token,skip_special_tokens=False,clean_up_tokenization_spaces=False)
-        except Exception:
+            new_words = tokenizer.decode(
+                new_token, 
+                skip_special_tokens=False, 
+                clean_up_tokenization_spaces=False
+            )
+        except (ValueError, TypeError, AttributeError):
             # If decode fails, fall back to token-based matching only
             new_words = ""
         offset = 0
         for i, offset_end in enumerate(prefix_sum):
             if i > 0:
                 offset = prefix_sum[i - 1]
             stop_word = stop_words_list[offset:offset_end]
             try:
-                stop_text = tokenizer.decode(stop_word, skip_special_tokens=False, clean_up_tokenization_spaces=False)
-            except Exception:
+                stop_text = tokenizer.decode(
+                    stop_word, 
+                    skip_special_tokens=False, 
+                    clean_up_tokenization_spaces=False
+                )
+            except (ValueError, TypeError, AttributeError):
                 continue
             if len(stop_word) > len(tokens):
                 continue
             if tokens[-len(stop_word):] == stop_word:
                 return True
-            if stop_text in new_words:
+            if stop_text and stop_text in new_words:
                 return True
 
     return False
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
def meet_stop_token_criteria(
request: LlmRequest,
tokenizer: PreTrainedTokenizerBase,
new_token: Union[int, List[int], torch.Tensor]
):
if request.py_stop_words_list:
assert isinstance(
request.py_stop_words_list,
list), "request.py_stop_words_list should be a list"
stop_words_list, prefix_sum = request.py_stop_words_list
tokens = request.get_tokens(0)
try:
new_words = tokenizer.decode(new_token,skip_special_tokens=False,clean_up_tokenization_spaces=False)
except Exception:
# If decode fails, fall back to token-based matching only
new_words = ""
offset = 0
for i, offset_end in enumerate(prefix_sum):
if i > 0:
offset = prefix_sum[i - 1]
stop_word = stop_words_list[offset:offset_end]
try:
stop_text = tokenizer.decode(stop_word, skip_special_tokens=False, clean_up_tokenization_spaces=False)
except Exception:
continue
if len(stop_word) > len(tokens):
continue
if tokens[-len(stop_word):] == stop_word:
return True
if stop_text in new_words:
return True
return False
def meet_stop_token_criteria(
request: LlmRequest,
tokenizer: PreTrainedTokenizerBase,
new_token: Union[int, List[int], torch.Tensor]
) -> bool:
if request.py_stop_words_list:
assert isinstance(
request.py_stop_words_list,
list
), "request.py_stop_words_list should be a list"
stop_words_list, prefix_sum = request.py_stop_words_list
tokens = request.get_tokens(0)
try:
new_words = tokenizer.decode(
new_token,
skip_special_tokens=False,
clean_up_tokenization_spaces=False
)
except (ValueError, TypeError, AttributeError):
# If decode fails, fall back to token-based matching only
new_words = ""
offset = 0
for i, offset_end in enumerate(prefix_sum):
if i > 0:
offset = prefix_sum[i - 1]
stop_word = stop_words_list[offset:offset_end]
try:
stop_text = tokenizer.decode(
stop_word,
skip_special_tokens=False,
clean_up_tokenization_spaces=False
)
except (ValueError, TypeError, AttributeError):
continue
if len(stop_word) > len(tokens):
continue
if tokens[-len(stop_word):] == stop_word:
return True
if stop_text and stop_text in new_words:
return True
return False
🤖 Prompt for AI Agents
In tensorrt_llm/_torch/pyexecutor/sampler.py around lines 209 to 242, add a
return type annotation to the meet_stop_token_criteria function to specify it
returns a bool. Replace the broad except Exception clauses with more specific
exceptions related to tokenizer.decode failures, such as UnicodeDecodeError or
TokenizerError if applicable. Before checking if stop_text is in new_words, add
a condition to skip empty stop_text values to avoid false positives or errors
during substring matching.



def meet_max_token_stop_criteria(request: LlmRequest,max_seq_len, beam: int):
num_tokens = request.get_num_tokens(beam)
return (num_tokens - request.py_orig_prompt_len
>= request.py_max_new_tokens) or (num_tokens
>= max_seq_len)


class TorchSampler(Sampler):
BEAM = 0
MAX_BEAM_WIDTH = BEAM + 1
Expand All @@ -224,13 +268,15 @@ class Args:
max_num_sequences: int
max_beam_width: int
enable_mixed_sampler: bool
tokenizer: PreTrainedTokenizerBase

def __init__(self, args: Args):
self.max_seq_len = args.max_seq_len
self.enable_mixed_sampler = args.enable_mixed_sampler
self.max_tokens = args.max_draft_len + 1
assert args.max_beam_width == self.MAX_BEAM_WIDTH, "TorchSampler only supports beam_width = 1"
self.num_seq_slots = args.max_num_sequences
self.tokenizer = args.tokenizer

self.NEW_TOKENS_SHAPE = (self.max_tokens, self.num_seq_slots,
self.MAX_BEAM_WIDTH)
Expand All @@ -240,31 +286,6 @@ def __init__(self, args: Args):
with torch.inference_mode(False):
self.store = self.create_store()

def _meet_max_token_stop_criteria(self, request: LlmRequest):
num_tokens = request.get_num_tokens(self.BEAM)
return (num_tokens - request.py_orig_prompt_len
>= request.py_max_new_tokens) or (num_tokens
>= self.max_seq_len)

@staticmethod
def _meet_stop_token_criteria(request: LlmRequest):
if request.py_stop_words_list:
assert isinstance(
request.py_stop_words_list,
list), "request.py_stop_words_list should be a list"
stop_words_list, prefix_sum = request.py_stop_words_list
tokens = request.get_tokens(0)
offset = 0
for i, offset_end in enumerate(prefix_sum):
if i > 0:
offset = prefix_sum[i - 1]
stop_word = stop_words_list[offset:offset_end]
if len(stop_word) > len(tokens):
continue
if tokens[-len(stop_word):] == stop_word:
return True
return False

def _handle_stop_criteria(self, request: LlmRequest,
new_token: int) -> bool:
"""Handle stop criteria and set appropriate finish reasons and state.
Expand All @@ -273,11 +294,11 @@ def _handle_stop_criteria(self, request: LlmRequest,
request.finish_by(FinishReason.END_ID, self.BEAM)
return True

if self._meet_max_token_stop_criteria(request):
if meet_max_token_stop_criteria(request,self.max_seq_len,self.BEAM):
request.finish_by(FinishReason.LENGTH, self.BEAM)
return True

if self._meet_stop_token_criteria(request):
if meet_stop_token_criteria(request, self.tokenizer, new_token):
request.finish_by(FinishReason.STOP_WORDS, self.BEAM)
return True

Expand Down Expand Up @@ -365,6 +386,7 @@ def gen_logits_host(self, requests: Iterable[LlmRequest], vocab_size: int):

def sample_async(self, scheduled_requests: ScheduledRequests,
model_outputs: dict[str, torch.Tensor]) -> SampleState:

requests = scheduled_requests.all_requests()
new_tokens = self.store.new_tokens
vocab_size = model_outputs["logits"].shape[-1]
Expand Down Expand Up @@ -492,6 +514,7 @@ def __init__(
mapping: Mapping,
decoding_mode: DecodingMode,
disable_overlap_scheduler: bool,
tokenizer: PreTrainedTokenizerBase
):

vocab_size = model.config.vocab_size
Expand Down Expand Up @@ -520,6 +543,8 @@ def __init__(
num_hidden_layers, 0, num_heads,
hidden_size, self.model_datatype)

self.tokenizer = tokenizer

self._initialize_store()
self._instantiate_algorithms()

Expand Down Expand Up @@ -625,7 +650,6 @@ def _update_cache_indirection_buffer(self,
@nvtx_range("sample_async")
def sample_async(self, scheduled_requests: ScheduledRequests,
model_outputs) -> SampleStateTRTLLM:

batch_size = scheduled_requests.batch_size
beam_width = self.beam_width(scheduled_requests.all_requests())
if (batch_size > 1 and beam_width > 1
Expand Down Expand Up @@ -753,16 +777,17 @@ def update_requests_single_beam_single_step(self, state: SampleStateTRTLLM):
if (sequence_lengths_host_data[r.py_seq_slot] > r.get_num_tokens(0))
]

# Add new tokens
new_tokens = [
new_tokens_host[r.py_seq_slot] for r in reqs_with_new_tokens
]
add_new_tokens_to_requests(reqs_with_new_tokens, new_tokens, 0)

# Log probs
for request in reqs_with_new_tokens:
seq_slot = request.py_seq_slot
new_token = new_tokens_host[seq_slot]
if meet_stop_token_criteria(request, self.tokenizer, new_token):
request.state = LlmRequestState.GENERATION_COMPLETE
request.set_finished_reason(FinishReason.STOP_WORDS, 0)

add_new_tokens_to_requests([request], [new_token], 0)

if request.py_return_log_probs:
seq_slot = request.py_seq_slot
seq_len = sequence_lengths_host_data[seq_slot]
begin_log_probs_offset = request.prompt_len
current_token = seq_len - request.prompt_len - 1
Expand Down Expand Up @@ -829,6 +854,10 @@ def update_requests_multiple_beams_or_drafting(self,
beam=beam,
step=step)

if meet_stop_token_criteria(request, self.tokenizer, new_token):
request.state = LlmRequestState.GENERATION_COMPLETE
request.set_finished_reason(FinishReason.STOP_WORDS, beam)

if request.py_return_log_probs:
assert state.host.log_probs is not None
# NOTE: Log probs with drafting has not been tested yet.
Expand Down
19 changes: 11 additions & 8 deletions tensorrt_llm/executor/executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,8 @@
from .result import GenerationResult, IterationResult
from .utils import IntraProcessQueue, ProcessPoolExecutorSession, RequestError

from transformers import PreTrainedTokenizerBase

if TYPE_CHECKING:
from .proxy import GenerationExecutorProxy
from .worker import GenerationExecutorWorker
Expand Down Expand Up @@ -352,6 +354,7 @@ def create(
is_llm_executor: Optional[bool] = None,
lora_config: Optional[LoraConfig] = None,
garbage_collection_gen0_threshold: Optional[int] = None,
tokenizer: Optional[PreTrainedTokenizerBase] = None
) -> Union["GenerationExecutorProxy", "GenerationExecutorWorker"]:
# local imports to avoid cyclic importing
from .proxy import GenerationExecutorProxy
Expand Down Expand Up @@ -396,8 +399,8 @@ def create(
mpi_session=mpi_session,
postproc_worker_config=postproc_worker_config,
is_llm_executor=is_llm_executor,
garbage_collection_gen0_threshold=
garbage_collection_gen0_threshold)
garbage_collection_gen0_threshold=garbage_collection_gen0_threshold,
tokenizer=tokenizer)

# WAR: For the performance of gathering logits, we use single process worker
# for TP1 to avoid the large overhead of IPC.
Expand All @@ -409,8 +412,8 @@ def create(
)
return GenerationExecutorWorker(**worker_kwargs,
is_llm_executor=is_llm_executor,
garbage_collection_gen0_threshold=
garbage_collection_gen0_threshold)
garbage_collection_gen0_threshold=garbage_collection_gen0_threshold,
tokenizer=tokenizer)

# For single-gpu case:
# Partition the workload to multiple process for streaming performance.
Expand All @@ -423,8 +426,8 @@ def create(
mpi_session=None, # use mpi4py
postproc_worker_config=postproc_worker_config,
is_llm_executor=is_llm_executor,
garbage_collection_gen0_threshold=
garbage_collection_gen0_threshold)
garbage_collection_gen0_threshold=garbage_collection_gen0_threshold,
tokenizer=tokenizer)
else:
ctx = multiprocessing.get_context("spawn")
# The ProcessPoolExecutorSession is used to support Windows, as mpi4py cannot.
Expand All @@ -436,8 +439,8 @@ def create(
mpi_session=mpi_session,
postproc_worker_config=postproc_worker_config,
is_llm_executor=is_llm_executor,
garbage_collection_gen0_threshold=
garbage_collection_gen0_threshold)
garbage_collection_gen0_threshold=garbage_collection_gen0_threshold,
tokenizer=tokenizer)

def wait_first_completed(
self, futures: List[GenerationResult]
Expand Down
11 changes: 9 additions & 2 deletions tensorrt_llm/executor/proxy.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,8 @@
is_llm_response, print_alive_threads)
from .worker import GenerationExecutorWorker, worker_main

from transformers import PreTrainedTokenizerBase

__all__ = [
"GenerationExecutorProxy",
]
Expand All @@ -46,6 +48,7 @@ def __init__(
postproc_worker_config: Optional[PostprocWorkerConfig] = None,
is_llm_executor: Optional[bool] = None,
garbage_collection_gen0_threshold: Optional[int] = None,
tokenizer: Optional[PreTrainedTokenizerBase] = None,
) -> None:
postproc_worker_config = postproc_worker_config or PostprocWorkerConfig(
)
Expand All @@ -59,6 +62,7 @@ def __init__(

self.workers_started = False
self.worker_cls = worker_cls
self.tokenizer = tokenizer

mpi_process_pre_spawned: bool = get_spawn_proxy_process_env()

Expand Down Expand Up @@ -94,7 +98,8 @@ def __init__(
postproc_worker_config=postproc_worker_config,
is_llm_executor=False,
garbage_collection_gen0_threshold=self.
garbage_collection_gen0_threshold)
garbage_collection_gen0_threshold,
tokenizer=tokenizer)

if "log_level" not in worker_kwargs:
worker_kwargs["log_level"] = logger.level
Expand Down Expand Up @@ -410,7 +415,9 @@ def submit(self, request: GenerationRequest) -> GenerationResult:
background_error_handler=self._handle_background_error,
executor=self,
disaggregated_params=request.disaggregated_params,
logprob_params=logprob_params)
logprob_params=logprob_params,
tokenizer = self.tokenizer
)
self._results[request.id] = result

with nvtx_range_debug("request_queue.put"):
Expand Down
Loading