Skip to content

Commit 9a3f49a

Browse files
authored
[BugFix] Overhaul async request cancellation (#7111)
1 parent f9a5600 commit 9a3f49a

File tree

11 files changed

+226
-226
lines changed

11 files changed

+226
-226
lines changed

tests/async_engine/api_server_async_engine.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
"""vllm.entrypoints.api_server with some extra logging for testing."""
2-
from typing import Any, Dict
2+
from typing import Any, Dict, Iterable
33

44
import uvicorn
55
from fastapi.responses import JSONResponse, Response
@@ -18,9 +18,10 @@ def __init__(self, *args, **kwargs):
1818
super().__init__(*args, **kwargs)
1919
self._num_aborts = 0
2020

21-
async def abort(self, request_id: str) -> None:
22-
await super().abort(request_id)
23-
self._num_aborts += 1
21+
async def _engine_abort(self, request_ids: Iterable[str]):
22+
ids = list(request_ids)
23+
self._num_aborts += len(ids)
24+
await super()._engine_abort(ids)
2425

2526
def testing_stats(self) -> Dict[str, Any]:
2627
return {"num_aborted_requests": self._num_aborts}

tests/async_engine/test_request_tracker.py

Lines changed: 12 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -10,23 +10,23 @@ async def test_request_tracker():
1010
stream_1 = tracker.add_request("1")
1111
assert tracker.new_requests_event.is_set()
1212
await tracker.wait_for_new_requests()
13-
new, finished = tracker.get_new_and_finished_requests()
13+
new, aborted = tracker.get_new_and_aborted_requests()
1414
assert not tracker.new_requests_event.is_set()
1515
assert len(new) == 1
1616
assert new[0]["request_id"] == "1"
17-
assert not finished
17+
assert not aborted
1818
assert not stream_1.finished
1919

2020
stream_2 = tracker.add_request("2")
2121
stream_3 = tracker.add_request("3")
2222
assert tracker.new_requests_event.is_set()
2323
await tracker.wait_for_new_requests()
24-
new, finished = tracker.get_new_and_finished_requests()
24+
new, aborted = tracker.get_new_and_aborted_requests()
2525
assert not tracker.new_requests_event.is_set()
2626
assert len(new) == 2
2727
assert new[0]["request_id"] == "2"
2828
assert new[1]["request_id"] == "3"
29-
assert not finished
29+
assert not aborted
3030
assert not stream_2.finished
3131
assert not stream_3.finished
3232

@@ -36,19 +36,19 @@ async def test_request_tracker():
3636
assert not tracker.new_requests_event.is_set()
3737

3838
tracker.abort_request("1")
39-
new, finished = tracker.get_new_and_finished_requests()
40-
assert len(finished) == 1
41-
assert "1" in finished
39+
new, aborted = tracker.get_new_and_aborted_requests()
40+
assert len(aborted) == 1
41+
assert "1" in aborted
4242
assert not new
4343
assert stream_1.finished
4444

4545
stream_4 = tracker.add_request("4")
4646
tracker.abort_request("4")
4747
assert tracker.new_requests_event.is_set()
4848
await tracker.wait_for_new_requests()
49-
new, finished = tracker.get_new_and_finished_requests()
50-
assert len(finished) == 1
51-
assert "4" in finished
49+
new, aborted = tracker.get_new_and_aborted_requests()
50+
assert len(aborted) == 1
51+
assert "4" in aborted
5252
assert not new
5353
assert stream_4.finished
5454

@@ -57,10 +57,9 @@ async def test_request_tracker():
5757
tracker.process_request_output(
5858
RequestOutput("2", "output", [], [], [], finished=True))
5959
await tracker.wait_for_new_requests()
60-
new, finished = tracker.get_new_and_finished_requests()
60+
new, aborted = tracker.get_new_and_aborted_requests()
6161
assert not tracker.new_requests_event.is_set()
62-
assert len(finished) == 1
63-
assert "2" in finished
62+
assert not aborted
6463
assert len(new) == 1
6564
assert new[0]["request_id"] == "5"
6665
assert stream_2.finished

tests/test_utils.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
import os
33
import socket
44
import sys
5+
from functools import partial
56
from typing import (TYPE_CHECKING, Any, AsyncIterator, Awaitable, Protocol,
67
Tuple, TypeVar)
78

@@ -37,11 +38,11 @@ async def mock_async_iterator(idx: int) -> AsyncIterator[str]:
3738
yield f"item from iterator {idx}"
3839
await asyncio.sleep(0.1)
3940
except asyncio.CancelledError:
40-
pass
41+
print(f"iterator {idx} cancelled")
4142

4243
iterators = [mock_async_iterator(i) for i in range(3)]
4344
merged_iterator: AsyncIterator[Tuple[int, str]] = merge_async_iterators(
44-
*iterators)
45+
*iterators, is_cancelled=partial(asyncio.sleep, 0, result=False))
4546

4647
async def stream_output(generator: AsyncIterator[Tuple[int, str]]):
4748
async for idx, output in generator:

0 commit comments

Comments
 (0)