[V1][Frontend] Improve Shutdown And Logs (#11737)

Signed-off-by: rshaw@neuralmagic.com <rshaw@neuralmagic.com>
Signed-off-by: Andrew Feldman <afeldman@neuralmagic.com>
Signed-off-by: Nick Hill <nhill@redhat.com>
Co-authored-by: rshaw@neuralmagic.com <rshaw@neuralmagic.com>
Co-authored-by: Cyrus Leung <cyrus.tl.leung@gmail.com>
Co-authored-by: Russell Bryant <rbryant@redhat.com>
Co-authored-by: Andrew Feldman <afeldman@neuralmagic.com>
Co-authored-by: afeldman-nm <156691304+afeldman-nm@users.noreply.github.com>
Co-authored-by: Nick Hill <nhill@redhat.com>
This commit is contained in:
Robert Shaw 2025-04-16 22:48:34 -04:00 committed by GitHub
parent 3c776dcefb
commit 2b05b8ce69
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
16 changed files with 1031 additions and 347 deletions

View File

@ -552,6 +552,7 @@ steps:
# - pytest -v -s spec_decode/e2e/test_integration_dist_tp2.py
- VLLM_USE_V1=0 CUDA_VISIBLE_DEVICES=0,1 pytest -v -s test_sharded_state_loader.py
- VLLM_USE_V1=0 CUDA_VISIBLE_DEVICES=0,1 pytest -v -s kv_transfer/test_disagg.py
- CUDA_VISIBLE_DEVICES=0,1 pytest -v -s v1/shutdown
- label: Plugin Tests (2 GPUs) # 40min
working_dir: "/vllm-workspace/tests"

View File

@ -0,0 +1,97 @@
# SPDX-License-Identifier: Apache-2.0
"""Test that we handle a startup Error and shutdown."""
import pytest
from tests.utils import wait_for_gpu_memory_to_clear
from tests.v1.shutdown.utils import (SHUTDOWN_TEST_THRESHOLD_BYTES,
SHUTDOWN_TEST_TIMEOUT_SEC)
from vllm import LLM, SamplingParams
from vllm.engine.arg_utils import AsyncEngineArgs
from vllm.sampling_params import RequestOutputKind
from vllm.utils import cuda_device_count_stateless
from vllm.v1.engine.async_llm import AsyncLLM
MODELS = ["meta-llama/Llama-3.2-1B"]
@pytest.mark.asyncio
@pytest.mark.timeout(SHUTDOWN_TEST_TIMEOUT_SEC)
@pytest.mark.parametrize("model", MODELS)
@pytest.mark.parametrize("tensor_parallel_size", [2, 1])
@pytest.mark.parametrize("send_one_request", [False, True])
async def test_async_llm_delete(model: str, tensor_parallel_size: int,
send_one_request: bool) -> None:
"""Test that AsyncLLM frees GPU memory upon deletion.
AsyncLLM always uses an MP client.
Args:
model: model under test
tensor_parallel_size: degree of tensor parallelism
send_one_request: send one request to engine before deleting
"""
if cuda_device_count_stateless() < tensor_parallel_size:
pytest.skip(reason="Not enough CUDA devices")
engine_args = AsyncEngineArgs(model=model,
enforce_eager=True,
tensor_parallel_size=tensor_parallel_size)
# Instantiate AsyncLLM; make request to complete any deferred
# initialization; then delete instance
async_llm = AsyncLLM.from_engine_args(engine_args)
if send_one_request:
async for _ in async_llm.generate(
"Hello my name is",
request_id="abc",
sampling_params=SamplingParams(
max_tokens=1, output_kind=RequestOutputKind.DELTA)):
pass
del async_llm
# Confirm all the processes are cleaned up.
wait_for_gpu_memory_to_clear(
devices=list(range(tensor_parallel_size)),
threshold_bytes=SHUTDOWN_TEST_THRESHOLD_BYTES,
)
@pytest.mark.timeout(SHUTDOWN_TEST_TIMEOUT_SEC)
@pytest.mark.parametrize("model", MODELS)
@pytest.mark.parametrize("tensor_parallel_size", [2, 1])
@pytest.mark.parametrize("enable_multiprocessing", [True])
@pytest.mark.parametrize("send_one_request", [False, True])
def test_llm_delete(monkeypatch, model: str, tensor_parallel_size: int,
enable_multiprocessing: bool,
send_one_request: bool) -> None:
"""Test that LLM frees GPU memory upon deletion.
TODO(andy) - LLM without multiprocessing.
Args:
model: model under test
tensor_parallel_size: degree of tensor parallelism
enable_multiprocessing: enable workers in separate process(es)
send_one_request: send one request to engine before deleting
"""
if cuda_device_count_stateless() < tensor_parallel_size:
pytest.skip(reason="Not enough CUDA devices")
with monkeypatch.context() as m:
MP_VALUE = "1" if enable_multiprocessing else "0"
m.setenv("VLLM_ENABLE_V1_MULTIPROCESSING", MP_VALUE)
# Instantiate LLM; make request to complete any deferred
# initialization; then delete instance
llm = LLM(model=model,
enforce_eager=True,
tensor_parallel_size=tensor_parallel_size)
if send_one_request:
llm.generate("Hello my name is",
sampling_params=SamplingParams(max_tokens=1))
del llm
# Confirm all the processes are cleaned up.
wait_for_gpu_memory_to_clear(
devices=list(range(tensor_parallel_size)),
threshold_bytes=SHUTDOWN_TEST_THRESHOLD_BYTES,
)

View File

@ -0,0 +1,129 @@
# SPDX-License-Identifier: Apache-2.0
"""Test that we handle an Error in model forward and shutdown."""
import asyncio
import pytest
from tests.utils import wait_for_gpu_memory_to_clear
from tests.v1.shutdown.utils import (SHUTDOWN_TEST_THRESHOLD_BYTES,
SHUTDOWN_TEST_TIMEOUT_SEC)
from vllm import LLM, AsyncEngineArgs, SamplingParams
from vllm.distributed import get_tensor_model_parallel_rank
from vllm.model_executor.models.llama import LlamaForCausalLM
from vllm.utils import cuda_device_count_stateless
from vllm.v1.engine.async_llm import AsyncLLM
from vllm.v1.engine.exceptions import EngineDeadError
MODELS = ["meta-llama/Llama-3.2-1B"]
def evil_forward(self, *args, **kwargs):
"""Evil forward method that raise an exception after 10 calls."""
NUMBER_OF_GOOD_PASSES = 10
if not hasattr(self, "num_calls"):
self.num_calls = 0
if (self.num_calls == NUMBER_OF_GOOD_PASSES
and get_tensor_model_parallel_rank() == 0):
raise Exception("Simulated illegal memory access on Rank 0!")
self.num_calls += 1
return self.model(*args, **kwargs)
@pytest.mark.asyncio
@pytest.mark.parametrize("tensor_parallel_size", [2, 1])
@pytest.mark.parametrize("model", MODELS)
async def test_async_llm_model_error(monkeypatch, tensor_parallel_size: int,
model: str) -> None:
"""Test that AsyncLLM propagates a forward pass error and frees memory.
AsyncLLM always uses an MP client.
"""
if cuda_device_count_stateless() < tensor_parallel_size:
pytest.skip(reason="Not enough CUDA devices")
# Monkeypatch an error in the model.
monkeypatch.setattr(LlamaForCausalLM, "forward", evil_forward)
engine_args = AsyncEngineArgs(model=model,
enforce_eager=True,
tensor_parallel_size=tensor_parallel_size)
async_llm = AsyncLLM.from_engine_args(engine_args)
async def generate(request_id: str):
generator = async_llm.generate("Hello my name is",
request_id=request_id,
sampling_params=SamplingParams())
try:
async for _ in generator:
pass
except Exception as e:
return e
NUM_REQS = 3
tasks = [generate(f"request-{idx}") for idx in range(NUM_REQS)]
outputs = await asyncio.gather(*tasks)
# Every request should get an EngineDeadError.
for output in outputs:
assert isinstance(output, EngineDeadError)
# AsyncLLM should be errored.
assert async_llm.errored
# We should not be able to make another request.
with pytest.raises(EngineDeadError):
async for _ in async_llm.generate("Hello my name is",
request_id="abc",
sampling_params=SamplingParams()):
raise Exception("We should not get here.")
# Confirm all the processes are cleaned up.
wait_for_gpu_memory_to_clear(
devices=list(range(tensor_parallel_size)),
threshold_bytes=2 * 2**30,
timeout_s=60,
)
# NOTE: shutdown is handled by the API Server if an exception
# occurs, so it is expected that we would need to call this.
async_llm.shutdown()
@pytest.mark.timeout(SHUTDOWN_TEST_TIMEOUT_SEC)
@pytest.mark.parametrize("enable_multiprocessing", [True])
@pytest.mark.parametrize("tensor_parallel_size", [2, 1])
@pytest.mark.parametrize("model", MODELS)
def test_llm_model_error(monkeypatch, tensor_parallel_size: int,
enable_multiprocessing: bool, model: str) -> None:
"""Test that LLM propagates a forward pass error and frees memory.
TODO(andy) - LLM without multiprocessing; LLM with multiprocessing
and >1 rank
"""
if cuda_device_count_stateless() < tensor_parallel_size:
pytest.skip(reason="Not enough CUDA devices")
with monkeypatch.context() as m:
MP_VALUE = "1" if enable_multiprocessing else "0"
m.setenv("VLLM_ENABLE_V1_MULTIPROCESSING", MP_VALUE)
# Monkeypatch an error in the model.
m.setattr(LlamaForCausalLM, "forward", evil_forward)
llm = LLM(model=model,
enforce_eager=True,
tensor_parallel_size=tensor_parallel_size)
with pytest.raises(
EngineDeadError if enable_multiprocessing else Exception):
llm.generate("Hello my name is Robert and I")
# Confirm all the processes are cleaned up.
wait_for_gpu_memory_to_clear(
devices=list(range(tensor_parallel_size)),
threshold_bytes=SHUTDOWN_TEST_THRESHOLD_BYTES,
)

View File

@ -0,0 +1,69 @@
# SPDX-License-Identifier: Apache-2.0
"""Test error handling in Processor. Should not impact other reqs."""
import asyncio
import pytest
from tests.v1.shutdown.utils import SHUTDOWN_TEST_TIMEOUT_SEC
from vllm import SamplingParams
from vllm.engine.arg_utils import AsyncEngineArgs
from vllm.inputs.data import TokensPrompt
from vllm.sampling_params import RequestOutputKind
from vllm.v1.engine.async_llm import AsyncLLM
from vllm.v1.engine.exceptions import EngineGenerateError
MODELS = ["meta-llama/Llama-3.2-1B"]
@pytest.mark.asyncio
@pytest.mark.timeout(SHUTDOWN_TEST_TIMEOUT_SEC)
@pytest.mark.parametrize("model", MODELS)
async def test_async_llm_processor_error(model: str) -> None:
"""Test that AsyncLLM propagates a processor error.
Test empty tokens prompt (failure) and non-empty prompt (no failure.)
AsyncLLM always uses an MP client.
"""
engine_args = AsyncEngineArgs(model=model, enforce_eager=True)
async_llm = AsyncLLM.from_engine_args(engine_args)
async def generate(request_id: str):
# [] is not allowed and will raise a ValueError in Processor.
generator = async_llm.generate(TokensPrompt([]),
request_id=request_id,
sampling_params=SamplingParams())
try:
async for _ in generator:
pass
except Exception as e:
return e
NUM_REQS = 3
tasks = [generate(f"request-{idx}") for idx in range(NUM_REQS)]
outputs = await asyncio.gather(*tasks)
# Every request should have get an EngineGenerateError.
for output in outputs:
with pytest.raises(EngineGenerateError):
raise output
# AsyncLLM should be errored.
assert not async_llm.errored
# This should be no problem.
EXPECTED_TOKENS = 5
outputs = []
async for out in async_llm.generate(
"Hello my name is",
request_id="abc",
sampling_params=SamplingParams(
max_tokens=EXPECTED_TOKENS,
output_kind=RequestOutputKind.DELTA)):
outputs.append(out)
generated_tokens = []
for out in outputs:
generated_tokens.extend(out.outputs[0].token_ids)
assert len(generated_tokens) == EXPECTED_TOKENS
async_llm.shutdown()

View File

@ -0,0 +1,97 @@
# SPDX-License-Identifier: Apache-2.0
"""Test that we handle a startup Error and shutdown."""
import pytest
from tests.utils import wait_for_gpu_memory_to_clear
from tests.v1.shutdown.utils import (SHUTDOWN_TEST_THRESHOLD_BYTES,
SHUTDOWN_TEST_TIMEOUT_SEC)
from vllm import LLM
from vllm.distributed import get_tensor_model_parallel_rank
from vllm.engine.arg_utils import AsyncEngineArgs
from vllm.model_executor.models.llama import LlamaForCausalLM
from vllm.utils import cuda_device_count_stateless
from vllm.v1.engine.async_llm import AsyncLLM
MODELS = ["meta-llama/Llama-3.2-1B"]
def evil_method(self, *args, **kwargs):
"""Evil method that raises an exception."""
if get_tensor_model_parallel_rank() == 0:
raise Exception("Simulated Error in startup!")
return self.model(*args, **kwargs, intermediate_tensors=None)
@pytest.mark.timeout(SHUTDOWN_TEST_TIMEOUT_SEC)
@pytest.mark.parametrize("model", MODELS)
@pytest.mark.parametrize("tensor_parallel_size", [2, 1])
@pytest.mark.parametrize("failing_method", ["forward", "load_weights"])
def test_async_llm_startup_error(monkeypatch, model: str,
tensor_parallel_size: int,
failing_method: str) -> None:
"""Test that AsyncLLM propagates an __init__ error & frees memory.
Test profiling (forward()) and load weights failures.
AsyncLLM always uses an MP client.
"""
if cuda_device_count_stateless() < tensor_parallel_size:
pytest.skip(reason="Not enough CUDA devices")
# Monkeypatch an error in the model.
monkeypatch.setattr(LlamaForCausalLM, failing_method, evil_method)
engine_args = AsyncEngineArgs(model=model,
enforce_eager=True,
tensor_parallel_size=tensor_parallel_size)
# Confirm we get an exception.
with pytest.raises(Exception, match="initialization failed"):
_ = AsyncLLM.from_engine_args(engine_args)
# Confirm all the processes are cleaned up.
wait_for_gpu_memory_to_clear(
devices=list(range(tensor_parallel_size)),
threshold_bytes=SHUTDOWN_TEST_THRESHOLD_BYTES,
)
@pytest.mark.timeout(SHUTDOWN_TEST_TIMEOUT_SEC)
@pytest.mark.parametrize("model", MODELS)
@pytest.mark.parametrize("tensor_parallel_size", [2, 1])
@pytest.mark.parametrize("enable_multiprocessing", [True])
@pytest.mark.parametrize("failing_method", ["forward", "load_weights"])
def test_llm_startup_error(monkeypatch, model: str, tensor_parallel_size: int,
enable_multiprocessing: bool,
failing_method: str) -> None:
"""Test that LLM propagates an __init__ error and frees memory.
Test profiling (forward()) and load weights failures.
TODO(andy) - LLM without multiprocessing.
"""
if model != "meta-llama/Llama-3.2-1B":
pytest.skip(reason="Only test meta-llama/Llama-3.2-1B")
if cuda_device_count_stateless() < tensor_parallel_size:
pytest.skip(reason="Not enough CUDA devices")
with monkeypatch.context() as m:
MP_VALUE = "1" if enable_multiprocessing else "0"
m.setenv("VLLM_ENABLE_V1_MULTIPROCESSING", MP_VALUE)
# Monkeypatch an error in the model.
monkeypatch.setattr(LlamaForCausalLM, failing_method, evil_method)
with pytest.raises(
Exception,
match="initialization failed"
if enable_multiprocessing else "Simulated Error in startup!"):
_ = LLM(model=model,
enforce_eager=True,
tensor_parallel_size=tensor_parallel_size)
# Confirm all the processes are cleaned up.
wait_for_gpu_memory_to_clear(
devices=list(range(tensor_parallel_size)),
threshold_bytes=SHUTDOWN_TEST_THRESHOLD_BYTES,
)

View File

@ -0,0 +1,5 @@
# SPDX-License-Identifier: Apache-2.0
"""Shutdown test utils"""
SHUTDOWN_TEST_TIMEOUT_SEC = 120
SHUTDOWN_TEST_THRESHOLD_BYTES = 2 * 2**30

View File

@ -7,11 +7,13 @@ import time
from contextlib import contextmanager
from dataclasses import dataclass, field
from multiprocessing import shared_memory
from typing import List, Optional, Tuple, Union
from threading import Event
from typing import Any, List, Optional, Tuple, Union
from unittest.mock import patch
import torch
import torch.distributed as dist
import zmq
from torch.distributed import ProcessGroup
from zmq import IPV6 # type: ignore
from zmq import SUB, SUBSCRIBE, XPUB, XPUB_VERBOSE, Context # type: ignore
@ -400,7 +402,9 @@ class MessageQueue:
break
@contextmanager
def acquire_read(self, timeout: Optional[float] = None):
def acquire_read(self,
timeout: Optional[float] = None,
cancel: Optional[Event] = None):
assert self._is_local_reader, "Only readers can acquire read"
start_time = time.monotonic()
n_warning = 1
@ -430,6 +434,9 @@ class MessageQueue:
)
n_warning += 1
if cancel is not None and cancel.is_set():
raise RuntimeError("cancelled")
# if we time out, raise an exception
if (timeout is not None
and time.monotonic() - start_time > timeout):
@ -464,10 +471,12 @@ class MessageQueue:
if self.n_remote_reader > 0:
self.remote_socket.send(serialized_obj)
def dequeue(self, timeout: Optional[float] = None):
def dequeue(self,
timeout: Optional[float] = None,
cancel: Optional[Event] = None):
""" Read from message queue with optional timeout (in seconds) """
if self._is_local_reader:
with self.acquire_read(timeout) as buf:
with self.acquire_read(timeout, cancel) as buf:
overflow = buf[0] == 1
if not overflow:
# no need to know the size of serialized object
@ -475,15 +484,21 @@ class MessageQueue:
# see https://docs.python.org/3/library/pickle.html
obj = pickle.loads(buf[1:])
if overflow:
recv = self.local_socket.recv()
obj = pickle.loads(recv)
obj = MessageQueue.recv(self.local_socket, timeout)
elif self._is_remote_reader:
recv = self.remote_socket.recv()
obj = pickle.loads(recv)
obj = MessageQueue.recv(self.remote_socket, timeout)
else:
raise RuntimeError("Only readers can dequeue")
return obj
@staticmethod
def recv(socket: zmq.Socket, timeout: Optional[float]) -> Any:
timeout_ms = None if timeout is None else int(timeout * 1000)
if not socket.poll(timeout=timeout_ms):
raise TimeoutError
recv = socket.recv(copy=False)
return pickle.loads(recv.buffer)
def broadcast_object(self, obj=None):
if self._is_writer:
self.enqueue(obj)

View File

@ -12,9 +12,11 @@ from fastapi import FastAPI, Request, Response
from vllm import envs
from vllm.engine.async_llm_engine import AsyncEngineDeadError
from vllm.engine.multiprocessing import MQEngineDeadError
from vllm.engine.protocol import EngineClient
from vllm.entrypoints.ssl import SSLCertRefresher
from vllm.logger import init_logger
from vllm.utils import find_process_using_port
from vllm.v1.engine.exceptions import EngineDeadError, EngineGenerateError
logger = init_logger(__name__)
@ -40,6 +42,8 @@ async def serve_http(app: FastAPI,
loop = asyncio.get_running_loop()
watchdog_task = loop.create_task(
watchdog_loop(server, app.state.engine_client))
server_task = loop.create_task(
server.serve(sockets=[sock] if sock else None))
@ -52,6 +56,7 @@ async def serve_http(app: FastAPI,
def signal_handler() -> None:
# prevents the uvicorn signal handler to exit early
server_task.cancel()
watchdog_task.cancel()
if ssl_cert_refresher:
ssl_cert_refresher.stop()
@ -73,48 +78,69 @@ async def serve_http(app: FastAPI,
port, process, " ".join(process.cmdline()))
logger.info("Shutting down FastAPI HTTP server.")
return server.shutdown()
finally:
watchdog_task.cancel()
async def watchdog_loop(server: uvicorn.Server, engine: EngineClient):
"""
# Watchdog task that runs in the background, checking
# for error state in the engine. Needed to trigger shutdown
# if an exception arises is StreamingResponse() generator.
"""
VLLM_WATCHDOG_TIME_S = 5.0
while True:
await asyncio.sleep(VLLM_WATCHDOG_TIME_S)
terminate_if_errored(server, engine)
def terminate_if_errored(server: uvicorn.Server, engine: EngineClient):
"""
See discussions here on shutting down a uvicorn server
https://github.com/encode/uvicorn/discussions/1103
In this case we cannot await the server shutdown here
because handler must first return to close the connection
for this request.
"""
engine_errored = engine.errored and not engine.is_running
if not envs.VLLM_KEEP_ALIVE_ON_ENGINE_DEATH and engine_errored:
server.should_exit = True
def _add_shutdown_handlers(app: FastAPI, server: uvicorn.Server) -> None:
"""Adds handlers for fatal errors that should crash the server"""
"""
VLLM V1 AsyncLLM catches exceptions and returns
only two types: EngineGenerateError and EngineDeadError.
EngineGenerateError is raised by the per request generate()
method. This error could be request specific (and therefore
recoverable - e.g. if there is an error in input processing).
EngineDeadError is raised by the background output_handler
method. This error is global and therefore not recoverable.
We register these @app.exception_handlers to return nice
responses to the end user if they occur and shut down if needed.
See https://fastapi.tiangolo.com/tutorial/handling-errors/
for more details on how exception handlers work.
If an exception is encountered in a StreamingResponse
generator, the exception is not raised, since we already sent
a 200 status. Rather, we send an error message as the next chunk.
Since the exception is not raised, this means that the server
will not automatically shut down. Instead, we use the watchdog
background task for check for errored state.
"""
@app.exception_handler(RuntimeError)
async def runtime_error_handler(request: Request, __):
"""On generic runtime error, check to see if the engine has died.
It probably has, in which case the server will no longer be able to
handle requests. Trigger a graceful shutdown with a SIGTERM."""
engine = request.app.state.engine_client
if (not envs.VLLM_KEEP_ALIVE_ON_ENGINE_DEATH and engine.errored
and not engine.is_running):
logger.fatal("AsyncLLMEngine has failed, terminating server "
"process")
# See discussions here on shutting down a uvicorn server
# https://github.com/encode/uvicorn/discussions/1103
# In this case we cannot await the server shutdown here because
# this handler must first return to close the connection for
# this request.
server.should_exit = True
return Response(status_code=HTTPStatus.INTERNAL_SERVER_ERROR)
@app.exception_handler(AsyncEngineDeadError)
async def async_engine_dead_handler(_, __):
"""Kill the server if the async engine is already dead. It will
not handle any further requests."""
if not envs.VLLM_KEEP_ALIVE_ON_ENGINE_DEATH:
logger.fatal("AsyncLLMEngine is already dead, terminating server "
"process")
server.should_exit = True
return Response(status_code=HTTPStatus.INTERNAL_SERVER_ERROR)
@app.exception_handler(MQEngineDeadError)
async def mq_engine_dead_handler(_, __):
"""Kill the server if the mq engine is already dead. It will
not handle any further requests."""
if not envs.VLLM_KEEP_ALIVE_ON_ENGINE_DEATH:
logger.fatal("MQLLMEngine is already dead, terminating server "
"process")
server.should_exit = True
@app.exception_handler(EngineDeadError)
@app.exception_handler(EngineGenerateError)
async def runtime_exception_handler(request: Request, __):
terminate_if_errored(
server=server,
engine=request.app.state.engine_client,
)
return Response(status_code=HTTPStatus.INTERNAL_SERVER_ERROR)

View File

@ -156,3 +156,5 @@ class EngineCoreRequestType(enum.Enum):
ABORT = b'\x01'
START_DP = b'\x02'
UTILITY = b'\x03'
# Sentinel used within EngineCoreProc.
EXECUTOR_FAILED = b'\x04'

View File

@ -1,8 +1,6 @@
# SPDX-License-Identifier: Apache-2.0
import asyncio
import logging
import os
from collections.abc import AsyncGenerator, Mapping
from copy import copy
from typing import Optional, Union
@ -26,9 +24,10 @@ from vllm.sampling_params import SamplingParams
from vllm.transformers_utils.tokenizer import AnyTokenizer
from vllm.transformers_utils.tokenizer_group import init_tokenizer_from_configs
from vllm.usage.usage_lib import UsageContext
from vllm.utils import Device, cdiv, kill_process_tree
from vllm.utils import Device, cdiv
from vllm.v1.engine import EngineCoreRequest
from vllm.v1.engine.core_client import EngineCoreClient
from vllm.v1.engine.core_client import AsyncMPClient, DPAsyncMPClient
from vllm.v1.engine.exceptions import EngineDeadError, EngineGenerateError
from vllm.v1.engine.output_processor import (OutputProcessor,
RequestOutputCollector)
from vllm.v1.engine.parallel_sampling import ParentRequest
@ -61,8 +60,6 @@ class AsyncLLM(EngineClient):
"AsyncLLMEngine.from_vllm_config(...) or explicitly set "
"VLLM_USE_V1=0 or 1 and report this issue on Github.")
assert start_engine_loop
self.model_config = vllm_config.model_config
self.vllm_config = vllm_config
self.log_requests = log_requests
@ -99,15 +96,23 @@ class AsyncLLM(EngineClient):
log_stats=self.log_stats)
# EngineCore (starts the engine in background process).
self.engine_core = EngineCoreClient.make_client(
multiprocess_mode=True,
asyncio_mode=True,
core_client_class = AsyncMPClient if (
vllm_config.parallel_config.data_parallel_size
== 1) else DPAsyncMPClient
self.engine_core = core_client_class(
vllm_config=vllm_config,
executor_class=executor_class,
log_stats=self.log_stats,
)
self.output_handler: Optional[asyncio.Task] = None
try:
# Start output handler eagerly if we are in the asyncio eventloop.
asyncio.get_running_loop()
self._run_output_handler()
except RuntimeError:
pass
@classmethod
def from_vllm_config(
@ -165,6 +170,9 @@ class AsyncLLM(EngineClient):
usage_context=usage_context,
)
def __del__(self):
self.shutdown()
def shutdown(self):
"""Shutdown, cleaning up the background proc and IPC."""
@ -187,6 +195,9 @@ class AsyncLLM(EngineClient):
) -> RequestOutputCollector:
"""Add new request to the AsyncLLM."""
if self.errored:
raise EngineDeadError()
assert isinstance(params, SamplingParams), \
"Pooling is not supported in V1"
@ -261,9 +272,7 @@ class AsyncLLM(EngineClient):
# We start the output_handler on the first call to generate() so
# we can call __init__ before the event loop, which enables us
# to handle startup failure gracefully in the OpenAI server.
if self.output_handler is None:
self.output_handler = asyncio.create_task(
self._run_output_handler())
self._run_output_handler()
q = await self.add_request(
request_id,
@ -288,62 +297,96 @@ class AsyncLLM(EngineClient):
finished = out.finished
yield out
# If the request is disconnected by the client, the
# generate() task will be canceled. So, we abort the
# request if we end up here.
# If the request is disconnected by the client, generate()
# is cancelled. So, we abort the request if we end up here.
except asyncio.CancelledError:
await self.abort(request_id)
if self.log_requests:
logger.info("Request %s aborted.", request_id)
raise
async def _run_output_handler(self):
# Engine is dead. Do not abort since we shut down.
except EngineDeadError:
if self.log_requests:
logger.info("Request %s failed (engine dead).", request_id)
raise
# Request validation error.
except ValueError:
if self.log_requests:
logger.info("Request %s failed (bad request).", request_id)
raise
# Unexpected error in the generate() task (possibly recoverable).
except Exception as e:
await self.abort(request_id)
if self.log_requests:
logger.info("Request %s failed.", request_id)
raise EngineGenerateError() from e
def _run_output_handler(self):
"""Background loop: pulls from EngineCore and pushes to AsyncStreams."""
try:
while True:
# 1) Pull EngineCoreOutputs from the EngineCore.
outputs = await self.engine_core.get_output_async()
num_outputs = len(outputs.outputs)
if self.output_handler is not None:
return
iteration_stats = IterationStats() if (
self.log_stats and num_outputs) else None
# Ensure that the task doesn't have a circular ref back to the AsyncLLM
# object, or else it won't be garbage collected and cleaned up properly.
engine_core = self.engine_core
output_processor = self.output_processor
log_stats = self.log_stats
stat_loggers = self.stat_loggers if log_stats else None
# Split outputs into chunks of at most
# VLLM_V1_OUTPUT_PROC_CHUNK_SIZE, so that we don't block the
# event loop for too long.
if num_outputs <= VLLM_V1_OUTPUT_PROC_CHUNK_SIZE:
slices = (outputs.outputs, )
else:
slices = np.array_split(
outputs.outputs,
cdiv(num_outputs, VLLM_V1_OUTPUT_PROC_CHUNK_SIZE))
async def output_handler():
try:
while True:
# 1) Pull EngineCoreOutputs from the EngineCore.
outputs = await engine_core.get_output_async()
num_outputs = len(outputs.outputs)
for i, outputs_slice in enumerate(slices):
# 2) Process EngineCoreOutputs.
processed_outputs = self.output_processor.process_outputs(
outputs_slice, outputs.timestamp, iteration_stats)
# NOTE: RequestOutputs are pushed to their queues.
assert not processed_outputs.request_outputs
iteration_stats = IterationStats() if (
log_stats and num_outputs) else None
# Allow other asyncio tasks to run between chunks
if i + 1 < len(slices):
await asyncio.sleep(0)
# Split outputs into chunks of at most
# VLLM_V1_OUTPUT_PROC_CHUNK_SIZE, so that we don't block the
# event loop for too long.
if num_outputs <= VLLM_V1_OUTPUT_PROC_CHUNK_SIZE:
slices = (outputs.outputs, )
else:
slices = np.array_split(
outputs.outputs,
cdiv(num_outputs, VLLM_V1_OUTPUT_PROC_CHUNK_SIZE))
# 3) Abort any reqs that finished due to stop strings.
await self.engine_core.abort_requests_async(
processed_outputs.reqs_to_abort)
for i, outputs_slice in enumerate(slices):
# 2) Process EngineCoreOutputs.
processed_outputs = output_processor.process_outputs(
outputs_slice, outputs.timestamp, iteration_stats)
# NOTE: RequestOutputs are pushed to their queues.
assert not processed_outputs.request_outputs
# 4) Logging.
# TODO(rob): make into a coroutine and launch it in
# background thread once Prometheus overhead is non-trivial.
self._record_stats(
engine_index=outputs.engine_index,
scheduler_stats=outputs.scheduler_stats,
iteration_stats=iteration_stats,
)
# Allow other asyncio tasks to run between chunks
if i + 1 < len(slices):
await asyncio.sleep(0)
except Exception as e:
logger.exception("EngineCore output handler hit an error: %s", e)
kill_process_tree(os.getpid())
# 3) Abort any reqs that finished due to stop strings.
await engine_core.abort_requests_async(
processed_outputs.reqs_to_abort)
# 4) Logging.
# TODO(rob): make into a coroutine and launch it in
# background thread once Prometheus overhead is non-trivial.
if stat_loggers:
assert outputs.scheduler_stats is not None
AsyncLLM._record_stats(
stat_loggers[outputs.engine_index],
scheduler_stats=outputs.scheduler_stats,
iteration_stats=iteration_stats,
)
except Exception as e:
logger.exception("AsyncLLM output_handler failed.")
output_processor.propagate_error(e)
self.output_handler = asyncio.create_task(output_handler())
async def abort(self, request_id: str) -> None:
"""Abort RequestId in OutputProcessor and EngineCore."""
@ -354,17 +397,15 @@ class AsyncLLM(EngineClient):
if self.log_requests:
logger.info("Aborted request %s.", request_id)
@staticmethod
def _record_stats(
self,
scheduler_stats: Optional[SchedulerStats],
stat_loggers: list[StatLoggerBase],
scheduler_stats: SchedulerStats,
iteration_stats: Optional[IterationStats],
engine_index: int = 0,
):
if not self.log_stats:
return
assert scheduler_stats is not None
for stat_logger in self.stat_loggers[engine_index]:
"""static so that it can be used from the output_handler task
without a circular ref to AsyncLLM."""
for stat_logger in stat_loggers:
stat_logger.record(scheduler_stats=scheduler_stats,
iteration_stats=iteration_stats)
@ -451,16 +492,17 @@ class AsyncLLM(EngineClient):
@property
def is_running(self) -> bool:
return True
# Is None before the loop is started.
return self.output_handler is None or not self.output_handler.done()
@property
def is_stopped(self) -> bool:
return False
return self.errored
@property
def errored(self) -> bool:
return False
return self.engine_core.resources.engine_dead or not self.is_running
@property
def dead_error(self) -> BaseException:
return Exception() # TODO: implement
return EngineDeadError()

View File

@ -11,9 +11,7 @@ from logging import DEBUG
from typing import Any, Callable, Optional, TypeVar, Union
import msgspec
import psutil
import zmq
import zmq.asyncio
from vllm.config import ParallelConfig, VllmConfig
from vllm.distributed import stateless_destroy_torch_distributed_process_group
@ -22,8 +20,7 @@ from vllm.logger import init_logger
from vllm.lora.request import LoRARequest
from vllm.transformers_utils.config import (
maybe_register_config_serialize_by_value)
from vllm.utils import (get_exception_traceback, resolve_obj_by_qualname,
zmq_socket_ctx)
from vllm.utils import resolve_obj_by_qualname, zmq_socket_ctx
from vllm.v1.core.kv_cache_utils import (get_kv_cache_config,
unify_kv_cache_configs)
from vllm.v1.core.sched.interface import SchedulerInterface
@ -50,12 +47,11 @@ _R = TypeVar('_R') # Return type for collective_rpc
class EngineCore:
"""Inner loop of vLLM's Engine."""
def __init__(
self,
vllm_config: VllmConfig,
executor_class: type[Executor],
log_stats: bool,
):
def __init__(self,
vllm_config: VllmConfig,
executor_class: type[Executor],
log_stats: bool,
executor_fail_callback: Optional[Callable] = None):
assert vllm_config.model_config.runner_type != "pooling"
logger.info("Initializing a V1 LLM engine (v%s) with config: %s",
@ -65,6 +61,9 @@ class EngineCore:
# Setup Model.
self.model_executor = executor_class(vllm_config)
if executor_fail_callback is not None:
self.model_executor.register_failure_callback(
executor_fail_callback)
# Setup KV Caches and update CacheConfig after profiling.
num_gpu_blocks, num_cpu_blocks, kv_cache_config = \
@ -254,7 +253,8 @@ class EngineCore:
return engine_core_outputs
def shutdown(self):
self.model_executor.shutdown()
if self.model_executor:
self.model_executor.shutdown()
def profile(self, is_start: bool = True):
self.model_executor.profile(is_start)
@ -308,6 +308,8 @@ class EngineCore:
class EngineCoreProc(EngineCore):
"""ZMQ-wrapper for running EngineCore in background process."""
ENGINE_CORE_DEAD = b'ENGINE_CORE_DEAD'
def __init__(
self,
input_path: str,
@ -317,11 +319,16 @@ class EngineCoreProc(EngineCore):
log_stats: bool,
engine_index: int = 0,
):
super().__init__(vllm_config, executor_class, log_stats)
input_queue = queue.Queue[tuple[EngineCoreRequestType, Any]]()
executor_fail_callback = lambda: input_queue.put_nowait(
(EngineCoreRequestType.EXECUTOR_FAILED, b''))
super().__init__(vllm_config, executor_class, log_stats,
executor_fail_callback)
self.step_fn = (self.step if self.batch_queue is None else
self.step_with_batch_queue)
self.global_unfinished_reqs = False
# Background Threads and Queues for IO. These enable us to
@ -329,15 +336,16 @@ class EngineCoreProc(EngineCore):
# and to overlap some serialization/deserialization with the
# model forward pass.
# Threads handle Socket <-> Queues and core_busy_loop uses Queue.
self.input_queue: queue.Queue[tuple[EngineCoreRequestType,
Any]] = queue.Queue()
self.output_queue: queue.Queue[EngineCoreOutputs] = queue.Queue()
self.input_queue = input_queue
self.output_queue = queue.Queue[Union[EngineCoreOutputs, bytes]]()
threading.Thread(target=self.process_input_socket,
args=(input_path, engine_index),
daemon=True).start()
threading.Thread(target=self.process_output_socket,
args=(output_path, engine_index),
daemon=True).start()
self.output_thread = threading.Thread(
target=self.process_output_socket,
args=(output_path, engine_index),
daemon=True)
self.output_thread.start()
@staticmethod
def run_engine_core(*args,
@ -364,7 +372,6 @@ class EngineCoreProc(EngineCore):
signal.signal(signal.SIGTERM, signal_handler)
signal.signal(signal.SIGINT, signal_handler)
parent_process = psutil.Process().parent()
engine_core: Optional[EngineCoreProc] = None
try:
parallel_config: ParallelConfig = kwargs[
@ -380,13 +387,15 @@ class EngineCoreProc(EngineCore):
engine_core.run_busy_loop()
except SystemExit:
logger.debug("EngineCore interrupted.")
except Exception:
traceback = get_exception_traceback()
logger.error("EngineCore hit an exception: %s", traceback)
parent_process.send_signal(signal.SIGUSR1)
logger.debug("EngineCore exiting.")
except Exception as e:
if engine_core is None:
logger.exception("EngineCore failed to start.")
else:
logger.exception("EngineCore encountered a fatal error.")
engine_core._send_engine_dead()
raise e
finally:
if engine_core is not None:
engine_core.shutdown()
@ -458,6 +467,11 @@ class EngineCoreProc(EngineCore):
f" failed: {str(e)}")
self.output_queue.put_nowait(
EngineCoreOutputs(utility_output=output))
elif request_type == EngineCoreRequestType.EXECUTOR_FAILED:
raise RuntimeError("Executor failed.")
else:
logger.error("Unrecognized input request type encountered: %s",
request_type)
@staticmethod
def _convert_msgspec_args(method, args):
@ -473,6 +487,18 @@ class EngineCoreProc(EngineCore):
and not isinstance(v, p.annotation) else v
for v, p in zip(args, arg_types))
def _send_engine_dead(self):
"""Send EngineDead status to the EngineCoreClient."""
# Put ENGINE_CORE_DEAD in the queue.
self.output_queue.put_nowait(EngineCoreProc.ENGINE_CORE_DEAD)
# Wait until msg sent by the daemon before shutdown.
self.output_thread.join(timeout=5.0)
if self.output_thread.is_alive():
logger.fatal("vLLM shutdown signal from EngineCore failed "
"to send. Please report this issue.")
def process_input_socket(self, input_path: str, engine_index: int):
"""Input socket IO thread."""
@ -511,9 +537,16 @@ class EngineCoreProc(EngineCore):
# Reuse send buffer.
buffer = bytearray()
with zmq_socket_ctx(output_path, zmq.constants.PUSH) as socket:
# We must set linger to ensure the ENGINE_CORE_DEAD
# message is sent prior to closing the socket.
with zmq_socket_ctx(output_path, zmq.constants.PUSH,
linger=4000) as socket:
while True:
outputs = self.output_queue.get()
if outputs == EngineCoreProc.ENGINE_CORE_DEAD:
socket.send(outputs, copy=False)
break
assert not isinstance(outputs, bytes)
outputs.engine_index = engine_index
buffers = encoder.encode_into(outputs, buffer)
socket.send_multipart(buffers, copy=False)

View File

@ -1,14 +1,10 @@
# SPDX-License-Identifier: Apache-2.0
import asyncio
import os
import queue
import signal
import threading
import uuid
import weakref
from abc import ABC, abstractmethod
from collections.abc import Awaitable
from collections.abc import Awaitable, Sequence
from concurrent.futures import Future
from dataclasses import dataclass, field
from threading import Thread
@ -21,10 +17,11 @@ from vllm.config import VllmConfig
from vllm.logger import init_logger
from vllm.lora.request import LoRARequest
from vllm.utils import (get_open_zmq_inproc_path, get_open_zmq_ipc_path,
kill_process_tree, make_zmq_socket)
make_zmq_socket)
from vllm.v1.engine import (EngineCoreOutputs, EngineCoreRequest,
EngineCoreRequestType, UtilityOutput)
from vllm.v1.engine.core import EngineCore, EngineCoreProc
from vllm.v1.engine.exceptions import EngineDeadError
from vllm.v1.executor.abstract import Executor
from vllm.v1.serial_utils import MsgpackDecoder, MsgpackEncoder, bytestr
from vllm.v1.utils import BackgroundProcHandle
@ -305,14 +302,22 @@ class BackgroundResources:
core_engines: list[CoreEngine] = field(default_factory=list)
output_socket: Optional[Union[zmq.Socket, zmq.asyncio.Socket]] = None
input_socket: Optional[Union[zmq.Socket, zmq.asyncio.Socket]] = None
output_queue_task: Optional[asyncio.Task] = None
shutdown_path: Optional[str] = None
# Set if any of the engines are dead. Here so that the output
# processing threads can access it without holding a ref to the client.
engine_dead: bool = False
def __call__(self):
"""Clean up background resources."""
for core_engine in self.core_engines:
core_engine.close()
if self.output_queue_task is not None:
self.output_queue_task.cancel()
# ZMQ context termination can hang if the sockets
# aren't explicitly closed first.
if self.output_socket is not None:
@ -327,6 +332,12 @@ class BackgroundResources:
# Send shutdown signal.
shutdown_sender.send(b'')
def validate_alive(self, frames: Sequence[zmq.Frame]):
if len(frames) == 1 and (frames[0].buffer
== EngineCoreProc.ENGINE_CORE_DEAD):
self.engine_dead = True
raise EngineDeadError()
class MPClient(EngineCoreClient):
"""
@ -348,27 +359,6 @@ class MPClient(EngineCoreClient):
executor_class: type[Executor],
log_stats: bool,
):
# The child processes will send SIGUSR1 when unrecoverable
# errors happen. We kill the process tree here so that the
# stack trace is very evident.
# TODO(rob): rather than killing the main process, we should
# figure out how to raise an AsyncEngineDeadError and
# handle at the API server level so we can return a better
# error code to the clients calling vLLM.
def sigusr1_handler(signum, frame):
logger.fatal("Got fatal signal from worker processes, shutting "
"down. See stack trace above for root cause issue.")
kill_process_tree(os.getpid())
if threading.current_thread() == threading.main_thread():
signal.signal(signal.SIGUSR1, sigusr1_handler)
else:
logger.warning("SIGUSR1 handler not installed because we are not "
"running in the main thread. In this case the "
"forked engine process may not be killed when "
"an exception is raised, and you need to handle "
"the engine process shutdown manually.")
# Serialization setup.
self.encoder = MsgpackEncoder()
self.decoder = MsgpackDecoder(EngineCoreOutputs)
@ -378,32 +368,37 @@ class MPClient(EngineCoreClient):
self.ctx = zmq.asyncio.Context(sync_ctx) if asyncio_mode else sync_ctx
# This will ensure resources created so far are closed
# when the client is garbage collected, even if an
# when the client is garbage collected, even if an
# exception is raised mid-construction.
self.resources = BackgroundResources(ctx=sync_ctx)
self._finalizer = weakref.finalize(self, self.resources)
success = False
try:
# Paths and sockets for IPC.
self.output_path = get_open_zmq_ipc_path()
input_path = get_open_zmq_ipc_path()
self.input_socket = make_zmq_socket(self.ctx,
input_path,
zmq.ROUTER,
bind=True)
self.resources.input_socket = self.input_socket
# Paths and sockets for IPC.
self.output_path = get_open_zmq_ipc_path()
input_path = get_open_zmq_ipc_path()
self.input_socket = make_zmq_socket(self.ctx,
input_path,
zmq.ROUTER,
bind=True)
self.resources.input_socket = self.input_socket
new_core_engine = lambda index, local_dp_rank=None: CoreEngine(
vllm_config, executor_class, log_stats, input_path, self.
output_path, index, local_dp_rank)
new_core_engine = lambda index, local_dp_rank=None: CoreEngine(
vllm_config, executor_class, log_stats, input_path, self.
output_path, index, local_dp_rank)
# Start engine core process(es).
self._init_core_engines(vllm_config, new_core_engine,
self.resources.core_engines)
# Start engine core process(es).
self._init_core_engines(vllm_config, new_core_engine,
self.resources.core_engines)
# Wait for engine core process(es) to start.
self._wait_for_engine_startup()
# Wait for engine core process(es) to start.
self._wait_for_engine_startup()
self.utility_results: dict[int, AnyFuture] = {}
self.utility_results: dict[int, AnyFuture] = {}
success = True
finally:
if not success:
self._finalizer()
def _wait_for_engine_startup(self):
# Get a sync handle to the socket which can be sync or async.
@ -451,8 +446,18 @@ class MPClient(EngineCoreClient):
self.core_engine = core_engine
def shutdown(self):
# Terminate background resources.
self._finalizer()
def _format_exception(self, e: Exception) -> Exception:
"""If errored, use EngineDeadError so root cause is clear."""
return EngineDeadError(
suppress_context=True) if self.resources.engine_dead else e
def ensure_alive(self):
if self.resources.engine_dead:
raise EngineDeadError()
def _process_utility_output(output: UtilityOutput,
utility_results: dict[int, AnyFuture]):
@ -476,7 +481,7 @@ class SyncMPClient(MPClient):
log_stats=log_stats,
)
self.outputs_queue: queue.Queue[EngineCoreOutputs] = queue.Queue()
self.outputs_queue = queue.Queue[Union[EngineCoreOutputs, Exception]]()
# Ensure that the outputs socket processing thread does not have
# a ref to the client which prevents gc.
@ -487,7 +492,8 @@ class SyncMPClient(MPClient):
outputs_queue = self.outputs_queue
shutdown_path = get_open_zmq_inproc_path()
self.resources.shutdown_path = shutdown_path
resources = self.resources
resources.shutdown_path = shutdown_path
def process_outputs_socket():
shutdown_socket = ctx.socket(zmq.PAIR)
@ -506,12 +512,15 @@ class SyncMPClient(MPClient):
break
frames = out_socket.recv_multipart(copy=False)
resources.validate_alive(frames)
outputs = decoder.decode(frames)
if outputs.utility_output:
_process_utility_output(outputs.utility_output,
utility_results)
else:
outputs_queue.put_nowait(outputs)
except Exception as e:
outputs_queue.put_nowait(e)
finally:
# Close sockets.
shutdown_socket.close(linger=0)
@ -524,9 +533,16 @@ class SyncMPClient(MPClient):
self.output_queue_thread.start()
def get_output(self) -> EngineCoreOutputs:
return self.outputs_queue.get()
# If an exception arises in process_outputs_socket task,
# it is forwarded to the outputs_queue so we can raise it
# from this (run_output_handler) task to shut down the server.
outputs = self.outputs_queue.get()
if isinstance(outputs, Exception):
raise self._format_exception(outputs) from None
return outputs
def _send_input(self, request_type: EngineCoreRequestType, request: Any):
self.ensure_alive()
# (Identity, RequestType, SerializedRequest)
msg = (self.core_engine.identity, request_type.value,
*self.encoder.encode(request))
@ -608,61 +624,81 @@ class AsyncMPClient(MPClient):
log_stats=log_stats,
)
self.outputs_queue: Optional[asyncio.Queue[EngineCoreOutputs]] = None
self.queue_task: Optional[asyncio.Task] = None
self.outputs_handler: Optional[Callable[
[AsyncMPClient, EngineCoreOutputs], Awaitable[None]]] = None
self.outputs_queue = asyncio.Queue[Union[EngineCoreOutputs,
Exception]]()
try:
# If we are running in an asyncio event loop, start the queue task.
# Otherwise, it will be started lazily. If it is not started here,
# we could miss EXECUTOR_FAILED messages from engine core if they
# occur prior to any requests being sent.
asyncio.get_running_loop()
self._ensure_output_queue_task()
except RuntimeError:
pass
def _ensure_output_queue_task(self):
if self.outputs_queue is not None:
resources = self.resources
if resources.output_queue_task is not None:
return
# Perform IO in separate task to parallelize as much as possible.
# Avoid task having direct reference back to the client.
self.outputs_queue = asyncio.Queue()
decoder = self.decoder
utility_results = self.utility_results
outputs_queue = self.outputs_queue
output_handler = self.outputs_handler
output_handler: Optional[Callable[[AsyncMPClient, EngineCoreOutputs],
Awaitable[None]]] = getattr(
self.__class__,
"process_engine_outputs", None)
_self_ref = weakref.ref(self) if output_handler else None
output_path = self.output_path
output_socket = make_zmq_socket(self.ctx, output_path,
zmq.constants.PULL)
self.resources.output_socket = output_socket
resources.output_socket = output_socket
async def process_outputs_socket():
while True:
frames = await output_socket.recv_multipart(copy=False)
outputs: EngineCoreOutputs = decoder.decode(frames)
if outputs.utility_output:
_process_utility_output(outputs.utility_output,
utility_results)
continue
try:
while True:
frames = await output_socket.recv_multipart(copy=False)
resources.validate_alive(frames)
outputs: EngineCoreOutputs = decoder.decode(frames)
if outputs.utility_output:
_process_utility_output(outputs.utility_output,
utility_results)
continue
if output_handler is not None:
assert _self_ref is not None
_self = _self_ref()
if not _self:
# Client has been garbage collected, abort.
return
await output_handler(_self, outputs)
if output_handler is not None:
assert _self_ref is not None
_self = _self_ref()
if not _self:
# Client has been garbage collected, abort.
return
await output_handler(_self, outputs)
if outputs.outputs or outputs.scheduler_stats:
outputs_queue.put_nowait(outputs)
if outputs.outputs or outputs.scheduler_stats:
outputs_queue.put_nowait(outputs)
except Exception as e:
outputs_queue.put_nowait(e)
self.queue_task = asyncio.create_task(process_outputs_socket(),
name="EngineCoreOutputQueueTask")
resources.output_queue_task = asyncio.create_task(
process_outputs_socket(), name="EngineCoreOutputQueueTask")
async def get_output_async(self) -> EngineCoreOutputs:
self._ensure_output_queue_task()
# If an exception arises in process_outputs_socket task,
# it is forwarded to the outputs_queue so we can raise it
# from this (run_output_handler) task to shut down the server.
assert self.outputs_queue is not None
return await self.outputs_queue.get()
outputs = await self.outputs_queue.get()
if isinstance(outputs, Exception):
raise self._format_exception(outputs) from None
return outputs
def _send_input(self,
request_type: EngineCoreRequestType,
request: Any,
engine: Optional[CoreEngine] = None) -> Awaitable[None]:
self.ensure_alive()
if engine is None:
engine = self.core_engine
@ -671,6 +707,7 @@ class AsyncMPClient(MPClient):
def _send_input_message(self, message: tuple[bytestr, ...],
engine: CoreEngine) -> Awaitable[None]:
self.ensure_alive()
message = (engine.identity, ) + message
return self.input_socket.send_multipart(message, copy=False)
@ -754,18 +791,17 @@ class DPAsyncMPClient(AsyncMPClient):
def __init__(self, vllm_config: VllmConfig, executor_class: type[Executor],
log_stats: bool):
super().__init__(vllm_config, executor_class, log_stats)
assert len(self.core_engines) > 1
self.num_engines_running = 0
self.reqs_in_flight: dict[str, CoreEngine] = {}
super().__init__(vllm_config, executor_class, log_stats)
# Control message used for triggering dp idle mode loop.
self.start_dp_msg = (EngineCoreRequestType.START_DP.value,
*self.encoder.encode(None))
self.num_engines_running = 0
self.reqs_in_flight: dict[str, CoreEngine] = {}
self.outputs_handler = DPAsyncMPClient.process_engine_outputs # type: ignore[assignment]
assert len(self.core_engines) > 1
def _init_core_engines(
self,

View File

@ -0,0 +1,16 @@
# SPDX-License-Identifier: Apache-2.0
class EngineGenerateError(Exception):
"""Raised when a AsyncLLM.generate() fails. Recoverable."""
pass
class EngineDeadError(Exception):
"""Raised when the EngineCore dies. Unrecoverable."""
def __init__(self, *args, suppress_context: bool = False, **kwargs):
ENGINE_DEAD_MESSAGE = "EngineCore encountered an issue. See stack trace (above) for the root cause." # noqa: E501
super().__init__(ENGINE_DEAD_MESSAGE, *args, **kwargs)
# Make stack trace clearer when using with LLMEngine by
# silencing irrelevant ZMQError.
self.__suppress_context__ = suppress_context

View File

@ -28,32 +28,40 @@ class RequestOutputCollector:
def __init__(self, output_kind: RequestOutputKind):
self.aggregate = output_kind == RequestOutputKind.DELTA
self.output: Optional[RequestOutput] = None
self.output: Optional[Union[RequestOutput, Exception]] = None
self.ready = asyncio.Event()
def put(self, output: RequestOutput) -> None:
if self.output is None:
def put(self, output: Union[RequestOutput, Exception]) -> None:
"""Non-blocking put operation."""
if self.output is None or isinstance(output, Exception):
self.output = output
self.ready.set()
elif self.aggregate:
# Coalesce the outputs in delta case.
self.output.add(output)
else:
# Just replace latest in non-delta case.
self.output = output
elif isinstance(self.output, RequestOutput):
if self.aggregate:
# Coalesce the outputs in delta case.
self.output.add(output)
else:
# Just replace latest in non-delta case.
self.output = output
async def get(self) -> RequestOutput:
"""Get operation blocks on put event."""
while (output := self.output) is None:
await self.ready.wait()
self.output = None
self.ready.clear()
if isinstance(output, Exception):
raise output
return output
def get_nowait(self) -> Optional[RequestOutput]:
"""Non-blocking get operation."""
output = self.output
if output is not None:
self.output = None
self.ready.clear()
if isinstance(output, Exception):
raise output
return output
@ -235,6 +243,13 @@ class OutputProcessor:
def has_unfinished_requests(self) -> bool:
return len(self.request_states) > 0
def propagate_error(self, e: Exception):
"""Propagate error to all generate() tasks."""
for _, state in self.request_states.items():
assert state.queue is not None
state.queue.put(e)
def abort_requests(
self,
request_ids: Iterable[str],

View File

@ -1,7 +1,7 @@
# SPDX-License-Identifier: Apache-2.0
from concurrent.futures import Future
from typing import Union
from typing import Callable, Union
import torch
import torch.distributed as dist
@ -15,6 +15,8 @@ from vllm.executor.uniproc_executor import ( # noqa
from vllm.v1.kv_cache_interface import KVCacheConfig, KVCacheSpec
from vllm.v1.outputs import ModelRunnerOutput
FailureCallback = Callable[[], None]
class Executor(ExecutorBase):
"""
@ -62,6 +64,13 @@ class Executor(ExecutorBase):
args=(kv_cache_configs, ))
self.collective_rpc("compile_or_warm_up_model")
def register_failure_callback(self, callback: FailureCallback):
"""
Register a function to be called if the executor enters a permanent
failed state.
"""
pass
def determine_available_memory(self) -> list[int]: # in bytes
output = self.collective_rpc("determine_available_memory")
return output

View File

@ -1,21 +1,23 @@
# SPDX-License-Identifier: Apache-2.0
import multiprocessing
import os
import pickle
import signal
import sys
import threading
import time
import traceback
import weakref
from concurrent.futures import Future
from dataclasses import dataclass
from enum import Enum, auto
from functools import partial
from multiprocessing.connection import Connection
from multiprocessing.process import BaseProcess
from typing import Any, Callable, Optional, Union
from threading import Thread
from typing import Any, Callable, Optional, Union, cast
import cloudpickle
import psutil
import zmq
from vllm.config import VllmConfig
from vllm.distributed import (destroy_distributed_environment,
@ -26,8 +28,9 @@ from vllm.executor.multiproc_worker_utils import (
_add_prefix, set_multiprocessing_worker_envs)
from vllm.logger import init_logger
from vllm.utils import (get_distributed_init_method, get_mp_context,
get_open_port, get_open_zmq_ipc_path, zmq_socket_ctx)
from vllm.v1.executor.abstract import Executor
get_open_port)
from vllm.v1.executor.abstract import Executor, FailureCallback
from vllm.v1.outputs import ModelRunnerOutput
from vllm.worker.worker_base import WorkerWrapperBase
logger = init_logger(__name__)
@ -35,6 +38,8 @@ logger = init_logger(__name__)
POLLING_TIMEOUT_MS = 5000
POLLING_TIMEOUT_S = POLLING_TIMEOUT_MS // 1000
EXECUTE_MODEL_TIMEOUT_S = 30
class MultiprocExecutor(Executor):
@ -42,19 +47,9 @@ class MultiprocExecutor(Executor):
# Call self.shutdown at exit to clean up
# and ensure workers will be terminated.
self._finalizer = weakref.finalize(self, self.shutdown)
# The child processes will send SIGUSR1 when unrecoverable
# errors happen.
def sigusr1_handler(signum, frame):
logger.fatal(
"MulitprocExecutor got fatal signal from worker processes, "
"shutting down. See stack trace above for root cause issue.")
# Propagate error up to parent process.
parent_process = psutil.Process().parent()
parent_process.send_signal(signal.SIGUSR1)
self.shutdown()
signal.signal(signal.SIGUSR1, sigusr1_handler)
self.is_failed = False
self.shutdown_event = threading.Event()
self.failure_callback: Optional[FailureCallback] = None
self.world_size = self.parallel_config.world_size
tensor_parallel_size = self.parallel_config.tensor_parallel_size
@ -78,28 +73,94 @@ class MultiprocExecutor(Executor):
scheduler_output_handle = self.rpc_broadcast_mq.export_handle()
# Create workers
self.workers: list[WorkerProcHandle] = []
for rank in range(self.world_size):
worker = WorkerProc.make_worker_process(self.vllm_config, rank,
rank,
distributed_init_method,
scheduler_output_handle)
self.workers.append(worker)
unready_workers: list[UnreadyWorkerProcHandle] = []
success = False
try:
for rank in range(self.world_size):
unready_workers.append(
WorkerProc.make_worker_process(
vllm_config=self.vllm_config,
local_rank=rank,
rank=rank,
distributed_init_method=distributed_init_method,
input_shm_handle=scheduler_output_handle,
))
# Ensure message queues are ready. Will deadlock if re-ordered
# Must be kept consistent with the WorkerProc
self.rpc_broadcast_mq.wait_until_ready()
for w in self.workers:
w.worker_response_mq.wait_until_ready()
# Workers must be created before wait_for_ready to avoid
# deadlock, since worker.init_device() does a device sync.
self.workers = WorkerProc.wait_for_ready(unready_workers)
# Ensure message queues are ready. Will deadlock if re-ordered
# Must be kept consistent with the WorkerProc.
self.rpc_broadcast_mq.wait_until_ready()
for w in self.workers:
w.worker_response_mq.wait_until_ready()
self.start_worker_monitor()
success = True
finally:
if not success:
# Clean up the worker procs if there was a failure.
self._ensure_worker_termination(
[w.proc for w in unready_workers])
def start_worker_monitor(self):
workers = self.workers
self_ref = weakref.ref(self)
# Monitors worker process liveness. If any die unexpectedly,
# logs an error, shuts down the executor and invokes the failure
# callback to inform the engine.
def monitor_workers():
sentinels = [h.proc.sentinel for h in workers]
died = multiprocessing.connection.wait(sentinels)
_self = self_ref()
if not _self or getattr(_self, 'shutting_down', False):
return
_self.is_failed = True
proc_name = next(h.proc.name for h in workers
if h.proc.sentinel == died[0])
logger.error(
"Worker proc %s died unexpectedly, "
"shutting down executor.", proc_name)
_self.shutdown()
callback = _self.failure_callback
if callback is not None:
_self.failure_callback = None
callback()
Thread(target=monitor_workers,
daemon=True,
name="MultiprocWorkerMonitor").start()
def register_failure_callback(self, callback: FailureCallback):
if self.is_failed:
callback()
else:
self.failure_callback = callback
def execute_model(
self,
scheduler_output,
) -> Union[ModelRunnerOutput, Future[ModelRunnerOutput]]:
(output, ) = self.collective_rpc("execute_model",
args=(scheduler_output, ),
rank0_reply_only=True,
timeout=EXECUTE_MODEL_TIMEOUT_S)
return output
def collective_rpc(self,
method: Union[str, Callable],
timeout: Optional[float] = None,
timeout: Optional[float] = 180.0,
args: tuple = (),
kwargs: Optional[dict] = None) -> list[Any]:
kwargs: Optional[dict] = None,
rank0_reply_only: bool = False) -> list[Any]:
start_time = time.monotonic()
kwargs = kwargs or {}
if self.is_failed:
raise RuntimeError("Executor failed.")
# NOTE: If the args are heterogeneous, then we pack them into a list,
# and unpack them in the method of every worker, because every worker
# knows their own rank.
@ -109,30 +170,30 @@ class MultiprocExecutor(Executor):
else:
send_method = cloudpickle.dumps(
method, protocol=pickle.HIGHEST_PROTOCOL)
self.rpc_broadcast_mq.enqueue((send_method, args, kwargs))
self.rpc_broadcast_mq.enqueue(
(send_method, args, kwargs, rank0_reply_only))
responses = [None] * self.world_size
for w in self.workers:
workers = (self.workers[0], ) if rank0_reply_only else self.workers
responses = [None] * len(workers)
for w in workers:
dequeue_timeout = timeout - (time.monotonic() - start_time
) if timeout is not None else None
status, result = w.worker_response_mq.dequeue(
timeout=dequeue_timeout)
timeout=dequeue_timeout, cancel=self.shutdown_event)
if status != WorkerProc.ResponseStatus.SUCCESS:
raise RuntimeError(
"Worker failed with error %s, please check the"
" stack trace above for the root cause", result)
f"Worker failed with error '{result}', please check the"
" stack trace above for the root cause")
responses[w.rank] = result
return responses
except TimeoutError as e:
raise TimeoutError(f"RPC call to {method} timed out.") from e
except Exception as e:
# Re-raise any other exceptions
raise e
def _ensure_worker_termination(self):
@staticmethod
def _ensure_worker_termination(worker_procs: list[BaseProcess]):
"""Ensure that all worker processes are terminated. Assumes workers have
received termination requests. Waits for processing, then sends
termination and kill signals if needed."""
@ -150,7 +211,7 @@ class MultiprocExecutor(Executor):
return False
# Send SIGTERM if still running
active_procs = [w.proc for w in self.workers if w.proc.is_alive()]
active_procs = [proc for proc in worker_procs if proc.is_alive()]
for p in active_procs:
p.terminate()
if not wait_for_termination(active_procs, 4):
@ -159,22 +220,14 @@ class MultiprocExecutor(Executor):
for p in active_procs:
p.kill()
self._cleanup_sockets()
def _cleanup_sockets(self):
for w in self.workers:
# Remove the zmq ipc socket file
socket_path = w.ready_path.replace("ipc://", "")
if os and os.path.exists(socket_path):
os.remove(socket_path)
def shutdown(self):
"""Properly shut down the executor and its workers"""
if not getattr(self, 'shutting_down', False):
self.shutting_down = True
self.shutdown_event.set()
for w in self.workers:
w.worker_response_mq = None
self._ensure_worker_termination()
self._ensure_worker_termination([w.proc for w in self.workers])
self.rpc_broadcast_mq = None
@ -183,13 +236,30 @@ class MultiprocExecutor(Executor):
return
@dataclass
class UnreadyWorkerProcHandle:
"""WorkerProcess handle before READY."""
proc: BaseProcess
rank: int
ready_pipe: Connection
@dataclass
class WorkerProcHandle:
proc: BaseProcess
rank: int
ready_path: str
worker_response_mq: MessageQueue # The worker process writes to this MQ
@classmethod
def from_unready_handle(
cls, unready_handle: UnreadyWorkerProcHandle,
worker_response_mq: MessageQueue) -> "WorkerProcHandle":
return cls(
proc=unready_handle.proc,
rank=unready_handle.rank,
worker_response_mq=worker_response_mq,
)
class WorkerProc:
"""Wrapper that runs one Worker in a separate process."""
@ -203,7 +273,6 @@ class WorkerProc:
rank: int,
distributed_init_method: str,
input_shm_handle: Handle,
ready_path: str,
):
self.rank = rank
wrapper = WorkerWrapperBase(vllm_config=vllm_config, rpc_rank=rank)
@ -231,18 +300,8 @@ class WorkerProc:
# Initializes a message queue for sending the model output
self.worker_response_mq = MessageQueue(1, 1)
worker_response_mq_handle = self.worker_response_mq.export_handle()
# Send Readiness signal to EngineCore process.
# Set linger here because we want to ensure the message has
# been sent before the context is closed.
with zmq_socket_ctx(ready_path, zmq.constants.PUSH,
linger=10000) as ready_socket:
payload = pickle.dumps(worker_response_mq_handle,
protocol=pickle.HIGHEST_PROTOCOL)
ready_socket.send_string(WorkerProc.READY_STR)
ready_socket.send(payload)
# Initialize device and loads weights
self.worker.init_device()
self.worker.load_model()
@ -253,12 +312,10 @@ class WorkerProc:
rank: int,
distributed_init_method: str,
input_shm_handle, # Receive SchedulerOutput
) -> WorkerProcHandle:
) -> UnreadyWorkerProcHandle:
context = get_mp_context()
# ZMQ path for worker to send ready message and shm_broadcast handle
# back to core process.
ready_path = get_open_zmq_ipc_path()
# (reader, writer)
reader, writer = context.Pipe(duplex=False)
process_kwargs = {
"vllm_config": vllm_config,
@ -266,24 +323,57 @@ class WorkerProc:
"rank": rank,
"distributed_init_method": distributed_init_method,
"input_shm_handle": input_shm_handle,
"ready_path": ready_path,
"ready_pipe": (reader, writer),
}
# Run EngineCore busy loop in background process.
proc = context.Process(target=WorkerProc.worker_main,
kwargs=process_kwargs,
name=f"VllmWorker-{rank}",
daemon=True)
with zmq_socket_ctx(ready_path, zmq.constants.PULL) as ready_socket:
proc.start()
proc.start()
writer.close()
return UnreadyWorkerProcHandle(proc, rank, reader)
# Wait for startup
worker_response_mq_handle = WorkerProc.wait_for_startup(
proc, ready_socket)
@staticmethod
def wait_for_ready(
unready_proc_handles: list[UnreadyWorkerProcHandle]
) -> list[WorkerProcHandle]:
worker_response_mq = MessageQueue.create_from_handle(
worker_response_mq_handle, 0)
e = Exception("WorkerProc initialization failed due to "
"an exception in a background process. "
"See stack trace for root cause.")
return WorkerProcHandle(proc, rank, ready_path, worker_response_mq)
pipes = {handle.ready_pipe: handle for handle in unready_proc_handles}
ready_proc_handles: list[Optional[WorkerProcHandle]] = (
[None] * len(unready_proc_handles))
while pipes:
ready = multiprocessing.connection.wait(pipes.keys())
for pipe in ready:
assert isinstance(pipe, Connection)
try:
# Wait until the WorkerProc is ready.
unready_proc_handle = pipes.pop(pipe)
response: dict[str, Any] = pipe.recv()
if response["status"] != "READY":
raise e
# Extract the message queue handle.
worker_response_mq = MessageQueue.create_from_handle(
response["handle"], 0)
ready_proc_handles[unready_proc_handle.rank] = (
WorkerProcHandle.from_unready_handle(
unready_proc_handle, worker_response_mq))
except EOFError:
e.__suppress_context__ = True
raise e from None
finally:
# Close connection.
pipe.close()
return cast(list[WorkerProcHandle], ready_proc_handles)
def shutdown(self):
self.rpc_broadcast_mq = None
@ -312,51 +402,51 @@ class WorkerProc:
signal.signal(signal.SIGINT, signal_handler)
worker = None
# tuple[Connection, Connection]
reader, ready_writer = kwargs.pop("ready_pipe")
try:
reader.close()
worker = WorkerProc(*args, **kwargs)
# Send READY once we know everything is loaded
ready_writer.send({
"status":
WorkerProc.READY_STR,
"handle":
worker.worker_response_mq.export_handle(),
})
# Ensure message queues are ready. Will deadlock if re-ordered.
# Must be kept consistent with the Executor
worker.rpc_broadcast_mq.wait_until_ready()
worker.worker_response_mq.wait_until_ready()
ready_writer.close()
ready_writer = None
worker.worker_busy_loop()
except SystemExit:
logger.debug("Worker interrupted.")
except Exception:
# worker_busy_loop sends exceptions to Executor
# for shutdown, but if there is an error in startup or an
# error with IPC itself, we need to alert the parent.
psutil.Process().parent().send_signal(signal.SIGUSR1)
raise
# NOTE: if an Exception arises in busy_loop, we send
# a FAILURE message over the MQ RPC to notify the Executor,
# which triggers system shutdown.
# TODO(rob): handle case where the MQ itself breaks.
if ready_writer is not None:
logger.exception("WorkerProc failed to start.")
else:
logger.exception("WorkerProc failed.")
# The parent sends a SIGTERM to all worker processes if
# any worker dies. Set this value so we don't re-throw
# SystemExit() to avoid zmq exceptions in __del__.
shutdown_requested = True
finally:
if ready_writer is not None:
ready_writer.close()
# Clean up once worker exits busy loop
if worker is not None:
worker.shutdown()
worker = None
@staticmethod
def wait_for_startup(
proc: BaseProcess,
ready_socket: zmq.Socket,
) -> Optional[Handle]:
"""Wait until the Worker is ready."""
# Wait for Worker to send READY.
while ready_socket.poll(timeout=POLLING_TIMEOUT_MS) == 0:
logger.debug("Waiting for WorkerProc to startup.")
if not proc.is_alive():
raise RuntimeError("WorkerProc failed to start.")
message = ready_socket.recv_string()
assert message == WorkerProc.READY_STR
handle_frame = ready_socket.recv(copy=False)
handle = pickle.loads(handle_frame.buffer)
return handle
class ResponseStatus(Enum):
SUCCESS = auto()
@ -365,7 +455,7 @@ class WorkerProc:
def worker_busy_loop(self):
"""Main busy loop for Multiprocessing Workers"""
while True:
method, args, kwargs = self.rpc_broadcast_mq.dequeue()
method, args, kwargs, rank0_only = self.rpc_broadcast_mq.dequeue()
try:
if isinstance(method, str):
@ -377,12 +467,14 @@ class WorkerProc:
# Notes have been introduced in python 3.11
if hasattr(e, "add_note"):
e.add_note(traceback.format_exc())
logger.exception("WorkerProc hit an exception: %s", exc_info=e)
logger.exception("WorkerProc hit an exception.")
# exception might not be serializable, so we convert it to
# string, only for logging purpose.
self.worker_response_mq.enqueue(
(WorkerProc.ResponseStatus.FAILURE, str(e)))
if not rank0_only or self.rank == 0:
self.worker_response_mq.enqueue(
(WorkerProc.ResponseStatus.FAILURE, str(e)))
continue
self.worker_response_mq.enqueue(
(WorkerProc.ResponseStatus.SUCCESS, output))
if not rank0_only or self.rank == 0:
self.worker_response_mq.enqueue(
(WorkerProc.ResponseStatus.SUCCESS, output))