From cc3beac508febf3039040aa0aa2da586040d4e20 Mon Sep 17 00:00:00 2001 From: Jiagan Cheng Date: Tue, 19 Aug 2025 07:57:48 +0000 Subject: [PATCH 1/3] use runtime max_batch_size if cuda_graph_config.max_batch_size is not provided Signed-off-by: Jiagan Cheng --- tensorrt_llm/llmapi/llm_args.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/tensorrt_llm/llmapi/llm_args.py b/tensorrt_llm/llmapi/llm_args.py index b7d46ed6fa2..bfa3f3da12b 100644 --- a/tensorrt_llm/llmapi/llm_args.py +++ b/tensorrt_llm/llmapi/llm_args.py @@ -2230,7 +2230,11 @@ def validate_cuda_graph_config(self) -> 'TorchLlmArgs': else: config.max_batch_size = max(config.batch_sizes) else: - max_batch_size = config.max_batch_size or 128 + # Use the max batch size from: + # 1. cuda_graph_config.max_batch_size, if provided, + # 2. base_llm_args.max_batch_size, if provided, + # 3. default value 128. + max_batch_size = config.max_batch_size or self.max_batch_size or 128 generated_sizes = CudaGraphConfig._generate_cuda_graph_batch_sizes( max_batch_size, config.enable_padding) config.batch_sizes = generated_sizes From 93d792dbfc93db8aca46baf96d7f6d57e2231595 Mon Sep 17 00:00:00 2001 From: Jiagan Cheng Date: Wed, 20 Aug 2025 10:28:37 +0000 Subject: [PATCH 2/3] decrease kvcache fraction to avoid OOM Signed-off-by: Jiagan Cheng --- tests/integration/defs/accuracy/test_llm_api_pytorch.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/tests/integration/defs/accuracy/test_llm_api_pytorch.py b/tests/integration/defs/accuracy/test_llm_api_pytorch.py index 22d04b26145..183d87022b3 100644 --- a/tests/integration/defs/accuracy/test_llm_api_pytorch.py +++ b/tests/integration/defs/accuracy/test_llm_api_pytorch.py @@ -2015,11 +2015,14 @@ def test_bf16(self, tp_size, pp_size, ep_size, attention_dp, cuda_graph, pytorch_config = dict( disable_overlap_scheduler=not overlap_scheduler, cuda_graph_config=CudaGraphConfig() if cuda_graph else None) + # decrease fraction to avoid OOM on RTX 5090 + kv_cache_config = KvCacheConfig(free_gpu_memory_fraction=0.6) with LLM(f"{llm_models_root()}/Qwen3/Qwen3-8B", tensor_parallel_size=tp_size, pipeline_parallel_size=pp_size, moe_expert_parallel_size=ep_size, + kv_cache_config=kv_cache_config, **pytorch_config, enable_attention_dp=attention_dp) as llm: task = CnnDailymail(self.MODEL_NAME) From 6c7142f3fad71f116f1490f088f0948e55a530a0 Mon Sep 17 00:00:00 2001 From: Jiagan Cheng Date: Thu, 21 Aug 2025 10:50:41 +0000 Subject: [PATCH 3/3] limit the changes in trtllm-bench to avoid OOM in CI Signed-off-by: Jiagan Cheng --- .../bench/dataclasses/configuration.py | 20 +++++++++++++++++-- tensorrt_llm/llmapi/llm_args.py | 6 +----- .../defs/accuracy/test_llm_api_pytorch.py | 3 --- 3 files changed, 19 insertions(+), 10 deletions(-) diff --git a/tensorrt_llm/bench/dataclasses/configuration.py b/tensorrt_llm/bench/dataclasses/configuration.py index a693333230c..6d8e703ee49 100755 --- a/tensorrt_llm/bench/dataclasses/configuration.py +++ b/tensorrt_llm/bench/dataclasses/configuration.py @@ -84,8 +84,24 @@ def get_llm_args(self) -> Dict: backend_cache_config = llm_args.pop("kv_cache_config", {}) llm_args["kv_cache_config"] = backend_cache_config | kv_cache_config - return update_llm_args_with_extra_options(llm_args, - self.extra_llm_api_options) + updated_llm_args = update_llm_args_with_extra_options( + llm_args, self.extra_llm_api_options) + + if self.backend == "pytorch": + cuda_graph_config = updated_llm_args.pop( + "cuda_graph_config", llm_args["cuda_graph_config"]) + # Use runtime max_batch_size as cuda_graph_config.max_batch_size + # if both max_batch_size and batch_sizes are not set. + batch_sizes_set = cuda_graph_config.get("batch_sizes", + None) is not None + max_batch_size_set = cuda_graph_config.get("max_batch_size", + None) is not None + if not batch_sizes_set and not max_batch_size_set: + cuda_graph_config[ + "max_batch_size"] = self.settings_config.max_batch_size + updated_llm_args["cuda_graph_config"] = cuda_graph_config + + return updated_llm_args @model_validator(mode="after") def validate_full_config(self) -> RuntimeConfig: diff --git a/tensorrt_llm/llmapi/llm_args.py b/tensorrt_llm/llmapi/llm_args.py index bfa3f3da12b..b7d46ed6fa2 100644 --- a/tensorrt_llm/llmapi/llm_args.py +++ b/tensorrt_llm/llmapi/llm_args.py @@ -2230,11 +2230,7 @@ def validate_cuda_graph_config(self) -> 'TorchLlmArgs': else: config.max_batch_size = max(config.batch_sizes) else: - # Use the max batch size from: - # 1. cuda_graph_config.max_batch_size, if provided, - # 2. base_llm_args.max_batch_size, if provided, - # 3. default value 128. - max_batch_size = config.max_batch_size or self.max_batch_size or 128 + max_batch_size = config.max_batch_size or 128 generated_sizes = CudaGraphConfig._generate_cuda_graph_batch_sizes( max_batch_size, config.enable_padding) config.batch_sizes = generated_sizes diff --git a/tests/integration/defs/accuracy/test_llm_api_pytorch.py b/tests/integration/defs/accuracy/test_llm_api_pytorch.py index 183d87022b3..22d04b26145 100644 --- a/tests/integration/defs/accuracy/test_llm_api_pytorch.py +++ b/tests/integration/defs/accuracy/test_llm_api_pytorch.py @@ -2015,14 +2015,11 @@ def test_bf16(self, tp_size, pp_size, ep_size, attention_dp, cuda_graph, pytorch_config = dict( disable_overlap_scheduler=not overlap_scheduler, cuda_graph_config=CudaGraphConfig() if cuda_graph else None) - # decrease fraction to avoid OOM on RTX 5090 - kv_cache_config = KvCacheConfig(free_gpu_memory_fraction=0.6) with LLM(f"{llm_models_root()}/Qwen3/Qwen3-8B", tensor_parallel_size=tp_size, pipeline_parallel_size=pp_size, moe_expert_parallel_size=ep_size, - kv_cache_config=kv_cache_config, **pytorch_config, enable_attention_dp=attention_dp) as llm: task = CnnDailymail(self.MODEL_NAME)