Remove AsyncLLMEngine busy loop, shield background task (#1059)

This commit is contained in:
Antoni Baum 2023-09-17 00:29:08 -07:00 committed by GitHub
parent e3e79e9e8a
commit ff36139ffc
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 154 additions and 18 deletions

View File

@ -11,3 +11,4 @@ types-setuptools
# testing
pytest
pytest-forked
pytest-asyncio

View File

@ -0,0 +1,80 @@
import asyncio
from dataclasses import dataclass
import pytest
from vllm.engine.async_llm_engine import AsyncLLMEngine
@dataclass
class RequestOutput:
request_id: int
finished: bool = False
class MockEngine:
def __init__(self):
self.step_calls = 0
self.add_request_calls = 0
self.abort_request_calls = 0
self.request_id = None
async def step_async(self):
self.step_calls += 1
return [RequestOutput(
request_id=self.request_id)] if self.request_id else []
def generate(self, request_id):
self.request_id = request_id
def stop_generating(self):
self.request_id = None
def add_request(self, **kwargs):
self.add_request_calls += 1
return
def abort_request(self, request_id):
self.abort_request_calls += 1
return
class MockAsyncLLMEngine(AsyncLLMEngine):
def _init_engine(self, *args, **kwargs):
return MockEngine()
@pytest.mark.asyncio
async def test_new_requests_event():
engine = MockAsyncLLMEngine(worker_use_ray=False, engine_use_ray=False)
engine.start_background_loop()
await asyncio.sleep(0.01)
assert engine.engine.step_calls == 0
await engine.add_request("1", "", None)
await asyncio.sleep(0.01)
assert engine.engine.add_request_calls == 1
assert engine.engine.step_calls == 1
await engine.add_request("2", "", None)
engine.engine.generate("2")
await asyncio.sleep(0)
assert engine.engine.add_request_calls == 2
assert engine.engine.step_calls == 2
await asyncio.sleep(0)
assert engine.engine.step_calls == 3
engine.engine.stop_generating()
await asyncio.sleep(0)
assert engine.engine.step_calls == 4
await asyncio.sleep(0)
assert engine.engine.step_calls == 4
await engine.add_request("3", "", None)
await asyncio.sleep(0.01)
assert engine.engine.add_request_calls == 3
assert engine.engine.step_calls == 5
await asyncio.sleep(0.01)
assert engine.engine.add_request_calls == 3
assert engine.engine.step_calls == 5

View File

@ -4,10 +4,25 @@ from vllm.engine.async_llm_engine import RequestTracker
from vllm.outputs import RequestOutput
class DummyEvent:
def __init__(self):
self._flag = False
def set(self):
self._flag = True
def clear(self):
self._flag = False
def test_request_tracker():
tracker = RequestTracker()
tracker.new_requests_event = DummyEvent()
stream_1 = tracker.add_request("1")
assert tracker.new_requests_event._flag
new, finished = tracker.get_new_and_finished_requests()
assert not tracker.new_requests_event._flag
assert len(new) == 1
assert new[0]["request_id"] == "1"
assert not finished
@ -15,7 +30,9 @@ def test_request_tracker():
stream_2 = tracker.add_request("2")
stream_3 = tracker.add_request("3")
assert tracker.new_requests_event._flag
new, finished = tracker.get_new_and_finished_requests()
assert not tracker.new_requests_event._flag
assert len(new) == 2
assert new[0]["request_id"] == "2"
assert new[1]["request_id"] == "3"
@ -26,6 +43,7 @@ def test_request_tracker():
# request_ids must be unique
with pytest.raises(KeyError):
tracker.add_request("1")
assert not tracker.new_requests_event._flag
tracker.abort_request("1")
new, finished = tracker.get_new_and_finished_requests()
@ -36,6 +54,7 @@ def test_request_tracker():
stream_4 = tracker.add_request("4")
tracker.abort_request("4")
assert tracker.new_requests_event._flag
new, finished = tracker.get_new_and_finished_requests()
assert len(finished) == 1
assert "4" in finished
@ -43,9 +62,11 @@ def test_request_tracker():
assert stream_4.finished
stream_5 = tracker.add_request("5")
assert tracker.new_requests_event._flag
tracker.process_request_output(
RequestOutput("2", "output", [], [], finished=True))
new, finished = tracker.get_new_and_finished_requests()
assert not tracker.new_requests_event._flag
assert len(finished) == 1
assert "2" in finished
assert len(new) == 1

View File

@ -1,7 +1,8 @@
import asyncio
import time
from functools import partial
from typing import Any, Dict, Iterable, List, Optional, Set, Tuple, Type, Union
from typing import (Any, Dict, Iterable, List, Optional, Set, Tuple, Type,
Union)
from vllm.config import ModelConfig
from vllm.engine.arg_utils import AsyncEngineArgs
@ -78,14 +79,24 @@ class RequestTracker:
self._finished_requests: asyncio.Queue[str] = asyncio.Queue()
self._new_requests: asyncio.Queue[Tuple[AsyncStream,
dict]] = asyncio.Queue()
self.new_requests_event = None
def __contains__(self, item):
return item in self._request_streams
def propagate_exception(self, exc: Exception) -> None:
"""Propagate an exception to all request streams."""
for stream in self._request_streams.values():
stream.put(exc)
def init_event(self):
self.new_requests_event = asyncio.Event()
def propagate_exception(self,
exc: Exception,
request_id: Optional[str] = None) -> None:
"""Propagate an exception to request streams
(all if request_id is None)."""
if request_id is not None:
self._request_streams[request_id].put(exc)
else:
for stream in self._request_streams.values():
stream.put(exc)
def process_request_output(self,
request_output: RequestOutput,
@ -112,6 +123,9 @@ class RequestTracker:
"request_id": request_id,
**engine_add_request_kwargs
}))
self.new_requests_event.set()
return stream
def abort_request(self, request_id: str, *, verbose: bool = False) -> None:
@ -148,8 +162,13 @@ class RequestTracker:
self._request_streams[stream.request_id] = stream
new_requests.append(new_request)
self.new_requests_event.clear()
return new_requests, finished_requests
async def wait_for_new_requests(self):
await self.new_requests_event.wait()
class _AsyncLLMEngine(LLMEngine):
"""Extension of LLMEngine to add async methods."""
@ -251,9 +270,13 @@ class AsyncLLMEngine:
self.max_log_len = max_log_len
self.engine = self._init_engine(*args, **kwargs)
self.request_tracker: RequestTracker = RequestTracker()
self.background_loop = None
# We need to keep a reference to unshielded
# task as well to prevent it from being garbage
# collected
self._background_loop_unshielded = None
self.start_engine_loop = start_engine_loop
self._request_tracker = RequestTracker()
@property
def is_running(self) -> bool:
@ -264,11 +287,14 @@ class AsyncLLMEngine:
"""Start the background loop."""
if self.is_running:
raise RuntimeError("Background loop is already running.")
self.background_loop = asyncio.get_event_loop().create_task(
self.run_engine_loop())
self.background_loop.add_done_callback(
self._request_tracker.init_event()
self._background_loop_unshielded = asyncio.get_event_loop(
).create_task(self.run_engine_loop())
self._background_loop_unshielded.add_done_callback(
partial(_raise_exception_on_finish,
request_tracker=self.request_tracker))
request_tracker=self._request_tracker))
self.background_loop = asyncio.shield(self._background_loop_unshielded)
def _init_engine(self, *args,
**kwargs) -> Union[_AsyncLLMEngine, "ray.ObjectRef"]:
@ -280,11 +306,13 @@ class AsyncLLMEngine:
engine_class = ray.remote(num_gpus=1)(self._engine_class).remote
return engine_class(*args, **kwargs)
async def engine_step(self):
"""Kick the engine to process the waiting requests."""
async def engine_step(self) -> bool:
"""Kick the engine to process the waiting requests.
Returns True if there are in-progress requests."""
new_requests, finished_requests = (
self.request_tracker.get_new_and_finished_requests())
self._request_tracker.get_new_and_finished_requests())
for new_request in new_requests:
# Add the request into the vLLM engine's waiting queue.
@ -304,9 +332,11 @@ class AsyncLLMEngine:
# Put the outputs into the corresponding streams.
for request_output in request_outputs:
self.request_tracker.process_request_output(
self._request_tracker.process_request_output(
request_output, verbose=self.log_requests)
return len(request_outputs) > 0
async def _engine_abort(self, request_ids: Iterable[str]):
if self.engine_use_ray:
await self.engine.abort_request.remote(request_ids)
@ -314,8 +344,12 @@ class AsyncLLMEngine:
self.engine.abort_request(request_ids)
async def run_engine_loop(self):
# Initialize the RequestTracker here so it uses the right event loop.
has_requests_in_progress = False
while True:
await self.engine_step()
if not has_requests_in_progress:
await self._request_tracker.wait_for_new_requests()
has_requests_in_progress = await self.engine_step()
await asyncio.sleep(0)
async def add_request(
@ -350,7 +384,7 @@ class AsyncLLMEngine:
"error that caused the background loop to stop "
"(AsyncEngineDeadError).")
stream = self.request_tracker.add_request(
stream = self._request_tracker.add_request(
request_id,
prompt=prompt,
sampling_params=sampling_params,
@ -428,8 +462,8 @@ class AsyncLLMEngine:
Args:
request_id: The unique id of the request.
"""
self.request_tracker.abort_request(request_id,
verbose=self.log_requests)
self._request_tracker.abort_request(request_id,
verbose=self.log_requests)
async def get_model_config(self) -> ModelConfig:
"""Get the model configuration of the vLLM engine."""