Make AsyncLLMEngine
more robust & fix batched abort (#969)
Signed-off-by: Antoni Baum <antoni.baum@protonmail.com> Co-authored-by: Avnish Narayan <38871737+avnishn@users.noreply.github.com>
This commit is contained in:
parent
7a9c20c715
commit
c07ece5ca4
51
tests/async_engine/api_server_async_engine.py
Normal file
51
tests/async_engine/api_server_async_engine.py
Normal file
@ -0,0 +1,51 @@
|
||||
"""vllm.entrypoints.api_server with some extra logging for testing."""
|
||||
import argparse
|
||||
from typing import Any, Dict
|
||||
|
||||
import uvicorn
|
||||
from fastapi.responses import JSONResponse, Response
|
||||
|
||||
import vllm.entrypoints.api_server
|
||||
from vllm.engine.arg_utils import AsyncEngineArgs
|
||||
from vllm.engine.async_llm_engine import AsyncLLMEngine
|
||||
|
||||
app = vllm.entrypoints.api_server.app
|
||||
|
||||
|
||||
class AsyncLLMEngineWithStats(AsyncLLMEngine):
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self._num_aborts = 0
|
||||
|
||||
async def abort(self, request_id: str) -> None:
|
||||
await super().abort(request_id)
|
||||
self._num_aborts += 1
|
||||
|
||||
def testing_stats(self) -> Dict[str, Any]:
|
||||
return {"num_aborted_requests": self._num_aborts}
|
||||
|
||||
|
||||
@app.get("/stats")
|
||||
def stats() -> Response:
|
||||
"""Get the statistics of the engine."""
|
||||
return JSONResponse(engine.testing_stats())
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--host", type=str, default="localhost")
|
||||
parser.add_argument("--port", type=int, default=8000)
|
||||
parser = AsyncEngineArgs.add_cli_args(parser)
|
||||
args = parser.parse_args()
|
||||
|
||||
engine_args = AsyncEngineArgs.from_cli_args(args)
|
||||
engine = AsyncLLMEngineWithStats.from_engine_args(engine_args,
|
||||
start_engine_loop=False)
|
||||
vllm.entrypoints.api_server.engine = engine
|
||||
uvicorn.run(
|
||||
app,
|
||||
host=args.host,
|
||||
port=args.port,
|
||||
log_level="debug",
|
||||
timeout_keep_alive=vllm.entrypoints.api_server.TIMEOUT_KEEP_ALIVE)
|
86
tests/async_engine/test_api_server.py
Normal file
86
tests/async_engine/test_api_server.py
Normal file
@ -0,0 +1,86 @@
|
||||
import subprocess
|
||||
import sys
|
||||
import time
|
||||
from multiprocessing import Pool
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
import requests
|
||||
|
||||
|
||||
def _query_server(prompt: str) -> dict:
|
||||
response = requests.post("http://localhost:8000/generate",
|
||||
json={
|
||||
"prompt": prompt,
|
||||
"max_tokens": 100,
|
||||
"temperature": 0,
|
||||
"ignore_eos": True
|
||||
})
|
||||
response.raise_for_status()
|
||||
return response.json()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def api_server():
|
||||
script_path = Path(__file__).parent.joinpath(
|
||||
"api_server_async_engine.py").absolute()
|
||||
uvicorn_process = subprocess.Popen([
|
||||
sys.executable, "-u",
|
||||
str(script_path), "--model", "facebook/opt-125m"
|
||||
])
|
||||
yield
|
||||
uvicorn_process.terminate()
|
||||
|
||||
|
||||
def test_api_server(api_server):
|
||||
"""
|
||||
Run the API server and test it.
|
||||
|
||||
We run both the server and requests in separate processes.
|
||||
|
||||
We test that the server can handle incoming requests, including
|
||||
multiple requests at the same time, and that it can handle requests
|
||||
being cancelled without crashing.
|
||||
"""
|
||||
with Pool(32) as pool:
|
||||
# Wait until the server is ready
|
||||
prompts = ["Hello world"] * 1
|
||||
result = None
|
||||
while not result:
|
||||
try:
|
||||
for result in pool.map(_query_server, prompts):
|
||||
break
|
||||
except:
|
||||
time.sleep(1)
|
||||
|
||||
# Actual tests start here
|
||||
# Try with 1 prompt
|
||||
for result in pool.map(_query_server, prompts):
|
||||
assert result
|
||||
|
||||
num_aborted_requests = requests.get(
|
||||
"http://localhost:8000/stats").json()["num_aborted_requests"]
|
||||
assert num_aborted_requests == 0
|
||||
|
||||
# Try with 100 prompts
|
||||
prompts = ["Hello world"] * 100
|
||||
for result in pool.map(_query_server, prompts):
|
||||
assert result
|
||||
|
||||
# Cancel requests
|
||||
pool.map_async(_query_server, prompts)
|
||||
time.sleep(0.01)
|
||||
pool.terminate()
|
||||
pool.join()
|
||||
|
||||
# check cancellation stats
|
||||
num_aborted_requests = requests.get(
|
||||
"http://localhost:8000/stats").json()["num_aborted_requests"]
|
||||
assert num_aborted_requests > 0
|
||||
|
||||
# check that server still runs after cancellations
|
||||
with Pool(32) as pool:
|
||||
# Try with 100 prompts
|
||||
prompts = ["Hello world"] * 100
|
||||
for result in pool.map(_query_server, prompts):
|
||||
assert result
|
54
tests/async_engine/test_request_tracker.py
Normal file
54
tests/async_engine/test_request_tracker.py
Normal file
@ -0,0 +1,54 @@
|
||||
import pytest
|
||||
|
||||
from vllm.engine.async_llm_engine import RequestTracker
|
||||
from vllm.outputs import RequestOutput
|
||||
|
||||
|
||||
def test_request_tracker():
|
||||
tracker = RequestTracker()
|
||||
stream_1 = tracker.add_request("1")
|
||||
new, finished = tracker.get_new_and_finished_requests()
|
||||
assert len(new) == 1
|
||||
assert new[0]["request_id"] == "1"
|
||||
assert not finished
|
||||
assert not stream_1.finished
|
||||
|
||||
stream_2 = tracker.add_request("2")
|
||||
stream_3 = tracker.add_request("3")
|
||||
new, finished = tracker.get_new_and_finished_requests()
|
||||
assert len(new) == 2
|
||||
assert new[0]["request_id"] == "2"
|
||||
assert new[1]["request_id"] == "3"
|
||||
assert not finished
|
||||
assert not stream_2.finished
|
||||
assert not stream_3.finished
|
||||
|
||||
# request_ids must be unique
|
||||
with pytest.raises(KeyError):
|
||||
tracker.add_request("1")
|
||||
|
||||
tracker.abort_request("1")
|
||||
new, finished = tracker.get_new_and_finished_requests()
|
||||
assert len(finished) == 1
|
||||
assert "1" in finished
|
||||
assert not new
|
||||
assert stream_1.finished
|
||||
|
||||
stream_4 = tracker.add_request("4")
|
||||
tracker.abort_request("4")
|
||||
new, finished = tracker.get_new_and_finished_requests()
|
||||
assert len(finished) == 1
|
||||
assert "4" in finished
|
||||
assert not new
|
||||
assert stream_4.finished
|
||||
|
||||
stream_5 = tracker.add_request("5")
|
||||
tracker.process_request_output(
|
||||
RequestOutput("2", "output", [], [], finished=True))
|
||||
new, finished = tracker.get_new_and_finished_requests()
|
||||
assert len(finished) == 1
|
||||
assert "2" in finished
|
||||
assert len(new) == 1
|
||||
assert new[0]["request_id"] == "5"
|
||||
assert stream_2.finished
|
||||
assert not stream_5.finished
|
@ -92,7 +92,10 @@ class Scheduler:
|
||||
request_id = (request_id, )
|
||||
request_ids = set(request_id)
|
||||
for state_queue in [self.waiting, self.running, self.swapped]:
|
||||
for seq_group in state_queue:
|
||||
# We need to reverse the list as we are removing elements
|
||||
# from it as we iterate over it. If we don't do it,
|
||||
# indices will get messed up and we will skip over elements.
|
||||
for seq_group in reversed(state_queue):
|
||||
if seq_group.request_id in request_ids:
|
||||
# Remove the sequence group from the state queue.
|
||||
state_queue.remove(seq_group)
|
||||
|
@ -1,7 +1,7 @@
|
||||
import asyncio
|
||||
import time
|
||||
from functools import partial
|
||||
from typing import Any, Dict, Iterable, List, Optional, Type, Union
|
||||
from typing import Any, Dict, Iterable, List, Optional, Set, Tuple, Type, Union
|
||||
|
||||
from vllm.config import ModelConfig
|
||||
from vllm.engine.arg_utils import AsyncEngineArgs
|
||||
@ -14,6 +14,28 @@ from vllm.sampling_params import SamplingParams
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
class AsyncEngineDeadError(RuntimeError):
|
||||
pass
|
||||
|
||||
|
||||
def _raise_exception_on_finish(task: asyncio.Task,
|
||||
request_tracker: "RequestTracker") -> None:
|
||||
msg = ("Task finished unexpectedly. This should never happen! "
|
||||
"Please open an issue on Github.")
|
||||
try:
|
||||
try:
|
||||
task.result()
|
||||
except asyncio.CancelledError:
|
||||
return
|
||||
except Exception as exc:
|
||||
raise AsyncEngineDeadError(
|
||||
msg + " See stack trace above for the actual cause.") from exc
|
||||
raise AsyncEngineDeadError(msg)
|
||||
except Exception as exc:
|
||||
request_tracker.propagate_exception(exc)
|
||||
raise exc
|
||||
|
||||
|
||||
class AsyncStream:
|
||||
"""A stream of RequestOutputs for a request that can be
|
||||
iterated over asynchronously."""
|
||||
@ -43,15 +65,90 @@ class AsyncStream:
|
||||
result = await self._queue.get()
|
||||
if result is StopIteration:
|
||||
raise StopAsyncIteration
|
||||
elif isinstance(result, Exception):
|
||||
raise result
|
||||
return result
|
||||
|
||||
|
||||
def _raise_exception_on_finish(task: asyncio.Task) -> None:
|
||||
try:
|
||||
task.result()
|
||||
except Exception as e:
|
||||
raise RuntimeError("Task finished unexpectedly.") from e
|
||||
raise RuntimeError("Task finished unexpectedly.")
|
||||
class RequestTracker:
|
||||
"""Synchronous abstraction for tracking requests."""
|
||||
|
||||
def __init__(self) -> None:
|
||||
self._request_streams: Dict[str, AsyncStream] = {}
|
||||
self._finished_requests: asyncio.Queue[str] = asyncio.Queue()
|
||||
self._new_requests: asyncio.Queue[Tuple[AsyncStream,
|
||||
dict]] = asyncio.Queue()
|
||||
|
||||
def __contains__(self, item):
|
||||
return item in self._request_streams
|
||||
|
||||
def propagate_exception(self, exc: Exception) -> None:
|
||||
"""Propagate an exception to all request streams."""
|
||||
for stream in self._request_streams.values():
|
||||
stream.put(exc)
|
||||
|
||||
def process_request_output(self,
|
||||
request_output: RequestOutput,
|
||||
*,
|
||||
verbose: bool = False) -> None:
|
||||
"""Process a request output from the engine."""
|
||||
request_id = request_output.request_id
|
||||
|
||||
self._request_streams[request_id].put(request_output)
|
||||
if request_output.finished:
|
||||
if verbose:
|
||||
logger.info(f"Finished request {request_id}.")
|
||||
self.abort_request(request_id)
|
||||
|
||||
def add_request(self, request_id: str,
|
||||
**engine_add_request_kwargs) -> AsyncStream:
|
||||
"""Add a request to be sent to the engine on the next background
|
||||
loop iteration."""
|
||||
if request_id in self._request_streams:
|
||||
raise KeyError(f"Request {request_id} already exists.")
|
||||
|
||||
stream = AsyncStream(request_id)
|
||||
self._new_requests.put_nowait((stream, {
|
||||
"request_id": request_id,
|
||||
**engine_add_request_kwargs
|
||||
}))
|
||||
return stream
|
||||
|
||||
def abort_request(self, request_id: str, *, verbose: bool = False) -> None:
|
||||
"""Abort a request during next background loop iteration."""
|
||||
if verbose:
|
||||
logger.info(f"Aborted request {request_id}.")
|
||||
|
||||
self._finished_requests.put_nowait(request_id)
|
||||
|
||||
if request_id not in self._request_streams or self._request_streams[
|
||||
request_id].finished:
|
||||
# The request has already finished or been aborted.
|
||||
return
|
||||
|
||||
self._request_streams[request_id].finish()
|
||||
|
||||
def get_new_and_finished_requests(self) -> Tuple[List[dict], Set[str]]:
|
||||
"""Get the new requests and finished requests to be
|
||||
sent to the engine."""
|
||||
new_requests: List[dict] = []
|
||||
finished_requests: Set[str] = set()
|
||||
|
||||
while not self._finished_requests.empty():
|
||||
request_id = self._finished_requests.get_nowait()
|
||||
finished_requests.add(request_id)
|
||||
self._request_streams.pop(request_id, None)
|
||||
|
||||
while not self._new_requests.empty():
|
||||
stream, new_request = self._new_requests.get_nowait()
|
||||
if stream.request_id in finished_requests:
|
||||
# The request has already been aborted.
|
||||
stream.finish()
|
||||
continue
|
||||
self._request_streams[stream.request_id] = stream
|
||||
new_requests.append(new_request)
|
||||
|
||||
return new_requests, finished_requests
|
||||
|
||||
|
||||
class _AsyncLLMEngine(LLMEngine):
|
||||
@ -150,16 +247,15 @@ class AsyncLLMEngine:
|
||||
self.log_requests = log_requests
|
||||
self.engine = self._init_engine(*args, **kwargs)
|
||||
|
||||
# Request id -> stream.
|
||||
self.request_streams: Dict[str, AsyncStream] = {}
|
||||
self.finished_requests: asyncio.Queue[str] = asyncio.Queue()
|
||||
self.request_tracker: RequestTracker = RequestTracker()
|
||||
self.background_loop = None
|
||||
if start_engine_loop:
|
||||
self.start_background_loop()
|
||||
|
||||
@property
|
||||
def is_running(self) -> bool:
|
||||
return self.background_loop is not None
|
||||
return (self.background_loop is not None
|
||||
and not self.background_loop.done())
|
||||
|
||||
def start_background_loop(self) -> None:
|
||||
"""Start the background loop."""
|
||||
@ -167,7 +263,9 @@ class AsyncLLMEngine:
|
||||
raise RuntimeError("Background loop is already running.")
|
||||
self.background_loop = asyncio.get_event_loop().create_task(
|
||||
self.run_engine_loop())
|
||||
self.background_loop.add_done_callback(_raise_exception_on_finish)
|
||||
self.background_loop.add_done_callback(
|
||||
partial(_raise_exception_on_finish,
|
||||
request_tracker=self.request_tracker))
|
||||
|
||||
def _init_engine(self, *args,
|
||||
**kwargs) -> Union[_AsyncLLMEngine, "ray.ObjectRef"]:
|
||||
@ -181,6 +279,21 @@ class AsyncLLMEngine:
|
||||
|
||||
async def engine_step(self):
|
||||
"""Kick the engine to process the waiting requests."""
|
||||
|
||||
new_requests, finished_requests = (
|
||||
self.request_tracker.get_new_and_finished_requests())
|
||||
|
||||
for new_request in new_requests:
|
||||
# Add the request into the vLLM engine's waiting queue.
|
||||
# TODO: Maybe add add_request_batch to reduce Ray overhead
|
||||
if self.engine_use_ray:
|
||||
await self.engine.add_request.remote(**new_request)
|
||||
else:
|
||||
self.engine.add_request(**new_request)
|
||||
|
||||
if finished_requests:
|
||||
await self._engine_abort(finished_requests)
|
||||
|
||||
if self.engine_use_ray:
|
||||
request_outputs = await self.engine.step.remote()
|
||||
else:
|
||||
@ -188,20 +301,8 @@ class AsyncLLMEngine:
|
||||
|
||||
# Put the outputs into the corresponding streams.
|
||||
for request_output in request_outputs:
|
||||
request_id = request_output.request_id
|
||||
self.request_streams[request_id].put(request_output)
|
||||
if request_output.finished:
|
||||
if self.log_requests:
|
||||
logger.info(f"Finished request {request_id}.")
|
||||
self.request_streams[request_id].finish()
|
||||
self.finished_requests.put_nowait(request_id)
|
||||
|
||||
finished_request = set()
|
||||
while not self.finished_requests.empty():
|
||||
finished_request.add(self.finished_requests.get_nowait())
|
||||
await self._engine_abort(finished_request)
|
||||
for request_id in finished_request:
|
||||
del self.request_streams[request_id]
|
||||
self.request_tracker.process_request_output(
|
||||
request_output, verbose=self.log_requests)
|
||||
|
||||
async def _engine_abort(self, request_ids: Iterable[str]):
|
||||
if self.engine_use_ray:
|
||||
@ -228,25 +329,19 @@ class AsyncLLMEngine:
|
||||
f"sampling params: {sampling_params}, "
|
||||
f"prompt token ids: {prompt_token_ids}.")
|
||||
|
||||
if request_id in self.request_streams:
|
||||
raise KeyError(f"Request {request_id} already exists.")
|
||||
stream = AsyncStream(request_id)
|
||||
self.request_streams[request_id] = stream
|
||||
if not self.is_running:
|
||||
raise AsyncEngineDeadError(
|
||||
"Background loop is not running. If it was running, "
|
||||
"inspect the output to find the stacktrace of the "
|
||||
"error that caused the background loop to stop "
|
||||
"(AsyncEngineDeadError).")
|
||||
|
||||
# Add the request into the vLLM engine's waiting queue.
|
||||
if self.engine_use_ray:
|
||||
await self.engine.add_request.remote(
|
||||
request_id,
|
||||
prompt,
|
||||
sampling_params,
|
||||
prompt_token_ids=prompt_token_ids,
|
||||
arrival_time=arrival_time)
|
||||
else:
|
||||
self.engine.add_request(request_id,
|
||||
prompt,
|
||||
sampling_params,
|
||||
prompt_token_ids=prompt_token_ids,
|
||||
arrival_time=arrival_time)
|
||||
stream = self.request_tracker.add_request(
|
||||
request_id,
|
||||
prompt=prompt,
|
||||
sampling_params=sampling_params,
|
||||
prompt_token_ids=prompt_token_ids,
|
||||
arrival_time=arrival_time)
|
||||
|
||||
return stream
|
||||
|
||||
@ -300,6 +395,13 @@ class AsyncLLMEngine:
|
||||
Args:
|
||||
request_id: The unique id of the request.
|
||||
"""
|
||||
if not self.is_running:
|
||||
raise AsyncEngineDeadError(
|
||||
"Background loop is not running. If it was running, "
|
||||
"inspect the output to find the stacktrace of the "
|
||||
"error that caused the background loop to stop "
|
||||
"(AsyncEngineDeadError).")
|
||||
|
||||
return self._abort(request_id)
|
||||
|
||||
def _abort(self, request_id: str) -> None:
|
||||
@ -311,16 +413,8 @@ class AsyncLLMEngine:
|
||||
Args:
|
||||
request_id: The unique id of the request.
|
||||
"""
|
||||
if request_id not in self.request_streams or self.request_streams[
|
||||
request_id].finished:
|
||||
# The request has already finished or been aborted.
|
||||
return
|
||||
|
||||
if self.log_requests:
|
||||
logger.info(f"Aborted request {request_id}.")
|
||||
|
||||
self.request_streams[request_id].finish()
|
||||
self.finished_requests.put_nowait(request_id)
|
||||
self.request_tracker.abort_request(request_id,
|
||||
verbose=self.log_requests)
|
||||
|
||||
async def get_model_config(self) -> ModelConfig:
|
||||
"""Get the model configuration of the vLLM engine."""
|
||||
|
@ -14,6 +14,7 @@ from vllm.utils import random_uuid
|
||||
TIMEOUT_KEEP_ALIVE = 5 # seconds.
|
||||
TIMEOUT_TO_PREVENT_DEADLOCK = 1 # seconds.
|
||||
app = FastAPI()
|
||||
engine = None
|
||||
|
||||
|
||||
@app.post("/generate")
|
||||
|
@ -44,6 +44,7 @@ TIMEOUT_KEEP_ALIVE = 5 # seconds
|
||||
logger = init_logger(__name__)
|
||||
served_model = None
|
||||
app = fastapi.FastAPI()
|
||||
engine = None
|
||||
|
||||
|
||||
def create_error_response(status_code: HTTPStatus,
|
||||
|
Loading…
x
Reference in New Issue
Block a user