Remove AsyncLLMEngine busy loop, shield background task (#1059)
This commit is contained in:
parent
e3e79e9e8a
commit
ff36139ffc
@ -11,3 +11,4 @@ types-setuptools
|
|||||||
# testing
|
# testing
|
||||||
pytest
|
pytest
|
||||||
pytest-forked
|
pytest-forked
|
||||||
|
pytest-asyncio
|
||||||
|
80
tests/async_engine/test_async_llm_engine.py
Normal file
80
tests/async_engine/test_async_llm_engine.py
Normal file
@ -0,0 +1,80 @@
|
|||||||
|
import asyncio
|
||||||
|
from dataclasses import dataclass
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from vllm.engine.async_llm_engine import AsyncLLMEngine
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class RequestOutput:
|
||||||
|
request_id: int
|
||||||
|
finished: bool = False
|
||||||
|
|
||||||
|
|
||||||
|
class MockEngine:
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
self.step_calls = 0
|
||||||
|
self.add_request_calls = 0
|
||||||
|
self.abort_request_calls = 0
|
||||||
|
self.request_id = None
|
||||||
|
|
||||||
|
async def step_async(self):
|
||||||
|
self.step_calls += 1
|
||||||
|
return [RequestOutput(
|
||||||
|
request_id=self.request_id)] if self.request_id else []
|
||||||
|
|
||||||
|
def generate(self, request_id):
|
||||||
|
self.request_id = request_id
|
||||||
|
|
||||||
|
def stop_generating(self):
|
||||||
|
self.request_id = None
|
||||||
|
|
||||||
|
def add_request(self, **kwargs):
|
||||||
|
self.add_request_calls += 1
|
||||||
|
return
|
||||||
|
|
||||||
|
def abort_request(self, request_id):
|
||||||
|
self.abort_request_calls += 1
|
||||||
|
return
|
||||||
|
|
||||||
|
|
||||||
|
class MockAsyncLLMEngine(AsyncLLMEngine):
|
||||||
|
|
||||||
|
def _init_engine(self, *args, **kwargs):
|
||||||
|
return MockEngine()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_new_requests_event():
|
||||||
|
engine = MockAsyncLLMEngine(worker_use_ray=False, engine_use_ray=False)
|
||||||
|
engine.start_background_loop()
|
||||||
|
await asyncio.sleep(0.01)
|
||||||
|
assert engine.engine.step_calls == 0
|
||||||
|
|
||||||
|
await engine.add_request("1", "", None)
|
||||||
|
await asyncio.sleep(0.01)
|
||||||
|
assert engine.engine.add_request_calls == 1
|
||||||
|
assert engine.engine.step_calls == 1
|
||||||
|
|
||||||
|
await engine.add_request("2", "", None)
|
||||||
|
engine.engine.generate("2")
|
||||||
|
await asyncio.sleep(0)
|
||||||
|
assert engine.engine.add_request_calls == 2
|
||||||
|
assert engine.engine.step_calls == 2
|
||||||
|
await asyncio.sleep(0)
|
||||||
|
assert engine.engine.step_calls == 3
|
||||||
|
engine.engine.stop_generating()
|
||||||
|
await asyncio.sleep(0)
|
||||||
|
assert engine.engine.step_calls == 4
|
||||||
|
await asyncio.sleep(0)
|
||||||
|
assert engine.engine.step_calls == 4
|
||||||
|
|
||||||
|
await engine.add_request("3", "", None)
|
||||||
|
await asyncio.sleep(0.01)
|
||||||
|
assert engine.engine.add_request_calls == 3
|
||||||
|
assert engine.engine.step_calls == 5
|
||||||
|
await asyncio.sleep(0.01)
|
||||||
|
assert engine.engine.add_request_calls == 3
|
||||||
|
assert engine.engine.step_calls == 5
|
@ -4,10 +4,25 @@ from vllm.engine.async_llm_engine import RequestTracker
|
|||||||
from vllm.outputs import RequestOutput
|
from vllm.outputs import RequestOutput
|
||||||
|
|
||||||
|
|
||||||
|
class DummyEvent:
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
self._flag = False
|
||||||
|
|
||||||
|
def set(self):
|
||||||
|
self._flag = True
|
||||||
|
|
||||||
|
def clear(self):
|
||||||
|
self._flag = False
|
||||||
|
|
||||||
|
|
||||||
def test_request_tracker():
|
def test_request_tracker():
|
||||||
tracker = RequestTracker()
|
tracker = RequestTracker()
|
||||||
|
tracker.new_requests_event = DummyEvent()
|
||||||
stream_1 = tracker.add_request("1")
|
stream_1 = tracker.add_request("1")
|
||||||
|
assert tracker.new_requests_event._flag
|
||||||
new, finished = tracker.get_new_and_finished_requests()
|
new, finished = tracker.get_new_and_finished_requests()
|
||||||
|
assert not tracker.new_requests_event._flag
|
||||||
assert len(new) == 1
|
assert len(new) == 1
|
||||||
assert new[0]["request_id"] == "1"
|
assert new[0]["request_id"] == "1"
|
||||||
assert not finished
|
assert not finished
|
||||||
@ -15,7 +30,9 @@ def test_request_tracker():
|
|||||||
|
|
||||||
stream_2 = tracker.add_request("2")
|
stream_2 = tracker.add_request("2")
|
||||||
stream_3 = tracker.add_request("3")
|
stream_3 = tracker.add_request("3")
|
||||||
|
assert tracker.new_requests_event._flag
|
||||||
new, finished = tracker.get_new_and_finished_requests()
|
new, finished = tracker.get_new_and_finished_requests()
|
||||||
|
assert not tracker.new_requests_event._flag
|
||||||
assert len(new) == 2
|
assert len(new) == 2
|
||||||
assert new[0]["request_id"] == "2"
|
assert new[0]["request_id"] == "2"
|
||||||
assert new[1]["request_id"] == "3"
|
assert new[1]["request_id"] == "3"
|
||||||
@ -26,6 +43,7 @@ def test_request_tracker():
|
|||||||
# request_ids must be unique
|
# request_ids must be unique
|
||||||
with pytest.raises(KeyError):
|
with pytest.raises(KeyError):
|
||||||
tracker.add_request("1")
|
tracker.add_request("1")
|
||||||
|
assert not tracker.new_requests_event._flag
|
||||||
|
|
||||||
tracker.abort_request("1")
|
tracker.abort_request("1")
|
||||||
new, finished = tracker.get_new_and_finished_requests()
|
new, finished = tracker.get_new_and_finished_requests()
|
||||||
@ -36,6 +54,7 @@ def test_request_tracker():
|
|||||||
|
|
||||||
stream_4 = tracker.add_request("4")
|
stream_4 = tracker.add_request("4")
|
||||||
tracker.abort_request("4")
|
tracker.abort_request("4")
|
||||||
|
assert tracker.new_requests_event._flag
|
||||||
new, finished = tracker.get_new_and_finished_requests()
|
new, finished = tracker.get_new_and_finished_requests()
|
||||||
assert len(finished) == 1
|
assert len(finished) == 1
|
||||||
assert "4" in finished
|
assert "4" in finished
|
||||||
@ -43,9 +62,11 @@ def test_request_tracker():
|
|||||||
assert stream_4.finished
|
assert stream_4.finished
|
||||||
|
|
||||||
stream_5 = tracker.add_request("5")
|
stream_5 = tracker.add_request("5")
|
||||||
|
assert tracker.new_requests_event._flag
|
||||||
tracker.process_request_output(
|
tracker.process_request_output(
|
||||||
RequestOutput("2", "output", [], [], finished=True))
|
RequestOutput("2", "output", [], [], finished=True))
|
||||||
new, finished = tracker.get_new_and_finished_requests()
|
new, finished = tracker.get_new_and_finished_requests()
|
||||||
|
assert not tracker.new_requests_event._flag
|
||||||
assert len(finished) == 1
|
assert len(finished) == 1
|
||||||
assert "2" in finished
|
assert "2" in finished
|
||||||
assert len(new) == 1
|
assert len(new) == 1
|
||||||
|
@ -1,7 +1,8 @@
|
|||||||
import asyncio
|
import asyncio
|
||||||
import time
|
import time
|
||||||
from functools import partial
|
from functools import partial
|
||||||
from typing import Any, Dict, Iterable, List, Optional, Set, Tuple, Type, Union
|
from typing import (Any, Dict, Iterable, List, Optional, Set, Tuple, Type,
|
||||||
|
Union)
|
||||||
|
|
||||||
from vllm.config import ModelConfig
|
from vllm.config import ModelConfig
|
||||||
from vllm.engine.arg_utils import AsyncEngineArgs
|
from vllm.engine.arg_utils import AsyncEngineArgs
|
||||||
@ -78,12 +79,22 @@ class RequestTracker:
|
|||||||
self._finished_requests: asyncio.Queue[str] = asyncio.Queue()
|
self._finished_requests: asyncio.Queue[str] = asyncio.Queue()
|
||||||
self._new_requests: asyncio.Queue[Tuple[AsyncStream,
|
self._new_requests: asyncio.Queue[Tuple[AsyncStream,
|
||||||
dict]] = asyncio.Queue()
|
dict]] = asyncio.Queue()
|
||||||
|
self.new_requests_event = None
|
||||||
|
|
||||||
def __contains__(self, item):
|
def __contains__(self, item):
|
||||||
return item in self._request_streams
|
return item in self._request_streams
|
||||||
|
|
||||||
def propagate_exception(self, exc: Exception) -> None:
|
def init_event(self):
|
||||||
"""Propagate an exception to all request streams."""
|
self.new_requests_event = asyncio.Event()
|
||||||
|
|
||||||
|
def propagate_exception(self,
|
||||||
|
exc: Exception,
|
||||||
|
request_id: Optional[str] = None) -> None:
|
||||||
|
"""Propagate an exception to request streams
|
||||||
|
(all if request_id is None)."""
|
||||||
|
if request_id is not None:
|
||||||
|
self._request_streams[request_id].put(exc)
|
||||||
|
else:
|
||||||
for stream in self._request_streams.values():
|
for stream in self._request_streams.values():
|
||||||
stream.put(exc)
|
stream.put(exc)
|
||||||
|
|
||||||
@ -112,6 +123,9 @@ class RequestTracker:
|
|||||||
"request_id": request_id,
|
"request_id": request_id,
|
||||||
**engine_add_request_kwargs
|
**engine_add_request_kwargs
|
||||||
}))
|
}))
|
||||||
|
|
||||||
|
self.new_requests_event.set()
|
||||||
|
|
||||||
return stream
|
return stream
|
||||||
|
|
||||||
def abort_request(self, request_id: str, *, verbose: bool = False) -> None:
|
def abort_request(self, request_id: str, *, verbose: bool = False) -> None:
|
||||||
@ -148,8 +162,13 @@ class RequestTracker:
|
|||||||
self._request_streams[stream.request_id] = stream
|
self._request_streams[stream.request_id] = stream
|
||||||
new_requests.append(new_request)
|
new_requests.append(new_request)
|
||||||
|
|
||||||
|
self.new_requests_event.clear()
|
||||||
|
|
||||||
return new_requests, finished_requests
|
return new_requests, finished_requests
|
||||||
|
|
||||||
|
async def wait_for_new_requests(self):
|
||||||
|
await self.new_requests_event.wait()
|
||||||
|
|
||||||
|
|
||||||
class _AsyncLLMEngine(LLMEngine):
|
class _AsyncLLMEngine(LLMEngine):
|
||||||
"""Extension of LLMEngine to add async methods."""
|
"""Extension of LLMEngine to add async methods."""
|
||||||
@ -251,9 +270,13 @@ class AsyncLLMEngine:
|
|||||||
self.max_log_len = max_log_len
|
self.max_log_len = max_log_len
|
||||||
self.engine = self._init_engine(*args, **kwargs)
|
self.engine = self._init_engine(*args, **kwargs)
|
||||||
|
|
||||||
self.request_tracker: RequestTracker = RequestTracker()
|
|
||||||
self.background_loop = None
|
self.background_loop = None
|
||||||
|
# We need to keep a reference to unshielded
|
||||||
|
# task as well to prevent it from being garbage
|
||||||
|
# collected
|
||||||
|
self._background_loop_unshielded = None
|
||||||
self.start_engine_loop = start_engine_loop
|
self.start_engine_loop = start_engine_loop
|
||||||
|
self._request_tracker = RequestTracker()
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def is_running(self) -> bool:
|
def is_running(self) -> bool:
|
||||||
@ -264,11 +287,14 @@ class AsyncLLMEngine:
|
|||||||
"""Start the background loop."""
|
"""Start the background loop."""
|
||||||
if self.is_running:
|
if self.is_running:
|
||||||
raise RuntimeError("Background loop is already running.")
|
raise RuntimeError("Background loop is already running.")
|
||||||
self.background_loop = asyncio.get_event_loop().create_task(
|
self._request_tracker.init_event()
|
||||||
self.run_engine_loop())
|
|
||||||
self.background_loop.add_done_callback(
|
self._background_loop_unshielded = asyncio.get_event_loop(
|
||||||
|
).create_task(self.run_engine_loop())
|
||||||
|
self._background_loop_unshielded.add_done_callback(
|
||||||
partial(_raise_exception_on_finish,
|
partial(_raise_exception_on_finish,
|
||||||
request_tracker=self.request_tracker))
|
request_tracker=self._request_tracker))
|
||||||
|
self.background_loop = asyncio.shield(self._background_loop_unshielded)
|
||||||
|
|
||||||
def _init_engine(self, *args,
|
def _init_engine(self, *args,
|
||||||
**kwargs) -> Union[_AsyncLLMEngine, "ray.ObjectRef"]:
|
**kwargs) -> Union[_AsyncLLMEngine, "ray.ObjectRef"]:
|
||||||
@ -280,11 +306,13 @@ class AsyncLLMEngine:
|
|||||||
engine_class = ray.remote(num_gpus=1)(self._engine_class).remote
|
engine_class = ray.remote(num_gpus=1)(self._engine_class).remote
|
||||||
return engine_class(*args, **kwargs)
|
return engine_class(*args, **kwargs)
|
||||||
|
|
||||||
async def engine_step(self):
|
async def engine_step(self) -> bool:
|
||||||
"""Kick the engine to process the waiting requests."""
|
"""Kick the engine to process the waiting requests.
|
||||||
|
|
||||||
|
Returns True if there are in-progress requests."""
|
||||||
|
|
||||||
new_requests, finished_requests = (
|
new_requests, finished_requests = (
|
||||||
self.request_tracker.get_new_and_finished_requests())
|
self._request_tracker.get_new_and_finished_requests())
|
||||||
|
|
||||||
for new_request in new_requests:
|
for new_request in new_requests:
|
||||||
# Add the request into the vLLM engine's waiting queue.
|
# Add the request into the vLLM engine's waiting queue.
|
||||||
@ -304,9 +332,11 @@ class AsyncLLMEngine:
|
|||||||
|
|
||||||
# Put the outputs into the corresponding streams.
|
# Put the outputs into the corresponding streams.
|
||||||
for request_output in request_outputs:
|
for request_output in request_outputs:
|
||||||
self.request_tracker.process_request_output(
|
self._request_tracker.process_request_output(
|
||||||
request_output, verbose=self.log_requests)
|
request_output, verbose=self.log_requests)
|
||||||
|
|
||||||
|
return len(request_outputs) > 0
|
||||||
|
|
||||||
async def _engine_abort(self, request_ids: Iterable[str]):
|
async def _engine_abort(self, request_ids: Iterable[str]):
|
||||||
if self.engine_use_ray:
|
if self.engine_use_ray:
|
||||||
await self.engine.abort_request.remote(request_ids)
|
await self.engine.abort_request.remote(request_ids)
|
||||||
@ -314,8 +344,12 @@ class AsyncLLMEngine:
|
|||||||
self.engine.abort_request(request_ids)
|
self.engine.abort_request(request_ids)
|
||||||
|
|
||||||
async def run_engine_loop(self):
|
async def run_engine_loop(self):
|
||||||
|
# Initialize the RequestTracker here so it uses the right event loop.
|
||||||
|
has_requests_in_progress = False
|
||||||
while True:
|
while True:
|
||||||
await self.engine_step()
|
if not has_requests_in_progress:
|
||||||
|
await self._request_tracker.wait_for_new_requests()
|
||||||
|
has_requests_in_progress = await self.engine_step()
|
||||||
await asyncio.sleep(0)
|
await asyncio.sleep(0)
|
||||||
|
|
||||||
async def add_request(
|
async def add_request(
|
||||||
@ -350,7 +384,7 @@ class AsyncLLMEngine:
|
|||||||
"error that caused the background loop to stop "
|
"error that caused the background loop to stop "
|
||||||
"(AsyncEngineDeadError).")
|
"(AsyncEngineDeadError).")
|
||||||
|
|
||||||
stream = self.request_tracker.add_request(
|
stream = self._request_tracker.add_request(
|
||||||
request_id,
|
request_id,
|
||||||
prompt=prompt,
|
prompt=prompt,
|
||||||
sampling_params=sampling_params,
|
sampling_params=sampling_params,
|
||||||
@ -428,7 +462,7 @@ class AsyncLLMEngine:
|
|||||||
Args:
|
Args:
|
||||||
request_id: The unique id of the request.
|
request_id: The unique id of the request.
|
||||||
"""
|
"""
|
||||||
self.request_tracker.abort_request(request_id,
|
self._request_tracker.abort_request(request_id,
|
||||||
verbose=self.log_requests)
|
verbose=self.log_requests)
|
||||||
|
|
||||||
async def get_model_config(self) -> ModelConfig:
|
async def get_model_config(self) -> ModelConfig:
|
||||||
|
Loading…
x
Reference in New Issue
Block a user