Skip to content

Commit db8aebc

Browse files
committed
Fix possible premature generator completion; add tests
1 parent 10a88ec commit db8aebc

File tree

2 files changed

+101
-23
lines changed

2 files changed

+101
-23
lines changed

tests/async_engine/test_async_llm_engine.py

Lines changed: 88 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,19 @@
11
import asyncio
22
import os
3+
from asyncio import CancelledError
34
from dataclasses import dataclass
5+
from typing import Optional
46

57
import pytest
8+
import pytest_asyncio
69
import torch
710

811
from vllm import SamplingParams
912
from vllm.config import ParallelConfig
1013
from vllm.engine.async_llm_engine import AsyncEngineArgs, AsyncLLMEngine
14+
from vllm.outputs import RequestOutput as RealRequestOutput
1115

16+
from ..conftest import cleanup
1217
from ..utils import wait_for_gpu_memory_to_clear
1318

1419

@@ -118,33 +123,103 @@ async def test_new_requests_event():
118123
os.environ.pop("VLLM_ALLOW_ENGINE_USE_RAY")
119124

120125

121-
def test_asyncio_run():
126+
def start_engine():
122127
wait_for_gpu_memory_to_clear(
123128
devices=list(range(torch.cuda.device_count())),
124129
threshold_bytes=2 * 2**30,
125130
timeout_s=60,
126131
)
127132

128-
engine = AsyncLLMEngine.from_engine_args(
129-
AsyncEngineArgs(model="facebook/opt-125m"))
133+
return AsyncLLMEngine.from_engine_args(
134+
AsyncEngineArgs(model="facebook/opt-125m", enforce_eager=True))
135+
136+
137+
@pytest_asyncio.fixture(scope="module")
138+
async def async_engine():
139+
engine = await asyncio.get_event_loop().run_in_executor(executor=None,
140+
func=start_engine)
141+
try:
142+
yield engine
143+
finally:
144+
engine.shutdown_background_loop()
145+
del engine
146+
await asyncio.sleep(0.1)
147+
cleanup()
148+
149+
150+
@pytest.fixture()
151+
def should_do_global_cleanup_after_test(request) -> bool:
152+
# So we can share the async engine fixture between these tests
153+
return False
154+
155+
156+
@pytest.mark.asyncio(scope="module")
157+
async def test_asyncio_run(async_engine):
130158

131159
async def run(prompt: str):
132160
sampling_params = SamplingParams(
133161
temperature=0,
134162
max_tokens=32,
135163
)
136164

137-
async for output in engine.generate(prompt,
138-
sampling_params,
139-
request_id=prompt):
165+
async for output in async_engine.generate(prompt,
166+
sampling_params,
167+
request_id=prompt):
140168
final_output = output
141169
return final_output
142170

143-
async def generate():
144-
return await asyncio.gather(
145-
run("test0"),
146-
run("test1"),
147-
)
148-
149-
results = asyncio.run(generate())
171+
results = await asyncio.gather(
172+
run("test0"),
173+
run("test1"),
174+
)
150175
assert len(results) == 2
176+
177+
178+
@pytest.mark.asyncio(scope="module")
179+
async def test_cancellation(async_engine):
180+
sampling_params = SamplingParams(
181+
temperature=0,
182+
min_tokens=10,
183+
max_tokens=10,
184+
)
185+
186+
i = 0
187+
with pytest.raises(CancelledError):
188+
async for output in async_engine.generate("test2",
189+
sampling_params,
190+
request_id="test2"):
191+
assert not output.finished
192+
i += 1
193+
if i == 5:
194+
await async_engine.abort("test2")
195+
196+
assert i == 5
197+
198+
199+
@pytest.mark.asyncio(scope="module")
200+
async def test_delayed_generator(async_engine):
201+
sampling_params = SamplingParams(
202+
temperature=0,
203+
min_tokens=10,
204+
max_tokens=10,
205+
)
206+
207+
stream = async_engine.generate("test3",
208+
sampling_params,
209+
request_id="test3")
210+
i = 0
211+
final_output: Optional[RealRequestOutput] = None
212+
async for output in stream:
213+
final_output = output
214+
if i == 0:
215+
# wait for generation to complete before consuming
216+
# the remaining messages
217+
await asyncio.sleep(1)
218+
if i < 9:
219+
assert not output.finished
220+
i += 1
221+
222+
assert i == 10
223+
assert final_output is not None
224+
assert len(final_output.outputs[0].token_ids) == 10
225+
assert final_output.finished

vllm/engine/async_llm_engine.py

Lines changed: 13 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,8 @@
22
import time
33
from dataclasses import dataclass
44
from functools import partial
5-
from typing import (AsyncGenerator, Callable, Dict, Iterable, List, Mapping,
6-
Optional, Set, Tuple, Type, Union)
5+
from typing import (Any, AsyncGenerator, Callable, Dict, Iterable, List,
6+
Mapping, Optional, Set, Tuple, Type, Union)
77

88
import torch
99
from transformers import PreTrainedTokenizer
@@ -85,9 +85,8 @@ def __init__(self, request_id: str, cancel: Callable[[str], None]) -> None:
8585

8686
def put(self, item: Union[RequestOutput, EmbeddingRequestOutput,
8787
Exception]) -> None:
88-
if self._finished:
89-
return
90-
self._queue.put_nowait(item)
88+
if not self._finished:
89+
self._queue.put_nowait(item)
9190

9291
def finish(
9392
self,
@@ -96,7 +95,7 @@ def finish(
9695
if not self._finished:
9796
self._finished = True
9897
self._queue.put_nowait(
99-
exception if exception is not None else STOP_ITERATION)
98+
exception if self._is_raisable(exception) else STOP_ITERATION)
10099

101100
@property
102101
def finished(self) -> bool:
@@ -106,11 +105,9 @@ async def generator(
106105
self
107106
) -> AsyncGenerator[Union[RequestOutput, EmbeddingRequestOutput], None]:
108107
try:
109-
while not self._finished:
108+
while True:
110109
result = await self._queue.get()
111-
if isinstance(result, BaseException) or \
112-
(isinstance(result, type) and \
113-
issubclass(result, BaseException)):
110+
if self._is_raisable(result):
114111
if result == STOP_ITERATION:
115112
return
116113
raise result
@@ -119,6 +116,12 @@ async def generator(
119116
self._cancel(self.request_id)
120117
raise asyncio.CancelledError from None
121118

119+
@staticmethod
120+
def _is_raisable(value: Any):
121+
return isinstance(value, BaseException) or \
122+
(isinstance(value, type) and \
123+
issubclass(value, BaseException))
124+
122125

123126
class RequestTracker:
124127
"""Synchronous abstraction for tracking requests."""

0 commit comments

Comments
 (0)