Skip to content

Commit 0ffafbb

Browse files
ShunkangShunkang
authored andcommitted
Add unitest
Signed-off-by: Shunkang <[email protected]>
1 parent fd7a311 commit 0ffafbb

File tree

3 files changed

+74
-1
lines changed

3 files changed

+74
-1
lines changed

tensorrt_llm/_torch/pyexecutor/executor_request_queue.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,7 @@ def __init__(self, dist: Distributed, enable_attention_dp: bool,
5050
self.waiting_queue: deque[RequestQueueItem] = deque()
5151
self.canceled_req_ids = []
5252
self.enable_attention_dp = enable_attention_dp
53+
self.max_batch_size = max_batch_size
5354
self.max_beam_width = max_beam_width
5455
self.max_num_active_requests = max_num_active_requests
5556
self.is_disaggregated = is_disaggregated

tensorrt_llm/llmapi/llm_args.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2330,7 +2330,7 @@ def validate_attention_dp_config(self) -> 'TorchLlmArgs':
23302330
@model_validator(mode='after')
23312331
def validate_batch_wait_timeout(self) -> 'TorchLlmArgs':
23322332
"""Validate batch wait timeout."""
2333-
if self.batch_wait_timeout <= 0:
2333+
if self.batch_wait_timeout < 0:
23342334
raise ValueError("batch_wait_timeout must be greater than 0")
23352335
return self
23362336

tests/unittest/_torch/test_executor_request_queue.py

Lines changed: 72 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,7 @@ def executor_queue(mock_dist):
4040
max_beam_width=1,
4141
max_num_active_requests=16,
4242
enable_iter_perf_stats=True,
43+
batch_wait_timeout=0.0,
4344
is_disaggregated=False)
4445

4546

@@ -52,6 +53,7 @@ def integration_queue(mock_dist):
5253
max_beam_width=2,
5354
max_num_active_requests=8,
5455
enable_iter_perf_stats=True,
56+
batch_wait_timeout=0.0,
5557
is_disaggregated=False)
5658

5759

@@ -215,6 +217,75 @@ def test_get_from_request_queue_with_timeout(executor_queue):
215217
assert elapsed < 0.2 # Should finish within timeout
216218

217219

220+
def test_get_from_request_queue_async_behavior(executor_queue):
221+
"""Test asynchronous behavior where requests arrive over time."""
222+
import threading
223+
224+
def add_requests_after_delay(delay, num_requests):
225+
"""Helper function to add requests after a delay."""
226+
time.sleep(delay)
227+
for i in range(num_requests):
228+
item = RequestQueueItem(i + 10, Mock())
229+
executor_queue.request_queue.put(item)
230+
231+
# Test 1: Without batch_wait_timeout (should only get initial requests)
232+
executor_queue.batch_wait_timeout = 0.0
233+
234+
initial_requests = 3
235+
for i in range(initial_requests):
236+
item = RequestQueueItem(i, Mock())
237+
executor_queue.request_queue.put(item)
238+
239+
thread = threading.Thread(target=add_requests_after_delay, args=(0.05, 2))
240+
thread.start()
241+
242+
# Get requests immediately - should only get the initial ones
243+
start_time = time.time()
244+
items = executor_queue._get_from_request_queue(None)
245+
elapsed = time.time() - start_time
246+
247+
assert len(items) == initial_requests
248+
assert elapsed < 0.1
249+
assert all(item.id < 10 for item in items)
250+
251+
thread.join()
252+
253+
# Test 2: With batch_wait_timeout (should wait and get all requests)
254+
executor_queue.batch_wait_timeout = 0.2
255+
256+
# Clear the queue and add initial requests again
257+
while not executor_queue.request_queue.empty():
258+
try:
259+
executor_queue.request_queue.get_nowait()
260+
except queue.Empty:
261+
break
262+
263+
initial_requests = 2
264+
for i in range(initial_requests):
265+
item = RequestQueueItem(i + 20, Mock())
266+
executor_queue.request_queue.put(item)
267+
268+
thread = threading.Thread(target=add_requests_after_delay, args=(0.05, 3))
269+
thread.start()
270+
271+
# Get requests with batch_wait_timeout - should wait and get all
272+
start_time = time.time()
273+
items = executor_queue._get_from_request_queue(None)
274+
elapsed = time.time() - start_time
275+
276+
# Should wait and return all requests
277+
assert len(items) == initial_requests + 3
278+
assert elapsed >= 0.05
279+
assert elapsed < 0.3
280+
281+
initial_ids = {item.id for item in items if 20 <= item.id < 30}
282+
delayed_ids = {item.id for item in items if 10 <= item.id < 20}
283+
assert len(initial_ids) == initial_requests
284+
assert len(delayed_ids) == 3
285+
286+
thread.join()
287+
288+
218289
def test_get_from_waiting_queue(executor_queue):
219290
"""Test getting items from waiting queue."""
220291
# Add items to waiting queue
@@ -371,6 +442,7 @@ def attention_dp_queue(mock_dist_attention_dp):
371442
max_beam_width=2,
372443
max_num_active_requests=8,
373444
enable_iter_perf_stats=True,
445+
batch_wait_timeout=0.0,
374446
is_disaggregated=False)
375447
# Initialize all_ranks_num_active_requests
376448
return queue

0 commit comments

Comments
 (0)