Skip to content

Commit 7ba91cb

Browse files
committed
optimize: ADP schedule optimization
Signed-off-by: yunruis <[email protected]>
1 parent 6135f75 commit 7ba91cb

File tree

9 files changed

+200
-16
lines changed

9 files changed

+200
-16
lines changed

examples/llm-api/quickstart_advanced.py

Lines changed: 19 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,11 @@
11
import argparse
22

33
from tensorrt_llm import LLM, SamplingParams
4-
from tensorrt_llm.llmapi import (AutoDecodingConfig, CudaGraphConfig,
5-
DraftTargetDecodingConfig, EagleDecodingConfig,
6-
KvCacheConfig, MoeConfig, MTPDecodingConfig,
7-
NGramDecodingConfig, TorchCompileConfig)
4+
from tensorrt_llm.llmapi import (AttentionDpConfig, AutoDecodingConfig,
5+
CudaGraphConfig, DraftTargetDecodingConfig,
6+
EagleDecodingConfig, KvCacheConfig, MoeConfig,
7+
MTPDecodingConfig, NGramDecodingConfig,
8+
TorchCompileConfig)
89

910
example_prompts = [
1011
"Hello, my name is",
@@ -57,6 +58,13 @@ def add_llm_args(parser):
5758
parser.add_argument('--enable_attention_dp',
5859
default=False,
5960
action='store_true')
61+
parser.add_argument('--attention_dp_enable_balance',
62+
default=False,
63+
action='store_true')
64+
parser.add_argument('--attention_dp_time_out_iters', type=int, default=0)
65+
parser.add_argument('--attention_dp_batching_wait_iters',
66+
type=int,
67+
default=0)
6068
parser.add_argument('--enable_trtllm_sampler',
6169
default=False,
6270
action='store_true')
@@ -196,6 +204,12 @@ def setup_llm(args, **kwargs):
196204
enable_padding=args.cuda_graph_padding_enabled,
197205
) if args.use_cuda_graph else None
198206

207+
attention_dp_config = AttentionDpConfig(
208+
enable_balance=args.attention_dp_enable_balance,
209+
timeout_iters=args.attention_dp_time_out_iters,
210+
batching_wait_iters=args.attention_dp_batching_wait_iters,
211+
)
212+
199213
llm = LLM(
200214
model=args.model_dir,
201215
backend='pytorch',
@@ -218,6 +232,7 @@ def setup_llm(args, **kwargs):
218232
max_batch_size=args.max_batch_size,
219233
max_num_tokens=args.max_num_tokens,
220234
enable_attention_dp=args.enable_attention_dp,
235+
attention_dp_config=attention_dp_config,
221236
tensor_parallel_size=args.tp_size,
222237
pipeline_parallel_size=args.pp_size,
223238
moe_expert_parallel_size=args.moe_ep_size,

tensorrt_llm/_torch/auto_deploy/shim/ad_executor.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -132,6 +132,9 @@ def __init__(
132132
self.pytorch_backend_config.enable_iter_perf_stats = False
133133
self.pytorch_backend_config.enable_iter_req_stats = False
134134
self.pytorch_backend_config.stream_interval = 1
135+
self.pytorch_backend_config.attention_dp_enable_balance = False
136+
self.pytorch_backend_config.attention_dp_time_out_iters = 50
137+
self.pytorch_backend_config.attention_dp_batching_wait_iters = 10
135138
self.iter_counter = 0
136139

137140
# NOTE (lucaslie): not a declared base member in the base class; required by PyExecutor...

tensorrt_llm/_torch/pyexecutor/config.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,10 @@ class PyTorchConfig:
4646
moe_max_num_tokens: Optional[int] = None
4747
moe_load_balancer: Optional[Union[MoeLoadBalancerConfig, dict, str]] = None
4848

49+
attention_dp_enable_balance: bool = False
50+
attention_dp_time_out_iters: int = 50
51+
attention_dp_batching_wait_iters: int = 10
52+
4953
attn_backend: str = 'TRTLLM'
5054
moe_backend: str = 'CUTLASS'
5155

tensorrt_llm/_torch/pyexecutor/py_executor.py

Lines changed: 63 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -177,6 +177,9 @@ def __init__(self,
177177
self.enable_iter_perf_stats = model_engine.pytorch_backend_config.enable_iter_perf_stats
178178
self.enable_iter_req_stats = model_engine.pytorch_backend_config.enable_iter_req_stats
179179
self.stream_interval = model_engine.pytorch_backend_config.stream_interval
180+
self.attention_dp_enable_balance = model_engine.pytorch_backend_config.attention_dp_enable_balance
181+
self.attention_dp_time_out_iters = model_engine.pytorch_backend_config.attention_dp_time_out_iters
182+
self.attention_dp_batching_wait_iters = model_engine.pytorch_backend_config.attention_dp_batching_wait_iters
180183
self.num_fetch_requests_cur_rank = 0
181184
self.num_fetch_requests = 0
182185
self.shutdown_event = threading.Event()
@@ -215,6 +218,9 @@ def __init__(self,
215218
self.draft_model_engine.warmup(self.resource_manager)
216219

217220
self.is_shutdown = False
221+
self.max_batch_size = max_batch_size
222+
self.adp_ctx_waiting_iters_count = 0
223+
self.adp_ctx_batching_wait_iters_count = 0
218224

219225
# request fetcher initialization
220226
self.executor_request_queue = ExecutorRequestQueue(
@@ -1131,13 +1137,68 @@ def _add_kv_cache_events(self):
11311137
# to be transferred to main thread when user needs them.
11321138
kv_cache_manager.flush_iteration_events()
11331139

1140+
def _balance_adp_requests(self, context_requests: list[LlmRequest],
1141+
generation_requests: list[LlmRequest]):
1142+
balanced_context_requests = context_requests
1143+
num_scheduled_context_requests = len(context_requests)
1144+
num_scheduled_generation_requests = len(generation_requests)
1145+
num_scheduled_tokens = sum(
1146+
[len(req.get_tokens(0))
1147+
for req in context_requests]) + num_scheduled_generation_requests
1148+
responses_list = self.dist.tp_allgather([
1149+
num_scheduled_context_requests, num_scheduled_generation_requests,
1150+
num_scheduled_tokens
1151+
])
1152+
all_ranks_num_scheduled_context_requests = [
1153+
response[0] for response in responses_list
1154+
]
1155+
all_ranks_num_scheduled_generation_requests = [
1156+
response[1] for response in responses_list
1157+
]
1158+
all_ranks_have_free_ctx_slots = all([
1159+
num_gen < self.max_batch_size
1160+
for num_gen in all_ranks_num_scheduled_generation_requests
1161+
])
1162+
all_ranks_have_ctx_requests = all([
1163+
num_ctx > 0 for num_ctx in all_ranks_num_scheduled_context_requests
1164+
])
1165+
all_ranks_have_gen_requests = all([
1166+
num_gen > 0
1167+
for num_gen in all_ranks_num_scheduled_generation_requests
1168+
])
1169+
1170+
if self.attention_dp_enable_balance:
1171+
# wait for all ranks have context requests
1172+
if all_ranks_have_free_ctx_slots and all_ranks_have_ctx_requests:
1173+
self.adp_ctx_waiting_iters_count = 0
1174+
# balance number of context requests across ranks
1175+
if all_ranks_have_gen_requests:
1176+
if self.adp_ctx_batching_wait_iters_count < self.attention_dp_batching_wait_iters:
1177+
self.adp_ctx_batching_wait_iters_count += 1
1178+
balanced_context_requests = []
1179+
else:
1180+
self.adp_ctx_batching_wait_iters_count = 0
1181+
else:
1182+
self.adp_ctx_waiting_iters_count += 1
1183+
balanced_context_requests = []
1184+
timeout_reached = self.adp_ctx_waiting_iters_count >= self.attention_dp_time_out_iters
1185+
if timeout_reached or not all_ranks_have_gen_requests:
1186+
self.adp_ctx_waiting_iters_count = 0
1187+
balanced_context_requests = context_requests
1188+
return balanced_context_requests
1189+
11341190
@nvtx_range("_schedule")
11351191
def _schedule(self):
11361192
scheduler_output = self.scheduler.schedule_request(
11371193
self.active_requests, self.inflight_req_ids)
1138-
scheduled_requests = ScheduledRequests()
1194+
scheduled_context_requests = scheduler_output.context_requests
1195+
if self.enable_attention_dp and self.attention_dp_enable_balance:
1196+
scheduled_context_requests = self._balance_adp_requests(
1197+
scheduler_output.context_requests,
1198+
scheduler_output.generation_requests)
11391199

1140-
scheduled_requests.context_requests = scheduler_output.context_requests
1200+
scheduled_requests = ScheduledRequests()
1201+
scheduled_requests.context_requests = scheduled_context_requests
11411202
scheduled_requests.generation_requests = scheduler_output.generation_requests
11421203
scheduled_requests.paused_requests = scheduler_output.paused_requests
11431204
return scheduled_requests, scheduler_output.fitting_disagg_gen_init_requests, scheduler_output.num_fitting_requests

tensorrt_llm/llmapi/__init__.py

Lines changed: 11 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -4,15 +4,16 @@
44
from .build_cache import BuildCacheConfig
55
from .llm import LLM, RequestOutput
66
# yapf: disable
7-
from .llm_args import (AutoDecodingConfig, BatchingType, CacheTransceiverConfig,
8-
CalibConfig, CapacitySchedulerPolicy,
9-
ContextChunkingPolicy, CudaGraphConfig,
10-
DraftTargetDecodingConfig, DynamicBatchConfig,
11-
EagleDecodingConfig, ExtendedRuntimePerfKnobConfig,
12-
KvCacheConfig, LlmArgs, LookaheadDecodingConfig,
13-
MedusaDecodingConfig, MoeConfig, MTPDecodingConfig,
14-
NGramDecodingConfig, SchedulerConfig, TorchCompileConfig,
15-
TorchLlmArgs, TrtLlmArgs, UserProvidedDecodingConfig)
7+
from .llm_args import (AttentionDpConfig, AutoDecodingConfig, BatchingType,
8+
CacheTransceiverConfig, CalibConfig,
9+
CapacitySchedulerPolicy, ContextChunkingPolicy,
10+
CudaGraphConfig, DraftTargetDecodingConfig,
11+
DynamicBatchConfig, EagleDecodingConfig,
12+
ExtendedRuntimePerfKnobConfig, KvCacheConfig, LlmArgs,
13+
LookaheadDecodingConfig, MedusaDecodingConfig, MoeConfig,
14+
MTPDecodingConfig, NGramDecodingConfig, SchedulerConfig,
15+
TorchCompileConfig, TorchLlmArgs, TrtLlmArgs,
16+
UserProvidedDecodingConfig)
1617
from .llm_utils import (BuildConfig, KvCacheRetentionConfig, QuantAlgo,
1718
QuantConfig)
1819
from .mpi_session import MpiCommSession
@@ -54,4 +55,5 @@
5455
'TorchLlmArgs',
5556
'TrtLlmArgs',
5657
'AutoDecodingConfig',
58+
'AttentionDpConfig',
5759
]

tensorrt_llm/llmapi/llm_args.py

Lines changed: 56 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -187,6 +187,23 @@ def from_dict(cls, data: dict):
187187
return cls(**data)
188188

189189

190+
class AttentionDpConfig(StrictBaseModel):
191+
"""
192+
Configuration for attention DP.
193+
"""
194+
enable_balance: bool = Field(default=False,
195+
description="Whether to enable balance.")
196+
timeout_iters: int = Field(
197+
default=50, description="The number of iterations to timeout.")
198+
batching_wait_iters: int = Field(
199+
default=10,
200+
description="The number of iterations to wait for batching.")
201+
202+
@classmethod
203+
def from_dict(cls, data: dict):
204+
return cls(**data)
205+
206+
190207
@dataclass
191208
class _ParallelConfig:
192209
''' The model distribution configs for LLM. '''
@@ -1988,6 +2005,11 @@ class TorchLlmArgs(BaseLlmArgs):
19882005
Note that each CUDA graph can use up to 200 MB of extra memory.",
19892006
status="beta")
19902007

2008+
attention_dp_config: Optional[AttentionDpConfig] = Field(
2009+
default=None,
2010+
description="Optimized load-balancing for the DP Attention scheduler.",
2011+
status="beta")
2012+
19912013
disable_overlap_scheduler: bool = Field(
19922014
default=False,
19932015
description="Disable the overlap scheduler.",
@@ -2253,6 +2275,29 @@ def warn_on_unstable_feature_usage(self) -> 'TorchLlmArgs':
22532275

22542276
return self
22552277

2278+
@model_validator(mode='after')
2279+
def validate_attention_dp_config(self) -> 'TorchLlmArgs':
2280+
"""Validate attention DP configuration.
2281+
2282+
Ensures that:
2283+
1. If attention_dp_config.enable_balance is true, attention_dp_config.batching_wait_iters must be greater or equal to 0
2284+
2. If attention_dp_config.enable_balance is true, attention_dp_config.timeout_iters must be greater or equal to 0
2285+
"""
2286+
if self.attention_dp_config is None:
2287+
return self
2288+
2289+
config = self.attention_dp_config
2290+
if config.enable_balance:
2291+
if config.batching_wait_iters < 0:
2292+
raise ValueError(
2293+
"attention_dp_config.batching_wait_iters must be greater or equal to 0 when enable_balance is true"
2294+
)
2295+
if config.timeout_iters < 0:
2296+
raise ValueError(
2297+
"attention_dp_config.timeout_iters must be greater or equal to 0 when enable_balance is true"
2298+
)
2299+
return self
2300+
22562301
# TODO: Remove this after the PyTorch backend is fully migrated to TorchLlmArgs from ExecutorConfig
22572302
def get_pytorch_backend_config(self) -> "PyTorchConfig":
22582303
from tensorrt_llm._torch.pyexecutor.config import PyTorchConfig
@@ -2303,7 +2348,16 @@ def get_pytorch_backend_config(self) -> "PyTorchConfig":
23032348
enable_min_latency=self.enable_min_latency,
23042349
stream_interval=self.stream_interval,
23052350
force_dynamic_quantization=self.force_dynamic_quantization,
2306-
allreduce_strategy=self.allreduce_strategy)
2351+
allreduce_strategy=self.allreduce_strategy,
2352+
attention_dp_enable_balance=bool(
2353+
self.attention_dp_config is not None
2354+
and self.attention_dp_config.enable_balance),
2355+
attention_dp_time_out_iters=self.attention_dp_config.timeout_iters
2356+
if self.attention_dp_config is not None else
2357+
AttentionDpConfig.model_fields['timeout_iters'].default,
2358+
attention_dp_batching_wait_iters=self.attention_dp_config.
2359+
batching_wait_iters if self.attention_dp_config is not None else
2360+
AttentionDpConfig.model_fields['batching_wait_iters'].default)
23072361

23082362

23092363
def update_llm_args_with_extra_dict(
@@ -2320,6 +2374,7 @@ def update_llm_args_with_extra_dict(
23202374
"speculative_config": DecodingBaseConfig,
23212375
"lora_config": LoraConfig,
23222376
"moe_config": MoeConfig,
2377+
"attention_dp_config": AttentionDpConfig,
23232378
}
23242379
for field_name, field_type in field_mapping.items():
23252380
if field_name in llm_args_dict:

tests/integration/defs/test_e2e.py

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1800,6 +1800,45 @@ def test_ptp_quickstart_advanced_auto(llm_root, llm_venv, model_name,
18001800
_check_mem_usage(running_log, [27.0, 0, 0, 0])
18011801

18021802

1803+
@skip_post_blackwell
1804+
@pytest.mark.skip_less_device_memory(80000)
1805+
@pytest.mark.skip_less_device(4)
1806+
@pytest.mark.parametrize("model_name,model_path", [
1807+
pytest.param(
1808+
'DeepSeek-V3-Lite-FP8', 'DeepSeek-V3-Lite/fp8', marks=skip_pre_hopper),
1809+
])
1810+
def test_ptp_quickstart_advanced_deepseek_v3_lite_4gpus_adp_balance(
1811+
llm_root, llm_venv, model_name, model_path):
1812+
print(f"Testing {model_name}.")
1813+
example_root = Path(os.path.join(llm_root, "examples", "llm-api"))
1814+
with tempfile.NamedTemporaryFile(mode='w+t',
1815+
suffix=f".{model_name}.log",
1816+
dir="./",
1817+
delete=True,
1818+
delete_on_close=True) as running_log:
1819+
llm_venv.run_cmd([
1820+
str(example_root / "quickstart_advanced.py"),
1821+
"--model_dir",
1822+
f"{llm_models_root()}/{model_path}",
1823+
"--moe_tp_size=1",
1824+
"--moe_ep_size=4",
1825+
"--tp_size=4",
1826+
"--use_cuda_graph",
1827+
"--enable_attention_dp",
1828+
f"--kv_cache_fraction={_MEM_FRACTION_95}",
1829+
"--max_batch_size=1",
1830+
"--max_seq_len=3000",
1831+
"--disable_kv_cache_reuse",
1832+
"--attention_dp_enable_balance",
1833+
"--attention_dp_time_out_iters",
1834+
"10",
1835+
"--attention_dp_batching_wait_iters",
1836+
"10",
1837+
],
1838+
stdout=running_log)
1839+
_check_mem_usage(running_log, [106.3, 0, 0, 0], 8)
1840+
1841+
18031842
@skip_post_blackwell
18041843
@pytest.mark.skip_less_device_memory(110000)
18051844
@pytest.mark.skip_less_device(8)

tests/integration/test_lists/test-db/l0_dgx_h100.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,7 @@ l0_dgx_h100:
5151
- accuracy/test_disaggregated_serving.py::TestLlama3_1_8BInstruct::test_ctx_pp_gen_tp_asymmetric[MMLU-gen_tp=1-ctx_pp=2]
5252
- accuracy/test_disaggregated_serving.py::TestLlama3_1_8BInstruct::test_ctx_pp_gen_tp_asymmetric[MMLU-gen_tp=2-ctx_pp=2]
5353
- test_e2e.py::test_ptp_quickstart_advanced_bs1
54+
- test_e2e.py::test_ptp_quickstart_advanced_deepseek_v3_lite_4gpus_adp_balance[DeepSeek-V3-Lite-FP8-DeepSeek-V3-Lite/fp8]
5455
- condition:
5556
ranges:
5657
system_gpu_count:

tests/unittest/api_stability/references/llm.yaml

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -79,6 +79,10 @@ methods:
7979
annotation: Optional[tensorrt_llm.llmapi.llm_args.CudaGraphConfig]
8080
default: null
8181
status: beta
82+
attention_dp_config:
83+
annotation: Optional[tensorrt_llm.llmapi.llm_args.AttentionDpConfig]
84+
default: null
85+
status: beta
8286
checkpoint_loader:
8387
annotation: Optional[tensorrt_llm._torch.models.checkpoints.BaseCheckpointLoader]
8488
default: null

0 commit comments

Comments
 (0)