[V1][Frontend] Improve Shutdown And Logs (#11737)

Signed-off-by: rshaw@neuralmagic.com <rshaw@neuralmagic.com>
Signed-off-by: Andrew Feldman <afeldman@neuralmagic.com>
Signed-off-by: Nick Hill <nhill@redhat.com>
Co-authored-by: rshaw@neuralmagic.com <rshaw@neuralmagic.com>
Co-authored-by: Cyrus Leung <cyrus.tl.leung@gmail.com>
Co-authored-by: Russell Bryant <rbryant@redhat.com>
Co-authored-by: Andrew Feldman <afeldman@neuralmagic.com>
Co-authored-by: afeldman-nm <156691304+afeldman-nm@users.noreply.github.com>
Co-authored-by: Nick Hill <nhill@redhat.com>
This commit is contained in:
Robert Shaw 2025-04-16 22:48:34 -04:00 committed by GitHub
parent 3c776dcefb
commit 2b05b8ce69
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
16 changed files with 1031 additions and 347 deletions

View File

@ -552,6 +552,7 @@ steps:
# - pytest -v -s spec_decode/e2e/test_integration_dist_tp2.py # - pytest -v -s spec_decode/e2e/test_integration_dist_tp2.py
- VLLM_USE_V1=0 CUDA_VISIBLE_DEVICES=0,1 pytest -v -s test_sharded_state_loader.py - VLLM_USE_V1=0 CUDA_VISIBLE_DEVICES=0,1 pytest -v -s test_sharded_state_loader.py
- VLLM_USE_V1=0 CUDA_VISIBLE_DEVICES=0,1 pytest -v -s kv_transfer/test_disagg.py - VLLM_USE_V1=0 CUDA_VISIBLE_DEVICES=0,1 pytest -v -s kv_transfer/test_disagg.py
- CUDA_VISIBLE_DEVICES=0,1 pytest -v -s v1/shutdown
- label: Plugin Tests (2 GPUs) # 40min - label: Plugin Tests (2 GPUs) # 40min
working_dir: "/vllm-workspace/tests" working_dir: "/vllm-workspace/tests"

View File

@ -0,0 +1,97 @@
# SPDX-License-Identifier: Apache-2.0
"""Test that we handle a startup Error and shutdown."""
import pytest
from tests.utils import wait_for_gpu_memory_to_clear
from tests.v1.shutdown.utils import (SHUTDOWN_TEST_THRESHOLD_BYTES,
SHUTDOWN_TEST_TIMEOUT_SEC)
from vllm import LLM, SamplingParams
from vllm.engine.arg_utils import AsyncEngineArgs
from vllm.sampling_params import RequestOutputKind
from vllm.utils import cuda_device_count_stateless
from vllm.v1.engine.async_llm import AsyncLLM
MODELS = ["meta-llama/Llama-3.2-1B"]
@pytest.mark.asyncio
@pytest.mark.timeout(SHUTDOWN_TEST_TIMEOUT_SEC)
@pytest.mark.parametrize("model", MODELS)
@pytest.mark.parametrize("tensor_parallel_size", [2, 1])
@pytest.mark.parametrize("send_one_request", [False, True])
async def test_async_llm_delete(model: str, tensor_parallel_size: int,
send_one_request: bool) -> None:
"""Test that AsyncLLM frees GPU memory upon deletion.
AsyncLLM always uses an MP client.
Args:
model: model under test
tensor_parallel_size: degree of tensor parallelism
send_one_request: send one request to engine before deleting
"""
if cuda_device_count_stateless() < tensor_parallel_size:
pytest.skip(reason="Not enough CUDA devices")
engine_args = AsyncEngineArgs(model=model,
enforce_eager=True,
tensor_parallel_size=tensor_parallel_size)
# Instantiate AsyncLLM; make request to complete any deferred
# initialization; then delete instance
async_llm = AsyncLLM.from_engine_args(engine_args)
if send_one_request:
async for _ in async_llm.generate(
"Hello my name is",
request_id="abc",
sampling_params=SamplingParams(
max_tokens=1, output_kind=RequestOutputKind.DELTA)):
pass
del async_llm
# Confirm all the processes are cleaned up.
wait_for_gpu_memory_to_clear(
devices=list(range(tensor_parallel_size)),
threshold_bytes=SHUTDOWN_TEST_THRESHOLD_BYTES,
)
@pytest.mark.timeout(SHUTDOWN_TEST_TIMEOUT_SEC)
@pytest.mark.parametrize("model", MODELS)
@pytest.mark.parametrize("tensor_parallel_size", [2, 1])
@pytest.mark.parametrize("enable_multiprocessing", [True])
@pytest.mark.parametrize("send_one_request", [False, True])
def test_llm_delete(monkeypatch, model: str, tensor_parallel_size: int,
enable_multiprocessing: bool,
send_one_request: bool) -> None:
"""Test that LLM frees GPU memory upon deletion.
TODO(andy) - LLM without multiprocessing.
Args:
model: model under test
tensor_parallel_size: degree of tensor parallelism
enable_multiprocessing: enable workers in separate process(es)
send_one_request: send one request to engine before deleting
"""
if cuda_device_count_stateless() < tensor_parallel_size:
pytest.skip(reason="Not enough CUDA devices")
with monkeypatch.context() as m:
MP_VALUE = "1" if enable_multiprocessing else "0"
m.setenv("VLLM_ENABLE_V1_MULTIPROCESSING", MP_VALUE)
# Instantiate LLM; make request to complete any deferred
# initialization; then delete instance
llm = LLM(model=model,
enforce_eager=True,
tensor_parallel_size=tensor_parallel_size)
if send_one_request:
llm.generate("Hello my name is",
sampling_params=SamplingParams(max_tokens=1))
del llm
# Confirm all the processes are cleaned up.
wait_for_gpu_memory_to_clear(
devices=list(range(tensor_parallel_size)),
threshold_bytes=SHUTDOWN_TEST_THRESHOLD_BYTES,
)

View File

@ -0,0 +1,129 @@
# SPDX-License-Identifier: Apache-2.0
"""Test that we handle an Error in model forward and shutdown."""
import asyncio
import pytest
from tests.utils import wait_for_gpu_memory_to_clear
from tests.v1.shutdown.utils import (SHUTDOWN_TEST_THRESHOLD_BYTES,
SHUTDOWN_TEST_TIMEOUT_SEC)
from vllm import LLM, AsyncEngineArgs, SamplingParams
from vllm.distributed import get_tensor_model_parallel_rank
from vllm.model_executor.models.llama import LlamaForCausalLM
from vllm.utils import cuda_device_count_stateless
from vllm.v1.engine.async_llm import AsyncLLM
from vllm.v1.engine.exceptions import EngineDeadError
MODELS = ["meta-llama/Llama-3.2-1B"]
def evil_forward(self, *args, **kwargs):
"""Evil forward method that raise an exception after 10 calls."""
NUMBER_OF_GOOD_PASSES = 10
if not hasattr(self, "num_calls"):
self.num_calls = 0
if (self.num_calls == NUMBER_OF_GOOD_PASSES
and get_tensor_model_parallel_rank() == 0):
raise Exception("Simulated illegal memory access on Rank 0!")
self.num_calls += 1
return self.model(*args, **kwargs)
@pytest.mark.asyncio
@pytest.mark.parametrize("tensor_parallel_size", [2, 1])
@pytest.mark.parametrize("model", MODELS)
async def test_async_llm_model_error(monkeypatch, tensor_parallel_size: int,
model: str) -> None:
"""Test that AsyncLLM propagates a forward pass error and frees memory.
AsyncLLM always uses an MP client.
"""
if cuda_device_count_stateless() < tensor_parallel_size:
pytest.skip(reason="Not enough CUDA devices")
# Monkeypatch an error in the model.
monkeypatch.setattr(LlamaForCausalLM, "forward", evil_forward)
engine_args = AsyncEngineArgs(model=model,
enforce_eager=True,
tensor_parallel_size=tensor_parallel_size)
async_llm = AsyncLLM.from_engine_args(engine_args)
async def generate(request_id: str):
generator = async_llm.generate("Hello my name is",
request_id=request_id,
sampling_params=SamplingParams())
try:
async for _ in generator:
pass
except Exception as e:
return e
NUM_REQS = 3
tasks = [generate(f"request-{idx}") for idx in range(NUM_REQS)]
outputs = await asyncio.gather(*tasks)
# Every request should get an EngineDeadError.
for output in outputs:
assert isinstance(output, EngineDeadError)
# AsyncLLM should be errored.
assert async_llm.errored
# We should not be able to make another request.
with pytest.raises(EngineDeadError):
async for _ in async_llm.generate("Hello my name is",
request_id="abc",
sampling_params=SamplingParams()):
raise Exception("We should not get here.")
# Confirm all the processes are cleaned up.
wait_for_gpu_memory_to_clear(
devices=list(range(tensor_parallel_size)),
threshold_bytes=2 * 2**30,
timeout_s=60,
)
# NOTE: shutdown is handled by the API Server if an exception
# occurs, so it is expected that we would need to call this.
async_llm.shutdown()
@pytest.mark.timeout(SHUTDOWN_TEST_TIMEOUT_SEC)
@pytest.mark.parametrize("enable_multiprocessing", [True])
@pytest.mark.parametrize("tensor_parallel_size", [2, 1])
@pytest.mark.parametrize("model", MODELS)
def test_llm_model_error(monkeypatch, tensor_parallel_size: int,
enable_multiprocessing: bool, model: str) -> None:
"""Test that LLM propagates a forward pass error and frees memory.
TODO(andy) - LLM without multiprocessing; LLM with multiprocessing
and >1 rank
"""
if cuda_device_count_stateless() < tensor_parallel_size:
pytest.skip(reason="Not enough CUDA devices")
with monkeypatch.context() as m:
MP_VALUE = "1" if enable_multiprocessing else "0"
m.setenv("VLLM_ENABLE_V1_MULTIPROCESSING", MP_VALUE)
# Monkeypatch an error in the model.
m.setattr(LlamaForCausalLM, "forward", evil_forward)
llm = LLM(model=model,
enforce_eager=True,
tensor_parallel_size=tensor_parallel_size)
with pytest.raises(
EngineDeadError if enable_multiprocessing else Exception):
llm.generate("Hello my name is Robert and I")
# Confirm all the processes are cleaned up.
wait_for_gpu_memory_to_clear(
devices=list(range(tensor_parallel_size)),
threshold_bytes=SHUTDOWN_TEST_THRESHOLD_BYTES,
)

View File

@ -0,0 +1,69 @@
# SPDX-License-Identifier: Apache-2.0
"""Test error handling in Processor. Should not impact other reqs."""
import asyncio
import pytest
from tests.v1.shutdown.utils import SHUTDOWN_TEST_TIMEOUT_SEC
from vllm import SamplingParams
from vllm.engine.arg_utils import AsyncEngineArgs
from vllm.inputs.data import TokensPrompt
from vllm.sampling_params import RequestOutputKind
from vllm.v1.engine.async_llm import AsyncLLM
from vllm.v1.engine.exceptions import EngineGenerateError
MODELS = ["meta-llama/Llama-3.2-1B"]
@pytest.mark.asyncio
@pytest.mark.timeout(SHUTDOWN_TEST_TIMEOUT_SEC)
@pytest.mark.parametrize("model", MODELS)
async def test_async_llm_processor_error(model: str) -> None:
"""Test that AsyncLLM propagates a processor error.
Test empty tokens prompt (failure) and non-empty prompt (no failure.)
AsyncLLM always uses an MP client.
"""
engine_args = AsyncEngineArgs(model=model, enforce_eager=True)
async_llm = AsyncLLM.from_engine_args(engine_args)
async def generate(request_id: str):
# [] is not allowed and will raise a ValueError in Processor.
generator = async_llm.generate(TokensPrompt([]),
request_id=request_id,
sampling_params=SamplingParams())
try:
async for _ in generator:
pass
except Exception as e:
return e
NUM_REQS = 3
tasks = [generate(f"request-{idx}") for idx in range(NUM_REQS)]
outputs = await asyncio.gather(*tasks)
# Every request should have get an EngineGenerateError.
for output in outputs:
with pytest.raises(EngineGenerateError):
raise output
# AsyncLLM should be errored.
assert not async_llm.errored
# This should be no problem.
EXPECTED_TOKENS = 5
outputs = []
async for out in async_llm.generate(
"Hello my name is",
request_id="abc",
sampling_params=SamplingParams(
max_tokens=EXPECTED_TOKENS,
output_kind=RequestOutputKind.DELTA)):
outputs.append(out)
generated_tokens = []
for out in outputs:
generated_tokens.extend(out.outputs[0].token_ids)
assert len(generated_tokens) == EXPECTED_TOKENS
async_llm.shutdown()

View File

@ -0,0 +1,97 @@
# SPDX-License-Identifier: Apache-2.0
"""Test that we handle a startup Error and shutdown."""
import pytest
from tests.utils import wait_for_gpu_memory_to_clear
from tests.v1.shutdown.utils import (SHUTDOWN_TEST_THRESHOLD_BYTES,
SHUTDOWN_TEST_TIMEOUT_SEC)
from vllm import LLM
from vllm.distributed import get_tensor_model_parallel_rank
from vllm.engine.arg_utils import AsyncEngineArgs
from vllm.model_executor.models.llama import LlamaForCausalLM
from vllm.utils import cuda_device_count_stateless
from vllm.v1.engine.async_llm import AsyncLLM
MODELS = ["meta-llama/Llama-3.2-1B"]
def evil_method(self, *args, **kwargs):
"""Evil method that raises an exception."""
if get_tensor_model_parallel_rank() == 0:
raise Exception("Simulated Error in startup!")
return self.model(*args, **kwargs, intermediate_tensors=None)
@pytest.mark.timeout(SHUTDOWN_TEST_TIMEOUT_SEC)
@pytest.mark.parametrize("model", MODELS)
@pytest.mark.parametrize("tensor_parallel_size", [2, 1])
@pytest.mark.parametrize("failing_method", ["forward", "load_weights"])
def test_async_llm_startup_error(monkeypatch, model: str,
tensor_parallel_size: int,
failing_method: str) -> None:
"""Test that AsyncLLM propagates an __init__ error & frees memory.
Test profiling (forward()) and load weights failures.
AsyncLLM always uses an MP client.
"""
if cuda_device_count_stateless() < tensor_parallel_size:
pytest.skip(reason="Not enough CUDA devices")
# Monkeypatch an error in the model.
monkeypatch.setattr(LlamaForCausalLM, failing_method, evil_method)
engine_args = AsyncEngineArgs(model=model,
enforce_eager=True,
tensor_parallel_size=tensor_parallel_size)
# Confirm we get an exception.
with pytest.raises(Exception, match="initialization failed"):
_ = AsyncLLM.from_engine_args(engine_args)
# Confirm all the processes are cleaned up.
wait_for_gpu_memory_to_clear(
devices=list(range(tensor_parallel_size)),
threshold_bytes=SHUTDOWN_TEST_THRESHOLD_BYTES,
)
@pytest.mark.timeout(SHUTDOWN_TEST_TIMEOUT_SEC)
@pytest.mark.parametrize("model", MODELS)
@pytest.mark.parametrize("tensor_parallel_size", [2, 1])
@pytest.mark.parametrize("enable_multiprocessing", [True])
@pytest.mark.parametrize("failing_method", ["forward", "load_weights"])
def test_llm_startup_error(monkeypatch, model: str, tensor_parallel_size: int,
enable_multiprocessing: bool,
failing_method: str) -> None:
"""Test that LLM propagates an __init__ error and frees memory.
Test profiling (forward()) and load weights failures.
TODO(andy) - LLM without multiprocessing.
"""
if model != "meta-llama/Llama-3.2-1B":
pytest.skip(reason="Only test meta-llama/Llama-3.2-1B")
if cuda_device_count_stateless() < tensor_parallel_size:
pytest.skip(reason="Not enough CUDA devices")
with monkeypatch.context() as m:
MP_VALUE = "1" if enable_multiprocessing else "0"
m.setenv("VLLM_ENABLE_V1_MULTIPROCESSING", MP_VALUE)
# Monkeypatch an error in the model.
monkeypatch.setattr(LlamaForCausalLM, failing_method, evil_method)
with pytest.raises(
Exception,
match="initialization failed"
if enable_multiprocessing else "Simulated Error in startup!"):
_ = LLM(model=model,
enforce_eager=True,
tensor_parallel_size=tensor_parallel_size)
# Confirm all the processes are cleaned up.
wait_for_gpu_memory_to_clear(
devices=list(range(tensor_parallel_size)),
threshold_bytes=SHUTDOWN_TEST_THRESHOLD_BYTES,
)

View File

@ -0,0 +1,5 @@
# SPDX-License-Identifier: Apache-2.0
"""Shutdown test utils"""
SHUTDOWN_TEST_TIMEOUT_SEC = 120
SHUTDOWN_TEST_THRESHOLD_BYTES = 2 * 2**30

View File

@ -7,11 +7,13 @@ import time
from contextlib import contextmanager from contextlib import contextmanager
from dataclasses import dataclass, field from dataclasses import dataclass, field
from multiprocessing import shared_memory from multiprocessing import shared_memory
from typing import List, Optional, Tuple, Union from threading import Event
from typing import Any, List, Optional, Tuple, Union
from unittest.mock import patch from unittest.mock import patch
import torch import torch
import torch.distributed as dist import torch.distributed as dist
import zmq
from torch.distributed import ProcessGroup from torch.distributed import ProcessGroup
from zmq import IPV6 # type: ignore from zmq import IPV6 # type: ignore
from zmq import SUB, SUBSCRIBE, XPUB, XPUB_VERBOSE, Context # type: ignore from zmq import SUB, SUBSCRIBE, XPUB, XPUB_VERBOSE, Context # type: ignore
@ -400,7 +402,9 @@ class MessageQueue:
break break
@contextmanager @contextmanager
def acquire_read(self, timeout: Optional[float] = None): def acquire_read(self,
timeout: Optional[float] = None,
cancel: Optional[Event] = None):
assert self._is_local_reader, "Only readers can acquire read" assert self._is_local_reader, "Only readers can acquire read"
start_time = time.monotonic() start_time = time.monotonic()
n_warning = 1 n_warning = 1
@ -430,6 +434,9 @@ class MessageQueue:
) )
n_warning += 1 n_warning += 1
if cancel is not None and cancel.is_set():
raise RuntimeError("cancelled")
# if we time out, raise an exception # if we time out, raise an exception
if (timeout is not None if (timeout is not None
and time.monotonic() - start_time > timeout): and time.monotonic() - start_time > timeout):
@ -464,10 +471,12 @@ class MessageQueue:
if self.n_remote_reader > 0: if self.n_remote_reader > 0:
self.remote_socket.send(serialized_obj) self.remote_socket.send(serialized_obj)
def dequeue(self, timeout: Optional[float] = None): def dequeue(self,
timeout: Optional[float] = None,
cancel: Optional[Event] = None):
""" Read from message queue with optional timeout (in seconds) """ """ Read from message queue with optional timeout (in seconds) """
if self._is_local_reader: if self._is_local_reader:
with self.acquire_read(timeout) as buf: with self.acquire_read(timeout, cancel) as buf:
overflow = buf[0] == 1 overflow = buf[0] == 1
if not overflow: if not overflow:
# no need to know the size of serialized object # no need to know the size of serialized object
@ -475,15 +484,21 @@ class MessageQueue:
# see https://docs.python.org/3/library/pickle.html # see https://docs.python.org/3/library/pickle.html
obj = pickle.loads(buf[1:]) obj = pickle.loads(buf[1:])
if overflow: if overflow:
recv = self.local_socket.recv() obj = MessageQueue.recv(self.local_socket, timeout)
obj = pickle.loads(recv)
elif self._is_remote_reader: elif self._is_remote_reader:
recv = self.remote_socket.recv() obj = MessageQueue.recv(self.remote_socket, timeout)
obj = pickle.loads(recv)
else: else:
raise RuntimeError("Only readers can dequeue") raise RuntimeError("Only readers can dequeue")
return obj return obj
@staticmethod
def recv(socket: zmq.Socket, timeout: Optional[float]) -> Any:
timeout_ms = None if timeout is None else int(timeout * 1000)
if not socket.poll(timeout=timeout_ms):
raise TimeoutError
recv = socket.recv(copy=False)
return pickle.loads(recv.buffer)
def broadcast_object(self, obj=None): def broadcast_object(self, obj=None):
if self._is_writer: if self._is_writer:
self.enqueue(obj) self.enqueue(obj)

View File

@ -12,9 +12,11 @@ from fastapi import FastAPI, Request, Response
from vllm import envs from vllm import envs
from vllm.engine.async_llm_engine import AsyncEngineDeadError from vllm.engine.async_llm_engine import AsyncEngineDeadError
from vllm.engine.multiprocessing import MQEngineDeadError from vllm.engine.multiprocessing import MQEngineDeadError
from vllm.engine.protocol import EngineClient
from vllm.entrypoints.ssl import SSLCertRefresher from vllm.entrypoints.ssl import SSLCertRefresher
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.utils import find_process_using_port from vllm.utils import find_process_using_port
from vllm.v1.engine.exceptions import EngineDeadError, EngineGenerateError
logger = init_logger(__name__) logger = init_logger(__name__)
@ -40,6 +42,8 @@ async def serve_http(app: FastAPI,
loop = asyncio.get_running_loop() loop = asyncio.get_running_loop()
watchdog_task = loop.create_task(
watchdog_loop(server, app.state.engine_client))
server_task = loop.create_task( server_task = loop.create_task(
server.serve(sockets=[sock] if sock else None)) server.serve(sockets=[sock] if sock else None))
@ -52,6 +56,7 @@ async def serve_http(app: FastAPI,
def signal_handler() -> None: def signal_handler() -> None:
# prevents the uvicorn signal handler to exit early # prevents the uvicorn signal handler to exit early
server_task.cancel() server_task.cancel()
watchdog_task.cancel()
if ssl_cert_refresher: if ssl_cert_refresher:
ssl_cert_refresher.stop() ssl_cert_refresher.stop()
@ -73,48 +78,69 @@ async def serve_http(app: FastAPI,
port, process, " ".join(process.cmdline())) port, process, " ".join(process.cmdline()))
logger.info("Shutting down FastAPI HTTP server.") logger.info("Shutting down FastAPI HTTP server.")
return server.shutdown() return server.shutdown()
finally:
watchdog_task.cancel()
async def watchdog_loop(server: uvicorn.Server, engine: EngineClient):
"""
# Watchdog task that runs in the background, checking
# for error state in the engine. Needed to trigger shutdown
# if an exception arises is StreamingResponse() generator.
"""
VLLM_WATCHDOG_TIME_S = 5.0
while True:
await asyncio.sleep(VLLM_WATCHDOG_TIME_S)
terminate_if_errored(server, engine)
def terminate_if_errored(server: uvicorn.Server, engine: EngineClient):
"""
See discussions here on shutting down a uvicorn server
https://github.com/encode/uvicorn/discussions/1103
In this case we cannot await the server shutdown here
because handler must first return to close the connection
for this request.
"""
engine_errored = engine.errored and not engine.is_running
if not envs.VLLM_KEEP_ALIVE_ON_ENGINE_DEATH and engine_errored:
server.should_exit = True
def _add_shutdown_handlers(app: FastAPI, server: uvicorn.Server) -> None: def _add_shutdown_handlers(app: FastAPI, server: uvicorn.Server) -> None:
"""Adds handlers for fatal errors that should crash the server""" """
VLLM V1 AsyncLLM catches exceptions and returns
only two types: EngineGenerateError and EngineDeadError.
EngineGenerateError is raised by the per request generate()
method. This error could be request specific (and therefore
recoverable - e.g. if there is an error in input processing).
EngineDeadError is raised by the background output_handler
method. This error is global and therefore not recoverable.
We register these @app.exception_handlers to return nice
responses to the end user if they occur and shut down if needed.
See https://fastapi.tiangolo.com/tutorial/handling-errors/
for more details on how exception handlers work.
If an exception is encountered in a StreamingResponse
generator, the exception is not raised, since we already sent
a 200 status. Rather, we send an error message as the next chunk.
Since the exception is not raised, this means that the server
will not automatically shut down. Instead, we use the watchdog
background task for check for errored state.
"""
@app.exception_handler(RuntimeError) @app.exception_handler(RuntimeError)
async def runtime_error_handler(request: Request, __):
"""On generic runtime error, check to see if the engine has died.
It probably has, in which case the server will no longer be able to
handle requests. Trigger a graceful shutdown with a SIGTERM."""
engine = request.app.state.engine_client
if (not envs.VLLM_KEEP_ALIVE_ON_ENGINE_DEATH and engine.errored
and not engine.is_running):
logger.fatal("AsyncLLMEngine has failed, terminating server "
"process")
# See discussions here on shutting down a uvicorn server
# https://github.com/encode/uvicorn/discussions/1103
# In this case we cannot await the server shutdown here because
# this handler must first return to close the connection for
# this request.
server.should_exit = True
return Response(status_code=HTTPStatus.INTERNAL_SERVER_ERROR)
@app.exception_handler(AsyncEngineDeadError) @app.exception_handler(AsyncEngineDeadError)
async def async_engine_dead_handler(_, __):
"""Kill the server if the async engine is already dead. It will
not handle any further requests."""
if not envs.VLLM_KEEP_ALIVE_ON_ENGINE_DEATH:
logger.fatal("AsyncLLMEngine is already dead, terminating server "
"process")
server.should_exit = True
return Response(status_code=HTTPStatus.INTERNAL_SERVER_ERROR)
@app.exception_handler(MQEngineDeadError) @app.exception_handler(MQEngineDeadError)
async def mq_engine_dead_handler(_, __): @app.exception_handler(EngineDeadError)
"""Kill the server if the mq engine is already dead. It will @app.exception_handler(EngineGenerateError)
not handle any further requests.""" async def runtime_exception_handler(request: Request, __):
if not envs.VLLM_KEEP_ALIVE_ON_ENGINE_DEATH: terminate_if_errored(
logger.fatal("MQLLMEngine is already dead, terminating server " server=server,
"process") engine=request.app.state.engine_client,
server.should_exit = True )
return Response(status_code=HTTPStatus.INTERNAL_SERVER_ERROR) return Response(status_code=HTTPStatus.INTERNAL_SERVER_ERROR)

View File

@ -156,3 +156,5 @@ class EngineCoreRequestType(enum.Enum):
ABORT = b'\x01' ABORT = b'\x01'
START_DP = b'\x02' START_DP = b'\x02'
UTILITY = b'\x03' UTILITY = b'\x03'
# Sentinel used within EngineCoreProc.
EXECUTOR_FAILED = b'\x04'

View File

@ -1,8 +1,6 @@
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
import asyncio import asyncio
import logging import logging
import os
from collections.abc import AsyncGenerator, Mapping from collections.abc import AsyncGenerator, Mapping
from copy import copy from copy import copy
from typing import Optional, Union from typing import Optional, Union
@ -26,9 +24,10 @@ from vllm.sampling_params import SamplingParams
from vllm.transformers_utils.tokenizer import AnyTokenizer from vllm.transformers_utils.tokenizer import AnyTokenizer
from vllm.transformers_utils.tokenizer_group import init_tokenizer_from_configs from vllm.transformers_utils.tokenizer_group import init_tokenizer_from_configs
from vllm.usage.usage_lib import UsageContext from vllm.usage.usage_lib import UsageContext
from vllm.utils import Device, cdiv, kill_process_tree from vllm.utils import Device, cdiv
from vllm.v1.engine import EngineCoreRequest from vllm.v1.engine import EngineCoreRequest
from vllm.v1.engine.core_client import EngineCoreClient from vllm.v1.engine.core_client import AsyncMPClient, DPAsyncMPClient
from vllm.v1.engine.exceptions import EngineDeadError, EngineGenerateError
from vllm.v1.engine.output_processor import (OutputProcessor, from vllm.v1.engine.output_processor import (OutputProcessor,
RequestOutputCollector) RequestOutputCollector)
from vllm.v1.engine.parallel_sampling import ParentRequest from vllm.v1.engine.parallel_sampling import ParentRequest
@ -61,8 +60,6 @@ class AsyncLLM(EngineClient):
"AsyncLLMEngine.from_vllm_config(...) or explicitly set " "AsyncLLMEngine.from_vllm_config(...) or explicitly set "
"VLLM_USE_V1=0 or 1 and report this issue on Github.") "VLLM_USE_V1=0 or 1 and report this issue on Github.")
assert start_engine_loop
self.model_config = vllm_config.model_config self.model_config = vllm_config.model_config
self.vllm_config = vllm_config self.vllm_config = vllm_config
self.log_requests = log_requests self.log_requests = log_requests
@ -99,15 +96,23 @@ class AsyncLLM(EngineClient):
log_stats=self.log_stats) log_stats=self.log_stats)
# EngineCore (starts the engine in background process). # EngineCore (starts the engine in background process).
self.engine_core = EngineCoreClient.make_client( core_client_class = AsyncMPClient if (
multiprocess_mode=True, vllm_config.parallel_config.data_parallel_size
asyncio_mode=True, == 1) else DPAsyncMPClient
self.engine_core = core_client_class(
vllm_config=vllm_config, vllm_config=vllm_config,
executor_class=executor_class, executor_class=executor_class,
log_stats=self.log_stats, log_stats=self.log_stats,
) )
self.output_handler: Optional[asyncio.Task] = None self.output_handler: Optional[asyncio.Task] = None
try:
# Start output handler eagerly if we are in the asyncio eventloop.
asyncio.get_running_loop()
self._run_output_handler()
except RuntimeError:
pass
@classmethod @classmethod
def from_vllm_config( def from_vllm_config(
@ -165,6 +170,9 @@ class AsyncLLM(EngineClient):
usage_context=usage_context, usage_context=usage_context,
) )
def __del__(self):
self.shutdown()
def shutdown(self): def shutdown(self):
"""Shutdown, cleaning up the background proc and IPC.""" """Shutdown, cleaning up the background proc and IPC."""
@ -187,6 +195,9 @@ class AsyncLLM(EngineClient):
) -> RequestOutputCollector: ) -> RequestOutputCollector:
"""Add new request to the AsyncLLM.""" """Add new request to the AsyncLLM."""
if self.errored:
raise EngineDeadError()
assert isinstance(params, SamplingParams), \ assert isinstance(params, SamplingParams), \
"Pooling is not supported in V1" "Pooling is not supported in V1"
@ -261,9 +272,7 @@ class AsyncLLM(EngineClient):
# We start the output_handler on the first call to generate() so # We start the output_handler on the first call to generate() so
# we can call __init__ before the event loop, which enables us # we can call __init__ before the event loop, which enables us
# to handle startup failure gracefully in the OpenAI server. # to handle startup failure gracefully in the OpenAI server.
if self.output_handler is None: self._run_output_handler()
self.output_handler = asyncio.create_task(
self._run_output_handler())
q = await self.add_request( q = await self.add_request(
request_id, request_id,
@ -288,62 +297,96 @@ class AsyncLLM(EngineClient):
finished = out.finished finished = out.finished
yield out yield out
# If the request is disconnected by the client, the # If the request is disconnected by the client, generate()
# generate() task will be canceled. So, we abort the # is cancelled. So, we abort the request if we end up here.
# request if we end up here.
except asyncio.CancelledError: except asyncio.CancelledError:
await self.abort(request_id) await self.abort(request_id)
if self.log_requests:
logger.info("Request %s aborted.", request_id)
raise raise
async def _run_output_handler(self): # Engine is dead. Do not abort since we shut down.
except EngineDeadError:
if self.log_requests:
logger.info("Request %s failed (engine dead).", request_id)
raise
# Request validation error.
except ValueError:
if self.log_requests:
logger.info("Request %s failed (bad request).", request_id)
raise
# Unexpected error in the generate() task (possibly recoverable).
except Exception as e:
await self.abort(request_id)
if self.log_requests:
logger.info("Request %s failed.", request_id)
raise EngineGenerateError() from e
def _run_output_handler(self):
"""Background loop: pulls from EngineCore and pushes to AsyncStreams.""" """Background loop: pulls from EngineCore and pushes to AsyncStreams."""
try: if self.output_handler is not None:
while True: return
# 1) Pull EngineCoreOutputs from the EngineCore.
outputs = await self.engine_core.get_output_async()
num_outputs = len(outputs.outputs)
iteration_stats = IterationStats() if ( # Ensure that the task doesn't have a circular ref back to the AsyncLLM
self.log_stats and num_outputs) else None # object, or else it won't be garbage collected and cleaned up properly.
engine_core = self.engine_core
output_processor = self.output_processor
log_stats = self.log_stats
stat_loggers = self.stat_loggers if log_stats else None
# Split outputs into chunks of at most async def output_handler():
# VLLM_V1_OUTPUT_PROC_CHUNK_SIZE, so that we don't block the try:
# event loop for too long. while True:
if num_outputs <= VLLM_V1_OUTPUT_PROC_CHUNK_SIZE: # 1) Pull EngineCoreOutputs from the EngineCore.
slices = (outputs.outputs, ) outputs = await engine_core.get_output_async()
else: num_outputs = len(outputs.outputs)
slices = np.array_split(
outputs.outputs,
cdiv(num_outputs, VLLM_V1_OUTPUT_PROC_CHUNK_SIZE))
for i, outputs_slice in enumerate(slices): iteration_stats = IterationStats() if (
# 2) Process EngineCoreOutputs. log_stats and num_outputs) else None
processed_outputs = self.output_processor.process_outputs(
outputs_slice, outputs.timestamp, iteration_stats)
# NOTE: RequestOutputs are pushed to their queues.
assert not processed_outputs.request_outputs
# Allow other asyncio tasks to run between chunks # Split outputs into chunks of at most
if i + 1 < len(slices): # VLLM_V1_OUTPUT_PROC_CHUNK_SIZE, so that we don't block the
await asyncio.sleep(0) # event loop for too long.
if num_outputs <= VLLM_V1_OUTPUT_PROC_CHUNK_SIZE:
slices = (outputs.outputs, )
else:
slices = np.array_split(
outputs.outputs,
cdiv(num_outputs, VLLM_V1_OUTPUT_PROC_CHUNK_SIZE))
# 3) Abort any reqs that finished due to stop strings. for i, outputs_slice in enumerate(slices):
await self.engine_core.abort_requests_async( # 2) Process EngineCoreOutputs.
processed_outputs.reqs_to_abort) processed_outputs = output_processor.process_outputs(
outputs_slice, outputs.timestamp, iteration_stats)
# NOTE: RequestOutputs are pushed to their queues.
assert not processed_outputs.request_outputs
# 4) Logging. # Allow other asyncio tasks to run between chunks
# TODO(rob): make into a coroutine and launch it in if i + 1 < len(slices):
# background thread once Prometheus overhead is non-trivial. await asyncio.sleep(0)
self._record_stats(
engine_index=outputs.engine_index,
scheduler_stats=outputs.scheduler_stats,
iteration_stats=iteration_stats,
)
except Exception as e: # 3) Abort any reqs that finished due to stop strings.
logger.exception("EngineCore output handler hit an error: %s", e) await engine_core.abort_requests_async(
kill_process_tree(os.getpid()) processed_outputs.reqs_to_abort)
# 4) Logging.
# TODO(rob): make into a coroutine and launch it in
# background thread once Prometheus overhead is non-trivial.
if stat_loggers:
assert outputs.scheduler_stats is not None
AsyncLLM._record_stats(
stat_loggers[outputs.engine_index],
scheduler_stats=outputs.scheduler_stats,
iteration_stats=iteration_stats,
)
except Exception as e:
logger.exception("AsyncLLM output_handler failed.")
output_processor.propagate_error(e)
self.output_handler = asyncio.create_task(output_handler())
async def abort(self, request_id: str) -> None: async def abort(self, request_id: str) -> None:
"""Abort RequestId in OutputProcessor and EngineCore.""" """Abort RequestId in OutputProcessor and EngineCore."""
@ -354,17 +397,15 @@ class AsyncLLM(EngineClient):
if self.log_requests: if self.log_requests:
logger.info("Aborted request %s.", request_id) logger.info("Aborted request %s.", request_id)
@staticmethod
def _record_stats( def _record_stats(
self, stat_loggers: list[StatLoggerBase],
scheduler_stats: Optional[SchedulerStats], scheduler_stats: SchedulerStats,
iteration_stats: Optional[IterationStats], iteration_stats: Optional[IterationStats],
engine_index: int = 0,
): ):
if not self.log_stats: """static so that it can be used from the output_handler task
return without a circular ref to AsyncLLM."""
for stat_logger in stat_loggers:
assert scheduler_stats is not None
for stat_logger in self.stat_loggers[engine_index]:
stat_logger.record(scheduler_stats=scheduler_stats, stat_logger.record(scheduler_stats=scheduler_stats,
iteration_stats=iteration_stats) iteration_stats=iteration_stats)
@ -451,16 +492,17 @@ class AsyncLLM(EngineClient):
@property @property
def is_running(self) -> bool: def is_running(self) -> bool:
return True # Is None before the loop is started.
return self.output_handler is None or not self.output_handler.done()
@property @property
def is_stopped(self) -> bool: def is_stopped(self) -> bool:
return False return self.errored
@property @property
def errored(self) -> bool: def errored(self) -> bool:
return False return self.engine_core.resources.engine_dead or not self.is_running
@property @property
def dead_error(self) -> BaseException: def dead_error(self) -> BaseException:
return Exception() # TODO: implement return EngineDeadError()

View File

@ -11,9 +11,7 @@ from logging import DEBUG
from typing import Any, Callable, Optional, TypeVar, Union from typing import Any, Callable, Optional, TypeVar, Union
import msgspec import msgspec
import psutil
import zmq import zmq
import zmq.asyncio
from vllm.config import ParallelConfig, VllmConfig from vllm.config import ParallelConfig, VllmConfig
from vllm.distributed import stateless_destroy_torch_distributed_process_group from vllm.distributed import stateless_destroy_torch_distributed_process_group
@ -22,8 +20,7 @@ from vllm.logger import init_logger
from vllm.lora.request import LoRARequest from vllm.lora.request import LoRARequest
from vllm.transformers_utils.config import ( from vllm.transformers_utils.config import (
maybe_register_config_serialize_by_value) maybe_register_config_serialize_by_value)
from vllm.utils import (get_exception_traceback, resolve_obj_by_qualname, from vllm.utils import resolve_obj_by_qualname, zmq_socket_ctx
zmq_socket_ctx)
from vllm.v1.core.kv_cache_utils import (get_kv_cache_config, from vllm.v1.core.kv_cache_utils import (get_kv_cache_config,
unify_kv_cache_configs) unify_kv_cache_configs)
from vllm.v1.core.sched.interface import SchedulerInterface from vllm.v1.core.sched.interface import SchedulerInterface
@ -50,12 +47,11 @@ _R = TypeVar('_R') # Return type for collective_rpc
class EngineCore: class EngineCore:
"""Inner loop of vLLM's Engine.""" """Inner loop of vLLM's Engine."""
def __init__( def __init__(self,
self, vllm_config: VllmConfig,
vllm_config: VllmConfig, executor_class: type[Executor],
executor_class: type[Executor], log_stats: bool,
log_stats: bool, executor_fail_callback: Optional[Callable] = None):
):
assert vllm_config.model_config.runner_type != "pooling" assert vllm_config.model_config.runner_type != "pooling"
logger.info("Initializing a V1 LLM engine (v%s) with config: %s", logger.info("Initializing a V1 LLM engine (v%s) with config: %s",
@ -65,6 +61,9 @@ class EngineCore:
# Setup Model. # Setup Model.
self.model_executor = executor_class(vllm_config) self.model_executor = executor_class(vllm_config)
if executor_fail_callback is not None:
self.model_executor.register_failure_callback(
executor_fail_callback)
# Setup KV Caches and update CacheConfig after profiling. # Setup KV Caches and update CacheConfig after profiling.
num_gpu_blocks, num_cpu_blocks, kv_cache_config = \ num_gpu_blocks, num_cpu_blocks, kv_cache_config = \
@ -254,7 +253,8 @@ class EngineCore:
return engine_core_outputs return engine_core_outputs
def shutdown(self): def shutdown(self):
self.model_executor.shutdown() if self.model_executor:
self.model_executor.shutdown()
def profile(self, is_start: bool = True): def profile(self, is_start: bool = True):
self.model_executor.profile(is_start) self.model_executor.profile(is_start)
@ -308,6 +308,8 @@ class EngineCore:
class EngineCoreProc(EngineCore): class EngineCoreProc(EngineCore):
"""ZMQ-wrapper for running EngineCore in background process.""" """ZMQ-wrapper for running EngineCore in background process."""
ENGINE_CORE_DEAD = b'ENGINE_CORE_DEAD'
def __init__( def __init__(
self, self,
input_path: str, input_path: str,
@ -317,11 +319,16 @@ class EngineCoreProc(EngineCore):
log_stats: bool, log_stats: bool,
engine_index: int = 0, engine_index: int = 0,
): ):
super().__init__(vllm_config, executor_class, log_stats) input_queue = queue.Queue[tuple[EngineCoreRequestType, Any]]()
executor_fail_callback = lambda: input_queue.put_nowait(
(EngineCoreRequestType.EXECUTOR_FAILED, b''))
super().__init__(vllm_config, executor_class, log_stats,
executor_fail_callback)
self.step_fn = (self.step if self.batch_queue is None else self.step_fn = (self.step if self.batch_queue is None else
self.step_with_batch_queue) self.step_with_batch_queue)
self.global_unfinished_reqs = False self.global_unfinished_reqs = False
# Background Threads and Queues for IO. These enable us to # Background Threads and Queues for IO. These enable us to
@ -329,15 +336,16 @@ class EngineCoreProc(EngineCore):
# and to overlap some serialization/deserialization with the # and to overlap some serialization/deserialization with the
# model forward pass. # model forward pass.
# Threads handle Socket <-> Queues and core_busy_loop uses Queue. # Threads handle Socket <-> Queues and core_busy_loop uses Queue.
self.input_queue: queue.Queue[tuple[EngineCoreRequestType, self.input_queue = input_queue
Any]] = queue.Queue() self.output_queue = queue.Queue[Union[EngineCoreOutputs, bytes]]()
self.output_queue: queue.Queue[EngineCoreOutputs] = queue.Queue()
threading.Thread(target=self.process_input_socket, threading.Thread(target=self.process_input_socket,
args=(input_path, engine_index), args=(input_path, engine_index),
daemon=True).start() daemon=True).start()
threading.Thread(target=self.process_output_socket, self.output_thread = threading.Thread(
args=(output_path, engine_index), target=self.process_output_socket,
daemon=True).start() args=(output_path, engine_index),
daemon=True)
self.output_thread.start()
@staticmethod @staticmethod
def run_engine_core(*args, def run_engine_core(*args,
@ -364,7 +372,6 @@ class EngineCoreProc(EngineCore):
signal.signal(signal.SIGTERM, signal_handler) signal.signal(signal.SIGTERM, signal_handler)
signal.signal(signal.SIGINT, signal_handler) signal.signal(signal.SIGINT, signal_handler)
parent_process = psutil.Process().parent()
engine_core: Optional[EngineCoreProc] = None engine_core: Optional[EngineCoreProc] = None
try: try:
parallel_config: ParallelConfig = kwargs[ parallel_config: ParallelConfig = kwargs[
@ -380,13 +387,15 @@ class EngineCoreProc(EngineCore):
engine_core.run_busy_loop() engine_core.run_busy_loop()
except SystemExit: except SystemExit:
logger.debug("EngineCore interrupted.") logger.debug("EngineCore exiting.")
except Exception:
traceback = get_exception_traceback()
logger.error("EngineCore hit an exception: %s", traceback)
parent_process.send_signal(signal.SIGUSR1)
except Exception as e:
if engine_core is None:
logger.exception("EngineCore failed to start.")
else:
logger.exception("EngineCore encountered a fatal error.")
engine_core._send_engine_dead()
raise e
finally: finally:
if engine_core is not None: if engine_core is not None:
engine_core.shutdown() engine_core.shutdown()
@ -458,6 +467,11 @@ class EngineCoreProc(EngineCore):
f" failed: {str(e)}") f" failed: {str(e)}")
self.output_queue.put_nowait( self.output_queue.put_nowait(
EngineCoreOutputs(utility_output=output)) EngineCoreOutputs(utility_output=output))
elif request_type == EngineCoreRequestType.EXECUTOR_FAILED:
raise RuntimeError("Executor failed.")
else:
logger.error("Unrecognized input request type encountered: %s",
request_type)
@staticmethod @staticmethod
def _convert_msgspec_args(method, args): def _convert_msgspec_args(method, args):
@ -473,6 +487,18 @@ class EngineCoreProc(EngineCore):
and not isinstance(v, p.annotation) else v and not isinstance(v, p.annotation) else v
for v, p in zip(args, arg_types)) for v, p in zip(args, arg_types))
def _send_engine_dead(self):
"""Send EngineDead status to the EngineCoreClient."""
# Put ENGINE_CORE_DEAD in the queue.
self.output_queue.put_nowait(EngineCoreProc.ENGINE_CORE_DEAD)
# Wait until msg sent by the daemon before shutdown.
self.output_thread.join(timeout=5.0)
if self.output_thread.is_alive():
logger.fatal("vLLM shutdown signal from EngineCore failed "
"to send. Please report this issue.")
def process_input_socket(self, input_path: str, engine_index: int): def process_input_socket(self, input_path: str, engine_index: int):
"""Input socket IO thread.""" """Input socket IO thread."""
@ -511,9 +537,16 @@ class EngineCoreProc(EngineCore):
# Reuse send buffer. # Reuse send buffer.
buffer = bytearray() buffer = bytearray()
with zmq_socket_ctx(output_path, zmq.constants.PUSH) as socket: # We must set linger to ensure the ENGINE_CORE_DEAD
# message is sent prior to closing the socket.
with zmq_socket_ctx(output_path, zmq.constants.PUSH,
linger=4000) as socket:
while True: while True:
outputs = self.output_queue.get() outputs = self.output_queue.get()
if outputs == EngineCoreProc.ENGINE_CORE_DEAD:
socket.send(outputs, copy=False)
break
assert not isinstance(outputs, bytes)
outputs.engine_index = engine_index outputs.engine_index = engine_index
buffers = encoder.encode_into(outputs, buffer) buffers = encoder.encode_into(outputs, buffer)
socket.send_multipart(buffers, copy=False) socket.send_multipart(buffers, copy=False)

View File

@ -1,14 +1,10 @@
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
import asyncio import asyncio
import os
import queue import queue
import signal
import threading
import uuid import uuid
import weakref import weakref
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from collections.abc import Awaitable from collections.abc import Awaitable, Sequence
from concurrent.futures import Future from concurrent.futures import Future
from dataclasses import dataclass, field from dataclasses import dataclass, field
from threading import Thread from threading import Thread
@ -21,10 +17,11 @@ from vllm.config import VllmConfig
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.lora.request import LoRARequest from vllm.lora.request import LoRARequest
from vllm.utils import (get_open_zmq_inproc_path, get_open_zmq_ipc_path, from vllm.utils import (get_open_zmq_inproc_path, get_open_zmq_ipc_path,
kill_process_tree, make_zmq_socket) make_zmq_socket)
from vllm.v1.engine import (EngineCoreOutputs, EngineCoreRequest, from vllm.v1.engine import (EngineCoreOutputs, EngineCoreRequest,
EngineCoreRequestType, UtilityOutput) EngineCoreRequestType, UtilityOutput)
from vllm.v1.engine.core import EngineCore, EngineCoreProc from vllm.v1.engine.core import EngineCore, EngineCoreProc
from vllm.v1.engine.exceptions import EngineDeadError
from vllm.v1.executor.abstract import Executor from vllm.v1.executor.abstract import Executor
from vllm.v1.serial_utils import MsgpackDecoder, MsgpackEncoder, bytestr from vllm.v1.serial_utils import MsgpackDecoder, MsgpackEncoder, bytestr
from vllm.v1.utils import BackgroundProcHandle from vllm.v1.utils import BackgroundProcHandle
@ -305,14 +302,22 @@ class BackgroundResources:
core_engines: list[CoreEngine] = field(default_factory=list) core_engines: list[CoreEngine] = field(default_factory=list)
output_socket: Optional[Union[zmq.Socket, zmq.asyncio.Socket]] = None output_socket: Optional[Union[zmq.Socket, zmq.asyncio.Socket]] = None
input_socket: Optional[Union[zmq.Socket, zmq.asyncio.Socket]] = None input_socket: Optional[Union[zmq.Socket, zmq.asyncio.Socket]] = None
output_queue_task: Optional[asyncio.Task] = None
shutdown_path: Optional[str] = None shutdown_path: Optional[str] = None
# Set if any of the engines are dead. Here so that the output
# processing threads can access it without holding a ref to the client.
engine_dead: bool = False
def __call__(self): def __call__(self):
"""Clean up background resources.""" """Clean up background resources."""
for core_engine in self.core_engines: for core_engine in self.core_engines:
core_engine.close() core_engine.close()
if self.output_queue_task is not None:
self.output_queue_task.cancel()
# ZMQ context termination can hang if the sockets # ZMQ context termination can hang if the sockets
# aren't explicitly closed first. # aren't explicitly closed first.
if self.output_socket is not None: if self.output_socket is not None:
@ -327,6 +332,12 @@ class BackgroundResources:
# Send shutdown signal. # Send shutdown signal.
shutdown_sender.send(b'') shutdown_sender.send(b'')
def validate_alive(self, frames: Sequence[zmq.Frame]):
if len(frames) == 1 and (frames[0].buffer
== EngineCoreProc.ENGINE_CORE_DEAD):
self.engine_dead = True
raise EngineDeadError()
class MPClient(EngineCoreClient): class MPClient(EngineCoreClient):
""" """
@ -348,27 +359,6 @@ class MPClient(EngineCoreClient):
executor_class: type[Executor], executor_class: type[Executor],
log_stats: bool, log_stats: bool,
): ):
# The child processes will send SIGUSR1 when unrecoverable
# errors happen. We kill the process tree here so that the
# stack trace is very evident.
# TODO(rob): rather than killing the main process, we should
# figure out how to raise an AsyncEngineDeadError and
# handle at the API server level so we can return a better
# error code to the clients calling vLLM.
def sigusr1_handler(signum, frame):
logger.fatal("Got fatal signal from worker processes, shutting "
"down. See stack trace above for root cause issue.")
kill_process_tree(os.getpid())
if threading.current_thread() == threading.main_thread():
signal.signal(signal.SIGUSR1, sigusr1_handler)
else:
logger.warning("SIGUSR1 handler not installed because we are not "
"running in the main thread. In this case the "
"forked engine process may not be killed when "
"an exception is raised, and you need to handle "
"the engine process shutdown manually.")
# Serialization setup. # Serialization setup.
self.encoder = MsgpackEncoder() self.encoder = MsgpackEncoder()
self.decoder = MsgpackDecoder(EngineCoreOutputs) self.decoder = MsgpackDecoder(EngineCoreOutputs)
@ -378,32 +368,37 @@ class MPClient(EngineCoreClient):
self.ctx = zmq.asyncio.Context(sync_ctx) if asyncio_mode else sync_ctx self.ctx = zmq.asyncio.Context(sync_ctx) if asyncio_mode else sync_ctx
# This will ensure resources created so far are closed # This will ensure resources created so far are closed
# when the client is garbage collected, even if an # when the client is garbage collected, even if an
# exception is raised mid-construction. # exception is raised mid-construction.
self.resources = BackgroundResources(ctx=sync_ctx) self.resources = BackgroundResources(ctx=sync_ctx)
self._finalizer = weakref.finalize(self, self.resources) self._finalizer = weakref.finalize(self, self.resources)
success = False
try:
# Paths and sockets for IPC.
self.output_path = get_open_zmq_ipc_path()
input_path = get_open_zmq_ipc_path()
self.input_socket = make_zmq_socket(self.ctx,
input_path,
zmq.ROUTER,
bind=True)
self.resources.input_socket = self.input_socket
# Paths and sockets for IPC. new_core_engine = lambda index, local_dp_rank=None: CoreEngine(
self.output_path = get_open_zmq_ipc_path() vllm_config, executor_class, log_stats, input_path, self.
input_path = get_open_zmq_ipc_path() output_path, index, local_dp_rank)
self.input_socket = make_zmq_socket(self.ctx,
input_path,
zmq.ROUTER,
bind=True)
self.resources.input_socket = self.input_socket
new_core_engine = lambda index, local_dp_rank=None: CoreEngine( # Start engine core process(es).
vllm_config, executor_class, log_stats, input_path, self. self._init_core_engines(vllm_config, new_core_engine,
output_path, index, local_dp_rank) self.resources.core_engines)
# Start engine core process(es). # Wait for engine core process(es) to start.
self._init_core_engines(vllm_config, new_core_engine, self._wait_for_engine_startup()
self.resources.core_engines)
# Wait for engine core process(es) to start. self.utility_results: dict[int, AnyFuture] = {}
self._wait_for_engine_startup() success = True
finally:
self.utility_results: dict[int, AnyFuture] = {} if not success:
self._finalizer()
def _wait_for_engine_startup(self): def _wait_for_engine_startup(self):
# Get a sync handle to the socket which can be sync or async. # Get a sync handle to the socket which can be sync or async.
@ -451,8 +446,18 @@ class MPClient(EngineCoreClient):
self.core_engine = core_engine self.core_engine = core_engine
def shutdown(self): def shutdown(self):
# Terminate background resources.
self._finalizer() self._finalizer()
def _format_exception(self, e: Exception) -> Exception:
"""If errored, use EngineDeadError so root cause is clear."""
return EngineDeadError(
suppress_context=True) if self.resources.engine_dead else e
def ensure_alive(self):
if self.resources.engine_dead:
raise EngineDeadError()
def _process_utility_output(output: UtilityOutput, def _process_utility_output(output: UtilityOutput,
utility_results: dict[int, AnyFuture]): utility_results: dict[int, AnyFuture]):
@ -476,7 +481,7 @@ class SyncMPClient(MPClient):
log_stats=log_stats, log_stats=log_stats,
) )
self.outputs_queue: queue.Queue[EngineCoreOutputs] = queue.Queue() self.outputs_queue = queue.Queue[Union[EngineCoreOutputs, Exception]]()
# Ensure that the outputs socket processing thread does not have # Ensure that the outputs socket processing thread does not have
# a ref to the client which prevents gc. # a ref to the client which prevents gc.
@ -487,7 +492,8 @@ class SyncMPClient(MPClient):
outputs_queue = self.outputs_queue outputs_queue = self.outputs_queue
shutdown_path = get_open_zmq_inproc_path() shutdown_path = get_open_zmq_inproc_path()
self.resources.shutdown_path = shutdown_path resources = self.resources
resources.shutdown_path = shutdown_path
def process_outputs_socket(): def process_outputs_socket():
shutdown_socket = ctx.socket(zmq.PAIR) shutdown_socket = ctx.socket(zmq.PAIR)
@ -506,12 +512,15 @@ class SyncMPClient(MPClient):
break break
frames = out_socket.recv_multipart(copy=False) frames = out_socket.recv_multipart(copy=False)
resources.validate_alive(frames)
outputs = decoder.decode(frames) outputs = decoder.decode(frames)
if outputs.utility_output: if outputs.utility_output:
_process_utility_output(outputs.utility_output, _process_utility_output(outputs.utility_output,
utility_results) utility_results)
else: else:
outputs_queue.put_nowait(outputs) outputs_queue.put_nowait(outputs)
except Exception as e:
outputs_queue.put_nowait(e)
finally: finally:
# Close sockets. # Close sockets.
shutdown_socket.close(linger=0) shutdown_socket.close(linger=0)
@ -524,9 +533,16 @@ class SyncMPClient(MPClient):
self.output_queue_thread.start() self.output_queue_thread.start()
def get_output(self) -> EngineCoreOutputs: def get_output(self) -> EngineCoreOutputs:
return self.outputs_queue.get() # If an exception arises in process_outputs_socket task,
# it is forwarded to the outputs_queue so we can raise it
# from this (run_output_handler) task to shut down the server.
outputs = self.outputs_queue.get()
if isinstance(outputs, Exception):
raise self._format_exception(outputs) from None
return outputs
def _send_input(self, request_type: EngineCoreRequestType, request: Any): def _send_input(self, request_type: EngineCoreRequestType, request: Any):
self.ensure_alive()
# (Identity, RequestType, SerializedRequest) # (Identity, RequestType, SerializedRequest)
msg = (self.core_engine.identity, request_type.value, msg = (self.core_engine.identity, request_type.value,
*self.encoder.encode(request)) *self.encoder.encode(request))
@ -608,61 +624,81 @@ class AsyncMPClient(MPClient):
log_stats=log_stats, log_stats=log_stats,
) )
self.outputs_queue: Optional[asyncio.Queue[EngineCoreOutputs]] = None self.outputs_queue = asyncio.Queue[Union[EngineCoreOutputs,
self.queue_task: Optional[asyncio.Task] = None Exception]]()
try:
self.outputs_handler: Optional[Callable[ # If we are running in an asyncio event loop, start the queue task.
[AsyncMPClient, EngineCoreOutputs], Awaitable[None]]] = None # Otherwise, it will be started lazily. If it is not started here,
# we could miss EXECUTOR_FAILED messages from engine core if they
# occur prior to any requests being sent.
asyncio.get_running_loop()
self._ensure_output_queue_task()
except RuntimeError:
pass
def _ensure_output_queue_task(self): def _ensure_output_queue_task(self):
if self.outputs_queue is not None: resources = self.resources
if resources.output_queue_task is not None:
return return
# Perform IO in separate task to parallelize as much as possible. # Perform IO in separate task to parallelize as much as possible.
# Avoid task having direct reference back to the client. # Avoid task having direct reference back to the client.
self.outputs_queue = asyncio.Queue()
decoder = self.decoder decoder = self.decoder
utility_results = self.utility_results utility_results = self.utility_results
outputs_queue = self.outputs_queue outputs_queue = self.outputs_queue
output_handler = self.outputs_handler output_handler: Optional[Callable[[AsyncMPClient, EngineCoreOutputs],
Awaitable[None]]] = getattr(
self.__class__,
"process_engine_outputs", None)
_self_ref = weakref.ref(self) if output_handler else None _self_ref = weakref.ref(self) if output_handler else None
output_path = self.output_path output_path = self.output_path
output_socket = make_zmq_socket(self.ctx, output_path, output_socket = make_zmq_socket(self.ctx, output_path,
zmq.constants.PULL) zmq.constants.PULL)
self.resources.output_socket = output_socket resources.output_socket = output_socket
async def process_outputs_socket(): async def process_outputs_socket():
while True: try:
frames = await output_socket.recv_multipart(copy=False) while True:
outputs: EngineCoreOutputs = decoder.decode(frames) frames = await output_socket.recv_multipart(copy=False)
if outputs.utility_output: resources.validate_alive(frames)
_process_utility_output(outputs.utility_output, outputs: EngineCoreOutputs = decoder.decode(frames)
utility_results) if outputs.utility_output:
continue _process_utility_output(outputs.utility_output,
utility_results)
continue
if output_handler is not None: if output_handler is not None:
assert _self_ref is not None assert _self_ref is not None
_self = _self_ref() _self = _self_ref()
if not _self: if not _self:
# Client has been garbage collected, abort. # Client has been garbage collected, abort.
return return
await output_handler(_self, outputs) await output_handler(_self, outputs)
if outputs.outputs or outputs.scheduler_stats: if outputs.outputs or outputs.scheduler_stats:
outputs_queue.put_nowait(outputs) outputs_queue.put_nowait(outputs)
except Exception as e:
outputs_queue.put_nowait(e)
self.queue_task = asyncio.create_task(process_outputs_socket(), resources.output_queue_task = asyncio.create_task(
name="EngineCoreOutputQueueTask") process_outputs_socket(), name="EngineCoreOutputQueueTask")
async def get_output_async(self) -> EngineCoreOutputs: async def get_output_async(self) -> EngineCoreOutputs:
self._ensure_output_queue_task() self._ensure_output_queue_task()
# If an exception arises in process_outputs_socket task,
# it is forwarded to the outputs_queue so we can raise it
# from this (run_output_handler) task to shut down the server.
assert self.outputs_queue is not None assert self.outputs_queue is not None
return await self.outputs_queue.get() outputs = await self.outputs_queue.get()
if isinstance(outputs, Exception):
raise self._format_exception(outputs) from None
return outputs
def _send_input(self, def _send_input(self,
request_type: EngineCoreRequestType, request_type: EngineCoreRequestType,
request: Any, request: Any,
engine: Optional[CoreEngine] = None) -> Awaitable[None]: engine: Optional[CoreEngine] = None) -> Awaitable[None]:
self.ensure_alive()
if engine is None: if engine is None:
engine = self.core_engine engine = self.core_engine
@ -671,6 +707,7 @@ class AsyncMPClient(MPClient):
def _send_input_message(self, message: tuple[bytestr, ...], def _send_input_message(self, message: tuple[bytestr, ...],
engine: CoreEngine) -> Awaitable[None]: engine: CoreEngine) -> Awaitable[None]:
self.ensure_alive()
message = (engine.identity, ) + message message = (engine.identity, ) + message
return self.input_socket.send_multipart(message, copy=False) return self.input_socket.send_multipart(message, copy=False)
@ -754,18 +791,17 @@ class DPAsyncMPClient(AsyncMPClient):
def __init__(self, vllm_config: VllmConfig, executor_class: type[Executor], def __init__(self, vllm_config: VllmConfig, executor_class: type[Executor],
log_stats: bool): log_stats: bool):
super().__init__(vllm_config, executor_class, log_stats)
assert len(self.core_engines) > 1 self.num_engines_running = 0
self.reqs_in_flight: dict[str, CoreEngine] = {}
super().__init__(vllm_config, executor_class, log_stats)
# Control message used for triggering dp idle mode loop. # Control message used for triggering dp idle mode loop.
self.start_dp_msg = (EngineCoreRequestType.START_DP.value, self.start_dp_msg = (EngineCoreRequestType.START_DP.value,
*self.encoder.encode(None)) *self.encoder.encode(None))
self.num_engines_running = 0 assert len(self.core_engines) > 1
self.reqs_in_flight: dict[str, CoreEngine] = {}
self.outputs_handler = DPAsyncMPClient.process_engine_outputs # type: ignore[assignment]
def _init_core_engines( def _init_core_engines(
self, self,

View File

@ -0,0 +1,16 @@
# SPDX-License-Identifier: Apache-2.0
class EngineGenerateError(Exception):
"""Raised when a AsyncLLM.generate() fails. Recoverable."""
pass
class EngineDeadError(Exception):
"""Raised when the EngineCore dies. Unrecoverable."""
def __init__(self, *args, suppress_context: bool = False, **kwargs):
ENGINE_DEAD_MESSAGE = "EngineCore encountered an issue. See stack trace (above) for the root cause." # noqa: E501
super().__init__(ENGINE_DEAD_MESSAGE, *args, **kwargs)
# Make stack trace clearer when using with LLMEngine by
# silencing irrelevant ZMQError.
self.__suppress_context__ = suppress_context

View File

@ -28,32 +28,40 @@ class RequestOutputCollector:
def __init__(self, output_kind: RequestOutputKind): def __init__(self, output_kind: RequestOutputKind):
self.aggregate = output_kind == RequestOutputKind.DELTA self.aggregate = output_kind == RequestOutputKind.DELTA
self.output: Optional[RequestOutput] = None self.output: Optional[Union[RequestOutput, Exception]] = None
self.ready = asyncio.Event() self.ready = asyncio.Event()
def put(self, output: RequestOutput) -> None: def put(self, output: Union[RequestOutput, Exception]) -> None:
if self.output is None: """Non-blocking put operation."""
if self.output is None or isinstance(output, Exception):
self.output = output self.output = output
self.ready.set() self.ready.set()
elif self.aggregate: elif isinstance(self.output, RequestOutput):
# Coalesce the outputs in delta case. if self.aggregate:
self.output.add(output) # Coalesce the outputs in delta case.
else: self.output.add(output)
# Just replace latest in non-delta case. else:
self.output = output # Just replace latest in non-delta case.
self.output = output
async def get(self) -> RequestOutput: async def get(self) -> RequestOutput:
"""Get operation blocks on put event."""
while (output := self.output) is None: while (output := self.output) is None:
await self.ready.wait() await self.ready.wait()
self.output = None self.output = None
self.ready.clear() self.ready.clear()
if isinstance(output, Exception):
raise output
return output return output
def get_nowait(self) -> Optional[RequestOutput]: def get_nowait(self) -> Optional[RequestOutput]:
"""Non-blocking get operation."""
output = self.output output = self.output
if output is not None: if output is not None:
self.output = None self.output = None
self.ready.clear() self.ready.clear()
if isinstance(output, Exception):
raise output
return output return output
@ -235,6 +243,13 @@ class OutputProcessor:
def has_unfinished_requests(self) -> bool: def has_unfinished_requests(self) -> bool:
return len(self.request_states) > 0 return len(self.request_states) > 0
def propagate_error(self, e: Exception):
"""Propagate error to all generate() tasks."""
for _, state in self.request_states.items():
assert state.queue is not None
state.queue.put(e)
def abort_requests( def abort_requests(
self, self,
request_ids: Iterable[str], request_ids: Iterable[str],

View File

@ -1,7 +1,7 @@
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
from concurrent.futures import Future from concurrent.futures import Future
from typing import Union from typing import Callable, Union
import torch import torch
import torch.distributed as dist import torch.distributed as dist
@ -15,6 +15,8 @@ from vllm.executor.uniproc_executor import ( # noqa
from vllm.v1.kv_cache_interface import KVCacheConfig, KVCacheSpec from vllm.v1.kv_cache_interface import KVCacheConfig, KVCacheSpec
from vllm.v1.outputs import ModelRunnerOutput from vllm.v1.outputs import ModelRunnerOutput
FailureCallback = Callable[[], None]
class Executor(ExecutorBase): class Executor(ExecutorBase):
""" """
@ -62,6 +64,13 @@ class Executor(ExecutorBase):
args=(kv_cache_configs, )) args=(kv_cache_configs, ))
self.collective_rpc("compile_or_warm_up_model") self.collective_rpc("compile_or_warm_up_model")
def register_failure_callback(self, callback: FailureCallback):
"""
Register a function to be called if the executor enters a permanent
failed state.
"""
pass
def determine_available_memory(self) -> list[int]: # in bytes def determine_available_memory(self) -> list[int]: # in bytes
output = self.collective_rpc("determine_available_memory") output = self.collective_rpc("determine_available_memory")
return output return output

View File

@ -1,21 +1,23 @@
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
import multiprocessing
import os import os
import pickle import pickle
import signal import signal
import sys import sys
import threading
import time import time
import traceback import traceback
import weakref import weakref
from concurrent.futures import Future
from dataclasses import dataclass from dataclasses import dataclass
from enum import Enum, auto from enum import Enum, auto
from functools import partial from functools import partial
from multiprocessing.connection import Connection
from multiprocessing.process import BaseProcess from multiprocessing.process import BaseProcess
from typing import Any, Callable, Optional, Union from threading import Thread
from typing import Any, Callable, Optional, Union, cast
import cloudpickle import cloudpickle
import psutil
import zmq
from vllm.config import VllmConfig from vllm.config import VllmConfig
from vllm.distributed import (destroy_distributed_environment, from vllm.distributed import (destroy_distributed_environment,
@ -26,8 +28,9 @@ from vllm.executor.multiproc_worker_utils import (
_add_prefix, set_multiprocessing_worker_envs) _add_prefix, set_multiprocessing_worker_envs)
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.utils import (get_distributed_init_method, get_mp_context, from vllm.utils import (get_distributed_init_method, get_mp_context,
get_open_port, get_open_zmq_ipc_path, zmq_socket_ctx) get_open_port)
from vllm.v1.executor.abstract import Executor from vllm.v1.executor.abstract import Executor, FailureCallback
from vllm.v1.outputs import ModelRunnerOutput
from vllm.worker.worker_base import WorkerWrapperBase from vllm.worker.worker_base import WorkerWrapperBase
logger = init_logger(__name__) logger = init_logger(__name__)
@ -35,6 +38,8 @@ logger = init_logger(__name__)
POLLING_TIMEOUT_MS = 5000 POLLING_TIMEOUT_MS = 5000
POLLING_TIMEOUT_S = POLLING_TIMEOUT_MS // 1000 POLLING_TIMEOUT_S = POLLING_TIMEOUT_MS // 1000
EXECUTE_MODEL_TIMEOUT_S = 30
class MultiprocExecutor(Executor): class MultiprocExecutor(Executor):
@ -42,19 +47,9 @@ class MultiprocExecutor(Executor):
# Call self.shutdown at exit to clean up # Call self.shutdown at exit to clean up
# and ensure workers will be terminated. # and ensure workers will be terminated.
self._finalizer = weakref.finalize(self, self.shutdown) self._finalizer = weakref.finalize(self, self.shutdown)
self.is_failed = False
# The child processes will send SIGUSR1 when unrecoverable self.shutdown_event = threading.Event()
# errors happen. self.failure_callback: Optional[FailureCallback] = None
def sigusr1_handler(signum, frame):
logger.fatal(
"MulitprocExecutor got fatal signal from worker processes, "
"shutting down. See stack trace above for root cause issue.")
# Propagate error up to parent process.
parent_process = psutil.Process().parent()
parent_process.send_signal(signal.SIGUSR1)
self.shutdown()
signal.signal(signal.SIGUSR1, sigusr1_handler)
self.world_size = self.parallel_config.world_size self.world_size = self.parallel_config.world_size
tensor_parallel_size = self.parallel_config.tensor_parallel_size tensor_parallel_size = self.parallel_config.tensor_parallel_size
@ -78,28 +73,94 @@ class MultiprocExecutor(Executor):
scheduler_output_handle = self.rpc_broadcast_mq.export_handle() scheduler_output_handle = self.rpc_broadcast_mq.export_handle()
# Create workers # Create workers
self.workers: list[WorkerProcHandle] = [] unready_workers: list[UnreadyWorkerProcHandle] = []
for rank in range(self.world_size): success = False
worker = WorkerProc.make_worker_process(self.vllm_config, rank, try:
rank, for rank in range(self.world_size):
distributed_init_method, unready_workers.append(
scheduler_output_handle) WorkerProc.make_worker_process(
self.workers.append(worker) vllm_config=self.vllm_config,
local_rank=rank,
rank=rank,
distributed_init_method=distributed_init_method,
input_shm_handle=scheduler_output_handle,
))
# Ensure message queues are ready. Will deadlock if re-ordered # Workers must be created before wait_for_ready to avoid
# Must be kept consistent with the WorkerProc # deadlock, since worker.init_device() does a device sync.
self.rpc_broadcast_mq.wait_until_ready() self.workers = WorkerProc.wait_for_ready(unready_workers)
for w in self.workers:
w.worker_response_mq.wait_until_ready() # Ensure message queues are ready. Will deadlock if re-ordered
# Must be kept consistent with the WorkerProc.
self.rpc_broadcast_mq.wait_until_ready()
for w in self.workers:
w.worker_response_mq.wait_until_ready()
self.start_worker_monitor()
success = True
finally:
if not success:
# Clean up the worker procs if there was a failure.
self._ensure_worker_termination(
[w.proc for w in unready_workers])
def start_worker_monitor(self):
workers = self.workers
self_ref = weakref.ref(self)
# Monitors worker process liveness. If any die unexpectedly,
# logs an error, shuts down the executor and invokes the failure
# callback to inform the engine.
def monitor_workers():
sentinels = [h.proc.sentinel for h in workers]
died = multiprocessing.connection.wait(sentinels)
_self = self_ref()
if not _self or getattr(_self, 'shutting_down', False):
return
_self.is_failed = True
proc_name = next(h.proc.name for h in workers
if h.proc.sentinel == died[0])
logger.error(
"Worker proc %s died unexpectedly, "
"shutting down executor.", proc_name)
_self.shutdown()
callback = _self.failure_callback
if callback is not None:
_self.failure_callback = None
callback()
Thread(target=monitor_workers,
daemon=True,
name="MultiprocWorkerMonitor").start()
def register_failure_callback(self, callback: FailureCallback):
if self.is_failed:
callback()
else:
self.failure_callback = callback
def execute_model(
self,
scheduler_output,
) -> Union[ModelRunnerOutput, Future[ModelRunnerOutput]]:
(output, ) = self.collective_rpc("execute_model",
args=(scheduler_output, ),
rank0_reply_only=True,
timeout=EXECUTE_MODEL_TIMEOUT_S)
return output
def collective_rpc(self, def collective_rpc(self,
method: Union[str, Callable], method: Union[str, Callable],
timeout: Optional[float] = None, timeout: Optional[float] = 180.0,
args: tuple = (), args: tuple = (),
kwargs: Optional[dict] = None) -> list[Any]: kwargs: Optional[dict] = None,
rank0_reply_only: bool = False) -> list[Any]:
start_time = time.monotonic() start_time = time.monotonic()
kwargs = kwargs or {} kwargs = kwargs or {}
if self.is_failed:
raise RuntimeError("Executor failed.")
# NOTE: If the args are heterogeneous, then we pack them into a list, # NOTE: If the args are heterogeneous, then we pack them into a list,
# and unpack them in the method of every worker, because every worker # and unpack them in the method of every worker, because every worker
# knows their own rank. # knows their own rank.
@ -109,30 +170,30 @@ class MultiprocExecutor(Executor):
else: else:
send_method = cloudpickle.dumps( send_method = cloudpickle.dumps(
method, protocol=pickle.HIGHEST_PROTOCOL) method, protocol=pickle.HIGHEST_PROTOCOL)
self.rpc_broadcast_mq.enqueue((send_method, args, kwargs)) self.rpc_broadcast_mq.enqueue(
(send_method, args, kwargs, rank0_reply_only))
responses = [None] * self.world_size workers = (self.workers[0], ) if rank0_reply_only else self.workers
for w in self.workers: responses = [None] * len(workers)
for w in workers:
dequeue_timeout = timeout - (time.monotonic() - start_time dequeue_timeout = timeout - (time.monotonic() - start_time
) if timeout is not None else None ) if timeout is not None else None
status, result = w.worker_response_mq.dequeue( status, result = w.worker_response_mq.dequeue(
timeout=dequeue_timeout) timeout=dequeue_timeout, cancel=self.shutdown_event)
if status != WorkerProc.ResponseStatus.SUCCESS: if status != WorkerProc.ResponseStatus.SUCCESS:
raise RuntimeError( raise RuntimeError(
"Worker failed with error %s, please check the" f"Worker failed with error '{result}', please check the"
" stack trace above for the root cause", result) " stack trace above for the root cause")
responses[w.rank] = result responses[w.rank] = result
return responses return responses
except TimeoutError as e: except TimeoutError as e:
raise TimeoutError(f"RPC call to {method} timed out.") from e raise TimeoutError(f"RPC call to {method} timed out.") from e
except Exception as e:
# Re-raise any other exceptions
raise e
def _ensure_worker_termination(self): @staticmethod
def _ensure_worker_termination(worker_procs: list[BaseProcess]):
"""Ensure that all worker processes are terminated. Assumes workers have """Ensure that all worker processes are terminated. Assumes workers have
received termination requests. Waits for processing, then sends received termination requests. Waits for processing, then sends
termination and kill signals if needed.""" termination and kill signals if needed."""
@ -150,7 +211,7 @@ class MultiprocExecutor(Executor):
return False return False
# Send SIGTERM if still running # Send SIGTERM if still running
active_procs = [w.proc for w in self.workers if w.proc.is_alive()] active_procs = [proc for proc in worker_procs if proc.is_alive()]
for p in active_procs: for p in active_procs:
p.terminate() p.terminate()
if not wait_for_termination(active_procs, 4): if not wait_for_termination(active_procs, 4):
@ -159,22 +220,14 @@ class MultiprocExecutor(Executor):
for p in active_procs: for p in active_procs:
p.kill() p.kill()
self._cleanup_sockets()
def _cleanup_sockets(self):
for w in self.workers:
# Remove the zmq ipc socket file
socket_path = w.ready_path.replace("ipc://", "")
if os and os.path.exists(socket_path):
os.remove(socket_path)
def shutdown(self): def shutdown(self):
"""Properly shut down the executor and its workers""" """Properly shut down the executor and its workers"""
if not getattr(self, 'shutting_down', False): if not getattr(self, 'shutting_down', False):
self.shutting_down = True self.shutting_down = True
self.shutdown_event.set()
for w in self.workers: for w in self.workers:
w.worker_response_mq = None w.worker_response_mq = None
self._ensure_worker_termination() self._ensure_worker_termination([w.proc for w in self.workers])
self.rpc_broadcast_mq = None self.rpc_broadcast_mq = None
@ -183,13 +236,30 @@ class MultiprocExecutor(Executor):
return return
@dataclass
class UnreadyWorkerProcHandle:
"""WorkerProcess handle before READY."""
proc: BaseProcess
rank: int
ready_pipe: Connection
@dataclass @dataclass
class WorkerProcHandle: class WorkerProcHandle:
proc: BaseProcess proc: BaseProcess
rank: int rank: int
ready_path: str
worker_response_mq: MessageQueue # The worker process writes to this MQ worker_response_mq: MessageQueue # The worker process writes to this MQ
@classmethod
def from_unready_handle(
cls, unready_handle: UnreadyWorkerProcHandle,
worker_response_mq: MessageQueue) -> "WorkerProcHandle":
return cls(
proc=unready_handle.proc,
rank=unready_handle.rank,
worker_response_mq=worker_response_mq,
)
class WorkerProc: class WorkerProc:
"""Wrapper that runs one Worker in a separate process.""" """Wrapper that runs one Worker in a separate process."""
@ -203,7 +273,6 @@ class WorkerProc:
rank: int, rank: int,
distributed_init_method: str, distributed_init_method: str,
input_shm_handle: Handle, input_shm_handle: Handle,
ready_path: str,
): ):
self.rank = rank self.rank = rank
wrapper = WorkerWrapperBase(vllm_config=vllm_config, rpc_rank=rank) wrapper = WorkerWrapperBase(vllm_config=vllm_config, rpc_rank=rank)
@ -231,18 +300,8 @@ class WorkerProc:
# Initializes a message queue for sending the model output # Initializes a message queue for sending the model output
self.worker_response_mq = MessageQueue(1, 1) self.worker_response_mq = MessageQueue(1, 1)
worker_response_mq_handle = self.worker_response_mq.export_handle()
# Send Readiness signal to EngineCore process.
# Set linger here because we want to ensure the message has
# been sent before the context is closed.
with zmq_socket_ctx(ready_path, zmq.constants.PUSH,
linger=10000) as ready_socket:
payload = pickle.dumps(worker_response_mq_handle,
protocol=pickle.HIGHEST_PROTOCOL)
ready_socket.send_string(WorkerProc.READY_STR)
ready_socket.send(payload)
# Initialize device and loads weights
self.worker.init_device() self.worker.init_device()
self.worker.load_model() self.worker.load_model()
@ -253,12 +312,10 @@ class WorkerProc:
rank: int, rank: int,
distributed_init_method: str, distributed_init_method: str,
input_shm_handle, # Receive SchedulerOutput input_shm_handle, # Receive SchedulerOutput
) -> WorkerProcHandle: ) -> UnreadyWorkerProcHandle:
context = get_mp_context() context = get_mp_context()
# (reader, writer)
# ZMQ path for worker to send ready message and shm_broadcast handle reader, writer = context.Pipe(duplex=False)
# back to core process.
ready_path = get_open_zmq_ipc_path()
process_kwargs = { process_kwargs = {
"vllm_config": vllm_config, "vllm_config": vllm_config,
@ -266,24 +323,57 @@ class WorkerProc:
"rank": rank, "rank": rank,
"distributed_init_method": distributed_init_method, "distributed_init_method": distributed_init_method,
"input_shm_handle": input_shm_handle, "input_shm_handle": input_shm_handle,
"ready_path": ready_path, "ready_pipe": (reader, writer),
} }
# Run EngineCore busy loop in background process. # Run EngineCore busy loop in background process.
proc = context.Process(target=WorkerProc.worker_main, proc = context.Process(target=WorkerProc.worker_main,
kwargs=process_kwargs, kwargs=process_kwargs,
name=f"VllmWorker-{rank}",
daemon=True) daemon=True)
with zmq_socket_ctx(ready_path, zmq.constants.PULL) as ready_socket: proc.start()
proc.start() writer.close()
return UnreadyWorkerProcHandle(proc, rank, reader)
# Wait for startup @staticmethod
worker_response_mq_handle = WorkerProc.wait_for_startup( def wait_for_ready(
proc, ready_socket) unready_proc_handles: list[UnreadyWorkerProcHandle]
) -> list[WorkerProcHandle]:
worker_response_mq = MessageQueue.create_from_handle( e = Exception("WorkerProc initialization failed due to "
worker_response_mq_handle, 0) "an exception in a background process. "
"See stack trace for root cause.")
return WorkerProcHandle(proc, rank, ready_path, worker_response_mq) pipes = {handle.ready_pipe: handle for handle in unready_proc_handles}
ready_proc_handles: list[Optional[WorkerProcHandle]] = (
[None] * len(unready_proc_handles))
while pipes:
ready = multiprocessing.connection.wait(pipes.keys())
for pipe in ready:
assert isinstance(pipe, Connection)
try:
# Wait until the WorkerProc is ready.
unready_proc_handle = pipes.pop(pipe)
response: dict[str, Any] = pipe.recv()
if response["status"] != "READY":
raise e
# Extract the message queue handle.
worker_response_mq = MessageQueue.create_from_handle(
response["handle"], 0)
ready_proc_handles[unready_proc_handle.rank] = (
WorkerProcHandle.from_unready_handle(
unready_proc_handle, worker_response_mq))
except EOFError:
e.__suppress_context__ = True
raise e from None
finally:
# Close connection.
pipe.close()
return cast(list[WorkerProcHandle], ready_proc_handles)
def shutdown(self): def shutdown(self):
self.rpc_broadcast_mq = None self.rpc_broadcast_mq = None
@ -312,51 +402,51 @@ class WorkerProc:
signal.signal(signal.SIGINT, signal_handler) signal.signal(signal.SIGINT, signal_handler)
worker = None worker = None
# tuple[Connection, Connection]
reader, ready_writer = kwargs.pop("ready_pipe")
try: try:
reader.close()
worker = WorkerProc(*args, **kwargs) worker = WorkerProc(*args, **kwargs)
# Send READY once we know everything is loaded
ready_writer.send({
"status":
WorkerProc.READY_STR,
"handle":
worker.worker_response_mq.export_handle(),
})
# Ensure message queues are ready. Will deadlock if re-ordered. # Ensure message queues are ready. Will deadlock if re-ordered.
# Must be kept consistent with the Executor # Must be kept consistent with the Executor
worker.rpc_broadcast_mq.wait_until_ready() worker.rpc_broadcast_mq.wait_until_ready()
worker.worker_response_mq.wait_until_ready() worker.worker_response_mq.wait_until_ready()
ready_writer.close()
ready_writer = None
worker.worker_busy_loop() worker.worker_busy_loop()
except SystemExit:
logger.debug("Worker interrupted.")
except Exception: except Exception:
# worker_busy_loop sends exceptions to Executor # NOTE: if an Exception arises in busy_loop, we send
# for shutdown, but if there is an error in startup or an # a FAILURE message over the MQ RPC to notify the Executor,
# error with IPC itself, we need to alert the parent. # which triggers system shutdown.
psutil.Process().parent().send_signal(signal.SIGUSR1) # TODO(rob): handle case where the MQ itself breaks.
raise
if ready_writer is not None:
logger.exception("WorkerProc failed to start.")
else:
logger.exception("WorkerProc failed.")
# The parent sends a SIGTERM to all worker processes if
# any worker dies. Set this value so we don't re-throw
# SystemExit() to avoid zmq exceptions in __del__.
shutdown_requested = True
finally: finally:
if ready_writer is not None:
ready_writer.close()
# Clean up once worker exits busy loop # Clean up once worker exits busy loop
if worker is not None: if worker is not None:
worker.shutdown() worker.shutdown()
worker = None
@staticmethod
def wait_for_startup(
proc: BaseProcess,
ready_socket: zmq.Socket,
) -> Optional[Handle]:
"""Wait until the Worker is ready."""
# Wait for Worker to send READY.
while ready_socket.poll(timeout=POLLING_TIMEOUT_MS) == 0:
logger.debug("Waiting for WorkerProc to startup.")
if not proc.is_alive():
raise RuntimeError("WorkerProc failed to start.")
message = ready_socket.recv_string()
assert message == WorkerProc.READY_STR
handle_frame = ready_socket.recv(copy=False)
handle = pickle.loads(handle_frame.buffer)
return handle
class ResponseStatus(Enum): class ResponseStatus(Enum):
SUCCESS = auto() SUCCESS = auto()
@ -365,7 +455,7 @@ class WorkerProc:
def worker_busy_loop(self): def worker_busy_loop(self):
"""Main busy loop for Multiprocessing Workers""" """Main busy loop for Multiprocessing Workers"""
while True: while True:
method, args, kwargs = self.rpc_broadcast_mq.dequeue() method, args, kwargs, rank0_only = self.rpc_broadcast_mq.dequeue()
try: try:
if isinstance(method, str): if isinstance(method, str):
@ -377,12 +467,14 @@ class WorkerProc:
# Notes have been introduced in python 3.11 # Notes have been introduced in python 3.11
if hasattr(e, "add_note"): if hasattr(e, "add_note"):
e.add_note(traceback.format_exc()) e.add_note(traceback.format_exc())
logger.exception("WorkerProc hit an exception: %s", exc_info=e) logger.exception("WorkerProc hit an exception.")
# exception might not be serializable, so we convert it to # exception might not be serializable, so we convert it to
# string, only for logging purpose. # string, only for logging purpose.
self.worker_response_mq.enqueue( if not rank0_only or self.rank == 0:
(WorkerProc.ResponseStatus.FAILURE, str(e))) self.worker_response_mq.enqueue(
(WorkerProc.ResponseStatus.FAILURE, str(e)))
continue continue
self.worker_response_mq.enqueue( if not rank0_only or self.rank == 0:
(WorkerProc.ResponseStatus.SUCCESS, output)) self.worker_response_mq.enqueue(
(WorkerProc.ResponseStatus.SUCCESS, output))