[Bugfix] Use heartbeats instead of health checks (#8583)
This commit is contained in:
parent
6da1ab6b41
commit
6e0c9d6bd0
@ -153,27 +153,20 @@ async def test_failed_abort(tmp_socket):
|
||||
await client.check_health()
|
||||
|
||||
# Trigger an abort on the client side.
|
||||
async def bad_abort_after_2s():
|
||||
await asyncio.sleep(2.0)
|
||||
await client.abort(request_id="foo")
|
||||
# This request ID does not exist, and will cause the engine to error
|
||||
await client.abort(request_id="foo")
|
||||
|
||||
# Trigger an abort in 2s from now.
|
||||
abort_task = asyncio.create_task(bad_abort_after_2s())
|
||||
|
||||
# Exception in abort() will happen during this generation.
|
||||
# This will kill the engine and should return ENGINE_DEAD_ERROR
|
||||
# Future generation requests will now fail
|
||||
# with reference to the original KeyError("foo")
|
||||
with pytest.raises(MQEngineDeadError) as execinfo:
|
||||
async for _ in client.generate(
|
||||
inputs="Hello my name is",
|
||||
sampling_params=SamplingParams(max_tokens=2000),
|
||||
sampling_params=SamplingParams(max_tokens=10),
|
||||
request_id=uuid.uuid4()):
|
||||
pass
|
||||
assert "KeyError" in repr(execinfo.value)
|
||||
assert client.errored
|
||||
|
||||
await abort_task
|
||||
|
||||
# This should raise the original error.
|
||||
with pytest.raises(RAISED_ERROR):
|
||||
await client.check_health()
|
||||
|
@ -43,10 +43,6 @@ class RPCAbortRequest:
|
||||
request_id: str
|
||||
|
||||
|
||||
class RPCHealthRequest:
|
||||
pass
|
||||
|
||||
|
||||
class RPCStartupRequest(Enum):
|
||||
IS_SERVER_READY = 1
|
||||
|
||||
@ -56,8 +52,7 @@ class RPCStartupResponse:
|
||||
tracing_enabled: bool
|
||||
|
||||
|
||||
RPC_REQUEST_T = Union[RPCProcessRequest, RPCAbortRequest, RPCHealthRequest,
|
||||
RPCStartupRequest]
|
||||
RPC_REQUEST_T = Union[RPCProcessRequest, RPCAbortRequest, RPCStartupRequest]
|
||||
|
||||
REQUEST_OUTPUTS_T = Union[List[RequestOutput], RPCError]
|
||||
|
||||
|
@ -20,9 +20,8 @@ from vllm.engine.multiprocessing import (ENGINE_DEAD_ERROR, IPC_DATA_EXT,
|
||||
IPC_HEALTH_EXT, IPC_INPUT_EXT,
|
||||
IPC_OUTPUT_EXT, RPC_REQUEST_T,
|
||||
VLLM_RPC_SUCCESS_STR, RPCAbortRequest,
|
||||
RPCError, RPCHealthRequest,
|
||||
RPCProcessRequest, RPCStartupRequest,
|
||||
RPCStartupResponse)
|
||||
RPCError, RPCProcessRequest,
|
||||
RPCStartupRequest, RPCStartupResponse)
|
||||
# yapf: enable
|
||||
from vllm.envs import VLLM_RPC_TIMEOUT
|
||||
from vllm.inputs import PromptInputs
|
||||
@ -95,9 +94,9 @@ class MQLLMEngineClient:
|
||||
self.output_socket: Socket = self.context.socket(zmq.constants.PULL)
|
||||
self.output_socket.connect(f"{ipc_path}{IPC_OUTPUT_EXT}")
|
||||
|
||||
# IPC path for ack of check_health requests.
|
||||
self.health_socket: Socket = self.context.socket(zmq.constants.PULL)
|
||||
self.health_socket.connect(f"{ipc_path}{IPC_HEALTH_EXT}")
|
||||
# IPC path for acking heartbeats.
|
||||
self.heartbeat_socket: Socket = self.context.socket(zmq.constants.PULL)
|
||||
self.heartbeat_socket.connect(f"{ipc_path}{IPC_HEALTH_EXT}")
|
||||
|
||||
# IPC path for the data socket.
|
||||
self.data_ipc_path = f"{ipc_path}{IPC_DATA_EXT}"
|
||||
@ -124,34 +123,28 @@ class MQLLMEngineClient:
|
||||
finally:
|
||||
socket.close(linger=0)
|
||||
|
||||
async def run_check_health_loop(self, timeout: int):
|
||||
"""Background loop that continually probes the RPCServer for health.
|
||||
|
||||
The loop sends CHECK_HEALTH requests to the INPUT_SOCKET, which
|
||||
the MQLLMEngine server is blocking on.
|
||||
|
||||
The Server replies on the HEALTH_SOCKET (rather than on the
|
||||
OUTPUT_SOCKET such that the messages are not intermingled with
|
||||
output streaming).
|
||||
async def run_heartbeat_loop(self, timeout: int):
|
||||
"""Background loop that continually listens to the RPCServer for
|
||||
heartbeats.
|
||||
"""
|
||||
|
||||
try:
|
||||
while True:
|
||||
if await self.health_socket.poll(timeout=timeout) == 0:
|
||||
# Wakeup every N seconds and do a health probe.
|
||||
await self._send_one_way_rpc_request(
|
||||
RPCHealthRequest(), self.input_socket)
|
||||
if await self.heartbeat_socket.poll(timeout=timeout) == 0:
|
||||
# No heartbeat was received. Set error and exit the loop
|
||||
self._set_errored(
|
||||
TimeoutError("No heartbeat received "
|
||||
"from MQLLMEngine"))
|
||||
logger.debug("Shutting down MQLLMEngineClient check "
|
||||
"health loop due to timeout")
|
||||
break
|
||||
|
||||
# Wait for ack from the health socket.
|
||||
await self._await_ack(error_message="Health check failed.",
|
||||
socket=self.health_socket)
|
||||
else:
|
||||
# Server sent a health status message unprompted.
|
||||
# Heartbeat received- check the message
|
||||
await self._check_success(
|
||||
error_message="Health check failed.",
|
||||
socket=self.health_socket)
|
||||
error_message="Heartbeat failed.",
|
||||
socket=self.heartbeat_socket)
|
||||
|
||||
logger.debug("Health probe successful.")
|
||||
logger.debug("Heartbeat successful.")
|
||||
|
||||
except asyncio.CancelledError:
|
||||
logger.debug("Shutting down MQLLMEngineClient check health loop.")
|
||||
@ -234,7 +227,7 @@ class MQLLMEngineClient:
|
||||
|
||||
# Start health_loop.
|
||||
self.health_loop = asyncio.create_task(
|
||||
self.run_check_health_loop(timeout=VLLM_RPC_TIMEOUT))
|
||||
self.run_heartbeat_loop(timeout=VLLM_RPC_TIMEOUT))
|
||||
|
||||
def close(self):
|
||||
"""Destroy the ZeroMQ Context."""
|
||||
|
@ -1,5 +1,7 @@
|
||||
import pickle
|
||||
import signal
|
||||
import threading
|
||||
import time
|
||||
from contextlib import contextmanager
|
||||
from typing import Iterator, List, Optional, Union
|
||||
|
||||
@ -15,10 +17,10 @@ from vllm.engine.multiprocessing import (ENGINE_DEAD_ERROR, IPC_DATA_EXT,
|
||||
IPC_HEALTH_EXT, IPC_INPUT_EXT,
|
||||
IPC_OUTPUT_EXT, REQUEST_OUTPUTS_T,
|
||||
VLLM_RPC_SUCCESS_STR, RPCAbortRequest,
|
||||
RPCError, RPCHealthRequest,
|
||||
RPCProcessRequest, RPCStartupRequest,
|
||||
RPCStartupResponse)
|
||||
RPCError, RPCProcessRequest,
|
||||
RPCStartupRequest, RPCStartupResponse)
|
||||
# yapf: enable
|
||||
from vllm.envs import VLLM_RPC_TIMEOUT
|
||||
from vllm.logger import init_logger
|
||||
from vllm.outputs import RequestOutput
|
||||
from vllm.usage.usage_lib import UsageContext
|
||||
@ -91,9 +93,9 @@ class MQLLMEngine:
|
||||
self.output_socket = self.ctx.socket(zmq.constants.PUSH)
|
||||
self.output_socket.bind(f"{ipc_path}{IPC_OUTPUT_EXT}")
|
||||
|
||||
# Send health status back to client.
|
||||
self.health_socket = self.ctx.socket(zmq.constants.PUSH)
|
||||
self.health_socket.bind(f"{ipc_path}{IPC_HEALTH_EXT}")
|
||||
# Send heartbeats back to client.
|
||||
self.heartbeat_socket = self.ctx.socket(zmq.constants.PUSH)
|
||||
self.heartbeat_socket.bind(f"{ipc_path}{IPC_HEALTH_EXT}")
|
||||
|
||||
# IPC path for the data socket.
|
||||
self.data_ipc_path = f"{ipc_path}{IPC_DATA_EXT}"
|
||||
@ -101,6 +103,20 @@ class MQLLMEngine:
|
||||
# Error state.
|
||||
self._errored_with: Optional[BaseException] = None
|
||||
|
||||
# Heartbeat thread
|
||||
self.heartbeat_thread = threading.Thread(target=self._heartbeat_loop,
|
||||
daemon=True)
|
||||
self._heartbeat_stop_event = threading.Event()
|
||||
# The heartbeat needs to be faster than what the client will wait for
|
||||
# The VLLM_RPC_TIMEOUT duration is in ms, and we need one in seconds
|
||||
self.heartbeat_interval_seconds = VLLM_RPC_TIMEOUT / 5000.0
|
||||
|
||||
self._last_alive_time = time.time()
|
||||
# The heartbeats can tolerate a long period of the engine chugging
|
||||
# away at a generation request.
|
||||
# The VLLM_RPC_TIMEOUT duration is in ms, and we need one in seconds
|
||||
self.last_alive_threshold = VLLM_RPC_TIMEOUT * 3.0 / 1000.0
|
||||
|
||||
@property
|
||||
def dead_error(self) -> BaseException:
|
||||
if self._errored_with is not None:
|
||||
@ -131,6 +147,8 @@ class MQLLMEngine:
|
||||
try:
|
||||
logger.debug("Starting Startup Loop.")
|
||||
self.run_startup_loop()
|
||||
logger.debug("Starting heartbeat thread")
|
||||
self.heartbeat_thread.start()
|
||||
logger.debug("Starting Engine Loop.")
|
||||
self.run_engine_loop()
|
||||
except Exception as e:
|
||||
@ -144,6 +162,7 @@ class MQLLMEngine:
|
||||
def cleanup(self):
|
||||
"""Cleanup zeromq state on shutdown."""
|
||||
# Closes all sockets and destroys context.
|
||||
self._heartbeat_stop_event.set()
|
||||
self.ctx.destroy(linger=0)
|
||||
del self.engine
|
||||
|
||||
@ -182,9 +201,11 @@ class MQLLMEngine:
|
||||
"""Core busy loop of the LLMEngine."""
|
||||
|
||||
while True:
|
||||
self._alive()
|
||||
if not self.engine.has_unfinished_requests():
|
||||
# Poll until there is work to do.
|
||||
while self.input_socket.poll(timeout=POLLING_TIMEOUT_MS) == 0:
|
||||
self._alive()
|
||||
self.engine.do_log_stats()
|
||||
logger.debug("Waiting for new requests in engine loop.")
|
||||
|
||||
@ -200,7 +221,6 @@ class MQLLMEngine:
|
||||
|
||||
def engine_step(self) -> List[RequestOutput]:
|
||||
"""Engine step wrapper with error handling."""
|
||||
|
||||
try:
|
||||
return self.engine.step()
|
||||
except SystemExit:
|
||||
@ -229,10 +249,9 @@ class MQLLMEngine:
|
||||
self._handle_process_request(request)
|
||||
elif isinstance(request, RPCAbortRequest):
|
||||
self._handle_abort_request(request)
|
||||
elif isinstance(request, RPCHealthRequest):
|
||||
self._handle_health_request()
|
||||
else:
|
||||
raise ValueError("Unknown RPCRequest Type: {request}")
|
||||
raise ValueError("Unknown RPCRequest Type: "
|
||||
f"{type(request)}")
|
||||
|
||||
except Exception as e:
|
||||
self._set_errored(e)
|
||||
@ -279,13 +298,32 @@ class MQLLMEngine:
|
||||
if self.log_requests:
|
||||
logger.info("Aborted request %s.", request.request_id)
|
||||
|
||||
def _handle_health_request(self):
|
||||
def _heartbeat_loop(self):
|
||||
while not self._heartbeat_stop_event.wait(
|
||||
timeout=self.heartbeat_interval_seconds):
|
||||
# Loops until the stop event is set
|
||||
self._heartbeat()
|
||||
|
||||
logger.debug("Exiting MQLLMEngine heartbeat thread")
|
||||
|
||||
def _heartbeat(self):
|
||||
# Send unhealthy if engine has already errored
|
||||
if self._errored_with is not None:
|
||||
self._send_unhealthy(self._errored_with)
|
||||
|
||||
# Raises error if unhealthy.
|
||||
self.engine.check_health()
|
||||
self._send_healthy()
|
||||
# Check for life of the main loop
|
||||
elif time.time() - self._last_alive_time > self.last_alive_threshold:
|
||||
self._send_unhealthy(RuntimeError("Engine loop has died"))
|
||||
|
||||
else:
|
||||
# Otherwise- check health of the engine
|
||||
# self.engine.check_health() raises on unhealthy
|
||||
try:
|
||||
self.engine.check_health()
|
||||
self._send_healthy()
|
||||
except Exception as e:
|
||||
self._set_errored(e)
|
||||
self._send_unhealthy(e)
|
||||
|
||||
def _send_outputs(self, outputs: REQUEST_OUTPUTS_T):
|
||||
"""Send List of RequestOutput to RPCClient."""
|
||||
@ -295,12 +333,14 @@ class MQLLMEngine:
|
||||
|
||||
def _send_healthy(self):
|
||||
"""Send HEALTHY message to RPCClient."""
|
||||
self.health_socket.send_multipart(HEALTHY_RESPONSE, copy=False)
|
||||
if not self.heartbeat_socket.closed:
|
||||
self.heartbeat_socket.send_multipart(HEALTHY_RESPONSE, copy=False)
|
||||
|
||||
def _send_unhealthy(self, error: BaseException):
|
||||
"""Send UNHEALTHY message to RPCClient."""
|
||||
error_bytes = pickle.dumps(error)
|
||||
self.health_socket.send_multipart((error_bytes, ), copy=False)
|
||||
if not self.heartbeat_socket.closed:
|
||||
error_bytes = pickle.dumps(error)
|
||||
self.heartbeat_socket.send_multipart((error_bytes, ), copy=False)
|
||||
|
||||
def _async_socket_engine_callback(self,
|
||||
request_outputs: REQUEST_OUTPUTS_T):
|
||||
@ -313,6 +353,9 @@ class MQLLMEngine:
|
||||
if self._errored_with is None:
|
||||
self._errored_with = e
|
||||
|
||||
def _alive(self):
|
||||
self._last_alive_time = time.time()
|
||||
|
||||
|
||||
def run_mp_engine(engine_args: AsyncEngineArgs, usage_context: UsageContext,
|
||||
ipc_path: str):
|
||||
|
Loading…
x
Reference in New Issue
Block a user