[BugFix] Overhaul async request cancellation (#7111)
This commit is contained in:
parent
f9a5600649
commit
9a3f49ae07
@ -1,5 +1,5 @@
|
|||||||
"""vllm.entrypoints.api_server with some extra logging for testing."""
|
"""vllm.entrypoints.api_server with some extra logging for testing."""
|
||||||
from typing import Any, Dict
|
from typing import Any, Dict, Iterable
|
||||||
|
|
||||||
import uvicorn
|
import uvicorn
|
||||||
from fastapi.responses import JSONResponse, Response
|
from fastapi.responses import JSONResponse, Response
|
||||||
@ -18,9 +18,10 @@ class AsyncLLMEngineWithStats(AsyncLLMEngine):
|
|||||||
super().__init__(*args, **kwargs)
|
super().__init__(*args, **kwargs)
|
||||||
self._num_aborts = 0
|
self._num_aborts = 0
|
||||||
|
|
||||||
async def abort(self, request_id: str) -> None:
|
async def _engine_abort(self, request_ids: Iterable[str]):
|
||||||
await super().abort(request_id)
|
ids = list(request_ids)
|
||||||
self._num_aborts += 1
|
self._num_aborts += len(ids)
|
||||||
|
await super()._engine_abort(ids)
|
||||||
|
|
||||||
def testing_stats(self) -> Dict[str, Any]:
|
def testing_stats(self) -> Dict[str, Any]:
|
||||||
return {"num_aborted_requests": self._num_aborts}
|
return {"num_aborted_requests": self._num_aborts}
|
||||||
|
@ -10,23 +10,23 @@ async def test_request_tracker():
|
|||||||
stream_1 = tracker.add_request("1")
|
stream_1 = tracker.add_request("1")
|
||||||
assert tracker.new_requests_event.is_set()
|
assert tracker.new_requests_event.is_set()
|
||||||
await tracker.wait_for_new_requests()
|
await tracker.wait_for_new_requests()
|
||||||
new, finished = tracker.get_new_and_finished_requests()
|
new, aborted = tracker.get_new_and_aborted_requests()
|
||||||
assert not tracker.new_requests_event.is_set()
|
assert not tracker.new_requests_event.is_set()
|
||||||
assert len(new) == 1
|
assert len(new) == 1
|
||||||
assert new[0]["request_id"] == "1"
|
assert new[0]["request_id"] == "1"
|
||||||
assert not finished
|
assert not aborted
|
||||||
assert not stream_1.finished
|
assert not stream_1.finished
|
||||||
|
|
||||||
stream_2 = tracker.add_request("2")
|
stream_2 = tracker.add_request("2")
|
||||||
stream_3 = tracker.add_request("3")
|
stream_3 = tracker.add_request("3")
|
||||||
assert tracker.new_requests_event.is_set()
|
assert tracker.new_requests_event.is_set()
|
||||||
await tracker.wait_for_new_requests()
|
await tracker.wait_for_new_requests()
|
||||||
new, finished = tracker.get_new_and_finished_requests()
|
new, aborted = tracker.get_new_and_aborted_requests()
|
||||||
assert not tracker.new_requests_event.is_set()
|
assert not tracker.new_requests_event.is_set()
|
||||||
assert len(new) == 2
|
assert len(new) == 2
|
||||||
assert new[0]["request_id"] == "2"
|
assert new[0]["request_id"] == "2"
|
||||||
assert new[1]["request_id"] == "3"
|
assert new[1]["request_id"] == "3"
|
||||||
assert not finished
|
assert not aborted
|
||||||
assert not stream_2.finished
|
assert not stream_2.finished
|
||||||
assert not stream_3.finished
|
assert not stream_3.finished
|
||||||
|
|
||||||
@ -36,9 +36,9 @@ async def test_request_tracker():
|
|||||||
assert not tracker.new_requests_event.is_set()
|
assert not tracker.new_requests_event.is_set()
|
||||||
|
|
||||||
tracker.abort_request("1")
|
tracker.abort_request("1")
|
||||||
new, finished = tracker.get_new_and_finished_requests()
|
new, aborted = tracker.get_new_and_aborted_requests()
|
||||||
assert len(finished) == 1
|
assert len(aborted) == 1
|
||||||
assert "1" in finished
|
assert "1" in aborted
|
||||||
assert not new
|
assert not new
|
||||||
assert stream_1.finished
|
assert stream_1.finished
|
||||||
|
|
||||||
@ -46,9 +46,9 @@ async def test_request_tracker():
|
|||||||
tracker.abort_request("4")
|
tracker.abort_request("4")
|
||||||
assert tracker.new_requests_event.is_set()
|
assert tracker.new_requests_event.is_set()
|
||||||
await tracker.wait_for_new_requests()
|
await tracker.wait_for_new_requests()
|
||||||
new, finished = tracker.get_new_and_finished_requests()
|
new, aborted = tracker.get_new_and_aborted_requests()
|
||||||
assert len(finished) == 1
|
assert len(aborted) == 1
|
||||||
assert "4" in finished
|
assert "4" in aborted
|
||||||
assert not new
|
assert not new
|
||||||
assert stream_4.finished
|
assert stream_4.finished
|
||||||
|
|
||||||
@ -57,10 +57,9 @@ async def test_request_tracker():
|
|||||||
tracker.process_request_output(
|
tracker.process_request_output(
|
||||||
RequestOutput("2", "output", [], [], [], finished=True))
|
RequestOutput("2", "output", [], [], [], finished=True))
|
||||||
await tracker.wait_for_new_requests()
|
await tracker.wait_for_new_requests()
|
||||||
new, finished = tracker.get_new_and_finished_requests()
|
new, aborted = tracker.get_new_and_aborted_requests()
|
||||||
assert not tracker.new_requests_event.is_set()
|
assert not tracker.new_requests_event.is_set()
|
||||||
assert len(finished) == 1
|
assert not aborted
|
||||||
assert "2" in finished
|
|
||||||
assert len(new) == 1
|
assert len(new) == 1
|
||||||
assert new[0]["request_id"] == "5"
|
assert new[0]["request_id"] == "5"
|
||||||
assert stream_2.finished
|
assert stream_2.finished
|
||||||
|
@ -2,6 +2,7 @@ import asyncio
|
|||||||
import os
|
import os
|
||||||
import socket
|
import socket
|
||||||
import sys
|
import sys
|
||||||
|
from functools import partial
|
||||||
from typing import (TYPE_CHECKING, Any, AsyncIterator, Awaitable, Protocol,
|
from typing import (TYPE_CHECKING, Any, AsyncIterator, Awaitable, Protocol,
|
||||||
Tuple, TypeVar)
|
Tuple, TypeVar)
|
||||||
|
|
||||||
@ -37,11 +38,11 @@ async def test_merge_async_iterators():
|
|||||||
yield f"item from iterator {idx}"
|
yield f"item from iterator {idx}"
|
||||||
await asyncio.sleep(0.1)
|
await asyncio.sleep(0.1)
|
||||||
except asyncio.CancelledError:
|
except asyncio.CancelledError:
|
||||||
pass
|
print(f"iterator {idx} cancelled")
|
||||||
|
|
||||||
iterators = [mock_async_iterator(i) for i in range(3)]
|
iterators = [mock_async_iterator(i) for i in range(3)]
|
||||||
merged_iterator: AsyncIterator[Tuple[int, str]] = merge_async_iterators(
|
merged_iterator: AsyncIterator[Tuple[int, str]] = merge_async_iterators(
|
||||||
*iterators)
|
*iterators, is_cancelled=partial(asyncio.sleep, 0, result=False))
|
||||||
|
|
||||||
async def stream_output(generator: AsyncIterator[Tuple[int, str]]):
|
async def stream_output(generator: AsyncIterator[Tuple[int, str]]):
|
||||||
async for idx, output in generator:
|
async for idx, output in generator:
|
||||||
|
@ -1,7 +1,7 @@
|
|||||||
import asyncio
|
import asyncio
|
||||||
import time
|
import time
|
||||||
from functools import partial
|
from functools import partial
|
||||||
from typing import (AsyncIterator, Callable, Dict, Iterable, List, Mapping,
|
from typing import (AsyncGenerator, Callable, Dict, Iterable, List, Mapping,
|
||||||
Optional, Set, Tuple, Type, Union)
|
Optional, Set, Tuple, Type, Union)
|
||||||
|
|
||||||
from transformers import PreTrainedTokenizer
|
from transformers import PreTrainedTokenizer
|
||||||
@ -62,12 +62,16 @@ def _log_task_completion(task: asyncio.Task,
|
|||||||
"actual cause.") from e
|
"actual cause.") from e
|
||||||
|
|
||||||
|
|
||||||
|
STOP_ITERATION = Exception() # Sentinel
|
||||||
|
|
||||||
|
|
||||||
class AsyncStream:
|
class AsyncStream:
|
||||||
"""A stream of RequestOutputs or EmbeddingRequestOutputs for a request
|
"""A stream of RequestOutputs or EmbeddingRequestOutputs for a request
|
||||||
that can be iterated over asynchronously."""
|
that can be iterated over asynchronously via an async generator."""
|
||||||
|
|
||||||
def __init__(self, request_id: str) -> None:
|
def __init__(self, request_id: str, cancel: Callable[[str], None]) -> None:
|
||||||
self.request_id = request_id
|
self.request_id = request_id
|
||||||
|
self._cancel = cancel
|
||||||
self._queue: asyncio.Queue = asyncio.Queue()
|
self._queue: asyncio.Queue = asyncio.Queue()
|
||||||
self._finished = False
|
self._finished = False
|
||||||
|
|
||||||
@ -77,22 +81,30 @@ class AsyncStream:
|
|||||||
return
|
return
|
||||||
self._queue.put_nowait(item)
|
self._queue.put_nowait(item)
|
||||||
|
|
||||||
def finish(self) -> None:
|
def finish(self, cancelled: bool = False) -> None:
|
||||||
self._queue.put_nowait(StopAsyncIteration())
|
if not self._finished:
|
||||||
self._finished = True
|
self._finished = True
|
||||||
|
self._queue.put_nowait(
|
||||||
|
asyncio.CancelledError if cancelled else STOP_ITERATION)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def finished(self) -> bool:
|
def finished(self) -> bool:
|
||||||
return self._finished
|
return self._finished
|
||||||
|
|
||||||
def __aiter__(self):
|
async def generator(
|
||||||
return self
|
self
|
||||||
|
) -> AsyncGenerator[Union[RequestOutput, EmbeddingRequestOutput], None]:
|
||||||
async def __anext__(self) -> Union[RequestOutput, EmbeddingRequestOutput]:
|
try:
|
||||||
|
while not self._finished:
|
||||||
result = await self._queue.get()
|
result = await self._queue.get()
|
||||||
if isinstance(result, Exception):
|
if isinstance(result, Exception):
|
||||||
|
if result == STOP_ITERATION:
|
||||||
|
return
|
||||||
raise result
|
raise result
|
||||||
return result
|
yield result
|
||||||
|
except GeneratorExit:
|
||||||
|
self._cancel(self.request_id)
|
||||||
|
raise asyncio.CancelledError from None
|
||||||
|
|
||||||
|
|
||||||
class RequestTracker:
|
class RequestTracker:
|
||||||
@ -100,7 +112,7 @@ class RequestTracker:
|
|||||||
|
|
||||||
def __init__(self) -> None:
|
def __init__(self) -> None:
|
||||||
self._request_streams: Dict[str, AsyncStream] = {}
|
self._request_streams: Dict[str, AsyncStream] = {}
|
||||||
self._finished_requests: asyncio.Queue[str] = asyncio.Queue()
|
self._aborted_requests: asyncio.Queue[str] = asyncio.Queue()
|
||||||
self._new_requests: asyncio.Queue[Tuple[AsyncStream,
|
self._new_requests: asyncio.Queue[Tuple[AsyncStream,
|
||||||
dict]] = asyncio.Queue()
|
dict]] = asyncio.Queue()
|
||||||
self.new_requests_event = asyncio.Event()
|
self.new_requests_event = asyncio.Event()
|
||||||
@ -131,15 +143,21 @@ class RequestTracker:
|
|||||||
verbose: bool = False) -> None:
|
verbose: bool = False) -> None:
|
||||||
"""Process a request output from the engine."""
|
"""Process a request output from the engine."""
|
||||||
request_id = request_output.request_id
|
request_id = request_output.request_id
|
||||||
|
finished = request_output.finished
|
||||||
|
|
||||||
|
if finished:
|
||||||
|
stream = self._request_streams.pop(request_id, None)
|
||||||
|
else:
|
||||||
|
stream = self._request_streams.get(request_id)
|
||||||
# Guard against a KeyError which can occur if the request was aborted
|
# Guard against a KeyError which can occur if the request was aborted
|
||||||
# while the output was generated
|
# while the output was generated
|
||||||
if (stream := self._request_streams.get(request_id)) is not None:
|
if stream is not None:
|
||||||
stream.put(request_output)
|
stream.put(request_output)
|
||||||
if request_output.finished:
|
if finished:
|
||||||
if verbose:
|
stream.finish()
|
||||||
|
|
||||||
|
if verbose and finished:
|
||||||
logger.info("Finished request %s.", request_id)
|
logger.info("Finished request %s.", request_id)
|
||||||
self.abort_request(request_id)
|
|
||||||
|
|
||||||
def process_exception(self,
|
def process_exception(self,
|
||||||
request_id: str,
|
request_id: str,
|
||||||
@ -162,7 +180,8 @@ class RequestTracker:
|
|||||||
if request_id in self._request_streams:
|
if request_id in self._request_streams:
|
||||||
raise KeyError(f"Request {request_id} already exists.")
|
raise KeyError(f"Request {request_id} already exists.")
|
||||||
|
|
||||||
stream = AsyncStream(request_id)
|
abort_request = partial(self.abort_request, verbose=verbose)
|
||||||
|
stream = AsyncStream(request_id, abort_request)
|
||||||
self._new_requests.put_nowait((stream, {
|
self._new_requests.put_nowait((stream, {
|
||||||
"request_id": request_id,
|
"request_id": request_id,
|
||||||
**engine_add_request_kwargs
|
**engine_add_request_kwargs
|
||||||
@ -175,36 +194,36 @@ class RequestTracker:
|
|||||||
|
|
||||||
return stream
|
return stream
|
||||||
|
|
||||||
def abort_request(self, request_id: str, *, verbose: bool = False) -> None:
|
def abort_request(self,
|
||||||
|
request_id: str,
|
||||||
|
*,
|
||||||
|
cancelled: bool = False,
|
||||||
|
verbose: bool = False) -> None:
|
||||||
"""Abort a request during next background loop iteration."""
|
"""Abort a request during next background loop iteration."""
|
||||||
if verbose:
|
if verbose:
|
||||||
logger.info("Aborted request %s.", request_id)
|
logger.info("Aborted request %s.", request_id)
|
||||||
|
|
||||||
self._finished_requests.put_nowait(request_id)
|
self._aborted_requests.put_nowait(request_id)
|
||||||
|
|
||||||
if request_id not in self._request_streams or self._request_streams[
|
stream = self._request_streams.pop(request_id, None)
|
||||||
request_id].finished:
|
if stream is not None:
|
||||||
# The request has already finished or been aborted.
|
stream.finish(cancelled=cancelled)
|
||||||
return
|
|
||||||
|
|
||||||
self._request_streams[request_id].finish()
|
def get_new_and_aborted_requests(self) -> Tuple[List[Dict], Set[str]]:
|
||||||
|
|
||||||
def get_new_and_finished_requests(self) -> Tuple[List[Dict], Set[str]]:
|
|
||||||
"""Get the new requests and finished requests to be
|
"""Get the new requests and finished requests to be
|
||||||
sent to the engine."""
|
sent to the engine."""
|
||||||
new_requests: List[Dict] = []
|
new_requests: List[Dict] = []
|
||||||
finished_requests: Set[str] = set()
|
finished_requests: Set[str] = set()
|
||||||
|
|
||||||
while not self._finished_requests.empty():
|
while not self._aborted_requests.empty():
|
||||||
request_id = self._finished_requests.get_nowait()
|
request_id = self._aborted_requests.get_nowait()
|
||||||
finished_requests.add(request_id)
|
finished_requests.add(request_id)
|
||||||
self._request_streams.pop(request_id, None)
|
|
||||||
|
|
||||||
while not self._new_requests.empty():
|
while not self._new_requests.empty():
|
||||||
stream, new_request = self._new_requests.get_nowait()
|
stream, new_request = self._new_requests.get_nowait()
|
||||||
if stream.request_id in finished_requests:
|
if stream.request_id in finished_requests:
|
||||||
# The request has already been aborted.
|
# The request has already been aborted.
|
||||||
stream.finish()
|
stream.finish(cancelled=True)
|
||||||
continue
|
continue
|
||||||
self._request_streams[stream.request_id] = stream
|
self._request_streams[stream.request_id] = stream
|
||||||
new_requests.append(new_request)
|
new_requests.append(new_request)
|
||||||
@ -556,8 +575,8 @@ class AsyncLLMEngine:
|
|||||||
|
|
||||||
Returns True if there are in-progress requests."""
|
Returns True if there are in-progress requests."""
|
||||||
|
|
||||||
new_requests, finished_requests = (
|
new_requests, aborted_requests = (
|
||||||
self._request_tracker.get_new_and_finished_requests())
|
self._request_tracker.get_new_and_aborted_requests())
|
||||||
|
|
||||||
for new_request in new_requests:
|
for new_request in new_requests:
|
||||||
# Add the request into the vLLM engine's waiting queue.
|
# Add the request into the vLLM engine's waiting queue.
|
||||||
@ -576,8 +595,8 @@ class AsyncLLMEngine:
|
|||||||
verbose=self.log_requests,
|
verbose=self.log_requests,
|
||||||
)
|
)
|
||||||
|
|
||||||
if finished_requests:
|
if aborted_requests:
|
||||||
await self._engine_abort(finished_requests)
|
await self._engine_abort(aborted_requests)
|
||||||
|
|
||||||
if self.engine_use_ray:
|
if self.engine_use_ray:
|
||||||
request_outputs = await self.engine.step.remote() # type: ignore
|
request_outputs = await self.engine.step.remote() # type: ignore
|
||||||
@ -666,6 +685,8 @@ class AsyncLLMEngine:
|
|||||||
raise
|
raise
|
||||||
await asyncio.sleep(0)
|
await asyncio.sleep(0)
|
||||||
|
|
||||||
|
# This method does not need to be async, but kept that way
|
||||||
|
# for backwards compatibility.
|
||||||
async def add_request(
|
async def add_request(
|
||||||
self,
|
self,
|
||||||
request_id: str,
|
request_id: str,
|
||||||
@ -675,7 +696,7 @@ class AsyncLLMEngine:
|
|||||||
lora_request: Optional[LoRARequest] = None,
|
lora_request: Optional[LoRARequest] = None,
|
||||||
trace_headers: Optional[Mapping[str, str]] = None,
|
trace_headers: Optional[Mapping[str, str]] = None,
|
||||||
prompt_adapter_request: Optional[PromptAdapterRequest] = None
|
prompt_adapter_request: Optional[PromptAdapterRequest] = None
|
||||||
) -> AsyncStream:
|
) -> AsyncGenerator[Union[RequestOutput, EmbeddingRequestOutput], None]:
|
||||||
if not self.is_running:
|
if not self.is_running:
|
||||||
if self.start_engine_loop:
|
if self.start_engine_loop:
|
||||||
self.start_background_loop()
|
self.start_background_loop()
|
||||||
@ -686,20 +707,17 @@ class AsyncLLMEngine:
|
|||||||
"error that caused the background loop to stop "
|
"error that caused the background loop to stop "
|
||||||
"(AsyncEngineDeadError).")
|
"(AsyncEngineDeadError).")
|
||||||
|
|
||||||
if arrival_time is None:
|
|
||||||
arrival_time = time.time()
|
|
||||||
|
|
||||||
stream = self._request_tracker.add_request(
|
stream = self._request_tracker.add_request(
|
||||||
request_id,
|
request_id,
|
||||||
verbose=self.log_requests,
|
verbose=self.log_requests,
|
||||||
inputs=inputs,
|
inputs=inputs,
|
||||||
params=params,
|
params=params,
|
||||||
arrival_time=arrival_time,
|
arrival_time=arrival_time or time.time(),
|
||||||
lora_request=lora_request,
|
lora_request=lora_request,
|
||||||
trace_headers=trace_headers,
|
trace_headers=trace_headers,
|
||||||
prompt_adapter_request=prompt_adapter_request)
|
prompt_adapter_request=prompt_adapter_request)
|
||||||
|
|
||||||
return stream
|
return stream.generator()
|
||||||
|
|
||||||
async def generate(
|
async def generate(
|
||||||
self,
|
self,
|
||||||
@ -709,7 +727,7 @@ class AsyncLLMEngine:
|
|||||||
lora_request: Optional[LoRARequest] = None,
|
lora_request: Optional[LoRARequest] = None,
|
||||||
trace_headers: Optional[Mapping[str, str]] = None,
|
trace_headers: Optional[Mapping[str, str]] = None,
|
||||||
prompt_adapter_request: Optional[PromptAdapterRequest] = None
|
prompt_adapter_request: Optional[PromptAdapterRequest] = None
|
||||||
) -> AsyncIterator[RequestOutput]:
|
) -> AsyncGenerator[RequestOutput, None]:
|
||||||
"""Generate outputs for a request.
|
"""Generate outputs for a request.
|
||||||
|
|
||||||
Generate outputs for a request. This method is a coroutine. It adds the
|
Generate outputs for a request. This method is a coroutine. It adds the
|
||||||
@ -774,7 +792,7 @@ class AsyncLLMEngine:
|
|||||||
>>> # Process and return the final output
|
>>> # Process and return the final output
|
||||||
>>> ...
|
>>> ...
|
||||||
"""
|
"""
|
||||||
async for output in self._process_request(
|
async for output in await self.add_request(
|
||||||
request_id,
|
request_id,
|
||||||
inputs,
|
inputs,
|
||||||
sampling_params,
|
sampling_params,
|
||||||
@ -791,7 +809,7 @@ class AsyncLLMEngine:
|
|||||||
request_id: str,
|
request_id: str,
|
||||||
lora_request: Optional[LoRARequest] = None,
|
lora_request: Optional[LoRARequest] = None,
|
||||||
trace_headers: Optional[Mapping[str, str]] = None,
|
trace_headers: Optional[Mapping[str, str]] = None,
|
||||||
) -> AsyncIterator[EmbeddingRequestOutput]:
|
) -> AsyncGenerator[EmbeddingRequestOutput, None]:
|
||||||
"""Generate outputs for a request from an embedding model.
|
"""Generate outputs for a request from an embedding model.
|
||||||
|
|
||||||
Generate outputs for a request. This method is a coroutine. It adds the
|
Generate outputs for a request. This method is a coroutine. It adds the
|
||||||
@ -852,7 +870,7 @@ class AsyncLLMEngine:
|
|||||||
>>> # Process and return the final output
|
>>> # Process and return the final output
|
||||||
>>> ...
|
>>> ...
|
||||||
"""
|
"""
|
||||||
async for output in self._process_request(
|
async for output in await self.add_request(
|
||||||
request_id,
|
request_id,
|
||||||
inputs,
|
inputs,
|
||||||
pooling_params,
|
pooling_params,
|
||||||
@ -861,37 +879,6 @@ class AsyncLLMEngine:
|
|||||||
):
|
):
|
||||||
yield LLMEngine.validate_output(output, EmbeddingRequestOutput)
|
yield LLMEngine.validate_output(output, EmbeddingRequestOutput)
|
||||||
|
|
||||||
async def _process_request(
|
|
||||||
self,
|
|
||||||
request_id: str,
|
|
||||||
inputs: PromptInputs,
|
|
||||||
params: Union[SamplingParams, PoolingParams],
|
|
||||||
*,
|
|
||||||
lora_request: Optional[LoRARequest] = None,
|
|
||||||
trace_headers: Optional[Mapping[str, str]] = None,
|
|
||||||
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
|
|
||||||
) -> AsyncIterator[Union[RequestOutput, EmbeddingRequestOutput]]:
|
|
||||||
"""Common logic to process requests with SamplingParams or
|
|
||||||
PoolingParams."""
|
|
||||||
arrival_time = time.time()
|
|
||||||
|
|
||||||
stream = await self.add_request(
|
|
||||||
request_id,
|
|
||||||
inputs,
|
|
||||||
params,
|
|
||||||
arrival_time=arrival_time,
|
|
||||||
lora_request=lora_request,
|
|
||||||
trace_headers=trace_headers,
|
|
||||||
prompt_adapter_request=prompt_adapter_request,
|
|
||||||
)
|
|
||||||
|
|
||||||
try:
|
|
||||||
async for request_output in stream:
|
|
||||||
yield request_output
|
|
||||||
except (Exception, asyncio.CancelledError) as e:
|
|
||||||
self._abort(request_id)
|
|
||||||
raise e
|
|
||||||
|
|
||||||
async def abort(self, request_id: str) -> None:
|
async def abort(self, request_id: str) -> None:
|
||||||
"""Abort a request.
|
"""Abort a request.
|
||||||
|
|
||||||
@ -920,6 +907,7 @@ class AsyncLLMEngine:
|
|||||||
request_id: The unique id of the request.
|
request_id: The unique id of the request.
|
||||||
"""
|
"""
|
||||||
self._request_tracker.abort_request(request_id,
|
self._request_tracker.abort_request(request_id,
|
||||||
|
cancelled=True,
|
||||||
verbose=self.log_requests)
|
verbose=self.log_requests)
|
||||||
|
|
||||||
async def get_model_config(self) -> ModelConfig:
|
async def get_model_config(self) -> ModelConfig:
|
||||||
|
@ -1,4 +1,4 @@
|
|||||||
from typing import (AsyncIterator, List, Mapping, Optional, Protocol,
|
from typing import (AsyncGenerator, List, Mapping, Optional, Protocol,
|
||||||
runtime_checkable)
|
runtime_checkable)
|
||||||
|
|
||||||
from transformers import PreTrainedTokenizer
|
from transformers import PreTrainedTokenizer
|
||||||
@ -30,7 +30,7 @@ class AsyncEngineClient(Protocol):
|
|||||||
def errored(self) -> bool:
|
def errored(self) -> bool:
|
||||||
...
|
...
|
||||||
|
|
||||||
async def generate(
|
def generate(
|
||||||
self,
|
self,
|
||||||
inputs: PromptInputs,
|
inputs: PromptInputs,
|
||||||
sampling_params: SamplingParams,
|
sampling_params: SamplingParams,
|
||||||
@ -38,17 +38,17 @@ class AsyncEngineClient(Protocol):
|
|||||||
lora_request: Optional[LoRARequest] = None,
|
lora_request: Optional[LoRARequest] = None,
|
||||||
trace_headers: Optional[Mapping[str, str]] = None,
|
trace_headers: Optional[Mapping[str, str]] = None,
|
||||||
prompt_adapter_request: Optional[PromptAdapterRequest] = None
|
prompt_adapter_request: Optional[PromptAdapterRequest] = None
|
||||||
) -> AsyncIterator[RequestOutput]:
|
) -> AsyncGenerator[RequestOutput, None]:
|
||||||
"""Generates outputs for a request"""
|
"""Generates outputs for a request"""
|
||||||
|
|
||||||
async def encode(
|
def encode(
|
||||||
self,
|
self,
|
||||||
inputs: PromptInputs,
|
inputs: PromptInputs,
|
||||||
pooling_params: PoolingParams,
|
pooling_params: PoolingParams,
|
||||||
request_id: str,
|
request_id: str,
|
||||||
lora_request: Optional[LoRARequest] = None,
|
lora_request: Optional[LoRARequest] = None,
|
||||||
trace_headers: Optional[Mapping[str, str]] = None,
|
trace_headers: Optional[Mapping[str, str]] = None,
|
||||||
) -> AsyncIterator[EmbeddingRequestOutput]:
|
) -> AsyncGenerator[EmbeddingRequestOutput, None]:
|
||||||
"""Generate outputs for a request from an embedding model."""
|
"""Generate outputs for a request from an embedding model."""
|
||||||
|
|
||||||
async def abort(self, request_id: str) -> None:
|
async def abort(self, request_id: str) -> None:
|
||||||
|
@ -20,7 +20,8 @@ from vllm.entrypoints.launcher import serve_http
|
|||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
from vllm.sampling_params import SamplingParams
|
from vllm.sampling_params import SamplingParams
|
||||||
from vllm.usage.usage_lib import UsageContext
|
from vllm.usage.usage_lib import UsageContext
|
||||||
from vllm.utils import FlexibleArgumentParser, random_uuid
|
from vllm.utils import (FlexibleArgumentParser, iterate_with_cancellation,
|
||||||
|
random_uuid)
|
||||||
from vllm.version import __version__ as VLLM_VERSION
|
from vllm.version import __version__ as VLLM_VERSION
|
||||||
|
|
||||||
logger = init_logger("vllm.entrypoints.api_server")
|
logger = init_logger("vllm.entrypoints.api_server")
|
||||||
@ -53,6 +54,8 @@ async def generate(request: Request) -> Response:
|
|||||||
|
|
||||||
assert engine is not None
|
assert engine is not None
|
||||||
results_generator = engine.generate(prompt, sampling_params, request_id)
|
results_generator = engine.generate(prompt, sampling_params, request_id)
|
||||||
|
results_generator = iterate_with_cancellation(
|
||||||
|
results_generator, is_cancelled=request.is_disconnected)
|
||||||
|
|
||||||
# Streaming case
|
# Streaming case
|
||||||
async def stream_results() -> AsyncGenerator[bytes, None]:
|
async def stream_results() -> AsyncGenerator[bytes, None]:
|
||||||
@ -69,12 +72,11 @@ async def generate(request: Request) -> Response:
|
|||||||
|
|
||||||
# Non-streaming case
|
# Non-streaming case
|
||||||
final_output = None
|
final_output = None
|
||||||
|
try:
|
||||||
async for request_output in results_generator:
|
async for request_output in results_generator:
|
||||||
if await request.is_disconnected():
|
|
||||||
# Abort the request if the client disconnects.
|
|
||||||
await engine.abort(request_id)
|
|
||||||
return Response(status_code=499)
|
|
||||||
final_output = request_output
|
final_output = request_output
|
||||||
|
except asyncio.CancelledError:
|
||||||
|
return Response(status_code=499)
|
||||||
|
|
||||||
assert final_output is not None
|
assert final_output is not None
|
||||||
prompt = final_output.prompt
|
prompt = final_output.prompt
|
||||||
|
@ -1,5 +1,5 @@
|
|||||||
from contextlib import contextmanager
|
from contextlib import contextmanager
|
||||||
from typing import Any, AsyncIterator, Mapping, Optional
|
from typing import Any, AsyncGenerator, Mapping, Optional
|
||||||
|
|
||||||
import cloudpickle
|
import cloudpickle
|
||||||
import zmq
|
import zmq
|
||||||
@ -190,9 +190,11 @@ class AsyncEngineRPCClient:
|
|||||||
lora_request: Optional[LoRARequest] = None,
|
lora_request: Optional[LoRARequest] = None,
|
||||||
trace_headers: Optional[Mapping[str, str]] = None,
|
trace_headers: Optional[Mapping[str, str]] = None,
|
||||||
prompt_adapter_request: Optional[PromptAdapterRequest] = None
|
prompt_adapter_request: Optional[PromptAdapterRequest] = None
|
||||||
) -> AsyncIterator[RequestOutput]:
|
) -> AsyncGenerator[RequestOutput, None]:
|
||||||
"""Send an RPCGenerateRequest to the RPCServer and stream responses."""
|
"""Send an RPCGenerateRequest to the RPCServer and stream responses."""
|
||||||
|
|
||||||
|
finished = False
|
||||||
|
try:
|
||||||
with self.socket() as socket:
|
with self.socket() as socket:
|
||||||
|
|
||||||
# Send RPCGenerateRequest to the RPCServer.
|
# Send RPCGenerateRequest to the RPCServer.
|
||||||
@ -208,18 +210,18 @@ class AsyncEngineRPCClient:
|
|||||||
])
|
])
|
||||||
|
|
||||||
# Stream back the results from the RPC Server.
|
# Stream back the results from the RPC Server.
|
||||||
while True:
|
while not finished:
|
||||||
message = await socket.recv()
|
message = await socket.recv()
|
||||||
request_output = cloudpickle.loads(message)
|
request_output = cloudpickle.loads(message)
|
||||||
|
|
||||||
if isinstance(request_output, Exception):
|
if isinstance(request_output, Exception):
|
||||||
raise request_output
|
raise request_output
|
||||||
|
|
||||||
if request_output.finished:
|
finished = request_output.finished
|
||||||
break
|
|
||||||
yield request_output
|
|
||||||
|
|
||||||
yield request_output
|
yield request_output
|
||||||
|
finally:
|
||||||
|
if not finished:
|
||||||
|
await self.abort(request_id)
|
||||||
|
|
||||||
async def check_health(self) -> None:
|
async def check_health(self) -> None:
|
||||||
"""Raise if unhealthy"""
|
"""Raise if unhealthy"""
|
||||||
@ -243,6 +245,6 @@ class AsyncEngineRPCClient:
|
|||||||
"f{health_message}")
|
"f{health_message}")
|
||||||
|
|
||||||
async def encode(self, *args,
|
async def encode(self, *args,
|
||||||
**kwargs) -> AsyncIterator[EmbeddingRequestOutput]:
|
**kwargs) -> AsyncGenerator[EmbeddingRequestOutput, None]:
|
||||||
raise NotImplementedError(
|
raise NotImplementedError(
|
||||||
"Embeddings not supported with multiprocessing backend")
|
"Embeddings not supported with multiprocessing backend")
|
||||||
|
@ -1,3 +1,4 @@
|
|||||||
|
import asyncio
|
||||||
import time
|
import time
|
||||||
from typing import AsyncGenerator, AsyncIterator, Dict, List, Optional
|
from typing import AsyncGenerator, AsyncIterator, Dict, List, Optional
|
||||||
from typing import Sequence as GenericSequence
|
from typing import Sequence as GenericSequence
|
||||||
@ -29,7 +30,7 @@ from vllm.outputs import RequestOutput
|
|||||||
from vllm.sequence import Logprob
|
from vllm.sequence import Logprob
|
||||||
from vllm.tracing import (contains_trace_headers, extract_trace_headers,
|
from vllm.tracing import (contains_trace_headers, extract_trace_headers,
|
||||||
log_tracing_disabled_warning)
|
log_tracing_disabled_warning)
|
||||||
from vllm.utils import random_uuid
|
from vllm.utils import iterate_with_cancellation, random_uuid
|
||||||
|
|
||||||
logger = init_logger(__name__)
|
logger = init_logger(__name__)
|
||||||
|
|
||||||
@ -176,15 +177,17 @@ class OpenAIServingChat(OpenAIServing):
|
|||||||
# TODO: Use a vllm-specific Validation Error
|
# TODO: Use a vllm-specific Validation Error
|
||||||
return self.create_error_response(str(e))
|
return self.create_error_response(str(e))
|
||||||
|
|
||||||
|
if raw_request:
|
||||||
|
result_generator = iterate_with_cancellation(
|
||||||
|
result_generator, raw_request.is_disconnected)
|
||||||
|
|
||||||
# Streaming response
|
# Streaming response
|
||||||
if request.stream:
|
if request.stream:
|
||||||
return self.chat_completion_stream_generator(
|
return self.chat_completion_stream_generator(
|
||||||
request, result_generator, request_id, conversation, tokenizer)
|
request, result_generator, request_id, conversation, tokenizer)
|
||||||
else:
|
|
||||||
try:
|
try:
|
||||||
return await self.chat_completion_full_generator(
|
return await self.chat_completion_full_generator(
|
||||||
request, raw_request, result_generator, request_id,
|
request, result_generator, request_id, conversation, tokenizer)
|
||||||
conversation, tokenizer)
|
|
||||||
except ValueError as e:
|
except ValueError as e:
|
||||||
# TODO: Use a vllm-specific Validation Error
|
# TODO: Use a vllm-specific Validation Error
|
||||||
return self.create_error_response(str(e))
|
return self.create_error_response(str(e))
|
||||||
@ -422,7 +425,6 @@ class OpenAIServingChat(OpenAIServing):
|
|||||||
async def chat_completion_full_generator(
|
async def chat_completion_full_generator(
|
||||||
self,
|
self,
|
||||||
request: ChatCompletionRequest,
|
request: ChatCompletionRequest,
|
||||||
raw_request: Optional[Request],
|
|
||||||
result_generator: AsyncIterator[RequestOutput],
|
result_generator: AsyncIterator[RequestOutput],
|
||||||
request_id: str,
|
request_id: str,
|
||||||
conversation: List[ConversationMessage],
|
conversation: List[ConversationMessage],
|
||||||
@ -433,12 +435,12 @@ class OpenAIServingChat(OpenAIServing):
|
|||||||
created_time = int(time.time())
|
created_time = int(time.time())
|
||||||
final_res: Optional[RequestOutput] = None
|
final_res: Optional[RequestOutput] = None
|
||||||
|
|
||||||
|
try:
|
||||||
async for res in result_generator:
|
async for res in result_generator:
|
||||||
if raw_request is not None and await raw_request.is_disconnected():
|
|
||||||
# Abort the request if the client disconnects.
|
|
||||||
await self.async_engine_client.abort(request_id)
|
|
||||||
return self.create_error_response("Client disconnected")
|
|
||||||
final_res = res
|
final_res = res
|
||||||
|
except asyncio.CancelledError:
|
||||||
|
return self.create_error_response("Client disconnected")
|
||||||
|
|
||||||
assert final_res is not None
|
assert final_res is not None
|
||||||
|
|
||||||
choices: List[ChatCompletionResponseChoice] = []
|
choices: List[ChatCompletionResponseChoice] = []
|
||||||
|
@ -1,3 +1,4 @@
|
|||||||
|
import asyncio
|
||||||
import time
|
import time
|
||||||
from typing import (AsyncGenerator, AsyncIterator, Callable, Dict, List,
|
from typing import (AsyncGenerator, AsyncIterator, Callable, Dict, List,
|
||||||
Optional)
|
Optional)
|
||||||
@ -84,7 +85,7 @@ class OpenAIServingCompletion(OpenAIServing):
|
|||||||
created_time = int(time.time())
|
created_time = int(time.time())
|
||||||
|
|
||||||
# Schedule the request and get the result generator.
|
# Schedule the request and get the result generator.
|
||||||
generators: List[AsyncIterator[RequestOutput]] = []
|
generators: List[AsyncGenerator[RequestOutput, None]] = []
|
||||||
try:
|
try:
|
||||||
(
|
(
|
||||||
lora_request,
|
lora_request,
|
||||||
@ -144,7 +145,8 @@ class OpenAIServingCompletion(OpenAIServing):
|
|||||||
return self.create_error_response(str(e))
|
return self.create_error_response(str(e))
|
||||||
|
|
||||||
result_generator: AsyncIterator[Tuple[
|
result_generator: AsyncIterator[Tuple[
|
||||||
int, RequestOutput]] = merge_async_iterators(*generators)
|
int, RequestOutput]] = merge_async_iterators(
|
||||||
|
*generators, is_cancelled=raw_request.is_disconnected)
|
||||||
|
|
||||||
# Similar to the OpenAI API, when n != best_of, we do not stream the
|
# Similar to the OpenAI API, when n != best_of, we do not stream the
|
||||||
# results. In addition, we do not stream the results when use
|
# results. In addition, we do not stream the results when use
|
||||||
@ -156,7 +158,6 @@ class OpenAIServingCompletion(OpenAIServing):
|
|||||||
# Streaming response
|
# Streaming response
|
||||||
if stream:
|
if stream:
|
||||||
return self.completion_stream_generator(request,
|
return self.completion_stream_generator(request,
|
||||||
raw_request,
|
|
||||||
result_generator,
|
result_generator,
|
||||||
request_id,
|
request_id,
|
||||||
created_time,
|
created_time,
|
||||||
@ -168,10 +169,6 @@ class OpenAIServingCompletion(OpenAIServing):
|
|||||||
final_res_batch: List[Optional[RequestOutput]] = [None] * len(prompts)
|
final_res_batch: List[Optional[RequestOutput]] = [None] * len(prompts)
|
||||||
try:
|
try:
|
||||||
async for i, res in result_generator:
|
async for i, res in result_generator:
|
||||||
if await raw_request.is_disconnected():
|
|
||||||
# Abort the request if the client disconnects.
|
|
||||||
await self.async_engine_client.abort(f"{request_id}-{i}")
|
|
||||||
return self.create_error_response("Client disconnected")
|
|
||||||
final_res_batch[i] = res
|
final_res_batch[i] = res
|
||||||
|
|
||||||
for i, final_res in enumerate(final_res_batch):
|
for i, final_res in enumerate(final_res_batch):
|
||||||
@ -194,6 +191,8 @@ class OpenAIServingCompletion(OpenAIServing):
|
|||||||
model_name,
|
model_name,
|
||||||
tokenizer,
|
tokenizer,
|
||||||
)
|
)
|
||||||
|
except asyncio.CancelledError:
|
||||||
|
return self.create_error_response("Client disconnected")
|
||||||
except ValueError as e:
|
except ValueError as e:
|
||||||
# TODO: Use a vllm-specific Validation Error
|
# TODO: Use a vllm-specific Validation Error
|
||||||
return self.create_error_response(str(e))
|
return self.create_error_response(str(e))
|
||||||
@ -214,7 +213,6 @@ class OpenAIServingCompletion(OpenAIServing):
|
|||||||
async def completion_stream_generator(
|
async def completion_stream_generator(
|
||||||
self,
|
self,
|
||||||
request: CompletionRequest,
|
request: CompletionRequest,
|
||||||
raw_request: Request,
|
|
||||||
result_generator: AsyncIterator[Tuple[int, RequestOutput]],
|
result_generator: AsyncIterator[Tuple[int, RequestOutput]],
|
||||||
request_id: str,
|
request_id: str,
|
||||||
created_time: int,
|
created_time: int,
|
||||||
@ -230,12 +228,6 @@ class OpenAIServingCompletion(OpenAIServing):
|
|||||||
try:
|
try:
|
||||||
async for prompt_idx, res in result_generator:
|
async for prompt_idx, res in result_generator:
|
||||||
|
|
||||||
# Abort the request if the client disconnects.
|
|
||||||
if await raw_request.is_disconnected():
|
|
||||||
await self.async_engine_client.abort(
|
|
||||||
f"{request_id}-{prompt_idx}")
|
|
||||||
raise StopAsyncIteration()
|
|
||||||
|
|
||||||
for output in res.outputs:
|
for output in res.outputs:
|
||||||
i = output.index + prompt_idx * num_choices
|
i = output.index + prompt_idx * num_choices
|
||||||
# TODO(simon): optimize the performance by avoiding full
|
# TODO(simon): optimize the performance by avoiding full
|
||||||
|
@ -1,6 +1,7 @@
|
|||||||
|
import asyncio
|
||||||
import base64
|
import base64
|
||||||
import time
|
import time
|
||||||
from typing import AsyncIterator, List, Optional, Tuple, cast
|
from typing import AsyncGenerator, AsyncIterator, List, Optional, Tuple, cast
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from fastapi import Request
|
from fastapi import Request
|
||||||
@ -92,7 +93,7 @@ class OpenAIServingEmbedding(OpenAIServing):
|
|||||||
created_time = int(time.monotonic())
|
created_time = int(time.monotonic())
|
||||||
|
|
||||||
# Schedule the request and get the result generator.
|
# Schedule the request and get the result generator.
|
||||||
generators: List[AsyncIterator[EmbeddingRequestOutput]] = []
|
generators: List[AsyncGenerator[EmbeddingRequestOutput, None]] = []
|
||||||
try:
|
try:
|
||||||
(
|
(
|
||||||
lora_request,
|
lora_request,
|
||||||
@ -138,17 +139,14 @@ class OpenAIServingEmbedding(OpenAIServing):
|
|||||||
return self.create_error_response(str(e))
|
return self.create_error_response(str(e))
|
||||||
|
|
||||||
result_generator: AsyncIterator[Tuple[
|
result_generator: AsyncIterator[Tuple[
|
||||||
int, EmbeddingRequestOutput]] = merge_async_iterators(*generators)
|
int, EmbeddingRequestOutput]] = merge_async_iterators(
|
||||||
|
*generators, is_cancelled=raw_request.is_disconnected)
|
||||||
|
|
||||||
# Non-streaming response
|
# Non-streaming response
|
||||||
final_res_batch: List[Optional[EmbeddingRequestOutput]]
|
final_res_batch: List[Optional[EmbeddingRequestOutput]]
|
||||||
final_res_batch = [None] * len(prompts)
|
final_res_batch = [None] * len(prompts)
|
||||||
try:
|
try:
|
||||||
async for i, res in result_generator:
|
async for i, res in result_generator:
|
||||||
if await raw_request.is_disconnected():
|
|
||||||
# Abort the request if the client disconnects.
|
|
||||||
await self.async_engine_client.abort(f"{request_id}-{i}")
|
|
||||||
return self.create_error_response("Client disconnected")
|
|
||||||
final_res_batch[i] = res
|
final_res_batch[i] = res
|
||||||
|
|
||||||
for final_res in final_res_batch:
|
for final_res in final_res_batch:
|
||||||
@ -160,6 +158,8 @@ class OpenAIServingEmbedding(OpenAIServing):
|
|||||||
response = request_output_to_embedding_response(
|
response = request_output_to_embedding_response(
|
||||||
final_res_batch_checked, request_id, created_time, model_name,
|
final_res_batch_checked, request_id, created_time, model_name,
|
||||||
encoding_format)
|
encoding_format)
|
||||||
|
except asyncio.CancelledError:
|
||||||
|
return self.create_error_response("Client disconnected")
|
||||||
except ValueError as e:
|
except ValueError as e:
|
||||||
# TODO: Use a vllm-specific Validation Error
|
# TODO: Use a vllm-specific Validation Error
|
||||||
return self.create_error_response(str(e))
|
return self.create_error_response(str(e))
|
||||||
|
107
vllm/utils.py
107
vllm/utils.py
@ -1,5 +1,6 @@
|
|||||||
import argparse
|
import argparse
|
||||||
import asyncio
|
import asyncio
|
||||||
|
import contextlib
|
||||||
import datetime
|
import datetime
|
||||||
import enum
|
import enum
|
||||||
import gc
|
import gc
|
||||||
@ -11,10 +12,11 @@ import tempfile
|
|||||||
import threading
|
import threading
|
||||||
import uuid
|
import uuid
|
||||||
import warnings
|
import warnings
|
||||||
|
from asyncio import FIRST_COMPLETED, ensure_future
|
||||||
from collections import defaultdict
|
from collections import defaultdict
|
||||||
from functools import lru_cache, partial, wraps
|
from functools import lru_cache, partial, wraps
|
||||||
from platform import uname
|
from platform import uname
|
||||||
from typing import (Any, AsyncIterator, Awaitable, Callable, Dict, Generic,
|
from typing import (Any, AsyncGenerator, Awaitable, Callable, Dict, Generic,
|
||||||
Hashable, List, Optional, OrderedDict, Set, Tuple, TypeVar,
|
Hashable, List, Optional, OrderedDict, Set, Tuple, TypeVar,
|
||||||
Union, overload)
|
Union, overload)
|
||||||
|
|
||||||
@ -373,63 +375,74 @@ def make_async(func: Callable[P, T]) -> Callable[P, Awaitable[T]]:
|
|||||||
return _async_wrapper
|
return _async_wrapper
|
||||||
|
|
||||||
|
|
||||||
class ProducerFinished:
|
async def iterate_with_cancellation(
|
||||||
pass
|
iterator: AsyncGenerator[T, None],
|
||||||
|
is_cancelled: Callable[[], Awaitable[bool]],
|
||||||
|
) -> AsyncGenerator[T, None]:
|
||||||
|
"""Convert async iterator into one that polls the provided function
|
||||||
|
at least once per second to check for client cancellation.
|
||||||
|
"""
|
||||||
|
|
||||||
|
# Can use anext() in python >= 3.10
|
||||||
|
awaits = [ensure_future(iterator.__anext__())]
|
||||||
|
while True:
|
||||||
|
done, pending = await asyncio.wait(awaits, timeout=1)
|
||||||
|
if await is_cancelled():
|
||||||
|
with contextlib.suppress(BaseException):
|
||||||
|
awaits[0].cancel()
|
||||||
|
await iterator.aclose()
|
||||||
|
raise asyncio.CancelledError("client cancelled")
|
||||||
|
if done:
|
||||||
|
try:
|
||||||
|
item = await awaits[0]
|
||||||
|
awaits[0] = ensure_future(iterator.__anext__())
|
||||||
|
yield item
|
||||||
|
except StopAsyncIteration:
|
||||||
|
# we are done
|
||||||
|
return
|
||||||
|
|
||||||
|
|
||||||
def merge_async_iterators(
|
async def merge_async_iterators(
|
||||||
*iterators: AsyncIterator[T]) -> AsyncIterator[Tuple[int, T]]:
|
*iterators: AsyncGenerator[T, None],
|
||||||
|
is_cancelled: Callable[[], Awaitable[bool]],
|
||||||
|
) -> AsyncGenerator[Tuple[int, T], None]:
|
||||||
"""Merge multiple asynchronous iterators into a single iterator.
|
"""Merge multiple asynchronous iterators into a single iterator.
|
||||||
|
|
||||||
This method handle the case where some iterators finish before others.
|
This method handle the case where some iterators finish before others.
|
||||||
When it yields, it yields a tuple (i, item) where i is the index of the
|
When it yields, it yields a tuple (i, item) where i is the index of the
|
||||||
iterator that yields the item.
|
iterator that yields the item.
|
||||||
|
|
||||||
|
It also polls the provided function at least once per second to check
|
||||||
|
for client cancellation.
|
||||||
"""
|
"""
|
||||||
queue: asyncio.Queue[Union[Tuple[int, T], ProducerFinished,
|
|
||||||
Exception]] = asyncio.Queue()
|
|
||||||
|
|
||||||
producers = len(iterators)
|
# Can use anext() in python >= 3.10
|
||||||
|
awaits = {
|
||||||
async def producer(i: int, iterator: AsyncIterator[T]):
|
ensure_future(pair[1].__anext__()): pair
|
||||||
|
for pair in enumerate(iterators)
|
||||||
|
}
|
||||||
try:
|
try:
|
||||||
async for item in iterator:
|
while awaits:
|
||||||
await queue.put((i, item))
|
done, pending = await asyncio.wait(awaits.keys(),
|
||||||
except Exception as e:
|
return_when=FIRST_COMPLETED,
|
||||||
await queue.put(e)
|
timeout=1)
|
||||||
# Signal to the consumer that we've finished
|
if await is_cancelled():
|
||||||
await queue.put(ProducerFinished())
|
raise asyncio.CancelledError("client cancelled")
|
||||||
|
for d in done:
|
||||||
_tasks = [
|
pair = awaits.pop(d)
|
||||||
asyncio.create_task(producer(i, iterator))
|
|
||||||
for i, iterator in enumerate(iterators)
|
|
||||||
]
|
|
||||||
|
|
||||||
async def consumer():
|
|
||||||
remaining = producers
|
|
||||||
try:
|
try:
|
||||||
while remaining or not queue.empty():
|
item = await d
|
||||||
# we think there is a race condition here
|
i, it = pair
|
||||||
item = await queue.get()
|
awaits[ensure_future(it.__anext__())] = pair
|
||||||
|
yield i, item
|
||||||
if isinstance(item, ProducerFinished):
|
except StopAsyncIteration:
|
||||||
# Signal that a producer finished- not a real item
|
pass
|
||||||
remaining -= 1
|
finally:
|
||||||
continue
|
# Cancel any remaining iterators
|
||||||
|
for f, (_, it) in awaits.items():
|
||||||
if isinstance(item, Exception):
|
with contextlib.suppress(BaseException):
|
||||||
raise item
|
f.cancel()
|
||||||
yield item
|
await it.aclose()
|
||||||
except (Exception, asyncio.CancelledError) as e:
|
|
||||||
for task in _tasks:
|
|
||||||
if sys.version_info >= (3, 9):
|
|
||||||
# msg parameter only supported in Python 3.9+
|
|
||||||
task.cancel(e)
|
|
||||||
else:
|
|
||||||
task.cancel()
|
|
||||||
raise e
|
|
||||||
await asyncio.gather(*_tasks)
|
|
||||||
|
|
||||||
return consumer()
|
|
||||||
|
|
||||||
|
|
||||||
def get_ip() -> str:
|
def get_ip() -> str:
|
||||||
|
Loading…
x
Reference in New Issue
Block a user