[BugFix] Overhaul async request cancellation (#7111)

This commit is contained in:
Nick Hill 2024-08-06 22:21:41 -07:00 committed by GitHub
parent f9a5600649
commit 9a3f49ae07
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
11 changed files with 222 additions and 222 deletions

View File

@ -1,5 +1,5 @@
"""vllm.entrypoints.api_server with some extra logging for testing."""
from typing import Any, Dict
from typing import Any, Dict, Iterable
import uvicorn
from fastapi.responses import JSONResponse, Response
@ -18,9 +18,10 @@ class AsyncLLMEngineWithStats(AsyncLLMEngine):
super().__init__(*args, **kwargs)
self._num_aborts = 0
async def abort(self, request_id: str) -> None:
await super().abort(request_id)
self._num_aborts += 1
async def _engine_abort(self, request_ids: Iterable[str]):
ids = list(request_ids)
self._num_aborts += len(ids)
await super()._engine_abort(ids)
def testing_stats(self) -> Dict[str, Any]:
return {"num_aborted_requests": self._num_aborts}

View File

@ -10,23 +10,23 @@ async def test_request_tracker():
stream_1 = tracker.add_request("1")
assert tracker.new_requests_event.is_set()
await tracker.wait_for_new_requests()
new, finished = tracker.get_new_and_finished_requests()
new, aborted = tracker.get_new_and_aborted_requests()
assert not tracker.new_requests_event.is_set()
assert len(new) == 1
assert new[0]["request_id"] == "1"
assert not finished
assert not aborted
assert not stream_1.finished
stream_2 = tracker.add_request("2")
stream_3 = tracker.add_request("3")
assert tracker.new_requests_event.is_set()
await tracker.wait_for_new_requests()
new, finished = tracker.get_new_and_finished_requests()
new, aborted = tracker.get_new_and_aborted_requests()
assert not tracker.new_requests_event.is_set()
assert len(new) == 2
assert new[0]["request_id"] == "2"
assert new[1]["request_id"] == "3"
assert not finished
assert not aborted
assert not stream_2.finished
assert not stream_3.finished
@ -36,9 +36,9 @@ async def test_request_tracker():
assert not tracker.new_requests_event.is_set()
tracker.abort_request("1")
new, finished = tracker.get_new_and_finished_requests()
assert len(finished) == 1
assert "1" in finished
new, aborted = tracker.get_new_and_aborted_requests()
assert len(aborted) == 1
assert "1" in aborted
assert not new
assert stream_1.finished
@ -46,9 +46,9 @@ async def test_request_tracker():
tracker.abort_request("4")
assert tracker.new_requests_event.is_set()
await tracker.wait_for_new_requests()
new, finished = tracker.get_new_and_finished_requests()
assert len(finished) == 1
assert "4" in finished
new, aborted = tracker.get_new_and_aborted_requests()
assert len(aborted) == 1
assert "4" in aborted
assert not new
assert stream_4.finished
@ -57,10 +57,9 @@ async def test_request_tracker():
tracker.process_request_output(
RequestOutput("2", "output", [], [], [], finished=True))
await tracker.wait_for_new_requests()
new, finished = tracker.get_new_and_finished_requests()
new, aborted = tracker.get_new_and_aborted_requests()
assert not tracker.new_requests_event.is_set()
assert len(finished) == 1
assert "2" in finished
assert not aborted
assert len(new) == 1
assert new[0]["request_id"] == "5"
assert stream_2.finished

View File

@ -2,6 +2,7 @@ import asyncio
import os
import socket
import sys
from functools import partial
from typing import (TYPE_CHECKING, Any, AsyncIterator, Awaitable, Protocol,
Tuple, TypeVar)
@ -37,11 +38,11 @@ async def test_merge_async_iterators():
yield f"item from iterator {idx}"
await asyncio.sleep(0.1)
except asyncio.CancelledError:
pass
print(f"iterator {idx} cancelled")
iterators = [mock_async_iterator(i) for i in range(3)]
merged_iterator: AsyncIterator[Tuple[int, str]] = merge_async_iterators(
*iterators)
*iterators, is_cancelled=partial(asyncio.sleep, 0, result=False))
async def stream_output(generator: AsyncIterator[Tuple[int, str]]):
async for idx, output in generator:

View File

@ -1,7 +1,7 @@
import asyncio
import time
from functools import partial
from typing import (AsyncIterator, Callable, Dict, Iterable, List, Mapping,
from typing import (AsyncGenerator, Callable, Dict, Iterable, List, Mapping,
Optional, Set, Tuple, Type, Union)
from transformers import PreTrainedTokenizer
@ -62,12 +62,16 @@ def _log_task_completion(task: asyncio.Task,
"actual cause.") from e
STOP_ITERATION = Exception() # Sentinel
class AsyncStream:
"""A stream of RequestOutputs or EmbeddingRequestOutputs for a request
that can be iterated over asynchronously."""
that can be iterated over asynchronously via an async generator."""
def __init__(self, request_id: str) -> None:
def __init__(self, request_id: str, cancel: Callable[[str], None]) -> None:
self.request_id = request_id
self._cancel = cancel
self._queue: asyncio.Queue = asyncio.Queue()
self._finished = False
@ -77,22 +81,30 @@ class AsyncStream:
return
self._queue.put_nowait(item)
def finish(self) -> None:
self._queue.put_nowait(StopAsyncIteration())
self._finished = True
def finish(self, cancelled: bool = False) -> None:
if not self._finished:
self._finished = True
self._queue.put_nowait(
asyncio.CancelledError if cancelled else STOP_ITERATION)
@property
def finished(self) -> bool:
return self._finished
def __aiter__(self):
return self
async def __anext__(self) -> Union[RequestOutput, EmbeddingRequestOutput]:
result = await self._queue.get()
if isinstance(result, Exception):
raise result
return result
async def generator(
self
) -> AsyncGenerator[Union[RequestOutput, EmbeddingRequestOutput], None]:
try:
while not self._finished:
result = await self._queue.get()
if isinstance(result, Exception):
if result == STOP_ITERATION:
return
raise result
yield result
except GeneratorExit:
self._cancel(self.request_id)
raise asyncio.CancelledError from None
class RequestTracker:
@ -100,7 +112,7 @@ class RequestTracker:
def __init__(self) -> None:
self._request_streams: Dict[str, AsyncStream] = {}
self._finished_requests: asyncio.Queue[str] = asyncio.Queue()
self._aborted_requests: asyncio.Queue[str] = asyncio.Queue()
self._new_requests: asyncio.Queue[Tuple[AsyncStream,
dict]] = asyncio.Queue()
self.new_requests_event = asyncio.Event()
@ -131,15 +143,21 @@ class RequestTracker:
verbose: bool = False) -> None:
"""Process a request output from the engine."""
request_id = request_output.request_id
finished = request_output.finished
if finished:
stream = self._request_streams.pop(request_id, None)
else:
stream = self._request_streams.get(request_id)
# Guard against a KeyError which can occur if the request was aborted
# while the output was generated
if (stream := self._request_streams.get(request_id)) is not None:
if stream is not None:
stream.put(request_output)
if request_output.finished:
if verbose:
logger.info("Finished request %s.", request_id)
self.abort_request(request_id)
if finished:
stream.finish()
if verbose and finished:
logger.info("Finished request %s.", request_id)
def process_exception(self,
request_id: str,
@ -162,7 +180,8 @@ class RequestTracker:
if request_id in self._request_streams:
raise KeyError(f"Request {request_id} already exists.")
stream = AsyncStream(request_id)
abort_request = partial(self.abort_request, verbose=verbose)
stream = AsyncStream(request_id, abort_request)
self._new_requests.put_nowait((stream, {
"request_id": request_id,
**engine_add_request_kwargs
@ -175,36 +194,36 @@ class RequestTracker:
return stream
def abort_request(self, request_id: str, *, verbose: bool = False) -> None:
def abort_request(self,
request_id: str,
*,
cancelled: bool = False,
verbose: bool = False) -> None:
"""Abort a request during next background loop iteration."""
if verbose:
logger.info("Aborted request %s.", request_id)
self._finished_requests.put_nowait(request_id)
self._aborted_requests.put_nowait(request_id)
if request_id not in self._request_streams or self._request_streams[
request_id].finished:
# The request has already finished or been aborted.
return
stream = self._request_streams.pop(request_id, None)
if stream is not None:
stream.finish(cancelled=cancelled)
self._request_streams[request_id].finish()
def get_new_and_finished_requests(self) -> Tuple[List[Dict], Set[str]]:
def get_new_and_aborted_requests(self) -> Tuple[List[Dict], Set[str]]:
"""Get the new requests and finished requests to be
sent to the engine."""
new_requests: List[Dict] = []
finished_requests: Set[str] = set()
while not self._finished_requests.empty():
request_id = self._finished_requests.get_nowait()
while not self._aborted_requests.empty():
request_id = self._aborted_requests.get_nowait()
finished_requests.add(request_id)
self._request_streams.pop(request_id, None)
while not self._new_requests.empty():
stream, new_request = self._new_requests.get_nowait()
if stream.request_id in finished_requests:
# The request has already been aborted.
stream.finish()
stream.finish(cancelled=True)
continue
self._request_streams[stream.request_id] = stream
new_requests.append(new_request)
@ -556,8 +575,8 @@ class AsyncLLMEngine:
Returns True if there are in-progress requests."""
new_requests, finished_requests = (
self._request_tracker.get_new_and_finished_requests())
new_requests, aborted_requests = (
self._request_tracker.get_new_and_aborted_requests())
for new_request in new_requests:
# Add the request into the vLLM engine's waiting queue.
@ -576,8 +595,8 @@ class AsyncLLMEngine:
verbose=self.log_requests,
)
if finished_requests:
await self._engine_abort(finished_requests)
if aborted_requests:
await self._engine_abort(aborted_requests)
if self.engine_use_ray:
request_outputs = await self.engine.step.remote() # type: ignore
@ -666,6 +685,8 @@ class AsyncLLMEngine:
raise
await asyncio.sleep(0)
# This method does not need to be async, but kept that way
# for backwards compatibility.
async def add_request(
self,
request_id: str,
@ -675,7 +696,7 @@ class AsyncLLMEngine:
lora_request: Optional[LoRARequest] = None,
trace_headers: Optional[Mapping[str, str]] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None
) -> AsyncStream:
) -> AsyncGenerator[Union[RequestOutput, EmbeddingRequestOutput], None]:
if not self.is_running:
if self.start_engine_loop:
self.start_background_loop()
@ -686,20 +707,17 @@ class AsyncLLMEngine:
"error that caused the background loop to stop "
"(AsyncEngineDeadError).")
if arrival_time is None:
arrival_time = time.time()
stream = self._request_tracker.add_request(
request_id,
verbose=self.log_requests,
inputs=inputs,
params=params,
arrival_time=arrival_time,
arrival_time=arrival_time or time.time(),
lora_request=lora_request,
trace_headers=trace_headers,
prompt_adapter_request=prompt_adapter_request)
return stream
return stream.generator()
async def generate(
self,
@ -709,7 +727,7 @@ class AsyncLLMEngine:
lora_request: Optional[LoRARequest] = None,
trace_headers: Optional[Mapping[str, str]] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None
) -> AsyncIterator[RequestOutput]:
) -> AsyncGenerator[RequestOutput, None]:
"""Generate outputs for a request.
Generate outputs for a request. This method is a coroutine. It adds the
@ -774,7 +792,7 @@ class AsyncLLMEngine:
>>> # Process and return the final output
>>> ...
"""
async for output in self._process_request(
async for output in await self.add_request(
request_id,
inputs,
sampling_params,
@ -791,7 +809,7 @@ class AsyncLLMEngine:
request_id: str,
lora_request: Optional[LoRARequest] = None,
trace_headers: Optional[Mapping[str, str]] = None,
) -> AsyncIterator[EmbeddingRequestOutput]:
) -> AsyncGenerator[EmbeddingRequestOutput, None]:
"""Generate outputs for a request from an embedding model.
Generate outputs for a request. This method is a coroutine. It adds the
@ -852,7 +870,7 @@ class AsyncLLMEngine:
>>> # Process and return the final output
>>> ...
"""
async for output in self._process_request(
async for output in await self.add_request(
request_id,
inputs,
pooling_params,
@ -861,37 +879,6 @@ class AsyncLLMEngine:
):
yield LLMEngine.validate_output(output, EmbeddingRequestOutput)
async def _process_request(
self,
request_id: str,
inputs: PromptInputs,
params: Union[SamplingParams, PoolingParams],
*,
lora_request: Optional[LoRARequest] = None,
trace_headers: Optional[Mapping[str, str]] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
) -> AsyncIterator[Union[RequestOutput, EmbeddingRequestOutput]]:
"""Common logic to process requests with SamplingParams or
PoolingParams."""
arrival_time = time.time()
stream = await self.add_request(
request_id,
inputs,
params,
arrival_time=arrival_time,
lora_request=lora_request,
trace_headers=trace_headers,
prompt_adapter_request=prompt_adapter_request,
)
try:
async for request_output in stream:
yield request_output
except (Exception, asyncio.CancelledError) as e:
self._abort(request_id)
raise e
async def abort(self, request_id: str) -> None:
"""Abort a request.
@ -920,6 +907,7 @@ class AsyncLLMEngine:
request_id: The unique id of the request.
"""
self._request_tracker.abort_request(request_id,
cancelled=True,
verbose=self.log_requests)
async def get_model_config(self) -> ModelConfig:

View File

@ -1,4 +1,4 @@
from typing import (AsyncIterator, List, Mapping, Optional, Protocol,
from typing import (AsyncGenerator, List, Mapping, Optional, Protocol,
runtime_checkable)
from transformers import PreTrainedTokenizer
@ -30,7 +30,7 @@ class AsyncEngineClient(Protocol):
def errored(self) -> bool:
...
async def generate(
def generate(
self,
inputs: PromptInputs,
sampling_params: SamplingParams,
@ -38,17 +38,17 @@ class AsyncEngineClient(Protocol):
lora_request: Optional[LoRARequest] = None,
trace_headers: Optional[Mapping[str, str]] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None
) -> AsyncIterator[RequestOutput]:
) -> AsyncGenerator[RequestOutput, None]:
"""Generates outputs for a request"""
async def encode(
def encode(
self,
inputs: PromptInputs,
pooling_params: PoolingParams,
request_id: str,
lora_request: Optional[LoRARequest] = None,
trace_headers: Optional[Mapping[str, str]] = None,
) -> AsyncIterator[EmbeddingRequestOutput]:
) -> AsyncGenerator[EmbeddingRequestOutput, None]:
"""Generate outputs for a request from an embedding model."""
async def abort(self, request_id: str) -> None:

View File

@ -20,7 +20,8 @@ from vllm.entrypoints.launcher import serve_http
from vllm.logger import init_logger
from vllm.sampling_params import SamplingParams
from vllm.usage.usage_lib import UsageContext
from vllm.utils import FlexibleArgumentParser, random_uuid
from vllm.utils import (FlexibleArgumentParser, iterate_with_cancellation,
random_uuid)
from vllm.version import __version__ as VLLM_VERSION
logger = init_logger("vllm.entrypoints.api_server")
@ -53,6 +54,8 @@ async def generate(request: Request) -> Response:
assert engine is not None
results_generator = engine.generate(prompt, sampling_params, request_id)
results_generator = iterate_with_cancellation(
results_generator, is_cancelled=request.is_disconnected)
# Streaming case
async def stream_results() -> AsyncGenerator[bytes, None]:
@ -69,12 +72,11 @@ async def generate(request: Request) -> Response:
# Non-streaming case
final_output = None
async for request_output in results_generator:
if await request.is_disconnected():
# Abort the request if the client disconnects.
await engine.abort(request_id)
return Response(status_code=499)
final_output = request_output
try:
async for request_output in results_generator:
final_output = request_output
except asyncio.CancelledError:
return Response(status_code=499)
assert final_output is not None
prompt = final_output.prompt

View File

@ -1,5 +1,5 @@
from contextlib import contextmanager
from typing import Any, AsyncIterator, Mapping, Optional
from typing import Any, AsyncGenerator, Mapping, Optional
import cloudpickle
import zmq
@ -190,36 +190,38 @@ class AsyncEngineRPCClient:
lora_request: Optional[LoRARequest] = None,
trace_headers: Optional[Mapping[str, str]] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None
) -> AsyncIterator[RequestOutput]:
) -> AsyncGenerator[RequestOutput, None]:
"""Send an RPCGenerateRequest to the RPCServer and stream responses."""
with self.socket() as socket:
finished = False
try:
with self.socket() as socket:
# Send RPCGenerateRequest to the RPCServer.
await socket.send_multipart([
cloudpickle.dumps(
RPCGenerateRequest(
inputs=inputs,
sampling_params=sampling_params,
request_id=request_id,
lora_request=lora_request,
trace_headers=trace_headers,
prompt_adapter_request=prompt_adapter_request))
])
# Send RPCGenerateRequest to the RPCServer.
await socket.send_multipart([
cloudpickle.dumps(
RPCGenerateRequest(
inputs=inputs,
sampling_params=sampling_params,
request_id=request_id,
lora_request=lora_request,
trace_headers=trace_headers,
prompt_adapter_request=prompt_adapter_request))
])
# Stream back the results from the RPC Server.
while True:
message = await socket.recv()
request_output = cloudpickle.loads(message)
# Stream back the results from the RPC Server.
while not finished:
message = await socket.recv()
request_output = cloudpickle.loads(message)
if isinstance(request_output, Exception):
raise request_output
if isinstance(request_output, Exception):
raise request_output
if request_output.finished:
break
yield request_output
yield request_output
finished = request_output.finished
yield request_output
finally:
if not finished:
await self.abort(request_id)
async def check_health(self) -> None:
"""Raise if unhealthy"""
@ -243,6 +245,6 @@ class AsyncEngineRPCClient:
"f{health_message}")
async def encode(self, *args,
**kwargs) -> AsyncIterator[EmbeddingRequestOutput]:
**kwargs) -> AsyncGenerator[EmbeddingRequestOutput, None]:
raise NotImplementedError(
"Embeddings not supported with multiprocessing backend")

View File

@ -1,3 +1,4 @@
import asyncio
import time
from typing import AsyncGenerator, AsyncIterator, Dict, List, Optional
from typing import Sequence as GenericSequence
@ -29,7 +30,7 @@ from vllm.outputs import RequestOutput
from vllm.sequence import Logprob
from vllm.tracing import (contains_trace_headers, extract_trace_headers,
log_tracing_disabled_warning)
from vllm.utils import random_uuid
from vllm.utils import iterate_with_cancellation, random_uuid
logger = init_logger(__name__)
@ -176,18 +177,20 @@ class OpenAIServingChat(OpenAIServing):
# TODO: Use a vllm-specific Validation Error
return self.create_error_response(str(e))
if raw_request:
result_generator = iterate_with_cancellation(
result_generator, raw_request.is_disconnected)
# Streaming response
if request.stream:
return self.chat_completion_stream_generator(
request, result_generator, request_id, conversation, tokenizer)
else:
try:
return await self.chat_completion_full_generator(
request, raw_request, result_generator, request_id,
conversation, tokenizer)
except ValueError as e:
# TODO: Use a vllm-specific Validation Error
return self.create_error_response(str(e))
try:
return await self.chat_completion_full_generator(
request, result_generator, request_id, conversation, tokenizer)
except ValueError as e:
# TODO: Use a vllm-specific Validation Error
return self.create_error_response(str(e))
def get_chat_request_role(self, request: ChatCompletionRequest) -> str:
if request.add_generation_prompt:
@ -422,7 +425,6 @@ class OpenAIServingChat(OpenAIServing):
async def chat_completion_full_generator(
self,
request: ChatCompletionRequest,
raw_request: Optional[Request],
result_generator: AsyncIterator[RequestOutput],
request_id: str,
conversation: List[ConversationMessage],
@ -433,12 +435,12 @@ class OpenAIServingChat(OpenAIServing):
created_time = int(time.time())
final_res: Optional[RequestOutput] = None
async for res in result_generator:
if raw_request is not None and await raw_request.is_disconnected():
# Abort the request if the client disconnects.
await self.async_engine_client.abort(request_id)
return self.create_error_response("Client disconnected")
final_res = res
try:
async for res in result_generator:
final_res = res
except asyncio.CancelledError:
return self.create_error_response("Client disconnected")
assert final_res is not None
choices: List[ChatCompletionResponseChoice] = []

View File

@ -1,3 +1,4 @@
import asyncio
import time
from typing import (AsyncGenerator, AsyncIterator, Callable, Dict, List,
Optional)
@ -84,7 +85,7 @@ class OpenAIServingCompletion(OpenAIServing):
created_time = int(time.time())
# Schedule the request and get the result generator.
generators: List[AsyncIterator[RequestOutput]] = []
generators: List[AsyncGenerator[RequestOutput, None]] = []
try:
(
lora_request,
@ -144,7 +145,8 @@ class OpenAIServingCompletion(OpenAIServing):
return self.create_error_response(str(e))
result_generator: AsyncIterator[Tuple[
int, RequestOutput]] = merge_async_iterators(*generators)
int, RequestOutput]] = merge_async_iterators(
*generators, is_cancelled=raw_request.is_disconnected)
# Similar to the OpenAI API, when n != best_of, we do not stream the
# results. In addition, we do not stream the results when use
@ -156,7 +158,6 @@ class OpenAIServingCompletion(OpenAIServing):
# Streaming response
if stream:
return self.completion_stream_generator(request,
raw_request,
result_generator,
request_id,
created_time,
@ -168,10 +169,6 @@ class OpenAIServingCompletion(OpenAIServing):
final_res_batch: List[Optional[RequestOutput]] = [None] * len(prompts)
try:
async for i, res in result_generator:
if await raw_request.is_disconnected():
# Abort the request if the client disconnects.
await self.async_engine_client.abort(f"{request_id}-{i}")
return self.create_error_response("Client disconnected")
final_res_batch[i] = res
for i, final_res in enumerate(final_res_batch):
@ -194,6 +191,8 @@ class OpenAIServingCompletion(OpenAIServing):
model_name,
tokenizer,
)
except asyncio.CancelledError:
return self.create_error_response("Client disconnected")
except ValueError as e:
# TODO: Use a vllm-specific Validation Error
return self.create_error_response(str(e))
@ -214,7 +213,6 @@ class OpenAIServingCompletion(OpenAIServing):
async def completion_stream_generator(
self,
request: CompletionRequest,
raw_request: Request,
result_generator: AsyncIterator[Tuple[int, RequestOutput]],
request_id: str,
created_time: int,
@ -230,12 +228,6 @@ class OpenAIServingCompletion(OpenAIServing):
try:
async for prompt_idx, res in result_generator:
# Abort the request if the client disconnects.
if await raw_request.is_disconnected():
await self.async_engine_client.abort(
f"{request_id}-{prompt_idx}")
raise StopAsyncIteration()
for output in res.outputs:
i = output.index + prompt_idx * num_choices
# TODO(simon): optimize the performance by avoiding full

View File

@ -1,6 +1,7 @@
import asyncio
import base64
import time
from typing import AsyncIterator, List, Optional, Tuple, cast
from typing import AsyncGenerator, AsyncIterator, List, Optional, Tuple, cast
import numpy as np
from fastapi import Request
@ -92,7 +93,7 @@ class OpenAIServingEmbedding(OpenAIServing):
created_time = int(time.monotonic())
# Schedule the request and get the result generator.
generators: List[AsyncIterator[EmbeddingRequestOutput]] = []
generators: List[AsyncGenerator[EmbeddingRequestOutput, None]] = []
try:
(
lora_request,
@ -138,17 +139,14 @@ class OpenAIServingEmbedding(OpenAIServing):
return self.create_error_response(str(e))
result_generator: AsyncIterator[Tuple[
int, EmbeddingRequestOutput]] = merge_async_iterators(*generators)
int, EmbeddingRequestOutput]] = merge_async_iterators(
*generators, is_cancelled=raw_request.is_disconnected)
# Non-streaming response
final_res_batch: List[Optional[EmbeddingRequestOutput]]
final_res_batch = [None] * len(prompts)
try:
async for i, res in result_generator:
if await raw_request.is_disconnected():
# Abort the request if the client disconnects.
await self.async_engine_client.abort(f"{request_id}-{i}")
return self.create_error_response("Client disconnected")
final_res_batch[i] = res
for final_res in final_res_batch:
@ -160,6 +158,8 @@ class OpenAIServingEmbedding(OpenAIServing):
response = request_output_to_embedding_response(
final_res_batch_checked, request_id, created_time, model_name,
encoding_format)
except asyncio.CancelledError:
return self.create_error_response("Client disconnected")
except ValueError as e:
# TODO: Use a vllm-specific Validation Error
return self.create_error_response(str(e))

View File

@ -1,5 +1,6 @@
import argparse
import asyncio
import contextlib
import datetime
import enum
import gc
@ -11,10 +12,11 @@ import tempfile
import threading
import uuid
import warnings
from asyncio import FIRST_COMPLETED, ensure_future
from collections import defaultdict
from functools import lru_cache, partial, wraps
from platform import uname
from typing import (Any, AsyncIterator, Awaitable, Callable, Dict, Generic,
from typing import (Any, AsyncGenerator, Awaitable, Callable, Dict, Generic,
Hashable, List, Optional, OrderedDict, Set, Tuple, TypeVar,
Union, overload)
@ -373,63 +375,74 @@ def make_async(func: Callable[P, T]) -> Callable[P, Awaitable[T]]:
return _async_wrapper
class ProducerFinished:
pass
async def iterate_with_cancellation(
iterator: AsyncGenerator[T, None],
is_cancelled: Callable[[], Awaitable[bool]],
) -> AsyncGenerator[T, None]:
"""Convert async iterator into one that polls the provided function
at least once per second to check for client cancellation.
"""
# Can use anext() in python >= 3.10
awaits = [ensure_future(iterator.__anext__())]
while True:
done, pending = await asyncio.wait(awaits, timeout=1)
if await is_cancelled():
with contextlib.suppress(BaseException):
awaits[0].cancel()
await iterator.aclose()
raise asyncio.CancelledError("client cancelled")
if done:
try:
item = await awaits[0]
awaits[0] = ensure_future(iterator.__anext__())
yield item
except StopAsyncIteration:
# we are done
return
def merge_async_iterators(
*iterators: AsyncIterator[T]) -> AsyncIterator[Tuple[int, T]]:
async def merge_async_iterators(
*iterators: AsyncGenerator[T, None],
is_cancelled: Callable[[], Awaitable[bool]],
) -> AsyncGenerator[Tuple[int, T], None]:
"""Merge multiple asynchronous iterators into a single iterator.
This method handle the case where some iterators finish before others.
When it yields, it yields a tuple (i, item) where i is the index of the
iterator that yields the item.
It also polls the provided function at least once per second to check
for client cancellation.
"""
queue: asyncio.Queue[Union[Tuple[int, T], ProducerFinished,
Exception]] = asyncio.Queue()
producers = len(iterators)
async def producer(i: int, iterator: AsyncIterator[T]):
try:
async for item in iterator:
await queue.put((i, item))
except Exception as e:
await queue.put(e)
# Signal to the consumer that we've finished
await queue.put(ProducerFinished())
_tasks = [
asyncio.create_task(producer(i, iterator))
for i, iterator in enumerate(iterators)
]
async def consumer():
remaining = producers
try:
while remaining or not queue.empty():
# we think there is a race condition here
item = await queue.get()
if isinstance(item, ProducerFinished):
# Signal that a producer finished- not a real item
remaining -= 1
continue
if isinstance(item, Exception):
raise item
yield item
except (Exception, asyncio.CancelledError) as e:
for task in _tasks:
if sys.version_info >= (3, 9):
# msg parameter only supported in Python 3.9+
task.cancel(e)
else:
task.cancel()
raise e
await asyncio.gather(*_tasks)
return consumer()
# Can use anext() in python >= 3.10
awaits = {
ensure_future(pair[1].__anext__()): pair
for pair in enumerate(iterators)
}
try:
while awaits:
done, pending = await asyncio.wait(awaits.keys(),
return_when=FIRST_COMPLETED,
timeout=1)
if await is_cancelled():
raise asyncio.CancelledError("client cancelled")
for d in done:
pair = awaits.pop(d)
try:
item = await d
i, it = pair
awaits[ensure_future(it.__anext__())] = pair
yield i, item
except StopAsyncIteration:
pass
finally:
# Cancel any remaining iterators
for f, (_, it) in awaits.items():
with contextlib.suppress(BaseException):
f.cancel()
await it.aclose()
def get_ip() -> str: