[Bugfix][CI/Build] Fix test and improve code for merge_async_iterators
(#5096)
This commit is contained in:
parent
ae495c74ea
commit
eecd864388
@ -1,41 +0,0 @@
|
|||||||
import asyncio
|
|
||||||
from typing import AsyncIterator, Tuple
|
|
||||||
|
|
||||||
import pytest
|
|
||||||
|
|
||||||
from vllm.utils import merge_async_iterators
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_merge_async_iterators():
|
|
||||||
|
|
||||||
async def mock_async_iterator(idx: int) -> AsyncIterator[str]:
|
|
||||||
try:
|
|
||||||
while True:
|
|
||||||
yield f"item from iterator {idx}"
|
|
||||||
await asyncio.sleep(0.1)
|
|
||||||
except asyncio.CancelledError:
|
|
||||||
pass
|
|
||||||
|
|
||||||
iterators = [mock_async_iterator(i) for i in range(3)]
|
|
||||||
merged_iterator: AsyncIterator[Tuple[int, str]] = merge_async_iterators(
|
|
||||||
*iterators)
|
|
||||||
|
|
||||||
async def stream_output(generator: AsyncIterator[Tuple[int, str]]):
|
|
||||||
async for idx, output in generator:
|
|
||||||
print(f"idx: {idx}, output: {output}")
|
|
||||||
|
|
||||||
task = asyncio.create_task(stream_output(merged_iterator))
|
|
||||||
await asyncio.sleep(0.5)
|
|
||||||
task.cancel()
|
|
||||||
with pytest.raises(asyncio.CancelledError):
|
|
||||||
await task
|
|
||||||
|
|
||||||
for iterator in iterators:
|
|
||||||
try:
|
|
||||||
await asyncio.wait_for(anext(iterator), 1)
|
|
||||||
except StopAsyncIteration:
|
|
||||||
# All iterators should be cancelled and print this message.
|
|
||||||
print("Iterator was cancelled normally")
|
|
||||||
except (Exception, asyncio.CancelledError) as e:
|
|
||||||
raise AssertionError() from e
|
|
@ -1,9 +1,64 @@
|
|||||||
|
import asyncio
|
||||||
|
import sys
|
||||||
|
from typing import (TYPE_CHECKING, Any, AsyncIterator, Awaitable, Protocol,
|
||||||
|
Tuple, TypeVar)
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
from vllm.utils import deprecate_kwargs
|
from vllm.utils import deprecate_kwargs, merge_async_iterators
|
||||||
|
|
||||||
from .utils import error_on_warning
|
from .utils import error_on_warning
|
||||||
|
|
||||||
|
if sys.version_info < (3, 10):
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
_AwaitableT = TypeVar("_AwaitableT", bound=Awaitable[Any])
|
||||||
|
_AwaitableT_co = TypeVar("_AwaitableT_co",
|
||||||
|
bound=Awaitable[Any],
|
||||||
|
covariant=True)
|
||||||
|
|
||||||
|
class _SupportsSynchronousAnext(Protocol[_AwaitableT_co]):
|
||||||
|
|
||||||
|
def __anext__(self) -> _AwaitableT_co:
|
||||||
|
...
|
||||||
|
|
||||||
|
def anext(i: "_SupportsSynchronousAnext[_AwaitableT]", /) -> "_AwaitableT":
|
||||||
|
return i.__anext__()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_merge_async_iterators():
|
||||||
|
|
||||||
|
async def mock_async_iterator(idx: int) -> AsyncIterator[str]:
|
||||||
|
try:
|
||||||
|
while True:
|
||||||
|
yield f"item from iterator {idx}"
|
||||||
|
await asyncio.sleep(0.1)
|
||||||
|
except asyncio.CancelledError:
|
||||||
|
pass
|
||||||
|
|
||||||
|
iterators = [mock_async_iterator(i) for i in range(3)]
|
||||||
|
merged_iterator: AsyncIterator[Tuple[int, str]] = merge_async_iterators(
|
||||||
|
*iterators)
|
||||||
|
|
||||||
|
async def stream_output(generator: AsyncIterator[Tuple[int, str]]):
|
||||||
|
async for idx, output in generator:
|
||||||
|
print(f"idx: {idx}, output: {output}")
|
||||||
|
|
||||||
|
task = asyncio.create_task(stream_output(merged_iterator))
|
||||||
|
await asyncio.sleep(0.5)
|
||||||
|
task.cancel()
|
||||||
|
with pytest.raises(asyncio.CancelledError):
|
||||||
|
await task
|
||||||
|
|
||||||
|
for iterator in iterators:
|
||||||
|
try:
|
||||||
|
await asyncio.wait_for(anext(iterator), 1)
|
||||||
|
except StopAsyncIteration:
|
||||||
|
# All iterators should be cancelled and print this message.
|
||||||
|
print("Iterator was cancelled normally")
|
||||||
|
except (Exception, asyncio.CancelledError) as e:
|
||||||
|
raise AssertionError() from e
|
||||||
|
|
||||||
|
|
||||||
def test_deprecate_kwargs_always():
|
def test_deprecate_kwargs_always():
|
||||||
|
|
||||||
|
@ -5,6 +5,7 @@ import gc
|
|||||||
import os
|
import os
|
||||||
import socket
|
import socket
|
||||||
import subprocess
|
import subprocess
|
||||||
|
import sys
|
||||||
import tempfile
|
import tempfile
|
||||||
import threading
|
import threading
|
||||||
import uuid
|
import uuid
|
||||||
@ -234,8 +235,10 @@ def merge_async_iterators(
|
|||||||
yield item
|
yield item
|
||||||
except (Exception, asyncio.CancelledError) as e:
|
except (Exception, asyncio.CancelledError) as e:
|
||||||
for task in _tasks:
|
for task in _tasks:
|
||||||
# NOTE: Pass the error msg in cancel()
|
if sys.version_info >= (3, 9):
|
||||||
# when only Python 3.9+ is supported.
|
# msg parameter only supported in Python 3.9+
|
||||||
|
task.cancel(e)
|
||||||
|
else:
|
||||||
task.cancel()
|
task.cancel()
|
||||||
raise e
|
raise e
|
||||||
await asyncio.gather(*_tasks)
|
await asyncio.gather(*_tasks)
|
||||||
|
Loading…
x
Reference in New Issue
Block a user