vllm/vllm/v1/executor/multiproc_executor.py
youkaichao bf53e0c70b
Support torchrun and SPMD-style offline inference (#12071)
Signed-off-by: youkaichao <youkaichao@gmail.com>
2025-01-16 19:58:53 +08:00

415 lines
15 KiB
Python

import os
import pickle
import signal
import sys
import time
import weakref
from dataclasses import dataclass
from enum import Enum, auto
from multiprocessing.process import BaseProcess
from typing import Any, Dict, List, Optional, Tuple
import psutil
import zmq
from vllm.config import VllmConfig
from vllm.distributed import (destroy_distributed_environment,
destroy_model_parallel)
from vllm.distributed.device_communicators.shm_broadcast import (Handle,
MessageQueue)
from vllm.executor.multiproc_worker_utils import (
_add_prefix, set_multiprocessing_worker_envs)
from vllm.logger import init_logger
from vllm.utils import (get_distributed_init_method, get_mp_context,
get_open_port, get_open_zmq_ipc_path, zmq_socket_ctx)
from vllm.v1.executor.abstract import Executor
from vllm.v1.outputs import ModelRunnerOutput
from vllm.worker.worker_base import WorkerWrapperBase
logger = init_logger(__name__)
POLLING_TIMEOUT_MS = 5000
POLLING_TIMEOUT_S = POLLING_TIMEOUT_MS // 1000
class MultiprocExecutor(Executor):
def __init__(self, vllm_config: VllmConfig) -> None:
# Call self.shutdown at exit to clean up
# and ensure workers will be terminated.
self._finalizer = weakref.finalize(self, self.shutdown)
# The child processes will send SIGUSR1 when unrecoverable
# errors happen.
def sigusr1_handler(signum, frame):
logger.fatal(
"MulitprocExecutor got fatal signal from worker processes, "
"shutting down. See stack trace above for root cause issue.")
# Propagate error up to parent process.
parent_process = psutil.Process().parent()
parent_process.send_signal(signal.SIGUSR1)
self.shutdown()
signal.signal(signal.SIGUSR1, sigusr1_handler)
self.vllm_config = vllm_config
self.parallel_config = vllm_config.parallel_config
self.world_size = self.parallel_config.world_size
tensor_parallel_size = self.parallel_config.tensor_parallel_size
assert self.world_size == tensor_parallel_size, (
f"world_size ({self.world_size}) must be equal to the "
f"tensor_parallel_size ({tensor_parallel_size}). "
f"Pipeline parallelism is not yet implemented in v1")
# Set multiprocessing envs that are common to V0 and V1
set_multiprocessing_worker_envs(self.parallel_config)
# Multiprocessing-based executor does not support multi-node setting.
# Since it only works for single node, we can use the loopback address
# 127.0.0.1 for communication.
distributed_init_method = get_distributed_init_method(
"127.0.0.1", get_open_port())
# Initialize worker and set up message queues for SchedulerOutputs
# and ModelRunnerOutputs
self.rpc_broadcast_mq = MessageQueue(self.world_size, self.world_size)
scheduler_output_handle = self.rpc_broadcast_mq.export_handle()
# Create workers
self.workers: List[WorkerProcHandle] = []
for rank in range(self.world_size):
worker = WorkerProc.make_worker_process(vllm_config, rank, rank,
distributed_init_method,
scheduler_output_handle)
self.workers.append(worker)
# Ensure message queues are ready. Will deadlock if re-ordered
# Must be kept consistent with the WorkerProc
self.rpc_broadcast_mq.wait_until_ready()
for w in self.workers:
w.worker_response_mq.wait_until_ready()
def initialize(self, num_gpu_blocks: int) -> None:
"""
Initialize the KV caches and begin the model execution loop of the
underlying workers.
"""
logger.info("# GPU blocks: %d", num_gpu_blocks)
self.collective_rpc("initialize_cache", args=(num_gpu_blocks, ))
self.collective_rpc("compile_or_warm_up_model")
def determine_num_available_blocks(self) -> Tuple[int, int]:
"""
Determine the number of available KV blocks by invoking the
underlying worker.
"""
num_blocks = self.collective_rpc("determine_num_available_blocks")
# Since we use a shared centralized controller, we take the minimum
# number of blocks across all workers to make sure all the memory
# operators can be applied to all workers.
num_gpu_blocks = min(b[0] for b in num_blocks)
num_cpu_blocks = min(b[1] for b in num_blocks)
return num_gpu_blocks, num_cpu_blocks
def collective_rpc(self,
method: str,
timeout: Optional[float] = None,
args: Tuple = (),
kwargs: Optional[Dict] = None) -> List[Any]:
"""
Execute an RPC call on workers.
Args:
method: Name of the worker method to execute
timeout: Maximum time in seconds to wait for execution. Rases a
TimeoutError on timeout. None means wait indefinitely.
args: Positional arguments to pass to the worker method
kwargs: Keyword arguments to pass to the worker method
Returns:
List of results from each worker
"""
start_time = time.monotonic()
kwargs = kwargs or {}
try:
self.rpc_broadcast_mq.enqueue((method, args, kwargs))
responses = [None] * self.world_size
for w in self.workers:
dequeue_timeout = timeout - (time.monotonic() - start_time
) if timeout is not None else None
status, result = w.worker_response_mq.dequeue(
timeout=dequeue_timeout)
if status != WorkerProc.ResponseStatus.SUCCESS:
if isinstance(result, Exception):
raise result
else:
raise RuntimeError("Worker failed")
responses[w.rank] = result
return responses
except TimeoutError as e:
raise TimeoutError(f"RPC call to {method} timed out.") from e
except Exception as e:
# Re-raise any other exceptions
raise e
def execute_model(
self,
scheduler_output,
) -> ModelRunnerOutput:
model_output = self.collective_rpc("execute_model",
args=(scheduler_output, ))[0]
return model_output
def profile(self, is_start: bool = True):
self.collective_rpc("profile", args=(is_start, ))
return
def _ensure_worker_termination(self):
"""Ensure that all worker processes are terminated. Assumes workers have
received termination requests. Waits for processing, then sends
termination and kill signals if needed."""
def wait_for_termination(procs, timeout):
if not time:
# If we are in late stage shutdown, the interpreter may replace
# `time` with `None`.
return all(not proc.is_alive() for proc in procs)
start_time = time.time()
while time.time() - start_time < timeout:
if all(not proc.is_alive() for proc in procs):
return True
time.sleep(0.1)
return False
# Send SIGTERM if still running
active_procs = [w.proc for w in self.workers if w.proc.is_alive()]
for p in active_procs:
p.terminate()
if not wait_for_termination(active_procs, 4):
# Send SIGKILL if still running
active_procs = [p for p in active_procs if p.is_alive()]
for p in active_procs:
p.kill()
self._cleanup_sockets()
def _cleanup_sockets(self):
for w in self.workers:
# Remove the zmq ipc socket file
socket_path = w.ready_path.replace("ipc://", "")
if os and os.path.exists(socket_path):
os.remove(socket_path)
def shutdown(self):
"""Properly shut down the executor and its workers"""
if getattr(self, 'shutting_down', False):
self.shutting_down = True
for w in self.workers:
w.worker_response_mq = None
self._ensure_worker_termination()
self.rpc_broadcast_mq = None
def check_health(self) -> None:
self.collective_rpc("check_health", timeout=10)
return
@dataclass
class WorkerProcHandle:
proc: BaseProcess
rank: int
ready_path: str
worker_response_mq: MessageQueue # The worker process writes to this MQ
class WorkerProc:
"""Wrapper that runs one Worker in a separate process."""
READY_STR = "READY"
def __init__(
self,
vllm_config: VllmConfig,
local_rank: int,
rank: int,
distributed_init_method: str,
input_shm_handle: Handle,
ready_path: str,
):
self.rank = rank
wrapper = WorkerWrapperBase(vllm_config=vllm_config, rpc_rank=rank)
# TODO: move `init_worker` to executor level as a collective rpc call
all_kwargs: List[Dict] = [
{} for _ in range(vllm_config.parallel_config.world_size)
]
all_kwargs[rank] = {
"vllm_config": vllm_config,
"local_rank": local_rank,
"rank": rank,
"distributed_init_method": distributed_init_method,
}
wrapper.init_worker(all_kwargs)
self.worker = wrapper.worker
pid = os.getpid()
_add_prefix(sys.stdout, f"VllmWorker rank={rank}", pid)
_add_prefix(sys.stderr, f"VllmWorker rank={rank}", pid)
# Initialize MessageQueue for receiving SchedulerOutput
self.rpc_broadcast_mq = MessageQueue.create_from_handle(
input_shm_handle, self.worker.rank)
# Initializes a message queue for sending the model output
self.worker_response_mq = MessageQueue(1, 1)
worker_response_mq_handle = self.worker_response_mq.export_handle()
# Send Readiness signal to EngineCore process.
with zmq_socket_ctx(ready_path, zmq.constants.PUSH) as ready_socket:
payload = pickle.dumps(worker_response_mq_handle,
protocol=pickle.HIGHEST_PROTOCOL)
ready_socket.send_string(WorkerProc.READY_STR)
ready_socket.send(payload)
self.worker.init_device()
self.worker.load_model()
@staticmethod
def make_worker_process(
vllm_config: VllmConfig,
local_rank: int,
rank: int,
distributed_init_method: str,
input_shm_handle, # Receive SchedulerOutput
) -> WorkerProcHandle:
context = get_mp_context()
# ZMQ path for worker to send ready message and shm_broadcast handle
# back to core process.
ready_path = get_open_zmq_ipc_path()
process_kwargs = {
"vllm_config": vllm_config,
"local_rank": local_rank,
"rank": rank,
"distributed_init_method": distributed_init_method,
"input_shm_handle": input_shm_handle,
"ready_path": ready_path,
}
# Run EngineCore busy loop in background process.
proc = context.Process(target=WorkerProc.worker_main,
kwargs=process_kwargs,
daemon=True)
proc.start()
# Wait for startup
worker_response_mq_handle = WorkerProc.wait_for_startup(
proc, ready_path)
worker_response_mq = MessageQueue.create_from_handle(
worker_response_mq_handle, 0)
return WorkerProcHandle(proc, rank, ready_path, worker_response_mq)
def shutdown(self):
self.rpc_broadcast_mq = None
self.worker_response_mq = None
destroy_model_parallel()
destroy_distributed_environment()
@staticmethod
def worker_main(*args, **kwargs):
""" Worker initialization and execution loops.
This runs a background process """
# Signal handler used for graceful termination.
# SystemExit exception is only raised once to allow this and worker
# processes to terminate without error
shutdown_requested = False
def signal_handler(signum, frame):
nonlocal shutdown_requested
if not shutdown_requested:
shutdown_requested = True
raise SystemExit()
# Either SIGTERM or SIGINT will terminate the worker
signal.signal(signal.SIGTERM, signal_handler)
signal.signal(signal.SIGINT, signal_handler)
worker = None
try:
worker = WorkerProc(*args, **kwargs)
# Ensure message queues are ready. Will deadlock if re-ordered.
# Must be kept consistent with the Executor
worker.rpc_broadcast_mq.wait_until_ready()
worker.worker_response_mq.wait_until_ready()
worker.worker_busy_loop()
except SystemExit:
logger.debug("Worker interrupted.")
except Exception:
# worker_busy_loop sends exceptions exceptons to Executor
# for shutdown, but if there is an error in startup or an
# error with IPC itself, we need to alert the parent.
psutil.Process().parent().send_signal(signal.SIGUSR1)
raise
finally:
# Clean up once worker exits busy loop
if worker is not None:
worker.shutdown()
worker = None
@staticmethod
def wait_for_startup(
proc: BaseProcess,
ready_path: str,
) -> Optional[Handle]:
"""Wait until the Worker is ready."""
with zmq_socket_ctx(ready_path, zmq.constants.PULL) as socket:
# Wait for Worker to send READY.
while socket.poll(timeout=POLLING_TIMEOUT_MS) == 0:
logger.debug("Waiting for WorkerProc to startup.")
if not proc.is_alive():
raise RuntimeError("WorkerProc failed to start.")
message = socket.recv_string()
assert message == WorkerProc.READY_STR
handle_frame = socket.recv(copy=False)
handle = pickle.loads(handle_frame.buffer)
return handle
class ResponseStatus(Enum):
SUCCESS = auto()
FAILURE = auto()
def worker_busy_loop(self):
"""Main busy loop for Multiprocessing Workers"""
while True:
method, args, kwargs = self.rpc_broadcast_mq.dequeue()
try:
output = getattr(self.worker, method)(*args, **kwargs)
except Exception as e:
self.worker_response_mq.enqueue(
(WorkerProc.ResponseStatus.FAILURE, e))
logger.exception("WorkerProc hit an exception: %s", exc_info=e)
continue
self.worker_response_mq.enqueue(
(WorkerProc.ResponseStatus.SUCCESS, output))