Skip to content
Original file line number Diff line number Diff line change
Expand Up @@ -65,14 +65,9 @@ cuda_graph_config:
max_batch_size: 1024
kv_cache_config:
dtype: fp8
use_torch_sampler: true
EOF
```

> Here `use_torch_sampler: true` is added as a temporary WAR to solve illegal memory access issue when using trtllm native sampler.
>
> TODO: Remove this after the issue is resolved

### Launch the TRT-LLM Server

Below is an example command to launch the TRT-LLM server with the Llama-4-Scout-17B-16E-Instruct-FP8 model from within the container. The command is specifically configured for the 1024/1024 Input/Output Sequence Length test. The explanation of each flag is shown in the “Configs and Parameters” section.
Expand Down
8 changes: 4 additions & 4 deletions examples/llm-api/quickstart_advanced.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,9 +65,9 @@ def add_llm_args(parser):
parser.add_argument('--attention_dp_batching_wait_iters',
type=int,
default=0)
parser.add_argument('--use_torch_sampler',
default=False,
action='store_true')
parser.add_argument('--sampler_type',
default="auto",
choices=["auto", "TorchSampler", "TRTLLMSampler"])
parser.add_argument('--tp_size', type=int, default=1)
parser.add_argument('--pp_size', type=int, default=1)
parser.add_argument('--moe_ep_size', type=int, default=-1)
Expand Down Expand Up @@ -230,7 +230,7 @@ def setup_llm(args, **kwargs):
args.use_piecewise_cuda_graph)
if args.use_torch_compile else None,
moe_config=MoeConfig(backend=args.moe_backend),
use_torch_sampler=args.use_torch_sampler,
sampler_type=args.sampler_type,
max_seq_len=args.max_seq_len,
max_batch_size=args.max_batch_size,
max_num_tokens=args.max_num_tokens,
Expand Down
101 changes: 10 additions & 91 deletions tensorrt_llm/_torch/pyexecutor/_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,9 @@
import tensorrt_llm
import tensorrt_llm.bindings.executor as trtllm
from tensorrt_llm._torch.model_config import ModelConfig
from tensorrt_llm._torch.pyexecutor.config import PyTorchConfig
from tensorrt_llm._utils import str_dtype_to_binding, torch_dtype_to_str
from tensorrt_llm.bindings.executor import DecodingMode, ExecutorConfig
from tensorrt_llm.llmapi.llm_args import PeftCacheConfig
from tensorrt_llm.llmapi.llm_args import PeftCacheConfig, SamplerType
from tensorrt_llm.logger import logger
from tensorrt_llm.lora_helper import (LoraConfig,
get_default_trtllm_modules_to_hf_modules)
Expand Down Expand Up @@ -589,20 +588,24 @@ def instantiate_sampler(engine: PyTorchModelEngine,
mapping,
max_seq_len=engine.max_seq_len,
enable_mixed_sampler=pytorch_backend_config.enable_mixed_sampler)
decoding_mode = get_decoding_mode(executor_config)
if mapping.cp_config.get('cp_type') == CpType.STAR:
assert pytorch_backend_config.attn_backend == "FLASHINFER_STAR_ATTENTION", "attention backend of star attention should be 'FLASHINFER_STAR_ATTENTION'"
return TorchSampler(sampler_args)
if engine.spec_config is not None and engine.spec_config.spec_dec_mode.has_spec_decoder(
):
return get_spec_decoder(sampler_args, engine.spec_config)
if pytorch_backend_config.use_torch_sampler or pytorch_backend_config.enable_mixed_sampler or engine.spec_config is not None:
return TorchSampler(sampler_args)
if pytorch_backend_config.sampler_type == SamplerType.TRTLLMSampler or (
pytorch_backend_config.sampler_type == SamplerType.auto
and decoding_mode.isBeamSearch()):
logger.debug(f"DecodingMode: {decoding_mode.name}")
return TRTLLMSampler(executor_config, engine.model, engine.dtype,
mapping, decoding_mode,
pytorch_backend_config.disable_overlap_scheduler)
if not engine.model.model_config.is_generation:
# NOTE: choose sampler based on model type
return EarlyStopSampler()
return TRTLLMSampler(executor_config, engine.model, engine.dtype, mapping,
get_decoding_mode(executor_config),
pytorch_backend_config.disable_overlap_scheduler)
return TorchSampler(sampler_args)


def get_decoding_mode(executor_config: ExecutorConfig) -> DecodingMode:
Expand All @@ -623,90 +626,6 @@ def get_decoding_mode(executor_config: ExecutorConfig) -> DecodingMode:
)
decoding_mode = DecodingMode.TopKTopP()

# Override decoding mode when Medusa is used
if getattr(executor_config.speculative_config, "is_medusa",
False) and not decoding_mode.isMedusa():
logger.warning(
"Model is Medusa, but decoding mode is not Medusa. Overwriting decoding mode to Medusa."
)
decoding_mode = DecodingMode.Medusa()

# Override decoding mode when Medusa is not used
if (not getattr(executor_config.speculative_config, "is_medusa",
False)) and decoding_mode.isMedusa():
logger.warning(
"Model is not Medusa, but decoding mode is Medusa. Overwriting decoding mode."
)
if executor_config.max_beam_width == 1:
decoding_mode = DecodingMode.TopKTopP()
else:
decoding_mode = DecodingMode.BeamSearch()

# Override decoding mode when lookahead decoding is used
if getattr(executor_config.speculative_config, "is_lookahead",
False) and not decoding_mode.isLookahead():
logger.warning(
"Model is Lookahead, but decoding mode is not Lookahead. Overwriting decoding mode to Lookahead."
)
decoding_mode = DecodingMode.Lookahead()

# Override decoding mode when lookahead decoding is not used
if (not getattr(executor_config.speculative_config, "is_lookahead",
False)) and decoding_mode.isLookahead():
logger.warning(
"Model is not built with Lookahead decoding, but decoding mode is Lookahead. Overwriting decoding mode."
)
if executor_config.max_beam_width == 1:
decoding_mode = DecodingMode.TopKTopP()
else:
decoding_mode = DecodingMode.BeamSearch()

# Override decoding mode when 'explicit draft tokens' is used
if getattr(executor_config.speculative_config, "is_explicit_draft_tokens",
False) and not decoding_mode.isExplicitDraftTokens():
logger.warning(
"Model is built with 'explicit draft tokens' decoding, but decoding mode is something else. Overwriting decoding mode."
)
decoding_mode = DecodingMode.ExplicitDraftTokens()

# Override decoding mode when 'explicit draft tokens' is not used
if (not getattr(executor_config.speculative_config,
"is_explicit_draft_tokens",
False)) and decoding_mode.isExplicitDraftTokens():
logger.warning(
"Model is not built with 'explicit draft tokens' decoding, but decoding mode is set to it. Overwriting decoding mode to default."
)
if executor_config.max_beam_width == 1:
decoding_mode = DecodingMode.TopKTopP()
else:
decoding_mode = DecodingMode.BeamSearch()

# Override decoding mode when EAGLE is used
if getattr(executor_config.speculative_config, "is_eagle",
False) and not decoding_mode.isEagle():
logger.warning(
"Model is Eagle, but decoding mode is not Eagle. Overwriting decoding mode to Eagle."
)
decoding_mode = DecodingMode.Eagle()

# Override decoding mode when Eagle is not used
if (not getattr(executor_config.speculative_config, "is_eagle",
False)) and decoding_mode.isEagle():
logger.warning(
"Model is not Eagle, but decoding mode is Eagle. Overwriting decoding mode."
)
if executor_config.max_beam_width == 1:
decoding_mode = DecodingMode.TopKTopP()
else:
decoding_mode = DecodingMode.BeamSearch()

# Override decoding mode when draft tokens are external
if getattr(executor_config.speculative_config, "is_draft_tokens_external",
False):
logger.warning("Overwriting decoding mode to external draft token")
decoding_mode = DecodingMode.ExternalDraftTokens()

logger.debug(f"DecodingMode: {decoding_mode.name}")
return decoding_mode


Expand Down
7 changes: 4 additions & 3 deletions tensorrt_llm/_torch/pyexecutor/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from tensorrt_llm.bindings.executor import ExecutorConfig

from ...builder import BuildConfig
from ...llmapi.llm_args import LoadFormat
from ...llmapi.llm_args import LoadFormat, SamplerType
from ...logger import logger
from ...mapping import Mapping
from ..model_config import MoeLoadBalancerConfig
Expand Down Expand Up @@ -60,9 +60,10 @@ class PyTorchConfig:
If true, will iterate over sampling_params of each request and use the
corresponding sampling strategy, e.g. top-k, top-p, etc.
"""
use_torch_sampler: bool = False
sampler_type: SamplerType = SamplerType.auto
"""
If true, will use the Torch sampler instead of the TRTLLM sampler.
The type of sampler to use. Options are TRTLLMSampler, TorchSampler or auto.
Defaults to auto, which will use TorchSampler unless BeamSearch is requested.
"""

kv_cache_dtype: str = "auto"
Expand Down
17 changes: 12 additions & 5 deletions tensorrt_llm/llmapi/llm_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -1968,6 +1968,13 @@ class LoadFormat(Enum):
DUMMY = 1


class SamplerType(StrEnum):
"""Enum for sampler type options."""
TRTLLMSampler = "TRTLLMSampler"
TorchSampler = "TorchSampler"
auto = "auto"


class TorchCompileConfig(StrictBaseModel):
"""
Configuration for torch.compile.
Expand Down Expand Up @@ -2055,11 +2062,11 @@ class TorchLlmArgs(BaseLlmArgs):
"If true, will iterate over sampling_params of each request and use the corresponding sampling strategy, e.g. top-k, top-p, etc.",
status="beta")

use_torch_sampler: bool = Field(
default=False,
sampler_type: Union[str, SamplerType] = Field(
default=SamplerType.auto,
description=
"If true, will use the Torch sampler instead of the TRTLLM sampler.",
status="beta")
"The type of sampler to use. Options are TRTLLMSampler, TorchSampler or auto. Defaults to auto, which will use TorchSampler unless BeamSearch is requested.",
status="prototype")

enable_iter_perf_stats: bool = Field(
default=False,
Expand Down Expand Up @@ -2344,7 +2351,7 @@ def get_pytorch_backend_config(self) -> "PyTorchConfig":
attn_backend=self.attn_backend,
moe_backend=self.moe_config.backend,
enable_mixed_sampler=self.enable_mixed_sampler,
use_torch_sampler=self.use_torch_sampler,
sampler_type=self.sampler_type,
kv_cache_dtype=self.kv_cache_config.dtype,
mamba_ssm_cache_dtype=self.kv_cache_config.mamba_ssm_cache_dtype,
enable_iter_perf_stats=self.enable_iter_perf_stats,
Expand Down
7 changes: 5 additions & 2 deletions tensorrt_llm/llmapi/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import hashlib
import io
import os
import re
import sys
import tempfile
import threading
Expand Down Expand Up @@ -508,8 +509,10 @@ def generate_api_docs_as_docstring(model: Type[BaseModel],
type_str = str(type_hints[field_name])
type_str = type_str.replace("typing.", "")
# Extract just the class name from full class path
if "<class '" in type_str:
type_str = type_str[8:-2]
for regex in [r"<class '([^']+)'>", r"<enum '([^']+)'>"]:
if (match := re.match(regex, type_str)) is not None:
type_str = match.group(1)
break
else:
type_str = field_type or 'Any'

Expand Down
9 changes: 3 additions & 6 deletions tests/integration/defs/accuracy/test_llm_api_pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -196,7 +196,7 @@ def test_fp8_4gpus(self, tp_size, pp_size, fp8kv, attn_backend,
@skip_pre_hopper
def test_fp8_llm_sampler(self):
model_path = f"{llm_models_root()}/llama-3.1-model/Llama-3.1-8B-Instruct-FP8"
with LLM(model_path, use_torch_sampler=True, max_batch_size=256) as llm:
with LLM(model_path, max_batch_size=256) as llm:
assert llm.args.quant_config.quant_algo == QuantAlgo.FP8

sampling_params = SamplingParams(
Expand Down Expand Up @@ -229,7 +229,6 @@ def test_fp8_beam_search(self):
max_beam_width=max_beam_width,
max_batch_size=16,
max_seq_len=1024,
use_torch_sampler=False,
build_config=None)

with llm:
Expand Down Expand Up @@ -2011,8 +2010,7 @@ def test_fp8_block_scales(self, tp_size, pp_size, ep_size, attention_dp,
cuda_graph, overlap_scheduler):
pytorch_config = dict(
disable_overlap_scheduler=not overlap_scheduler,
cuda_graph_config=CudaGraphConfig() if cuda_graph else None,
use_torch_sampler=True)
cuda_graph_config=CudaGraphConfig() if cuda_graph else None)

with LLM(f"{llm_models_root()}/Qwen3/Qwen3-8B-FP8",
tensor_parallel_size=tp_size,
Expand All @@ -2034,8 +2032,7 @@ def test_bf16(self, tp_size, pp_size, ep_size, attention_dp, cuda_graph,
overlap_scheduler):
pytorch_config = dict(
disable_overlap_scheduler=not overlap_scheduler,
cuda_graph_config=CudaGraphConfig() if cuda_graph else None,
use_torch_sampler=True)
cuda_graph_config=CudaGraphConfig() if cuda_graph else None)

with LLM(f"{llm_models_root()}/Qwen3/Qwen3-8B",
tensor_parallel_size=tp_size,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@ context_servers:
backend: "DEFAULT"
urls:
- "localhost:8001"
use_torch_sampler: True
generation_servers:
num_instances: 1
tensor_parallel_size: 1
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ context_servers:
max_seq_len: 4096
tensor_parallel_size: 1
pipeline_parallel_size: 1
use_torch_sampler: True
sampler_type: "TRTLLMSampler"
kv_cache_config:
free_gpu_memory_fraction: 0.2
enable_partial_reuse: False
Expand All @@ -27,7 +27,7 @@ generation_servers:
max_batch_size: 256
max_num_tokens: 4096
max_seq_len: 4096
use_torch_sampler: True
sampler_type: "TRTLLMSampler"
kv_cache_config:
free_gpu_memory_fraction: 0.2
enable_partial_reuse: False
Expand Down
16 changes: 8 additions & 8 deletions tests/integration/defs/disaggregated/test_disaggregated.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,8 +55,8 @@ def get_test_config(test_desc, example_dir, test_root):
(2, f"{test_configs_root}/disagg_config_cuda_graph_padding.yaml"),
"mixed": (2, f"{test_configs_root}/disagg_config_mixed.yaml"),
"overlap": (2, f"{test_configs_root}/disagg_config_overlap.yaml"),
"torch_sampler":
(2, f"{test_configs_root}/disagg_config_torch_sampler.yaml"),
"trtllm_sampler":
(2, f"{test_configs_root}/disagg_config_trtllm_sampler.yaml"),
"load_balance":
(4, f"{test_configs_root}/disagg_config_load_balance.yaml"),
"cache_aware_balance":
Expand Down Expand Up @@ -211,7 +211,7 @@ def run_disaggregated_test(example_dir,
poll_procs=[workers_proc, server_proc])

# Run the chat completion endpoint test only for TinyLlama
if test_desc == "overlap" or test_desc == "torch_sampler":
if test_desc == "overlap" or test_desc == "trtllm_sampler":
chat_client_cmd = client_cmd + [
'-e', 'chat', '-o', 'output_chat.json'
]
Expand All @@ -234,7 +234,7 @@ def run_disaggregated_test(example_dir,
not_expected_strings = ["Berlin Berlin"]

output_files = ['output.json', 'output_streaming.json']
if test_desc == "overlap" or test_desc == "torch_sampler":
if test_desc == "overlap" or test_desc == "trtllm_sampler":
# Disable streaming chat completion for overlap test
# due to bug
output_files.extend(['output_chat.json'])
Expand Down Expand Up @@ -488,9 +488,9 @@ def test_disaggregated_overlap(disaggregated_test_root, llm_venv,

@pytest.mark.parametrize("llama_model_root", ['TinyLlama-1.1B-Chat-v1.0'],
indirect=True)
def test_disaggregated_torch_sampler(disaggregated_test_root, llm_venv,
disaggregated_example_root,
llama_model_root):
def test_disaggregated_trtllm_sampler(disaggregated_test_root, llm_venv,
disaggregated_example_root,
llama_model_root):
src_dst_dict = {
llama_model_root:
f"{llm_venv.get_working_directory()}/TinyLlama/TinyLlama-1.1B-Chat-v1.0",
Expand All @@ -501,7 +501,7 @@ def test_disaggregated_torch_sampler(disaggregated_test_root, llm_venv,
os.symlink(src, dst, target_is_directory=True)

run_disaggregated_test(disaggregated_example_root,
"torch_sampler",
"trtllm_sampler",
env=llm_venv._new_env,
cwd=llm_venv.get_working_directory())

Expand Down
2 changes: 1 addition & 1 deletion tests/integration/test_lists/qa/llm_function_full.txt
Original file line number Diff line number Diff line change
Expand Up @@ -689,7 +689,7 @@ disaggregated/test_disaggregated.py::test_disaggregated_deepseek_v3_lite_fp8_att
disaggregated/test_disaggregated.py::test_disaggregated_deepseek_v3_lite_fp8_attention_dp_one_mtp[DeepSeek-V3-Lite-fp8]
disaggregated/test_disaggregated.py::test_disaggregated_load_balance[TinyLlama-1.1B-Chat-v1.0]
disaggregated/test_disaggregated.py::test_disaggregated_cache_aware_balance[TinyLlama-1.1B-Chat-v1.0]
disaggregated/test_disaggregated.py::test_disaggregated_torch_sampler[TinyLlama-1.1B-Chat-v1.0]
disaggregated/test_disaggregated.py::test_disaggregated_trtllm_sampler[TinyLlama-1.1B-Chat-v1.0]
disaggregated/test_disaggregated_single_gpu.py::test_disaggregated_simple_qwen3[False-False-Qwen3-8B-FP8]
disaggregated/test_disaggregated_single_gpu.py::test_disaggregated_simple_qwen3[False-True-Qwen3-8B-FP8]
disaggregated/test_disaggregated_single_gpu.py::test_disaggregated_simple_qwen3[True-False-Qwen3-8B-FP8]
Expand Down
2 changes: 1 addition & 1 deletion tests/integration/test_lists/qa/llm_function_sanity.txt
Original file line number Diff line number Diff line change
Expand Up @@ -122,7 +122,7 @@ disaggregated/test_disaggregated.py::test_disaggregated_deepseek_v3_lite_fp8_mpi
disaggregated/test_disaggregated.py::test_disaggregated_deepseek_v3_lite_fp8_nixl[DeepSeek-V3-Lite-fp8]
disaggregated/test_disaggregated.py::test_disaggregated_load_balance[TinyLlama-1.1B-Chat-v1.0]
disaggregated/test_disaggregated.py::test_disaggregated_cache_aware_balance[TinyLlama-1.1B-Chat-v1.0]
disaggregated/test_disaggregated.py::test_disaggregated_torch_sampler[TinyLlama-1.1B-Chat-v1.0]
disaggregated/test_disaggregated.py::test_disaggregated_trtllm_sampler[TinyLlama-1.1B-Chat-v1.0]
disaggregated/test_disaggregated.py::test_disaggregated_multi_gpu_with_mpirun[TinyLlama-1.1B-Chat-v1.0]
disaggregated/test_disaggregated.py::test_disaggregated_single_gpu_with_mpirun_trt_backend[TinyLlama-1.1B-Chat-v1.0]
disaggregated/test_disaggregated.py::test_disaggregated_single_gpu_with_mpirun[TinyLlama-1.1B-Chat-v1.0]
Expand Down
4 changes: 0 additions & 4 deletions tests/scripts/perf-sanity/run_benchmark_serve.py
Original file line number Diff line number Diff line change
Expand Up @@ -308,10 +308,6 @@ def generate_extra_llm_api_config(self, test_case: Dict[str, Any]) -> str:
" enable_block_reuse: false",
]

# https://nvbugs/5437106: WAR to avoid illegal memory access in Scout
if "Scout" in test_case['model']:
config_lines.append("use_torch_sampler: true")

# Add moe_config if moe_backend is specified
if test_case['moe_backend']:
config_lines.append("moe_config:")
Expand Down
Loading