Skip to content

Commit 73372ae

Browse files
dcamporadominicshanshan
authored andcommitted
[TRTLLM-7157][feat] BREAKING CHANGE Introduce sampler_type, detect sampler according to options (NVIDIA#6831)
Signed-off-by: Daniel Campora <[email protected]> Signed-off-by: Wangshanshan <[email protected]>
1 parent dd77bea commit 73372ae

27 files changed

+81
-172
lines changed

docs/source/deployment-guide/quick-start-recipe-for-llama4-scout-on-trtllm.md

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -65,14 +65,9 @@ cuda_graph_config:
6565
max_batch_size: 1024
6666
kv_cache_config:
6767
dtype: fp8
68-
use_torch_sampler: true
6968
EOF
7069
```
7170

72-
> Here `use_torch_sampler: true` is added as a temporary WAR to solve illegal memory access issue when using trtllm native sampler.
73-
>
74-
> TODO: Remove this after the issue is resolved
75-
7671
### Launch the TRT-LLM Server
7772

7873
Below is an example command to launch the TRT-LLM server with the Llama-4-Scout-17B-16E-Instruct-FP8 model from within the container. The command is specifically configured for the 1024/1024 Input/Output Sequence Length test. The explanation of each flag is shown in the “Configs and Parameters” section.

examples/llm-api/quickstart_advanced.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -65,9 +65,9 @@ def add_llm_args(parser):
6565
parser.add_argument('--attention_dp_batching_wait_iters',
6666
type=int,
6767
default=0)
68-
parser.add_argument('--use_torch_sampler',
69-
default=False,
70-
action='store_true')
68+
parser.add_argument('--sampler_type',
69+
default="auto",
70+
choices=["auto", "TorchSampler", "TRTLLMSampler"])
7171
parser.add_argument('--tp_size', type=int, default=1)
7272
parser.add_argument('--pp_size', type=int, default=1)
7373
parser.add_argument('--moe_ep_size', type=int, default=-1)
@@ -230,7 +230,7 @@ def setup_llm(args, **kwargs):
230230
args.use_piecewise_cuda_graph)
231231
if args.use_torch_compile else None,
232232
moe_config=MoeConfig(backend=args.moe_backend),
233-
use_torch_sampler=args.use_torch_sampler,
233+
sampler_type=args.sampler_type,
234234
max_seq_len=args.max_seq_len,
235235
max_batch_size=args.max_batch_size,
236236
max_num_tokens=args.max_num_tokens,

tensorrt_llm/_torch/pyexecutor/_util.py

Lines changed: 10 additions & 91 deletions
Original file line numberDiff line numberDiff line change
@@ -8,10 +8,9 @@
88
import tensorrt_llm
99
import tensorrt_llm.bindings.executor as trtllm
1010
from tensorrt_llm._torch.model_config import ModelConfig
11-
from tensorrt_llm._torch.pyexecutor.config import PyTorchConfig
1211
from tensorrt_llm._utils import str_dtype_to_binding, torch_dtype_to_str
1312
from tensorrt_llm.bindings.executor import DecodingMode, ExecutorConfig
14-
from tensorrt_llm.llmapi.llm_args import PeftCacheConfig
13+
from tensorrt_llm.llmapi.llm_args import PeftCacheConfig, SamplerType
1514
from tensorrt_llm.logger import logger
1615
from tensorrt_llm.lora_helper import (LoraConfig,
1716
get_default_trtllm_modules_to_hf_modules)
@@ -595,20 +594,24 @@ def instantiate_sampler(engine: PyTorchModelEngine,
595594
mapping,
596595
max_seq_len=engine.max_seq_len,
597596
enable_mixed_sampler=pytorch_backend_config.enable_mixed_sampler)
597+
decoding_mode = get_decoding_mode(executor_config)
598598
if mapping.cp_config.get('cp_type') == CpType.STAR:
599599
assert pytorch_backend_config.attn_backend == "FLASHINFER_STAR_ATTENTION", "attention backend of star attention should be 'FLASHINFER_STAR_ATTENTION'"
600600
return TorchSampler(sampler_args)
601601
if engine.spec_config is not None and engine.spec_config.spec_dec_mode.has_spec_decoder(
602602
):
603603
return get_spec_decoder(sampler_args, engine.spec_config)
604-
if pytorch_backend_config.use_torch_sampler or pytorch_backend_config.enable_mixed_sampler or engine.spec_config is not None:
605-
return TorchSampler(sampler_args)
604+
if pytorch_backend_config.sampler_type == SamplerType.TRTLLMSampler or (
605+
pytorch_backend_config.sampler_type == SamplerType.auto
606+
and decoding_mode.isBeamSearch()):
607+
logger.debug(f"DecodingMode: {decoding_mode.name}")
608+
return TRTLLMSampler(executor_config, engine.model, engine.dtype,
609+
mapping, decoding_mode,
610+
pytorch_backend_config.disable_overlap_scheduler)
606611
if not engine.model.model_config.is_generation:
607612
# NOTE: choose sampler based on model type
608613
return EarlyStopSampler()
609-
return TRTLLMSampler(executor_config, engine.model, engine.dtype, mapping,
610-
get_decoding_mode(executor_config),
611-
pytorch_backend_config.disable_overlap_scheduler)
614+
return TorchSampler(sampler_args)
612615

613616

614617
def get_decoding_mode(executor_config: ExecutorConfig) -> DecodingMode:
@@ -629,90 +632,6 @@ def get_decoding_mode(executor_config: ExecutorConfig) -> DecodingMode:
629632
)
630633
decoding_mode = DecodingMode.TopKTopP()
631634

632-
# Override decoding mode when Medusa is used
633-
if getattr(executor_config.speculative_config, "is_medusa",
634-
False) and not decoding_mode.isMedusa():
635-
logger.warning(
636-
"Model is Medusa, but decoding mode is not Medusa. Overwriting decoding mode to Medusa."
637-
)
638-
decoding_mode = DecodingMode.Medusa()
639-
640-
# Override decoding mode when Medusa is not used
641-
if (not getattr(executor_config.speculative_config, "is_medusa",
642-
False)) and decoding_mode.isMedusa():
643-
logger.warning(
644-
"Model is not Medusa, but decoding mode is Medusa. Overwriting decoding mode."
645-
)
646-
if executor_config.max_beam_width == 1:
647-
decoding_mode = DecodingMode.TopKTopP()
648-
else:
649-
decoding_mode = DecodingMode.BeamSearch()
650-
651-
# Override decoding mode when lookahead decoding is used
652-
if getattr(executor_config.speculative_config, "is_lookahead",
653-
False) and not decoding_mode.isLookahead():
654-
logger.warning(
655-
"Model is Lookahead, but decoding mode is not Lookahead. Overwriting decoding mode to Lookahead."
656-
)
657-
decoding_mode = DecodingMode.Lookahead()
658-
659-
# Override decoding mode when lookahead decoding is not used
660-
if (not getattr(executor_config.speculative_config, "is_lookahead",
661-
False)) and decoding_mode.isLookahead():
662-
logger.warning(
663-
"Model is not built with Lookahead decoding, but decoding mode is Lookahead. Overwriting decoding mode."
664-
)
665-
if executor_config.max_beam_width == 1:
666-
decoding_mode = DecodingMode.TopKTopP()
667-
else:
668-
decoding_mode = DecodingMode.BeamSearch()
669-
670-
# Override decoding mode when 'explicit draft tokens' is used
671-
if getattr(executor_config.speculative_config, "is_explicit_draft_tokens",
672-
False) and not decoding_mode.isExplicitDraftTokens():
673-
logger.warning(
674-
"Model is built with 'explicit draft tokens' decoding, but decoding mode is something else. Overwriting decoding mode."
675-
)
676-
decoding_mode = DecodingMode.ExplicitDraftTokens()
677-
678-
# Override decoding mode when 'explicit draft tokens' is not used
679-
if (not getattr(executor_config.speculative_config,
680-
"is_explicit_draft_tokens",
681-
False)) and decoding_mode.isExplicitDraftTokens():
682-
logger.warning(
683-
"Model is not built with 'explicit draft tokens' decoding, but decoding mode is set to it. Overwriting decoding mode to default."
684-
)
685-
if executor_config.max_beam_width == 1:
686-
decoding_mode = DecodingMode.TopKTopP()
687-
else:
688-
decoding_mode = DecodingMode.BeamSearch()
689-
690-
# Override decoding mode when EAGLE is used
691-
if getattr(executor_config.speculative_config, "is_eagle",
692-
False) and not decoding_mode.isEagle():
693-
logger.warning(
694-
"Model is Eagle, but decoding mode is not Eagle. Overwriting decoding mode to Eagle."
695-
)
696-
decoding_mode = DecodingMode.Eagle()
697-
698-
# Override decoding mode when Eagle is not used
699-
if (not getattr(executor_config.speculative_config, "is_eagle",
700-
False)) and decoding_mode.isEagle():
701-
logger.warning(
702-
"Model is not Eagle, but decoding mode is Eagle. Overwriting decoding mode."
703-
)
704-
if executor_config.max_beam_width == 1:
705-
decoding_mode = DecodingMode.TopKTopP()
706-
else:
707-
decoding_mode = DecodingMode.BeamSearch()
708-
709-
# Override decoding mode when draft tokens are external
710-
if getattr(executor_config.speculative_config, "is_draft_tokens_external",
711-
False):
712-
logger.warning("Overwriting decoding mode to external draft token")
713-
decoding_mode = DecodingMode.ExternalDraftTokens()
714-
715-
logger.debug(f"DecodingMode: {decoding_mode.name}")
716635
return decoding_mode
717636

718637

tensorrt_llm/_torch/pyexecutor/config.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
from tensorrt_llm.bindings.executor import ExecutorConfig
77

88
from ...builder import BuildConfig
9-
from ...llmapi.llm_args import LoadFormat
9+
from ...llmapi.llm_args import LoadFormat, SamplerType
1010
from ...logger import logger
1111
from ...mapping import Mapping
1212
from ..model_config import MoeLoadBalancerConfig
@@ -60,9 +60,10 @@ class PyTorchConfig:
6060
If true, will iterate over sampling_params of each request and use the
6161
corresponding sampling strategy, e.g. top-k, top-p, etc.
6262
"""
63-
use_torch_sampler: bool = False
63+
sampler_type: SamplerType = SamplerType.auto
6464
"""
65-
If true, will use the Torch sampler instead of the TRTLLM sampler.
65+
The type of sampler to use. Options are TRTLLMSampler, TorchSampler or auto.
66+
Defaults to auto, which will use TorchSampler unless BeamSearch is requested.
6667
"""
6768

6869
kv_cache_dtype: str = "auto"

tensorrt_llm/llmapi/llm_args.py

Lines changed: 12 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1968,6 +1968,13 @@ class LoadFormat(Enum):
19681968
DUMMY = 1
19691969

19701970

1971+
class SamplerType(StrEnum):
1972+
"""Enum for sampler type options."""
1973+
TRTLLMSampler = "TRTLLMSampler"
1974+
TorchSampler = "TorchSampler"
1975+
auto = "auto"
1976+
1977+
19711978
class TorchCompileConfig(StrictBaseModel):
19721979
"""
19731980
Configuration for torch.compile.
@@ -2055,11 +2062,11 @@ class TorchLlmArgs(BaseLlmArgs):
20552062
"If true, will iterate over sampling_params of each request and use the corresponding sampling strategy, e.g. top-k, top-p, etc.",
20562063
status="beta")
20572064

2058-
use_torch_sampler: bool = Field(
2059-
default=False,
2065+
sampler_type: Union[str, SamplerType] = Field(
2066+
default=SamplerType.auto,
20602067
description=
2061-
"If true, will use the Torch sampler instead of the TRTLLM sampler.",
2062-
status="beta")
2068+
"The type of sampler to use. Options are TRTLLMSampler, TorchSampler or auto. Defaults to auto, which will use TorchSampler unless BeamSearch is requested.",
2069+
status="prototype")
20632070

20642071
enable_iter_perf_stats: bool = Field(
20652072
default=False,
@@ -2344,7 +2351,7 @@ def get_pytorch_backend_config(self) -> "PyTorchConfig":
23442351
attn_backend=self.attn_backend,
23452352
moe_backend=self.moe_config.backend,
23462353
enable_mixed_sampler=self.enable_mixed_sampler,
2347-
use_torch_sampler=self.use_torch_sampler,
2354+
sampler_type=self.sampler_type,
23482355
kv_cache_dtype=self.kv_cache_config.dtype,
23492356
mamba_ssm_cache_dtype=self.kv_cache_config.mamba_ssm_cache_dtype,
23502357
enable_iter_perf_stats=self.enable_iter_perf_stats,

tensorrt_llm/llmapi/utils.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
import hashlib
44
import io
55
import os
6+
import re
67
import sys
78
import tempfile
89
import threading
@@ -508,8 +509,10 @@ def generate_api_docs_as_docstring(model: Type[BaseModel],
508509
type_str = str(type_hints[field_name])
509510
type_str = type_str.replace("typing.", "")
510511
# Extract just the class name from full class path
511-
if "<class '" in type_str:
512-
type_str = type_str[8:-2]
512+
for regex in [r"<class '([^']+)'>", r"<enum '([^']+)'>"]:
513+
if (match := re.match(regex, type_str)) is not None:
514+
type_str = match.group(1)
515+
break
513516
else:
514517
type_str = field_type or 'Any'
515518

tests/integration/defs/accuracy/test_llm_api_pytorch.py

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -196,7 +196,7 @@ def test_fp8_4gpus(self, tp_size, pp_size, fp8kv, attn_backend,
196196
@skip_pre_hopper
197197
def test_fp8_llm_sampler(self):
198198
model_path = f"{llm_models_root()}/llama-3.1-model/Llama-3.1-8B-Instruct-FP8"
199-
with LLM(model_path, use_torch_sampler=True, max_batch_size=256) as llm:
199+
with LLM(model_path, max_batch_size=256) as llm:
200200
assert llm.args.quant_config.quant_algo == QuantAlgo.FP8
201201

202202
sampling_params = SamplingParams(
@@ -229,7 +229,6 @@ def test_fp8_beam_search(self):
229229
max_beam_width=max_beam_width,
230230
max_batch_size=16,
231231
max_seq_len=1024,
232-
use_torch_sampler=False,
233232
build_config=None)
234233

235234
with llm:
@@ -2042,8 +2041,7 @@ def test_fp8_block_scales(self, tp_size, pp_size, ep_size, attention_dp,
20422041
cuda_graph, overlap_scheduler):
20432042
pytorch_config = dict(
20442043
disable_overlap_scheduler=not overlap_scheduler,
2045-
cuda_graph_config=CudaGraphConfig() if cuda_graph else None,
2046-
use_torch_sampler=True)
2044+
cuda_graph_config=CudaGraphConfig() if cuda_graph else None)
20472045

20482046
with LLM(f"{llm_models_root()}/Qwen3/Qwen3-8B-FP8",
20492047
tensor_parallel_size=tp_size,
@@ -2065,8 +2063,7 @@ def test_bf16(self, tp_size, pp_size, ep_size, attention_dp, cuda_graph,
20652063
overlap_scheduler):
20662064
pytorch_config = dict(
20672065
disable_overlap_scheduler=not overlap_scheduler,
2068-
cuda_graph_config=CudaGraphConfig() if cuda_graph else None,
2069-
use_torch_sampler=True)
2066+
cuda_graph_config=CudaGraphConfig() if cuda_graph else None)
20702067

20712068
with LLM(f"{llm_models_root()}/Qwen3/Qwen3-8B",
20722069
tensor_parallel_size=tp_size,

tests/integration/defs/disaggregated/test_configs/disagg_config_ngram.yaml

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,6 @@ context_servers:
1212
backend: "DEFAULT"
1313
urls:
1414
- "localhost:8001"
15-
use_torch_sampler: True
1615
generation_servers:
1716
num_instances: 1
1817
tensor_parallel_size: 1

tests/integration/defs/disaggregated/test_configs/disagg_config_torch_sampler.yaml renamed to tests/integration/defs/disaggregated/test_configs/disagg_config_trtllm_sampler.yaml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@ context_servers:
1111
max_seq_len: 4096
1212
tensor_parallel_size: 1
1313
pipeline_parallel_size: 1
14-
use_torch_sampler: True
14+
sampler_type: "TRTLLMSampler"
1515
kv_cache_config:
1616
free_gpu_memory_fraction: 0.2
1717
enable_partial_reuse: False
@@ -27,7 +27,7 @@ generation_servers:
2727
max_batch_size: 256
2828
max_num_tokens: 4096
2929
max_seq_len: 4096
30-
use_torch_sampler: True
30+
sampler_type: "TRTLLMSampler"
3131
kv_cache_config:
3232
free_gpu_memory_fraction: 0.2
3333
enable_partial_reuse: False

tests/integration/defs/disaggregated/test_disaggregated.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -57,8 +57,8 @@ def get_test_config(test_desc, example_dir, test_root):
5757
(2, f"{test_configs_root}/disagg_config_cuda_graph_padding.yaml"),
5858
"mixed": (2, f"{test_configs_root}/disagg_config_mixed.yaml"),
5959
"overlap": (2, f"{test_configs_root}/disagg_config_overlap.yaml"),
60-
"torch_sampler":
61-
(2, f"{test_configs_root}/disagg_config_torch_sampler.yaml"),
60+
"trtllm_sampler":
61+
(2, f"{test_configs_root}/disagg_config_trtllm_sampler.yaml"),
6262
"load_balance":
6363
(4, f"{test_configs_root}/disagg_config_load_balance.yaml"),
6464
"cache_aware_balance":
@@ -213,7 +213,7 @@ def run_disaggregated_test(example_dir,
213213
poll_procs=[workers_proc, server_proc])
214214

215215
# Run the chat completion endpoint test only for TinyLlama
216-
if test_desc == "overlap" or test_desc == "torch_sampler":
216+
if test_desc == "overlap" or test_desc == "trtllm_sampler":
217217
chat_client_cmd = client_cmd + [
218218
'-e', 'chat', '-o', 'output_chat.json'
219219
]
@@ -236,7 +236,7 @@ def run_disaggregated_test(example_dir,
236236
not_expected_strings = ["Berlin Berlin"]
237237

238238
output_files = ['output.json', 'output_streaming.json']
239-
if test_desc == "overlap" or test_desc == "torch_sampler":
239+
if test_desc == "overlap" or test_desc == "trtllm_sampler":
240240
# Disable streaming chat completion for overlap test
241241
# due to bug
242242
output_files.extend(['output_chat.json'])
@@ -513,9 +513,9 @@ def test_disaggregated_overlap(disaggregated_test_root, llm_venv,
513513

514514
@pytest.mark.parametrize("llama_model_root", ['TinyLlama-1.1B-Chat-v1.0'],
515515
indirect=True)
516-
def test_disaggregated_torch_sampler(disaggregated_test_root, llm_venv,
517-
disaggregated_example_root,
518-
llama_model_root):
516+
def test_disaggregated_trtllm_sampler(disaggregated_test_root, llm_venv,
517+
disaggregated_example_root,
518+
llama_model_root):
519519
src_dst_dict = {
520520
llama_model_root:
521521
f"{llm_venv.get_working_directory()}/TinyLlama/TinyLlama-1.1B-Chat-v1.0",
@@ -526,7 +526,7 @@ def test_disaggregated_torch_sampler(disaggregated_test_root, llm_venv,
526526
os.symlink(src, dst, target_is_directory=True)
527527

528528
run_disaggregated_test(disaggregated_example_root,
529-
"torch_sampler",
529+
"trtllm_sampler",
530530
env=llm_venv._new_env,
531531
cwd=llm_venv.get_working_directory())
532532

0 commit comments

Comments
 (0)