[V1] Multiprocessing Tensor Parallel Support for v1 (#9856)

Signed-off-by: Tyler Michael Smith <tyler@neuralmagic.com>
This commit is contained in:
Tyler Michael Smith 2024-12-10 01:28:14 -05:00 committed by GitHub
parent bc192a2b09
commit 28b3a1c7e5
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
21 changed files with 732 additions and 145 deletions

View File

@ -26,6 +26,14 @@ MODELS = [
TARGET_TEST_SUITE = os.environ.get("TARGET_TEST_SUITE", "L4")
@pytest.fixture(autouse=True)
def v1(run_with_both_engines):
# Simple autouse wrapper to run both engines for each test
# This can be promoted up to conftest.py to run for every
# test in a package
pass
def test_vllm_gc_ed():
"""Verify vllm instance is GC'ed when it is deleted"""
llm = LLM("facebook/opt-125m")
@ -36,6 +44,7 @@ def test_vllm_gc_ed():
assert weak_llm() is None
@pytest.mark.skip_v1
@pytest.mark.parametrize("model", MODELS)
@pytest.mark.parametrize("backend", ["FLASH_ATTN", "XFORMERS", "FLASHINFER"])
@pytest.mark.parametrize("dtype", ["half"])
@ -118,6 +127,11 @@ def test_models_distributed(
if attention_backend:
os.environ["VLLM_ATTENTION_BACKEND"] = attention_backend
# Import VLLM_USE_V1 dynamically to handle patching
from vllm.envs import VLLM_USE_V1
if VLLM_USE_V1 and distributed_executor_backend != "mp":
pytest.skip(f"Skip {distributed_executor_backend} for V1")
dtype = "half"
max_tokens = 5
@ -143,6 +157,7 @@ def test_models_distributed(
)
@pytest.mark.skip_v1
def test_model_with_failure(vllm_runner) -> None:
try:
with patch("vllm.model_executor.models.opt.OPTForCausalLM.forward",
@ -169,6 +184,7 @@ def test_model_with_failure(vllm_runner) -> None:
os.remove(filename)
@pytest.mark.skip_v1
def test_failure_with_async_out_proc(vllm_runner) -> None:
filename = None

View File

@ -5,7 +5,6 @@ from collections import UserList
from enum import Enum
from typing import (Any, Callable, Dict, List, Optional, Tuple, Type,
TypedDict, TypeVar, Union)
from unittest.mock import patch
import numpy as np
import pytest
@ -110,7 +109,7 @@ VIDEO_ASSETS = _VideoAssets()
@pytest.fixture(params=[True, False])
def run_with_both_engines(request):
def run_with_both_engines(request, monkeypatch):
# Automatically runs tests twice, once with V1 and once without
use_v1 = request.param
# Tests decorated with `@skip_v1` are only run without v1
@ -119,11 +118,11 @@ def run_with_both_engines(request):
if use_v1:
if skip_v1:
pytest.skip("Skipping test on vllm V1")
with patch('vllm.envs.VLLM_USE_V1', True):
yield
monkeypatch.setenv('VLLM_USE_V1', '1')
else:
with patch('vllm.envs.VLLM_USE_V1', False):
yield
monkeypatch.setenv('VLLM_USE_V1', '0')
yield
@pytest.fixture(autouse=True)

View File

@ -1,10 +1,11 @@
import os
import pickle
import sys
import time
from contextlib import contextmanager
from dataclasses import dataclass, field
from multiprocessing import shared_memory
from typing import List, Optional
from typing import List, Optional, Tuple
from unittest.mock import patch
import torch
@ -21,6 +22,20 @@ VLLM_RINGBUFFER_WARNING_INTERVAL = envs.VLLM_RINGBUFFER_WARNING_INTERVAL
logger = init_logger(__name__)
# We prefer to use os.sched_yield as it results in tighter polling loops,
# measured to be around 3e-7 seconds. However on earlier versions of Python
# os.sched_yield() does not release the GIL, so we fall back to time.sleep(0)
USE_SCHED_YIELD = ((sys.version_info[:3] >= (3, 11, 1))
or (sys.version_info[:2] == (3, 10)
and sys.version_info[2] >= 8))
def sched_yield():
if USE_SCHED_YIELD:
os.sched_yield()
else:
time.sleep(0)
class ShmRingBuffer:
@ -114,11 +129,14 @@ class ShmRingBuffer:
# and we should suppress the error
pass
def handle(self):
return (self.n_reader, self.max_chunk_bytes, self.max_chunks,
self.shared_memory.name)
def __reduce__(self):
return (
self.__class__,
(self.n_reader, self.max_chunk_bytes, self.max_chunks,
self.shared_memory.name),
self.handle(),
)
def __del__(self):
@ -147,7 +165,7 @@ class Handle:
connect_ip: str
local_reader_ranks: List[int] = field(default_factory=list)
buffer: Optional[ShmRingBuffer] = None
buffer_handle: Optional[Tuple[int, int, int, str]] = None
local_subscribe_port: Optional[int] = None
remote_subscribe_port: Optional[int] = None
@ -228,7 +246,7 @@ class MessageQueue:
self.handle = Handle(
connect_ip=connect_ip,
local_reader_ranks=local_reader_ranks,
buffer=self.buffer,
buffer_handle=self.buffer.handle(),
local_subscribe_port=local_subscribe_port,
remote_subscribe_port=remote_subscribe_port,
)
@ -247,8 +265,8 @@ class MessageQueue:
context = Context()
if rank in handle.local_reader_ranks:
assert handle.buffer is not None
self.buffer = handle.buffer
assert handle.buffer_handle is not None
self.buffer = ShmRingBuffer(*handle.buffer_handle)
self.current_idx = 0
self.local_reader_rank = handle.local_reader_ranks.index(rank)
self._is_local_reader = True
@ -314,7 +332,7 @@ class MessageQueue:
assert recv == b"READY"
@contextmanager
def acquire_write(self):
def acquire_write(self, timeout: Optional[float] = None):
assert self._is_writer, "Only writers can acquire write"
start_time = time.monotonic()
n_warning = 1
@ -329,16 +347,20 @@ class MessageQueue:
# we need to wait until it is read by all readers
# Release the processor to other threads
os.sched_yield()
sched_yield()
# if we wait for a long time, we should warn the user
# if we wait for a long time, log a message
if (time.monotonic() - start_time >
VLLM_RINGBUFFER_WARNING_INTERVAL * n_warning):
logger.warning(
"No available block found in %s second. ",
VLLM_RINGBUFFER_WARNING_INTERVAL)
logger.debug("No available block found in %s second. ",
VLLM_RINGBUFFER_WARNING_INTERVAL)
n_warning += 1
# if we time out, raise an exception
if (timeout is not None
and time.monotonic() - start_time > timeout):
raise TimeoutError
continue
# found a block that is either
# (1) not written
@ -365,7 +387,7 @@ class MessageQueue:
break
@contextmanager
def acquire_read(self):
def acquire_read(self, timeout: Optional[float] = None):
assert self._is_local_reader, "Only readers can acquire read"
start_time = time.monotonic()
n_warning = 1
@ -383,16 +405,20 @@ class MessageQueue:
# we need to wait until it is written
# Release the processor to other threads
os.sched_yield()
sched_yield()
# if we wait for a long time, we should warn the user
# if we wait for a long time, log a message
if (time.monotonic() - start_time >
VLLM_RINGBUFFER_WARNING_INTERVAL * n_warning):
logger.warning(
"No available block found in %s second. ",
VLLM_RINGBUFFER_WARNING_INTERVAL)
logger.debug("No available block found in %s second. ",
VLLM_RINGBUFFER_WARNING_INTERVAL)
n_warning += 1
# if we time out, raise an exception
if (timeout is not None
and time.monotonic() - start_time > timeout):
raise TimeoutError
continue
# found a block that is not read by this reader
# let caller read from the buffer
@ -406,24 +432,26 @@ class MessageQueue:
1) % self.buffer.max_chunks
break
def enqueue(self, obj):
def enqueue(self, obj, timeout: Optional[float] = None):
""" Write to message queue with optional timeout (in seconds) """
assert self._is_writer, "Only writers can enqueue"
serialized_obj = pickle.dumps(obj, protocol=pickle.HIGHEST_PROTOCOL)
if self.n_local_reader > 0:
if len(serialized_obj) >= self.buffer.max_chunk_bytes:
with self.acquire_write() as buf:
with self.acquire_write(timeout) as buf:
buf[0] = 1 # overflow
self.local_socket.send(serialized_obj)
else:
with self.acquire_write() as buf:
with self.acquire_write(timeout) as buf:
buf[0] = 0 # not overflow
buf[1:len(serialized_obj) + 1] = serialized_obj
if self.n_remote_reader > 0:
self.remote_socket.send(serialized_obj)
def dequeue(self):
def dequeue(self, timeout: Optional[float] = None):
""" Read from message queue with optional timeout (in seconds) """
if self._is_local_reader:
with self.acquire_read() as buf:
with self.acquire_read(timeout) as buf:
overflow = buf[0] == 1
if not overflow:
# no need to know the size of serialized object

View File

@ -3,25 +3,19 @@ import os
from functools import partial
from typing import Any, List, Optional
import torch
from vllm.executor.distributed_gpu_executor import ( # yapf: disable
DistributedGPUExecutor, DistributedGPUExecutorAsync)
from vllm.executor.gpu_executor import create_worker
from vllm.executor.multiproc_worker_utils import (ProcessWorkerWrapper,
ResultHandler, WorkerMonitor)
from vllm.executor.multiproc_worker_utils import (
ProcessWorkerWrapper, ResultHandler, WorkerMonitor,
set_multiprocessing_worker_envs)
from vllm.logger import init_logger
from vllm.model_executor.layers.sampler import SamplerOutput
from vllm.sequence import ExecuteModelRequest
from vllm.triton_utils.importing import HAS_TRITON
from vllm.utils import (_run_task_with_lock, cuda_device_count_stateless,
cuda_is_initialized, get_distributed_init_method,
get_open_port, make_async,
get_distributed_init_method, get_open_port, make_async,
update_environment_variables)
if HAS_TRITON:
from vllm.triton_utils import maybe_set_triton_cache_manager
logger = init_logger(__name__)
@ -37,30 +31,8 @@ class MultiprocessingGPUExecutor(DistributedGPUExecutor):
world_size = self.parallel_config.world_size
tensor_parallel_size = self.parallel_config.tensor_parallel_size
# Disable torch async compiling which won't work with daemonic processes
os.environ["TORCHINDUCTOR_COMPILE_THREADS"] = "1"
# Configure thread parallelism if OMP_NUM_THREADS isn't set
#
# Helps to avoid CPU contention. The default of spawning a thread per
# core combined with multiprocessing for each GPU can have a negative
# impact on performance. The contention is amplified when running in a
# container where CPU limits can cause throttling.
default_omp_num_threads = 1
if "OMP_NUM_THREADS" not in os.environ and (
current_parallelism :=
torch.get_num_threads()) > default_omp_num_threads:
logger.warning(
"Reducing Torch parallelism from %d threads to %d to avoid "
"unnecessary CPU contention. Set OMP_NUM_THREADS in the "
"external environment to tune this value as needed.",
current_parallelism, default_omp_num_threads)
os.environ["OMP_NUM_THREADS"] = str(default_omp_num_threads)
torch.set_num_threads(default_omp_num_threads)
# workaround for https://github.com/vllm-project/vllm/issues/6103
if HAS_TRITON and world_size > 1:
maybe_set_triton_cache_manager()
# Set multiprocessing envs that are common to V0 and V1
set_multiprocessing_worker_envs(self.parallel_config)
# Multiprocessing-based executor does not support multi-node setting.
# Since it only works for single node, we can use the loopback address
@ -122,13 +94,6 @@ class MultiprocessingGPUExecutor(DistributedGPUExecutor):
"CUDA_VISIBLE_DEVICES": (",".join(map(str, range(world_size))))
})
if (cuda_is_initialized()
and os.environ.get("VLLM_WORKER_MULTIPROC_METHOD") != "spawn"):
logger.warning("CUDA was previously initialized. We must use "
"the `spawn` multiprocessing start method. Setting "
"VLLM_WORKER_MULTIPROC_METHOD to 'spawn'.")
os.environ["VLLM_WORKER_MULTIPROC_METHOD"] = "spawn"
cuda_device_count = cuda_device_count_stateless()
# Use confusing message for more common TP-only case.
assert tensor_parallel_size <= cuda_device_count, (

View File

@ -11,8 +11,15 @@ from multiprocessing.process import BaseProcess
from typing import (Any, Callable, Dict, Generic, List, Optional, TextIO,
TypeVar, Union)
import torch
import vllm.envs as envs
from vllm.logger import init_logger
from vllm.triton_utils.importing import HAS_TRITON
from vllm.utils import cuda_is_initialized
if HAS_TRITON:
from vllm.triton_utils import maybe_set_triton_cache_manager
logger = init_logger(__name__)
@ -270,3 +277,38 @@ def _add_prefix(file: TextIO, worker_name: str, pid: int) -> None:
def get_mp_context():
mp_method = envs.VLLM_WORKER_MULTIPROC_METHOD
return multiprocessing.get_context(mp_method)
def set_multiprocessing_worker_envs(parallel_config):
""" Set up environment variables that should be used when there are workers
in a multiprocessing environment. This should be called by the parent
process before worker processes are created"""
if (cuda_is_initialized()
and os.environ.get("VLLM_WORKER_MULTIPROC_METHOD") != "spawn"):
logger.warning("CUDA was previously initialized. We must use "
"the `spawn` multiprocessing start method. Setting "
"VLLM_WORKER_MULTIPROC_METHOD to 'spawn'.")
os.environ["VLLM_WORKER_MULTIPROC_METHOD"] = "spawn"
# Configure thread parallelism if OMP_NUM_THREADS isn't set
#
# Helps to avoid CPU contention. The default of spawning a thread per
# core combined with multiprocessing for each GPU can have a negative
# impact on performance. The contention is amplified when running in a
# container where CPU limits can cause throttling.
default_omp_num_threads = 1
if "OMP_NUM_THREADS" not in os.environ and (
current_parallelism :=
torch.get_num_threads()) > default_omp_num_threads:
logger.warning(
"Reducing Torch parallelism from %d threads to %d to avoid "
"unnecessary CPU contention. Set OMP_NUM_THREADS in the "
"external environment to tune this value as needed.",
current_parallelism, default_omp_num_threads)
os.environ["OMP_NUM_THREADS"] = str(default_omp_num_threads)
torch.set_num_threads(default_omp_num_threads)
# workaround for https://github.com/vllm-project/vllm/issues/6103
if HAS_TRITON and parallel_config.world_size > 1:
maybe_set_triton_cache_manager()

View File

@ -5,6 +5,7 @@ from typing import Optional
import torch
import torch.nn as nn
import vllm.envs as envs
from vllm.distributed import (tensor_model_parallel_all_gather,
tensor_model_parallel_gather)
from vllm.model_executor.layers.vocab_parallel_embedding import (
@ -42,7 +43,9 @@ class LogitsProcessor(nn.Module):
# Soft cap the logits. Used in Gemma 2.
self.soft_cap = soft_cap
# Whether to use gather or all-gather to gather the logits.
self.use_gather = not current_platform.is_tpu()
self.use_gather = not current_platform.is_tpu(
) and not envs.VLLM_USE_V1
def forward(
self,

View File

@ -12,6 +12,7 @@ from typing_extensions import ParamSpec
# import custom ops, trigger op registration
import vllm._C # noqa
import vllm.envs as envs
from vllm.logger import init_logger
from .interface import DeviceCapability, Platform, PlatformEnum
@ -110,17 +111,28 @@ class CudaPlatformBase(Platform):
def check_and_update_config(cls, vllm_config: VllmConfig) -> None:
parallel_config = vllm_config.parallel_config
scheduler_config = vllm_config.scheduler_config
if parallel_config.worker_cls == "auto":
if scheduler_config.is_multi_step:
parallel_config.worker_cls = \
"vllm.worker.multi_step_worker.MultiStepWorker"
if envs.VLLM_USE_V1:
raise NotImplementedError
else:
parallel_config.worker_cls = \
"vllm.worker.multi_step_worker.MultiStepWorker"
elif vllm_config.speculative_config:
parallel_config.worker_cls = \
"vllm.spec_decode.spec_decode_worker.create_spec_worker"
parallel_config.sd_worker_cls = \
"vllm.worker.worker.Worker"
if envs.VLLM_USE_V1:
raise NotImplementedError
else:
parallel_config.worker_cls = \
"vllm.spec_decode.spec_decode_worker.create_spec_worker"
parallel_config.sd_worker_cls = \
"vllm.worker.worker.Worker"
else:
parallel_config.worker_cls = "vllm.worker.worker.Worker"
if envs.VLLM_USE_V1:
parallel_config.worker_cls = \
"vllm.v1.worker.gpu_worker.Worker"
else:
parallel_config.worker_cls = "vllm.worker.worker.Worker"
# NVML utils
@ -249,4 +261,4 @@ try:
if not isinstance(pynvml, _MockModule):
CudaPlatform.log_warnings()
except ModuleNotFoundError:
CudaPlatform.log_warnings()
CudaPlatform.log_warnings()

View File

@ -10,6 +10,7 @@ import importlib.util
import inspect
import ipaddress
import os
import signal
import socket
import subprocess
import sys
@ -1652,3 +1653,28 @@ def resolve_obj_by_qualname(qualname: str) -> Any:
module_name, obj_name = qualname.rsplit(".", 1)
module = importlib.import_module(module_name)
return getattr(module, obj_name)
def kill_process_tree(pid: int):
"""
Kills all descendant processes of the given pid by sending SIGKILL.
Args:
pid (int): Process ID of the parent process
"""
try:
parent = psutil.Process(pid)
except psutil.NoSuchProcess:
return
# Get all children recursively
children = parent.children(recursive=True)
# Send SIGKILL to all children first
for child in children:
with contextlib.suppress(ProcessLookupError):
os.kill(child.pid, signal.SIGKILL)
# Finally kill the parent
with contextlib.suppress(ProcessLookupError):
os.kill(pid, signal.SIGKILL)

View File

@ -5,6 +5,8 @@ from typing import (TYPE_CHECKING, Deque, Dict, Iterable, List, Optional, Set,
from vllm.config import CacheConfig, LoRAConfig, SchedulerConfig
from vllm.logger import init_logger
from vllm.multimodal import MultiModalKwargs
from vllm.multimodal.base import PlaceholderRange
from vllm.sampling_params import SamplingParams
from vllm.v1.core.encoder_cache_manager import EncoderCacheManager
from vllm.v1.core.kv_cache_manager import KVCacheManager
@ -383,7 +385,7 @@ class Scheduler:
model_runner_output: "ModelRunnerOutput",
) -> List[EngineCoreOutput]:
# NOTE(woosuk): This method doesn't consider speculative decoding.
sampled_token_ids = model_runner_output.sampled_token_ids_cpu.tolist()
sampled_token_ids = model_runner_output.sampled_token_ids
num_scheduled_tokens = scheduler_output.num_scheduled_tokens
new_running: List[Request] = []
engine_core_outputs: List[EngineCoreOutput] = []

View File

@ -20,7 +20,7 @@ from vllm.v1.engine.async_stream import AsyncStream
from vllm.v1.engine.core_client import EngineCoreClient
from vllm.v1.engine.detokenizer import Detokenizer
from vllm.v1.engine.processor import Processor
from vllm.v1.executor.gpu_executor import GPUExecutor
from vllm.v1.executor.abstract import Executor
logger = init_logger(__name__)
@ -30,7 +30,7 @@ class AsyncLLM(EngineClient):
def __init__(
self,
vllm_config: VllmConfig,
executor_class: Type[GPUExecutor],
executor_class: Type[Executor],
log_stats: bool,
usage_context: UsageContext = UsageContext.ENGINE_CONTEXT,
stat_loggers: Optional[Dict[str, StatLoggerBase]] = None,
@ -119,14 +119,24 @@ class AsyncLLM(EngineClient):
def shutdown(self):
"""Shutdown, cleaning up the background proc and IPC."""
self.engine_core.shutdown()
if engine_core := getattr(self, "engine_core", None):
engine_core.shutdown()
if handler := getattr(self, "output_handler", None):
handler.cancel()
@classmethod
def _get_executor_cls(cls, vllm_config: VllmConfig):
return GPUExecutor
distributed_executor_backend = (
vllm_config.parallel_config.distributed_executor_backend)
if distributed_executor_backend == "mp":
from vllm.v1.executor.multiproc_executor import MultiprocExecutor
executor_class = MultiprocExecutor
else:
assert (distributed_executor_backend is None)
from vllm.v1.executor.uniproc_executor import UniprocExecutor
executor_class = UniprocExecutor
return executor_class
async def add_request(
self,

View File

@ -1,12 +1,12 @@
import multiprocessing
import pickle
import queue
import signal
import threading
import time
from contextlib import contextmanager
from multiprocessing.process import BaseProcess
from multiprocessing.sharedctypes import Synchronized
from typing import Any, Iterator, List, Tuple, Type, Union
from typing import List, Tuple, Type, Union
import zmq
import zmq.asyncio
@ -20,9 +20,10 @@ from vllm.v1.engine import (EngineCoreOutput, EngineCoreOutputs,
EngineCoreProfile, EngineCoreRequest,
EngineCoreRequestType)
from vllm.v1.engine.mm_input_mapper import MMInputMapper
from vllm.v1.executor.gpu_executor import GPUExecutor
from vllm.v1.executor.abstract import Executor
from vllm.v1.request import Request, RequestStatus
from vllm.v1.serial_utils import PickleEncoder
from vllm.v1.utils import make_zmq_socket
from vllm.version import __version__ as VLLM_VERSION
logger = init_logger(__name__)
@ -38,7 +39,7 @@ class EngineCore:
def __init__(
self,
vllm_config: VllmConfig,
executor_class: Type[GPUExecutor],
executor_class: Type[Executor],
usage_context: UsageContext,
):
assert vllm_config.model_config.task != "embedding"
@ -80,7 +81,7 @@ class EngineCore:
num_gpu_blocks = num_gpu_blocks_override
num_cpu_blocks = 0
self.model_executor.initialize_cache(num_gpu_blocks)
self.model_executor.initialize(num_gpu_blocks)
elapsed = time.time() - start
logger.info(("init engine (profile, create kv cache, "
"warmup model) took %.2f seconds"), elapsed)
@ -112,8 +113,11 @@ class EngineCore:
scheduler_output, output)
return engine_core_outputs
def shutdown(self):
self.model_executor.shutdown()
def profile(self, is_start=True):
self.model_executor.worker.profile(is_start)
self.model_executor.profile(is_start)
class EngineCoreProc(EngineCore):
@ -124,7 +128,7 @@ class EngineCoreProc(EngineCore):
def __init__(
self,
vllm_config: VllmConfig,
executor_class: Type[GPUExecutor],
executor_class: Type[Executor],
usage_context: UsageContext,
input_path: str,
output_path: str,
@ -151,32 +155,9 @@ class EngineCoreProc(EngineCore):
daemon=True).start()
# Send Readiness signal to EngineClient.
with self.make_socket(ready_path, zmq.constants.PUSH) as ready_socket:
with make_zmq_socket(ready_path, zmq.constants.PUSH) as ready_socket:
ready_socket.send_string(EngineCoreProc.READY_STR)
@contextmanager
def make_socket(self, path: str, type: Any) -> Iterator[zmq.Socket]:
"""Context manager for use """
ctx = zmq.Context()
try:
socket = ctx.socket(type)
if type == zmq.constants.PULL:
socket.connect(path)
elif type == zmq.constants.PUSH:
socket.bind(path)
else:
raise ValueError(f"Unknown Socket Type: {type}")
yield socket
except KeyboardInterrupt:
logger.debug("EngineCore had Keyboard Interrupt.")
finally:
ctx.destroy(linger=0)
@staticmethod
def wait_for_startup(
proc: BaseProcess,
@ -209,7 +190,7 @@ class EngineCoreProc(EngineCore):
@staticmethod
def make_engine_core_process(
vllm_config: VllmConfig,
executor_class: Type[GPUExecutor],
executor_class: Type[Executor],
usage_context: UsageContext,
input_path: str,
output_path: str,
@ -244,17 +225,38 @@ class EngineCoreProc(EngineCore):
def run_engine_core(*args, **kwargs):
"""Launch EngineCore busy loop in background process."""
# Signal handler used for graceful termination.
# SystemExit exception is only raised once to allow this and worker
# processes to terminate without error
shutdown_requested = False
def signal_handler(signum, frame):
nonlocal shutdown_requested
if not shutdown_requested:
shutdown_requested = True
raise SystemExit()
# Either SIGTERM or SIGINT will terminate the engine_core
signal.signal(signal.SIGTERM, signal_handler)
signal.signal(signal.SIGINT, signal_handler)
engine_core = None
try:
engine_core = EngineCoreProc(*args, **kwargs)
engine_core.run_busy_loop()
except KeyboardInterrupt:
except SystemExit:
logger.debug("EngineCore interrupted.")
except BaseException as e:
logger.exception(e)
raise e
finally:
if engine_core is not None:
engine_core.shutdown()
engine_core = None
def run_busy_loop(self):
"""Core busy loop of the EngineCore."""
@ -272,6 +274,8 @@ class EngineCoreProc(EngineCore):
logger.debug("EngineCore busy loop waiting.")
if self.should_shutdown:
return
except BaseException:
raise
# 2) Handle any new client requests (Abort or Add).
while not self.input_queue.empty():
@ -321,7 +325,7 @@ class EngineCoreProc(EngineCore):
decoder_add_req = PickleEncoder()
decoder_abort_req = PickleEncoder()
with self.make_socket(input_path, zmq.constants.PULL) as socket:
with make_zmq_socket(input_path, zmq.constants.PULL) as socket:
while True:
# (RequestType, RequestData)
type_frame, data_frame = socket.recv_multipart(copy=False)
@ -349,7 +353,7 @@ class EngineCoreProc(EngineCore):
# Reuse send buffer.
buffer = bytearray()
with self.make_socket(output_path, zmq.constants.PUSH) as socket:
with make_zmq_socket(output_path, zmq.constants.PUSH) as socket:
while True:
engine_core_outputs = self.output_queue.get()
outputs = EngineCoreOutputs(outputs=engine_core_outputs)

View File

@ -1,5 +1,4 @@
import multiprocessing
import time
from typing import List, Union
import msgspec
@ -7,7 +6,7 @@ import zmq
import zmq.asyncio
from vllm.logger import init_logger
from vllm.utils import get_open_zmq_ipc_path
from vllm.utils import get_open_zmq_ipc_path, kill_process_tree
from vllm.v1.engine import (EngineCoreOutput, EngineCoreOutputs,
EngineCoreProfile, EngineCoreRequest,
EngineCoreRequestType)
@ -99,6 +98,12 @@ class InprocClient(EngineCoreClient):
def abort_requests(self, request_ids: List[str]) -> None:
self.engine_core.abort_requests(request_ids)
def shutdown(self):
self.engine_core.shutdown()
def __del__(self):
self.shutdown()
async def profile(self, is_start=True) -> None:
self.engine_core.profile(is_start)
@ -163,10 +168,10 @@ class MPClient(EngineCoreClient):
# Shutdown the process if needed.
if hasattr(self, "proc") and self.proc.is_alive():
self.proc.terminate()
self.proc.join(5)
time.sleep(5)
if self.proc.is_alive():
self.proc.kill()
kill_process_tree(self.proc.pid)
def __del__(self):
self.shutdown()

View File

@ -20,7 +20,7 @@ from vllm.usage.usage_lib import UsageContext
from vllm.v1.engine.core_client import EngineCoreClient
from vllm.v1.engine.detokenizer import Detokenizer
from vllm.v1.engine.processor import Processor
from vllm.v1.executor.gpu_executor import GPUExecutor
from vllm.v1.executor.abstract import Executor
logger = init_logger(__name__)
@ -33,7 +33,7 @@ class LLMEngine:
def __init__(
self,
vllm_config: VllmConfig,
executor_class: Type[GPUExecutor],
executor_class: Type[Executor],
log_stats: bool,
usage_context: UsageContext = UsageContext.ENGINE_CONTEXT,
stat_loggers: Optional[Dict[str, StatLoggerBase]] = None,
@ -104,10 +104,17 @@ class LLMEngine:
@classmethod
def _get_executor_cls(cls, vllm_config: VllmConfig):
return GPUExecutor
distributed_executor_backend = (
vllm_config.parallel_config.distributed_executor_backend)
if distributed_executor_backend == "mp":
from vllm.v1.executor.multiproc_executor import MultiprocExecutor
executor_class = MultiprocExecutor
else:
assert (distributed_executor_backend is None)
from vllm.v1.executor.uniproc_executor import UniprocExecutor
executor_class = UniprocExecutor
def stop_remote_worker_execution_loop(self) -> None:
raise NotImplementedError("TP not implemented yet.")
return executor_class
def get_num_unfinished_requests(self) -> int:
return self.detokenizer.get_num_unfinished_requests()

View File

@ -0,0 +1,48 @@
from abc import ABC, abstractmethod
from typing import Dict, Optional, Tuple
from vllm.config import VllmConfig
from vllm.v1.outputs import ModelRunnerOutput
class Executor(ABC):
"""Abstract class for executors."""
@abstractmethod
def __init__(self, vllm_config: VllmConfig) -> None:
raise NotImplementedError
@abstractmethod
def initialize(self, num_gpu_blocks: int) -> None:
raise NotImplementedError
@abstractmethod
def determine_num_available_blocks(self) -> Tuple[int, int]:
raise NotImplementedError
@abstractmethod
def execute_model(
self,
scheduler_output,
) -> ModelRunnerOutput:
raise NotImplementedError
@abstractmethod
def profile(self, is_start=True):
raise NotImplementedError
@abstractmethod
def shutdown(self):
pass
@abstractmethod
def check_health(self) -> None:
raise NotImplementedError
@abstractmethod
def collective_rpc(self,
method: str,
timeout: Optional[float] = None,
args: Tuple = (),
kwargs: Optional[Dict] = None) -> []:
raise NotImplementedError

View File

@ -0,0 +1,375 @@
import atexit
import os
import pickle
import signal
import sys
import time
from dataclasses import dataclass
from enum import Enum, auto
from multiprocessing.process import BaseProcess
from typing import Dict, List, Optional, Tuple
import zmq
from vllm.config import VllmConfig
from vllm.distributed import (destroy_distributed_environment,
destroy_model_parallel)
from vllm.distributed.device_communicators.shm_broadcast import (Handle,
MessageQueue)
from vllm.executor.multiproc_worker_utils import (
_add_prefix, get_mp_context, set_multiprocessing_worker_envs)
from vllm.logger import init_logger
from vllm.utils import (get_distributed_init_method, get_open_port,
get_open_zmq_ipc_path)
from vllm.v1.outputs import ModelRunnerOutput
from vllm.v1.utils import make_zmq_socket
from vllm.worker.worker_base import WorkerWrapperBase
logger = init_logger(__name__)
POLLING_TIMEOUT_MS = 5000
POLLING_TIMEOUT_S = POLLING_TIMEOUT_MS // 1000
class MultiprocExecutor:
def __init__(self, vllm_config: VllmConfig) -> None:
# Call self.shutdown at exit to clean up
# and ensure workers will be terminated.
atexit.register(self.shutdown)
self.vllm_config = vllm_config
self.parallel_config = vllm_config.parallel_config
self.world_size = self.parallel_config.world_size
tensor_parallel_size = self.parallel_config.tensor_parallel_size
assert self.world_size == tensor_parallel_size, (
f"world_size ({self.world_size}) must be equal to the "
f"tensor_parallel_size ({tensor_parallel_size}). "
f"Pipeline parallelism is not yet implemented in v1")
# Set multiprocessing envs that are common to V0 and V1
set_multiprocessing_worker_envs(self.parallel_config)
# Multiprocessing-based executor does not support multi-node setting.
# Since it only works for single node, we can use the loopback address
# 127.0.0.1 for communication.
distributed_init_method = get_distributed_init_method(
"127.0.0.1", get_open_port())
# Initialize worker and set up message queues for SchedulerOutputs
# and ModelRunnerOutputs
self.rpc_broadcast_mq = MessageQueue(self.world_size, self.world_size)
scheduler_output_handle = self.rpc_broadcast_mq.export_handle()
# Create workers
self.workers: List[WorkerProcHandle] = []
for rank in range(self.world_size):
worker = WorkerProc.make_worker_process(vllm_config, rank, rank,
distributed_init_method,
scheduler_output_handle)
self.workers.append(worker)
# Ensure message queues are ready. Will deadlock if re-ordered
# Must be kept consistent with the WorkerProc
self.rpc_broadcast_mq.wait_until_ready()
for w in self.workers:
w.worker_response_mq.wait_until_ready()
def initialize(self, num_gpu_blocks: int) -> None:
"""
Initialize the KV caches and begin the model execution loop of the
underlying workers.
"""
self.collective_rpc("initialize_cache", args=(num_gpu_blocks, ))
self.collective_rpc("compile_or_warm_up_model")
def determine_num_available_blocks(self) -> Tuple[int, int]:
"""
Determine the number of available KV blocks by invoking the
underlying worker.
"""
num_blocks = self.collective_rpc("determine_num_available_blocks")
# Since we use a shared centralized controller, we take the minimum
# number of blocks across all workers to make sure all the memory
# operators can be applied to all workers.
num_gpu_blocks = min(b[0] for b in num_blocks)
num_cpu_blocks = min(b[1] for b in num_blocks)
return num_gpu_blocks, num_cpu_blocks
def collective_rpc(self,
method: str,
timeout: Optional[float] = None,
args: Tuple = (),
kwargs: Optional[Dict] = None) -> []:
"""
Execute an RPC call on workers.
Args:
method: Name of the worker method to execute
timeout: Maximum time in seconds to wait for execution. Rases a
TimeoutError on timeout. None means wait indefinitely.
args: Positional arguments to pass to the worker method
kwargs: Keyword arguments to pass to the worker method
Returns:
List of results from each worker
"""
start_time = time.monotonic()
kwargs = kwargs or {}
try:
self.rpc_broadcast_mq.enqueue((method, args, kwargs))
responses = [None] * self.world_size
for w in self.workers:
dequeue_timeout = timeout - (time.monotonic() - start_time()
) if timeout is not None else None
status, result = w.worker_response_mq.dequeue(
timeout=dequeue_timeout)
if status != WorkerProc.ResponseStatus.SUCCESS:
if isinstance(result, Exception):
raise result
else:
raise RuntimeError("Worker failed")
responses[w.rank] = result
return responses
except TimeoutError as e:
raise TimeoutError(f"RPC call to {method} timed out.") from e
except Exception as e:
# Re-raise any other exceptions
raise e
def execute_model(
self,
scheduler_output,
) -> ModelRunnerOutput:
model_output = self.collective_rpc("execute_model",
args=(scheduler_output, ))[0]
return model_output
def profile(self, is_start=True):
self.collective_rpc("profile", args=(is_start, ))
return
def _ensure_worker_termination(self):
"""Ensure that all worker processes are terminated. Assumes workers have
received termination requests. Waits for processing, then sends
termination and kill signals if needed."""
def wait_for_termination(procs, timeout):
start_time = time.time()
while time.time() - start_time < timeout:
if all(not proc.is_alive() for proc in procs):
return True
time.sleep(0.1)
return False
# Send SIGTERM if still running
active_procs = [w.proc for w in self.workers if w.proc.is_alive()]
self.workers = None
for p in active_procs:
p.terminate()
if wait_for_termination(active_procs, 4):
return
# Send SIGKILL if still running
active_procs = [p for p in active_procs if p.is_alive()]
for p in active_procs:
p.kill()
def shutdown(self):
"""Properly shut down the executor and its workers"""
if (hasattr(self, 'workers') and self.workers is not None):
for w in self.workers: #TODO: not sure if needed
w.worker_response_mq = None
self._ensure_worker_termination()
self.rpc_broadcast_mq = None
def check_health(self) -> None:
self.collective_rpc("check_health", timeout=10)
return
@dataclass
class WorkerProcHandle:
proc: BaseProcess
rank: int
ready_path: str
worker_response_mq: MessageQueue # The worker process writes to this MQ
class WorkerProc:
"""Wrapper that runs one Worker in a separate process."""
READY_STR = "READY"
def __init__(
self,
vllm_config: VllmConfig,
local_rank: int,
rank: int,
distributed_init_method: str,
input_shm_handle: Handle,
ready_path: str,
):
self.rank = rank
wrapper = WorkerWrapperBase(vllm_config=vllm_config)
wrapper.init_worker(vllm_config, local_rank, rank,
distributed_init_method)
self.worker = wrapper.worker
pid = os.getpid()
_add_prefix(sys.stdout, f"VllmWorker rank={rank}", pid)
_add_prefix(sys.stderr, f"VllmWorker rank={rank}", pid)
# Initialize MessageQueue for receiving SchedulerOutput
self.rpc_broadcast_mq = MessageQueue.create_from_handle(
input_shm_handle, self.worker.rank)
# Initializes a message queue for sending the model output
self.worker_response_mq = MessageQueue(1, 1)
worker_response_mq_handle = self.worker_response_mq.export_handle()
# Send Readiness signal to EngineCore process.
with make_zmq_socket(ready_path, zmq.constants.PUSH) as ready_socket:
payload = pickle.dumps(worker_response_mq_handle,
protocol=pickle.HIGHEST_PROTOCOL)
ready_socket.send_string(WorkerProc.READY_STR)
ready_socket.send(payload)
self.worker.initialize()
self.worker.load_model()
@staticmethod
def make_worker_process(
vllm_config: VllmConfig,
local_rank: int,
rank: int,
distributed_init_method: str,
input_shm_handle, # Receive SchedulerOutput
) -> WorkerProcHandle:
context = get_mp_context()
# ZMQ path for worker to send ready message and shm_broadcast handle
# back to core process.
ready_path = get_open_zmq_ipc_path()
process_kwargs = {
"vllm_config": vllm_config,
"local_rank": local_rank,
"rank": rank,
"distributed_init_method": distributed_init_method,
"input_shm_handle": input_shm_handle,
"ready_path": ready_path,
}
# Run EngineCore busy loop in background process.
proc = context.Process(target=WorkerProc.worker_main,
kwargs=process_kwargs,
daemon=True)
proc.start()
# Wait for startup
worker_response_mq_handle = WorkerProc.wait_for_startup(
proc, ready_path)
worker_response_mq = MessageQueue.create_from_handle(
worker_response_mq_handle, 0)
return WorkerProcHandle(proc, rank, ready_path, worker_response_mq)
def shutdown(self):
self.rpc_broadcast_mq = None
self.worker_response_mq = None
destroy_model_parallel()
destroy_distributed_environment()
@staticmethod
def worker_main(*args, **kwargs):
""" Worker initialization and execution loops.
This runs a background process """
# Signal handler used for graceful termination.
# SystemExit exception is only raised once to allow this and worker
# processes to terminate without error
shutdown_requested = False
def signal_handler(signum, frame):
nonlocal shutdown_requested
if not shutdown_requested:
shutdown_requested = True
raise SystemExit()
# Either SIGTERM or SIGINT will terminate the worker
signal.signal(signal.SIGTERM, signal_handler)
signal.signal(signal.SIGINT, signal_handler)
worker = None
try:
worker = WorkerProc(*args, **kwargs)
# Ensure message queues are ready. Will deadlock if re-ordered.
# Must be kept consistent with the Executor
worker.rpc_broadcast_mq.wait_until_ready()
worker.worker_response_mq.wait_until_ready()
worker.worker_busy_loop()
except SystemExit:
logger.debug("Worker interrupted.")
except BaseException as e:
logger.exception(e)
raise
finally:
# Clean up once worker exits busy loop
if worker is not None:
worker.shutdown()
worker = None
@staticmethod
def wait_for_startup(
proc: BaseProcess,
ready_path: str,
) -> Optional[Handle]:
"""Wait until the Worker is ready."""
with make_zmq_socket(ready_path, zmq.constants.PULL) as socket:
# Wait for Worker to send READY.
while socket.poll(timeout=POLLING_TIMEOUT_MS) == 0:
logger.debug("Waiting for WorkerProc to startup.")
if not proc.is_alive():
raise RuntimeError("WorkerProc failed to start.")
message = socket.recv_string()
assert message == WorkerProc.READY_STR
handle_frame = socket.recv(copy=False)
handle = pickle.loads(handle_frame.buffer)
return handle
class ResponseStatus(Enum):
SUCCESS = auto()
FAILURE = auto()
def worker_busy_loop(self):
"""Main busy loop for Multiprocessing Workers"""
while True:
method, args, kwargs = self.rpc_broadcast_mq.dequeue()
try:
output = getattr(self.worker, method)(*args, **kwargs)
except BaseException as e:
self.worker_response_mq.enqueue(
(WorkerProc.ResponseStatus.FAILURE, e))
continue
self.worker_response_mq.enqueue(
(WorkerProc.ResponseStatus.SUCCESS, output))

View File

@ -10,7 +10,7 @@ from vllm.v1.worker.gpu_worker import Worker
logger = init_logger(__name__)
class GPUExecutor:
class UniprocExecutor:
def __init__(self, vllm_config: VllmConfig) -> None:
self.vllm_config = vllm_config
@ -54,7 +54,7 @@ class GPUExecutor:
"""
return self.worker.determine_num_available_blocks()
def initialize_cache(self, num_gpu_blocks: int) -> None:
def initialize(self, num_gpu_blocks: int) -> None:
"""Initialize the KV cache by invoking the underlying worker.
"""
# NOTE: This is logged in the executor because there can be >1 worker
@ -71,7 +71,13 @@ class GPUExecutor:
output = self.worker.execute_model(scheduler_output)
return output
def profile(self, is_start: bool = True):
self.worker.profile(is_start)
def shutdown(self):
self.worker = None
def check_health(self) -> None:
# GPUExecutor will always be healthy as long as
# UniprocExecutor will always be healthy as long as
# it's running.
return

View File

@ -8,7 +8,7 @@ import torch
class SamplerOutput:
# [num_reqs]
sampled_token_ids: torch.Tensor
sampled_token_ids: List[int]
# [num_reqs, max_num_logprobs + 1]
logprob_token_ids: Optional[torch.Tensor]
@ -20,6 +20,8 @@ class SamplerOutput:
prompt_logprobs: Optional[torch.Tensor]
# ModelRunnerOutput is serialized and sent to the scheduler process.
# This is expensive for torch.Tensor so prefer to use List instead.
@dataclass
class ModelRunnerOutput:
@ -29,7 +31,7 @@ class ModelRunnerOutput:
req_id_to_index: Dict[str, int]
# [num_reqs]
sampled_token_ids_cpu: torch.Tensor
sampled_token_ids: List[int]
# [num_reqs, max_num_logprobs + 1]
logprob_token_ids_cpu: Optional[torch.Tensor]

View File

@ -37,8 +37,9 @@ class Sampler(nn.Module):
topk_logprobs = None
topk_indices = None
# NOTE: CPU-GPU synchronization happens here.
sampler_output = SamplerOutput(
sampled_token_ids=sampled,
sampled_token_ids=sampled.tolist(),
logprob_token_ids=topk_indices,
logprobs=topk_logprobs,
prompt_logprob_token_ids=None,

View File

@ -1,4 +1,11 @@
from typing import Generic, List, TypeVar, overload
from contextlib import contextmanager
from typing import Any, Generic, Iterator, List, TypeVar, overload
import zmq
from vllm.logger import init_logger
logger = init_logger(__name__)
T = TypeVar("T")
@ -62,3 +69,27 @@ class ConstantList(Generic[T]):
def __len__(self):
return len(self._x)
@contextmanager
def make_zmq_socket(path: str, type: Any) -> Iterator[zmq.Socket]:
"""Context manager for a ZMQ socket"""
ctx = zmq.Context()
try:
socket = ctx.socket(type)
if type == zmq.constants.PULL:
socket.connect(path)
elif type == zmq.constants.PUSH:
socket.bind(path)
else:
raise ValueError(f"Unknown Socket Type: {type}")
yield socket
except KeyboardInterrupt:
logger.debug("Worker had Keyboard Interrupt.")
finally:
ctx.destroy(linger=0)

View File

@ -34,6 +34,7 @@ class GPUModelRunner:
def __init__(
self,
vllm_config: VllmConfig,
device: torch.device,
input_registry: InputRegistry = INPUT_REGISTRY,
):
self.vllm_config = vllm_config
@ -43,7 +44,6 @@ class GPUModelRunner:
self.load_config = vllm_config.load_config
self.parallel_config = vllm_config.parallel_config
self.scheduler_config = vllm_config.scheduler_config
self.device_config = vllm_config.device_config
self.speculative_config = vllm_config.speculative_config
self.prompt_adapter_config = vllm_config.prompt_adapter_config
self.observability_config = vllm_config.observability_config
@ -52,7 +52,7 @@ class GPUModelRunner:
cache_config = self.cache_config
scheduler_config = self.scheduler_config
parallel_config = self.parallel_config
self.device = self.device_config.device
self.device = device
self.pin_memory = is_pin_memory_available()
self.dtype = self.model_config.dtype
if cache_config.cache_dtype == "auto":
@ -477,9 +477,7 @@ class GPUModelRunner:
sampling_metadata=sampling_metadata,
)
# NOTE: CPU-GPU synchronization happens here.
sampled_token_ids = sampler_output.sampled_token_ids.cpu()
sampled_token_ids_list = sampled_token_ids.tolist()
sampled_token_ids = sampler_output.sampled_token_ids
# TODO(woosuk): The following loop can be slow since it iterates over
# the requests one by one. Optimize.
num_reqs = self.input_batch.num_reqs
@ -490,7 +488,7 @@ class GPUModelRunner:
assert seq_len <= req_state.num_tokens
if seq_len == req_state.num_tokens:
# Append the sampled token to the output token ids.
token_id = sampled_token_ids_list[i]
token_id = sampled_token_ids[i]
self.input_batch.token_ids_cpu[i, seq_len] = token_id
req_state.output_token_ids.append(token_id)
else:
@ -512,7 +510,7 @@ class GPUModelRunner:
model_runner_output = ModelRunnerOutput(
req_ids=self.input_batch.req_ids[:num_reqs],
req_id_to_index=self.input_batch.req_id_to_index,
sampled_token_ids_cpu=sampled_token_ids,
sampled_token_ids=sampled_token_ids,
logprob_token_ids_cpu=logprob_token_ids,
logprobs_cpu=logprobs,
)

View File

@ -15,6 +15,7 @@ from vllm.logger import init_logger
from vllm.model_executor import set_random_seed
from vllm.platforms import current_platform
from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE, get_dtype_size
from vllm.v1.core.scheduler import SchedulerOutput
from vllm.v1.outputs import ModelRunnerOutput
from vllm.v1.worker.gpu_model_runner import GPUModelRunner
@ -56,7 +57,6 @@ class Worker:
from vllm.utils import init_cached_hf_modules
init_cached_hf_modules()
self.model_runner = GPUModelRunner(vllm_config)
# Torch profiler. Enabled and configured through env vars:
# VLLM_TORCH_PROFILER_DIR=/path/to/save/trace
if envs.VLLM_TORCH_PROFILER_DIR:
@ -103,6 +103,9 @@ class Worker:
# Set random seed.
set_random_seed(self.model_config.seed)
# Construct the model runner
self.model_runner = GPUModelRunner(self.vllm_config, self.device)
def load_model(self) -> None:
self.model_runner.load_model()
@ -198,7 +201,7 @@ class Worker:
scheduler_output: "SchedulerOutput",
) -> ModelRunnerOutput:
output = self.model_runner.execute_model(scheduler_output)
# TODO(woosuk): Send the output to the engine process.
return output if self.rank == 0 else None
return output
def profile(self, is_start=True):
@ -209,6 +212,10 @@ class Worker:
else:
self.profiler.stop()
def check_health(self) -> None:
# worker will always be healthy as long as it's running.
return
def init_worker_distributed_environment(
parallel_config: ParallelConfig,