From c07ece5ca490a90b2b19c33ab7da2d21e015d7bd Mon Sep 17 00:00:00 2001 From: Antoni Baum Date: Thu, 7 Sep 2023 13:43:45 -0700 Subject: [PATCH] Make `AsyncLLMEngine` more robust & fix batched abort (#969) Signed-off-by: Antoni Baum Co-authored-by: Avnish Narayan <38871737+avnishn@users.noreply.github.com> --- tests/async_engine/api_server_async_engine.py | 51 +++++ tests/async_engine/test_api_server.py | 86 ++++++++ tests/async_engine/test_request_tracker.py | 54 +++++ vllm/core/scheduler.py | 5 +- vllm/engine/async_llm_engine.py | 202 +++++++++++++----- vllm/entrypoints/api_server.py | 1 + vllm/entrypoints/openai/api_server.py | 1 + 7 files changed, 345 insertions(+), 55 deletions(-) create mode 100644 tests/async_engine/api_server_async_engine.py create mode 100644 tests/async_engine/test_api_server.py create mode 100644 tests/async_engine/test_request_tracker.py diff --git a/tests/async_engine/api_server_async_engine.py b/tests/async_engine/api_server_async_engine.py new file mode 100644 index 00000000..e84916da --- /dev/null +++ b/tests/async_engine/api_server_async_engine.py @@ -0,0 +1,51 @@ +"""vllm.entrypoints.api_server with some extra logging for testing.""" +import argparse +from typing import Any, Dict + +import uvicorn +from fastapi.responses import JSONResponse, Response + +import vllm.entrypoints.api_server +from vllm.engine.arg_utils import AsyncEngineArgs +from vllm.engine.async_llm_engine import AsyncLLMEngine + +app = vllm.entrypoints.api_server.app + + +class AsyncLLMEngineWithStats(AsyncLLMEngine): + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self._num_aborts = 0 + + async def abort(self, request_id: str) -> None: + await super().abort(request_id) + self._num_aborts += 1 + + def testing_stats(self) -> Dict[str, Any]: + return {"num_aborted_requests": self._num_aborts} + + +@app.get("/stats") +def stats() -> Response: + """Get the statistics of the engine.""" + return JSONResponse(engine.testing_stats()) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--host", type=str, default="localhost") + parser.add_argument("--port", type=int, default=8000) + parser = AsyncEngineArgs.add_cli_args(parser) + args = parser.parse_args() + + engine_args = AsyncEngineArgs.from_cli_args(args) + engine = AsyncLLMEngineWithStats.from_engine_args(engine_args, + start_engine_loop=False) + vllm.entrypoints.api_server.engine = engine + uvicorn.run( + app, + host=args.host, + port=args.port, + log_level="debug", + timeout_keep_alive=vllm.entrypoints.api_server.TIMEOUT_KEEP_ALIVE) diff --git a/tests/async_engine/test_api_server.py b/tests/async_engine/test_api_server.py new file mode 100644 index 00000000..dee62b9a --- /dev/null +++ b/tests/async_engine/test_api_server.py @@ -0,0 +1,86 @@ +import subprocess +import sys +import time +from multiprocessing import Pool +from pathlib import Path + +import pytest +import requests + + +def _query_server(prompt: str) -> dict: + response = requests.post("http://localhost:8000/generate", + json={ + "prompt": prompt, + "max_tokens": 100, + "temperature": 0, + "ignore_eos": True + }) + response.raise_for_status() + return response.json() + + +@pytest.fixture +def api_server(): + script_path = Path(__file__).parent.joinpath( + "api_server_async_engine.py").absolute() + uvicorn_process = subprocess.Popen([ + sys.executable, "-u", + str(script_path), "--model", "facebook/opt-125m" + ]) + yield + uvicorn_process.terminate() + + +def test_api_server(api_server): + """ + Run the API server and test it. + + We run both the server and requests in separate processes. + + We test that the server can handle incoming requests, including + multiple requests at the same time, and that it can handle requests + being cancelled without crashing. + """ + with Pool(32) as pool: + # Wait until the server is ready + prompts = ["Hello world"] * 1 + result = None + while not result: + try: + for result in pool.map(_query_server, prompts): + break + except: + time.sleep(1) + + # Actual tests start here + # Try with 1 prompt + for result in pool.map(_query_server, prompts): + assert result + + num_aborted_requests = requests.get( + "http://localhost:8000/stats").json()["num_aborted_requests"] + assert num_aborted_requests == 0 + + # Try with 100 prompts + prompts = ["Hello world"] * 100 + for result in pool.map(_query_server, prompts): + assert result + + # Cancel requests + pool.map_async(_query_server, prompts) + time.sleep(0.01) + pool.terminate() + pool.join() + + # check cancellation stats + num_aborted_requests = requests.get( + "http://localhost:8000/stats").json()["num_aborted_requests"] + assert num_aborted_requests > 0 + + # check that server still runs after cancellations + with Pool(32) as pool: + # Try with 100 prompts + prompts = ["Hello world"] * 100 + for result in pool.map(_query_server, prompts): + assert result diff --git a/tests/async_engine/test_request_tracker.py b/tests/async_engine/test_request_tracker.py new file mode 100644 index 00000000..3666e6c7 --- /dev/null +++ b/tests/async_engine/test_request_tracker.py @@ -0,0 +1,54 @@ +import pytest + +from vllm.engine.async_llm_engine import RequestTracker +from vllm.outputs import RequestOutput + + +def test_request_tracker(): + tracker = RequestTracker() + stream_1 = tracker.add_request("1") + new, finished = tracker.get_new_and_finished_requests() + assert len(new) == 1 + assert new[0]["request_id"] == "1" + assert not finished + assert not stream_1.finished + + stream_2 = tracker.add_request("2") + stream_3 = tracker.add_request("3") + new, finished = tracker.get_new_and_finished_requests() + assert len(new) == 2 + assert new[0]["request_id"] == "2" + assert new[1]["request_id"] == "3" + assert not finished + assert not stream_2.finished + assert not stream_3.finished + + # request_ids must be unique + with pytest.raises(KeyError): + tracker.add_request("1") + + tracker.abort_request("1") + new, finished = tracker.get_new_and_finished_requests() + assert len(finished) == 1 + assert "1" in finished + assert not new + assert stream_1.finished + + stream_4 = tracker.add_request("4") + tracker.abort_request("4") + new, finished = tracker.get_new_and_finished_requests() + assert len(finished) == 1 + assert "4" in finished + assert not new + assert stream_4.finished + + stream_5 = tracker.add_request("5") + tracker.process_request_output( + RequestOutput("2", "output", [], [], finished=True)) + new, finished = tracker.get_new_and_finished_requests() + assert len(finished) == 1 + assert "2" in finished + assert len(new) == 1 + assert new[0]["request_id"] == "5" + assert stream_2.finished + assert not stream_5.finished diff --git a/vllm/core/scheduler.py b/vllm/core/scheduler.py index 2696cf54..55ce5aa6 100644 --- a/vllm/core/scheduler.py +++ b/vllm/core/scheduler.py @@ -92,7 +92,10 @@ class Scheduler: request_id = (request_id, ) request_ids = set(request_id) for state_queue in [self.waiting, self.running, self.swapped]: - for seq_group in state_queue: + # We need to reverse the list as we are removing elements + # from it as we iterate over it. If we don't do it, + # indices will get messed up and we will skip over elements. + for seq_group in reversed(state_queue): if seq_group.request_id in request_ids: # Remove the sequence group from the state queue. state_queue.remove(seq_group) diff --git a/vllm/engine/async_llm_engine.py b/vllm/engine/async_llm_engine.py index a706650e..fc4b4b09 100644 --- a/vllm/engine/async_llm_engine.py +++ b/vllm/engine/async_llm_engine.py @@ -1,7 +1,7 @@ import asyncio import time from functools import partial -from typing import Any, Dict, Iterable, List, Optional, Type, Union +from typing import Any, Dict, Iterable, List, Optional, Set, Tuple, Type, Union from vllm.config import ModelConfig from vllm.engine.arg_utils import AsyncEngineArgs @@ -14,6 +14,28 @@ from vllm.sampling_params import SamplingParams logger = init_logger(__name__) +class AsyncEngineDeadError(RuntimeError): + pass + + +def _raise_exception_on_finish(task: asyncio.Task, + request_tracker: "RequestTracker") -> None: + msg = ("Task finished unexpectedly. This should never happen! " + "Please open an issue on Github.") + try: + try: + task.result() + except asyncio.CancelledError: + return + except Exception as exc: + raise AsyncEngineDeadError( + msg + " See stack trace above for the actual cause.") from exc + raise AsyncEngineDeadError(msg) + except Exception as exc: + request_tracker.propagate_exception(exc) + raise exc + + class AsyncStream: """A stream of RequestOutputs for a request that can be iterated over asynchronously.""" @@ -43,15 +65,90 @@ class AsyncStream: result = await self._queue.get() if result is StopIteration: raise StopAsyncIteration + elif isinstance(result, Exception): + raise result return result -def _raise_exception_on_finish(task: asyncio.Task) -> None: - try: - task.result() - except Exception as e: - raise RuntimeError("Task finished unexpectedly.") from e - raise RuntimeError("Task finished unexpectedly.") +class RequestTracker: + """Synchronous abstraction for tracking requests.""" + + def __init__(self) -> None: + self._request_streams: Dict[str, AsyncStream] = {} + self._finished_requests: asyncio.Queue[str] = asyncio.Queue() + self._new_requests: asyncio.Queue[Tuple[AsyncStream, + dict]] = asyncio.Queue() + + def __contains__(self, item): + return item in self._request_streams + + def propagate_exception(self, exc: Exception) -> None: + """Propagate an exception to all request streams.""" + for stream in self._request_streams.values(): + stream.put(exc) + + def process_request_output(self, + request_output: RequestOutput, + *, + verbose: bool = False) -> None: + """Process a request output from the engine.""" + request_id = request_output.request_id + + self._request_streams[request_id].put(request_output) + if request_output.finished: + if verbose: + logger.info(f"Finished request {request_id}.") + self.abort_request(request_id) + + def add_request(self, request_id: str, + **engine_add_request_kwargs) -> AsyncStream: + """Add a request to be sent to the engine on the next background + loop iteration.""" + if request_id in self._request_streams: + raise KeyError(f"Request {request_id} already exists.") + + stream = AsyncStream(request_id) + self._new_requests.put_nowait((stream, { + "request_id": request_id, + **engine_add_request_kwargs + })) + return stream + + def abort_request(self, request_id: str, *, verbose: bool = False) -> None: + """Abort a request during next background loop iteration.""" + if verbose: + logger.info(f"Aborted request {request_id}.") + + self._finished_requests.put_nowait(request_id) + + if request_id not in self._request_streams or self._request_streams[ + request_id].finished: + # The request has already finished or been aborted. + return + + self._request_streams[request_id].finish() + + def get_new_and_finished_requests(self) -> Tuple[List[dict], Set[str]]: + """Get the new requests and finished requests to be + sent to the engine.""" + new_requests: List[dict] = [] + finished_requests: Set[str] = set() + + while not self._finished_requests.empty(): + request_id = self._finished_requests.get_nowait() + finished_requests.add(request_id) + self._request_streams.pop(request_id, None) + + while not self._new_requests.empty(): + stream, new_request = self._new_requests.get_nowait() + if stream.request_id in finished_requests: + # The request has already been aborted. + stream.finish() + continue + self._request_streams[stream.request_id] = stream + new_requests.append(new_request) + + return new_requests, finished_requests class _AsyncLLMEngine(LLMEngine): @@ -150,16 +247,15 @@ class AsyncLLMEngine: self.log_requests = log_requests self.engine = self._init_engine(*args, **kwargs) - # Request id -> stream. - self.request_streams: Dict[str, AsyncStream] = {} - self.finished_requests: asyncio.Queue[str] = asyncio.Queue() + self.request_tracker: RequestTracker = RequestTracker() self.background_loop = None if start_engine_loop: self.start_background_loop() @property def is_running(self) -> bool: - return self.background_loop is not None + return (self.background_loop is not None + and not self.background_loop.done()) def start_background_loop(self) -> None: """Start the background loop.""" @@ -167,7 +263,9 @@ class AsyncLLMEngine: raise RuntimeError("Background loop is already running.") self.background_loop = asyncio.get_event_loop().create_task( self.run_engine_loop()) - self.background_loop.add_done_callback(_raise_exception_on_finish) + self.background_loop.add_done_callback( + partial(_raise_exception_on_finish, + request_tracker=self.request_tracker)) def _init_engine(self, *args, **kwargs) -> Union[_AsyncLLMEngine, "ray.ObjectRef"]: @@ -181,6 +279,21 @@ class AsyncLLMEngine: async def engine_step(self): """Kick the engine to process the waiting requests.""" + + new_requests, finished_requests = ( + self.request_tracker.get_new_and_finished_requests()) + + for new_request in new_requests: + # Add the request into the vLLM engine's waiting queue. + # TODO: Maybe add add_request_batch to reduce Ray overhead + if self.engine_use_ray: + await self.engine.add_request.remote(**new_request) + else: + self.engine.add_request(**new_request) + + if finished_requests: + await self._engine_abort(finished_requests) + if self.engine_use_ray: request_outputs = await self.engine.step.remote() else: @@ -188,20 +301,8 @@ class AsyncLLMEngine: # Put the outputs into the corresponding streams. for request_output in request_outputs: - request_id = request_output.request_id - self.request_streams[request_id].put(request_output) - if request_output.finished: - if self.log_requests: - logger.info(f"Finished request {request_id}.") - self.request_streams[request_id].finish() - self.finished_requests.put_nowait(request_id) - - finished_request = set() - while not self.finished_requests.empty(): - finished_request.add(self.finished_requests.get_nowait()) - await self._engine_abort(finished_request) - for request_id in finished_request: - del self.request_streams[request_id] + self.request_tracker.process_request_output( + request_output, verbose=self.log_requests) async def _engine_abort(self, request_ids: Iterable[str]): if self.engine_use_ray: @@ -228,25 +329,19 @@ class AsyncLLMEngine: f"sampling params: {sampling_params}, " f"prompt token ids: {prompt_token_ids}.") - if request_id in self.request_streams: - raise KeyError(f"Request {request_id} already exists.") - stream = AsyncStream(request_id) - self.request_streams[request_id] = stream + if not self.is_running: + raise AsyncEngineDeadError( + "Background loop is not running. If it was running, " + "inspect the output to find the stacktrace of the " + "error that caused the background loop to stop " + "(AsyncEngineDeadError).") - # Add the request into the vLLM engine's waiting queue. - if self.engine_use_ray: - await self.engine.add_request.remote( - request_id, - prompt, - sampling_params, - prompt_token_ids=prompt_token_ids, - arrival_time=arrival_time) - else: - self.engine.add_request(request_id, - prompt, - sampling_params, - prompt_token_ids=prompt_token_ids, - arrival_time=arrival_time) + stream = self.request_tracker.add_request( + request_id, + prompt=prompt, + sampling_params=sampling_params, + prompt_token_ids=prompt_token_ids, + arrival_time=arrival_time) return stream @@ -300,6 +395,13 @@ class AsyncLLMEngine: Args: request_id: The unique id of the request. """ + if not self.is_running: + raise AsyncEngineDeadError( + "Background loop is not running. If it was running, " + "inspect the output to find the stacktrace of the " + "error that caused the background loop to stop " + "(AsyncEngineDeadError).") + return self._abort(request_id) def _abort(self, request_id: str) -> None: @@ -311,16 +413,8 @@ class AsyncLLMEngine: Args: request_id: The unique id of the request. """ - if request_id not in self.request_streams or self.request_streams[ - request_id].finished: - # The request has already finished or been aborted. - return - - if self.log_requests: - logger.info(f"Aborted request {request_id}.") - - self.request_streams[request_id].finish() - self.finished_requests.put_nowait(request_id) + self.request_tracker.abort_request(request_id, + verbose=self.log_requests) async def get_model_config(self) -> ModelConfig: """Get the model configuration of the vLLM engine.""" diff --git a/vllm/entrypoints/api_server.py b/vllm/entrypoints/api_server.py index f430c0ff..b73642b2 100644 --- a/vllm/entrypoints/api_server.py +++ b/vllm/entrypoints/api_server.py @@ -14,6 +14,7 @@ from vllm.utils import random_uuid TIMEOUT_KEEP_ALIVE = 5 # seconds. TIMEOUT_TO_PREVENT_DEADLOCK = 1 # seconds. app = FastAPI() +engine = None @app.post("/generate") diff --git a/vllm/entrypoints/openai/api_server.py b/vllm/entrypoints/openai/api_server.py index 16296600..22170a05 100644 --- a/vllm/entrypoints/openai/api_server.py +++ b/vllm/entrypoints/openai/api_server.py @@ -44,6 +44,7 @@ TIMEOUT_KEEP_ALIVE = 5 # seconds logger = init_logger(__name__) served_model = None app = fastapi.FastAPI() +engine = None def create_error_response(status_code: HTTPStatus,