[Bug][Frontend] Improve ZMQ client robustness (#7443)

Signed-off-by: Joe Runde <Joseph.Runde@ibm.com>
This commit is contained in:
Joe Runde 2024-08-21 20:18:11 -06:00 committed by GitHub
parent df1a21131d
commit cde9183b40
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
6 changed files with 176 additions and 28 deletions

View File

View File

@ -0,0 +1,119 @@
import asyncio
import tempfile
import unittest
import unittest.mock
import uuid
import pytest
import pytest_asyncio
from vllm.engine.async_llm_engine import AsyncLLMEngine
from vllm.entrypoints.openai.rpc.client import (AsyncEngineRPCClient,
RPCClientClosedError)
from vllm.entrypoints.openai.rpc.server import AsyncEngineRPCServer
@pytest.fixture(scope="function")
def tmp_socket():
with tempfile.TemporaryDirectory() as td:
yield f"ipc://{td}/{uuid.uuid4()}"
@pytest_asyncio.fixture(scope="function")
async def dummy_server(tmp_socket, monkeypatch):
dummy_engine = unittest.mock.AsyncMock()
def dummy_engine_builder(*args, **kwargs):
return dummy_engine
with monkeypatch.context() as m:
m.setattr(AsyncLLMEngine, "from_engine_args", dummy_engine_builder)
server = AsyncEngineRPCServer(None, None, rpc_path=tmp_socket)
loop = asyncio.get_running_loop()
server_task = loop.create_task(server.run_server_loop())
try:
yield server
finally:
server_task.cancel()
server.cleanup()
@pytest_asyncio.fixture(scope="function")
async def client(tmp_socket):
client = AsyncEngineRPCClient(rpc_path=tmp_socket)
# Sanity check: the server is connected
await client._wait_for_server_rpc()
try:
yield client
finally:
client.close()
@pytest.mark.asyncio
async def test_client_data_methods_use_timeouts(monkeypatch, dummy_server,
client: AsyncEngineRPCClient):
with monkeypatch.context() as m:
# Make the server _not_ reply with a model config
m.setattr(dummy_server, "get_config", lambda x: None)
m.setattr(client, "_data_timeout", 10)
# And ensure the task completes anyway
# (client.setup() invokes server.get_config())
client_task = asyncio.get_running_loop().create_task(client.setup())
with pytest.raises(TimeoutError, match="Server didn't reply within"):
await asyncio.wait_for(client_task, timeout=0.05)
@pytest.mark.asyncio
async def test_client_aborts_use_timeouts(monkeypatch, dummy_server,
client: AsyncEngineRPCClient):
with monkeypatch.context() as m:
# Hang all abort requests
m.setattr(dummy_server, "abort", lambda x: None)
m.setattr(client, "_data_timeout", 10)
# Ensure the client doesn't hang
client_task = asyncio.get_running_loop().create_task(
client.abort("test request id"))
with pytest.raises(TimeoutError, match="Server didn't reply within"):
await asyncio.wait_for(client_task, timeout=0.05)
@pytest.mark.asyncio
async def test_client_data_methods_reraise_exceptions(
monkeypatch, dummy_server, client: AsyncEngineRPCClient):
with monkeypatch.context() as m:
# Make the server raise some random exception
exception = RuntimeError("Client test exception")
def raiser():
raise exception
m.setattr(dummy_server.engine, "get_model_config", raiser)
m.setattr(client, "_data_timeout", 10)
client_task = asyncio.get_running_loop().create_task(client.setup())
# And ensure the task completes, raising the exception
with pytest.raises(RuntimeError, match=str(exception)):
await asyncio.wait_for(client_task, timeout=0.05)
@pytest.mark.asyncio
async def test_client_errors_after_closing(monkeypatch, dummy_server,
client: AsyncEngineRPCClient):
client.close()
# Healthchecks and generate requests will fail with explicit errors
with pytest.raises(RPCClientClosedError):
await client.check_health()
with pytest.raises(RPCClientClosedError):
async for _ in client.generate(None, None, None):
pass
# But no-ops like aborting will pass
await client.abort("test-request-id")
await client.do_log_stats()

View File

@ -6,7 +6,7 @@ import os
import re import re
import tempfile import tempfile
from argparse import Namespace from argparse import Namespace
from contextlib import asynccontextmanager from contextlib import asynccontextmanager, suppress
from http import HTTPStatus from http import HTTPStatus
from typing import AsyncIterator, Optional, Set from typing import AsyncIterator, Optional, Set
@ -83,7 +83,8 @@ async def lifespan(app: FastAPI):
async def _force_log(): async def _force_log():
while True: while True:
await asyncio.sleep(10) await asyncio.sleep(10)
await async_engine_client.do_log_stats() with suppress(Exception):
await async_engine_client.do_log_stats()
if not engine_args.disable_log_stats: if not engine_args.disable_log_stats:
task = asyncio.create_task(_force_log()) task = asyncio.create_task(_force_log())

View File

@ -10,10 +10,6 @@ from vllm.sampling_params import SamplingParams
# Success string used for RPC instructions. # Success string used for RPC instructions.
VLLM_RPC_SUCCESS_STR = "SUCCESS" VLLM_RPC_SUCCESS_STR = "SUCCESS"
# Timeouts.
VLLM_RPC_SERVER_START_TIMEOUT_MS = 1000
VLLM_RPC_HEALTH_TIMEOUT_MS = 10000
# Minimum value of ZMQ.SOCKET_LIMIT to run mp. # Minimum value of ZMQ.SOCKET_LIMIT to run mp.
VLLM_RPC_SOCKET_LIMIT_CUTOFF = 2000 VLLM_RPC_SOCKET_LIMIT_CUTOFF = 2000

View File

@ -1,5 +1,5 @@
import asyncio import asyncio
from contextlib import contextmanager from contextlib import contextmanager, suppress
from typing import Any, AsyncGenerator, Mapping, Optional from typing import Any, AsyncGenerator, Mapping, Optional
from uuid import uuid4 from uuid import uuid4
@ -11,13 +11,12 @@ from vllm.config import (DecodingConfig, LoRAConfig, ModelConfig,
ParallelConfig, SchedulerConfig) ParallelConfig, SchedulerConfig)
# yapf: disable # yapf: disable
from vllm.entrypoints.openai.rpc import (RPC_REQUEST_TYPE, from vllm.entrypoints.openai.rpc import (RPC_REQUEST_TYPE,
VLLM_RPC_HEALTH_TIMEOUT_MS,
VLLM_RPC_SERVER_START_TIMEOUT_MS,
VLLM_RPC_SOCKET_LIMIT_CUTOFF, VLLM_RPC_SOCKET_LIMIT_CUTOFF,
VLLM_RPC_SUCCESS_STR, VLLM_RPC_SUCCESS_STR,
VLLM_RPC_ZMQ_HWM, RPCAbortRequest, VLLM_RPC_ZMQ_HWM, RPCAbortRequest,
RPCGenerateRequest, RPCUtilityRequest) RPCGenerateRequest, RPCUtilityRequest)
# yapf: enable # yapf: enable
from vllm.envs import VLLM_RPC_GET_DATA_TIMEOUT_MS
from vllm.inputs import PromptInputs from vllm.inputs import PromptInputs
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.lora.request import LoRARequest from vllm.lora.request import LoRARequest
@ -32,6 +31,17 @@ logger = init_logger(__name__)
INPROC_PROXY_PATH = f"inproc://{uuid4()}" INPROC_PROXY_PATH = f"inproc://{uuid4()}"
class RPCClientClosedError(Exception):
"""Exception class raised when the client is used post-close.
The client can be closed, which closes the ZMQ context. This normally
happens on server shutdown. In some cases, methods like abort and
do_log_stats will still be called and then try to open a socket, which
causes a ZMQError and creates a huge stack trace.
So, we throw this error such that we can suppress it.
"""
class AsyncEngineRPCClient: class AsyncEngineRPCClient:
""" """
RPCClient that connects to the RPCServer wrapping AsyncLLMEngine. RPCClient that connects to the RPCServer wrapping AsyncLLMEngine.
@ -85,6 +95,8 @@ class AsyncEngineRPCClient:
def __init__(self, rpc_path: str): def __init__(self, rpc_path: str):
self.context = zmq.asyncio.Context() self.context = zmq.asyncio.Context()
self._data_timeout = VLLM_RPC_GET_DATA_TIMEOUT_MS
self._errored = False
# Maximum number of sockets that can be opened (typically 65536). # Maximum number of sockets that can be opened (typically 65536).
# ZMQ_SOCKET_LIMIT (http://api.zeromq.org/4-2:zmq-ctx-get) # ZMQ_SOCKET_LIMIT (http://api.zeromq.org/4-2:zmq-ctx-get)
@ -143,7 +155,6 @@ class AsyncEngineRPCClient:
# Wait until server is ready. # Wait until server is ready.
await self._wait_for_server_rpc() await self._wait_for_server_rpc()
self._errored = False
# Get the configs. # Get the configs.
self.model_config = await self._get_model_config_rpc() self.model_config = await self._get_model_config_rpc()
@ -170,6 +181,15 @@ class AsyncEngineRPCClient:
@contextmanager @contextmanager
def to_proxy_socket(self): def to_proxy_socket(self):
# Connect to the RPCServer via the proxy. # Connect to the RPCServer via the proxy.
# Raise a sensible error if the client was already closed.
# This can happen if a server shutdown is triggered but some coroutines
# are still running requests.
# There should not be a race condition with this check because we don't
# yield to the event loop between here and opening the socket.
if self.context.closed:
raise RPCClientClosedError("The ZMQ client has already shut down")
# Note that we use DEALER to enable asynchronous communication # Note that we use DEALER to enable asynchronous communication
# to enable streaming. # to enable streaming.
socket = self.context.socket(zmq.constants.DEALER) socket = self.context.socket(zmq.constants.DEALER)
@ -189,9 +209,18 @@ class AsyncEngineRPCClient:
# Ping RPCServer with a request. # Ping RPCServer with a request.
await socket.send_multipart([cloudpickle.dumps(request)]) await socket.send_multipart([cloudpickle.dumps(request)])
# Make sure the server responds
if await socket.poll(timeout=self._data_timeout) == 0:
raise TimeoutError("Server didn't reply within "
f"{self._data_timeout} ms")
# Await the data from the Server. # Await the data from the Server.
data = cloudpickle.loads(await socket.recv()) data = cloudpickle.loads(await socket.recv())
if isinstance(data, Exception):
# Re-raise exceptions returned by the server
raise data
if not isinstance(data, expected_type): if not isinstance(data, expected_type):
# LoRAConfig can be None. # LoRAConfig can be None.
if expected_type == LoRAConfig and data is None: if expected_type == LoRAConfig and data is None:
@ -208,29 +237,28 @@ class AsyncEngineRPCClient:
self, self,
request: RPC_REQUEST_TYPE, request: RPC_REQUEST_TYPE,
error_message: str, error_message: str,
timeout: Optional[int] = None,
socket: Optional[zmq.asyncio.Socket] = None): socket: Optional[zmq.asyncio.Socket] = None):
"""Send one-way RPC request to trigger an action.""" """Send one-way RPC request to trigger an action."""
async def do_rpc_call(socket: zmq.asyncio.Socket, async def do_rpc_call(socket: zmq.asyncio.Socket,
request: RPC_REQUEST_TYPE, request: RPC_REQUEST_TYPE):
timeout=None):
await socket.send_multipart([cloudpickle.dumps(request)]) await socket.send_multipart([cloudpickle.dumps(request)])
if timeout is not None and await socket.poll(timeout=timeout) == 0: if await socket.poll(timeout=self._data_timeout) == 0:
raise TimeoutError(f"Server didn't reply within {timeout} ms") raise TimeoutError("Server didn't reply within "
f"{self._data_timeout} ms")
return cloudpickle.loads(await socket.recv()) return cloudpickle.loads(await socket.recv())
# Make a new socket connection. # Make a new socket connection.
if socket is None: if socket is None:
with self.to_proxy_socket() as socket: with self.to_proxy_socket() as socket:
response = await do_rpc_call(socket, request, timeout) response = await do_rpc_call(socket, request)
# Use existing socket connection. # Use existing socket connection.
else: else:
response = await do_rpc_call(socket, request, timeout) response = await do_rpc_call(socket, request)
if not isinstance(response, str) or response != VLLM_RPC_SUCCESS_STR: if not isinstance(response, str) or response != VLLM_RPC_SUCCESS_STR:
if isinstance(response, Exception): if isinstance(response, Exception):
@ -255,8 +283,7 @@ class AsyncEngineRPCClient:
await self._send_one_way_rpc_request( await self._send_one_way_rpc_request(
request=RPCUtilityRequest.IS_SERVER_READY, request=RPCUtilityRequest.IS_SERVER_READY,
error_message="Unable to start RPC Server", error_message="Unable to start RPC Server")
timeout=VLLM_RPC_SERVER_START_TIMEOUT_MS)
async def _get_model_config_rpc(self) -> ModelConfig: async def _get_model_config_rpc(self) -> ModelConfig:
"""Get the ModelConfig object from the RPC Server""" """Get the ModelConfig object from the RPC Server"""
@ -308,17 +335,17 @@ class AsyncEngineRPCClient:
async def abort(self, request_id: str): async def abort(self, request_id: str):
"""Send an ABORT_REQUEST signal to the RPC Server""" """Send an ABORT_REQUEST signal to the RPC Server"""
with suppress(RPCClientClosedError):
await self._send_one_way_rpc_request( await self._send_one_way_rpc_request(
request=RPCAbortRequest(request_id), request=RPCAbortRequest(request_id),
error_message=f"RPCAbortRequest {request_id} failed") error_message=f"RPCAbortRequest {request_id} failed")
async def do_log_stats(self): async def do_log_stats(self):
"""Send a DO_LOG_STATS signal to the RPC Server""" """Send a DO_LOG_STATS signal to the RPC Server"""
with suppress(RPCClientClosedError):
await self._send_one_way_rpc_request( await self._send_one_way_rpc_request(
request=RPCUtilityRequest.DO_LOG_STATS, request=RPCUtilityRequest.DO_LOG_STATS,
error_message="RPCRequest DO_LOG_STATS failed.") error_message="RPCRequest DO_LOG_STATS failed.")
@property @property
def is_running(self) -> bool: def is_running(self) -> bool:
@ -393,7 +420,6 @@ class AsyncEngineRPCClient:
await self._send_one_way_rpc_request( await self._send_one_way_rpc_request(
request=RPCUtilityRequest.IS_SERVER_HEALTHY, request=RPCUtilityRequest.IS_SERVER_HEALTHY,
error_message="Got Unhealthy response from RPC Server", error_message="Got Unhealthy response from RPC Server",
timeout=VLLM_RPC_HEALTH_TIMEOUT_MS,
socket=socket) socket=socket)
async def encode(self, *args, async def encode(self, *args,

View File

@ -56,6 +56,7 @@ if TYPE_CHECKING:
VERBOSE: bool = False VERBOSE: bool = False
VLLM_ALLOW_LONG_MAX_MODEL_LEN: bool = False VLLM_ALLOW_LONG_MAX_MODEL_LEN: bool = False
VLLM_TEST_FORCE_FP8_MARLIN: bool = False VLLM_TEST_FORCE_FP8_MARLIN: bool = False
VLLM_RPC_GET_DATA_TIMEOUT_MS: int = 5000
VLLM_ALLOW_ENGINE_USE_RAY: bool = False VLLM_ALLOW_ENGINE_USE_RAY: bool = False
VLLM_PLUGINS: Optional[List[str]] = None VLLM_PLUGINS: Optional[List[str]] = None
VLLM_TORCH_PROFILER_DIR: Optional[str] = None VLLM_TORCH_PROFILER_DIR: Optional[str] = None
@ -374,6 +375,11 @@ environment_variables: Dict[str, Callable[[], Any]] = {
(os.environ.get("VLLM_TEST_FORCE_FP8_MARLIN", "0").strip().lower() in (os.environ.get("VLLM_TEST_FORCE_FP8_MARLIN", "0").strip().lower() in
("1", "true")), ("1", "true")),
# Time in ms for the zmq client to wait for a response from the backend
# server for simple data operations
"VLLM_RPC_GET_DATA_TIMEOUT_MS":
lambda: int(os.getenv("VLLM_RPC_GET_DATA_TIMEOUT_MS", "5000")),
# If set, allow running the engine as a separate ray actor, # If set, allow running the engine as a separate ray actor,
# which is a deprecated feature soon to be removed. # which is a deprecated feature soon to be removed.
# See https://github.com/vllm-project/vllm/issues/7045 # See https://github.com/vllm-project/vllm/issues/7045