diff --git a/tests/async_engine/test_merge_async_iterators.py b/tests/async_engine/test_merge_async_iterators.py deleted file mode 100644 index ea453526c77f..000000000000 --- a/tests/async_engine/test_merge_async_iterators.py +++ /dev/null @@ -1,41 +0,0 @@ -import asyncio -from typing import AsyncIterator, Tuple - -import pytest - -from vllm.utils import merge_async_iterators - - -@pytest.mark.asyncio -async def test_merge_async_iterators(): - - async def mock_async_iterator(idx: int) -> AsyncIterator[str]: - try: - while True: - yield f"item from iterator {idx}" - await asyncio.sleep(0.1) - except asyncio.CancelledError: - pass - - iterators = [mock_async_iterator(i) for i in range(3)] - merged_iterator: AsyncIterator[Tuple[int, str]] = merge_async_iterators( - *iterators) - - async def stream_output(generator: AsyncIterator[Tuple[int, str]]): - async for idx, output in generator: - print(f"idx: {idx}, output: {output}") - - task = asyncio.create_task(stream_output(merged_iterator)) - await asyncio.sleep(0.5) - task.cancel() - with pytest.raises(asyncio.CancelledError): - await task - - for iterator in iterators: - try: - await asyncio.wait_for(anext(iterator), 1) - except StopAsyncIteration: - # All iterators should be cancelled and print this message. - print("Iterator was cancelled normally") - except (Exception, asyncio.CancelledError) as e: - raise AssertionError() from e diff --git a/tests/test_utils.py b/tests/test_utils.py index 54dc5c6f5bfb..a6c3896fa43b 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -1,9 +1,64 @@ +import asyncio +import sys +from typing import (TYPE_CHECKING, Any, AsyncIterator, Awaitable, Protocol, + Tuple, TypeVar) + import pytest -from vllm.utils import deprecate_kwargs +from vllm.utils import deprecate_kwargs, merge_async_iterators from .utils import error_on_warning +if sys.version_info < (3, 10): + if TYPE_CHECKING: + _AwaitableT = TypeVar("_AwaitableT", bound=Awaitable[Any]) + _AwaitableT_co = TypeVar("_AwaitableT_co", + bound=Awaitable[Any], + covariant=True) + + class _SupportsSynchronousAnext(Protocol[_AwaitableT_co]): + + def __anext__(self) -> _AwaitableT_co: + ... + + def anext(i: "_SupportsSynchronousAnext[_AwaitableT]", /) -> "_AwaitableT": + return i.__anext__() + + +@pytest.mark.asyncio +async def test_merge_async_iterators(): + + async def mock_async_iterator(idx: int) -> AsyncIterator[str]: + try: + while True: + yield f"item from iterator {idx}" + await asyncio.sleep(0.1) + except asyncio.CancelledError: + pass + + iterators = [mock_async_iterator(i) for i in range(3)] + merged_iterator: AsyncIterator[Tuple[int, str]] = merge_async_iterators( + *iterators) + + async def stream_output(generator: AsyncIterator[Tuple[int, str]]): + async for idx, output in generator: + print(f"idx: {idx}, output: {output}") + + task = asyncio.create_task(stream_output(merged_iterator)) + await asyncio.sleep(0.5) + task.cancel() + with pytest.raises(asyncio.CancelledError): + await task + + for iterator in iterators: + try: + await asyncio.wait_for(anext(iterator), 1) + except StopAsyncIteration: + # All iterators should be cancelled and print this message. + print("Iterator was cancelled normally") + except (Exception, asyncio.CancelledError) as e: + raise AssertionError() from e + def test_deprecate_kwargs_always(): diff --git a/vllm/utils.py b/vllm/utils.py index 85e045cb3b76..26140e15636a 100644 --- a/vllm/utils.py +++ b/vllm/utils.py @@ -5,6 +5,7 @@ import os import socket import subprocess +import sys import tempfile import threading import uuid @@ -234,9 +235,11 @@ async def consumer(): yield item except (Exception, asyncio.CancelledError) as e: for task in _tasks: - # NOTE: Pass the error msg in cancel() - # when only Python 3.9+ is supported. - task.cancel() + if sys.version_info >= (3, 9): + # msg parameter only supported in Python 3.9+ + task.cancel(e) + else: + task.cancel() raise e await asyncio.gather(*_tasks)