diff --git a/cpp/include/tensorrt_llm/batch_manager/logitsPostProcessor.h b/cpp/include/tensorrt_llm/batch_manager/logitsPostProcessor.h index 048a84ecca3..1916a915e33 100644 --- a/cpp/include/tensorrt_llm/batch_manager/logitsPostProcessor.h +++ b/cpp/include/tensorrt_llm/batch_manager/logitsPostProcessor.h @@ -47,7 +47,7 @@ class LogitsPostProcessor : Algorithm bool operator()(DecoderInputBuffers& inputBuffers, bool replicateLogitsPostProcessor, runtime::WorldConfig const& worldConfig, CudaStreamPtr const& stream, - std::optional logitsPostProcessorBatched = std::nullopt) const; + std::optional const& logitsPostProcessorBatched = std::nullopt) const; }; } // namespace tensorrt_llm::batch_manager diff --git a/cpp/tensorrt_llm/batch_manager/logitsPostProcessor.cpp b/cpp/tensorrt_llm/batch_manager/logitsPostProcessor.cpp index dd34de0ef9a..dbb90da326a 100644 --- a/cpp/tensorrt_llm/batch_manager/logitsPostProcessor.cpp +++ b/cpp/tensorrt_llm/batch_manager/logitsPostProcessor.cpp @@ -34,7 +34,7 @@ using SizeType32 = tensorrt_llm::runtime::SizeType32; bool LogitsPostProcessor::operator()(DecoderInputBuffers& inputBuffers, bool replicateLogitsPostProcessor, tr::WorldConfig const& worldConfig, CudaStreamPtr const& stream, - std::optional logitsPostProcessorBatched) const + std::optional const& logitsPostProcessorBatched) const { TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__); NVTX3_SCOPED_RANGE(LogitsPostProcessor); diff --git a/docs/source/commands/trtllm-serve/trtllm-serve.rst b/docs/source/commands/trtllm-serve/trtllm-serve.rst index b59a588cac1..8b7d25e7359 100644 --- a/docs/source/commands/trtllm-serve/trtllm-serve.rst +++ b/docs/source/commands/trtllm-serve/trtllm-serve.rst @@ -201,56 +201,60 @@ Metrics Endpoint .. note:: - This endpoint is beta maturity. + The metrics endpoint for the default PyTorch backend are in beta and are not as comprehensive as those for the TensorRT backend. - The statistics for the PyTorch backend are beta and not as comprehensive as those for the TensorRT backend. + Some fields, such as CPU memory usage, are not yet available for the PyTorch backend. - Some fields, such as CPU memory usage, are not available for the PyTorch backend. + Enabling ``enable_iter_perf_stats`` in the PyTorch backend can slightly impact performance, depending on the serving configuration. - Enabling ``enable_iter_perf_stats`` in the PyTorch backend can impact performance slightly, depending on the serving configuration. +The ``/metrics`` endpoint provides runtime iteration statistics such as GPU memory usage and KV cache details. -The ``/metrics`` endpoint provides runtime-iteration statistics such as GPU memory use and inflight-batching details. -For the TensorRT backend, these statistics are enabled by default. -However, for the PyTorch backend, you must explicitly enable iteration statistics logging by setting the `enable_iter_perf_stats` field in a YAML configuration file as shown in the following example: +For the default PyTorch backend, iteration statistics logging is enabled by setting the ``enable_iter_perf_stats`` field in a YAML file: .. code-block:: yaml - # extra-llm-api-config.yml - pytorch_backend_config: - enable_iter_perf_stats: true + # extra_llm_config.yaml + enable_iter_perf_stats: true -Then start the server and specify the ``--extra_llm_api_options`` argument with the path to the YAML file as shown in the following example: +Start the server and specify the ``--extra_llm_api_options`` argument with the path to the YAML file: .. code-block:: bash - trtllm-serve \ - --extra_llm_api_options \ - [--tp_size --pp_size --ep_size --host --port ] + trtllm-serve "TinyLlama/TinyLlama-1.1B-Chat-v1.0" --extra_llm_api_options extra_llm_config.yaml -After at least one inference request is sent to the server, you can fetch the runtime-iteration statistics by polling the `/metrics` endpoint: +After sending at least one inference request to the server, you can fetch runtime iteration statistics by polling the ``/metrics`` endpoint. +Since the statistics are stored in an internal queue and removed once retrieved, it's recommended to poll the endpoint shortly after each request and store the results if needed. .. code-block:: bash - curl -X GET http://:/metrics + curl -X GET http://localhost:8000/metrics -*Example Output* +Example output: .. code-block:: json - [ - { - "gpuMemUsage": 56401920000, - "inflightBatchingStats": { + [ + { + "gpuMemUsage": 76665782272, + "iter": 154, + "iterLatencyMS": 7.00688362121582, + "kvCacheStats": { + "allocNewBlocks": 3126, + "allocTotalBlocks": 3126, + "cacheHitRate": 0.00128, + "freeNumBlocks": 101253, + "maxNumBlocks": 101256, + "missedBlocks": 3121, + "reusedBlocks": 4, + "tokensPerBlock": 32, + "usedNumBlocks": 3 + }, + "numActiveRequests": 1 ... - }, - "iter": 1, - "iterLatencyMS": 16.505143404006958, - "kvCacheStats": { - ... - }, - "newActiveRequestsQueueLatencyMS": 0.0007503032684326172 - } -] + } + ] + + Syntax ------ diff --git a/docs/source/legacy/tensorrt_quickstart.md b/docs/source/legacy/tensorrt_quickstart.md index df62aa38d73..e74a0f5e9e2 100644 --- a/docs/source/legacy/tensorrt_quickstart.md +++ b/docs/source/legacy/tensorrt_quickstart.md @@ -1,7 +1,7 @@ # LLM API with TensorRT Engine A simple inference example with TinyLlama using the LLM API: -```{literalinclude} ../../examples/llm-api/_tensorrt_engine/quickstart_example.py +```{literalinclude} ../../../examples/llm-api/_tensorrt_engine/quickstart_example.py :language: python :linenos: ``` diff --git a/examples/llm-api/_tensorrt_engine/quickstart_example.py b/examples/llm-api/_tensorrt_engine/quickstart_example.py index 400a241c0e9..a6ba9ec5598 100644 --- a/examples/llm-api/_tensorrt_engine/quickstart_example.py +++ b/examples/llm-api/_tensorrt_engine/quickstart_example.py @@ -1,11 +1,17 @@ -from tensorrt_llm import LLM, SamplingParams +from tensorrt_llm import BuildConfig, SamplingParams +from tensorrt_llm._tensorrt_engine import LLM # NOTE the change def main(): + build_config = BuildConfig() + build_config.max_batch_size = 256 + build_config.max_num_tokens = 1024 + # Model could accept HF model name, a path to local HF model, # or TensorRT Model Optimizer's quantized checkpoints like nvidia/Llama-3.1-8B-Instruct-FP8 on HF. - llm = LLM(model="TinyLlama/TinyLlama-1.1B-Chat-v1.0") + llm = LLM(model="TinyLlama/TinyLlama-1.1B-Chat-v1.0", + build_config=build_config) # Sample prompts. prompts = [ diff --git a/examples/llm-api/llm_mgmn_trtllm_bench.sh b/examples/llm-api/llm_mgmn_trtllm_bench.sh index de7ee73d536..5169c00ad38 100644 --- a/examples/llm-api/llm_mgmn_trtllm_bench.sh +++ b/examples/llm-api/llm_mgmn_trtllm_bench.sh @@ -76,6 +76,7 @@ srun -l \ # This is optional cat > /tmp/pytorch_extra_args.txt << EOF +cuda_graph_config: null print_iter_log: true enable_attention_dp: false EOF diff --git a/tensorrt_llm/_torch/attention_backend/flashinfer.py b/tensorrt_llm/_torch/attention_backend/flashinfer.py index b8bf3304883..74adc69c02b 100644 --- a/tensorrt_llm/_torch/attention_backend/flashinfer.py +++ b/tensorrt_llm/_torch/attention_backend/flashinfer.py @@ -170,7 +170,8 @@ def __post_init__(self) -> None: def create_cuda_graph_metadata(self, max_batch_size: int, sub_cross_metadata: bool = False, - max_draft_tokens: int = 0) -> Self: + max_draft_tokens: int = 0, + buffers=None) -> Self: metadata = super().create_cuda_graph_metadata(max_batch_size, sub_cross_metadata, max_draft_tokens) diff --git a/tensorrt_llm/_torch/attention_backend/interface.py b/tensorrt_llm/_torch/attention_backend/interface.py index aa081e82dd4..6a035ad477c 100644 --- a/tensorrt_llm/_torch/attention_backend/interface.py +++ b/tensorrt_llm/_torch/attention_backend/interface.py @@ -140,6 +140,7 @@ class AttentionMetadata: # This buffer is currently only used for TrtllmAttentionMetadata. cache_indirection: Optional[torch.Tensor] = None + cuda_graph_buffers: dict[str, list[torch.Tensor]] = None _saved_tensors: Dict[str, torch.Tensor] = field(init=False, default_factory=dict) @@ -288,7 +289,8 @@ def prepare(self): def create_cuda_graph_metadata(self, max_batch_size: int, sub_cross_metadata: bool = False, - max_draft_tokens: int = 0) -> Self: + max_draft_tokens: int = 0, + buffers=None) -> Self: """ Creates metadata for CUDA graph execution. CUDA graphs require to use pre-allocated buffers for all tensors in fields. @@ -300,6 +302,7 @@ def create_cuda_graph_metadata(self, cuda_graph_metadata = copy.copy(self) cuda_graph_metadata.is_cuda_graph = True + cuda_graph_metadata.cuda_graph_buffers = buffers if self.has_cross_sub_metadata: cuda_graph_metadata.cross = cuda_graph_metadata.cross.create_cuda_graph_metadata( max_batch_size, True) diff --git a/tensorrt_llm/_torch/attention_backend/trtllm.py b/tensorrt_llm/_torch/attention_backend/trtllm.py index a95519f22ce..cdca67a7b95 100644 --- a/tensorrt_llm/_torch/attention_backend/trtllm.py +++ b/tensorrt_llm/_torch/attention_backend/trtllm.py @@ -600,13 +600,65 @@ def host_kv_cache_pool_mapping(self) -> Optional[torch.Tensor]: def __post_init__(self) -> None: super().__post_init__() + self._post_init_with_buffers(self.cuda_graph_buffers) + + def _post_init_with_buffers(self, buffers) -> None: + # Set a default value, as max_num_sequences is not always set. if self.max_num_sequences is None: self.max_num_sequences = self.max_num_requests - self.prompt_lens_cuda = torch.empty( + def get_empty(tensor_shape: list[int], dtype: torch.dtype, + cache_name: str) -> torch.Tensor: + """ + Finds a compatible, reusable buffer from a cache or creates a new one. + + This function searches for a pre-allocated tensor (buffer) that can be + reused for an operation involving a tensor with the shape of `tensor_shape`. + + The compatibility rules are: The buffer's total elements must be >= tensor_shape's. + + If a compatible buffer is found, it's returned immediately. Otherwise, a new + buffer is allocated on the 'cuda' device with the give properties of 'tensor_shape' and 'dtype'. + + Args: + tensor_shape: The required shape. + dtype: The required dtype. + cache_name: The key for the specific list of buffers to search in. + + Returns: + An existing compatible buffer or a newly created one. + """ + if buffers is not None: + # Safely get the list of candidates. Defaults to an empty list if key is missing. + candidate_buffers = buffers.get(cache_name, []) + numel_like = math.prod(tensor_shape) + + for buffer in candidate_buffers: + numel_buffer = buffer.numel() + + # buffer just needs to be large enough. + if numel_buffer >= numel_like: + return buffer[0:numel_like].view( + tensor_shape) # Found a fit, return immediately. + + # If we get here, no suitable buffer was found in the cache. Create a new one. + new_buffer = torch.zeros(tensor_shape, device='cuda', dtype=dtype) + if buffers is not None: + buffers.setdefault(cache_name, []).append(new_buffer) + return new_buffer + + def get_empty_like(like_tensor: torch.Tensor, + cache_name: str) -> torch.Tensor: + return get_empty( + like_tensor.shape, + cache_name=cache_name, + dtype=like_tensor.dtype, + ) + + self.prompt_lens_cuda = get_empty( (self.max_num_sequences, ), - device='cuda', + cache_name="prompt_lens_cuda", dtype=torch.int, ) self.prompt_lens_cpu = torch.empty_like( @@ -614,7 +666,10 @@ def __post_init__(self) -> None: device='cpu', pin_memory=True, ) - self.kv_lens_cuda = torch.empty_like(self.prompt_lens_cuda) + self.kv_lens_cuda = get_empty_like( + self.prompt_lens_cuda, + cache_name="kv_lens_cuda", + ) self.kv_lens = torch.empty_like(self.kv_lens_cuda, device='cpu', pin_memory=True) @@ -629,13 +684,13 @@ def __post_init__(self) -> None: dtype=torch.int8, ) if self.kv_cache_manager is not None: - self.kv_cache_block_offsets = torch.empty( + self.kv_cache_block_offsets = get_empty( [ self.kv_cache_manager.num_pools, self.max_num_sequences, 2, self.kv_cache_manager.max_blocks_per_seq ], + cache_name="kv_cache_block_offsets", dtype=torch.int32, - device='cuda', ) self.host_kv_cache_block_offsets = torch.empty_like( self.kv_cache_block_offsets, @@ -645,27 +700,27 @@ def __post_init__(self) -> None: self.block_ids_per_seq = None self.kv_block_ids_per_seq = None if self.enable_flash_mla: - self.block_ids_per_seq = torch.zeros( + self.block_ids_per_seq = get_empty( [ self.kv_cache_manager.max_batch_size, self.kv_cache_manager.max_blocks_per_seq ], + cache_name="block_ids_per_seq", dtype=torch.int32, - device='cuda', ) - self.kv_block_ids_per_seq = torch.zeros( + self.kv_block_ids_per_seq = get_empty( [ self.kv_cache_manager.max_batch_size, self.kv_cache_manager.max_blocks_per_seq ], + cache_name="kv_block_ids_per_seq", dtype=torch.int32, - device='cuda', ) if self.enable_context_mla_with_cached_kv: # for kv cache reuse/chunked context in MLA - self.ctx_cached_token_indptr = torch.zeros( + self.ctx_cached_token_indptr = get_empty( (self.max_num_requests + 1, ), - device='cuda', + cache_name="ctx_cached_token_indptr", dtype=torch.int64, ) self.host_ctx_cached_token_indptr = torch.zeros_like( @@ -673,9 +728,9 @@ def __post_init__(self) -> None: device='cpu', pin_memory=True, ) - self.ctx_uncached_token_indptr = torch.zeros( + self.ctx_uncached_token_indptr = get_empty( (self.max_num_requests + 1, ), - device='cuda', + cache_name="ctx_uncached_token_indptr", dtype=torch.int64, ) self.host_ctx_uncached_token_indptr = torch.zeros_like( @@ -684,9 +739,9 @@ def __post_init__(self) -> None: pin_memory=True, ) # context full seqlens include cached tokens and uncached tokens - self.ctx_kv_indptr = torch.zeros( + self.ctx_kv_indptr = get_empty( (self.max_num_requests + 1, ), - device='cuda', + cache_name="ctx_kv_indptr", dtype=torch.int64, ) self.host_ctx_kv_indptr = torch.zeros_like( @@ -1165,7 +1220,7 @@ def forward( host_kv_cache_pool_pointers=metadata.host_kv_cache_pool_pointers, host_kv_cache_pool_mapping=metadata.host_kv_cache_pool_mapping, block_ids_per_seq=metadata.block_ids_per_seq, - workspace=metadata.workspace, + workspace=None, cache_indirection=metadata.cache_indirection, kv_scale_orig_quant=self.kv_scale_orig_quant, kv_scale_quant_orig=self.kv_scale_quant_orig, diff --git a/tensorrt_llm/_torch/autotuner.py b/tensorrt_llm/_torch/autotuner.py index aa1b250b3a1..846386bb1f6 100644 --- a/tensorrt_llm/_torch/autotuner.py +++ b/tensorrt_llm/_torch/autotuner.py @@ -371,7 +371,7 @@ def choose_one( if not is_cache_hit: logger.warning_once( f"[AutoTunner] Using the fallback tactic, due to cache miss on input shapes={input_shapes}", - key=(custom_op)) + key=custom_op) return (best_runner, best_tactic) diff --git a/tensorrt_llm/_torch/compilation/piecewise_optimizer.py b/tensorrt_llm/_torch/compilation/piecewise_optimizer.py index 32c37d5339c..6139131e478 100644 --- a/tensorrt_llm/_torch/compilation/piecewise_optimizer.py +++ b/tensorrt_llm/_torch/compilation/piecewise_optimizer.py @@ -210,15 +210,9 @@ def __call__(self, *args): runtime_input_addresses = [ i.data_ptr() for i in args if isinstance(i, torch.Tensor) ] - runtime_output_addresses = [ - i.data_ptr() for i in output if isinstance(i, torch.Tensor) - ] assert (entry.input_addresses == runtime_input_addresses ), f"{entry.input_addresses} vs\n {runtime_input_addresses}" - assert ( - entry.output_addresses == runtime_output_addresses - ), f"{entry.output_addresses} vs\n {runtime_output_addresses}" entry.cuda_graph.replay() diff --git a/tensorrt_llm/_torch/model_config.py b/tensorrt_llm/_torch/model_config.py index a4ce0092a0b..f77d309805e 100644 --- a/tensorrt_llm/_torch/model_config.py +++ b/tensorrt_llm/_torch/model_config.py @@ -494,7 +494,8 @@ def get_bindings_model_config(self, architectures = self.pretrained_config.architectures if len(architectures ) == 1 and architectures[0] == "DeciLMForCausalLM": - mlp_hidden_size = self._infer_nemotron_ffn_mult() + mlp_hidden_size = self._infer_nemotron_ffn_mult( + ) // self.mapping.tp_size else: raise ValueError( f"Inferring mlp hidden size for model architecture: {architectures} isn't supported yet" diff --git a/tensorrt_llm/_torch/models/modeling_llava_next.py b/tensorrt_llm/_torch/models/modeling_llava_next.py index 7158c23f527..7e84fbde5cf 100644 --- a/tensorrt_llm/_torch/models/modeling_llava_next.py +++ b/tensorrt_llm/_torch/models/modeling_llava_next.py @@ -302,7 +302,8 @@ def pack_image_features(self, logger.warning_once( "Image feature shape does not line up with the provided patch size. " "You may be using the `default` vision_feature_select_strategy with a" - " visual encoder that does not have CLS.") + " visual encoder that does not have CLS.", + key="llava_next_vision_model_pack_image_features") image_feature = image_feature.view(num_patch_height, num_patch_width, height, diff --git a/tensorrt_llm/_torch/models/modeling_qwen3.py b/tensorrt_llm/_torch/models/modeling_qwen3.py index cbd2ebb9836..8087ef30dff 100644 --- a/tensorrt_llm/_torch/models/modeling_qwen3.py +++ b/tensorrt_llm/_torch/models/modeling_qwen3.py @@ -48,6 +48,9 @@ def __init__( rope=RopeParams.from_config(config), ) + # Qwen3 has accuracy issues with deep_gemm (see: https://nvbugspro.nvidia.com/bug/5461712) + # TODO: Consider adding disable_deep_gemm support to QKNormRoPEAttention if accuracy still remains + super().__init__( hidden_size=config.hidden_size, num_attention_heads=config.num_attention_heads, @@ -85,6 +88,7 @@ def __init__( dtype=config.torch_dtype, config=model_config, ) + self.input_layernorm = RMSNorm(hidden_size=config.hidden_size, eps=config.rms_norm_eps, dtype=config.torch_dtype) diff --git a/tensorrt_llm/_torch/modules/attention.py b/tensorrt_llm/_torch/modules/attention.py index d353cf21a99..c488e2cd3f9 100644 --- a/tensorrt_llm/_torch/modules/attention.py +++ b/tensorrt_llm/_torch/modules/attention.py @@ -216,6 +216,7 @@ def __init__( skip_create_weights_in_init=config.skip_create_weights_in_init, allreduce_strategy=config.allreduce_strategy, force_dynamic_quantization=config.force_dynamic_quantization) + self.o_lora = LoraLayer([LoraModuleType.ATTENTION_DENSE], [self.hidden_size]) diff --git a/tensorrt_llm/_torch/modules/fused_moe/fused_moe_trtllm_gen.py b/tensorrt_llm/_torch/modules/fused_moe/fused_moe_trtllm_gen.py index a74d8f2e738..4e18ae8c245 100644 --- a/tensorrt_llm/_torch/modules/fused_moe/fused_moe_trtllm_gen.py +++ b/tensorrt_llm/_torch/modules/fused_moe/fused_moe_trtllm_gen.py @@ -3,6 +3,8 @@ import torch from torch import nn +from tensorrt_llm._utils import get_sm_version + from ...model_config import ModelConfig from ...utils import Fp4QuantizedTensor, next_positive_power_of_2 from .interface import MoE, MoEWeightLoadingMode @@ -78,6 +80,11 @@ def __init__( swiglu_limit=swiglu_limit, ) + sm_version = get_sm_version() + if sm_version >= 120: + raise NotImplementedError( + "TRTLLMGenFusedMoE does not support SM120 and above.") + assert not self.smart_router, "Smart router is not supported in TRTLLMGenFusedMoE." self.num_slots = self.num_experts diff --git a/tensorrt_llm/_torch/modules/gated_mlp.py b/tensorrt_llm/_torch/modules/gated_mlp.py index f177c418850..cf381ea2c27 100644 --- a/tensorrt_llm/_torch/modules/gated_mlp.py +++ b/tensorrt_llm/_torch/modules/gated_mlp.py @@ -5,6 +5,7 @@ import torch.nn.functional as F from torch import nn +from tensorrt_llm.logger import logger from tensorrt_llm.mapping import Mapping from ..distributed import AllReduceParams @@ -29,6 +30,7 @@ def __init__(self, reduce_output: bool = True, layer_idx: Optional[int] = None, use_cute_dsl_blockscaling_mm: bool = False): + super().__init__() self.layer_idx = layer_idx self.hidden_size = hidden_size @@ -98,12 +100,21 @@ def __init__(self, [LoraModuleType.MLP_GATE_UP], [2 * self.intermediate_size // mapping.tp_size]) - def _apply_activation(self, x): + def _apply_activation(self, x, *, has_lora: bool = False): if self.activation == F.silu: if self.down_proj.has_fp8_qdq or self.down_proj.has_w4a8_nvfp4_fp8: - return swiglu(x, - quant_scale=self.down_proj.input_scale, - quant_type=torch.float8_e4m3fn) + if has_lora: + # NOTE: This is a WAR, since LoRA grouped_gemm does not support FP8 yet. + # TODO: Remove this path when LoRA grouped_gemm supports FP8 + # see: cpp/tensorrt_llm/thop/loraOp.cpp::lora_grouped_gemm + logger.warning( + f"GatedMLP._apply_activation: LoRA path active; forcing non-FP8 activation dtype bf16/fp16, layer_idx={self.layer_idx}" + ) + return swiglu(x) + else: + return swiglu(x, + quant_scale=self.down_proj.input_scale, + quant_type=torch.float8_e4m3fn) else: return swiglu(x) elif callable(self.activation): @@ -155,7 +166,7 @@ def forward_lora( if h1_lora is not None: h1 = h1 + h1_lora - h2 = self._apply_activation(h1) + h2 = self._apply_activation(h1, has_lora=True) output = self.down_proj(h2, all_reduce_params=final_all_reduce_params, lora_params=lora_params, diff --git a/tensorrt_llm/bench/dataclasses/configuration.py b/tensorrt_llm/bench/dataclasses/configuration.py index a693333230c..6d8e703ee49 100755 --- a/tensorrt_llm/bench/dataclasses/configuration.py +++ b/tensorrt_llm/bench/dataclasses/configuration.py @@ -84,8 +84,24 @@ def get_llm_args(self) -> Dict: backend_cache_config = llm_args.pop("kv_cache_config", {}) llm_args["kv_cache_config"] = backend_cache_config | kv_cache_config - return update_llm_args_with_extra_options(llm_args, - self.extra_llm_api_options) + updated_llm_args = update_llm_args_with_extra_options( + llm_args, self.extra_llm_api_options) + + if self.backend == "pytorch": + cuda_graph_config = updated_llm_args.pop( + "cuda_graph_config", llm_args["cuda_graph_config"]) + # Use runtime max_batch_size as cuda_graph_config.max_batch_size + # if both max_batch_size and batch_sizes are not set. + batch_sizes_set = cuda_graph_config.get("batch_sizes", + None) is not None + max_batch_size_set = cuda_graph_config.get("max_batch_size", + None) is not None + if not batch_sizes_set and not max_batch_size_set: + cuda_graph_config[ + "max_batch_size"] = self.settings_config.max_batch_size + updated_llm_args["cuda_graph_config"] = cuda_graph_config + + return updated_llm_args @model_validator(mode="after") def validate_full_config(self) -> RuntimeConfig: diff --git a/tensorrt_llm/executor/ipc.py b/tensorrt_llm/executor/ipc.py index 2a86f50d650..5d45ebe4c12 100644 --- a/tensorrt_llm/executor/ipc.py +++ b/tensorrt_llm/executor/ipc.py @@ -125,13 +125,36 @@ def put(self, obj: Any): # Send data without HMAC self.socket.send_pyobj(obj) - def put_noblock(self, obj: Any): + def put_noblock(self, + obj: Any, + *, + retry: int = 1, + wait_time: float = 0.001): + ''' + Put an object into the queue without blocking, and retry if the send fails. + NOTE: It won't raise any error if the send fails. + + Parameters: + obj (Any): The object to send. + retry (int): The number of times to retry sending the object. + wait_time (float): The time to wait before retrying. + ''' + + assert retry >= 0 and retry <= 10, "Retry must be between 0 and 10, adjust the wait_time if needed" + self.setup_lazily() with nvtx_range_debug("send", color="blue", category="IPC"): data = pickle.dumps(obj) # nosec B301 if self.use_hmac_encryption: data = self._sign_data(data) - self.socket.send(data, flags=zmq.NOBLOCK) + try: + self.socket.send(data, flags=zmq.NOBLOCK) + except zmq.Again: + if retry > 0: + time.sleep(wait_time) + self.put_noblock(obj, retry=retry - 1, wait_time=wait_time) + else: + logger.error(f"Failed to send object: {obj}") async def put_async(self, obj: Any): self.setup_lazily() diff --git a/tensorrt_llm/executor/proxy.py b/tensorrt_llm/executor/proxy.py index ec561bb2918..bf60f7edb6c 100644 --- a/tensorrt_llm/executor/proxy.py +++ b/tensorrt_llm/executor/proxy.py @@ -351,7 +351,7 @@ def pre_shutdown(self): # notify the workers to quit if all(not f.done() for f in self.mpi_futures): - self.request_queue.put_noblock(None) + self.request_queue.put_noblock(None, retry=4) def shutdown(self): if not self.workers_started: diff --git a/tensorrt_llm/llmapi/mpi_session.py b/tensorrt_llm/llmapi/mpi_session.py index 7a25fa57f3d..f361b977b7d 100644 --- a/tensorrt_llm/llmapi/mpi_session.py +++ b/tensorrt_llm/llmapi/mpi_session.py @@ -435,7 +435,7 @@ def mpi_future_callback(self, future): f"RemoteMpiCommSessionServer received all results, sending to client\n", "green") try: - self.queue.put_noblock(self.results) + self.queue.put_noblock(self.results, retry=2) except zmq.ZMQError as e: # The client could be shutdown first. if e.errno == zmq.EAGAIN: diff --git a/tensorrt_llm/llmapi/utils.py b/tensorrt_llm/llmapi/utils.py index 65000841909..d08e5397504 100644 --- a/tensorrt_llm/llmapi/utils.py +++ b/tensorrt_llm/llmapi/utils.py @@ -518,6 +518,8 @@ def generate_api_docs_as_docstring(model: Type[BaseModel], # Format the argument documentation with 12 spaces indent for args arg_line = f"{indent} {field_name} ({type_str}): " + if status := field_info.get("status", None): + arg_line += f":tag:`{status}` " if field_description: arg_line += field_description.split('\n')[0] # First line with type @@ -557,20 +559,21 @@ class ApiParamTagger: ''' def __call__(self, cls: Type[BaseModel]) -> None: - self.process_pydantic_model(cls) + """ The main entry point to tag the api doc. """ + self._process_pydantic_model(cls) - def process_pydantic_model(self, cls: Type[BaseModel]) -> None: + def _process_pydantic_model(self, cls: Type[BaseModel]) -> None: """Process the Pydantic model to add tags to the fields. """ for field_name, field_info in cls.model_fields.items(): if field_info.json_schema_extra and 'status' in field_info.json_schema_extra: status = field_info.json_schema_extra['status'] - self.amend_pydantic_field_description_with_tags( + self._amend_pydantic_field_description_with_tags( cls, [field_name], status) - def amend_pydantic_field_description_with_tags(self, cls: Type[BaseModel], - field_names: list[str], - tag: str) -> None: + def _amend_pydantic_field_description_with_tags(self, cls: Type[BaseModel], + field_names: list[str], + tag: str) -> None: """Amend the description of the fields with tags. e.g. :tag:`beta` or :tag:`prototype` Args: diff --git a/tensorrt_llm/logger.py b/tensorrt_llm/logger.py index 27b10165d1a..99d9ddaa583 100644 --- a/tensorrt_llm/logger.py +++ b/tensorrt_llm/logger.py @@ -109,6 +109,7 @@ def log(self, severity, *msg): self._func_wrapper(severity)(" ".join(parts)) def log_once(self, severity, *msg, key): + assert key is not None, "key is required for log_once" if key not in self._appeared_keys: self._appeared_keys.add(key) self.log(severity, *msg) diff --git a/tensorrt_llm/serve/openai_disagg_server.py b/tensorrt_llm/serve/openai_disagg_server.py index 9726efd8817..495724b2928 100644 --- a/tensorrt_llm/serve/openai_disagg_server.py +++ b/tensorrt_llm/serve/openai_disagg_server.py @@ -322,6 +322,8 @@ async def _send_context_request(self, ctx_server: str, ctx_req: Union[Completion raise ValueError("Disagg server returned more than one choice. This is currently not supported in disaggregated server.") if choices[0].disaggregated_params is None: raise ValueError("Context server did not return disaggregated params") + if choices[0].disaggregated_params.ctx_request_id is None: + raise ValueError("Invalid disaggregated params in context phase response.") return ctx_response diff --git a/tests/integration/defs/accuracy/references/gpqa_diamond.yaml b/tests/integration/defs/accuracy/references/gpqa_diamond.yaml index f729cef1bdd..dde3b538762 100644 --- a/tests/integration/defs/accuracy/references/gpqa_diamond.yaml +++ b/tests/integration/defs/accuracy/references/gpqa_diamond.yaml @@ -8,6 +8,9 @@ meta-llama/Llama-3.3-70B-Instruct: accuracy: 48.03 - quant_algo: FP8 accuracy: 48.03 + - quant_algo: FP8 + kv_cache_quant_algo: FP8 + accuracy: 48.03 deepseek-ai/DeepSeek-R1: - quant_algo: NVFP4 accuracy: 70.45 diff --git a/tests/integration/defs/accuracy/test_llm_api_pytorch.py b/tests/integration/defs/accuracy/test_llm_api_pytorch.py index 05fe064c508..ad9a781c1ec 100644 --- a/tests/integration/defs/accuracy/test_llm_api_pytorch.py +++ b/tests/integration/defs/accuracy/test_llm_api_pytorch.py @@ -1565,6 +1565,11 @@ def test_nvfp4_4gpus_online_eplb(self, fp8kv): @parametrize_with_ids("moe_backend", ["CUTLASS", "TRTLLM"]) def test_nvfp4(self, fp8kv, attention_dp, cuda_graph, overlap_scheduler, torch_compile, mtp_nextn, moe_backend): + if moe_backend == "TRTLLM" and (get_sm_version() == 120 + or get_sm_version() == 121): + pytest.skip( + "MOE TRTLLM backend does not support SM version 120 or 121") + kv_cache_config = KvCacheConfig(free_gpu_memory_fraction=0.75) torch_compile_config = TorchCompileConfig( enable_fullgraph=True, @@ -1613,8 +1618,10 @@ def test_nvfp4_4gpus(self, fp8kv, attention_dp, cuda_graph, torch_compile, mtp_nextn, moe_backend): if torch_compile and pp_size > 1: pytest.skip("PP with torch.compile is not supported yet.") - if moe_backend == "TRTLLM" and get_sm_version() == 120: - pytest.skip("MOE TRTLLM backend does not support SM version 120") + if moe_backend == "TRTLLM" and (get_sm_version() == 120 + or get_sm_version() == 121): + pytest.skip( + "MOE TRTLLM backend does not support SM version 120 or 121") kv_cache_config = KvCacheConfig(free_gpu_memory_fraction=0.75) # Picewise Cuda Graph cannot be enabled for nvfp4 attention dp. torch_compile_config = TorchCompileConfig( @@ -1885,6 +1892,11 @@ class TestDeepSeekR1(LlmapiAccuracyTestHarness): def test_nvfp4_multi_gpus(self, tp_size, pp_size, ep_size, mtp_nextn, fp8kv, attention_dp, cuda_graph, overlap_scheduler, max_batch_size, moe_backend): + if moe_backend == "TRTLLM" and (get_sm_version() == 120 + or get_sm_version() == 121): + pytest.skip( + "MOE TRTLLM backend does not support SM version 120 or 121") + kv_cache_config = KvCacheConfig(free_gpu_memory_fraction=0.70) pytorch_config = dict( disable_overlap_scheduler=not overlap_scheduler, @@ -2509,6 +2521,11 @@ def test_nvfp4( torch_compile, ): + if moe_backend == "TRTLLM" and (get_sm_version() == 120 + or get_sm_version() == 121): + pytest.skip( + "MOE TRTLLM backend does not support SM version 120 or 121") + torch_compile_config = TorchCompileConfig( enable_fullgraph=True, enable_piecewise_cuda_graph=cuda_graph, @@ -2700,6 +2717,11 @@ def test_fp8(self, tp_size, pp_size, ep_size, attention_dp, cuda_graph, def test_nvfp4(self, tp_size, pp_size, ep_size, attention_dp, cuda_graph, overlap_scheduler, moe_backend, eagle3): + if moe_backend == "TRTLLM" and (get_sm_version() == 120 + or get_sm_version() == 121): + pytest.skip( + "MOE TRTLLM backend does not support SM version 120 or 121") + pytorch_config = dict( disable_overlap_scheduler=not overlap_scheduler, cuda_graph_config=CudaGraphConfig() if cuda_graph else None, @@ -2779,7 +2801,7 @@ class TestPhi4MiniInstruct(LlmapiAccuracyTestHarness): MODEL_PATH = f"{llm_models_root()}/Phi-4-mini-instruct" def test_auto_dtype(self): - with LLM(self.MODEL_PATH) as llm: + with LLM(self.MODEL_PATH, max_seq_len=4096) as llm: task = CnnDailymail(self.MODEL_NAME) task.evaluate(llm) task = MMLU(self.MODEL_NAME) @@ -3046,10 +3068,8 @@ def test_w4_2gpus(self, moe_backend, tp_size, pp_size, ep_size, class TestEXAONE4(LlmapiAccuracyTestHarness): MODEL_NAME = "LGAI-EXAONE/EXAONE-4.0-32B" - kv_cache_config = KvCacheConfig( - enable_block_reuse=False, - enable_partial_reuse=False, - max_attention_window=[4096, 4096, 4096, 131072]) + kv_cache_config = KvCacheConfig(enable_block_reuse=False, + enable_partial_reuse=False) def test_auto_dtype(self): model_path = f"{llm_models_root()}/EXAONE-4.0-32B" diff --git a/tests/integration/defs/disaggregated/test_configs/disagg_config_diff_max_tokens.yaml b/tests/integration/defs/disaggregated/test_configs/disagg_config_diff_max_tokens.yaml index e6ec461b5eb..3d9cfda12ec 100644 --- a/tests/integration/defs/disaggregated/test_configs/disagg_config_diff_max_tokens.yaml +++ b/tests/integration/defs/disaggregated/test_configs/disagg_config_diff_max_tokens.yaml @@ -8,7 +8,7 @@ disable_overlap_scheduler: True context_servers: num_instances: 1 max_num_tokens: 512 - max_batch_size: 256 + max_batch_size: 64 cache_transceiver_config: backend: DEFAULT urls: @@ -16,7 +16,7 @@ context_servers: generation_servers: num_instances: 1 max_num_tokens: 256 - max_batch_size: 128 + max_batch_size: 32 cache_transceiver_config: backend: DEFAULT urls: diff --git a/tests/integration/defs/disaggregated/test_disaggregated.py b/tests/integration/defs/disaggregated/test_disaggregated.py index 46c393ab488..89719b395d3 100644 --- a/tests/integration/defs/disaggregated/test_disaggregated.py +++ b/tests/integration/defs/disaggregated/test_disaggregated.py @@ -1302,7 +1302,7 @@ def get_config_for_benchmark(model_root, backend): "num_instances": 1, "max_batch_size": 2, "max_num_tokens": 384, - "max_seq_len": 320, + "max_seq_len": 384, "tensor_parallel_size": 1, "pipeline_parallel_size": 1, "disable_overlap_scheduler": True, @@ -1318,7 +1318,7 @@ def get_config_for_benchmark(model_root, backend): "pipeline_parallel_size": 1, "max_batch_size": 2, "max_num_tokens": 384, - "max_seq_len": 320, + "max_seq_len": 384, "cache_transceiver_config": { "backend": backend, "max_tokens_in_buffer": 512, diff --git a/tests/integration/defs/disaggregated/test_disaggregated_single_gpu.py b/tests/integration/defs/disaggregated/test_disaggregated_single_gpu.py index 93611de040b..b49b4afb7cd 100644 --- a/tests/integration/defs/disaggregated/test_disaggregated_single_gpu.py +++ b/tests/integration/defs/disaggregated/test_disaggregated_single_gpu.py @@ -36,6 +36,42 @@ } +def mpi_publish_name(): + port_name = None + try: + port_name = MPI.Open_port() + MPI.Publish_name('my_port', port_name) + except MPI.Exception as e: + print(f"Error publishing port name: {e}") + raise e + except Exception as e: + print(f"Unexpected error publishing port name: {e}") + raise e + + return port_name + + +def mpi_initialize_intercomm(port_name): + intercomm = None + try: + intercomm = MPI.COMM_SELF.Accept(port_name) + except MPI.Exception as e: + print(f"Error accepting intercomm: {e}", flush=True) + raise + except Exception as e: + print(f"Unexpected error accepting intercomm: {e}", flush=True) + raise + return intercomm + + +def mpi_send_termination_request(intercomm): + if intercomm is not None: + # Send termination requests + intercomm.send(None, dest=0, tag=MPI_REQUEST) + intercomm.send(None, dest=1, tag=MPI_REQUEST) + print("Sent termination requests to the workers.") + + def model_path(model_name): llm_models_root = os.environ["LLM_MODELS_ROOT"] for name, path in MODEL_PATHS.items(): @@ -48,8 +84,15 @@ async def run_worker(kv_cache_config, cache_transceiver_config, pytorch_config, model_name, rank): assert isinstance(pytorch_config, dict) print(f"Running worker {rank}") - port_name = MPI.Lookup_name('my_port') - intercomm = MPI.COMM_WORLD.Connect(port_name) + try: + port_name = MPI.Lookup_name('my_port') + intercomm = MPI.COMM_WORLD.Connect(port_name) + except MPI.Exception as e: + print(f"Error publishing port name: {e}") + raise e + except Exception as e: + print(f"Unexpected error publishing port name: {e}") + raise e session = MPI.COMM_WORLD.Split(color=rank, key=0) set_mpi_comm(session) @@ -139,8 +182,7 @@ def verify_disaggregated(model, generation_overlap, enable_cuda_graph, prompt, zip(kv_cache_configs, cache_transceiver_configs, worker_pytorch_configs, model_names, ranks)) - port_name = MPI.Open_port() - MPI.Publish_name('my_port', port_name) + port_name = mpi_publish_name() with MPIPoolExecutor(max_workers=2, env={"UCX_TLS": "^ib"}) as executor: futures = [] @@ -152,9 +194,10 @@ def verify_disaggregated(model, generation_overlap, enable_cuda_graph, prompt, print(f"Error in worker {worker_arg}: {e}") raise e + intercomm = None try: - print("Launched all the workers.") - intercomm = MPI.COMM_SELF.Accept(port_name) + print("Launched all the workers.", flush=True) + intercomm = mpi_initialize_intercomm(port_name) for _ in range(2): intercomm.recv(tag=MPI_READY) @@ -187,14 +230,15 @@ def verify_disaggregated(model, generation_overlap, enable_cuda_graph, prompt, output = responses[0] assert output[0].text == expected_output assert output[0].token_ids == expected_output_ids - + except Exception as e: + print(f"Exception encountered: {e}", flush=True) + raise e finally: - # Send termination requests - intercomm.send(None, dest=0, tag=MPI_REQUEST) - intercomm.send(None, dest=1, tag=MPI_REQUEST) - print("Sent termination requests to the workers.") + print("Sending termination request", flush=True) + mpi_send_termination_request(intercomm) # Wait for all futures to complete + print("Waiting for all workers to terminate. ", flush=True) for future in futures: future.result() print("All workers terminated.") @@ -282,8 +326,7 @@ def test_disaggregated_llama_context_capacity(model, enable_cuda_graph, zip(kv_cache_configs, cache_transceiver_configs, worker_pytorch_configs, model_names, ranks)) - port_name = MPI.Open_port() - MPI.Publish_name('my_port', port_name) + port_name = mpi_publish_name() prompt = "European Union is a political and economic union of 27 countries. The European Union is headquartered in Brussels, Belgium. The first president of the European Union was Jean-Claude Juncker. The current president is Ursula von der Leyen. The European Union is a major economic and political entity." @@ -297,9 +340,10 @@ def test_disaggregated_llama_context_capacity(model, enable_cuda_graph, print(f"Error in worker {worker_arg}: {e}") raise e + intercomm = None try: print("Launched all the workers.") - intercomm = MPI.COMM_SELF.Accept(port_name) + intercomm = mpi_initialize_intercomm(port_name) for _ in range(2): intercomm.recv(tag=MPI_READY) @@ -334,11 +378,11 @@ def test_disaggregated_llama_context_capacity(model, enable_cuda_graph, intercomm.send(requests, dest=1, tag=MPI_REQUEST) output = intercomm.recv(source=1, tag=MPI_RESULT) + except MPI.Exception as e: + print(f"MPI Error") + raise e finally: - # Send termination requests - intercomm.send(None, dest=0, tag=MPI_REQUEST) - intercomm.send(None, dest=1, tag=MPI_REQUEST) - print("Sent termination requests to the workers.") + mpi_send_termination_request(intercomm) # Wait for all futures to complete for future in futures: @@ -387,8 +431,7 @@ def test_disaggregated_spec_dec_batch_slot_limit(model, spec_dec_model_path, zip(kv_cache_configs, cache_transceiver_configs, worker_pytorch_configs, model_names, ranks)) - port_name = MPI.Open_port() - MPI.Publish_name('my_port', port_name) + port_name = mpi_publish_name() prompt = "What is the capital of Germany?" @@ -402,9 +445,10 @@ def test_disaggregated_spec_dec_batch_slot_limit(model, spec_dec_model_path, print(f"Error in worker {worker_arg}: {e}") raise e + intercomm = None try: print("Launched all the workers.") - intercomm = MPI.COMM_SELF.Accept(port_name) + intercomm = mpi_initialize_intercomm(port_name) for _ in range(2): intercomm.recv(tag=MPI_READY) @@ -438,11 +482,11 @@ def test_disaggregated_spec_dec_batch_slot_limit(model, spec_dec_model_path, intercomm.send(requests, dest=1, tag=MPI_REQUEST) output = intercomm.recv(source=1, tag=MPI_RESULT) + except MPI.Exception as e: + print(f"MPI Error") + raise e finally: - # Send termination requests - intercomm.send(None, dest=0, tag=MPI_REQUEST) - intercomm.send(None, dest=1, tag=MPI_REQUEST) - print("Sent termination requests to the workers.") + mpi_send_termination_request(intercomm) # Wait for all futures to complete for future in futures: diff --git a/tests/integration/defs/perf/pytorch_model_config.py b/tests/integration/defs/perf/pytorch_model_config.py index ab0e0bf08d2..49915d3b479 100644 --- a/tests/integration/defs/perf/pytorch_model_config.py +++ b/tests/integration/defs/perf/pytorch_model_config.py @@ -181,10 +181,19 @@ def get_model_yaml_config(model_label: str, # lora-specific change for pytorch if 'pytorch' in model_label and 'loras' in model_label: + # Derive the requested number of adapters from model_label (segment like "loras:X") + lora_count = 1 + for part in model_label.split('-'): + if part.startswith('loras:'): + lora_count = max(1, int(part.split(':', 1)[1])) + break + lora_config = { 'lora_config': { 'lora_dir': lora_dirs if lora_dirs is not None else [], - 'max_lora_rank': 64 + 'max_lora_rank': 64, + 'max_loras': lora_count, + 'max_cpu_loras': lora_count, } } if 'phi_4_multimodal_instruct' in model_label: diff --git a/tests/integration/defs/test_e2e.py b/tests/integration/defs/test_e2e.py index 0def4787de6..21bf49b363e 100644 --- a/tests/integration/defs/test_e2e.py +++ b/tests/integration/defs/test_e2e.py @@ -1755,7 +1755,7 @@ def parse_output(text): for item in text_lists: item = item.replace(os.linesep, "") while True: - match = re.search(r"(Generated text: \'(.*?)\')", item, + match = re.search(r'Generated text: ([\'"])(.*?)\1', item, re.MULTILINE) if match is None: break @@ -2299,7 +2299,8 @@ def test_ptp_quickstart_advanced_mixed_precision(llm_root, llm_venv): marks=pytest.mark.skip_less_device_memory(80000)), pytest.param("gemma-3-27b-it", "gemma/gemma-3-27b-it", - marks=pytest.mark.skip_less_device_memory(80000)), + marks=(skip_post_blackwell, + pytest.mark.skip_less_device_memory(80000))), ]) def test_ptp_quickstart_multimodal(llm_root, llm_venv, model_name, model_path, modality, use_cuda_graph): @@ -2407,9 +2408,9 @@ def test_ptp_quickstart_multimodal(llm_root, llm_venv, model_name, model_path, }, "gemma-3-27b-it": { "image": [ - ["dramatic", "turbulent", "waves", "ocean", "overcast"], - ["half", "dome", "yosemite", "landmark", "rounded"], - ["flowing", "traffic", "vehicles", "road", "Changi"], + ["natural", "turbulent", "dramatic", "scene", "wave"], + ["image", "famous", "rock", "granite", "landmark"], + ["traffic", "moderate", "heavy", "flowing", "cars"], ], }, } @@ -2600,9 +2601,10 @@ def test_ptp_quickstart_multimodal_phi4mm(llm_root, llm_venv, modality): @pytest.mark.skip_less_device(2) @pytest.mark.skip_less_device_memory(80000) @pytest.mark.parametrize("model_name,model_path", [ - ("gemma-3-27b-it", "gemma/gemma-3-27b-it"), ("mistral-small-3.1-24b-instruct", "Mistral-Small-3.1-24B-Instruct-2503"), ("Phi-4-multimodal-instruct", "multimodals/Phi-4-multimodal-instruct"), + pytest.param( + "gemma-3-27b-it", "gemma/gemma-3-27b-it", marks=skip_post_blackwell), ]) def test_ptp_quickstart_multimodal_2gpu(llm_root, llm_venv, model_name, model_path): @@ -2645,8 +2647,8 @@ def test_ptp_quickstart_multimodal_2gpu(llm_root, llm_venv, model_name, }, "Phi-4-multimodal-instruct": { "image": [ - ["image", "depicts", "mountain", "half", "rock"], - ["road", "car", "lane", "traffic", "bus"], + ["object", "mountain", "weather", "clear", "clouds"], + ["traffic", "road", "vehicles", "cars", "bus"], ], }, } @@ -2674,6 +2676,8 @@ def test_ptp_quickstart_multimodal_2gpu(llm_root, llm_venv, model_name, cmd.append("--image_format=pil") cmd.append("--attention_backend=FLASHINFER") cmd.append("--disable_kv_cache_reuse") + cmd.append("--kv_cache_fraction=0.5") + cmd.append("--max_seq_len=1024") elif model_name == "Phi-4-multimodal-instruct": # Set max_seq_len to 4096 to use short rope factor. cmd.append("--max_seq_len=4096") @@ -2702,9 +2706,10 @@ def test_ptp_quickstart_multimodal_2gpu(llm_root, llm_venv, model_name, @pytest.mark.skip_less_device_memory(80000) @pytest.mark.parametrize("model_name,model_path", [ - ("gemma-3-27b-it", "gemma/gemma-3-27b-it"), ("mistral-small-3.1-24b-instruct", "Mistral-Small-3.1-24B-Instruct-2503"), ("Phi-4-multimodal-instruct", "multimodals/Phi-4-multimodal-instruct"), + pytest.param( + "gemma-3-27b-it", "gemma/gemma-3-27b-it", marks=skip_post_blackwell), ]) def test_ptp_quickstart_multimodal_multiturn(llm_root, llm_venv, model_name, model_path): @@ -2770,6 +2775,9 @@ def test_ptp_quickstart_multimodal_multiturn(llm_root, llm_venv, model_name, cmd.append("--image_format=pil") cmd.append("--attention_backend=FLASHINFER") cmd.append("--disable_kv_cache_reuse") + cmd.append("--kv_cache_fraction=0.5") + cmd.append("--max_seq_len=1024") + elif model_name == "Phi-4-multimodal-instruct": # Set max_seq_len to 4096 to use short rope factor. cmd.append("--max_seq_len=4096") diff --git a/tests/integration/test_lists/test-db/l0_a10.yml b/tests/integration/test_lists/test-db/l0_a10.yml index 138e7e2376e..ba1420e64e3 100644 --- a/tests/integration/test_lists/test-db/l0_a10.yml +++ b/tests/integration/test_lists/test-db/l0_a10.yml @@ -39,6 +39,13 @@ l0_a10: - test_e2e.py::test_openai_chat_example[pytorch] TIMEOUT (90) - test_e2e.py::test_trtllm_bench_request_rate_and_concurrency[enable_concurrency-] - test_e2e.py::test_trtllm_bench_invalid_token_pytorch[TinyLlama-1.1B-Chat-v1.0-TinyLlama-1.1B-Chat-v1.0] + # llmapi + - unittest/llmapi/test_llm_utils.py + - unittest/llmapi/test_gc_utils.py + - unittest/llmapi/test_reasoning_parser.py + - unittest/llmapi/test_serialization.py + - unittest/llmapi/test_utils.py + - unittest/llmapi/test_llm_args.py - condition: ranges: system_gpu_count: @@ -114,12 +121,6 @@ l0_a10: - unittest/bindings - unittest/test_model_runner_cpp.py - unittest/llmapi/test_build_cache.py - - unittest/llmapi/test_llm_utils.py - - unittest/llmapi/test_gc_utils.py - - unittest/llmapi/test_reasoning_parser.py - - unittest/llmapi/test_serialization.py - - unittest/llmapi/test_utils.py - - unittest/llmapi/test_llm_args.py - accuracy/test_cli_flow.py::TestGpt2::test_auto_dtype # 1 min - accuracy/test_cli_flow.py::TestGpt2::test_beam_search # 1 min - accuracy/test_cli_flow.py::TestGpt2::test_beam_search_large # 6 mins diff --git a/tests/integration/test_lists/test-db/l0_h100.yml b/tests/integration/test_lists/test-db/l0_h100.yml index 339b12fba26..eb597d25a0d 100644 --- a/tests/integration/test_lists/test-db/l0_h100.yml +++ b/tests/integration/test_lists/test-db/l0_h100.yml @@ -106,7 +106,7 @@ l0_h100: - test_e2e.py::test_trtllm_bench_help_sanity[meta-llama/Llama-3.1-8B] - test_e2e.py::test_openai_chat_harmony - test_e2e.py::test_openai_responses - - test_e2e.py::test_ptp_quickstart_multimodal[gemma-3-27b-it-gemma/gemma-3-27b-it-image-True] + - test_e2e.py::test_ptp_quickstart_multimodal[gemma-3-27b-it-gemma/gemma-3-27b-it-image-True] TIMEOUT (90) # ------------- AutoDeploy tests --------------- - accuracy/test_llm_api_autodeploy.py::TestLlama3_1_8B::test_auto_dtype - condition: @@ -227,7 +227,7 @@ l0_h100: - accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_guided_decoding[xgrammar-mtp_nextn=0] - accuracy/test_llm_api_pytorch.py::TestGemma3_1BInstruct::test_fp8_prequantized - accuracy/test_llm_api_pytorch.py::TestGemma3_27BInstruct::test_fp8_prequantized - - accuracy/test_llm_api_pytorch.py::TestGemma3_27BInstruct::test_auto_dtype + - accuracy/test_llm_api_pytorch.py::TestGemma3_27BInstruct::test_auto_dtype TIMEOUT (90) - accuracy/test_llm_api_pytorch.py::TestMistralSmall24B::test_auto_dtype - accuracy/test_llm_api_pytorch.py::TestMistralSmall24B::test_fp8 - accuracy/test_llm_api_pytorch.py::TestQwen3_30B_A3B::test_fp8_block_scales[latency-torch_compile=False] diff --git a/tests/integration/test_lists/waives.txt b/tests/integration/test_lists/waives.txt index 3c0d6562e6d..18968eb4a0a 100644 --- a/tests/integration/test_lists/waives.txt +++ b/tests/integration/test_lists/waives.txt @@ -284,6 +284,12 @@ examples/test_phi.py::test_phi_fp8_with_bf16_lora[Phi-3-small-128k-instruct] SKI examples/test_phi.py::test_phi_fp8_with_bf16_lora[Phi-3.5-mini-instruct] SKIP (https://nvbugs/5465143) examples/test_phi.py::test_phi_fp8_with_bf16_lora[Phi-4-mini-instruct] SKIP (https://nvbugs/5465143) examples/test_llama.py::test_llm_llama_v1_2gpu_summary[llama-7b-nb:4-enable_auto_parallel] SKIP (https://nvbugs/5465173) +test_e2e.py::test_ptp_quickstart_multimodal[llava-v1.6-mistral-7b-llava-v1.6-mistral-7b-hf-image-False] SKIP (https://nvbugs/5444095) +full:GB200/examples/test_qwen.py::test_llm_qwen_7b_multi_gpus_summary[qwen1.5_7b_chat-enable_fmha_fp32_acc-enable_plugin-tp2pp2-nb:4] SKIP (https://nvbugs/5247837) +full:GB200/examples/test_qwen.py::test_llm_qwen_7b_multi_gpus_summary[qwen2_7b_instruct-enable_fmha_fp32_acc-enable_plugin-tp2pp2-nb:4] SKIP (https://nvbugs/5247837) +full:GB200/examples/test_qwen.py::test_llm_qwen_7b_multi_gpus_summary[qwen2_vl_7b_instruct-enable_fmha_fp32_acc-enable_plugin-tp2pp2-nb:4] SKIP (https://nvbugs/5359696) +full:GB200/examples/test_qwen.py::test_llm_qwen_7b_multi_gpus_summary[qwen2.5_7b_chat-enable_fmha_fp32_acc-enable_plugin-tp2pp2-nb:4] SKIP (https://nvbugs/5247837) +accuracy/test_cli_flow.py::TestLlama3_8BInstructGradient1048k::test_long_context_ppl SKIP (https://nvbugs/5413362) disaggregated/test_disaggregated.py::test_disaggregated_diff_max_tokens[TinyLlama-1.1B-Chat-v1.0] SKIP (https://nvbugs/5451272) disaggregated/test_disaggregated.py::test_disaggregated_deepseek_v3_lite_fp8_tp1_single_gpu_mtp[DeepSeek-V3-Lite-fp8] SKIP (https://nvbugs/5465642) examples/test_multimodal.py::test_llm_multimodal_general[Mistral-Small-3.1-24B-Instruct-2503-pp:1-tp:1-bfloat16-bs:1-cpp_e2e:False-nb:1] SKIP (https://nvbugs/5431146) @@ -337,17 +343,6 @@ accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_nvfp4_4gpus_online_ep accuracy/test_disaggregated_serving.py::TestLlama3_1_8BInstruct::test_auto_dtype[False] SKIP (https://nvbugs/5488118) accuracy/test_disaggregated_serving.py::TestLlama3_1_8BInstruct::test_auto_dtype[True] SKIP (https://nvbugs/5488118) accuracy/test_disaggregated_serving.py::TestLlama3_1_8BInstruct::test_ngram SKIP (https://nvbugs/5488118) -full:GB200/examples/test_qwen.py::test_llm_qwen_7b_multi_gpus_summary[qwen1.5_7b_chat-enable_fmha_fp32_acc-enable_plugin-tp2pp2-nb:4] SKIP (https://nvbugs/5247837) -full:GB200/examples/test_qwen.py::test_llm_qwen_7b_multi_gpus_summary[qwen2_7b_instruct-enable_fmha_fp32_acc-enable_plugin-tp2pp2-nb:4] SKIP (https://nvbugs/5247837) -full:GB200/examples/test_qwen.py::test_llm_qwen_7b_multi_gpus_summary[qwen2_vl_7b_instruct-enable_fmha_fp32_acc-enable_plugin-tp2pp2-nb:4] SKIP (https://nvbugs/5359696) -full:GB200/examples/test_qwen.py::test_llm_qwen_7b_multi_gpus_summary[qwen2.5_7b_chat-enable_fmha_fp32_acc-enable_plugin-tp2pp2-nb:4] SKIP (https://nvbugs/5247837) -accuracy/test_cli_flow.py::TestLlama3_8BInstructGradient1048k::test_long_context_ppl SKIP (https://nvbugs/5413362) -accuracy/test_llm_api_pytorch.py::TestDeepSeekR1::test_nvfp4_multi_gpus[throughput_tp8] SKIP (https://nvbugs/5455140) -full:L40S/accuracy/test_disaggregated_serving.py::TestDeepSeekV3Lite::test_auto_dtype[mtp_nextn=0-overlap_scheduler=False] SKIP (https://nvbugs/5347051) -full:L40S/accuracy/test_disaggregated_serving.py::TestLlama3_1_8BInstruct::test_eagle3[eagle3_one_model=False-overlap_scheduler=False] SKIP (https://nvbugs/5471106) -full:L40S/accuracy/test_disaggregated_serving.py::TestLlama3_1_8BInstruct::test_tp_pp_symmetric[MMLU-tp2pp2] SKIP (https://nvbugs/5471108) -test_e2e.py::test_multi_nodes_eval[llama4-models/nvidia/Llama-4-Maverick-17B-128E-Instruct-FP8-tp8pp2-mmlu] SKIP (https://nvbugs/5473781) -accuracy/test_cli_flow.py::TestLlama3_8BInstructGradient1048k::test_long_context_ppl SKIP (https://nvbugs/5413362) stress_test/stress_test.py::test_run_stress_test[llama-v3-8b-instruct-hf_tp1-stress_time_300s_timeout_450s-GUARANTEED_NO_EVICT-pytorch-stress-test] SKIP (https://nvbugs/5474169) test_e2e.py::test_trtllm_bench_iteration_log[TRT-streaming-meta-llama/Llama-3.1-8B-llama-3.1-model/Meta-Llama-3.1-8B] SKIP (https://nvbugs/5448523) cpp/test_unit_tests.py::test_unit_tests[kernels-80] SKIP (https://nvbugs/5504078) diff --git a/tests/unittest/_torch/multi_gpu_modeling/test_deepseek.py b/tests/unittest/_torch/multi_gpu_modeling/test_deepseek.py index 5a38f0d0788..8b6ac42cd28 100644 --- a/tests/unittest/_torch/multi_gpu_modeling/test_deepseek.py +++ b/tests/unittest/_torch/multi_gpu_modeling/test_deepseek.py @@ -17,6 +17,7 @@ def similar(a, b, threshold=0.9): return SequenceMatcher(None, a, b).ratio() >= threshold +@pytest.mark.skip(reason="https://nvbugs/5470782") @pytest.mark.parametrize("model_name", ["DeepSeek-V3-Lite"], ids=["deepseekv3_lite"]) @pytest.mark.parametrize("backend", ["TRTLLM"], ids=["trtllm"]) diff --git a/tests/unittest/llmapi/apps/_test_openai_multi_chat.py b/tests/unittest/llmapi/apps/_test_openai_multi_chat.py index 9ed9a654c52..1265b58bd90 100644 --- a/tests/unittest/llmapi/apps/_test_openai_multi_chat.py +++ b/tests/unittest/llmapi/apps/_test_openai_multi_chat.py @@ -65,7 +65,10 @@ def engine_from_fp8_quantization(model_name): @pytest.fixture(scope="module") def server(model_name: str, engine_from_fp8_quantization: str): model_path = get_model_path(model_name) - args = ["--tp_size", "2", "--tokenizer", model_path] + args = [ + "--tp_size", "2", "--tokenizer", model_path, "--backend", "trt", + "--max_num_tokens", "20480", "--max_batch_size", "128" + ] with RemoteOpenAIServer(engine_from_fp8_quantization, args) as remote_server: yield remote_server diff --git a/tests/unittest/llmapi/test_utils.py b/tests/unittest/llmapi/test_utils.py index fc5876cdb15..5488f7c7bad 100644 --- a/tests/unittest/llmapi/test_utils.py +++ b/tests/unittest/llmapi/test_utils.py @@ -1,4 +1,6 @@ -from tensorrt_llm.llmapi.utils import ApiStatusRegistry +from tensorrt_llm.llmapi import LlmArgs +from tensorrt_llm.llmapi.utils import (ApiStatusRegistry, + generate_api_docs_as_docstring) def test_api_status_registry(): @@ -24,3 +26,9 @@ def _my_method(self, *args, **kwargs): pass assert ApiStatusRegistry.get_api_status(App._my_method) == "beta" + + +def test_generate_api_docs_as_docstring(): + doc = generate_api_docs_as_docstring(LlmArgs) + assert ":tag:`beta`" in doc, "the label is not generated" + print(doc)