[Bugfix] Use heartbeats instead of health checks (#8583)

This commit is contained in:
Joe Runde 2024-09-24 21:37:38 -06:00 committed by GitHub
parent 6da1ab6b41
commit 6e0c9d6bd0
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 86 additions and 62 deletions

View File

@ -153,27 +153,20 @@ async def test_failed_abort(tmp_socket):
await client.check_health()
# Trigger an abort on the client side.
async def bad_abort_after_2s():
await asyncio.sleep(2.0)
await client.abort(request_id="foo")
# This request ID does not exist, and will cause the engine to error
await client.abort(request_id="foo")
# Trigger an abort in 2s from now.
abort_task = asyncio.create_task(bad_abort_after_2s())
# Exception in abort() will happen during this generation.
# This will kill the engine and should return ENGINE_DEAD_ERROR
# Future generation requests will now fail
# with reference to the original KeyError("foo")
with pytest.raises(MQEngineDeadError) as execinfo:
async for _ in client.generate(
inputs="Hello my name is",
sampling_params=SamplingParams(max_tokens=2000),
sampling_params=SamplingParams(max_tokens=10),
request_id=uuid.uuid4()):
pass
assert "KeyError" in repr(execinfo.value)
assert client.errored
await abort_task
# This should raise the original error.
with pytest.raises(RAISED_ERROR):
await client.check_health()

View File

@ -43,10 +43,6 @@ class RPCAbortRequest:
request_id: str
class RPCHealthRequest:
pass
class RPCStartupRequest(Enum):
IS_SERVER_READY = 1
@ -56,8 +52,7 @@ class RPCStartupResponse:
tracing_enabled: bool
RPC_REQUEST_T = Union[RPCProcessRequest, RPCAbortRequest, RPCHealthRequest,
RPCStartupRequest]
RPC_REQUEST_T = Union[RPCProcessRequest, RPCAbortRequest, RPCStartupRequest]
REQUEST_OUTPUTS_T = Union[List[RequestOutput], RPCError]

View File

@ -20,9 +20,8 @@ from vllm.engine.multiprocessing import (ENGINE_DEAD_ERROR, IPC_DATA_EXT,
IPC_HEALTH_EXT, IPC_INPUT_EXT,
IPC_OUTPUT_EXT, RPC_REQUEST_T,
VLLM_RPC_SUCCESS_STR, RPCAbortRequest,
RPCError, RPCHealthRequest,
RPCProcessRequest, RPCStartupRequest,
RPCStartupResponse)
RPCError, RPCProcessRequest,
RPCStartupRequest, RPCStartupResponse)
# yapf: enable
from vllm.envs import VLLM_RPC_TIMEOUT
from vllm.inputs import PromptInputs
@ -95,9 +94,9 @@ class MQLLMEngineClient:
self.output_socket: Socket = self.context.socket(zmq.constants.PULL)
self.output_socket.connect(f"{ipc_path}{IPC_OUTPUT_EXT}")
# IPC path for ack of check_health requests.
self.health_socket: Socket = self.context.socket(zmq.constants.PULL)
self.health_socket.connect(f"{ipc_path}{IPC_HEALTH_EXT}")
# IPC path for acking heartbeats.
self.heartbeat_socket: Socket = self.context.socket(zmq.constants.PULL)
self.heartbeat_socket.connect(f"{ipc_path}{IPC_HEALTH_EXT}")
# IPC path for the data socket.
self.data_ipc_path = f"{ipc_path}{IPC_DATA_EXT}"
@ -124,34 +123,28 @@ class MQLLMEngineClient:
finally:
socket.close(linger=0)
async def run_check_health_loop(self, timeout: int):
"""Background loop that continually probes the RPCServer for health.
The loop sends CHECK_HEALTH requests to the INPUT_SOCKET, which
the MQLLMEngine server is blocking on.
The Server replies on the HEALTH_SOCKET (rather than on the
OUTPUT_SOCKET such that the messages are not intermingled with
output streaming).
async def run_heartbeat_loop(self, timeout: int):
"""Background loop that continually listens to the RPCServer for
heartbeats.
"""
try:
while True:
if await self.health_socket.poll(timeout=timeout) == 0:
# Wakeup every N seconds and do a health probe.
await self._send_one_way_rpc_request(
RPCHealthRequest(), self.input_socket)
if await self.heartbeat_socket.poll(timeout=timeout) == 0:
# No heartbeat was received. Set error and exit the loop
self._set_errored(
TimeoutError("No heartbeat received "
"from MQLLMEngine"))
logger.debug("Shutting down MQLLMEngineClient check "
"health loop due to timeout")
break
# Wait for ack from the health socket.
await self._await_ack(error_message="Health check failed.",
socket=self.health_socket)
else:
# Server sent a health status message unprompted.
# Heartbeat received- check the message
await self._check_success(
error_message="Health check failed.",
socket=self.health_socket)
error_message="Heartbeat failed.",
socket=self.heartbeat_socket)
logger.debug("Health probe successful.")
logger.debug("Heartbeat successful.")
except asyncio.CancelledError:
logger.debug("Shutting down MQLLMEngineClient check health loop.")
@ -234,7 +227,7 @@ class MQLLMEngineClient:
# Start health_loop.
self.health_loop = asyncio.create_task(
self.run_check_health_loop(timeout=VLLM_RPC_TIMEOUT))
self.run_heartbeat_loop(timeout=VLLM_RPC_TIMEOUT))
def close(self):
"""Destroy the ZeroMQ Context."""

View File

@ -1,5 +1,7 @@
import pickle
import signal
import threading
import time
from contextlib import contextmanager
from typing import Iterator, List, Optional, Union
@ -15,10 +17,10 @@ from vllm.engine.multiprocessing import (ENGINE_DEAD_ERROR, IPC_DATA_EXT,
IPC_HEALTH_EXT, IPC_INPUT_EXT,
IPC_OUTPUT_EXT, REQUEST_OUTPUTS_T,
VLLM_RPC_SUCCESS_STR, RPCAbortRequest,
RPCError, RPCHealthRequest,
RPCProcessRequest, RPCStartupRequest,
RPCStartupResponse)
RPCError, RPCProcessRequest,
RPCStartupRequest, RPCStartupResponse)
# yapf: enable
from vllm.envs import VLLM_RPC_TIMEOUT
from vllm.logger import init_logger
from vllm.outputs import RequestOutput
from vllm.usage.usage_lib import UsageContext
@ -91,9 +93,9 @@ class MQLLMEngine:
self.output_socket = self.ctx.socket(zmq.constants.PUSH)
self.output_socket.bind(f"{ipc_path}{IPC_OUTPUT_EXT}")
# Send health status back to client.
self.health_socket = self.ctx.socket(zmq.constants.PUSH)
self.health_socket.bind(f"{ipc_path}{IPC_HEALTH_EXT}")
# Send heartbeats back to client.
self.heartbeat_socket = self.ctx.socket(zmq.constants.PUSH)
self.heartbeat_socket.bind(f"{ipc_path}{IPC_HEALTH_EXT}")
# IPC path for the data socket.
self.data_ipc_path = f"{ipc_path}{IPC_DATA_EXT}"
@ -101,6 +103,20 @@ class MQLLMEngine:
# Error state.
self._errored_with: Optional[BaseException] = None
# Heartbeat thread
self.heartbeat_thread = threading.Thread(target=self._heartbeat_loop,
daemon=True)
self._heartbeat_stop_event = threading.Event()
# The heartbeat needs to be faster than what the client will wait for
# The VLLM_RPC_TIMEOUT duration is in ms, and we need one in seconds
self.heartbeat_interval_seconds = VLLM_RPC_TIMEOUT / 5000.0
self._last_alive_time = time.time()
# The heartbeats can tolerate a long period of the engine chugging
# away at a generation request.
# The VLLM_RPC_TIMEOUT duration is in ms, and we need one in seconds
self.last_alive_threshold = VLLM_RPC_TIMEOUT * 3.0 / 1000.0
@property
def dead_error(self) -> BaseException:
if self._errored_with is not None:
@ -131,6 +147,8 @@ class MQLLMEngine:
try:
logger.debug("Starting Startup Loop.")
self.run_startup_loop()
logger.debug("Starting heartbeat thread")
self.heartbeat_thread.start()
logger.debug("Starting Engine Loop.")
self.run_engine_loop()
except Exception as e:
@ -144,6 +162,7 @@ class MQLLMEngine:
def cleanup(self):
"""Cleanup zeromq state on shutdown."""
# Closes all sockets and destroys context.
self._heartbeat_stop_event.set()
self.ctx.destroy(linger=0)
del self.engine
@ -182,9 +201,11 @@ class MQLLMEngine:
"""Core busy loop of the LLMEngine."""
while True:
self._alive()
if not self.engine.has_unfinished_requests():
# Poll until there is work to do.
while self.input_socket.poll(timeout=POLLING_TIMEOUT_MS) == 0:
self._alive()
self.engine.do_log_stats()
logger.debug("Waiting for new requests in engine loop.")
@ -200,7 +221,6 @@ class MQLLMEngine:
def engine_step(self) -> List[RequestOutput]:
"""Engine step wrapper with error handling."""
try:
return self.engine.step()
except SystemExit:
@ -229,10 +249,9 @@ class MQLLMEngine:
self._handle_process_request(request)
elif isinstance(request, RPCAbortRequest):
self._handle_abort_request(request)
elif isinstance(request, RPCHealthRequest):
self._handle_health_request()
else:
raise ValueError("Unknown RPCRequest Type: {request}")
raise ValueError("Unknown RPCRequest Type: "
f"{type(request)}")
except Exception as e:
self._set_errored(e)
@ -279,13 +298,32 @@ class MQLLMEngine:
if self.log_requests:
logger.info("Aborted request %s.", request.request_id)
def _handle_health_request(self):
def _heartbeat_loop(self):
while not self._heartbeat_stop_event.wait(
timeout=self.heartbeat_interval_seconds):
# Loops until the stop event is set
self._heartbeat()
logger.debug("Exiting MQLLMEngine heartbeat thread")
def _heartbeat(self):
# Send unhealthy if engine has already errored
if self._errored_with is not None:
self._send_unhealthy(self._errored_with)
# Raises error if unhealthy.
self.engine.check_health()
self._send_healthy()
# Check for life of the main loop
elif time.time() - self._last_alive_time > self.last_alive_threshold:
self._send_unhealthy(RuntimeError("Engine loop has died"))
else:
# Otherwise- check health of the engine
# self.engine.check_health() raises on unhealthy
try:
self.engine.check_health()
self._send_healthy()
except Exception as e:
self._set_errored(e)
self._send_unhealthy(e)
def _send_outputs(self, outputs: REQUEST_OUTPUTS_T):
"""Send List of RequestOutput to RPCClient."""
@ -295,12 +333,14 @@ class MQLLMEngine:
def _send_healthy(self):
"""Send HEALTHY message to RPCClient."""
self.health_socket.send_multipart(HEALTHY_RESPONSE, copy=False)
if not self.heartbeat_socket.closed:
self.heartbeat_socket.send_multipart(HEALTHY_RESPONSE, copy=False)
def _send_unhealthy(self, error: BaseException):
"""Send UNHEALTHY message to RPCClient."""
error_bytes = pickle.dumps(error)
self.health_socket.send_multipart((error_bytes, ), copy=False)
if not self.heartbeat_socket.closed:
error_bytes = pickle.dumps(error)
self.heartbeat_socket.send_multipart((error_bytes, ), copy=False)
def _async_socket_engine_callback(self,
request_outputs: REQUEST_OUTPUTS_T):
@ -313,6 +353,9 @@ class MQLLMEngine:
if self._errored_with is None:
self._errored_with = e
def _alive(self):
self._last_alive_time = time.time()
def run_mp_engine(engine_args: AsyncEngineArgs, usage_context: UsageContext,
ipc_path: str):