133 lines
3.8 KiB
Python
133 lines
3.8 KiB
Python
import asyncio
|
|
import os
|
|
import socket
|
|
import sys
|
|
from typing import (TYPE_CHECKING, Any, AsyncIterator, Awaitable, Protocol,
|
|
Tuple, TypeVar)
|
|
|
|
import pytest
|
|
|
|
from vllm.utils import deprecate_kwargs, get_open_port, merge_async_iterators
|
|
|
|
from .utils import error_on_warning
|
|
|
|
if sys.version_info < (3, 10):
|
|
if TYPE_CHECKING:
|
|
_AwaitableT = TypeVar("_AwaitableT", bound=Awaitable[Any])
|
|
_AwaitableT_co = TypeVar("_AwaitableT_co",
|
|
bound=Awaitable[Any],
|
|
covariant=True)
|
|
|
|
class _SupportsSynchronousAnext(Protocol[_AwaitableT_co]):
|
|
|
|
def __anext__(self) -> _AwaitableT_co:
|
|
...
|
|
|
|
def anext(i: "_SupportsSynchronousAnext[_AwaitableT]", /) -> "_AwaitableT":
|
|
return i.__anext__()
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_merge_async_iterators():
|
|
|
|
async def mock_async_iterator(idx: int) -> AsyncIterator[str]:
|
|
try:
|
|
while True:
|
|
yield f"item from iterator {idx}"
|
|
await asyncio.sleep(0.1)
|
|
except asyncio.CancelledError:
|
|
pass
|
|
|
|
iterators = [mock_async_iterator(i) for i in range(3)]
|
|
merged_iterator: AsyncIterator[Tuple[int, str]] = merge_async_iterators(
|
|
*iterators)
|
|
|
|
async def stream_output(generator: AsyncIterator[Tuple[int, str]]):
|
|
async for idx, output in generator:
|
|
print(f"idx: {idx}, output: {output}")
|
|
|
|
task = asyncio.create_task(stream_output(merged_iterator))
|
|
await asyncio.sleep(0.5)
|
|
task.cancel()
|
|
with pytest.raises(asyncio.CancelledError):
|
|
await task
|
|
|
|
for iterator in iterators:
|
|
try:
|
|
await asyncio.wait_for(anext(iterator), 1)
|
|
except StopAsyncIteration:
|
|
# All iterators should be cancelled and print this message.
|
|
print("Iterator was cancelled normally")
|
|
except (Exception, asyncio.CancelledError) as e:
|
|
raise AssertionError() from e
|
|
|
|
|
|
def test_deprecate_kwargs_always():
|
|
|
|
@deprecate_kwargs("old_arg", is_deprecated=True)
|
|
def dummy(*, old_arg: object = None, new_arg: object = None):
|
|
pass
|
|
|
|
with pytest.warns(DeprecationWarning, match="'old_arg'"):
|
|
dummy(old_arg=1)
|
|
|
|
with error_on_warning():
|
|
dummy(new_arg=1)
|
|
|
|
|
|
def test_deprecate_kwargs_never():
|
|
|
|
@deprecate_kwargs("old_arg", is_deprecated=False)
|
|
def dummy(*, old_arg: object = None, new_arg: object = None):
|
|
pass
|
|
|
|
with error_on_warning():
|
|
dummy(old_arg=1)
|
|
|
|
with error_on_warning():
|
|
dummy(new_arg=1)
|
|
|
|
|
|
def test_deprecate_kwargs_dynamic():
|
|
is_deprecated = True
|
|
|
|
@deprecate_kwargs("old_arg", is_deprecated=lambda: is_deprecated)
|
|
def dummy(*, old_arg: object = None, new_arg: object = None):
|
|
pass
|
|
|
|
with pytest.warns(DeprecationWarning, match="'old_arg'"):
|
|
dummy(old_arg=1)
|
|
|
|
with error_on_warning():
|
|
dummy(new_arg=1)
|
|
|
|
is_deprecated = False
|
|
|
|
with error_on_warning():
|
|
dummy(old_arg=1)
|
|
|
|
with error_on_warning():
|
|
dummy(new_arg=1)
|
|
|
|
|
|
def test_deprecate_kwargs_additional_message():
|
|
|
|
@deprecate_kwargs("old_arg", is_deprecated=True, additional_message="abcd")
|
|
def dummy(*, old_arg: object = None, new_arg: object = None):
|
|
pass
|
|
|
|
with pytest.warns(DeprecationWarning, match="abcd"):
|
|
dummy(old_arg=1)
|
|
|
|
|
|
def test_get_open_port():
|
|
os.environ["VLLM_PORT"] = "5678"
|
|
# make sure we can get multiple ports, even if the env var is set
|
|
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s1:
|
|
s1.bind(("localhost", get_open_port()))
|
|
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s2:
|
|
s2.bind(("localhost", get_open_port()))
|
|
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s3:
|
|
s3.bind(("localhost", get_open_port()))
|
|
os.environ.pop("VLLM_PORT")
|