[BugFix] Overhaul async request cancellation (#7111)
This commit is contained in:
parent
f9a5600649
commit
9a3f49ae07
@ -1,5 +1,5 @@
|
||||
"""vllm.entrypoints.api_server with some extra logging for testing."""
|
||||
from typing import Any, Dict
|
||||
from typing import Any, Dict, Iterable
|
||||
|
||||
import uvicorn
|
||||
from fastapi.responses import JSONResponse, Response
|
||||
@ -18,9 +18,10 @@ class AsyncLLMEngineWithStats(AsyncLLMEngine):
|
||||
super().__init__(*args, **kwargs)
|
||||
self._num_aborts = 0
|
||||
|
||||
async def abort(self, request_id: str) -> None:
|
||||
await super().abort(request_id)
|
||||
self._num_aborts += 1
|
||||
async def _engine_abort(self, request_ids: Iterable[str]):
|
||||
ids = list(request_ids)
|
||||
self._num_aborts += len(ids)
|
||||
await super()._engine_abort(ids)
|
||||
|
||||
def testing_stats(self) -> Dict[str, Any]:
|
||||
return {"num_aborted_requests": self._num_aborts}
|
||||
|
@ -10,23 +10,23 @@ async def test_request_tracker():
|
||||
stream_1 = tracker.add_request("1")
|
||||
assert tracker.new_requests_event.is_set()
|
||||
await tracker.wait_for_new_requests()
|
||||
new, finished = tracker.get_new_and_finished_requests()
|
||||
new, aborted = tracker.get_new_and_aborted_requests()
|
||||
assert not tracker.new_requests_event.is_set()
|
||||
assert len(new) == 1
|
||||
assert new[0]["request_id"] == "1"
|
||||
assert not finished
|
||||
assert not aborted
|
||||
assert not stream_1.finished
|
||||
|
||||
stream_2 = tracker.add_request("2")
|
||||
stream_3 = tracker.add_request("3")
|
||||
assert tracker.new_requests_event.is_set()
|
||||
await tracker.wait_for_new_requests()
|
||||
new, finished = tracker.get_new_and_finished_requests()
|
||||
new, aborted = tracker.get_new_and_aborted_requests()
|
||||
assert not tracker.new_requests_event.is_set()
|
||||
assert len(new) == 2
|
||||
assert new[0]["request_id"] == "2"
|
||||
assert new[1]["request_id"] == "3"
|
||||
assert not finished
|
||||
assert not aborted
|
||||
assert not stream_2.finished
|
||||
assert not stream_3.finished
|
||||
|
||||
@ -36,9 +36,9 @@ async def test_request_tracker():
|
||||
assert not tracker.new_requests_event.is_set()
|
||||
|
||||
tracker.abort_request("1")
|
||||
new, finished = tracker.get_new_and_finished_requests()
|
||||
assert len(finished) == 1
|
||||
assert "1" in finished
|
||||
new, aborted = tracker.get_new_and_aborted_requests()
|
||||
assert len(aborted) == 1
|
||||
assert "1" in aborted
|
||||
assert not new
|
||||
assert stream_1.finished
|
||||
|
||||
@ -46,9 +46,9 @@ async def test_request_tracker():
|
||||
tracker.abort_request("4")
|
||||
assert tracker.new_requests_event.is_set()
|
||||
await tracker.wait_for_new_requests()
|
||||
new, finished = tracker.get_new_and_finished_requests()
|
||||
assert len(finished) == 1
|
||||
assert "4" in finished
|
||||
new, aborted = tracker.get_new_and_aborted_requests()
|
||||
assert len(aborted) == 1
|
||||
assert "4" in aborted
|
||||
assert not new
|
||||
assert stream_4.finished
|
||||
|
||||
@ -57,10 +57,9 @@ async def test_request_tracker():
|
||||
tracker.process_request_output(
|
||||
RequestOutput("2", "output", [], [], [], finished=True))
|
||||
await tracker.wait_for_new_requests()
|
||||
new, finished = tracker.get_new_and_finished_requests()
|
||||
new, aborted = tracker.get_new_and_aborted_requests()
|
||||
assert not tracker.new_requests_event.is_set()
|
||||
assert len(finished) == 1
|
||||
assert "2" in finished
|
||||
assert not aborted
|
||||
assert len(new) == 1
|
||||
assert new[0]["request_id"] == "5"
|
||||
assert stream_2.finished
|
||||
|
@ -2,6 +2,7 @@ import asyncio
|
||||
import os
|
||||
import socket
|
||||
import sys
|
||||
from functools import partial
|
||||
from typing import (TYPE_CHECKING, Any, AsyncIterator, Awaitable, Protocol,
|
||||
Tuple, TypeVar)
|
||||
|
||||
@ -37,11 +38,11 @@ async def test_merge_async_iterators():
|
||||
yield f"item from iterator {idx}"
|
||||
await asyncio.sleep(0.1)
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
print(f"iterator {idx} cancelled")
|
||||
|
||||
iterators = [mock_async_iterator(i) for i in range(3)]
|
||||
merged_iterator: AsyncIterator[Tuple[int, str]] = merge_async_iterators(
|
||||
*iterators)
|
||||
*iterators, is_cancelled=partial(asyncio.sleep, 0, result=False))
|
||||
|
||||
async def stream_output(generator: AsyncIterator[Tuple[int, str]]):
|
||||
async for idx, output in generator:
|
||||
|
@ -1,7 +1,7 @@
|
||||
import asyncio
|
||||
import time
|
||||
from functools import partial
|
||||
from typing import (AsyncIterator, Callable, Dict, Iterable, List, Mapping,
|
||||
from typing import (AsyncGenerator, Callable, Dict, Iterable, List, Mapping,
|
||||
Optional, Set, Tuple, Type, Union)
|
||||
|
||||
from transformers import PreTrainedTokenizer
|
||||
@ -62,12 +62,16 @@ def _log_task_completion(task: asyncio.Task,
|
||||
"actual cause.") from e
|
||||
|
||||
|
||||
STOP_ITERATION = Exception() # Sentinel
|
||||
|
||||
|
||||
class AsyncStream:
|
||||
"""A stream of RequestOutputs or EmbeddingRequestOutputs for a request
|
||||
that can be iterated over asynchronously."""
|
||||
that can be iterated over asynchronously via an async generator."""
|
||||
|
||||
def __init__(self, request_id: str) -> None:
|
||||
def __init__(self, request_id: str, cancel: Callable[[str], None]) -> None:
|
||||
self.request_id = request_id
|
||||
self._cancel = cancel
|
||||
self._queue: asyncio.Queue = asyncio.Queue()
|
||||
self._finished = False
|
||||
|
||||
@ -77,22 +81,30 @@ class AsyncStream:
|
||||
return
|
||||
self._queue.put_nowait(item)
|
||||
|
||||
def finish(self) -> None:
|
||||
self._queue.put_nowait(StopAsyncIteration())
|
||||
def finish(self, cancelled: bool = False) -> None:
|
||||
if not self._finished:
|
||||
self._finished = True
|
||||
self._queue.put_nowait(
|
||||
asyncio.CancelledError if cancelled else STOP_ITERATION)
|
||||
|
||||
@property
|
||||
def finished(self) -> bool:
|
||||
return self._finished
|
||||
|
||||
def __aiter__(self):
|
||||
return self
|
||||
|
||||
async def __anext__(self) -> Union[RequestOutput, EmbeddingRequestOutput]:
|
||||
async def generator(
|
||||
self
|
||||
) -> AsyncGenerator[Union[RequestOutput, EmbeddingRequestOutput], None]:
|
||||
try:
|
||||
while not self._finished:
|
||||
result = await self._queue.get()
|
||||
if isinstance(result, Exception):
|
||||
if result == STOP_ITERATION:
|
||||
return
|
||||
raise result
|
||||
return result
|
||||
yield result
|
||||
except GeneratorExit:
|
||||
self._cancel(self.request_id)
|
||||
raise asyncio.CancelledError from None
|
||||
|
||||
|
||||
class RequestTracker:
|
||||
@ -100,7 +112,7 @@ class RequestTracker:
|
||||
|
||||
def __init__(self) -> None:
|
||||
self._request_streams: Dict[str, AsyncStream] = {}
|
||||
self._finished_requests: asyncio.Queue[str] = asyncio.Queue()
|
||||
self._aborted_requests: asyncio.Queue[str] = asyncio.Queue()
|
||||
self._new_requests: asyncio.Queue[Tuple[AsyncStream,
|
||||
dict]] = asyncio.Queue()
|
||||
self.new_requests_event = asyncio.Event()
|
||||
@ -131,15 +143,21 @@ class RequestTracker:
|
||||
verbose: bool = False) -> None:
|
||||
"""Process a request output from the engine."""
|
||||
request_id = request_output.request_id
|
||||
finished = request_output.finished
|
||||
|
||||
if finished:
|
||||
stream = self._request_streams.pop(request_id, None)
|
||||
else:
|
||||
stream = self._request_streams.get(request_id)
|
||||
# Guard against a KeyError which can occur if the request was aborted
|
||||
# while the output was generated
|
||||
if (stream := self._request_streams.get(request_id)) is not None:
|
||||
if stream is not None:
|
||||
stream.put(request_output)
|
||||
if request_output.finished:
|
||||
if verbose:
|
||||
if finished:
|
||||
stream.finish()
|
||||
|
||||
if verbose and finished:
|
||||
logger.info("Finished request %s.", request_id)
|
||||
self.abort_request(request_id)
|
||||
|
||||
def process_exception(self,
|
||||
request_id: str,
|
||||
@ -162,7 +180,8 @@ class RequestTracker:
|
||||
if request_id in self._request_streams:
|
||||
raise KeyError(f"Request {request_id} already exists.")
|
||||
|
||||
stream = AsyncStream(request_id)
|
||||
abort_request = partial(self.abort_request, verbose=verbose)
|
||||
stream = AsyncStream(request_id, abort_request)
|
||||
self._new_requests.put_nowait((stream, {
|
||||
"request_id": request_id,
|
||||
**engine_add_request_kwargs
|
||||
@ -175,36 +194,36 @@ class RequestTracker:
|
||||
|
||||
return stream
|
||||
|
||||
def abort_request(self, request_id: str, *, verbose: bool = False) -> None:
|
||||
def abort_request(self,
|
||||
request_id: str,
|
||||
*,
|
||||
cancelled: bool = False,
|
||||
verbose: bool = False) -> None:
|
||||
"""Abort a request during next background loop iteration."""
|
||||
if verbose:
|
||||
logger.info("Aborted request %s.", request_id)
|
||||
|
||||
self._finished_requests.put_nowait(request_id)
|
||||
self._aborted_requests.put_nowait(request_id)
|
||||
|
||||
if request_id not in self._request_streams or self._request_streams[
|
||||
request_id].finished:
|
||||
# The request has already finished or been aborted.
|
||||
return
|
||||
stream = self._request_streams.pop(request_id, None)
|
||||
if stream is not None:
|
||||
stream.finish(cancelled=cancelled)
|
||||
|
||||
self._request_streams[request_id].finish()
|
||||
|
||||
def get_new_and_finished_requests(self) -> Tuple[List[Dict], Set[str]]:
|
||||
def get_new_and_aborted_requests(self) -> Tuple[List[Dict], Set[str]]:
|
||||
"""Get the new requests and finished requests to be
|
||||
sent to the engine."""
|
||||
new_requests: List[Dict] = []
|
||||
finished_requests: Set[str] = set()
|
||||
|
||||
while not self._finished_requests.empty():
|
||||
request_id = self._finished_requests.get_nowait()
|
||||
while not self._aborted_requests.empty():
|
||||
request_id = self._aborted_requests.get_nowait()
|
||||
finished_requests.add(request_id)
|
||||
self._request_streams.pop(request_id, None)
|
||||
|
||||
while not self._new_requests.empty():
|
||||
stream, new_request = self._new_requests.get_nowait()
|
||||
if stream.request_id in finished_requests:
|
||||
# The request has already been aborted.
|
||||
stream.finish()
|
||||
stream.finish(cancelled=True)
|
||||
continue
|
||||
self._request_streams[stream.request_id] = stream
|
||||
new_requests.append(new_request)
|
||||
@ -556,8 +575,8 @@ class AsyncLLMEngine:
|
||||
|
||||
Returns True if there are in-progress requests."""
|
||||
|
||||
new_requests, finished_requests = (
|
||||
self._request_tracker.get_new_and_finished_requests())
|
||||
new_requests, aborted_requests = (
|
||||
self._request_tracker.get_new_and_aborted_requests())
|
||||
|
||||
for new_request in new_requests:
|
||||
# Add the request into the vLLM engine's waiting queue.
|
||||
@ -576,8 +595,8 @@ class AsyncLLMEngine:
|
||||
verbose=self.log_requests,
|
||||
)
|
||||
|
||||
if finished_requests:
|
||||
await self._engine_abort(finished_requests)
|
||||
if aborted_requests:
|
||||
await self._engine_abort(aborted_requests)
|
||||
|
||||
if self.engine_use_ray:
|
||||
request_outputs = await self.engine.step.remote() # type: ignore
|
||||
@ -666,6 +685,8 @@ class AsyncLLMEngine:
|
||||
raise
|
||||
await asyncio.sleep(0)
|
||||
|
||||
# This method does not need to be async, but kept that way
|
||||
# for backwards compatibility.
|
||||
async def add_request(
|
||||
self,
|
||||
request_id: str,
|
||||
@ -675,7 +696,7 @@ class AsyncLLMEngine:
|
||||
lora_request: Optional[LoRARequest] = None,
|
||||
trace_headers: Optional[Mapping[str, str]] = None,
|
||||
prompt_adapter_request: Optional[PromptAdapterRequest] = None
|
||||
) -> AsyncStream:
|
||||
) -> AsyncGenerator[Union[RequestOutput, EmbeddingRequestOutput], None]:
|
||||
if not self.is_running:
|
||||
if self.start_engine_loop:
|
||||
self.start_background_loop()
|
||||
@ -686,20 +707,17 @@ class AsyncLLMEngine:
|
||||
"error that caused the background loop to stop "
|
||||
"(AsyncEngineDeadError).")
|
||||
|
||||
if arrival_time is None:
|
||||
arrival_time = time.time()
|
||||
|
||||
stream = self._request_tracker.add_request(
|
||||
request_id,
|
||||
verbose=self.log_requests,
|
||||
inputs=inputs,
|
||||
params=params,
|
||||
arrival_time=arrival_time,
|
||||
arrival_time=arrival_time or time.time(),
|
||||
lora_request=lora_request,
|
||||
trace_headers=trace_headers,
|
||||
prompt_adapter_request=prompt_adapter_request)
|
||||
|
||||
return stream
|
||||
return stream.generator()
|
||||
|
||||
async def generate(
|
||||
self,
|
||||
@ -709,7 +727,7 @@ class AsyncLLMEngine:
|
||||
lora_request: Optional[LoRARequest] = None,
|
||||
trace_headers: Optional[Mapping[str, str]] = None,
|
||||
prompt_adapter_request: Optional[PromptAdapterRequest] = None
|
||||
) -> AsyncIterator[RequestOutput]:
|
||||
) -> AsyncGenerator[RequestOutput, None]:
|
||||
"""Generate outputs for a request.
|
||||
|
||||
Generate outputs for a request. This method is a coroutine. It adds the
|
||||
@ -774,7 +792,7 @@ class AsyncLLMEngine:
|
||||
>>> # Process and return the final output
|
||||
>>> ...
|
||||
"""
|
||||
async for output in self._process_request(
|
||||
async for output in await self.add_request(
|
||||
request_id,
|
||||
inputs,
|
||||
sampling_params,
|
||||
@ -791,7 +809,7 @@ class AsyncLLMEngine:
|
||||
request_id: str,
|
||||
lora_request: Optional[LoRARequest] = None,
|
||||
trace_headers: Optional[Mapping[str, str]] = None,
|
||||
) -> AsyncIterator[EmbeddingRequestOutput]:
|
||||
) -> AsyncGenerator[EmbeddingRequestOutput, None]:
|
||||
"""Generate outputs for a request from an embedding model.
|
||||
|
||||
Generate outputs for a request. This method is a coroutine. It adds the
|
||||
@ -852,7 +870,7 @@ class AsyncLLMEngine:
|
||||
>>> # Process and return the final output
|
||||
>>> ...
|
||||
"""
|
||||
async for output in self._process_request(
|
||||
async for output in await self.add_request(
|
||||
request_id,
|
||||
inputs,
|
||||
pooling_params,
|
||||
@ -861,37 +879,6 @@ class AsyncLLMEngine:
|
||||
):
|
||||
yield LLMEngine.validate_output(output, EmbeddingRequestOutput)
|
||||
|
||||
async def _process_request(
|
||||
self,
|
||||
request_id: str,
|
||||
inputs: PromptInputs,
|
||||
params: Union[SamplingParams, PoolingParams],
|
||||
*,
|
||||
lora_request: Optional[LoRARequest] = None,
|
||||
trace_headers: Optional[Mapping[str, str]] = None,
|
||||
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
|
||||
) -> AsyncIterator[Union[RequestOutput, EmbeddingRequestOutput]]:
|
||||
"""Common logic to process requests with SamplingParams or
|
||||
PoolingParams."""
|
||||
arrival_time = time.time()
|
||||
|
||||
stream = await self.add_request(
|
||||
request_id,
|
||||
inputs,
|
||||
params,
|
||||
arrival_time=arrival_time,
|
||||
lora_request=lora_request,
|
||||
trace_headers=trace_headers,
|
||||
prompt_adapter_request=prompt_adapter_request,
|
||||
)
|
||||
|
||||
try:
|
||||
async for request_output in stream:
|
||||
yield request_output
|
||||
except (Exception, asyncio.CancelledError) as e:
|
||||
self._abort(request_id)
|
||||
raise e
|
||||
|
||||
async def abort(self, request_id: str) -> None:
|
||||
"""Abort a request.
|
||||
|
||||
@ -920,6 +907,7 @@ class AsyncLLMEngine:
|
||||
request_id: The unique id of the request.
|
||||
"""
|
||||
self._request_tracker.abort_request(request_id,
|
||||
cancelled=True,
|
||||
verbose=self.log_requests)
|
||||
|
||||
async def get_model_config(self) -> ModelConfig:
|
||||
|
@ -1,4 +1,4 @@
|
||||
from typing import (AsyncIterator, List, Mapping, Optional, Protocol,
|
||||
from typing import (AsyncGenerator, List, Mapping, Optional, Protocol,
|
||||
runtime_checkable)
|
||||
|
||||
from transformers import PreTrainedTokenizer
|
||||
@ -30,7 +30,7 @@ class AsyncEngineClient(Protocol):
|
||||
def errored(self) -> bool:
|
||||
...
|
||||
|
||||
async def generate(
|
||||
def generate(
|
||||
self,
|
||||
inputs: PromptInputs,
|
||||
sampling_params: SamplingParams,
|
||||
@ -38,17 +38,17 @@ class AsyncEngineClient(Protocol):
|
||||
lora_request: Optional[LoRARequest] = None,
|
||||
trace_headers: Optional[Mapping[str, str]] = None,
|
||||
prompt_adapter_request: Optional[PromptAdapterRequest] = None
|
||||
) -> AsyncIterator[RequestOutput]:
|
||||
) -> AsyncGenerator[RequestOutput, None]:
|
||||
"""Generates outputs for a request"""
|
||||
|
||||
async def encode(
|
||||
def encode(
|
||||
self,
|
||||
inputs: PromptInputs,
|
||||
pooling_params: PoolingParams,
|
||||
request_id: str,
|
||||
lora_request: Optional[LoRARequest] = None,
|
||||
trace_headers: Optional[Mapping[str, str]] = None,
|
||||
) -> AsyncIterator[EmbeddingRequestOutput]:
|
||||
) -> AsyncGenerator[EmbeddingRequestOutput, None]:
|
||||
"""Generate outputs for a request from an embedding model."""
|
||||
|
||||
async def abort(self, request_id: str) -> None:
|
||||
|
@ -20,7 +20,8 @@ from vllm.entrypoints.launcher import serve_http
|
||||
from vllm.logger import init_logger
|
||||
from vllm.sampling_params import SamplingParams
|
||||
from vllm.usage.usage_lib import UsageContext
|
||||
from vllm.utils import FlexibleArgumentParser, random_uuid
|
||||
from vllm.utils import (FlexibleArgumentParser, iterate_with_cancellation,
|
||||
random_uuid)
|
||||
from vllm.version import __version__ as VLLM_VERSION
|
||||
|
||||
logger = init_logger("vllm.entrypoints.api_server")
|
||||
@ -53,6 +54,8 @@ async def generate(request: Request) -> Response:
|
||||
|
||||
assert engine is not None
|
||||
results_generator = engine.generate(prompt, sampling_params, request_id)
|
||||
results_generator = iterate_with_cancellation(
|
||||
results_generator, is_cancelled=request.is_disconnected)
|
||||
|
||||
# Streaming case
|
||||
async def stream_results() -> AsyncGenerator[bytes, None]:
|
||||
@ -69,12 +72,11 @@ async def generate(request: Request) -> Response:
|
||||
|
||||
# Non-streaming case
|
||||
final_output = None
|
||||
try:
|
||||
async for request_output in results_generator:
|
||||
if await request.is_disconnected():
|
||||
# Abort the request if the client disconnects.
|
||||
await engine.abort(request_id)
|
||||
return Response(status_code=499)
|
||||
final_output = request_output
|
||||
except asyncio.CancelledError:
|
||||
return Response(status_code=499)
|
||||
|
||||
assert final_output is not None
|
||||
prompt = final_output.prompt
|
||||
|
@ -1,5 +1,5 @@
|
||||
from contextlib import contextmanager
|
||||
from typing import Any, AsyncIterator, Mapping, Optional
|
||||
from typing import Any, AsyncGenerator, Mapping, Optional
|
||||
|
||||
import cloudpickle
|
||||
import zmq
|
||||
@ -190,9 +190,11 @@ class AsyncEngineRPCClient:
|
||||
lora_request: Optional[LoRARequest] = None,
|
||||
trace_headers: Optional[Mapping[str, str]] = None,
|
||||
prompt_adapter_request: Optional[PromptAdapterRequest] = None
|
||||
) -> AsyncIterator[RequestOutput]:
|
||||
) -> AsyncGenerator[RequestOutput, None]:
|
||||
"""Send an RPCGenerateRequest to the RPCServer and stream responses."""
|
||||
|
||||
finished = False
|
||||
try:
|
||||
with self.socket() as socket:
|
||||
|
||||
# Send RPCGenerateRequest to the RPCServer.
|
||||
@ -208,18 +210,18 @@ class AsyncEngineRPCClient:
|
||||
])
|
||||
|
||||
# Stream back the results from the RPC Server.
|
||||
while True:
|
||||
while not finished:
|
||||
message = await socket.recv()
|
||||
request_output = cloudpickle.loads(message)
|
||||
|
||||
if isinstance(request_output, Exception):
|
||||
raise request_output
|
||||
|
||||
if request_output.finished:
|
||||
break
|
||||
yield request_output
|
||||
|
||||
finished = request_output.finished
|
||||
yield request_output
|
||||
finally:
|
||||
if not finished:
|
||||
await self.abort(request_id)
|
||||
|
||||
async def check_health(self) -> None:
|
||||
"""Raise if unhealthy"""
|
||||
@ -243,6 +245,6 @@ class AsyncEngineRPCClient:
|
||||
"f{health_message}")
|
||||
|
||||
async def encode(self, *args,
|
||||
**kwargs) -> AsyncIterator[EmbeddingRequestOutput]:
|
||||
**kwargs) -> AsyncGenerator[EmbeddingRequestOutput, None]:
|
||||
raise NotImplementedError(
|
||||
"Embeddings not supported with multiprocessing backend")
|
||||
|
@ -1,3 +1,4 @@
|
||||
import asyncio
|
||||
import time
|
||||
from typing import AsyncGenerator, AsyncIterator, Dict, List, Optional
|
||||
from typing import Sequence as GenericSequence
|
||||
@ -29,7 +30,7 @@ from vllm.outputs import RequestOutput
|
||||
from vllm.sequence import Logprob
|
||||
from vllm.tracing import (contains_trace_headers, extract_trace_headers,
|
||||
log_tracing_disabled_warning)
|
||||
from vllm.utils import random_uuid
|
||||
from vllm.utils import iterate_with_cancellation, random_uuid
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
@ -176,15 +177,17 @@ class OpenAIServingChat(OpenAIServing):
|
||||
# TODO: Use a vllm-specific Validation Error
|
||||
return self.create_error_response(str(e))
|
||||
|
||||
if raw_request:
|
||||
result_generator = iterate_with_cancellation(
|
||||
result_generator, raw_request.is_disconnected)
|
||||
|
||||
# Streaming response
|
||||
if request.stream:
|
||||
return self.chat_completion_stream_generator(
|
||||
request, result_generator, request_id, conversation, tokenizer)
|
||||
else:
|
||||
try:
|
||||
return await self.chat_completion_full_generator(
|
||||
request, raw_request, result_generator, request_id,
|
||||
conversation, tokenizer)
|
||||
request, result_generator, request_id, conversation, tokenizer)
|
||||
except ValueError as e:
|
||||
# TODO: Use a vllm-specific Validation Error
|
||||
return self.create_error_response(str(e))
|
||||
@ -422,7 +425,6 @@ class OpenAIServingChat(OpenAIServing):
|
||||
async def chat_completion_full_generator(
|
||||
self,
|
||||
request: ChatCompletionRequest,
|
||||
raw_request: Optional[Request],
|
||||
result_generator: AsyncIterator[RequestOutput],
|
||||
request_id: str,
|
||||
conversation: List[ConversationMessage],
|
||||
@ -433,12 +435,12 @@ class OpenAIServingChat(OpenAIServing):
|
||||
created_time = int(time.time())
|
||||
final_res: Optional[RequestOutput] = None
|
||||
|
||||
try:
|
||||
async for res in result_generator:
|
||||
if raw_request is not None and await raw_request.is_disconnected():
|
||||
# Abort the request if the client disconnects.
|
||||
await self.async_engine_client.abort(request_id)
|
||||
return self.create_error_response("Client disconnected")
|
||||
final_res = res
|
||||
except asyncio.CancelledError:
|
||||
return self.create_error_response("Client disconnected")
|
||||
|
||||
assert final_res is not None
|
||||
|
||||
choices: List[ChatCompletionResponseChoice] = []
|
||||
|
@ -1,3 +1,4 @@
|
||||
import asyncio
|
||||
import time
|
||||
from typing import (AsyncGenerator, AsyncIterator, Callable, Dict, List,
|
||||
Optional)
|
||||
@ -84,7 +85,7 @@ class OpenAIServingCompletion(OpenAIServing):
|
||||
created_time = int(time.time())
|
||||
|
||||
# Schedule the request and get the result generator.
|
||||
generators: List[AsyncIterator[RequestOutput]] = []
|
||||
generators: List[AsyncGenerator[RequestOutput, None]] = []
|
||||
try:
|
||||
(
|
||||
lora_request,
|
||||
@ -144,7 +145,8 @@ class OpenAIServingCompletion(OpenAIServing):
|
||||
return self.create_error_response(str(e))
|
||||
|
||||
result_generator: AsyncIterator[Tuple[
|
||||
int, RequestOutput]] = merge_async_iterators(*generators)
|
||||
int, RequestOutput]] = merge_async_iterators(
|
||||
*generators, is_cancelled=raw_request.is_disconnected)
|
||||
|
||||
# Similar to the OpenAI API, when n != best_of, we do not stream the
|
||||
# results. In addition, we do not stream the results when use
|
||||
@ -156,7 +158,6 @@ class OpenAIServingCompletion(OpenAIServing):
|
||||
# Streaming response
|
||||
if stream:
|
||||
return self.completion_stream_generator(request,
|
||||
raw_request,
|
||||
result_generator,
|
||||
request_id,
|
||||
created_time,
|
||||
@ -168,10 +169,6 @@ class OpenAIServingCompletion(OpenAIServing):
|
||||
final_res_batch: List[Optional[RequestOutput]] = [None] * len(prompts)
|
||||
try:
|
||||
async for i, res in result_generator:
|
||||
if await raw_request.is_disconnected():
|
||||
# Abort the request if the client disconnects.
|
||||
await self.async_engine_client.abort(f"{request_id}-{i}")
|
||||
return self.create_error_response("Client disconnected")
|
||||
final_res_batch[i] = res
|
||||
|
||||
for i, final_res in enumerate(final_res_batch):
|
||||
@ -194,6 +191,8 @@ class OpenAIServingCompletion(OpenAIServing):
|
||||
model_name,
|
||||
tokenizer,
|
||||
)
|
||||
except asyncio.CancelledError:
|
||||
return self.create_error_response("Client disconnected")
|
||||
except ValueError as e:
|
||||
# TODO: Use a vllm-specific Validation Error
|
||||
return self.create_error_response(str(e))
|
||||
@ -214,7 +213,6 @@ class OpenAIServingCompletion(OpenAIServing):
|
||||
async def completion_stream_generator(
|
||||
self,
|
||||
request: CompletionRequest,
|
||||
raw_request: Request,
|
||||
result_generator: AsyncIterator[Tuple[int, RequestOutput]],
|
||||
request_id: str,
|
||||
created_time: int,
|
||||
@ -230,12 +228,6 @@ class OpenAIServingCompletion(OpenAIServing):
|
||||
try:
|
||||
async for prompt_idx, res in result_generator:
|
||||
|
||||
# Abort the request if the client disconnects.
|
||||
if await raw_request.is_disconnected():
|
||||
await self.async_engine_client.abort(
|
||||
f"{request_id}-{prompt_idx}")
|
||||
raise StopAsyncIteration()
|
||||
|
||||
for output in res.outputs:
|
||||
i = output.index + prompt_idx * num_choices
|
||||
# TODO(simon): optimize the performance by avoiding full
|
||||
|
@ -1,6 +1,7 @@
|
||||
import asyncio
|
||||
import base64
|
||||
import time
|
||||
from typing import AsyncIterator, List, Optional, Tuple, cast
|
||||
from typing import AsyncGenerator, AsyncIterator, List, Optional, Tuple, cast
|
||||
|
||||
import numpy as np
|
||||
from fastapi import Request
|
||||
@ -92,7 +93,7 @@ class OpenAIServingEmbedding(OpenAIServing):
|
||||
created_time = int(time.monotonic())
|
||||
|
||||
# Schedule the request and get the result generator.
|
||||
generators: List[AsyncIterator[EmbeddingRequestOutput]] = []
|
||||
generators: List[AsyncGenerator[EmbeddingRequestOutput, None]] = []
|
||||
try:
|
||||
(
|
||||
lora_request,
|
||||
@ -138,17 +139,14 @@ class OpenAIServingEmbedding(OpenAIServing):
|
||||
return self.create_error_response(str(e))
|
||||
|
||||
result_generator: AsyncIterator[Tuple[
|
||||
int, EmbeddingRequestOutput]] = merge_async_iterators(*generators)
|
||||
int, EmbeddingRequestOutput]] = merge_async_iterators(
|
||||
*generators, is_cancelled=raw_request.is_disconnected)
|
||||
|
||||
# Non-streaming response
|
||||
final_res_batch: List[Optional[EmbeddingRequestOutput]]
|
||||
final_res_batch = [None] * len(prompts)
|
||||
try:
|
||||
async for i, res in result_generator:
|
||||
if await raw_request.is_disconnected():
|
||||
# Abort the request if the client disconnects.
|
||||
await self.async_engine_client.abort(f"{request_id}-{i}")
|
||||
return self.create_error_response("Client disconnected")
|
||||
final_res_batch[i] = res
|
||||
|
||||
for final_res in final_res_batch:
|
||||
@ -160,6 +158,8 @@ class OpenAIServingEmbedding(OpenAIServing):
|
||||
response = request_output_to_embedding_response(
|
||||
final_res_batch_checked, request_id, created_time, model_name,
|
||||
encoding_format)
|
||||
except asyncio.CancelledError:
|
||||
return self.create_error_response("Client disconnected")
|
||||
except ValueError as e:
|
||||
# TODO: Use a vllm-specific Validation Error
|
||||
return self.create_error_response(str(e))
|
||||
|
107
vllm/utils.py
107
vllm/utils.py
@ -1,5 +1,6 @@
|
||||
import argparse
|
||||
import asyncio
|
||||
import contextlib
|
||||
import datetime
|
||||
import enum
|
||||
import gc
|
||||
@ -11,10 +12,11 @@ import tempfile
|
||||
import threading
|
||||
import uuid
|
||||
import warnings
|
||||
from asyncio import FIRST_COMPLETED, ensure_future
|
||||
from collections import defaultdict
|
||||
from functools import lru_cache, partial, wraps
|
||||
from platform import uname
|
||||
from typing import (Any, AsyncIterator, Awaitable, Callable, Dict, Generic,
|
||||
from typing import (Any, AsyncGenerator, Awaitable, Callable, Dict, Generic,
|
||||
Hashable, List, Optional, OrderedDict, Set, Tuple, TypeVar,
|
||||
Union, overload)
|
||||
|
||||
@ -373,63 +375,74 @@ def make_async(func: Callable[P, T]) -> Callable[P, Awaitable[T]]:
|
||||
return _async_wrapper
|
||||
|
||||
|
||||
class ProducerFinished:
|
||||
pass
|
||||
async def iterate_with_cancellation(
|
||||
iterator: AsyncGenerator[T, None],
|
||||
is_cancelled: Callable[[], Awaitable[bool]],
|
||||
) -> AsyncGenerator[T, None]:
|
||||
"""Convert async iterator into one that polls the provided function
|
||||
at least once per second to check for client cancellation.
|
||||
"""
|
||||
|
||||
# Can use anext() in python >= 3.10
|
||||
awaits = [ensure_future(iterator.__anext__())]
|
||||
while True:
|
||||
done, pending = await asyncio.wait(awaits, timeout=1)
|
||||
if await is_cancelled():
|
||||
with contextlib.suppress(BaseException):
|
||||
awaits[0].cancel()
|
||||
await iterator.aclose()
|
||||
raise asyncio.CancelledError("client cancelled")
|
||||
if done:
|
||||
try:
|
||||
item = await awaits[0]
|
||||
awaits[0] = ensure_future(iterator.__anext__())
|
||||
yield item
|
||||
except StopAsyncIteration:
|
||||
# we are done
|
||||
return
|
||||
|
||||
|
||||
def merge_async_iterators(
|
||||
*iterators: AsyncIterator[T]) -> AsyncIterator[Tuple[int, T]]:
|
||||
async def merge_async_iterators(
|
||||
*iterators: AsyncGenerator[T, None],
|
||||
is_cancelled: Callable[[], Awaitable[bool]],
|
||||
) -> AsyncGenerator[Tuple[int, T], None]:
|
||||
"""Merge multiple asynchronous iterators into a single iterator.
|
||||
|
||||
This method handle the case where some iterators finish before others.
|
||||
When it yields, it yields a tuple (i, item) where i is the index of the
|
||||
iterator that yields the item.
|
||||
|
||||
It also polls the provided function at least once per second to check
|
||||
for client cancellation.
|
||||
"""
|
||||
queue: asyncio.Queue[Union[Tuple[int, T], ProducerFinished,
|
||||
Exception]] = asyncio.Queue()
|
||||
|
||||
producers = len(iterators)
|
||||
|
||||
async def producer(i: int, iterator: AsyncIterator[T]):
|
||||
# Can use anext() in python >= 3.10
|
||||
awaits = {
|
||||
ensure_future(pair[1].__anext__()): pair
|
||||
for pair in enumerate(iterators)
|
||||
}
|
||||
try:
|
||||
async for item in iterator:
|
||||
await queue.put((i, item))
|
||||
except Exception as e:
|
||||
await queue.put(e)
|
||||
# Signal to the consumer that we've finished
|
||||
await queue.put(ProducerFinished())
|
||||
|
||||
_tasks = [
|
||||
asyncio.create_task(producer(i, iterator))
|
||||
for i, iterator in enumerate(iterators)
|
||||
]
|
||||
|
||||
async def consumer():
|
||||
remaining = producers
|
||||
while awaits:
|
||||
done, pending = await asyncio.wait(awaits.keys(),
|
||||
return_when=FIRST_COMPLETED,
|
||||
timeout=1)
|
||||
if await is_cancelled():
|
||||
raise asyncio.CancelledError("client cancelled")
|
||||
for d in done:
|
||||
pair = awaits.pop(d)
|
||||
try:
|
||||
while remaining or not queue.empty():
|
||||
# we think there is a race condition here
|
||||
item = await queue.get()
|
||||
|
||||
if isinstance(item, ProducerFinished):
|
||||
# Signal that a producer finished- not a real item
|
||||
remaining -= 1
|
||||
continue
|
||||
|
||||
if isinstance(item, Exception):
|
||||
raise item
|
||||
yield item
|
||||
except (Exception, asyncio.CancelledError) as e:
|
||||
for task in _tasks:
|
||||
if sys.version_info >= (3, 9):
|
||||
# msg parameter only supported in Python 3.9+
|
||||
task.cancel(e)
|
||||
else:
|
||||
task.cancel()
|
||||
raise e
|
||||
await asyncio.gather(*_tasks)
|
||||
|
||||
return consumer()
|
||||
item = await d
|
||||
i, it = pair
|
||||
awaits[ensure_future(it.__anext__())] = pair
|
||||
yield i, item
|
||||
except StopAsyncIteration:
|
||||
pass
|
||||
finally:
|
||||
# Cancel any remaining iterators
|
||||
for f, (_, it) in awaits.items():
|
||||
with contextlib.suppress(BaseException):
|
||||
f.cancel()
|
||||
await it.aclose()
|
||||
|
||||
|
||||
def get_ip() -> str:
|
||||
|
Loading…
x
Reference in New Issue
Block a user