diff --git a/tensorrt_llm/bench/dataclasses/configuration.py b/tensorrt_llm/bench/dataclasses/configuration.py index a693333230c..6d8e703ee49 100755 --- a/tensorrt_llm/bench/dataclasses/configuration.py +++ b/tensorrt_llm/bench/dataclasses/configuration.py @@ -84,8 +84,24 @@ def get_llm_args(self) -> Dict: backend_cache_config = llm_args.pop("kv_cache_config", {}) llm_args["kv_cache_config"] = backend_cache_config | kv_cache_config - return update_llm_args_with_extra_options(llm_args, - self.extra_llm_api_options) + updated_llm_args = update_llm_args_with_extra_options( + llm_args, self.extra_llm_api_options) + + if self.backend == "pytorch": + cuda_graph_config = updated_llm_args.pop( + "cuda_graph_config", llm_args["cuda_graph_config"]) + # Use runtime max_batch_size as cuda_graph_config.max_batch_size + # if both max_batch_size and batch_sizes are not set. + batch_sizes_set = cuda_graph_config.get("batch_sizes", + None) is not None + max_batch_size_set = cuda_graph_config.get("max_batch_size", + None) is not None + if not batch_sizes_set and not max_batch_size_set: + cuda_graph_config[ + "max_batch_size"] = self.settings_config.max_batch_size + updated_llm_args["cuda_graph_config"] = cuda_graph_config + + return updated_llm_args @model_validator(mode="after") def validate_full_config(self) -> RuntimeConfig: