Make AsyncLLMEngine more robust & fix batched abort (#969)

Signed-off-by: Antoni Baum <antoni.baum@protonmail.com>
Co-authored-by: Avnish Narayan <38871737+avnishn@users.noreply.github.com>
This commit is contained in:
Antoni Baum 2023-09-07 13:43:45 -07:00 committed by GitHub
parent 7a9c20c715
commit c07ece5ca4
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
7 changed files with 345 additions and 55 deletions

View File

@ -0,0 +1,51 @@
"""vllm.entrypoints.api_server with some extra logging for testing."""
import argparse
from typing import Any, Dict
import uvicorn
from fastapi.responses import JSONResponse, Response
import vllm.entrypoints.api_server
from vllm.engine.arg_utils import AsyncEngineArgs
from vllm.engine.async_llm_engine import AsyncLLMEngine
app = vllm.entrypoints.api_server.app
class AsyncLLMEngineWithStats(AsyncLLMEngine):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self._num_aborts = 0
async def abort(self, request_id: str) -> None:
await super().abort(request_id)
self._num_aborts += 1
def testing_stats(self) -> Dict[str, Any]:
return {"num_aborted_requests": self._num_aborts}
@app.get("/stats")
def stats() -> Response:
"""Get the statistics of the engine."""
return JSONResponse(engine.testing_stats())
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--host", type=str, default="localhost")
parser.add_argument("--port", type=int, default=8000)
parser = AsyncEngineArgs.add_cli_args(parser)
args = parser.parse_args()
engine_args = AsyncEngineArgs.from_cli_args(args)
engine = AsyncLLMEngineWithStats.from_engine_args(engine_args,
start_engine_loop=False)
vllm.entrypoints.api_server.engine = engine
uvicorn.run(
app,
host=args.host,
port=args.port,
log_level="debug",
timeout_keep_alive=vllm.entrypoints.api_server.TIMEOUT_KEEP_ALIVE)

View File

@ -0,0 +1,86 @@
import subprocess
import sys
import time
from multiprocessing import Pool
from pathlib import Path
import pytest
import requests
def _query_server(prompt: str) -> dict:
response = requests.post("http://localhost:8000/generate",
json={
"prompt": prompt,
"max_tokens": 100,
"temperature": 0,
"ignore_eos": True
})
response.raise_for_status()
return response.json()
@pytest.fixture
def api_server():
script_path = Path(__file__).parent.joinpath(
"api_server_async_engine.py").absolute()
uvicorn_process = subprocess.Popen([
sys.executable, "-u",
str(script_path), "--model", "facebook/opt-125m"
])
yield
uvicorn_process.terminate()
def test_api_server(api_server):
"""
Run the API server and test it.
We run both the server and requests in separate processes.
We test that the server can handle incoming requests, including
multiple requests at the same time, and that it can handle requests
being cancelled without crashing.
"""
with Pool(32) as pool:
# Wait until the server is ready
prompts = ["Hello world"] * 1
result = None
while not result:
try:
for result in pool.map(_query_server, prompts):
break
except:
time.sleep(1)
# Actual tests start here
# Try with 1 prompt
for result in pool.map(_query_server, prompts):
assert result
num_aborted_requests = requests.get(
"http://localhost:8000/stats").json()["num_aborted_requests"]
assert num_aborted_requests == 0
# Try with 100 prompts
prompts = ["Hello world"] * 100
for result in pool.map(_query_server, prompts):
assert result
# Cancel requests
pool.map_async(_query_server, prompts)
time.sleep(0.01)
pool.terminate()
pool.join()
# check cancellation stats
num_aborted_requests = requests.get(
"http://localhost:8000/stats").json()["num_aborted_requests"]
assert num_aborted_requests > 0
# check that server still runs after cancellations
with Pool(32) as pool:
# Try with 100 prompts
prompts = ["Hello world"] * 100
for result in pool.map(_query_server, prompts):
assert result

View File

@ -0,0 +1,54 @@
import pytest
from vllm.engine.async_llm_engine import RequestTracker
from vllm.outputs import RequestOutput
def test_request_tracker():
tracker = RequestTracker()
stream_1 = tracker.add_request("1")
new, finished = tracker.get_new_and_finished_requests()
assert len(new) == 1
assert new[0]["request_id"] == "1"
assert not finished
assert not stream_1.finished
stream_2 = tracker.add_request("2")
stream_3 = tracker.add_request("3")
new, finished = tracker.get_new_and_finished_requests()
assert len(new) == 2
assert new[0]["request_id"] == "2"
assert new[1]["request_id"] == "3"
assert not finished
assert not stream_2.finished
assert not stream_3.finished
# request_ids must be unique
with pytest.raises(KeyError):
tracker.add_request("1")
tracker.abort_request("1")
new, finished = tracker.get_new_and_finished_requests()
assert len(finished) == 1
assert "1" in finished
assert not new
assert stream_1.finished
stream_4 = tracker.add_request("4")
tracker.abort_request("4")
new, finished = tracker.get_new_and_finished_requests()
assert len(finished) == 1
assert "4" in finished
assert not new
assert stream_4.finished
stream_5 = tracker.add_request("5")
tracker.process_request_output(
RequestOutput("2", "output", [], [], finished=True))
new, finished = tracker.get_new_and_finished_requests()
assert len(finished) == 1
assert "2" in finished
assert len(new) == 1
assert new[0]["request_id"] == "5"
assert stream_2.finished
assert not stream_5.finished

View File

@ -92,7 +92,10 @@ class Scheduler:
request_id = (request_id, ) request_id = (request_id, )
request_ids = set(request_id) request_ids = set(request_id)
for state_queue in [self.waiting, self.running, self.swapped]: for state_queue in [self.waiting, self.running, self.swapped]:
for seq_group in state_queue: # We need to reverse the list as we are removing elements
# from it as we iterate over it. If we don't do it,
# indices will get messed up and we will skip over elements.
for seq_group in reversed(state_queue):
if seq_group.request_id in request_ids: if seq_group.request_id in request_ids:
# Remove the sequence group from the state queue. # Remove the sequence group from the state queue.
state_queue.remove(seq_group) state_queue.remove(seq_group)

View File

@ -1,7 +1,7 @@
import asyncio import asyncio
import time import time
from functools import partial from functools import partial
from typing import Any, Dict, Iterable, List, Optional, Type, Union from typing import Any, Dict, Iterable, List, Optional, Set, Tuple, Type, Union
from vllm.config import ModelConfig from vllm.config import ModelConfig
from vllm.engine.arg_utils import AsyncEngineArgs from vllm.engine.arg_utils import AsyncEngineArgs
@ -14,6 +14,28 @@ from vllm.sampling_params import SamplingParams
logger = init_logger(__name__) logger = init_logger(__name__)
class AsyncEngineDeadError(RuntimeError):
pass
def _raise_exception_on_finish(task: asyncio.Task,
request_tracker: "RequestTracker") -> None:
msg = ("Task finished unexpectedly. This should never happen! "
"Please open an issue on Github.")
try:
try:
task.result()
except asyncio.CancelledError:
return
except Exception as exc:
raise AsyncEngineDeadError(
msg + " See stack trace above for the actual cause.") from exc
raise AsyncEngineDeadError(msg)
except Exception as exc:
request_tracker.propagate_exception(exc)
raise exc
class AsyncStream: class AsyncStream:
"""A stream of RequestOutputs for a request that can be """A stream of RequestOutputs for a request that can be
iterated over asynchronously.""" iterated over asynchronously."""
@ -43,15 +65,90 @@ class AsyncStream:
result = await self._queue.get() result = await self._queue.get()
if result is StopIteration: if result is StopIteration:
raise StopAsyncIteration raise StopAsyncIteration
elif isinstance(result, Exception):
raise result
return result return result
def _raise_exception_on_finish(task: asyncio.Task) -> None: class RequestTracker:
try: """Synchronous abstraction for tracking requests."""
task.result()
except Exception as e: def __init__(self) -> None:
raise RuntimeError("Task finished unexpectedly.") from e self._request_streams: Dict[str, AsyncStream] = {}
raise RuntimeError("Task finished unexpectedly.") self._finished_requests: asyncio.Queue[str] = asyncio.Queue()
self._new_requests: asyncio.Queue[Tuple[AsyncStream,
dict]] = asyncio.Queue()
def __contains__(self, item):
return item in self._request_streams
def propagate_exception(self, exc: Exception) -> None:
"""Propagate an exception to all request streams."""
for stream in self._request_streams.values():
stream.put(exc)
def process_request_output(self,
request_output: RequestOutput,
*,
verbose: bool = False) -> None:
"""Process a request output from the engine."""
request_id = request_output.request_id
self._request_streams[request_id].put(request_output)
if request_output.finished:
if verbose:
logger.info(f"Finished request {request_id}.")
self.abort_request(request_id)
def add_request(self, request_id: str,
**engine_add_request_kwargs) -> AsyncStream:
"""Add a request to be sent to the engine on the next background
loop iteration."""
if request_id in self._request_streams:
raise KeyError(f"Request {request_id} already exists.")
stream = AsyncStream(request_id)
self._new_requests.put_nowait((stream, {
"request_id": request_id,
**engine_add_request_kwargs
}))
return stream
def abort_request(self, request_id: str, *, verbose: bool = False) -> None:
"""Abort a request during next background loop iteration."""
if verbose:
logger.info(f"Aborted request {request_id}.")
self._finished_requests.put_nowait(request_id)
if request_id not in self._request_streams or self._request_streams[
request_id].finished:
# The request has already finished or been aborted.
return
self._request_streams[request_id].finish()
def get_new_and_finished_requests(self) -> Tuple[List[dict], Set[str]]:
"""Get the new requests and finished requests to be
sent to the engine."""
new_requests: List[dict] = []
finished_requests: Set[str] = set()
while not self._finished_requests.empty():
request_id = self._finished_requests.get_nowait()
finished_requests.add(request_id)
self._request_streams.pop(request_id, None)
while not self._new_requests.empty():
stream, new_request = self._new_requests.get_nowait()
if stream.request_id in finished_requests:
# The request has already been aborted.
stream.finish()
continue
self._request_streams[stream.request_id] = stream
new_requests.append(new_request)
return new_requests, finished_requests
class _AsyncLLMEngine(LLMEngine): class _AsyncLLMEngine(LLMEngine):
@ -150,16 +247,15 @@ class AsyncLLMEngine:
self.log_requests = log_requests self.log_requests = log_requests
self.engine = self._init_engine(*args, **kwargs) self.engine = self._init_engine(*args, **kwargs)
# Request id -> stream. self.request_tracker: RequestTracker = RequestTracker()
self.request_streams: Dict[str, AsyncStream] = {}
self.finished_requests: asyncio.Queue[str] = asyncio.Queue()
self.background_loop = None self.background_loop = None
if start_engine_loop: if start_engine_loop:
self.start_background_loop() self.start_background_loop()
@property @property
def is_running(self) -> bool: def is_running(self) -> bool:
return self.background_loop is not None return (self.background_loop is not None
and not self.background_loop.done())
def start_background_loop(self) -> None: def start_background_loop(self) -> None:
"""Start the background loop.""" """Start the background loop."""
@ -167,7 +263,9 @@ class AsyncLLMEngine:
raise RuntimeError("Background loop is already running.") raise RuntimeError("Background loop is already running.")
self.background_loop = asyncio.get_event_loop().create_task( self.background_loop = asyncio.get_event_loop().create_task(
self.run_engine_loop()) self.run_engine_loop())
self.background_loop.add_done_callback(_raise_exception_on_finish) self.background_loop.add_done_callback(
partial(_raise_exception_on_finish,
request_tracker=self.request_tracker))
def _init_engine(self, *args, def _init_engine(self, *args,
**kwargs) -> Union[_AsyncLLMEngine, "ray.ObjectRef"]: **kwargs) -> Union[_AsyncLLMEngine, "ray.ObjectRef"]:
@ -181,6 +279,21 @@ class AsyncLLMEngine:
async def engine_step(self): async def engine_step(self):
"""Kick the engine to process the waiting requests.""" """Kick the engine to process the waiting requests."""
new_requests, finished_requests = (
self.request_tracker.get_new_and_finished_requests())
for new_request in new_requests:
# Add the request into the vLLM engine's waiting queue.
# TODO: Maybe add add_request_batch to reduce Ray overhead
if self.engine_use_ray:
await self.engine.add_request.remote(**new_request)
else:
self.engine.add_request(**new_request)
if finished_requests:
await self._engine_abort(finished_requests)
if self.engine_use_ray: if self.engine_use_ray:
request_outputs = await self.engine.step.remote() request_outputs = await self.engine.step.remote()
else: else:
@ -188,20 +301,8 @@ class AsyncLLMEngine:
# Put the outputs into the corresponding streams. # Put the outputs into the corresponding streams.
for request_output in request_outputs: for request_output in request_outputs:
request_id = request_output.request_id self.request_tracker.process_request_output(
self.request_streams[request_id].put(request_output) request_output, verbose=self.log_requests)
if request_output.finished:
if self.log_requests:
logger.info(f"Finished request {request_id}.")
self.request_streams[request_id].finish()
self.finished_requests.put_nowait(request_id)
finished_request = set()
while not self.finished_requests.empty():
finished_request.add(self.finished_requests.get_nowait())
await self._engine_abort(finished_request)
for request_id in finished_request:
del self.request_streams[request_id]
async def _engine_abort(self, request_ids: Iterable[str]): async def _engine_abort(self, request_ids: Iterable[str]):
if self.engine_use_ray: if self.engine_use_ray:
@ -228,25 +329,19 @@ class AsyncLLMEngine:
f"sampling params: {sampling_params}, " f"sampling params: {sampling_params}, "
f"prompt token ids: {prompt_token_ids}.") f"prompt token ids: {prompt_token_ids}.")
if request_id in self.request_streams: if not self.is_running:
raise KeyError(f"Request {request_id} already exists.") raise AsyncEngineDeadError(
stream = AsyncStream(request_id) "Background loop is not running. If it was running, "
self.request_streams[request_id] = stream "inspect the output to find the stacktrace of the "
"error that caused the background loop to stop "
"(AsyncEngineDeadError).")
# Add the request into the vLLM engine's waiting queue. stream = self.request_tracker.add_request(
if self.engine_use_ray: request_id,
await self.engine.add_request.remote( prompt=prompt,
request_id, sampling_params=sampling_params,
prompt, prompt_token_ids=prompt_token_ids,
sampling_params, arrival_time=arrival_time)
prompt_token_ids=prompt_token_ids,
arrival_time=arrival_time)
else:
self.engine.add_request(request_id,
prompt,
sampling_params,
prompt_token_ids=prompt_token_ids,
arrival_time=arrival_time)
return stream return stream
@ -300,6 +395,13 @@ class AsyncLLMEngine:
Args: Args:
request_id: The unique id of the request. request_id: The unique id of the request.
""" """
if not self.is_running:
raise AsyncEngineDeadError(
"Background loop is not running. If it was running, "
"inspect the output to find the stacktrace of the "
"error that caused the background loop to stop "
"(AsyncEngineDeadError).")
return self._abort(request_id) return self._abort(request_id)
def _abort(self, request_id: str) -> None: def _abort(self, request_id: str) -> None:
@ -311,16 +413,8 @@ class AsyncLLMEngine:
Args: Args:
request_id: The unique id of the request. request_id: The unique id of the request.
""" """
if request_id not in self.request_streams or self.request_streams[ self.request_tracker.abort_request(request_id,
request_id].finished: verbose=self.log_requests)
# The request has already finished or been aborted.
return
if self.log_requests:
logger.info(f"Aborted request {request_id}.")
self.request_streams[request_id].finish()
self.finished_requests.put_nowait(request_id)
async def get_model_config(self) -> ModelConfig: async def get_model_config(self) -> ModelConfig:
"""Get the model configuration of the vLLM engine.""" """Get the model configuration of the vLLM engine."""

View File

@ -14,6 +14,7 @@ from vllm.utils import random_uuid
TIMEOUT_KEEP_ALIVE = 5 # seconds. TIMEOUT_KEEP_ALIVE = 5 # seconds.
TIMEOUT_TO_PREVENT_DEADLOCK = 1 # seconds. TIMEOUT_TO_PREVENT_DEADLOCK = 1 # seconds.
app = FastAPI() app = FastAPI()
engine = None
@app.post("/generate") @app.post("/generate")

View File

@ -44,6 +44,7 @@ TIMEOUT_KEEP_ALIVE = 5 # seconds
logger = init_logger(__name__) logger = init_logger(__name__)
served_model = None served_model = None
app = fastapi.FastAPI() app = fastapi.FastAPI()
engine = None
def create_error_response(status_code: HTTPStatus, def create_error_response(status_code: HTTPStatus,