[Bug][Frontend] Improve ZMQ client robustness (#7443)
Signed-off-by: Joe Runde <Joseph.Runde@ibm.com>
This commit is contained in:
parent
df1a21131d
commit
cde9183b40
0
tests/entrypoints/openai/rpc/__init__.py
Normal file
0
tests/entrypoints/openai/rpc/__init__.py
Normal file
119
tests/entrypoints/openai/rpc/test_zmq_client.py
Normal file
119
tests/entrypoints/openai/rpc/test_zmq_client.py
Normal file
@ -0,0 +1,119 @@
|
||||
import asyncio
|
||||
import tempfile
|
||||
import unittest
|
||||
import unittest.mock
|
||||
import uuid
|
||||
|
||||
import pytest
|
||||
import pytest_asyncio
|
||||
|
||||
from vllm.engine.async_llm_engine import AsyncLLMEngine
|
||||
from vllm.entrypoints.openai.rpc.client import (AsyncEngineRPCClient,
|
||||
RPCClientClosedError)
|
||||
from vllm.entrypoints.openai.rpc.server import AsyncEngineRPCServer
|
||||
|
||||
|
||||
@pytest.fixture(scope="function")
|
||||
def tmp_socket():
|
||||
with tempfile.TemporaryDirectory() as td:
|
||||
yield f"ipc://{td}/{uuid.uuid4()}"
|
||||
|
||||
|
||||
@pytest_asyncio.fixture(scope="function")
|
||||
async def dummy_server(tmp_socket, monkeypatch):
|
||||
dummy_engine = unittest.mock.AsyncMock()
|
||||
|
||||
def dummy_engine_builder(*args, **kwargs):
|
||||
return dummy_engine
|
||||
|
||||
with monkeypatch.context() as m:
|
||||
m.setattr(AsyncLLMEngine, "from_engine_args", dummy_engine_builder)
|
||||
server = AsyncEngineRPCServer(None, None, rpc_path=tmp_socket)
|
||||
|
||||
loop = asyncio.get_running_loop()
|
||||
server_task = loop.create_task(server.run_server_loop())
|
||||
|
||||
try:
|
||||
yield server
|
||||
finally:
|
||||
server_task.cancel()
|
||||
server.cleanup()
|
||||
|
||||
|
||||
@pytest_asyncio.fixture(scope="function")
|
||||
async def client(tmp_socket):
|
||||
client = AsyncEngineRPCClient(rpc_path=tmp_socket)
|
||||
# Sanity check: the server is connected
|
||||
await client._wait_for_server_rpc()
|
||||
|
||||
try:
|
||||
yield client
|
||||
finally:
|
||||
client.close()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_client_data_methods_use_timeouts(monkeypatch, dummy_server,
|
||||
client: AsyncEngineRPCClient):
|
||||
with monkeypatch.context() as m:
|
||||
# Make the server _not_ reply with a model config
|
||||
m.setattr(dummy_server, "get_config", lambda x: None)
|
||||
m.setattr(client, "_data_timeout", 10)
|
||||
|
||||
# And ensure the task completes anyway
|
||||
# (client.setup() invokes server.get_config())
|
||||
client_task = asyncio.get_running_loop().create_task(client.setup())
|
||||
with pytest.raises(TimeoutError, match="Server didn't reply within"):
|
||||
await asyncio.wait_for(client_task, timeout=0.05)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_client_aborts_use_timeouts(monkeypatch, dummy_server,
|
||||
client: AsyncEngineRPCClient):
|
||||
with monkeypatch.context() as m:
|
||||
# Hang all abort requests
|
||||
m.setattr(dummy_server, "abort", lambda x: None)
|
||||
m.setattr(client, "_data_timeout", 10)
|
||||
|
||||
# Ensure the client doesn't hang
|
||||
client_task = asyncio.get_running_loop().create_task(
|
||||
client.abort("test request id"))
|
||||
with pytest.raises(TimeoutError, match="Server didn't reply within"):
|
||||
await asyncio.wait_for(client_task, timeout=0.05)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_client_data_methods_reraise_exceptions(
|
||||
monkeypatch, dummy_server, client: AsyncEngineRPCClient):
|
||||
with monkeypatch.context() as m:
|
||||
# Make the server raise some random exception
|
||||
exception = RuntimeError("Client test exception")
|
||||
|
||||
def raiser():
|
||||
raise exception
|
||||
|
||||
m.setattr(dummy_server.engine, "get_model_config", raiser)
|
||||
m.setattr(client, "_data_timeout", 10)
|
||||
|
||||
client_task = asyncio.get_running_loop().create_task(client.setup())
|
||||
# And ensure the task completes, raising the exception
|
||||
with pytest.raises(RuntimeError, match=str(exception)):
|
||||
await asyncio.wait_for(client_task, timeout=0.05)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_client_errors_after_closing(monkeypatch, dummy_server,
|
||||
client: AsyncEngineRPCClient):
|
||||
|
||||
client.close()
|
||||
|
||||
# Healthchecks and generate requests will fail with explicit errors
|
||||
with pytest.raises(RPCClientClosedError):
|
||||
await client.check_health()
|
||||
with pytest.raises(RPCClientClosedError):
|
||||
async for _ in client.generate(None, None, None):
|
||||
pass
|
||||
|
||||
# But no-ops like aborting will pass
|
||||
await client.abort("test-request-id")
|
||||
await client.do_log_stats()
|
@ -6,7 +6,7 @@ import os
|
||||
import re
|
||||
import tempfile
|
||||
from argparse import Namespace
|
||||
from contextlib import asynccontextmanager
|
||||
from contextlib import asynccontextmanager, suppress
|
||||
from http import HTTPStatus
|
||||
from typing import AsyncIterator, Optional, Set
|
||||
|
||||
@ -83,7 +83,8 @@ async def lifespan(app: FastAPI):
|
||||
async def _force_log():
|
||||
while True:
|
||||
await asyncio.sleep(10)
|
||||
await async_engine_client.do_log_stats()
|
||||
with suppress(Exception):
|
||||
await async_engine_client.do_log_stats()
|
||||
|
||||
if not engine_args.disable_log_stats:
|
||||
task = asyncio.create_task(_force_log())
|
||||
|
@ -10,10 +10,6 @@ from vllm.sampling_params import SamplingParams
|
||||
# Success string used for RPC instructions.
|
||||
VLLM_RPC_SUCCESS_STR = "SUCCESS"
|
||||
|
||||
# Timeouts.
|
||||
VLLM_RPC_SERVER_START_TIMEOUT_MS = 1000
|
||||
VLLM_RPC_HEALTH_TIMEOUT_MS = 10000
|
||||
|
||||
# Minimum value of ZMQ.SOCKET_LIMIT to run mp.
|
||||
VLLM_RPC_SOCKET_LIMIT_CUTOFF = 2000
|
||||
|
||||
|
@ -1,5 +1,5 @@
|
||||
import asyncio
|
||||
from contextlib import contextmanager
|
||||
from contextlib import contextmanager, suppress
|
||||
from typing import Any, AsyncGenerator, Mapping, Optional
|
||||
from uuid import uuid4
|
||||
|
||||
@ -11,13 +11,12 @@ from vllm.config import (DecodingConfig, LoRAConfig, ModelConfig,
|
||||
ParallelConfig, SchedulerConfig)
|
||||
# yapf: disable
|
||||
from vllm.entrypoints.openai.rpc import (RPC_REQUEST_TYPE,
|
||||
VLLM_RPC_HEALTH_TIMEOUT_MS,
|
||||
VLLM_RPC_SERVER_START_TIMEOUT_MS,
|
||||
VLLM_RPC_SOCKET_LIMIT_CUTOFF,
|
||||
VLLM_RPC_SUCCESS_STR,
|
||||
VLLM_RPC_ZMQ_HWM, RPCAbortRequest,
|
||||
RPCGenerateRequest, RPCUtilityRequest)
|
||||
# yapf: enable
|
||||
from vllm.envs import VLLM_RPC_GET_DATA_TIMEOUT_MS
|
||||
from vllm.inputs import PromptInputs
|
||||
from vllm.logger import init_logger
|
||||
from vllm.lora.request import LoRARequest
|
||||
@ -32,6 +31,17 @@ logger = init_logger(__name__)
|
||||
INPROC_PROXY_PATH = f"inproc://{uuid4()}"
|
||||
|
||||
|
||||
class RPCClientClosedError(Exception):
|
||||
"""Exception class raised when the client is used post-close.
|
||||
|
||||
The client can be closed, which closes the ZMQ context. This normally
|
||||
happens on server shutdown. In some cases, methods like abort and
|
||||
do_log_stats will still be called and then try to open a socket, which
|
||||
causes a ZMQError and creates a huge stack trace.
|
||||
So, we throw this error such that we can suppress it.
|
||||
"""
|
||||
|
||||
|
||||
class AsyncEngineRPCClient:
|
||||
"""
|
||||
RPCClient that connects to the RPCServer wrapping AsyncLLMEngine.
|
||||
@ -85,6 +95,8 @@ class AsyncEngineRPCClient:
|
||||
|
||||
def __init__(self, rpc_path: str):
|
||||
self.context = zmq.asyncio.Context()
|
||||
self._data_timeout = VLLM_RPC_GET_DATA_TIMEOUT_MS
|
||||
self._errored = False
|
||||
|
||||
# Maximum number of sockets that can be opened (typically 65536).
|
||||
# ZMQ_SOCKET_LIMIT (http://api.zeromq.org/4-2:zmq-ctx-get)
|
||||
@ -143,7 +155,6 @@ class AsyncEngineRPCClient:
|
||||
|
||||
# Wait until server is ready.
|
||||
await self._wait_for_server_rpc()
|
||||
self._errored = False
|
||||
|
||||
# Get the configs.
|
||||
self.model_config = await self._get_model_config_rpc()
|
||||
@ -170,6 +181,15 @@ class AsyncEngineRPCClient:
|
||||
@contextmanager
|
||||
def to_proxy_socket(self):
|
||||
# Connect to the RPCServer via the proxy.
|
||||
|
||||
# Raise a sensible error if the client was already closed.
|
||||
# This can happen if a server shutdown is triggered but some coroutines
|
||||
# are still running requests.
|
||||
# There should not be a race condition with this check because we don't
|
||||
# yield to the event loop between here and opening the socket.
|
||||
if self.context.closed:
|
||||
raise RPCClientClosedError("The ZMQ client has already shut down")
|
||||
|
||||
# Note that we use DEALER to enable asynchronous communication
|
||||
# to enable streaming.
|
||||
socket = self.context.socket(zmq.constants.DEALER)
|
||||
@ -189,9 +209,18 @@ class AsyncEngineRPCClient:
|
||||
# Ping RPCServer with a request.
|
||||
await socket.send_multipart([cloudpickle.dumps(request)])
|
||||
|
||||
# Make sure the server responds
|
||||
if await socket.poll(timeout=self._data_timeout) == 0:
|
||||
raise TimeoutError("Server didn't reply within "
|
||||
f"{self._data_timeout} ms")
|
||||
|
||||
# Await the data from the Server.
|
||||
data = cloudpickle.loads(await socket.recv())
|
||||
|
||||
if isinstance(data, Exception):
|
||||
# Re-raise exceptions returned by the server
|
||||
raise data
|
||||
|
||||
if not isinstance(data, expected_type):
|
||||
# LoRAConfig can be None.
|
||||
if expected_type == LoRAConfig and data is None:
|
||||
@ -208,29 +237,28 @@ class AsyncEngineRPCClient:
|
||||
self,
|
||||
request: RPC_REQUEST_TYPE,
|
||||
error_message: str,
|
||||
timeout: Optional[int] = None,
|
||||
socket: Optional[zmq.asyncio.Socket] = None):
|
||||
"""Send one-way RPC request to trigger an action."""
|
||||
|
||||
async def do_rpc_call(socket: zmq.asyncio.Socket,
|
||||
request: RPC_REQUEST_TYPE,
|
||||
timeout=None):
|
||||
request: RPC_REQUEST_TYPE):
|
||||
|
||||
await socket.send_multipart([cloudpickle.dumps(request)])
|
||||
|
||||
if timeout is not None and await socket.poll(timeout=timeout) == 0:
|
||||
raise TimeoutError(f"Server didn't reply within {timeout} ms")
|
||||
if await socket.poll(timeout=self._data_timeout) == 0:
|
||||
raise TimeoutError("Server didn't reply within "
|
||||
f"{self._data_timeout} ms")
|
||||
|
||||
return cloudpickle.loads(await socket.recv())
|
||||
|
||||
# Make a new socket connection.
|
||||
if socket is None:
|
||||
with self.to_proxy_socket() as socket:
|
||||
response = await do_rpc_call(socket, request, timeout)
|
||||
response = await do_rpc_call(socket, request)
|
||||
|
||||
# Use existing socket connection.
|
||||
else:
|
||||
response = await do_rpc_call(socket, request, timeout)
|
||||
response = await do_rpc_call(socket, request)
|
||||
|
||||
if not isinstance(response, str) or response != VLLM_RPC_SUCCESS_STR:
|
||||
if isinstance(response, Exception):
|
||||
@ -255,8 +283,7 @@ class AsyncEngineRPCClient:
|
||||
|
||||
await self._send_one_way_rpc_request(
|
||||
request=RPCUtilityRequest.IS_SERVER_READY,
|
||||
error_message="Unable to start RPC Server",
|
||||
timeout=VLLM_RPC_SERVER_START_TIMEOUT_MS)
|
||||
error_message="Unable to start RPC Server")
|
||||
|
||||
async def _get_model_config_rpc(self) -> ModelConfig:
|
||||
"""Get the ModelConfig object from the RPC Server"""
|
||||
@ -308,17 +335,17 @@ class AsyncEngineRPCClient:
|
||||
|
||||
async def abort(self, request_id: str):
|
||||
"""Send an ABORT_REQUEST signal to the RPC Server"""
|
||||
|
||||
await self._send_one_way_rpc_request(
|
||||
request=RPCAbortRequest(request_id),
|
||||
error_message=f"RPCAbortRequest {request_id} failed")
|
||||
with suppress(RPCClientClosedError):
|
||||
await self._send_one_way_rpc_request(
|
||||
request=RPCAbortRequest(request_id),
|
||||
error_message=f"RPCAbortRequest {request_id} failed")
|
||||
|
||||
async def do_log_stats(self):
|
||||
"""Send a DO_LOG_STATS signal to the RPC Server"""
|
||||
|
||||
await self._send_one_way_rpc_request(
|
||||
request=RPCUtilityRequest.DO_LOG_STATS,
|
||||
error_message="RPCRequest DO_LOG_STATS failed.")
|
||||
with suppress(RPCClientClosedError):
|
||||
await self._send_one_way_rpc_request(
|
||||
request=RPCUtilityRequest.DO_LOG_STATS,
|
||||
error_message="RPCRequest DO_LOG_STATS failed.")
|
||||
|
||||
@property
|
||||
def is_running(self) -> bool:
|
||||
@ -393,7 +420,6 @@ class AsyncEngineRPCClient:
|
||||
await self._send_one_way_rpc_request(
|
||||
request=RPCUtilityRequest.IS_SERVER_HEALTHY,
|
||||
error_message="Got Unhealthy response from RPC Server",
|
||||
timeout=VLLM_RPC_HEALTH_TIMEOUT_MS,
|
||||
socket=socket)
|
||||
|
||||
async def encode(self, *args,
|
||||
|
@ -56,6 +56,7 @@ if TYPE_CHECKING:
|
||||
VERBOSE: bool = False
|
||||
VLLM_ALLOW_LONG_MAX_MODEL_LEN: bool = False
|
||||
VLLM_TEST_FORCE_FP8_MARLIN: bool = False
|
||||
VLLM_RPC_GET_DATA_TIMEOUT_MS: int = 5000
|
||||
VLLM_ALLOW_ENGINE_USE_RAY: bool = False
|
||||
VLLM_PLUGINS: Optional[List[str]] = None
|
||||
VLLM_TORCH_PROFILER_DIR: Optional[str] = None
|
||||
@ -374,6 +375,11 @@ environment_variables: Dict[str, Callable[[], Any]] = {
|
||||
(os.environ.get("VLLM_TEST_FORCE_FP8_MARLIN", "0").strip().lower() in
|
||||
("1", "true")),
|
||||
|
||||
# Time in ms for the zmq client to wait for a response from the backend
|
||||
# server for simple data operations
|
||||
"VLLM_RPC_GET_DATA_TIMEOUT_MS":
|
||||
lambda: int(os.getenv("VLLM_RPC_GET_DATA_TIMEOUT_MS", "5000")),
|
||||
|
||||
# If set, allow running the engine as a separate ray actor,
|
||||
# which is a deprecated feature soon to be removed.
|
||||
# See https://github.com/vllm-project/vllm/issues/7045
|
||||
|
Loading…
x
Reference in New Issue
Block a user