Remove AsyncLLMEngine busy loop, shield background task (#1059)

This commit is contained in:
Antoni Baum 2023-09-17 00:29:08 -07:00 committed by GitHub
parent e3e79e9e8a
commit ff36139ffc
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 154 additions and 18 deletions

View File

@ -11,3 +11,4 @@ types-setuptools
# testing # testing
pytest pytest
pytest-forked pytest-forked
pytest-asyncio

View File

@ -0,0 +1,80 @@
import asyncio
from dataclasses import dataclass
import pytest
from vllm.engine.async_llm_engine import AsyncLLMEngine
@dataclass
class RequestOutput:
request_id: int
finished: bool = False
class MockEngine:
def __init__(self):
self.step_calls = 0
self.add_request_calls = 0
self.abort_request_calls = 0
self.request_id = None
async def step_async(self):
self.step_calls += 1
return [RequestOutput(
request_id=self.request_id)] if self.request_id else []
def generate(self, request_id):
self.request_id = request_id
def stop_generating(self):
self.request_id = None
def add_request(self, **kwargs):
self.add_request_calls += 1
return
def abort_request(self, request_id):
self.abort_request_calls += 1
return
class MockAsyncLLMEngine(AsyncLLMEngine):
def _init_engine(self, *args, **kwargs):
return MockEngine()
@pytest.mark.asyncio
async def test_new_requests_event():
engine = MockAsyncLLMEngine(worker_use_ray=False, engine_use_ray=False)
engine.start_background_loop()
await asyncio.sleep(0.01)
assert engine.engine.step_calls == 0
await engine.add_request("1", "", None)
await asyncio.sleep(0.01)
assert engine.engine.add_request_calls == 1
assert engine.engine.step_calls == 1
await engine.add_request("2", "", None)
engine.engine.generate("2")
await asyncio.sleep(0)
assert engine.engine.add_request_calls == 2
assert engine.engine.step_calls == 2
await asyncio.sleep(0)
assert engine.engine.step_calls == 3
engine.engine.stop_generating()
await asyncio.sleep(0)
assert engine.engine.step_calls == 4
await asyncio.sleep(0)
assert engine.engine.step_calls == 4
await engine.add_request("3", "", None)
await asyncio.sleep(0.01)
assert engine.engine.add_request_calls == 3
assert engine.engine.step_calls == 5
await asyncio.sleep(0.01)
assert engine.engine.add_request_calls == 3
assert engine.engine.step_calls == 5

View File

@ -4,10 +4,25 @@ from vllm.engine.async_llm_engine import RequestTracker
from vllm.outputs import RequestOutput from vllm.outputs import RequestOutput
class DummyEvent:
def __init__(self):
self._flag = False
def set(self):
self._flag = True
def clear(self):
self._flag = False
def test_request_tracker(): def test_request_tracker():
tracker = RequestTracker() tracker = RequestTracker()
tracker.new_requests_event = DummyEvent()
stream_1 = tracker.add_request("1") stream_1 = tracker.add_request("1")
assert tracker.new_requests_event._flag
new, finished = tracker.get_new_and_finished_requests() new, finished = tracker.get_new_and_finished_requests()
assert not tracker.new_requests_event._flag
assert len(new) == 1 assert len(new) == 1
assert new[0]["request_id"] == "1" assert new[0]["request_id"] == "1"
assert not finished assert not finished
@ -15,7 +30,9 @@ def test_request_tracker():
stream_2 = tracker.add_request("2") stream_2 = tracker.add_request("2")
stream_3 = tracker.add_request("3") stream_3 = tracker.add_request("3")
assert tracker.new_requests_event._flag
new, finished = tracker.get_new_and_finished_requests() new, finished = tracker.get_new_and_finished_requests()
assert not tracker.new_requests_event._flag
assert len(new) == 2 assert len(new) == 2
assert new[0]["request_id"] == "2" assert new[0]["request_id"] == "2"
assert new[1]["request_id"] == "3" assert new[1]["request_id"] == "3"
@ -26,6 +43,7 @@ def test_request_tracker():
# request_ids must be unique # request_ids must be unique
with pytest.raises(KeyError): with pytest.raises(KeyError):
tracker.add_request("1") tracker.add_request("1")
assert not tracker.new_requests_event._flag
tracker.abort_request("1") tracker.abort_request("1")
new, finished = tracker.get_new_and_finished_requests() new, finished = tracker.get_new_and_finished_requests()
@ -36,6 +54,7 @@ def test_request_tracker():
stream_4 = tracker.add_request("4") stream_4 = tracker.add_request("4")
tracker.abort_request("4") tracker.abort_request("4")
assert tracker.new_requests_event._flag
new, finished = tracker.get_new_and_finished_requests() new, finished = tracker.get_new_and_finished_requests()
assert len(finished) == 1 assert len(finished) == 1
assert "4" in finished assert "4" in finished
@ -43,9 +62,11 @@ def test_request_tracker():
assert stream_4.finished assert stream_4.finished
stream_5 = tracker.add_request("5") stream_5 = tracker.add_request("5")
assert tracker.new_requests_event._flag
tracker.process_request_output( tracker.process_request_output(
RequestOutput("2", "output", [], [], finished=True)) RequestOutput("2", "output", [], [], finished=True))
new, finished = tracker.get_new_and_finished_requests() new, finished = tracker.get_new_and_finished_requests()
assert not tracker.new_requests_event._flag
assert len(finished) == 1 assert len(finished) == 1
assert "2" in finished assert "2" in finished
assert len(new) == 1 assert len(new) == 1

View File

@ -1,7 +1,8 @@
import asyncio import asyncio
import time import time
from functools import partial from functools import partial
from typing import Any, Dict, Iterable, List, Optional, Set, Tuple, Type, Union from typing import (Any, Dict, Iterable, List, Optional, Set, Tuple, Type,
Union)
from vllm.config import ModelConfig from vllm.config import ModelConfig
from vllm.engine.arg_utils import AsyncEngineArgs from vllm.engine.arg_utils import AsyncEngineArgs
@ -78,14 +79,24 @@ class RequestTracker:
self._finished_requests: asyncio.Queue[str] = asyncio.Queue() self._finished_requests: asyncio.Queue[str] = asyncio.Queue()
self._new_requests: asyncio.Queue[Tuple[AsyncStream, self._new_requests: asyncio.Queue[Tuple[AsyncStream,
dict]] = asyncio.Queue() dict]] = asyncio.Queue()
self.new_requests_event = None
def __contains__(self, item): def __contains__(self, item):
return item in self._request_streams return item in self._request_streams
def propagate_exception(self, exc: Exception) -> None: def init_event(self):
"""Propagate an exception to all request streams.""" self.new_requests_event = asyncio.Event()
for stream in self._request_streams.values():
stream.put(exc) def propagate_exception(self,
exc: Exception,
request_id: Optional[str] = None) -> None:
"""Propagate an exception to request streams
(all if request_id is None)."""
if request_id is not None:
self._request_streams[request_id].put(exc)
else:
for stream in self._request_streams.values():
stream.put(exc)
def process_request_output(self, def process_request_output(self,
request_output: RequestOutput, request_output: RequestOutput,
@ -112,6 +123,9 @@ class RequestTracker:
"request_id": request_id, "request_id": request_id,
**engine_add_request_kwargs **engine_add_request_kwargs
})) }))
self.new_requests_event.set()
return stream return stream
def abort_request(self, request_id: str, *, verbose: bool = False) -> None: def abort_request(self, request_id: str, *, verbose: bool = False) -> None:
@ -148,8 +162,13 @@ class RequestTracker:
self._request_streams[stream.request_id] = stream self._request_streams[stream.request_id] = stream
new_requests.append(new_request) new_requests.append(new_request)
self.new_requests_event.clear()
return new_requests, finished_requests return new_requests, finished_requests
async def wait_for_new_requests(self):
await self.new_requests_event.wait()
class _AsyncLLMEngine(LLMEngine): class _AsyncLLMEngine(LLMEngine):
"""Extension of LLMEngine to add async methods.""" """Extension of LLMEngine to add async methods."""
@ -251,9 +270,13 @@ class AsyncLLMEngine:
self.max_log_len = max_log_len self.max_log_len = max_log_len
self.engine = self._init_engine(*args, **kwargs) self.engine = self._init_engine(*args, **kwargs)
self.request_tracker: RequestTracker = RequestTracker()
self.background_loop = None self.background_loop = None
# We need to keep a reference to unshielded
# task as well to prevent it from being garbage
# collected
self._background_loop_unshielded = None
self.start_engine_loop = start_engine_loop self.start_engine_loop = start_engine_loop
self._request_tracker = RequestTracker()
@property @property
def is_running(self) -> bool: def is_running(self) -> bool:
@ -264,11 +287,14 @@ class AsyncLLMEngine:
"""Start the background loop.""" """Start the background loop."""
if self.is_running: if self.is_running:
raise RuntimeError("Background loop is already running.") raise RuntimeError("Background loop is already running.")
self.background_loop = asyncio.get_event_loop().create_task( self._request_tracker.init_event()
self.run_engine_loop())
self.background_loop.add_done_callback( self._background_loop_unshielded = asyncio.get_event_loop(
).create_task(self.run_engine_loop())
self._background_loop_unshielded.add_done_callback(
partial(_raise_exception_on_finish, partial(_raise_exception_on_finish,
request_tracker=self.request_tracker)) request_tracker=self._request_tracker))
self.background_loop = asyncio.shield(self._background_loop_unshielded)
def _init_engine(self, *args, def _init_engine(self, *args,
**kwargs) -> Union[_AsyncLLMEngine, "ray.ObjectRef"]: **kwargs) -> Union[_AsyncLLMEngine, "ray.ObjectRef"]:
@ -280,11 +306,13 @@ class AsyncLLMEngine:
engine_class = ray.remote(num_gpus=1)(self._engine_class).remote engine_class = ray.remote(num_gpus=1)(self._engine_class).remote
return engine_class(*args, **kwargs) return engine_class(*args, **kwargs)
async def engine_step(self): async def engine_step(self) -> bool:
"""Kick the engine to process the waiting requests.""" """Kick the engine to process the waiting requests.
Returns True if there are in-progress requests."""
new_requests, finished_requests = ( new_requests, finished_requests = (
self.request_tracker.get_new_and_finished_requests()) self._request_tracker.get_new_and_finished_requests())
for new_request in new_requests: for new_request in new_requests:
# Add the request into the vLLM engine's waiting queue. # Add the request into the vLLM engine's waiting queue.
@ -304,9 +332,11 @@ class AsyncLLMEngine:
# Put the outputs into the corresponding streams. # Put the outputs into the corresponding streams.
for request_output in request_outputs: for request_output in request_outputs:
self.request_tracker.process_request_output( self._request_tracker.process_request_output(
request_output, verbose=self.log_requests) request_output, verbose=self.log_requests)
return len(request_outputs) > 0
async def _engine_abort(self, request_ids: Iterable[str]): async def _engine_abort(self, request_ids: Iterable[str]):
if self.engine_use_ray: if self.engine_use_ray:
await self.engine.abort_request.remote(request_ids) await self.engine.abort_request.remote(request_ids)
@ -314,8 +344,12 @@ class AsyncLLMEngine:
self.engine.abort_request(request_ids) self.engine.abort_request(request_ids)
async def run_engine_loop(self): async def run_engine_loop(self):
# Initialize the RequestTracker here so it uses the right event loop.
has_requests_in_progress = False
while True: while True:
await self.engine_step() if not has_requests_in_progress:
await self._request_tracker.wait_for_new_requests()
has_requests_in_progress = await self.engine_step()
await asyncio.sleep(0) await asyncio.sleep(0)
async def add_request( async def add_request(
@ -350,7 +384,7 @@ class AsyncLLMEngine:
"error that caused the background loop to stop " "error that caused the background loop to stop "
"(AsyncEngineDeadError).") "(AsyncEngineDeadError).")
stream = self.request_tracker.add_request( stream = self._request_tracker.add_request(
request_id, request_id,
prompt=prompt, prompt=prompt,
sampling_params=sampling_params, sampling_params=sampling_params,
@ -428,8 +462,8 @@ class AsyncLLMEngine:
Args: Args:
request_id: The unique id of the request. request_id: The unique id of the request.
""" """
self.request_tracker.abort_request(request_id, self._request_tracker.abort_request(request_id,
verbose=self.log_requests) verbose=self.log_requests)
async def get_model_config(self) -> ModelConfig: async def get_model_config(self) -> ModelConfig:
"""Get the model configuration of the vLLM engine.""" """Get the model configuration of the vLLM engine."""