diff --git a/docs/source/deployment-guide/quick-start-recipe-for-llama4-scout-on-trtllm.md b/docs/source/deployment-guide/quick-start-recipe-for-llama4-scout-on-trtllm.md index ea6668c04ab..6ec972dc454 100644 --- a/docs/source/deployment-guide/quick-start-recipe-for-llama4-scout-on-trtllm.md +++ b/docs/source/deployment-guide/quick-start-recipe-for-llama4-scout-on-trtllm.md @@ -65,14 +65,9 @@ cuda_graph_config: max_batch_size: 1024 kv_cache_config: dtype: fp8 -use_torch_sampler: true EOF ``` -> Here `use_torch_sampler: true` is added as a temporary WAR to solve illegal memory access issue when using trtllm native sampler. -> -> TODO: Remove this after the issue is resolved - ### Launch the TRT-LLM Server Below is an example command to launch the TRT-LLM server with the Llama-4-Scout-17B-16E-Instruct-FP8 model from within the container. The command is specifically configured for the 1024/1024 Input/Output Sequence Length test. The explanation of each flag is shown in the “Configs and Parameters” section. diff --git a/examples/llm-api/quickstart_advanced.py b/examples/llm-api/quickstart_advanced.py index 6fc6fc28f4e..61240b496de 100644 --- a/examples/llm-api/quickstart_advanced.py +++ b/examples/llm-api/quickstart_advanced.py @@ -65,9 +65,9 @@ def add_llm_args(parser): parser.add_argument('--attention_dp_batching_wait_iters', type=int, default=0) - parser.add_argument('--use_torch_sampler', - default=False, - action='store_true') + parser.add_argument('--sampler_type', + default="auto", + choices=["auto", "TorchSampler", "TRTLLMSampler"]) parser.add_argument('--tp_size', type=int, default=1) parser.add_argument('--pp_size', type=int, default=1) parser.add_argument('--moe_ep_size', type=int, default=-1) @@ -230,7 +230,7 @@ def setup_llm(args, **kwargs): args.use_piecewise_cuda_graph) if args.use_torch_compile else None, moe_config=MoeConfig(backend=args.moe_backend), - use_torch_sampler=args.use_torch_sampler, + sampler_type=args.sampler_type, max_seq_len=args.max_seq_len, max_batch_size=args.max_batch_size, max_num_tokens=args.max_num_tokens, diff --git a/tensorrt_llm/_torch/pyexecutor/_util.py b/tensorrt_llm/_torch/pyexecutor/_util.py index 8784c1fa409..480aba79bf9 100644 --- a/tensorrt_llm/_torch/pyexecutor/_util.py +++ b/tensorrt_llm/_torch/pyexecutor/_util.py @@ -8,10 +8,9 @@ import tensorrt_llm import tensorrt_llm.bindings.executor as trtllm from tensorrt_llm._torch.model_config import ModelConfig -from tensorrt_llm._torch.pyexecutor.config import PyTorchConfig from tensorrt_llm._utils import str_dtype_to_binding, torch_dtype_to_str from tensorrt_llm.bindings.executor import DecodingMode, ExecutorConfig -from tensorrt_llm.llmapi.llm_args import PeftCacheConfig +from tensorrt_llm.llmapi.llm_args import PeftCacheConfig, SamplerType from tensorrt_llm.logger import logger from tensorrt_llm.lora_helper import (LoraConfig, get_default_trtllm_modules_to_hf_modules) @@ -589,20 +588,24 @@ def instantiate_sampler(engine: PyTorchModelEngine, mapping, max_seq_len=engine.max_seq_len, enable_mixed_sampler=pytorch_backend_config.enable_mixed_sampler) + decoding_mode = get_decoding_mode(executor_config) if mapping.cp_config.get('cp_type') == CpType.STAR: assert pytorch_backend_config.attn_backend == "FLASHINFER_STAR_ATTENTION", "attention backend of star attention should be 'FLASHINFER_STAR_ATTENTION'" return TorchSampler(sampler_args) if engine.spec_config is not None and engine.spec_config.spec_dec_mode.has_spec_decoder( ): return get_spec_decoder(sampler_args, engine.spec_config) - if pytorch_backend_config.use_torch_sampler or pytorch_backend_config.enable_mixed_sampler or engine.spec_config is not None: - return TorchSampler(sampler_args) + if pytorch_backend_config.sampler_type == SamplerType.TRTLLMSampler or ( + pytorch_backend_config.sampler_type == SamplerType.auto + and decoding_mode.isBeamSearch()): + logger.debug(f"DecodingMode: {decoding_mode.name}") + return TRTLLMSampler(executor_config, engine.model, engine.dtype, + mapping, decoding_mode, + pytorch_backend_config.disable_overlap_scheduler) if not engine.model.model_config.is_generation: # NOTE: choose sampler based on model type return EarlyStopSampler() - return TRTLLMSampler(executor_config, engine.model, engine.dtype, mapping, - get_decoding_mode(executor_config), - pytorch_backend_config.disable_overlap_scheduler) + return TorchSampler(sampler_args) def get_decoding_mode(executor_config: ExecutorConfig) -> DecodingMode: @@ -623,90 +626,6 @@ def get_decoding_mode(executor_config: ExecutorConfig) -> DecodingMode: ) decoding_mode = DecodingMode.TopKTopP() - # Override decoding mode when Medusa is used - if getattr(executor_config.speculative_config, "is_medusa", - False) and not decoding_mode.isMedusa(): - logger.warning( - "Model is Medusa, but decoding mode is not Medusa. Overwriting decoding mode to Medusa." - ) - decoding_mode = DecodingMode.Medusa() - - # Override decoding mode when Medusa is not used - if (not getattr(executor_config.speculative_config, "is_medusa", - False)) and decoding_mode.isMedusa(): - logger.warning( - "Model is not Medusa, but decoding mode is Medusa. Overwriting decoding mode." - ) - if executor_config.max_beam_width == 1: - decoding_mode = DecodingMode.TopKTopP() - else: - decoding_mode = DecodingMode.BeamSearch() - - # Override decoding mode when lookahead decoding is used - if getattr(executor_config.speculative_config, "is_lookahead", - False) and not decoding_mode.isLookahead(): - logger.warning( - "Model is Lookahead, but decoding mode is not Lookahead. Overwriting decoding mode to Lookahead." - ) - decoding_mode = DecodingMode.Lookahead() - - # Override decoding mode when lookahead decoding is not used - if (not getattr(executor_config.speculative_config, "is_lookahead", - False)) and decoding_mode.isLookahead(): - logger.warning( - "Model is not built with Lookahead decoding, but decoding mode is Lookahead. Overwriting decoding mode." - ) - if executor_config.max_beam_width == 1: - decoding_mode = DecodingMode.TopKTopP() - else: - decoding_mode = DecodingMode.BeamSearch() - - # Override decoding mode when 'explicit draft tokens' is used - if getattr(executor_config.speculative_config, "is_explicit_draft_tokens", - False) and not decoding_mode.isExplicitDraftTokens(): - logger.warning( - "Model is built with 'explicit draft tokens' decoding, but decoding mode is something else. Overwriting decoding mode." - ) - decoding_mode = DecodingMode.ExplicitDraftTokens() - - # Override decoding mode when 'explicit draft tokens' is not used - if (not getattr(executor_config.speculative_config, - "is_explicit_draft_tokens", - False)) and decoding_mode.isExplicitDraftTokens(): - logger.warning( - "Model is not built with 'explicit draft tokens' decoding, but decoding mode is set to it. Overwriting decoding mode to default." - ) - if executor_config.max_beam_width == 1: - decoding_mode = DecodingMode.TopKTopP() - else: - decoding_mode = DecodingMode.BeamSearch() - - # Override decoding mode when EAGLE is used - if getattr(executor_config.speculative_config, "is_eagle", - False) and not decoding_mode.isEagle(): - logger.warning( - "Model is Eagle, but decoding mode is not Eagle. Overwriting decoding mode to Eagle." - ) - decoding_mode = DecodingMode.Eagle() - - # Override decoding mode when Eagle is not used - if (not getattr(executor_config.speculative_config, "is_eagle", - False)) and decoding_mode.isEagle(): - logger.warning( - "Model is not Eagle, but decoding mode is Eagle. Overwriting decoding mode." - ) - if executor_config.max_beam_width == 1: - decoding_mode = DecodingMode.TopKTopP() - else: - decoding_mode = DecodingMode.BeamSearch() - - # Override decoding mode when draft tokens are external - if getattr(executor_config.speculative_config, "is_draft_tokens_external", - False): - logger.warning("Overwriting decoding mode to external draft token") - decoding_mode = DecodingMode.ExternalDraftTokens() - - logger.debug(f"DecodingMode: {decoding_mode.name}") return decoding_mode diff --git a/tensorrt_llm/_torch/pyexecutor/config.py b/tensorrt_llm/_torch/pyexecutor/config.py index 226bd0880ee..631f974db26 100644 --- a/tensorrt_llm/_torch/pyexecutor/config.py +++ b/tensorrt_llm/_torch/pyexecutor/config.py @@ -6,7 +6,7 @@ from tensorrt_llm.bindings.executor import ExecutorConfig from ...builder import BuildConfig -from ...llmapi.llm_args import LoadFormat +from ...llmapi.llm_args import LoadFormat, SamplerType from ...logger import logger from ...mapping import Mapping from ..model_config import MoeLoadBalancerConfig @@ -60,9 +60,10 @@ class PyTorchConfig: If true, will iterate over sampling_params of each request and use the corresponding sampling strategy, e.g. top-k, top-p, etc. """ - use_torch_sampler: bool = False + sampler_type: SamplerType = SamplerType.auto """ - If true, will use the Torch sampler instead of the TRTLLM sampler. + The type of sampler to use. Options are TRTLLMSampler, TorchSampler or auto. + Defaults to auto, which will use TorchSampler unless BeamSearch is requested. """ kv_cache_dtype: str = "auto" diff --git a/tensorrt_llm/llmapi/llm_args.py b/tensorrt_llm/llmapi/llm_args.py index 0f377657261..abc41b00356 100644 --- a/tensorrt_llm/llmapi/llm_args.py +++ b/tensorrt_llm/llmapi/llm_args.py @@ -1968,6 +1968,13 @@ class LoadFormat(Enum): DUMMY = 1 +class SamplerType(StrEnum): + """Enum for sampler type options.""" + TRTLLMSampler = "TRTLLMSampler" + TorchSampler = "TorchSampler" + auto = "auto" + + class TorchCompileConfig(StrictBaseModel): """ Configuration for torch.compile. @@ -2055,11 +2062,11 @@ class TorchLlmArgs(BaseLlmArgs): "If true, will iterate over sampling_params of each request and use the corresponding sampling strategy, e.g. top-k, top-p, etc.", status="beta") - use_torch_sampler: bool = Field( - default=False, + sampler_type: Union[str, SamplerType] = Field( + default=SamplerType.auto, description= - "If true, will use the Torch sampler instead of the TRTLLM sampler.", - status="beta") + "The type of sampler to use. Options are TRTLLMSampler, TorchSampler or auto. Defaults to auto, which will use TorchSampler unless BeamSearch is requested.", + status="prototype") enable_iter_perf_stats: bool = Field( default=False, @@ -2344,7 +2351,7 @@ def get_pytorch_backend_config(self) -> "PyTorchConfig": attn_backend=self.attn_backend, moe_backend=self.moe_config.backend, enable_mixed_sampler=self.enable_mixed_sampler, - use_torch_sampler=self.use_torch_sampler, + sampler_type=self.sampler_type, kv_cache_dtype=self.kv_cache_config.dtype, mamba_ssm_cache_dtype=self.kv_cache_config.mamba_ssm_cache_dtype, enable_iter_perf_stats=self.enable_iter_perf_stats, diff --git a/tensorrt_llm/llmapi/utils.py b/tensorrt_llm/llmapi/utils.py index 8b2e516dba2..65000841909 100644 --- a/tensorrt_llm/llmapi/utils.py +++ b/tensorrt_llm/llmapi/utils.py @@ -3,6 +3,7 @@ import hashlib import io import os +import re import sys import tempfile import threading @@ -508,8 +509,10 @@ def generate_api_docs_as_docstring(model: Type[BaseModel], type_str = str(type_hints[field_name]) type_str = type_str.replace("typing.", "") # Extract just the class name from full class path - if "", r""]: + if (match := re.match(regex, type_str)) is not None: + type_str = match.group(1) + break else: type_str = field_type or 'Any' diff --git a/tests/integration/defs/accuracy/test_llm_api_pytorch.py b/tests/integration/defs/accuracy/test_llm_api_pytorch.py index 9cb504d5b13..889733057b1 100644 --- a/tests/integration/defs/accuracy/test_llm_api_pytorch.py +++ b/tests/integration/defs/accuracy/test_llm_api_pytorch.py @@ -196,7 +196,7 @@ def test_fp8_4gpus(self, tp_size, pp_size, fp8kv, attn_backend, @skip_pre_hopper def test_fp8_llm_sampler(self): model_path = f"{llm_models_root()}/llama-3.1-model/Llama-3.1-8B-Instruct-FP8" - with LLM(model_path, use_torch_sampler=True, max_batch_size=256) as llm: + with LLM(model_path, max_batch_size=256) as llm: assert llm.args.quant_config.quant_algo == QuantAlgo.FP8 sampling_params = SamplingParams( @@ -229,7 +229,6 @@ def test_fp8_beam_search(self): max_beam_width=max_beam_width, max_batch_size=16, max_seq_len=1024, - use_torch_sampler=False, build_config=None) with llm: @@ -2011,8 +2010,7 @@ def test_fp8_block_scales(self, tp_size, pp_size, ep_size, attention_dp, cuda_graph, overlap_scheduler): pytorch_config = dict( disable_overlap_scheduler=not overlap_scheduler, - cuda_graph_config=CudaGraphConfig() if cuda_graph else None, - use_torch_sampler=True) + cuda_graph_config=CudaGraphConfig() if cuda_graph else None) with LLM(f"{llm_models_root()}/Qwen3/Qwen3-8B-FP8", tensor_parallel_size=tp_size, @@ -2034,8 +2032,7 @@ def test_bf16(self, tp_size, pp_size, ep_size, attention_dp, cuda_graph, overlap_scheduler): pytorch_config = dict( disable_overlap_scheduler=not overlap_scheduler, - cuda_graph_config=CudaGraphConfig() if cuda_graph else None, - use_torch_sampler=True) + cuda_graph_config=CudaGraphConfig() if cuda_graph else None) with LLM(f"{llm_models_root()}/Qwen3/Qwen3-8B", tensor_parallel_size=tp_size, diff --git a/tests/integration/defs/disaggregated/test_configs/disagg_config_ngram.yaml b/tests/integration/defs/disaggregated/test_configs/disagg_config_ngram.yaml index fad36aac4d8..4e3417c732a 100644 --- a/tests/integration/defs/disaggregated/test_configs/disagg_config_ngram.yaml +++ b/tests/integration/defs/disaggregated/test_configs/disagg_config_ngram.yaml @@ -12,7 +12,6 @@ context_servers: backend: "DEFAULT" urls: - "localhost:8001" - use_torch_sampler: True generation_servers: num_instances: 1 tensor_parallel_size: 1 diff --git a/tests/integration/defs/disaggregated/test_configs/disagg_config_torch_sampler.yaml b/tests/integration/defs/disaggregated/test_configs/disagg_config_trtllm_sampler.yaml similarity index 93% rename from tests/integration/defs/disaggregated/test_configs/disagg_config_torch_sampler.yaml rename to tests/integration/defs/disaggregated/test_configs/disagg_config_trtllm_sampler.yaml index f4b06f1d14e..287d1103a4f 100644 --- a/tests/integration/defs/disaggregated/test_configs/disagg_config_torch_sampler.yaml +++ b/tests/integration/defs/disaggregated/test_configs/disagg_config_trtllm_sampler.yaml @@ -11,7 +11,7 @@ context_servers: max_seq_len: 4096 tensor_parallel_size: 1 pipeline_parallel_size: 1 - use_torch_sampler: True + sampler_type: "TRTLLMSampler" kv_cache_config: free_gpu_memory_fraction: 0.2 enable_partial_reuse: False @@ -27,7 +27,7 @@ generation_servers: max_batch_size: 256 max_num_tokens: 4096 max_seq_len: 4096 - use_torch_sampler: True + sampler_type: "TRTLLMSampler" kv_cache_config: free_gpu_memory_fraction: 0.2 enable_partial_reuse: False diff --git a/tests/integration/defs/disaggregated/test_disaggregated.py b/tests/integration/defs/disaggregated/test_disaggregated.py index 40b112b234b..5f871163d93 100644 --- a/tests/integration/defs/disaggregated/test_disaggregated.py +++ b/tests/integration/defs/disaggregated/test_disaggregated.py @@ -55,8 +55,8 @@ def get_test_config(test_desc, example_dir, test_root): (2, f"{test_configs_root}/disagg_config_cuda_graph_padding.yaml"), "mixed": (2, f"{test_configs_root}/disagg_config_mixed.yaml"), "overlap": (2, f"{test_configs_root}/disagg_config_overlap.yaml"), - "torch_sampler": - (2, f"{test_configs_root}/disagg_config_torch_sampler.yaml"), + "trtllm_sampler": + (2, f"{test_configs_root}/disagg_config_trtllm_sampler.yaml"), "load_balance": (4, f"{test_configs_root}/disagg_config_load_balance.yaml"), "cache_aware_balance": @@ -211,7 +211,7 @@ def run_disaggregated_test(example_dir, poll_procs=[workers_proc, server_proc]) # Run the chat completion endpoint test only for TinyLlama - if test_desc == "overlap" or test_desc == "torch_sampler": + if test_desc == "overlap" or test_desc == "trtllm_sampler": chat_client_cmd = client_cmd + [ '-e', 'chat', '-o', 'output_chat.json' ] @@ -234,7 +234,7 @@ def run_disaggregated_test(example_dir, not_expected_strings = ["Berlin Berlin"] output_files = ['output.json', 'output_streaming.json'] - if test_desc == "overlap" or test_desc == "torch_sampler": + if test_desc == "overlap" or test_desc == "trtllm_sampler": # Disable streaming chat completion for overlap test # due to bug output_files.extend(['output_chat.json']) @@ -488,9 +488,9 @@ def test_disaggregated_overlap(disaggregated_test_root, llm_venv, @pytest.mark.parametrize("llama_model_root", ['TinyLlama-1.1B-Chat-v1.0'], indirect=True) -def test_disaggregated_torch_sampler(disaggregated_test_root, llm_venv, - disaggregated_example_root, - llama_model_root): +def test_disaggregated_trtllm_sampler(disaggregated_test_root, llm_venv, + disaggregated_example_root, + llama_model_root): src_dst_dict = { llama_model_root: f"{llm_venv.get_working_directory()}/TinyLlama/TinyLlama-1.1B-Chat-v1.0", @@ -501,7 +501,7 @@ def test_disaggregated_torch_sampler(disaggregated_test_root, llm_venv, os.symlink(src, dst, target_is_directory=True) run_disaggregated_test(disaggregated_example_root, - "torch_sampler", + "trtllm_sampler", env=llm_venv._new_env, cwd=llm_venv.get_working_directory()) diff --git a/tests/integration/test_lists/qa/llm_function_full.txt b/tests/integration/test_lists/qa/llm_function_full.txt index b523aa1ecaa..ab43ce124f9 100644 --- a/tests/integration/test_lists/qa/llm_function_full.txt +++ b/tests/integration/test_lists/qa/llm_function_full.txt @@ -689,7 +689,7 @@ disaggregated/test_disaggregated.py::test_disaggregated_deepseek_v3_lite_fp8_att disaggregated/test_disaggregated.py::test_disaggregated_deepseek_v3_lite_fp8_attention_dp_one_mtp[DeepSeek-V3-Lite-fp8] disaggregated/test_disaggregated.py::test_disaggregated_load_balance[TinyLlama-1.1B-Chat-v1.0] disaggregated/test_disaggregated.py::test_disaggregated_cache_aware_balance[TinyLlama-1.1B-Chat-v1.0] -disaggregated/test_disaggregated.py::test_disaggregated_torch_sampler[TinyLlama-1.1B-Chat-v1.0] +disaggregated/test_disaggregated.py::test_disaggregated_trtllm_sampler[TinyLlama-1.1B-Chat-v1.0] disaggregated/test_disaggregated_single_gpu.py::test_disaggregated_simple_qwen3[False-False-Qwen3-8B-FP8] disaggregated/test_disaggregated_single_gpu.py::test_disaggregated_simple_qwen3[False-True-Qwen3-8B-FP8] disaggregated/test_disaggregated_single_gpu.py::test_disaggregated_simple_qwen3[True-False-Qwen3-8B-FP8] diff --git a/tests/integration/test_lists/qa/llm_function_sanity.txt b/tests/integration/test_lists/qa/llm_function_sanity.txt index 4c92e077d87..aeaa1ba573b 100644 --- a/tests/integration/test_lists/qa/llm_function_sanity.txt +++ b/tests/integration/test_lists/qa/llm_function_sanity.txt @@ -122,7 +122,7 @@ disaggregated/test_disaggregated.py::test_disaggregated_deepseek_v3_lite_fp8_mpi disaggregated/test_disaggregated.py::test_disaggregated_deepseek_v3_lite_fp8_nixl[DeepSeek-V3-Lite-fp8] disaggregated/test_disaggregated.py::test_disaggregated_load_balance[TinyLlama-1.1B-Chat-v1.0] disaggregated/test_disaggregated.py::test_disaggregated_cache_aware_balance[TinyLlama-1.1B-Chat-v1.0] -disaggregated/test_disaggregated.py::test_disaggregated_torch_sampler[TinyLlama-1.1B-Chat-v1.0] +disaggregated/test_disaggregated.py::test_disaggregated_trtllm_sampler[TinyLlama-1.1B-Chat-v1.0] disaggregated/test_disaggregated.py::test_disaggregated_multi_gpu_with_mpirun[TinyLlama-1.1B-Chat-v1.0] disaggregated/test_disaggregated.py::test_disaggregated_single_gpu_with_mpirun_trt_backend[TinyLlama-1.1B-Chat-v1.0] disaggregated/test_disaggregated.py::test_disaggregated_single_gpu_with_mpirun[TinyLlama-1.1B-Chat-v1.0] diff --git a/tests/scripts/perf-sanity/run_benchmark_serve.py b/tests/scripts/perf-sanity/run_benchmark_serve.py index d9c69314879..2d4928ae325 100644 --- a/tests/scripts/perf-sanity/run_benchmark_serve.py +++ b/tests/scripts/perf-sanity/run_benchmark_serve.py @@ -308,10 +308,6 @@ def generate_extra_llm_api_config(self, test_case: Dict[str, Any]) -> str: " enable_block_reuse: false", ] - # https://nvbugs/5437106: WAR to avoid illegal memory access in Scout - if "Scout" in test_case['model']: - config_lines.append("use_torch_sampler: true") - # Add moe_config if moe_backend is specified if test_case['moe_backend']: config_lines.append("moe_config:") diff --git a/tests/unittest/_torch/modeling/test_modeling_nemotron_h.py b/tests/unittest/_torch/modeling/test_modeling_nemotron_h.py index 269d43596e7..f5c993a785e 100644 --- a/tests/unittest/_torch/modeling/test_modeling_nemotron_h.py +++ b/tests/unittest/_torch/modeling/test_modeling_nemotron_h.py @@ -46,6 +46,7 @@ def create_nemotron_h_llm(use_cuda_graph, enable_block_reuse=False, mamba_ssm_cache_dtype="auto" if mamba_ssm_cache_dtype is None else mamba_ssm_cache_dtype), + sampler_type="TRTLLMSampler", ) diff --git a/tests/unittest/_torch/multi_gpu_modeling/test_llama4.py b/tests/unittest/_torch/multi_gpu_modeling/test_llama4.py index fa186ce1e47..5c374d0f2aa 100644 --- a/tests/unittest/_torch/multi_gpu_modeling/test_llama4.py +++ b/tests/unittest/_torch/multi_gpu_modeling/test_llama4.py @@ -72,7 +72,6 @@ def test_llama4(model_name, backend, tp_size, use_cuda_graph, pipeline_parallel_size=pp_size, enable_attention_dp=enable_attention_dp, kv_cache_config=kv_cache_config, - use_torch_sampler=True, enable_chunked_prefill=True, ) with llm: diff --git a/tests/unittest/_torch/speculative/test_draft_target.py b/tests/unittest/_torch/speculative/test_draft_target.py index bbc2f1484e6..05e55b0ea7c 100644 --- a/tests/unittest/_torch/speculative/test_draft_target.py +++ b/tests/unittest/_torch/speculative/test_draft_target.py @@ -41,7 +41,6 @@ def test_llama_draft_target(use_cuda_graph: bool, attn_backend: str): max_batch_size=max_batch_size, kv_cache_config=kv_cache_config, max_num_tokens=2048, - use_torch_sampler=True, ) spec_config = DraftTargetDecodingConfig( diff --git a/tests/unittest/_torch/speculative/test_eagle3.py b/tests/unittest/_torch/speculative/test_eagle3.py index 56228b4b77a..ffb8e33766a 100644 --- a/tests/unittest/_torch/speculative/test_eagle3.py +++ b/tests/unittest/_torch/speculative/test_eagle3.py @@ -60,7 +60,6 @@ def test_llama_eagle3(use_cuda_graph: bool, attn_backend: str, # in this test. max_seq_len=8192, enable_chunked_prefill=enable_chunked_prefill, - use_torch_sampler=True, ) if enable_chunked_prefill: # Use a small max_num_tokens so that the chunked prefill path gets exercised. diff --git a/tests/unittest/_torch/test_overlap_scheduler.py b/tests/unittest/_torch/test_overlap_scheduler.py index 5ac3044a2eb..8d7406aacc8 100644 --- a/tests/unittest/_torch/test_overlap_scheduler.py +++ b/tests/unittest/_torch/test_overlap_scheduler.py @@ -21,10 +21,10 @@ def model_path(): return llm_models_root() / "llama-models-v2/TinyLlama-1.1B-Chat-v1.0" -def create_llm(model_dir, disable_overlap_scheduler, use_torch_sampler): +def create_llm(model_dir, disable_overlap_scheduler, sampler_type): """Create LLM with specific overlap scheduler setting""" pytorch_config = dict(disable_overlap_scheduler=disable_overlap_scheduler, - use_torch_sampler=use_torch_sampler) + sampler_type=sampler_type) trt_kv_cache_config = TRT_KvCacheConfig(enable_block_reuse=False) @@ -41,16 +41,15 @@ def create_llm(model_dir, disable_overlap_scheduler, use_torch_sampler): ) -@pytest.mark.parametrize("use_torch_sampler", [False, True]) +@pytest.mark.parametrize("sampler_type", ["TorchSampler", "TRTLLMSampler"]) @pytest.mark.high_cuda_memory -def test_overlap_scheduler_consistency(model_path, test_case, - use_torch_sampler): +def test_overlap_scheduler_consistency(model_path, test_case, sampler_type): # Test configuration prompts = test_case["prompts"] max_new_tokens = test_case["max_new_tokens"] temperature = test_case["temperature"] top_p = test_case["top_p"] - stop_words = test_case["stop_words"] if not use_torch_sampler else None + stop_words = test_case["stop_words"] sampling_config = SamplingParams(max_tokens=max_new_tokens, stop=stop_words, @@ -62,7 +61,7 @@ def test_overlap_scheduler_consistency(model_path, test_case, # Test with overlap scheduler enabled llm = create_llm(model_path, disable_overlap_scheduler=False, - use_torch_sampler=use_torch_sampler) + sampler_type=sampler_type) outputs_with_overlap = llm.generate(prompts, sampling_params=sampling_config, use_tqdm=True) @@ -74,7 +73,7 @@ def test_overlap_scheduler_consistency(model_path, test_case, # Test with overlap scheduler disabled llm = create_llm(model_path, disable_overlap_scheduler=True, - use_torch_sampler=use_torch_sampler) + sampler_type=sampler_type) outputs_without_overlap = llm.generate(prompts, sampling_params=sampling_config, use_tqdm=True) diff --git a/tests/unittest/_torch/test_return_logits.py b/tests/unittest/_torch/test_return_logits.py index 9010834a6f9..0d6a5e28ca6 100644 --- a/tests/unittest/_torch/test_return_logits.py +++ b/tests/unittest/_torch/test_return_logits.py @@ -16,10 +16,10 @@ @pytest.mark.parametrize("return_log_probs", [False, True]) @pytest.mark.parametrize("gather_generation_logits", [False, True]) @pytest.mark.parametrize("gather_context_logits", [False, True]) -@pytest.mark.parametrize("use_torch_sampler", [False, True]) +@pytest.mark.parametrize("sampler_type", ["TRTLLMSampler", "TorchSampler"]) @pytest.mark.parametrize("disable_overlap_scheduler", [False, True]) def test_generate_with_return_logits(disable_overlap_scheduler: bool, - use_torch_sampler: bool, + sampler_type: str, gather_context_logits: bool, gather_generation_logits: bool, return_log_probs: bool): @@ -27,7 +27,7 @@ def test_generate_with_return_logits(disable_overlap_scheduler: bool, or return_log_probs): # prune space pytest.skip("Nothing to test") - if use_torch_sampler and gather_context_logits: + if sampler_type == "TorchSampler" and gather_context_logits: pytest.skip("TorchSampler does not support gather_context_logits") build_config = BuildConfig() @@ -41,7 +41,7 @@ def test_generate_with_return_logits(disable_overlap_scheduler: bool, gather_generation_logits=gather_generation_logits, max_batch_size= 128, # reduce buffer sizes, specially for generation logits - use_torch_sampler=use_torch_sampler, + sampler_type=sampler_type, disable_overlap_scheduler=disable_overlap_scheduler, ) @@ -83,10 +83,10 @@ def test_generate_with_return_logits(disable_overlap_scheduler: bool, @pytest.mark.parametrize("return_log_probs", [False, True]) @pytest.mark.parametrize("gather_generation_logits", [False, True]) @pytest.mark.parametrize("gather_context_logits", [False, True]) -@pytest.mark.parametrize("use_torch_sampler", [False, True]) +@pytest.mark.parametrize("sampler_type", ["TRTLLMSampler", "TorchSampler"]) @pytest.mark.parametrize("disable_overlap_scheduler", [False, True]) def test_generate_async_with_return_logits(disable_overlap_scheduler: bool, - use_torch_sampler: bool, + sampler_type: str, gather_context_logits: bool, gather_generation_logits: bool, return_log_probs: bool): @@ -94,7 +94,7 @@ def test_generate_async_with_return_logits(disable_overlap_scheduler: bool, or return_log_probs): # prune space pytest.skip("Nothing to test") - if use_torch_sampler and gather_context_logits: + if sampler_type == "TorchSampler" and gather_context_logits: pytest.skip("TorchSampler does not support gather_context_logits") build_config = BuildConfig() @@ -108,7 +108,7 @@ def test_generate_async_with_return_logits(disable_overlap_scheduler: bool, gather_generation_logits=gather_generation_logits, max_batch_size= 128, # reduce buffer sizes, specially for generation logits - use_torch_sampler=use_torch_sampler, + sampler_type=sampler_type, disable_overlap_scheduler=disable_overlap_scheduler, ) sampling_params = SamplingParams( diff --git a/tests/unittest/api_stability/api_stability_core.py b/tests/unittest/api_stability/api_stability_core.py index 2278fad2011..61650d59097 100644 --- a/tests/unittest/api_stability/api_stability_core.py +++ b/tests/unittest/api_stability/api_stability_core.py @@ -27,6 +27,7 @@ from tensorrt_llm.llmapi import (CalibConfig, CompletionOutput, GuidedDecodingParams, QuantConfig, RequestOutput, SamplingParams) +from tensorrt_llm.llmapi.llm_args import SamplerType from tensorrt_llm.llmapi.llm_utils import LlmArgs from tensorrt_llm.logger import Singleton diff --git a/tests/unittest/api_stability/references/llm.yaml b/tests/unittest/api_stability/references/llm.yaml index 5a846dd7869..86f740c3844 100644 --- a/tests/unittest/api_stability/references/llm.yaml +++ b/tests/unittest/api_stability/references/llm.yaml @@ -111,9 +111,9 @@ methods: annotation: bool default: False status: beta - use_torch_sampler: - annotation: bool - default: False + sampler_type: + annotation: Union[str, tensorrt_llm.llmapi.llm_args.SamplerType] + default: auto status: beta enable_iter_perf_stats: annotation: bool diff --git a/tests/unittest/llmapi/apps/_test_openai_chat.py b/tests/unittest/llmapi/apps/_test_openai_chat.py index e59a0fae9fa..6e58b094783 100644 --- a/tests/unittest/llmapi/apps/_test_openai_chat.py +++ b/tests/unittest/llmapi/apps/_test_openai_chat.py @@ -533,10 +533,10 @@ def test_stop_reason(client: openai.OpenAI, model_name: str, backend: str): 'server_with_custom_sampler', [ { - 'use_torch_sampler': True + 'sampler_type': "TorchSampler" }, # torch_sampler { - 'use_torch_sampler': False + 'sampler_type': "TRTLLMSampler" }, # trtllm_sampler ], indirect=True, diff --git a/tests/unittest/llmapi/apps/_test_openai_chat_multimodal.py b/tests/unittest/llmapi/apps/_test_openai_chat_multimodal.py index ab3c5ac58c7..d92ca061672 100644 --- a/tests/unittest/llmapi/apps/_test_openai_chat_multimodal.py +++ b/tests/unittest/llmapi/apps/_test_openai_chat_multimodal.py @@ -32,7 +32,6 @@ def temp_extra_llm_api_options_file(request): "build_config": { "max_num_tokens": 16384, }, - "use_torch_sampler": True, } with open(temp_file_path, 'w') as f: diff --git a/tests/unittest/llmapi/apps/_test_openai_completions.py b/tests/unittest/llmapi/apps/_test_openai_completions.py index 4762f219960..3e1c96cff3c 100644 --- a/tests/unittest/llmapi/apps/_test_openai_completions.py +++ b/tests/unittest/llmapi/apps/_test_openai_completions.py @@ -395,10 +395,10 @@ async def test_completion_streaming(async_client: openai.AsyncOpenAI, 'server_with_custom_sampler', [ { - 'use_torch_sampler': True + 'sampler_type': "TorchSampler" }, # torch_sampler { - 'use_torch_sampler': False + 'sampler_type': "TRTLLMSampler" }, # trtllm_sampler ], indirect=True, diff --git a/tests/unittest/llmapi/apps/_test_trtllm_serve_multimodal_example.py b/tests/unittest/llmapi/apps/_test_trtllm_serve_multimodal_example.py index a86301a6748..5b28e12675c 100644 --- a/tests/unittest/llmapi/apps/_test_trtllm_serve_multimodal_example.py +++ b/tests/unittest/llmapi/apps/_test_trtllm_serve_multimodal_example.py @@ -32,7 +32,6 @@ def temp_extra_llm_api_options_file(request): "build_config": { "max_num_tokens": 16384, }, - "use_torch_sampler": True, } with open(temp_file_path, 'w') as f: diff --git a/tests/unittest/llmapi/apps/utils.py b/tests/unittest/llmapi/apps/utils.py index ae78a3180bb..073760d51f3 100644 --- a/tests/unittest/llmapi/apps/utils.py +++ b/tests/unittest/llmapi/apps/utils.py @@ -151,8 +151,7 @@ def make_server_with_custom_sampler_fixture(api_type: str) -> Callable: def server_with_custom_sampler(model_name: str, request: Any, backend: str, tmp_path: Path) -> RemoteOpenAIServer: '''Fixture to launch a server (pytorch backend only) with a custom sampler configuration.''' - use_torch_sampler = getattr(request, 'param', - {}).get('use_torch_sampler', True) + sampler_type = getattr(request, 'param', {}).get('sampler_type', "auto") if backend != 'pytorch': pytest.skip( f"Server with custom sampler is only supported for pytorch backend, skipping for {backend}" @@ -162,7 +161,7 @@ def server_with_custom_sampler(model_name: str, request: Any, backend: str, temp_file_path = tmp_path / f'test_sampler_config_{request.node.name}.yaml' extra_llm_api_options_dict = { 'enable_chunked_prefill': True, - 'use_torch_sampler': use_torch_sampler + 'sampler_type': sampler_type } with temp_file_path.open('w') as f: yaml.dump(extra_llm_api_options_dict, f) diff --git a/tests/unittest/llmapi/test_llm_pytorch.py b/tests/unittest/llmapi/test_llm_pytorch.py index be0de40eb65..6b78c46bd73 100644 --- a/tests/unittest/llmapi/test_llm_pytorch.py +++ b/tests/unittest/llmapi/test_llm_pytorch.py @@ -255,14 +255,11 @@ def test_embedding_bias_with_torch_sampler_strategies(enable_mixed_sampler, sampling_params = SamplingParams(**sampling_kwargs) - llm_test_harness( - llama_model_path, - prompts, - ["Z Z Z Z Z Z"], - sampling_params=sampling_params, - backend="pytorch", - use_torch_sampler=True, # Use TorchSampler to test all 3 paths - enable_mixed_sampler=enable_mixed_sampler) + llm_test_harness(llama_model_path, + prompts, ["Z Z Z Z Z Z"], + sampling_params=sampling_params, + backend="pytorch", + enable_mixed_sampler=enable_mixed_sampler) def llama_7b_lora_from_dir_test_harness(**llm_kwargs) -> None: