Skip to content

Commit dc9a9af

Browse files
yunruislitaotju
andcommitted
[None][opt] Add batch waiting when scheduling (NVIDIA#7287)
Signed-off-by: yunruis <[email protected]> Co-authored-by: Tao Li @ NVIDIA <[email protected]>
1 parent 8484aa9 commit dc9a9af

File tree

7 files changed

+135
-1
lines changed

7 files changed

+135
-1
lines changed

tensorrt_llm/_torch/auto_deploy/shim/ad_executor.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -142,6 +142,9 @@ def __init__(
142142
self.pytorch_backend_config.attention_dp_time_out_iters = 50
143143
self.pytorch_backend_config.attention_dp_batching_wait_iters = 10
144144
self.pytorch_backend_config.batch_wait_timeout_ms = 0
145+
self.pytorch_backend_config.batch_wait_timeout_iters = 0
146+
self.pytorch_backend_config.batch_wait_max_tokens_ratio = 0.0
147+
self.pytorch_backend_config.max_num_tokens = 8192
145148
self.iter_counter = 0
146149

147150
# NOTE (lucaslie): not a declared base member in the base class; required by PyExecutor...

tensorrt_llm/_torch/pyexecutor/config.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,11 @@ class PyTorchConfig:
4949
attention_dp_time_out_iters: int = 50
5050
attention_dp_batching_wait_iters: int = 10
5151

52+
max_num_tokens: int = 8192
53+
5254
batch_wait_timeout_ms: float = 0
55+
batch_wait_timeout_iters: int = 0
56+
batch_wait_max_tokens_ratio: float = 0
5357

5458
attn_backend: str = 'TRTLLM'
5559
moe_backend: str = 'CUTLASS'

tensorrt_llm/_torch/pyexecutor/py_executor.py

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -190,6 +190,7 @@ def __init__(self,
190190
self.active = True
191191
self.max_beam_width = max_beam_width
192192
self.max_draft_len = max_draft_len
193+
self.max_num_tokens = model_engine.pytorch_backend_config.max_num_tokens
193194
self.print_log = model_engine.pytorch_backend_config.print_iter_log
194195
self.enable_iter_perf_stats = model_engine.pytorch_backend_config.enable_iter_perf_stats
195196
self.enable_iter_req_stats = model_engine.pytorch_backend_config.enable_iter_req_stats
@@ -198,6 +199,10 @@ def __init__(self,
198199
self.attention_dp_time_out_iters = model_engine.pytorch_backend_config.attention_dp_time_out_iters
199200
self.attention_dp_batching_wait_iters = model_engine.pytorch_backend_config.attention_dp_batching_wait_iters
200201
self.batch_wait_timeout_ms = model_engine.pytorch_backend_config.batch_wait_timeout_ms
202+
self.batch_wait_timeout_iters = model_engine.pytorch_backend_config.batch_wait_timeout_iters
203+
self.batch_wait_max_tokens_ratio = model_engine.pytorch_backend_config.batch_wait_max_tokens_ratio
204+
self.enable_batch_waiting = self.batch_wait_timeout_iters > 0 or self.batch_wait_max_tokens_ratio > 0
205+
201206
self.num_fetch_requests_cur_rank = 0
202207
self.num_fetch_requests = 0
203208
self.shutdown_event = threading.Event()
@@ -244,6 +249,7 @@ def __init__(self,
244249
self.max_batch_size = max_batch_size
245250
self.adp_ctx_waiting_iters_count = 0
246251
self.adp_ctx_batching_wait_iters_count = 0
252+
self.batch_wait_iters_count = 0
247253

248254
# request fetcher initialization
249255
self.executor_request_queue = ExecutorRequestQueue(
@@ -1397,6 +1403,27 @@ def _balance_adp_requests(self, context_requests: list[LlmRequest],
13971403
balanced_context_requests = context_requests
13981404
return balanced_context_requests
13991405

1406+
def _waiting_requests(self, context_requests: list[LlmRequest],
1407+
generation_requests: list[LlmRequest]):
1408+
if not self.enable_batch_waiting:
1409+
return context_requests
1410+
1411+
waited_context_requests = []
1412+
stop_waiting = False
1413+
num_scheduled_ctx_tokens = sum(
1414+
len(ctx_req.get_tokens(0)) for ctx_req in context_requests)
1415+
num_scheduled_gen_tokens = sum(1 + gen_req.num_draft_tokens
1416+
for gen_req in generation_requests)
1417+
num_scheduled_tokens = num_scheduled_ctx_tokens + num_scheduled_gen_tokens
1418+
1419+
stop_waiting = self.batch_wait_iters_count >= self.batch_wait_timeout_iters or num_scheduled_tokens >= self.batch_wait_max_tokens_ratio * self.max_num_tokens
1420+
if stop_waiting:
1421+
waited_context_requests = context_requests
1422+
self.batch_wait_iters_count = 0
1423+
else:
1424+
self.batch_wait_iters_count += 1
1425+
return waited_context_requests
1426+
14001427
@nvtx_range("_schedule")
14011428
def _schedule(self):
14021429
scheduler_output = self.scheduler.schedule_request(
@@ -1407,6 +1434,14 @@ def _schedule(self):
14071434
scheduler_output.context_requests,
14081435
scheduler_output.generation_requests)
14091436

1437+
# if no generation requests, no need to wait, to avoid dead waiting
1438+
if not self.enable_attention_dp and self.enable_batch_waiting and len(
1439+
scheduler_output.context_requests) > 0 and len(
1440+
scheduler_output.generation_requests) > 0:
1441+
scheduled_context_requests = self._waiting_requests(
1442+
scheduler_output.context_requests,
1443+
scheduler_output.generation_requests)
1444+
14101445
scheduled_requests = ScheduledRequests()
14111446
scheduled_requests.context_requests = scheduled_context_requests
14121447
scheduled_requests.generation_requests = scheduler_output.generation_requests

tensorrt_llm/llmapi/llm_args.py

Lines changed: 41 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2208,6 +2208,18 @@ class TorchLlmArgs(BaseLlmArgs):
22082208
"If greater than 0, the request queue might wait up to batch_wait_timeout_ms to receive max_batch_size requests, if fewer than max_batch_size requests are currently available. If 0, no waiting occurs.",
22092209
status="prototype")
22102210

2211+
batch_wait_timeout_iters: int = Field(
2212+
default=0,
2213+
description=
2214+
"Maximum number of iterations the scheduler will wait to accumulate new coming requests for improved GPU utilization efficiency. If greater than 0, the scheduler will delay batch processing to gather more requests up to the specified iteration limit. If 0, disables timeout-iters-based batching delays.",
2215+
status="prototype")
2216+
2217+
batch_wait_max_tokens_ratio: float = Field(
2218+
default=0,
2219+
description=
2220+
"Token accumulation threshold ratio for batch scheduling optimization. If greater than 0, the scheduler will accumulate requests locally until the total token count reaches batch_wait_max_tokens_ratio * max_num_tokens. This mechanism enhances GPU utilization efficiency by ensuring adequate batch sizes.If 0 disables token-based batching delays.",
2221+
status="prototype")
2222+
22112223
torch_compile_config: Optional[TorchCompileConfig] = Field(
22122224
default=None, description="Torch compile config.", status="prototype")
22132225

@@ -2481,6 +2493,31 @@ def validate_batch_wait_timeout_ms(self) -> 'TorchLlmArgs':
24812493
raise ValueError("batch_wait_timeout_ms must be greater than 0")
24822494
return self
24832495

2496+
@model_validator(mode='after')
2497+
def validate_batch_wait_timeout_iters(self) -> 'TorchLlmArgs':
2498+
if self.batch_wait_timeout_iters < 0:
2499+
raise ValueError(
2500+
f"batch_wait_timeout_iters must be >= 0, got {self.batch_wait_timeout_iters}"
2501+
)
2502+
return self
2503+
2504+
@model_validator(mode='after')
2505+
def validate_batch_wait_max_tokens_ratio(self) -> 'TorchLlmArgs':
2506+
if self.batch_wait_max_tokens_ratio < 0 or self.batch_wait_max_tokens_ratio > 1:
2507+
raise ValueError(
2508+
f"batch_wait_max_tokens_ratio must be in range [0, 1], got {self.batch_wait_max_tokens_ratio}"
2509+
)
2510+
return self
2511+
2512+
def get_executor_config(
2513+
self,
2514+
_hf_model_dir: Optional[Path] = None,
2515+
tokenizer: Optional[TokenizerBase] = None,
2516+
) -> _ExecutorConfig:
2517+
executor_config = super().get_executor_config(_hf_model_dir, tokenizer)
2518+
executor_config.mm_encoder_only = self.mm_encoder_only
2519+
return executor_config
2520+
24842521
# TODO: Remove this after the PyTorch backend is fully migrated to TorchLlmArgs from ExecutorConfig
24852522
def get_pytorch_backend_config(self) -> "PyTorchConfig":
24862523
from tensorrt_llm._torch.pyexecutor.config import PyTorchConfig
@@ -2547,7 +2584,10 @@ def get_pytorch_backend_config(self) -> "PyTorchConfig":
25472584
attention_dp_batching_wait_iters=self.attention_dp_config.
25482585
batching_wait_iters if self.attention_dp_config is not None else
25492586
AttentionDpConfig.model_fields['batching_wait_iters'].default,
2550-
batch_wait_timeout_ms=self.batch_wait_timeout_ms)
2587+
batch_wait_timeout_ms=self.batch_wait_timeout_ms,
2588+
batch_wait_timeout_iters=self.batch_wait_timeout_iters,
2589+
batch_wait_max_tokens_ratio=self.batch_wait_max_tokens_ratio,
2590+
)
25512591

25522592

25532593
def update_llm_args_with_extra_dict(

tests/integration/defs/accuracy/test_llm_api_pytorch.py

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1606,6 +1606,49 @@ def test_nvfp4(self, fp8kv, attention_dp, cuda_graph, overlap_scheduler,
16061606
task = GSM8K(self.MODEL_NAME)
16071607
task.evaluate(llm)
16081608

1609+
@skip_pre_blackwell
1610+
@parametrize_with_ids("torch_compile", [False, True])
1611+
@parametrize_with_ids("fp8kv,cuda_graph,overlap_scheduler",
1612+
[(False, False, False), (True, True, True)])
1613+
@parametrize_with_ids("mtp_nextn", [0, 2])
1614+
@parametrize_with_ids(
1615+
"batch_wait_timeout_iters,batch_wait_max_tokens_ratio", [(0, 0),
1616+
(10, 0.75),
1617+
(10, 0),
1618+
(0, 0.75)])
1619+
def test_nvfp4_batch_waiting(self, torch_compile, fp8kv, cuda_graph,
1620+
overlap_scheduler, mtp_nextn,
1621+
batch_wait_timeout_iters,
1622+
batch_wait_max_tokens_ratio):
1623+
moe_backend = "CUTLASS"
1624+
kv_cache_config = KvCacheConfig(free_gpu_memory_fraction=0.75)
1625+
torch_compile_config = TorchCompileConfig(
1626+
enable_fullgraph=True,
1627+
enable_piecewise_cuda_graph=cuda_graph,
1628+
capture_num_tokens=[2048, 8192],
1629+
max_num_streams=3) if torch_compile else None
1630+
pytorch_config = dict(
1631+
disable_overlap_scheduler=not overlap_scheduler,
1632+
cuda_graph_config=CudaGraphConfig() if cuda_graph else None,
1633+
torch_compile_config=torch_compile_config,
1634+
batch_wait_timeout_iters=batch_wait_timeout_iters,
1635+
batch_wait_max_tokens_ratio=batch_wait_max_tokens_ratio,
1636+
moe_config=MoeConfig(backend=moe_backend))
1637+
mtp_config = None
1638+
if mtp_nextn > 0:
1639+
mtp_config = MTPDecodingConfig(num_nextn_predict_layers=mtp_nextn)
1640+
if fp8kv:
1641+
kv_cache_config.dtype = "fp8"
1642+
with LLM(f"{llm_models_root()}/DeepSeek-V3-Lite/nvfp4_moe_only_mtp",
1643+
kv_cache_config=kv_cache_config,
1644+
**pytorch_config,
1645+
enable_attention_dp=False,
1646+
speculative_config=mtp_config) as llm:
1647+
assert llm.args.quant_config.quant_algo == QuantAlgo.NVFP4
1648+
1649+
task = GSM8K(self.MODEL_NAME)
1650+
task.evaluate(llm)
1651+
16091652
@pytest.mark.skip_less_device(4)
16101653
@skip_pre_blackwell
16111654
@parametrize_with_ids("torch_compile", [False, True])

tests/integration/test_lists/qa/llm_function_core.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -465,6 +465,7 @@ accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_bfloat16[mtp_nextn=2-
465465
accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_fp8_block_scales[mtp=disable-fp8kv=False-attention_dp=False-cuda_graph=False-overlap_scheduler=False-torch_compile=False]
466466
accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_nvfp4[moe_backend=CUTLASS-mtp_nextn=0-fp8kv=False-attention_dp=False-cuda_graph=False-overlap_scheduler=False-torch_compile=False]
467467
accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_nvfp4[moe_backend=CUTLASS-mtp_nextn=2-fp8kv=False-attention_dp=False-cuda_graph=False-overlap_scheduler=False-torch_compile=False]
468+
accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_nvfp4_batch_waiting[batch_wait_timeout_iters=10-batch_wait_max_tokens_ratio=0.75-mtp_nextn=0-fp8kv=False-cuda_graph=False-overlap_scheduler=False-torch_compile=False]
468469
accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_fp8_block_scales_4gpus_static_eplb
469470
accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_bfloat16_4gpus_online_eplb[mtp_nextn=0]
470471
accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_bfloat16_4gpus_online_eplb[mtp_nextn=2]

tests/unittest/api_stability/references/llm.yaml

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -131,6 +131,14 @@ methods:
131131
annotation: float
132132
default: 0
133133
status: prototype
134+
batch_wait_timeout_iters:
135+
annotation: int
136+
default: 0
137+
status: prototype
138+
batch_wait_max_tokens_ratio:
139+
annotation: float
140+
default: 0
141+
status: prototype
134142
print_iter_log:
135143
annotation: bool
136144
default: False

0 commit comments

Comments
 (0)