Skip to content

Commit 7e3b3fa

Browse files
authored
fix: Add default configs in LLMAPI. Fixes OOM issues (#2198)
1 parent 625578c commit 7e3b3fa

File tree

2 files changed

+128
-1
lines changed

2 files changed

+128
-1
lines changed

components/backends/trtllm/src/dynamo/trtllm/main.py

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,8 +8,16 @@
88

99
import uvloop
1010
from tensorrt_llm import SamplingParams
11+
from tensorrt_llm.llmapi import (
12+
BuildConfig,
13+
CapacitySchedulerPolicy,
14+
DynamicBatchConfig,
15+
KvCacheConfig,
16+
SchedulerConfig,
17+
)
1118
from tensorrt_llm.llmapi.llm_utils import update_llm_args_with_extra_options
1219
from tensorrt_llm.llmapi.tokenizer import tokenizer_factory
20+
from torch.cuda import device_count
1321

1422
from dynamo.llm import ModelType, register_llm
1523
from dynamo.runtime import DistributedRuntime, dynamo_worker
@@ -84,12 +92,51 @@ async def init(runtime: DistributedRuntime, config: Config):
8492
# Convert model path to Path object if it's a local path, otherwise keep as string
8593
model_path = str(config.model_path)
8694

95+
if config.gpus_per_node is None:
96+
gpus_per_node = device_count()
97+
if gpus_per_node == 0:
98+
raise ValueError("No GPU devices found on the node")
99+
else:
100+
gpus_per_node = config.gpus_per_node
101+
102+
build_config = BuildConfig(
103+
max_batch_size=config.max_batch_size,
104+
max_num_tokens=config.max_num_tokens,
105+
max_beam_width=config.max_beam_width,
106+
max_seq_len=config.max_seq_len,
107+
)
108+
109+
kv_cache_config = KvCacheConfig(
110+
free_gpu_memory_fraction=config.free_gpu_memory_fraction
111+
)
112+
113+
dynamic_batch_config = DynamicBatchConfig(
114+
enable_batch_size_tuning=True,
115+
enable_max_num_tokens_tuning=False,
116+
dynamic_batch_moving_average_window=128,
117+
)
118+
scheduler_config = SchedulerConfig(
119+
capacity_scheduler_policy=CapacitySchedulerPolicy.GUARANTEED_NO_EVICT,
120+
dynamic_batch_config=dynamic_batch_config,
121+
)
122+
87123
arg_map = {
88124
"model": model_path,
125+
"scheduler_config": scheduler_config,
89126
"tensor_parallel_size": config.tensor_parallel_size,
127+
"pipeline_parallel_size": config.pipeline_parallel_size,
128+
"moe_expert_parallel_size": config.expert_parallel_size,
90129
"backend": "pytorch",
91130
"skip_tokenizer_init": True,
131+
"build_config": build_config,
132+
"kv_cache_config": kv_cache_config,
133+
"gpus_per_node": gpus_per_node,
134+
"max_num_tokens": config.max_num_tokens,
135+
"max_seq_len": config.max_seq_len,
136+
"max_beam_width": config.max_beam_width,
137+
"max_batch_size": config.max_batch_size,
92138
}
139+
93140
if config.extra_engine_args != "":
94141
# TODO: Support extra engine args from json file as well.
95142
arg_map = update_llm_args_with_extra_options(arg_map, config.extra_engine_args)

components/backends/trtllm/src/dynamo/trtllm/utils/trtllm_utils.py

Lines changed: 81 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,8 @@
44
import argparse
55
from typing import Optional
66

7+
from tensorrt_llm.llmapi import BuildConfig
8+
79
from dynamo.trtllm.request_handlers.handler_base import (
810
DisaggregationMode,
911
DisaggregationStrategy,
@@ -27,8 +29,16 @@ def __init__(self) -> None:
2729
self.model_path: str = ""
2830
self.served_model_name: Optional[str] = None
2931
self.tensor_parallel_size: int = 1
32+
self.pipeline_parallel_size: int = 1
33+
self.expert_parallel_size: Optional[int] = None
3034
self.kv_block_size: int = 32
3135
self.migration_limit: int = 0
36+
self.gpus_per_node: Optional[int] = None
37+
self.max_batch_size: int = BuildConfig.max_batch_size
38+
self.max_num_tokens: int = BuildConfig.max_num_tokens
39+
self.max_seq_len: int = BuildConfig.max_seq_len
40+
self.max_beam_width: int = BuildConfig.max_beam_width
41+
self.free_gpu_memory_fraction: Optional[float] = None
3242
self.extra_engine_args: str = ""
3343
self.publish_events_and_metrics: bool = False
3444
self.disaggregation_mode: DisaggregationMode = DEFAULT_DISAGGREGATION_MODE
@@ -45,7 +55,15 @@ def __str__(self) -> str:
4555
f"model_path={self.model_path}, "
4656
f"served_model_name={self.served_model_name}, "
4757
f"tensor_parallel_size={self.tensor_parallel_size}, "
58+
f"pipeline_parallel_size={self.pipeline_parallel_size}, "
59+
f"expert_parallel_size={self.expert_parallel_size}, "
4860
f"kv_block_size={self.kv_block_size}, "
61+
f"gpus_per_node={self.gpus_per_node}, "
62+
f"max_batch_size={self.max_batch_size}, "
63+
f"max_num_tokens={self.max_num_tokens}, "
64+
f"max_seq_len={self.max_seq_len}, "
65+
f"max_beam_width={self.max_beam_width}, "
66+
f"free_gpu_memory_fraction={self.free_gpu_memory_fraction}, "
4967
f"extra_engine_args={self.extra_engine_args}, "
5068
f"migration_limit={self.migration_limit}, "
5169
f"publish_events_and_metrics={self.publish_events_and_metrics}, "
@@ -108,8 +126,21 @@ def cmd_line_args():
108126
help="Name to serve the model under. Defaults to deriving it from model path.",
109127
)
110128
parser.add_argument(
111-
"--tensor-parallel-size", type=int, default=1, help="Number of GPUs to use."
129+
"--tensor-parallel-size", type=int, default=1, help="Tensor parallelism size."
130+
)
131+
parser.add_argument(
132+
"--pipeline-parallel-size",
133+
type=int,
134+
default=None,
135+
help="Pipeline parallelism size.",
136+
)
137+
parser.add_argument(
138+
"--expert-parallel-size",
139+
type=int,
140+
default=None,
141+
help="expert parallelism size.",
112142
)
143+
113144
# IMPORTANT: We should ideally not expose this to users. We should be able to
114145
# query the block size from the TRTLLM engine.
115146
parser.add_argument(
@@ -121,6 +152,43 @@ def cmd_line_args():
121152
default=0,
122153
help="Maximum number of times a request may be migrated to a different engine worker. The number may be overridden by the engine.",
123154
)
155+
parser.add_argument(
156+
"--gpus-per-node",
157+
type=int,
158+
default=None,
159+
help="Number of GPUs per node. If not provided, will be inferred from the environment.",
160+
)
161+
parser.add_argument(
162+
"--max-batch-size",
163+
type=int,
164+
default=BuildConfig.max_batch_size,
165+
help="Maximum number of requests that the engine can schedule.",
166+
)
167+
parser.add_argument(
168+
"--max-num-tokens",
169+
type=int,
170+
default=BuildConfig.max_num_tokens,
171+
help="Maximum number of batched input tokens after padding is removed in each batch.",
172+
)
173+
parser.add_argument(
174+
"--max-seq-len",
175+
type=int,
176+
default=BuildConfig.max_seq_len,
177+
help="Maximum total length of one request, including prompt and outputs. "
178+
"If unspecified, the value is deduced from the model config.",
179+
)
180+
parser.add_argument(
181+
"--max-beam-width",
182+
type=int,
183+
default=BuildConfig.max_beam_width,
184+
help="Maximum number of beams for beam search decoding.",
185+
)
186+
parser.add_argument(
187+
"--free-gpu-memory-fraction",
188+
type=float,
189+
default=None,
190+
help="Free GPU memory fraction reserved for KV Cache, after allocating model weights and buffers.",
191+
)
124192

125193
parser.add_argument(
126194
"--extra-engine-args",
@@ -195,6 +263,18 @@ def cmd_line_args():
195263
config.next_endpoint = args.next_endpoint
196264

197265
config.tensor_parallel_size = args.tensor_parallel_size
266+
if args.pipeline_parallel_size is not None:
267+
config.pipeline_parallel_size = args.pipeline_parallel_size
268+
if args.expert_parallel_size is not None:
269+
config.expert_parallel_size = args.expert_parallel_size
270+
if args.gpus_per_node is not None:
271+
config.gpus_per_node = args.gpus_per_node
272+
if args.free_gpu_memory_fraction is not None:
273+
config.free_gpu_memory_fraction = args.free_gpu_memory_fraction
274+
config.max_batch_size = args.max_batch_size
275+
config.max_num_tokens = args.max_num_tokens
276+
config.max_seq_len = args.max_seq_len
277+
config.max_beam_width = args.max_beam_width
198278
config.kv_block_size = args.kv_block_size
199279
config.migration_limit = args.migration_limit
200280
config.extra_engine_args = args.extra_engine_args

0 commit comments

Comments
 (0)