From 2b05b8ce69fff90b7f1e3285dcd05ced2b79fd97 Mon Sep 17 00:00:00 2001 From: Robert Shaw <114415538+robertgshaw2-redhat@users.noreply.github.com> Date: Wed, 16 Apr 2025 22:48:34 -0400 Subject: [PATCH] [V1][Frontend] Improve Shutdown And Logs (#11737) Signed-off-by: rshaw@neuralmagic.com Signed-off-by: Andrew Feldman Signed-off-by: Nick Hill Co-authored-by: rshaw@neuralmagic.com Co-authored-by: Cyrus Leung Co-authored-by: Russell Bryant Co-authored-by: Andrew Feldman Co-authored-by: afeldman-nm <156691304+afeldman-nm@users.noreply.github.com> Co-authored-by: Nick Hill --- .buildkite/test-pipeline.yaml | 1 + tests/v1/shutdown/test_delete.py | 97 ++++++ tests/v1/shutdown/test_forward_error.py | 129 +++++++ tests/v1/shutdown/test_processor_error.py | 69 ++++ tests/v1/shutdown/test_startup_error.py | 97 ++++++ tests/v1/shutdown/utils.py | 5 + .../device_communicators/shm_broadcast.py | 31 +- vllm/entrypoints/launcher.py | 98 ++++-- vllm/v1/engine/__init__.py | 2 + vllm/v1/engine/async_llm.py | 176 ++++++---- vllm/v1/engine/core.py | 87 +++-- vllm/v1/engine/core_client.py | 202 ++++++----- vllm/v1/engine/exceptions.py | 16 + vllm/v1/engine/output_processor.py | 33 +- vllm/v1/executor/abstract.py | 11 +- vllm/v1/executor/multiproc_executor.py | 324 +++++++++++------- 16 files changed, 1031 insertions(+), 347 deletions(-) create mode 100644 tests/v1/shutdown/test_delete.py create mode 100644 tests/v1/shutdown/test_forward_error.py create mode 100644 tests/v1/shutdown/test_processor_error.py create mode 100644 tests/v1/shutdown/test_startup_error.py create mode 100644 tests/v1/shutdown/utils.py create mode 100644 vllm/v1/engine/exceptions.py diff --git a/.buildkite/test-pipeline.yaml b/.buildkite/test-pipeline.yaml index c86f6add..5fc7b48b 100644 --- a/.buildkite/test-pipeline.yaml +++ b/.buildkite/test-pipeline.yaml @@ -552,6 +552,7 @@ steps: # - pytest -v -s spec_decode/e2e/test_integration_dist_tp2.py - VLLM_USE_V1=0 CUDA_VISIBLE_DEVICES=0,1 pytest -v -s test_sharded_state_loader.py - VLLM_USE_V1=0 CUDA_VISIBLE_DEVICES=0,1 pytest -v -s kv_transfer/test_disagg.py + - CUDA_VISIBLE_DEVICES=0,1 pytest -v -s v1/shutdown - label: Plugin Tests (2 GPUs) # 40min working_dir: "/vllm-workspace/tests" diff --git a/tests/v1/shutdown/test_delete.py b/tests/v1/shutdown/test_delete.py new file mode 100644 index 00000000..ed368fe8 --- /dev/null +++ b/tests/v1/shutdown/test_delete.py @@ -0,0 +1,97 @@ +# SPDX-License-Identifier: Apache-2.0 +"""Test that we handle a startup Error and shutdown.""" + +import pytest + +from tests.utils import wait_for_gpu_memory_to_clear +from tests.v1.shutdown.utils import (SHUTDOWN_TEST_THRESHOLD_BYTES, + SHUTDOWN_TEST_TIMEOUT_SEC) +from vllm import LLM, SamplingParams +from vllm.engine.arg_utils import AsyncEngineArgs +from vllm.sampling_params import RequestOutputKind +from vllm.utils import cuda_device_count_stateless +from vllm.v1.engine.async_llm import AsyncLLM + +MODELS = ["meta-llama/Llama-3.2-1B"] + + +@pytest.mark.asyncio +@pytest.mark.timeout(SHUTDOWN_TEST_TIMEOUT_SEC) +@pytest.mark.parametrize("model", MODELS) +@pytest.mark.parametrize("tensor_parallel_size", [2, 1]) +@pytest.mark.parametrize("send_one_request", [False, True]) +async def test_async_llm_delete(model: str, tensor_parallel_size: int, + send_one_request: bool) -> None: + """Test that AsyncLLM frees GPU memory upon deletion. + AsyncLLM always uses an MP client. + + Args: + model: model under test + tensor_parallel_size: degree of tensor parallelism + send_one_request: send one request to engine before deleting + """ + if cuda_device_count_stateless() < tensor_parallel_size: + pytest.skip(reason="Not enough CUDA devices") + + engine_args = AsyncEngineArgs(model=model, + enforce_eager=True, + tensor_parallel_size=tensor_parallel_size) + + # Instantiate AsyncLLM; make request to complete any deferred + # initialization; then delete instance + async_llm = AsyncLLM.from_engine_args(engine_args) + if send_one_request: + async for _ in async_llm.generate( + "Hello my name is", + request_id="abc", + sampling_params=SamplingParams( + max_tokens=1, output_kind=RequestOutputKind.DELTA)): + pass + del async_llm + + # Confirm all the processes are cleaned up. + wait_for_gpu_memory_to_clear( + devices=list(range(tensor_parallel_size)), + threshold_bytes=SHUTDOWN_TEST_THRESHOLD_BYTES, + ) + + +@pytest.mark.timeout(SHUTDOWN_TEST_TIMEOUT_SEC) +@pytest.mark.parametrize("model", MODELS) +@pytest.mark.parametrize("tensor_parallel_size", [2, 1]) +@pytest.mark.parametrize("enable_multiprocessing", [True]) +@pytest.mark.parametrize("send_one_request", [False, True]) +def test_llm_delete(monkeypatch, model: str, tensor_parallel_size: int, + enable_multiprocessing: bool, + send_one_request: bool) -> None: + """Test that LLM frees GPU memory upon deletion. + TODO(andy) - LLM without multiprocessing. + + Args: + model: model under test + tensor_parallel_size: degree of tensor parallelism + enable_multiprocessing: enable workers in separate process(es) + send_one_request: send one request to engine before deleting + """ + if cuda_device_count_stateless() < tensor_parallel_size: + pytest.skip(reason="Not enough CUDA devices") + + with monkeypatch.context() as m: + MP_VALUE = "1" if enable_multiprocessing else "0" + m.setenv("VLLM_ENABLE_V1_MULTIPROCESSING", MP_VALUE) + + # Instantiate LLM; make request to complete any deferred + # initialization; then delete instance + llm = LLM(model=model, + enforce_eager=True, + tensor_parallel_size=tensor_parallel_size) + if send_one_request: + llm.generate("Hello my name is", + sampling_params=SamplingParams(max_tokens=1)) + del llm + + # Confirm all the processes are cleaned up. + wait_for_gpu_memory_to_clear( + devices=list(range(tensor_parallel_size)), + threshold_bytes=SHUTDOWN_TEST_THRESHOLD_BYTES, + ) diff --git a/tests/v1/shutdown/test_forward_error.py b/tests/v1/shutdown/test_forward_error.py new file mode 100644 index 00000000..9fedbe4f --- /dev/null +++ b/tests/v1/shutdown/test_forward_error.py @@ -0,0 +1,129 @@ +# SPDX-License-Identifier: Apache-2.0 +"""Test that we handle an Error in model forward and shutdown.""" + +import asyncio + +import pytest + +from tests.utils import wait_for_gpu_memory_to_clear +from tests.v1.shutdown.utils import (SHUTDOWN_TEST_THRESHOLD_BYTES, + SHUTDOWN_TEST_TIMEOUT_SEC) +from vllm import LLM, AsyncEngineArgs, SamplingParams +from vllm.distributed import get_tensor_model_parallel_rank +from vllm.model_executor.models.llama import LlamaForCausalLM +from vllm.utils import cuda_device_count_stateless +from vllm.v1.engine.async_llm import AsyncLLM +from vllm.v1.engine.exceptions import EngineDeadError + +MODELS = ["meta-llama/Llama-3.2-1B"] + + +def evil_forward(self, *args, **kwargs): + """Evil forward method that raise an exception after 10 calls.""" + NUMBER_OF_GOOD_PASSES = 10 + + if not hasattr(self, "num_calls"): + self.num_calls = 0 + + if (self.num_calls == NUMBER_OF_GOOD_PASSES + and get_tensor_model_parallel_rank() == 0): + raise Exception("Simulated illegal memory access on Rank 0!") + self.num_calls += 1 + + return self.model(*args, **kwargs) + + +@pytest.mark.asyncio +@pytest.mark.parametrize("tensor_parallel_size", [2, 1]) +@pytest.mark.parametrize("model", MODELS) +async def test_async_llm_model_error(monkeypatch, tensor_parallel_size: int, + model: str) -> None: + """Test that AsyncLLM propagates a forward pass error and frees memory. + + AsyncLLM always uses an MP client. + """ + if cuda_device_count_stateless() < tensor_parallel_size: + pytest.skip(reason="Not enough CUDA devices") + + # Monkeypatch an error in the model. + monkeypatch.setattr(LlamaForCausalLM, "forward", evil_forward) + + engine_args = AsyncEngineArgs(model=model, + enforce_eager=True, + tensor_parallel_size=tensor_parallel_size) + async_llm = AsyncLLM.from_engine_args(engine_args) + + async def generate(request_id: str): + generator = async_llm.generate("Hello my name is", + request_id=request_id, + sampling_params=SamplingParams()) + try: + async for _ in generator: + pass + except Exception as e: + return e + + NUM_REQS = 3 + tasks = [generate(f"request-{idx}") for idx in range(NUM_REQS)] + outputs = await asyncio.gather(*tasks) + + # Every request should get an EngineDeadError. + for output in outputs: + assert isinstance(output, EngineDeadError) + + # AsyncLLM should be errored. + assert async_llm.errored + + # We should not be able to make another request. + with pytest.raises(EngineDeadError): + async for _ in async_llm.generate("Hello my name is", + request_id="abc", + sampling_params=SamplingParams()): + raise Exception("We should not get here.") + + # Confirm all the processes are cleaned up. + wait_for_gpu_memory_to_clear( + devices=list(range(tensor_parallel_size)), + threshold_bytes=2 * 2**30, + timeout_s=60, + ) + + # NOTE: shutdown is handled by the API Server if an exception + # occurs, so it is expected that we would need to call this. + async_llm.shutdown() + + +@pytest.mark.timeout(SHUTDOWN_TEST_TIMEOUT_SEC) +@pytest.mark.parametrize("enable_multiprocessing", [True]) +@pytest.mark.parametrize("tensor_parallel_size", [2, 1]) +@pytest.mark.parametrize("model", MODELS) +def test_llm_model_error(monkeypatch, tensor_parallel_size: int, + enable_multiprocessing: bool, model: str) -> None: + """Test that LLM propagates a forward pass error and frees memory. + TODO(andy) - LLM without multiprocessing; LLM with multiprocessing + and >1 rank + """ + if cuda_device_count_stateless() < tensor_parallel_size: + pytest.skip(reason="Not enough CUDA devices") + + with monkeypatch.context() as m: + + MP_VALUE = "1" if enable_multiprocessing else "0" + m.setenv("VLLM_ENABLE_V1_MULTIPROCESSING", MP_VALUE) + + # Monkeypatch an error in the model. + m.setattr(LlamaForCausalLM, "forward", evil_forward) + + llm = LLM(model=model, + enforce_eager=True, + tensor_parallel_size=tensor_parallel_size) + + with pytest.raises( + EngineDeadError if enable_multiprocessing else Exception): + llm.generate("Hello my name is Robert and I") + + # Confirm all the processes are cleaned up. + wait_for_gpu_memory_to_clear( + devices=list(range(tensor_parallel_size)), + threshold_bytes=SHUTDOWN_TEST_THRESHOLD_BYTES, + ) diff --git a/tests/v1/shutdown/test_processor_error.py b/tests/v1/shutdown/test_processor_error.py new file mode 100644 index 00000000..0fe48da4 --- /dev/null +++ b/tests/v1/shutdown/test_processor_error.py @@ -0,0 +1,69 @@ +# SPDX-License-Identifier: Apache-2.0 +"""Test error handling in Processor. Should not impact other reqs.""" + +import asyncio + +import pytest + +from tests.v1.shutdown.utils import SHUTDOWN_TEST_TIMEOUT_SEC +from vllm import SamplingParams +from vllm.engine.arg_utils import AsyncEngineArgs +from vllm.inputs.data import TokensPrompt +from vllm.sampling_params import RequestOutputKind +from vllm.v1.engine.async_llm import AsyncLLM +from vllm.v1.engine.exceptions import EngineGenerateError + +MODELS = ["meta-llama/Llama-3.2-1B"] + + +@pytest.mark.asyncio +@pytest.mark.timeout(SHUTDOWN_TEST_TIMEOUT_SEC) +@pytest.mark.parametrize("model", MODELS) +async def test_async_llm_processor_error(model: str) -> None: + """Test that AsyncLLM propagates a processor error. + Test empty tokens prompt (failure) and non-empty prompt (no failure.) + AsyncLLM always uses an MP client. + """ + engine_args = AsyncEngineArgs(model=model, enforce_eager=True) + async_llm = AsyncLLM.from_engine_args(engine_args) + + async def generate(request_id: str): + # [] is not allowed and will raise a ValueError in Processor. + generator = async_llm.generate(TokensPrompt([]), + request_id=request_id, + sampling_params=SamplingParams()) + try: + async for _ in generator: + pass + except Exception as e: + return e + + NUM_REQS = 3 + tasks = [generate(f"request-{idx}") for idx in range(NUM_REQS)] + outputs = await asyncio.gather(*tasks) + + # Every request should have get an EngineGenerateError. + for output in outputs: + with pytest.raises(EngineGenerateError): + raise output + + # AsyncLLM should be errored. + assert not async_llm.errored + + # This should be no problem. + EXPECTED_TOKENS = 5 + outputs = [] + async for out in async_llm.generate( + "Hello my name is", + request_id="abc", + sampling_params=SamplingParams( + max_tokens=EXPECTED_TOKENS, + output_kind=RequestOutputKind.DELTA)): + outputs.append(out) + + generated_tokens = [] + for out in outputs: + generated_tokens.extend(out.outputs[0].token_ids) + assert len(generated_tokens) == EXPECTED_TOKENS + + async_llm.shutdown() diff --git a/tests/v1/shutdown/test_startup_error.py b/tests/v1/shutdown/test_startup_error.py new file mode 100644 index 00000000..1bba1910 --- /dev/null +++ b/tests/v1/shutdown/test_startup_error.py @@ -0,0 +1,97 @@ +# SPDX-License-Identifier: Apache-2.0 +"""Test that we handle a startup Error and shutdown.""" + +import pytest + +from tests.utils import wait_for_gpu_memory_to_clear +from tests.v1.shutdown.utils import (SHUTDOWN_TEST_THRESHOLD_BYTES, + SHUTDOWN_TEST_TIMEOUT_SEC) +from vllm import LLM +from vllm.distributed import get_tensor_model_parallel_rank +from vllm.engine.arg_utils import AsyncEngineArgs +from vllm.model_executor.models.llama import LlamaForCausalLM +from vllm.utils import cuda_device_count_stateless +from vllm.v1.engine.async_llm import AsyncLLM + +MODELS = ["meta-llama/Llama-3.2-1B"] + + +def evil_method(self, *args, **kwargs): + """Evil method that raises an exception.""" + + if get_tensor_model_parallel_rank() == 0: + raise Exception("Simulated Error in startup!") + + return self.model(*args, **kwargs, intermediate_tensors=None) + + +@pytest.mark.timeout(SHUTDOWN_TEST_TIMEOUT_SEC) +@pytest.mark.parametrize("model", MODELS) +@pytest.mark.parametrize("tensor_parallel_size", [2, 1]) +@pytest.mark.parametrize("failing_method", ["forward", "load_weights"]) +def test_async_llm_startup_error(monkeypatch, model: str, + tensor_parallel_size: int, + failing_method: str) -> None: + """Test that AsyncLLM propagates an __init__ error & frees memory. + Test profiling (forward()) and load weights failures. + AsyncLLM always uses an MP client. + """ + if cuda_device_count_stateless() < tensor_parallel_size: + pytest.skip(reason="Not enough CUDA devices") + + # Monkeypatch an error in the model. + monkeypatch.setattr(LlamaForCausalLM, failing_method, evil_method) + + engine_args = AsyncEngineArgs(model=model, + enforce_eager=True, + tensor_parallel_size=tensor_parallel_size) + + # Confirm we get an exception. + with pytest.raises(Exception, match="initialization failed"): + _ = AsyncLLM.from_engine_args(engine_args) + + # Confirm all the processes are cleaned up. + wait_for_gpu_memory_to_clear( + devices=list(range(tensor_parallel_size)), + threshold_bytes=SHUTDOWN_TEST_THRESHOLD_BYTES, + ) + + +@pytest.mark.timeout(SHUTDOWN_TEST_TIMEOUT_SEC) +@pytest.mark.parametrize("model", MODELS) +@pytest.mark.parametrize("tensor_parallel_size", [2, 1]) +@pytest.mark.parametrize("enable_multiprocessing", [True]) +@pytest.mark.parametrize("failing_method", ["forward", "load_weights"]) +def test_llm_startup_error(monkeypatch, model: str, tensor_parallel_size: int, + enable_multiprocessing: bool, + failing_method: str) -> None: + """Test that LLM propagates an __init__ error and frees memory. + Test profiling (forward()) and load weights failures. + TODO(andy) - LLM without multiprocessing. + """ + if model != "meta-llama/Llama-3.2-1B": + pytest.skip(reason="Only test meta-llama/Llama-3.2-1B") + if cuda_device_count_stateless() < tensor_parallel_size: + pytest.skip(reason="Not enough CUDA devices") + + with monkeypatch.context() as m: + + MP_VALUE = "1" if enable_multiprocessing else "0" + m.setenv("VLLM_ENABLE_V1_MULTIPROCESSING", MP_VALUE) + + # Monkeypatch an error in the model. + monkeypatch.setattr(LlamaForCausalLM, failing_method, evil_method) + + with pytest.raises( + Exception, + match="initialization failed" + if enable_multiprocessing else "Simulated Error in startup!"): + _ = LLM(model=model, + enforce_eager=True, + tensor_parallel_size=tensor_parallel_size) + + # Confirm all the processes are cleaned up. + wait_for_gpu_memory_to_clear( + devices=list(range(tensor_parallel_size)), + threshold_bytes=SHUTDOWN_TEST_THRESHOLD_BYTES, + ) diff --git a/tests/v1/shutdown/utils.py b/tests/v1/shutdown/utils.py new file mode 100644 index 00000000..8f7c0380 --- /dev/null +++ b/tests/v1/shutdown/utils.py @@ -0,0 +1,5 @@ +# SPDX-License-Identifier: Apache-2.0 +"""Shutdown test utils""" + +SHUTDOWN_TEST_TIMEOUT_SEC = 120 +SHUTDOWN_TEST_THRESHOLD_BYTES = 2 * 2**30 diff --git a/vllm/distributed/device_communicators/shm_broadcast.py b/vllm/distributed/device_communicators/shm_broadcast.py index 11ed7c08..49a65bd0 100644 --- a/vllm/distributed/device_communicators/shm_broadcast.py +++ b/vllm/distributed/device_communicators/shm_broadcast.py @@ -7,11 +7,13 @@ import time from contextlib import contextmanager from dataclasses import dataclass, field from multiprocessing import shared_memory -from typing import List, Optional, Tuple, Union +from threading import Event +from typing import Any, List, Optional, Tuple, Union from unittest.mock import patch import torch import torch.distributed as dist +import zmq from torch.distributed import ProcessGroup from zmq import IPV6 # type: ignore from zmq import SUB, SUBSCRIBE, XPUB, XPUB_VERBOSE, Context # type: ignore @@ -400,7 +402,9 @@ class MessageQueue: break @contextmanager - def acquire_read(self, timeout: Optional[float] = None): + def acquire_read(self, + timeout: Optional[float] = None, + cancel: Optional[Event] = None): assert self._is_local_reader, "Only readers can acquire read" start_time = time.monotonic() n_warning = 1 @@ -430,6 +434,9 @@ class MessageQueue: ) n_warning += 1 + if cancel is not None and cancel.is_set(): + raise RuntimeError("cancelled") + # if we time out, raise an exception if (timeout is not None and time.monotonic() - start_time > timeout): @@ -464,10 +471,12 @@ class MessageQueue: if self.n_remote_reader > 0: self.remote_socket.send(serialized_obj) - def dequeue(self, timeout: Optional[float] = None): + def dequeue(self, + timeout: Optional[float] = None, + cancel: Optional[Event] = None): """ Read from message queue with optional timeout (in seconds) """ if self._is_local_reader: - with self.acquire_read(timeout) as buf: + with self.acquire_read(timeout, cancel) as buf: overflow = buf[0] == 1 if not overflow: # no need to know the size of serialized object @@ -475,15 +484,21 @@ class MessageQueue: # see https://docs.python.org/3/library/pickle.html obj = pickle.loads(buf[1:]) if overflow: - recv = self.local_socket.recv() - obj = pickle.loads(recv) + obj = MessageQueue.recv(self.local_socket, timeout) elif self._is_remote_reader: - recv = self.remote_socket.recv() - obj = pickle.loads(recv) + obj = MessageQueue.recv(self.remote_socket, timeout) else: raise RuntimeError("Only readers can dequeue") return obj + @staticmethod + def recv(socket: zmq.Socket, timeout: Optional[float]) -> Any: + timeout_ms = None if timeout is None else int(timeout * 1000) + if not socket.poll(timeout=timeout_ms): + raise TimeoutError + recv = socket.recv(copy=False) + return pickle.loads(recv.buffer) + def broadcast_object(self, obj=None): if self._is_writer: self.enqueue(obj) diff --git a/vllm/entrypoints/launcher.py b/vllm/entrypoints/launcher.py index b09ee526..a4f70a51 100644 --- a/vllm/entrypoints/launcher.py +++ b/vllm/entrypoints/launcher.py @@ -12,9 +12,11 @@ from fastapi import FastAPI, Request, Response from vllm import envs from vllm.engine.async_llm_engine import AsyncEngineDeadError from vllm.engine.multiprocessing import MQEngineDeadError +from vllm.engine.protocol import EngineClient from vllm.entrypoints.ssl import SSLCertRefresher from vllm.logger import init_logger from vllm.utils import find_process_using_port +from vllm.v1.engine.exceptions import EngineDeadError, EngineGenerateError logger = init_logger(__name__) @@ -40,6 +42,8 @@ async def serve_http(app: FastAPI, loop = asyncio.get_running_loop() + watchdog_task = loop.create_task( + watchdog_loop(server, app.state.engine_client)) server_task = loop.create_task( server.serve(sockets=[sock] if sock else None)) @@ -52,6 +56,7 @@ async def serve_http(app: FastAPI, def signal_handler() -> None: # prevents the uvicorn signal handler to exit early server_task.cancel() + watchdog_task.cancel() if ssl_cert_refresher: ssl_cert_refresher.stop() @@ -73,48 +78,69 @@ async def serve_http(app: FastAPI, port, process, " ".join(process.cmdline())) logger.info("Shutting down FastAPI HTTP server.") return server.shutdown() + finally: + watchdog_task.cancel() + + +async def watchdog_loop(server: uvicorn.Server, engine: EngineClient): + """ + # Watchdog task that runs in the background, checking + # for error state in the engine. Needed to trigger shutdown + # if an exception arises is StreamingResponse() generator. + """ + VLLM_WATCHDOG_TIME_S = 5.0 + while True: + await asyncio.sleep(VLLM_WATCHDOG_TIME_S) + terminate_if_errored(server, engine) + + +def terminate_if_errored(server: uvicorn.Server, engine: EngineClient): + """ + See discussions here on shutting down a uvicorn server + https://github.com/encode/uvicorn/discussions/1103 + In this case we cannot await the server shutdown here + because handler must first return to close the connection + for this request. + """ + engine_errored = engine.errored and not engine.is_running + if not envs.VLLM_KEEP_ALIVE_ON_ENGINE_DEATH and engine_errored: + server.should_exit = True def _add_shutdown_handlers(app: FastAPI, server: uvicorn.Server) -> None: - """Adds handlers for fatal errors that should crash the server""" + """ + VLLM V1 AsyncLLM catches exceptions and returns + only two types: EngineGenerateError and EngineDeadError. + + EngineGenerateError is raised by the per request generate() + method. This error could be request specific (and therefore + recoverable - e.g. if there is an error in input processing). + + EngineDeadError is raised by the background output_handler + method. This error is global and therefore not recoverable. + + We register these @app.exception_handlers to return nice + responses to the end user if they occur and shut down if needed. + See https://fastapi.tiangolo.com/tutorial/handling-errors/ + for more details on how exception handlers work. + + If an exception is encountered in a StreamingResponse + generator, the exception is not raised, since we already sent + a 200 status. Rather, we send an error message as the next chunk. + Since the exception is not raised, this means that the server + will not automatically shut down. Instead, we use the watchdog + background task for check for errored state. + """ @app.exception_handler(RuntimeError) - async def runtime_error_handler(request: Request, __): - """On generic runtime error, check to see if the engine has died. - It probably has, in which case the server will no longer be able to - handle requests. Trigger a graceful shutdown with a SIGTERM.""" - engine = request.app.state.engine_client - if (not envs.VLLM_KEEP_ALIVE_ON_ENGINE_DEATH and engine.errored - and not engine.is_running): - logger.fatal("AsyncLLMEngine has failed, terminating server " - "process") - # See discussions here on shutting down a uvicorn server - # https://github.com/encode/uvicorn/discussions/1103 - # In this case we cannot await the server shutdown here because - # this handler must first return to close the connection for - # this request. - server.should_exit = True - - return Response(status_code=HTTPStatus.INTERNAL_SERVER_ERROR) - @app.exception_handler(AsyncEngineDeadError) - async def async_engine_dead_handler(_, __): - """Kill the server if the async engine is already dead. It will - not handle any further requests.""" - if not envs.VLLM_KEEP_ALIVE_ON_ENGINE_DEATH: - logger.fatal("AsyncLLMEngine is already dead, terminating server " - "process") - server.should_exit = True - - return Response(status_code=HTTPStatus.INTERNAL_SERVER_ERROR) - @app.exception_handler(MQEngineDeadError) - async def mq_engine_dead_handler(_, __): - """Kill the server if the mq engine is already dead. It will - not handle any further requests.""" - if not envs.VLLM_KEEP_ALIVE_ON_ENGINE_DEATH: - logger.fatal("MQLLMEngine is already dead, terminating server " - "process") - server.should_exit = True + @app.exception_handler(EngineDeadError) + @app.exception_handler(EngineGenerateError) + async def runtime_exception_handler(request: Request, __): + terminate_if_errored( + server=server, + engine=request.app.state.engine_client, + ) return Response(status_code=HTTPStatus.INTERNAL_SERVER_ERROR) diff --git a/vllm/v1/engine/__init__.py b/vllm/v1/engine/__init__.py index 1264e43c..af4122a5 100644 --- a/vllm/v1/engine/__init__.py +++ b/vllm/v1/engine/__init__.py @@ -156,3 +156,5 @@ class EngineCoreRequestType(enum.Enum): ABORT = b'\x01' START_DP = b'\x02' UTILITY = b'\x03' + # Sentinel used within EngineCoreProc. + EXECUTOR_FAILED = b'\x04' diff --git a/vllm/v1/engine/async_llm.py b/vllm/v1/engine/async_llm.py index 6d24ba2b..bc49a0d3 100644 --- a/vllm/v1/engine/async_llm.py +++ b/vllm/v1/engine/async_llm.py @@ -1,8 +1,6 @@ # SPDX-License-Identifier: Apache-2.0 - import asyncio import logging -import os from collections.abc import AsyncGenerator, Mapping from copy import copy from typing import Optional, Union @@ -26,9 +24,10 @@ from vllm.sampling_params import SamplingParams from vllm.transformers_utils.tokenizer import AnyTokenizer from vllm.transformers_utils.tokenizer_group import init_tokenizer_from_configs from vllm.usage.usage_lib import UsageContext -from vllm.utils import Device, cdiv, kill_process_tree +from vllm.utils import Device, cdiv from vllm.v1.engine import EngineCoreRequest -from vllm.v1.engine.core_client import EngineCoreClient +from vllm.v1.engine.core_client import AsyncMPClient, DPAsyncMPClient +from vllm.v1.engine.exceptions import EngineDeadError, EngineGenerateError from vllm.v1.engine.output_processor import (OutputProcessor, RequestOutputCollector) from vllm.v1.engine.parallel_sampling import ParentRequest @@ -61,8 +60,6 @@ class AsyncLLM(EngineClient): "AsyncLLMEngine.from_vllm_config(...) or explicitly set " "VLLM_USE_V1=0 or 1 and report this issue on Github.") - assert start_engine_loop - self.model_config = vllm_config.model_config self.vllm_config = vllm_config self.log_requests = log_requests @@ -99,15 +96,23 @@ class AsyncLLM(EngineClient): log_stats=self.log_stats) # EngineCore (starts the engine in background process). - self.engine_core = EngineCoreClient.make_client( - multiprocess_mode=True, - asyncio_mode=True, + core_client_class = AsyncMPClient if ( + vllm_config.parallel_config.data_parallel_size + == 1) else DPAsyncMPClient + + self.engine_core = core_client_class( vllm_config=vllm_config, executor_class=executor_class, log_stats=self.log_stats, ) self.output_handler: Optional[asyncio.Task] = None + try: + # Start output handler eagerly if we are in the asyncio eventloop. + asyncio.get_running_loop() + self._run_output_handler() + except RuntimeError: + pass @classmethod def from_vllm_config( @@ -165,6 +170,9 @@ class AsyncLLM(EngineClient): usage_context=usage_context, ) + def __del__(self): + self.shutdown() + def shutdown(self): """Shutdown, cleaning up the background proc and IPC.""" @@ -187,6 +195,9 @@ class AsyncLLM(EngineClient): ) -> RequestOutputCollector: """Add new request to the AsyncLLM.""" + if self.errored: + raise EngineDeadError() + assert isinstance(params, SamplingParams), \ "Pooling is not supported in V1" @@ -261,9 +272,7 @@ class AsyncLLM(EngineClient): # We start the output_handler on the first call to generate() so # we can call __init__ before the event loop, which enables us # to handle startup failure gracefully in the OpenAI server. - if self.output_handler is None: - self.output_handler = asyncio.create_task( - self._run_output_handler()) + self._run_output_handler() q = await self.add_request( request_id, @@ -288,62 +297,96 @@ class AsyncLLM(EngineClient): finished = out.finished yield out - # If the request is disconnected by the client, the - # generate() task will be canceled. So, we abort the - # request if we end up here. + # If the request is disconnected by the client, generate() + # is cancelled. So, we abort the request if we end up here. except asyncio.CancelledError: await self.abort(request_id) + if self.log_requests: + logger.info("Request %s aborted.", request_id) raise - async def _run_output_handler(self): + # Engine is dead. Do not abort since we shut down. + except EngineDeadError: + if self.log_requests: + logger.info("Request %s failed (engine dead).", request_id) + raise + + # Request validation error. + except ValueError: + if self.log_requests: + logger.info("Request %s failed (bad request).", request_id) + raise + + # Unexpected error in the generate() task (possibly recoverable). + except Exception as e: + await self.abort(request_id) + if self.log_requests: + logger.info("Request %s failed.", request_id) + raise EngineGenerateError() from e + + def _run_output_handler(self): """Background loop: pulls from EngineCore and pushes to AsyncStreams.""" - try: - while True: - # 1) Pull EngineCoreOutputs from the EngineCore. - outputs = await self.engine_core.get_output_async() - num_outputs = len(outputs.outputs) + if self.output_handler is not None: + return - iteration_stats = IterationStats() if ( - self.log_stats and num_outputs) else None + # Ensure that the task doesn't have a circular ref back to the AsyncLLM + # object, or else it won't be garbage collected and cleaned up properly. + engine_core = self.engine_core + output_processor = self.output_processor + log_stats = self.log_stats + stat_loggers = self.stat_loggers if log_stats else None - # Split outputs into chunks of at most - # VLLM_V1_OUTPUT_PROC_CHUNK_SIZE, so that we don't block the - # event loop for too long. - if num_outputs <= VLLM_V1_OUTPUT_PROC_CHUNK_SIZE: - slices = (outputs.outputs, ) - else: - slices = np.array_split( - outputs.outputs, - cdiv(num_outputs, VLLM_V1_OUTPUT_PROC_CHUNK_SIZE)) + async def output_handler(): + try: + while True: + # 1) Pull EngineCoreOutputs from the EngineCore. + outputs = await engine_core.get_output_async() + num_outputs = len(outputs.outputs) - for i, outputs_slice in enumerate(slices): - # 2) Process EngineCoreOutputs. - processed_outputs = self.output_processor.process_outputs( - outputs_slice, outputs.timestamp, iteration_stats) - # NOTE: RequestOutputs are pushed to their queues. - assert not processed_outputs.request_outputs + iteration_stats = IterationStats() if ( + log_stats and num_outputs) else None - # Allow other asyncio tasks to run between chunks - if i + 1 < len(slices): - await asyncio.sleep(0) + # Split outputs into chunks of at most + # VLLM_V1_OUTPUT_PROC_CHUNK_SIZE, so that we don't block the + # event loop for too long. + if num_outputs <= VLLM_V1_OUTPUT_PROC_CHUNK_SIZE: + slices = (outputs.outputs, ) + else: + slices = np.array_split( + outputs.outputs, + cdiv(num_outputs, VLLM_V1_OUTPUT_PROC_CHUNK_SIZE)) - # 3) Abort any reqs that finished due to stop strings. - await self.engine_core.abort_requests_async( - processed_outputs.reqs_to_abort) + for i, outputs_slice in enumerate(slices): + # 2) Process EngineCoreOutputs. + processed_outputs = output_processor.process_outputs( + outputs_slice, outputs.timestamp, iteration_stats) + # NOTE: RequestOutputs are pushed to their queues. + assert not processed_outputs.request_outputs - # 4) Logging. - # TODO(rob): make into a coroutine and launch it in - # background thread once Prometheus overhead is non-trivial. - self._record_stats( - engine_index=outputs.engine_index, - scheduler_stats=outputs.scheduler_stats, - iteration_stats=iteration_stats, - ) + # Allow other asyncio tasks to run between chunks + if i + 1 < len(slices): + await asyncio.sleep(0) - except Exception as e: - logger.exception("EngineCore output handler hit an error: %s", e) - kill_process_tree(os.getpid()) + # 3) Abort any reqs that finished due to stop strings. + await engine_core.abort_requests_async( + processed_outputs.reqs_to_abort) + + # 4) Logging. + # TODO(rob): make into a coroutine and launch it in + # background thread once Prometheus overhead is non-trivial. + if stat_loggers: + assert outputs.scheduler_stats is not None + AsyncLLM._record_stats( + stat_loggers[outputs.engine_index], + scheduler_stats=outputs.scheduler_stats, + iteration_stats=iteration_stats, + ) + except Exception as e: + logger.exception("AsyncLLM output_handler failed.") + output_processor.propagate_error(e) + + self.output_handler = asyncio.create_task(output_handler()) async def abort(self, request_id: str) -> None: """Abort RequestId in OutputProcessor and EngineCore.""" @@ -354,17 +397,15 @@ class AsyncLLM(EngineClient): if self.log_requests: logger.info("Aborted request %s.", request_id) + @staticmethod def _record_stats( - self, - scheduler_stats: Optional[SchedulerStats], + stat_loggers: list[StatLoggerBase], + scheduler_stats: SchedulerStats, iteration_stats: Optional[IterationStats], - engine_index: int = 0, ): - if not self.log_stats: - return - - assert scheduler_stats is not None - for stat_logger in self.stat_loggers[engine_index]: + """static so that it can be used from the output_handler task + without a circular ref to AsyncLLM.""" + for stat_logger in stat_loggers: stat_logger.record(scheduler_stats=scheduler_stats, iteration_stats=iteration_stats) @@ -451,16 +492,17 @@ class AsyncLLM(EngineClient): @property def is_running(self) -> bool: - return True + # Is None before the loop is started. + return self.output_handler is None or not self.output_handler.done() @property def is_stopped(self) -> bool: - return False + return self.errored @property def errored(self) -> bool: - return False + return self.engine_core.resources.engine_dead or not self.is_running @property def dead_error(self) -> BaseException: - return Exception() # TODO: implement + return EngineDeadError() diff --git a/vllm/v1/engine/core.py b/vllm/v1/engine/core.py index f642e510..ba5e5050 100644 --- a/vllm/v1/engine/core.py +++ b/vllm/v1/engine/core.py @@ -11,9 +11,7 @@ from logging import DEBUG from typing import Any, Callable, Optional, TypeVar, Union import msgspec -import psutil import zmq -import zmq.asyncio from vllm.config import ParallelConfig, VllmConfig from vllm.distributed import stateless_destroy_torch_distributed_process_group @@ -22,8 +20,7 @@ from vllm.logger import init_logger from vllm.lora.request import LoRARequest from vllm.transformers_utils.config import ( maybe_register_config_serialize_by_value) -from vllm.utils import (get_exception_traceback, resolve_obj_by_qualname, - zmq_socket_ctx) +from vllm.utils import resolve_obj_by_qualname, zmq_socket_ctx from vllm.v1.core.kv_cache_utils import (get_kv_cache_config, unify_kv_cache_configs) from vllm.v1.core.sched.interface import SchedulerInterface @@ -50,12 +47,11 @@ _R = TypeVar('_R') # Return type for collective_rpc class EngineCore: """Inner loop of vLLM's Engine.""" - def __init__( - self, - vllm_config: VllmConfig, - executor_class: type[Executor], - log_stats: bool, - ): + def __init__(self, + vllm_config: VllmConfig, + executor_class: type[Executor], + log_stats: bool, + executor_fail_callback: Optional[Callable] = None): assert vllm_config.model_config.runner_type != "pooling" logger.info("Initializing a V1 LLM engine (v%s) with config: %s", @@ -65,6 +61,9 @@ class EngineCore: # Setup Model. self.model_executor = executor_class(vllm_config) + if executor_fail_callback is not None: + self.model_executor.register_failure_callback( + executor_fail_callback) # Setup KV Caches and update CacheConfig after profiling. num_gpu_blocks, num_cpu_blocks, kv_cache_config = \ @@ -254,7 +253,8 @@ class EngineCore: return engine_core_outputs def shutdown(self): - self.model_executor.shutdown() + if self.model_executor: + self.model_executor.shutdown() def profile(self, is_start: bool = True): self.model_executor.profile(is_start) @@ -308,6 +308,8 @@ class EngineCore: class EngineCoreProc(EngineCore): """ZMQ-wrapper for running EngineCore in background process.""" + ENGINE_CORE_DEAD = b'ENGINE_CORE_DEAD' + def __init__( self, input_path: str, @@ -317,11 +319,16 @@ class EngineCoreProc(EngineCore): log_stats: bool, engine_index: int = 0, ): - super().__init__(vllm_config, executor_class, log_stats) + input_queue = queue.Queue[tuple[EngineCoreRequestType, Any]]() + + executor_fail_callback = lambda: input_queue.put_nowait( + (EngineCoreRequestType.EXECUTOR_FAILED, b'')) + + super().__init__(vllm_config, executor_class, log_stats, + executor_fail_callback) self.step_fn = (self.step if self.batch_queue is None else self.step_with_batch_queue) - self.global_unfinished_reqs = False # Background Threads and Queues for IO. These enable us to @@ -329,15 +336,16 @@ class EngineCoreProc(EngineCore): # and to overlap some serialization/deserialization with the # model forward pass. # Threads handle Socket <-> Queues and core_busy_loop uses Queue. - self.input_queue: queue.Queue[tuple[EngineCoreRequestType, - Any]] = queue.Queue() - self.output_queue: queue.Queue[EngineCoreOutputs] = queue.Queue() + self.input_queue = input_queue + self.output_queue = queue.Queue[Union[EngineCoreOutputs, bytes]]() threading.Thread(target=self.process_input_socket, args=(input_path, engine_index), daemon=True).start() - threading.Thread(target=self.process_output_socket, - args=(output_path, engine_index), - daemon=True).start() + self.output_thread = threading.Thread( + target=self.process_output_socket, + args=(output_path, engine_index), + daemon=True) + self.output_thread.start() @staticmethod def run_engine_core(*args, @@ -364,7 +372,6 @@ class EngineCoreProc(EngineCore): signal.signal(signal.SIGTERM, signal_handler) signal.signal(signal.SIGINT, signal_handler) - parent_process = psutil.Process().parent() engine_core: Optional[EngineCoreProc] = None try: parallel_config: ParallelConfig = kwargs[ @@ -380,13 +387,15 @@ class EngineCoreProc(EngineCore): engine_core.run_busy_loop() except SystemExit: - logger.debug("EngineCore interrupted.") - - except Exception: - traceback = get_exception_traceback() - logger.error("EngineCore hit an exception: %s", traceback) - parent_process.send_signal(signal.SIGUSR1) + logger.debug("EngineCore exiting.") + except Exception as e: + if engine_core is None: + logger.exception("EngineCore failed to start.") + else: + logger.exception("EngineCore encountered a fatal error.") + engine_core._send_engine_dead() + raise e finally: if engine_core is not None: engine_core.shutdown() @@ -458,6 +467,11 @@ class EngineCoreProc(EngineCore): f" failed: {str(e)}") self.output_queue.put_nowait( EngineCoreOutputs(utility_output=output)) + elif request_type == EngineCoreRequestType.EXECUTOR_FAILED: + raise RuntimeError("Executor failed.") + else: + logger.error("Unrecognized input request type encountered: %s", + request_type) @staticmethod def _convert_msgspec_args(method, args): @@ -473,6 +487,18 @@ class EngineCoreProc(EngineCore): and not isinstance(v, p.annotation) else v for v, p in zip(args, arg_types)) + def _send_engine_dead(self): + """Send EngineDead status to the EngineCoreClient.""" + + # Put ENGINE_CORE_DEAD in the queue. + self.output_queue.put_nowait(EngineCoreProc.ENGINE_CORE_DEAD) + + # Wait until msg sent by the daemon before shutdown. + self.output_thread.join(timeout=5.0) + if self.output_thread.is_alive(): + logger.fatal("vLLM shutdown signal from EngineCore failed " + "to send. Please report this issue.") + def process_input_socket(self, input_path: str, engine_index: int): """Input socket IO thread.""" @@ -511,9 +537,16 @@ class EngineCoreProc(EngineCore): # Reuse send buffer. buffer = bytearray() - with zmq_socket_ctx(output_path, zmq.constants.PUSH) as socket: + # We must set linger to ensure the ENGINE_CORE_DEAD + # message is sent prior to closing the socket. + with zmq_socket_ctx(output_path, zmq.constants.PUSH, + linger=4000) as socket: while True: outputs = self.output_queue.get() + if outputs == EngineCoreProc.ENGINE_CORE_DEAD: + socket.send(outputs, copy=False) + break + assert not isinstance(outputs, bytes) outputs.engine_index = engine_index buffers = encoder.encode_into(outputs, buffer) socket.send_multipart(buffers, copy=False) diff --git a/vllm/v1/engine/core_client.py b/vllm/v1/engine/core_client.py index a96ebc7e..f54b3546 100644 --- a/vllm/v1/engine/core_client.py +++ b/vllm/v1/engine/core_client.py @@ -1,14 +1,10 @@ # SPDX-License-Identifier: Apache-2.0 - import asyncio -import os import queue -import signal -import threading import uuid import weakref from abc import ABC, abstractmethod -from collections.abc import Awaitable +from collections.abc import Awaitable, Sequence from concurrent.futures import Future from dataclasses import dataclass, field from threading import Thread @@ -21,10 +17,11 @@ from vllm.config import VllmConfig from vllm.logger import init_logger from vllm.lora.request import LoRARequest from vllm.utils import (get_open_zmq_inproc_path, get_open_zmq_ipc_path, - kill_process_tree, make_zmq_socket) + make_zmq_socket) from vllm.v1.engine import (EngineCoreOutputs, EngineCoreRequest, EngineCoreRequestType, UtilityOutput) from vllm.v1.engine.core import EngineCore, EngineCoreProc +from vllm.v1.engine.exceptions import EngineDeadError from vllm.v1.executor.abstract import Executor from vllm.v1.serial_utils import MsgpackDecoder, MsgpackEncoder, bytestr from vllm.v1.utils import BackgroundProcHandle @@ -305,14 +302,22 @@ class BackgroundResources: core_engines: list[CoreEngine] = field(default_factory=list) output_socket: Optional[Union[zmq.Socket, zmq.asyncio.Socket]] = None input_socket: Optional[Union[zmq.Socket, zmq.asyncio.Socket]] = None + output_queue_task: Optional[asyncio.Task] = None shutdown_path: Optional[str] = None + # Set if any of the engines are dead. Here so that the output + # processing threads can access it without holding a ref to the client. + engine_dead: bool = False + def __call__(self): """Clean up background resources.""" for core_engine in self.core_engines: core_engine.close() + if self.output_queue_task is not None: + self.output_queue_task.cancel() + # ZMQ context termination can hang if the sockets # aren't explicitly closed first. if self.output_socket is not None: @@ -327,6 +332,12 @@ class BackgroundResources: # Send shutdown signal. shutdown_sender.send(b'') + def validate_alive(self, frames: Sequence[zmq.Frame]): + if len(frames) == 1 and (frames[0].buffer + == EngineCoreProc.ENGINE_CORE_DEAD): + self.engine_dead = True + raise EngineDeadError() + class MPClient(EngineCoreClient): """ @@ -348,27 +359,6 @@ class MPClient(EngineCoreClient): executor_class: type[Executor], log_stats: bool, ): - # The child processes will send SIGUSR1 when unrecoverable - # errors happen. We kill the process tree here so that the - # stack trace is very evident. - # TODO(rob): rather than killing the main process, we should - # figure out how to raise an AsyncEngineDeadError and - # handle at the API server level so we can return a better - # error code to the clients calling vLLM. - def sigusr1_handler(signum, frame): - logger.fatal("Got fatal signal from worker processes, shutting " - "down. See stack trace above for root cause issue.") - kill_process_tree(os.getpid()) - - if threading.current_thread() == threading.main_thread(): - signal.signal(signal.SIGUSR1, sigusr1_handler) - else: - logger.warning("SIGUSR1 handler not installed because we are not " - "running in the main thread. In this case the " - "forked engine process may not be killed when " - "an exception is raised, and you need to handle " - "the engine process shutdown manually.") - # Serialization setup. self.encoder = MsgpackEncoder() self.decoder = MsgpackDecoder(EngineCoreOutputs) @@ -378,32 +368,37 @@ class MPClient(EngineCoreClient): self.ctx = zmq.asyncio.Context(sync_ctx) if asyncio_mode else sync_ctx # This will ensure resources created so far are closed - # when the client is garbage collected, even if an + # when the client is garbage collected, even if an # exception is raised mid-construction. self.resources = BackgroundResources(ctx=sync_ctx) self._finalizer = weakref.finalize(self, self.resources) + success = False + try: + # Paths and sockets for IPC. + self.output_path = get_open_zmq_ipc_path() + input_path = get_open_zmq_ipc_path() + self.input_socket = make_zmq_socket(self.ctx, + input_path, + zmq.ROUTER, + bind=True) + self.resources.input_socket = self.input_socket - # Paths and sockets for IPC. - self.output_path = get_open_zmq_ipc_path() - input_path = get_open_zmq_ipc_path() - self.input_socket = make_zmq_socket(self.ctx, - input_path, - zmq.ROUTER, - bind=True) - self.resources.input_socket = self.input_socket + new_core_engine = lambda index, local_dp_rank=None: CoreEngine( + vllm_config, executor_class, log_stats, input_path, self. + output_path, index, local_dp_rank) - new_core_engine = lambda index, local_dp_rank=None: CoreEngine( - vllm_config, executor_class, log_stats, input_path, self. - output_path, index, local_dp_rank) + # Start engine core process(es). + self._init_core_engines(vllm_config, new_core_engine, + self.resources.core_engines) - # Start engine core process(es). - self._init_core_engines(vllm_config, new_core_engine, - self.resources.core_engines) + # Wait for engine core process(es) to start. + self._wait_for_engine_startup() - # Wait for engine core process(es) to start. - self._wait_for_engine_startup() - - self.utility_results: dict[int, AnyFuture] = {} + self.utility_results: dict[int, AnyFuture] = {} + success = True + finally: + if not success: + self._finalizer() def _wait_for_engine_startup(self): # Get a sync handle to the socket which can be sync or async. @@ -451,8 +446,18 @@ class MPClient(EngineCoreClient): self.core_engine = core_engine def shutdown(self): + # Terminate background resources. self._finalizer() + def _format_exception(self, e: Exception) -> Exception: + """If errored, use EngineDeadError so root cause is clear.""" + return EngineDeadError( + suppress_context=True) if self.resources.engine_dead else e + + def ensure_alive(self): + if self.resources.engine_dead: + raise EngineDeadError() + def _process_utility_output(output: UtilityOutput, utility_results: dict[int, AnyFuture]): @@ -476,7 +481,7 @@ class SyncMPClient(MPClient): log_stats=log_stats, ) - self.outputs_queue: queue.Queue[EngineCoreOutputs] = queue.Queue() + self.outputs_queue = queue.Queue[Union[EngineCoreOutputs, Exception]]() # Ensure that the outputs socket processing thread does not have # a ref to the client which prevents gc. @@ -487,7 +492,8 @@ class SyncMPClient(MPClient): outputs_queue = self.outputs_queue shutdown_path = get_open_zmq_inproc_path() - self.resources.shutdown_path = shutdown_path + resources = self.resources + resources.shutdown_path = shutdown_path def process_outputs_socket(): shutdown_socket = ctx.socket(zmq.PAIR) @@ -506,12 +512,15 @@ class SyncMPClient(MPClient): break frames = out_socket.recv_multipart(copy=False) + resources.validate_alive(frames) outputs = decoder.decode(frames) if outputs.utility_output: _process_utility_output(outputs.utility_output, utility_results) else: outputs_queue.put_nowait(outputs) + except Exception as e: + outputs_queue.put_nowait(e) finally: # Close sockets. shutdown_socket.close(linger=0) @@ -524,9 +533,16 @@ class SyncMPClient(MPClient): self.output_queue_thread.start() def get_output(self) -> EngineCoreOutputs: - return self.outputs_queue.get() + # If an exception arises in process_outputs_socket task, + # it is forwarded to the outputs_queue so we can raise it + # from this (run_output_handler) task to shut down the server. + outputs = self.outputs_queue.get() + if isinstance(outputs, Exception): + raise self._format_exception(outputs) from None + return outputs def _send_input(self, request_type: EngineCoreRequestType, request: Any): + self.ensure_alive() # (Identity, RequestType, SerializedRequest) msg = (self.core_engine.identity, request_type.value, *self.encoder.encode(request)) @@ -608,61 +624,81 @@ class AsyncMPClient(MPClient): log_stats=log_stats, ) - self.outputs_queue: Optional[asyncio.Queue[EngineCoreOutputs]] = None - self.queue_task: Optional[asyncio.Task] = None - - self.outputs_handler: Optional[Callable[ - [AsyncMPClient, EngineCoreOutputs], Awaitable[None]]] = None + self.outputs_queue = asyncio.Queue[Union[EngineCoreOutputs, + Exception]]() + try: + # If we are running in an asyncio event loop, start the queue task. + # Otherwise, it will be started lazily. If it is not started here, + # we could miss EXECUTOR_FAILED messages from engine core if they + # occur prior to any requests being sent. + asyncio.get_running_loop() + self._ensure_output_queue_task() + except RuntimeError: + pass def _ensure_output_queue_task(self): - if self.outputs_queue is not None: + resources = self.resources + if resources.output_queue_task is not None: return # Perform IO in separate task to parallelize as much as possible. # Avoid task having direct reference back to the client. - self.outputs_queue = asyncio.Queue() decoder = self.decoder utility_results = self.utility_results outputs_queue = self.outputs_queue - output_handler = self.outputs_handler + output_handler: Optional[Callable[[AsyncMPClient, EngineCoreOutputs], + Awaitable[None]]] = getattr( + self.__class__, + "process_engine_outputs", None) _self_ref = weakref.ref(self) if output_handler else None output_path = self.output_path output_socket = make_zmq_socket(self.ctx, output_path, zmq.constants.PULL) - self.resources.output_socket = output_socket + resources.output_socket = output_socket async def process_outputs_socket(): - while True: - frames = await output_socket.recv_multipart(copy=False) - outputs: EngineCoreOutputs = decoder.decode(frames) - if outputs.utility_output: - _process_utility_output(outputs.utility_output, - utility_results) - continue + try: + while True: + frames = await output_socket.recv_multipart(copy=False) + resources.validate_alive(frames) + outputs: EngineCoreOutputs = decoder.decode(frames) + if outputs.utility_output: + _process_utility_output(outputs.utility_output, + utility_results) + continue - if output_handler is not None: - assert _self_ref is not None - _self = _self_ref() - if not _self: - # Client has been garbage collected, abort. - return - await output_handler(_self, outputs) + if output_handler is not None: + assert _self_ref is not None + _self = _self_ref() + if not _self: + # Client has been garbage collected, abort. + return + await output_handler(_self, outputs) - if outputs.outputs or outputs.scheduler_stats: - outputs_queue.put_nowait(outputs) + if outputs.outputs or outputs.scheduler_stats: + outputs_queue.put_nowait(outputs) + except Exception as e: + outputs_queue.put_nowait(e) - self.queue_task = asyncio.create_task(process_outputs_socket(), - name="EngineCoreOutputQueueTask") + resources.output_queue_task = asyncio.create_task( + process_outputs_socket(), name="EngineCoreOutputQueueTask") async def get_output_async(self) -> EngineCoreOutputs: self._ensure_output_queue_task() + # If an exception arises in process_outputs_socket task, + # it is forwarded to the outputs_queue so we can raise it + # from this (run_output_handler) task to shut down the server. assert self.outputs_queue is not None - return await self.outputs_queue.get() + outputs = await self.outputs_queue.get() + if isinstance(outputs, Exception): + raise self._format_exception(outputs) from None + return outputs def _send_input(self, request_type: EngineCoreRequestType, request: Any, engine: Optional[CoreEngine] = None) -> Awaitable[None]: + self.ensure_alive() if engine is None: engine = self.core_engine @@ -671,6 +707,7 @@ class AsyncMPClient(MPClient): def _send_input_message(self, message: tuple[bytestr, ...], engine: CoreEngine) -> Awaitable[None]: + self.ensure_alive() message = (engine.identity, ) + message return self.input_socket.send_multipart(message, copy=False) @@ -754,18 +791,17 @@ class DPAsyncMPClient(AsyncMPClient): def __init__(self, vllm_config: VllmConfig, executor_class: type[Executor], log_stats: bool): - super().__init__(vllm_config, executor_class, log_stats) - assert len(self.core_engines) > 1 + self.num_engines_running = 0 + self.reqs_in_flight: dict[str, CoreEngine] = {} + + super().__init__(vllm_config, executor_class, log_stats) # Control message used for triggering dp idle mode loop. self.start_dp_msg = (EngineCoreRequestType.START_DP.value, *self.encoder.encode(None)) - self.num_engines_running = 0 - self.reqs_in_flight: dict[str, CoreEngine] = {} - - self.outputs_handler = DPAsyncMPClient.process_engine_outputs # type: ignore[assignment] + assert len(self.core_engines) > 1 def _init_core_engines( self, diff --git a/vllm/v1/engine/exceptions.py b/vllm/v1/engine/exceptions.py new file mode 100644 index 00000000..97dd31d5 --- /dev/null +++ b/vllm/v1/engine/exceptions.py @@ -0,0 +1,16 @@ +# SPDX-License-Identifier: Apache-2.0 +class EngineGenerateError(Exception): + """Raised when a AsyncLLM.generate() fails. Recoverable.""" + pass + + +class EngineDeadError(Exception): + """Raised when the EngineCore dies. Unrecoverable.""" + + def __init__(self, *args, suppress_context: bool = False, **kwargs): + ENGINE_DEAD_MESSAGE = "EngineCore encountered an issue. See stack trace (above) for the root cause." # noqa: E501 + + super().__init__(ENGINE_DEAD_MESSAGE, *args, **kwargs) + # Make stack trace clearer when using with LLMEngine by + # silencing irrelevant ZMQError. + self.__suppress_context__ = suppress_context diff --git a/vllm/v1/engine/output_processor.py b/vllm/v1/engine/output_processor.py index 70f072d3..21e2a1ae 100644 --- a/vllm/v1/engine/output_processor.py +++ b/vllm/v1/engine/output_processor.py @@ -28,32 +28,40 @@ class RequestOutputCollector: def __init__(self, output_kind: RequestOutputKind): self.aggregate = output_kind == RequestOutputKind.DELTA - self.output: Optional[RequestOutput] = None + self.output: Optional[Union[RequestOutput, Exception]] = None self.ready = asyncio.Event() - def put(self, output: RequestOutput) -> None: - if self.output is None: + def put(self, output: Union[RequestOutput, Exception]) -> None: + """Non-blocking put operation.""" + if self.output is None or isinstance(output, Exception): self.output = output self.ready.set() - elif self.aggregate: - # Coalesce the outputs in delta case. - self.output.add(output) - else: - # Just replace latest in non-delta case. - self.output = output + elif isinstance(self.output, RequestOutput): + if self.aggregate: + # Coalesce the outputs in delta case. + self.output.add(output) + else: + # Just replace latest in non-delta case. + self.output = output async def get(self) -> RequestOutput: + """Get operation blocks on put event.""" while (output := self.output) is None: await self.ready.wait() self.output = None self.ready.clear() + if isinstance(output, Exception): + raise output return output def get_nowait(self) -> Optional[RequestOutput]: + """Non-blocking get operation.""" output = self.output if output is not None: self.output = None self.ready.clear() + if isinstance(output, Exception): + raise output return output @@ -235,6 +243,13 @@ class OutputProcessor: def has_unfinished_requests(self) -> bool: return len(self.request_states) > 0 + def propagate_error(self, e: Exception): + """Propagate error to all generate() tasks.""" + + for _, state in self.request_states.items(): + assert state.queue is not None + state.queue.put(e) + def abort_requests( self, request_ids: Iterable[str], diff --git a/vllm/v1/executor/abstract.py b/vllm/v1/executor/abstract.py index e3a4cd98..3b9feb0d 100644 --- a/vllm/v1/executor/abstract.py +++ b/vllm/v1/executor/abstract.py @@ -1,7 +1,7 @@ # SPDX-License-Identifier: Apache-2.0 from concurrent.futures import Future -from typing import Union +from typing import Callable, Union import torch import torch.distributed as dist @@ -15,6 +15,8 @@ from vllm.executor.uniproc_executor import ( # noqa from vllm.v1.kv_cache_interface import KVCacheConfig, KVCacheSpec from vllm.v1.outputs import ModelRunnerOutput +FailureCallback = Callable[[], None] + class Executor(ExecutorBase): """ @@ -62,6 +64,13 @@ class Executor(ExecutorBase): args=(kv_cache_configs, )) self.collective_rpc("compile_or_warm_up_model") + def register_failure_callback(self, callback: FailureCallback): + """ + Register a function to be called if the executor enters a permanent + failed state. + """ + pass + def determine_available_memory(self) -> list[int]: # in bytes output = self.collective_rpc("determine_available_memory") return output diff --git a/vllm/v1/executor/multiproc_executor.py b/vllm/v1/executor/multiproc_executor.py index e854c2a4..cff6181f 100644 --- a/vllm/v1/executor/multiproc_executor.py +++ b/vllm/v1/executor/multiproc_executor.py @@ -1,21 +1,23 @@ # SPDX-License-Identifier: Apache-2.0 - +import multiprocessing import os import pickle import signal import sys +import threading import time import traceback import weakref +from concurrent.futures import Future from dataclasses import dataclass from enum import Enum, auto from functools import partial +from multiprocessing.connection import Connection from multiprocessing.process import BaseProcess -from typing import Any, Callable, Optional, Union +from threading import Thread +from typing import Any, Callable, Optional, Union, cast import cloudpickle -import psutil -import zmq from vllm.config import VllmConfig from vllm.distributed import (destroy_distributed_environment, @@ -26,8 +28,9 @@ from vllm.executor.multiproc_worker_utils import ( _add_prefix, set_multiprocessing_worker_envs) from vllm.logger import init_logger from vllm.utils import (get_distributed_init_method, get_mp_context, - get_open_port, get_open_zmq_ipc_path, zmq_socket_ctx) -from vllm.v1.executor.abstract import Executor + get_open_port) +from vllm.v1.executor.abstract import Executor, FailureCallback +from vllm.v1.outputs import ModelRunnerOutput from vllm.worker.worker_base import WorkerWrapperBase logger = init_logger(__name__) @@ -35,6 +38,8 @@ logger = init_logger(__name__) POLLING_TIMEOUT_MS = 5000 POLLING_TIMEOUT_S = POLLING_TIMEOUT_MS // 1000 +EXECUTE_MODEL_TIMEOUT_S = 30 + class MultiprocExecutor(Executor): @@ -42,19 +47,9 @@ class MultiprocExecutor(Executor): # Call self.shutdown at exit to clean up # and ensure workers will be terminated. self._finalizer = weakref.finalize(self, self.shutdown) - - # The child processes will send SIGUSR1 when unrecoverable - # errors happen. - def sigusr1_handler(signum, frame): - logger.fatal( - "MulitprocExecutor got fatal signal from worker processes, " - "shutting down. See stack trace above for root cause issue.") - # Propagate error up to parent process. - parent_process = psutil.Process().parent() - parent_process.send_signal(signal.SIGUSR1) - self.shutdown() - - signal.signal(signal.SIGUSR1, sigusr1_handler) + self.is_failed = False + self.shutdown_event = threading.Event() + self.failure_callback: Optional[FailureCallback] = None self.world_size = self.parallel_config.world_size tensor_parallel_size = self.parallel_config.tensor_parallel_size @@ -78,28 +73,94 @@ class MultiprocExecutor(Executor): scheduler_output_handle = self.rpc_broadcast_mq.export_handle() # Create workers - self.workers: list[WorkerProcHandle] = [] - for rank in range(self.world_size): - worker = WorkerProc.make_worker_process(self.vllm_config, rank, - rank, - distributed_init_method, - scheduler_output_handle) - self.workers.append(worker) + unready_workers: list[UnreadyWorkerProcHandle] = [] + success = False + try: + for rank in range(self.world_size): + unready_workers.append( + WorkerProc.make_worker_process( + vllm_config=self.vllm_config, + local_rank=rank, + rank=rank, + distributed_init_method=distributed_init_method, + input_shm_handle=scheduler_output_handle, + )) - # Ensure message queues are ready. Will deadlock if re-ordered - # Must be kept consistent with the WorkerProc - self.rpc_broadcast_mq.wait_until_ready() - for w in self.workers: - w.worker_response_mq.wait_until_ready() + # Workers must be created before wait_for_ready to avoid + # deadlock, since worker.init_device() does a device sync. + self.workers = WorkerProc.wait_for_ready(unready_workers) + + # Ensure message queues are ready. Will deadlock if re-ordered + # Must be kept consistent with the WorkerProc. + self.rpc_broadcast_mq.wait_until_ready() + for w in self.workers: + w.worker_response_mq.wait_until_ready() + + self.start_worker_monitor() + success = True + finally: + if not success: + # Clean up the worker procs if there was a failure. + self._ensure_worker_termination( + [w.proc for w in unready_workers]) + + def start_worker_monitor(self): + workers = self.workers + self_ref = weakref.ref(self) + + # Monitors worker process liveness. If any die unexpectedly, + # logs an error, shuts down the executor and invokes the failure + # callback to inform the engine. + def monitor_workers(): + sentinels = [h.proc.sentinel for h in workers] + died = multiprocessing.connection.wait(sentinels) + _self = self_ref() + if not _self or getattr(_self, 'shutting_down', False): + return + _self.is_failed = True + proc_name = next(h.proc.name for h in workers + if h.proc.sentinel == died[0]) + logger.error( + "Worker proc %s died unexpectedly, " + "shutting down executor.", proc_name) + _self.shutdown() + callback = _self.failure_callback + if callback is not None: + _self.failure_callback = None + callback() + + Thread(target=monitor_workers, + daemon=True, + name="MultiprocWorkerMonitor").start() + + def register_failure_callback(self, callback: FailureCallback): + if self.is_failed: + callback() + else: + self.failure_callback = callback + + def execute_model( + self, + scheduler_output, + ) -> Union[ModelRunnerOutput, Future[ModelRunnerOutput]]: + (output, ) = self.collective_rpc("execute_model", + args=(scheduler_output, ), + rank0_reply_only=True, + timeout=EXECUTE_MODEL_TIMEOUT_S) + return output def collective_rpc(self, method: Union[str, Callable], - timeout: Optional[float] = None, + timeout: Optional[float] = 180.0, args: tuple = (), - kwargs: Optional[dict] = None) -> list[Any]: + kwargs: Optional[dict] = None, + rank0_reply_only: bool = False) -> list[Any]: start_time = time.monotonic() kwargs = kwargs or {} + if self.is_failed: + raise RuntimeError("Executor failed.") + # NOTE: If the args are heterogeneous, then we pack them into a list, # and unpack them in the method of every worker, because every worker # knows their own rank. @@ -109,30 +170,30 @@ class MultiprocExecutor(Executor): else: send_method = cloudpickle.dumps( method, protocol=pickle.HIGHEST_PROTOCOL) - self.rpc_broadcast_mq.enqueue((send_method, args, kwargs)) + self.rpc_broadcast_mq.enqueue( + (send_method, args, kwargs, rank0_reply_only)) - responses = [None] * self.world_size - for w in self.workers: + workers = (self.workers[0], ) if rank0_reply_only else self.workers + responses = [None] * len(workers) + for w in workers: dequeue_timeout = timeout - (time.monotonic() - start_time ) if timeout is not None else None status, result = w.worker_response_mq.dequeue( - timeout=dequeue_timeout) + timeout=dequeue_timeout, cancel=self.shutdown_event) if status != WorkerProc.ResponseStatus.SUCCESS: raise RuntimeError( - "Worker failed with error %s, please check the" - " stack trace above for the root cause", result) + f"Worker failed with error '{result}', please check the" + " stack trace above for the root cause") responses[w.rank] = result return responses except TimeoutError as e: raise TimeoutError(f"RPC call to {method} timed out.") from e - except Exception as e: - # Re-raise any other exceptions - raise e - def _ensure_worker_termination(self): + @staticmethod + def _ensure_worker_termination(worker_procs: list[BaseProcess]): """Ensure that all worker processes are terminated. Assumes workers have received termination requests. Waits for processing, then sends termination and kill signals if needed.""" @@ -150,7 +211,7 @@ class MultiprocExecutor(Executor): return False # Send SIGTERM if still running - active_procs = [w.proc for w in self.workers if w.proc.is_alive()] + active_procs = [proc for proc in worker_procs if proc.is_alive()] for p in active_procs: p.terminate() if not wait_for_termination(active_procs, 4): @@ -159,22 +220,14 @@ class MultiprocExecutor(Executor): for p in active_procs: p.kill() - self._cleanup_sockets() - - def _cleanup_sockets(self): - for w in self.workers: - # Remove the zmq ipc socket file - socket_path = w.ready_path.replace("ipc://", "") - if os and os.path.exists(socket_path): - os.remove(socket_path) - def shutdown(self): """Properly shut down the executor and its workers""" if not getattr(self, 'shutting_down', False): self.shutting_down = True + self.shutdown_event.set() for w in self.workers: w.worker_response_mq = None - self._ensure_worker_termination() + self._ensure_worker_termination([w.proc for w in self.workers]) self.rpc_broadcast_mq = None @@ -183,13 +236,30 @@ class MultiprocExecutor(Executor): return +@dataclass +class UnreadyWorkerProcHandle: + """WorkerProcess handle before READY.""" + proc: BaseProcess + rank: int + ready_pipe: Connection + + @dataclass class WorkerProcHandle: proc: BaseProcess rank: int - ready_path: str worker_response_mq: MessageQueue # The worker process writes to this MQ + @classmethod + def from_unready_handle( + cls, unready_handle: UnreadyWorkerProcHandle, + worker_response_mq: MessageQueue) -> "WorkerProcHandle": + return cls( + proc=unready_handle.proc, + rank=unready_handle.rank, + worker_response_mq=worker_response_mq, + ) + class WorkerProc: """Wrapper that runs one Worker in a separate process.""" @@ -203,7 +273,6 @@ class WorkerProc: rank: int, distributed_init_method: str, input_shm_handle: Handle, - ready_path: str, ): self.rank = rank wrapper = WorkerWrapperBase(vllm_config=vllm_config, rpc_rank=rank) @@ -231,18 +300,8 @@ class WorkerProc: # Initializes a message queue for sending the model output self.worker_response_mq = MessageQueue(1, 1) - worker_response_mq_handle = self.worker_response_mq.export_handle() - - # Send Readiness signal to EngineCore process. - # Set linger here because we want to ensure the message has - # been sent before the context is closed. - with zmq_socket_ctx(ready_path, zmq.constants.PUSH, - linger=10000) as ready_socket: - payload = pickle.dumps(worker_response_mq_handle, - protocol=pickle.HIGHEST_PROTOCOL) - ready_socket.send_string(WorkerProc.READY_STR) - ready_socket.send(payload) + # Initialize device and loads weights self.worker.init_device() self.worker.load_model() @@ -253,12 +312,10 @@ class WorkerProc: rank: int, distributed_init_method: str, input_shm_handle, # Receive SchedulerOutput - ) -> WorkerProcHandle: + ) -> UnreadyWorkerProcHandle: context = get_mp_context() - - # ZMQ path for worker to send ready message and shm_broadcast handle - # back to core process. - ready_path = get_open_zmq_ipc_path() + # (reader, writer) + reader, writer = context.Pipe(duplex=False) process_kwargs = { "vllm_config": vllm_config, @@ -266,24 +323,57 @@ class WorkerProc: "rank": rank, "distributed_init_method": distributed_init_method, "input_shm_handle": input_shm_handle, - "ready_path": ready_path, + "ready_pipe": (reader, writer), } # Run EngineCore busy loop in background process. proc = context.Process(target=WorkerProc.worker_main, kwargs=process_kwargs, + name=f"VllmWorker-{rank}", daemon=True) - with zmq_socket_ctx(ready_path, zmq.constants.PULL) as ready_socket: - proc.start() + proc.start() + writer.close() + return UnreadyWorkerProcHandle(proc, rank, reader) - # Wait for startup - worker_response_mq_handle = WorkerProc.wait_for_startup( - proc, ready_socket) + @staticmethod + def wait_for_ready( + unready_proc_handles: list[UnreadyWorkerProcHandle] + ) -> list[WorkerProcHandle]: - worker_response_mq = MessageQueue.create_from_handle( - worker_response_mq_handle, 0) + e = Exception("WorkerProc initialization failed due to " + "an exception in a background process. " + "See stack trace for root cause.") - return WorkerProcHandle(proc, rank, ready_path, worker_response_mq) + pipes = {handle.ready_pipe: handle for handle in unready_proc_handles} + ready_proc_handles: list[Optional[WorkerProcHandle]] = ( + [None] * len(unready_proc_handles)) + while pipes: + ready = multiprocessing.connection.wait(pipes.keys()) + for pipe in ready: + assert isinstance(pipe, Connection) + try: + # Wait until the WorkerProc is ready. + unready_proc_handle = pipes.pop(pipe) + response: dict[str, Any] = pipe.recv() + if response["status"] != "READY": + raise e + + # Extract the message queue handle. + worker_response_mq = MessageQueue.create_from_handle( + response["handle"], 0) + ready_proc_handles[unready_proc_handle.rank] = ( + WorkerProcHandle.from_unready_handle( + unready_proc_handle, worker_response_mq)) + + except EOFError: + e.__suppress_context__ = True + raise e from None + + finally: + # Close connection. + pipe.close() + + return cast(list[WorkerProcHandle], ready_proc_handles) def shutdown(self): self.rpc_broadcast_mq = None @@ -312,51 +402,51 @@ class WorkerProc: signal.signal(signal.SIGINT, signal_handler) worker = None + # tuple[Connection, Connection] + reader, ready_writer = kwargs.pop("ready_pipe") try: + reader.close() worker = WorkerProc(*args, **kwargs) + # Send READY once we know everything is loaded + ready_writer.send({ + "status": + WorkerProc.READY_STR, + "handle": + worker.worker_response_mq.export_handle(), + }) + # Ensure message queues are ready. Will deadlock if re-ordered. # Must be kept consistent with the Executor worker.rpc_broadcast_mq.wait_until_ready() worker.worker_response_mq.wait_until_ready() + ready_writer.close() + ready_writer = None worker.worker_busy_loop() - except SystemExit: - logger.debug("Worker interrupted.") - except Exception: - # worker_busy_loop sends exceptions to Executor - # for shutdown, but if there is an error in startup or an - # error with IPC itself, we need to alert the parent. - psutil.Process().parent().send_signal(signal.SIGUSR1) - raise + # NOTE: if an Exception arises in busy_loop, we send + # a FAILURE message over the MQ RPC to notify the Executor, + # which triggers system shutdown. + # TODO(rob): handle case where the MQ itself breaks. + + if ready_writer is not None: + logger.exception("WorkerProc failed to start.") + else: + logger.exception("WorkerProc failed.") + + # The parent sends a SIGTERM to all worker processes if + # any worker dies. Set this value so we don't re-throw + # SystemExit() to avoid zmq exceptions in __del__. + shutdown_requested = True finally: + if ready_writer is not None: + ready_writer.close() # Clean up once worker exits busy loop if worker is not None: worker.shutdown() - worker = None - - @staticmethod - def wait_for_startup( - proc: BaseProcess, - ready_socket: zmq.Socket, - ) -> Optional[Handle]: - """Wait until the Worker is ready.""" - - # Wait for Worker to send READY. - while ready_socket.poll(timeout=POLLING_TIMEOUT_MS) == 0: - logger.debug("Waiting for WorkerProc to startup.") - - if not proc.is_alive(): - raise RuntimeError("WorkerProc failed to start.") - - message = ready_socket.recv_string() - assert message == WorkerProc.READY_STR - handle_frame = ready_socket.recv(copy=False) - handle = pickle.loads(handle_frame.buffer) - return handle class ResponseStatus(Enum): SUCCESS = auto() @@ -365,7 +455,7 @@ class WorkerProc: def worker_busy_loop(self): """Main busy loop for Multiprocessing Workers""" while True: - method, args, kwargs = self.rpc_broadcast_mq.dequeue() + method, args, kwargs, rank0_only = self.rpc_broadcast_mq.dequeue() try: if isinstance(method, str): @@ -377,12 +467,14 @@ class WorkerProc: # Notes have been introduced in python 3.11 if hasattr(e, "add_note"): e.add_note(traceback.format_exc()) - logger.exception("WorkerProc hit an exception: %s", exc_info=e) + logger.exception("WorkerProc hit an exception.") # exception might not be serializable, so we convert it to # string, only for logging purpose. - self.worker_response_mq.enqueue( - (WorkerProc.ResponseStatus.FAILURE, str(e))) + if not rank0_only or self.rank == 0: + self.worker_response_mq.enqueue( + (WorkerProc.ResponseStatus.FAILURE, str(e))) continue - self.worker_response_mq.enqueue( - (WorkerProc.ResponseStatus.SUCCESS, output)) + if not rank0_only or self.rank == 0: + self.worker_response_mq.enqueue( + (WorkerProc.ResponseStatus.SUCCESS, output))