[V1] Multiprocessing Tensor Parallel Support for v1 (#9856)
Signed-off-by: Tyler Michael Smith <tyler@neuralmagic.com>
This commit is contained in:
parent
bc192a2b09
commit
28b3a1c7e5
@ -26,6 +26,14 @@ MODELS = [
|
||||
TARGET_TEST_SUITE = os.environ.get("TARGET_TEST_SUITE", "L4")
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def v1(run_with_both_engines):
|
||||
# Simple autouse wrapper to run both engines for each test
|
||||
# This can be promoted up to conftest.py to run for every
|
||||
# test in a package
|
||||
pass
|
||||
|
||||
|
||||
def test_vllm_gc_ed():
|
||||
"""Verify vllm instance is GC'ed when it is deleted"""
|
||||
llm = LLM("facebook/opt-125m")
|
||||
@ -36,6 +44,7 @@ def test_vllm_gc_ed():
|
||||
assert weak_llm() is None
|
||||
|
||||
|
||||
@pytest.mark.skip_v1
|
||||
@pytest.mark.parametrize("model", MODELS)
|
||||
@pytest.mark.parametrize("backend", ["FLASH_ATTN", "XFORMERS", "FLASHINFER"])
|
||||
@pytest.mark.parametrize("dtype", ["half"])
|
||||
@ -118,6 +127,11 @@ def test_models_distributed(
|
||||
if attention_backend:
|
||||
os.environ["VLLM_ATTENTION_BACKEND"] = attention_backend
|
||||
|
||||
# Import VLLM_USE_V1 dynamically to handle patching
|
||||
from vllm.envs import VLLM_USE_V1
|
||||
if VLLM_USE_V1 and distributed_executor_backend != "mp":
|
||||
pytest.skip(f"Skip {distributed_executor_backend} for V1")
|
||||
|
||||
dtype = "half"
|
||||
max_tokens = 5
|
||||
|
||||
@ -143,6 +157,7 @@ def test_models_distributed(
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.skip_v1
|
||||
def test_model_with_failure(vllm_runner) -> None:
|
||||
try:
|
||||
with patch("vllm.model_executor.models.opt.OPTForCausalLM.forward",
|
||||
@ -169,6 +184,7 @@ def test_model_with_failure(vllm_runner) -> None:
|
||||
os.remove(filename)
|
||||
|
||||
|
||||
@pytest.mark.skip_v1
|
||||
def test_failure_with_async_out_proc(vllm_runner) -> None:
|
||||
|
||||
filename = None
|
||||
|
@ -5,7 +5,6 @@ from collections import UserList
|
||||
from enum import Enum
|
||||
from typing import (Any, Callable, Dict, List, Optional, Tuple, Type,
|
||||
TypedDict, TypeVar, Union)
|
||||
from unittest.mock import patch
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
@ -110,7 +109,7 @@ VIDEO_ASSETS = _VideoAssets()
|
||||
|
||||
|
||||
@pytest.fixture(params=[True, False])
|
||||
def run_with_both_engines(request):
|
||||
def run_with_both_engines(request, monkeypatch):
|
||||
# Automatically runs tests twice, once with V1 and once without
|
||||
use_v1 = request.param
|
||||
# Tests decorated with `@skip_v1` are only run without v1
|
||||
@ -119,11 +118,11 @@ def run_with_both_engines(request):
|
||||
if use_v1:
|
||||
if skip_v1:
|
||||
pytest.skip("Skipping test on vllm V1")
|
||||
with patch('vllm.envs.VLLM_USE_V1', True):
|
||||
yield
|
||||
monkeypatch.setenv('VLLM_USE_V1', '1')
|
||||
else:
|
||||
with patch('vllm.envs.VLLM_USE_V1', False):
|
||||
yield
|
||||
monkeypatch.setenv('VLLM_USE_V1', '0')
|
||||
|
||||
yield
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
|
@ -1,10 +1,11 @@
|
||||
import os
|
||||
import pickle
|
||||
import sys
|
||||
import time
|
||||
from contextlib import contextmanager
|
||||
from dataclasses import dataclass, field
|
||||
from multiprocessing import shared_memory
|
||||
from typing import List, Optional
|
||||
from typing import List, Optional, Tuple
|
||||
from unittest.mock import patch
|
||||
|
||||
import torch
|
||||
@ -21,6 +22,20 @@ VLLM_RINGBUFFER_WARNING_INTERVAL = envs.VLLM_RINGBUFFER_WARNING_INTERVAL
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
# We prefer to use os.sched_yield as it results in tighter polling loops,
|
||||
# measured to be around 3e-7 seconds. However on earlier versions of Python
|
||||
# os.sched_yield() does not release the GIL, so we fall back to time.sleep(0)
|
||||
USE_SCHED_YIELD = ((sys.version_info[:3] >= (3, 11, 1))
|
||||
or (sys.version_info[:2] == (3, 10)
|
||||
and sys.version_info[2] >= 8))
|
||||
|
||||
|
||||
def sched_yield():
|
||||
if USE_SCHED_YIELD:
|
||||
os.sched_yield()
|
||||
else:
|
||||
time.sleep(0)
|
||||
|
||||
|
||||
class ShmRingBuffer:
|
||||
|
||||
@ -114,11 +129,14 @@ class ShmRingBuffer:
|
||||
# and we should suppress the error
|
||||
pass
|
||||
|
||||
def handle(self):
|
||||
return (self.n_reader, self.max_chunk_bytes, self.max_chunks,
|
||||
self.shared_memory.name)
|
||||
|
||||
def __reduce__(self):
|
||||
return (
|
||||
self.__class__,
|
||||
(self.n_reader, self.max_chunk_bytes, self.max_chunks,
|
||||
self.shared_memory.name),
|
||||
self.handle(),
|
||||
)
|
||||
|
||||
def __del__(self):
|
||||
@ -147,7 +165,7 @@ class Handle:
|
||||
connect_ip: str
|
||||
local_reader_ranks: List[int] = field(default_factory=list)
|
||||
|
||||
buffer: Optional[ShmRingBuffer] = None
|
||||
buffer_handle: Optional[Tuple[int, int, int, str]] = None
|
||||
local_subscribe_port: Optional[int] = None
|
||||
remote_subscribe_port: Optional[int] = None
|
||||
|
||||
@ -228,7 +246,7 @@ class MessageQueue:
|
||||
self.handle = Handle(
|
||||
connect_ip=connect_ip,
|
||||
local_reader_ranks=local_reader_ranks,
|
||||
buffer=self.buffer,
|
||||
buffer_handle=self.buffer.handle(),
|
||||
local_subscribe_port=local_subscribe_port,
|
||||
remote_subscribe_port=remote_subscribe_port,
|
||||
)
|
||||
@ -247,8 +265,8 @@ class MessageQueue:
|
||||
context = Context()
|
||||
|
||||
if rank in handle.local_reader_ranks:
|
||||
assert handle.buffer is not None
|
||||
self.buffer = handle.buffer
|
||||
assert handle.buffer_handle is not None
|
||||
self.buffer = ShmRingBuffer(*handle.buffer_handle)
|
||||
self.current_idx = 0
|
||||
self.local_reader_rank = handle.local_reader_ranks.index(rank)
|
||||
self._is_local_reader = True
|
||||
@ -314,7 +332,7 @@ class MessageQueue:
|
||||
assert recv == b"READY"
|
||||
|
||||
@contextmanager
|
||||
def acquire_write(self):
|
||||
def acquire_write(self, timeout: Optional[float] = None):
|
||||
assert self._is_writer, "Only writers can acquire write"
|
||||
start_time = time.monotonic()
|
||||
n_warning = 1
|
||||
@ -329,16 +347,20 @@ class MessageQueue:
|
||||
# we need to wait until it is read by all readers
|
||||
|
||||
# Release the processor to other threads
|
||||
os.sched_yield()
|
||||
sched_yield()
|
||||
|
||||
# if we wait for a long time, we should warn the user
|
||||
# if we wait for a long time, log a message
|
||||
if (time.monotonic() - start_time >
|
||||
VLLM_RINGBUFFER_WARNING_INTERVAL * n_warning):
|
||||
logger.warning(
|
||||
"No available block found in %s second. ",
|
||||
VLLM_RINGBUFFER_WARNING_INTERVAL)
|
||||
logger.debug("No available block found in %s second. ",
|
||||
VLLM_RINGBUFFER_WARNING_INTERVAL)
|
||||
n_warning += 1
|
||||
|
||||
# if we time out, raise an exception
|
||||
if (timeout is not None
|
||||
and time.monotonic() - start_time > timeout):
|
||||
raise TimeoutError
|
||||
|
||||
continue
|
||||
# found a block that is either
|
||||
# (1) not written
|
||||
@ -365,7 +387,7 @@ class MessageQueue:
|
||||
break
|
||||
|
||||
@contextmanager
|
||||
def acquire_read(self):
|
||||
def acquire_read(self, timeout: Optional[float] = None):
|
||||
assert self._is_local_reader, "Only readers can acquire read"
|
||||
start_time = time.monotonic()
|
||||
n_warning = 1
|
||||
@ -383,16 +405,20 @@ class MessageQueue:
|
||||
# we need to wait until it is written
|
||||
|
||||
# Release the processor to other threads
|
||||
os.sched_yield()
|
||||
sched_yield()
|
||||
|
||||
# if we wait for a long time, we should warn the user
|
||||
# if we wait for a long time, log a message
|
||||
if (time.monotonic() - start_time >
|
||||
VLLM_RINGBUFFER_WARNING_INTERVAL * n_warning):
|
||||
logger.warning(
|
||||
"No available block found in %s second. ",
|
||||
VLLM_RINGBUFFER_WARNING_INTERVAL)
|
||||
logger.debug("No available block found in %s second. ",
|
||||
VLLM_RINGBUFFER_WARNING_INTERVAL)
|
||||
n_warning += 1
|
||||
|
||||
# if we time out, raise an exception
|
||||
if (timeout is not None
|
||||
and time.monotonic() - start_time > timeout):
|
||||
raise TimeoutError
|
||||
|
||||
continue
|
||||
# found a block that is not read by this reader
|
||||
# let caller read from the buffer
|
||||
@ -406,24 +432,26 @@ class MessageQueue:
|
||||
1) % self.buffer.max_chunks
|
||||
break
|
||||
|
||||
def enqueue(self, obj):
|
||||
def enqueue(self, obj, timeout: Optional[float] = None):
|
||||
""" Write to message queue with optional timeout (in seconds) """
|
||||
assert self._is_writer, "Only writers can enqueue"
|
||||
serialized_obj = pickle.dumps(obj, protocol=pickle.HIGHEST_PROTOCOL)
|
||||
if self.n_local_reader > 0:
|
||||
if len(serialized_obj) >= self.buffer.max_chunk_bytes:
|
||||
with self.acquire_write() as buf:
|
||||
with self.acquire_write(timeout) as buf:
|
||||
buf[0] = 1 # overflow
|
||||
self.local_socket.send(serialized_obj)
|
||||
else:
|
||||
with self.acquire_write() as buf:
|
||||
with self.acquire_write(timeout) as buf:
|
||||
buf[0] = 0 # not overflow
|
||||
buf[1:len(serialized_obj) + 1] = serialized_obj
|
||||
if self.n_remote_reader > 0:
|
||||
self.remote_socket.send(serialized_obj)
|
||||
|
||||
def dequeue(self):
|
||||
def dequeue(self, timeout: Optional[float] = None):
|
||||
""" Read from message queue with optional timeout (in seconds) """
|
||||
if self._is_local_reader:
|
||||
with self.acquire_read() as buf:
|
||||
with self.acquire_read(timeout) as buf:
|
||||
overflow = buf[0] == 1
|
||||
if not overflow:
|
||||
# no need to know the size of serialized object
|
||||
|
@ -3,25 +3,19 @@ import os
|
||||
from functools import partial
|
||||
from typing import Any, List, Optional
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.executor.distributed_gpu_executor import ( # yapf: disable
|
||||
DistributedGPUExecutor, DistributedGPUExecutorAsync)
|
||||
from vllm.executor.gpu_executor import create_worker
|
||||
from vllm.executor.multiproc_worker_utils import (ProcessWorkerWrapper,
|
||||
ResultHandler, WorkerMonitor)
|
||||
from vllm.executor.multiproc_worker_utils import (
|
||||
ProcessWorkerWrapper, ResultHandler, WorkerMonitor,
|
||||
set_multiprocessing_worker_envs)
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.sampler import SamplerOutput
|
||||
from vllm.sequence import ExecuteModelRequest
|
||||
from vllm.triton_utils.importing import HAS_TRITON
|
||||
from vllm.utils import (_run_task_with_lock, cuda_device_count_stateless,
|
||||
cuda_is_initialized, get_distributed_init_method,
|
||||
get_open_port, make_async,
|
||||
get_distributed_init_method, get_open_port, make_async,
|
||||
update_environment_variables)
|
||||
|
||||
if HAS_TRITON:
|
||||
from vllm.triton_utils import maybe_set_triton_cache_manager
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
@ -37,30 +31,8 @@ class MultiprocessingGPUExecutor(DistributedGPUExecutor):
|
||||
world_size = self.parallel_config.world_size
|
||||
tensor_parallel_size = self.parallel_config.tensor_parallel_size
|
||||
|
||||
# Disable torch async compiling which won't work with daemonic processes
|
||||
os.environ["TORCHINDUCTOR_COMPILE_THREADS"] = "1"
|
||||
|
||||
# Configure thread parallelism if OMP_NUM_THREADS isn't set
|
||||
#
|
||||
# Helps to avoid CPU contention. The default of spawning a thread per
|
||||
# core combined with multiprocessing for each GPU can have a negative
|
||||
# impact on performance. The contention is amplified when running in a
|
||||
# container where CPU limits can cause throttling.
|
||||
default_omp_num_threads = 1
|
||||
if "OMP_NUM_THREADS" not in os.environ and (
|
||||
current_parallelism :=
|
||||
torch.get_num_threads()) > default_omp_num_threads:
|
||||
logger.warning(
|
||||
"Reducing Torch parallelism from %d threads to %d to avoid "
|
||||
"unnecessary CPU contention. Set OMP_NUM_THREADS in the "
|
||||
"external environment to tune this value as needed.",
|
||||
current_parallelism, default_omp_num_threads)
|
||||
os.environ["OMP_NUM_THREADS"] = str(default_omp_num_threads)
|
||||
torch.set_num_threads(default_omp_num_threads)
|
||||
|
||||
# workaround for https://github.com/vllm-project/vllm/issues/6103
|
||||
if HAS_TRITON and world_size > 1:
|
||||
maybe_set_triton_cache_manager()
|
||||
# Set multiprocessing envs that are common to V0 and V1
|
||||
set_multiprocessing_worker_envs(self.parallel_config)
|
||||
|
||||
# Multiprocessing-based executor does not support multi-node setting.
|
||||
# Since it only works for single node, we can use the loopback address
|
||||
@ -122,13 +94,6 @@ class MultiprocessingGPUExecutor(DistributedGPUExecutor):
|
||||
"CUDA_VISIBLE_DEVICES": (",".join(map(str, range(world_size))))
|
||||
})
|
||||
|
||||
if (cuda_is_initialized()
|
||||
and os.environ.get("VLLM_WORKER_MULTIPROC_METHOD") != "spawn"):
|
||||
logger.warning("CUDA was previously initialized. We must use "
|
||||
"the `spawn` multiprocessing start method. Setting "
|
||||
"VLLM_WORKER_MULTIPROC_METHOD to 'spawn'.")
|
||||
os.environ["VLLM_WORKER_MULTIPROC_METHOD"] = "spawn"
|
||||
|
||||
cuda_device_count = cuda_device_count_stateless()
|
||||
# Use confusing message for more common TP-only case.
|
||||
assert tensor_parallel_size <= cuda_device_count, (
|
||||
|
@ -11,8 +11,15 @@ from multiprocessing.process import BaseProcess
|
||||
from typing import (Any, Callable, Dict, Generic, List, Optional, TextIO,
|
||||
TypeVar, Union)
|
||||
|
||||
import torch
|
||||
|
||||
import vllm.envs as envs
|
||||
from vllm.logger import init_logger
|
||||
from vllm.triton_utils.importing import HAS_TRITON
|
||||
from vllm.utils import cuda_is_initialized
|
||||
|
||||
if HAS_TRITON:
|
||||
from vllm.triton_utils import maybe_set_triton_cache_manager
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
@ -270,3 +277,38 @@ def _add_prefix(file: TextIO, worker_name: str, pid: int) -> None:
|
||||
def get_mp_context():
|
||||
mp_method = envs.VLLM_WORKER_MULTIPROC_METHOD
|
||||
return multiprocessing.get_context(mp_method)
|
||||
|
||||
|
||||
def set_multiprocessing_worker_envs(parallel_config):
|
||||
""" Set up environment variables that should be used when there are workers
|
||||
in a multiprocessing environment. This should be called by the parent
|
||||
process before worker processes are created"""
|
||||
|
||||
if (cuda_is_initialized()
|
||||
and os.environ.get("VLLM_WORKER_MULTIPROC_METHOD") != "spawn"):
|
||||
logger.warning("CUDA was previously initialized. We must use "
|
||||
"the `spawn` multiprocessing start method. Setting "
|
||||
"VLLM_WORKER_MULTIPROC_METHOD to 'spawn'.")
|
||||
os.environ["VLLM_WORKER_MULTIPROC_METHOD"] = "spawn"
|
||||
|
||||
# Configure thread parallelism if OMP_NUM_THREADS isn't set
|
||||
#
|
||||
# Helps to avoid CPU contention. The default of spawning a thread per
|
||||
# core combined with multiprocessing for each GPU can have a negative
|
||||
# impact on performance. The contention is amplified when running in a
|
||||
# container where CPU limits can cause throttling.
|
||||
default_omp_num_threads = 1
|
||||
if "OMP_NUM_THREADS" not in os.environ and (
|
||||
current_parallelism :=
|
||||
torch.get_num_threads()) > default_omp_num_threads:
|
||||
logger.warning(
|
||||
"Reducing Torch parallelism from %d threads to %d to avoid "
|
||||
"unnecessary CPU contention. Set OMP_NUM_THREADS in the "
|
||||
"external environment to tune this value as needed.",
|
||||
current_parallelism, default_omp_num_threads)
|
||||
os.environ["OMP_NUM_THREADS"] = str(default_omp_num_threads)
|
||||
torch.set_num_threads(default_omp_num_threads)
|
||||
|
||||
# workaround for https://github.com/vllm-project/vllm/issues/6103
|
||||
if HAS_TRITON and parallel_config.world_size > 1:
|
||||
maybe_set_triton_cache_manager()
|
||||
|
@ -5,6 +5,7 @@ from typing import Optional
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
import vllm.envs as envs
|
||||
from vllm.distributed import (tensor_model_parallel_all_gather,
|
||||
tensor_model_parallel_gather)
|
||||
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
||||
@ -42,7 +43,9 @@ class LogitsProcessor(nn.Module):
|
||||
# Soft cap the logits. Used in Gemma 2.
|
||||
self.soft_cap = soft_cap
|
||||
# Whether to use gather or all-gather to gather the logits.
|
||||
self.use_gather = not current_platform.is_tpu()
|
||||
|
||||
self.use_gather = not current_platform.is_tpu(
|
||||
) and not envs.VLLM_USE_V1
|
||||
|
||||
def forward(
|
||||
self,
|
||||
|
@ -12,6 +12,7 @@ from typing_extensions import ParamSpec
|
||||
|
||||
# import custom ops, trigger op registration
|
||||
import vllm._C # noqa
|
||||
import vllm.envs as envs
|
||||
from vllm.logger import init_logger
|
||||
|
||||
from .interface import DeviceCapability, Platform, PlatformEnum
|
||||
@ -110,17 +111,28 @@ class CudaPlatformBase(Platform):
|
||||
def check_and_update_config(cls, vllm_config: VllmConfig) -> None:
|
||||
parallel_config = vllm_config.parallel_config
|
||||
scheduler_config = vllm_config.scheduler_config
|
||||
|
||||
if parallel_config.worker_cls == "auto":
|
||||
if scheduler_config.is_multi_step:
|
||||
parallel_config.worker_cls = \
|
||||
"vllm.worker.multi_step_worker.MultiStepWorker"
|
||||
if envs.VLLM_USE_V1:
|
||||
raise NotImplementedError
|
||||
else:
|
||||
parallel_config.worker_cls = \
|
||||
"vllm.worker.multi_step_worker.MultiStepWorker"
|
||||
elif vllm_config.speculative_config:
|
||||
parallel_config.worker_cls = \
|
||||
"vllm.spec_decode.spec_decode_worker.create_spec_worker"
|
||||
parallel_config.sd_worker_cls = \
|
||||
"vllm.worker.worker.Worker"
|
||||
if envs.VLLM_USE_V1:
|
||||
raise NotImplementedError
|
||||
else:
|
||||
parallel_config.worker_cls = \
|
||||
"vllm.spec_decode.spec_decode_worker.create_spec_worker"
|
||||
parallel_config.sd_worker_cls = \
|
||||
"vllm.worker.worker.Worker"
|
||||
else:
|
||||
parallel_config.worker_cls = "vllm.worker.worker.Worker"
|
||||
if envs.VLLM_USE_V1:
|
||||
parallel_config.worker_cls = \
|
||||
"vllm.v1.worker.gpu_worker.Worker"
|
||||
else:
|
||||
parallel_config.worker_cls = "vllm.worker.worker.Worker"
|
||||
|
||||
|
||||
# NVML utils
|
||||
@ -249,4 +261,4 @@ try:
|
||||
if not isinstance(pynvml, _MockModule):
|
||||
CudaPlatform.log_warnings()
|
||||
except ModuleNotFoundError:
|
||||
CudaPlatform.log_warnings()
|
||||
CudaPlatform.log_warnings()
|
||||
|
@ -10,6 +10,7 @@ import importlib.util
|
||||
import inspect
|
||||
import ipaddress
|
||||
import os
|
||||
import signal
|
||||
import socket
|
||||
import subprocess
|
||||
import sys
|
||||
@ -1652,3 +1653,28 @@ def resolve_obj_by_qualname(qualname: str) -> Any:
|
||||
module_name, obj_name = qualname.rsplit(".", 1)
|
||||
module = importlib.import_module(module_name)
|
||||
return getattr(module, obj_name)
|
||||
|
||||
|
||||
def kill_process_tree(pid: int):
|
||||
"""
|
||||
Kills all descendant processes of the given pid by sending SIGKILL.
|
||||
|
||||
Args:
|
||||
pid (int): Process ID of the parent process
|
||||
"""
|
||||
try:
|
||||
parent = psutil.Process(pid)
|
||||
except psutil.NoSuchProcess:
|
||||
return
|
||||
|
||||
# Get all children recursively
|
||||
children = parent.children(recursive=True)
|
||||
|
||||
# Send SIGKILL to all children first
|
||||
for child in children:
|
||||
with contextlib.suppress(ProcessLookupError):
|
||||
os.kill(child.pid, signal.SIGKILL)
|
||||
|
||||
# Finally kill the parent
|
||||
with contextlib.suppress(ProcessLookupError):
|
||||
os.kill(pid, signal.SIGKILL)
|
||||
|
@ -5,6 +5,8 @@ from typing import (TYPE_CHECKING, Deque, Dict, Iterable, List, Optional, Set,
|
||||
|
||||
from vllm.config import CacheConfig, LoRAConfig, SchedulerConfig
|
||||
from vllm.logger import init_logger
|
||||
from vllm.multimodal import MultiModalKwargs
|
||||
from vllm.multimodal.base import PlaceholderRange
|
||||
from vllm.sampling_params import SamplingParams
|
||||
from vllm.v1.core.encoder_cache_manager import EncoderCacheManager
|
||||
from vllm.v1.core.kv_cache_manager import KVCacheManager
|
||||
@ -383,7 +385,7 @@ class Scheduler:
|
||||
model_runner_output: "ModelRunnerOutput",
|
||||
) -> List[EngineCoreOutput]:
|
||||
# NOTE(woosuk): This method doesn't consider speculative decoding.
|
||||
sampled_token_ids = model_runner_output.sampled_token_ids_cpu.tolist()
|
||||
sampled_token_ids = model_runner_output.sampled_token_ids
|
||||
num_scheduled_tokens = scheduler_output.num_scheduled_tokens
|
||||
new_running: List[Request] = []
|
||||
engine_core_outputs: List[EngineCoreOutput] = []
|
||||
|
@ -20,7 +20,7 @@ from vllm.v1.engine.async_stream import AsyncStream
|
||||
from vllm.v1.engine.core_client import EngineCoreClient
|
||||
from vllm.v1.engine.detokenizer import Detokenizer
|
||||
from vllm.v1.engine.processor import Processor
|
||||
from vllm.v1.executor.gpu_executor import GPUExecutor
|
||||
from vllm.v1.executor.abstract import Executor
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
@ -30,7 +30,7 @@ class AsyncLLM(EngineClient):
|
||||
def __init__(
|
||||
self,
|
||||
vllm_config: VllmConfig,
|
||||
executor_class: Type[GPUExecutor],
|
||||
executor_class: Type[Executor],
|
||||
log_stats: bool,
|
||||
usage_context: UsageContext = UsageContext.ENGINE_CONTEXT,
|
||||
stat_loggers: Optional[Dict[str, StatLoggerBase]] = None,
|
||||
@ -119,14 +119,24 @@ class AsyncLLM(EngineClient):
|
||||
def shutdown(self):
|
||||
"""Shutdown, cleaning up the background proc and IPC."""
|
||||
|
||||
self.engine_core.shutdown()
|
||||
if engine_core := getattr(self, "engine_core", None):
|
||||
engine_core.shutdown()
|
||||
|
||||
if handler := getattr(self, "output_handler", None):
|
||||
handler.cancel()
|
||||
|
||||
@classmethod
|
||||
def _get_executor_cls(cls, vllm_config: VllmConfig):
|
||||
return GPUExecutor
|
||||
distributed_executor_backend = (
|
||||
vllm_config.parallel_config.distributed_executor_backend)
|
||||
if distributed_executor_backend == "mp":
|
||||
from vllm.v1.executor.multiproc_executor import MultiprocExecutor
|
||||
executor_class = MultiprocExecutor
|
||||
else:
|
||||
assert (distributed_executor_backend is None)
|
||||
from vllm.v1.executor.uniproc_executor import UniprocExecutor
|
||||
executor_class = UniprocExecutor
|
||||
return executor_class
|
||||
|
||||
async def add_request(
|
||||
self,
|
||||
|
@ -1,12 +1,12 @@
|
||||
import multiprocessing
|
||||
import pickle
|
||||
import queue
|
||||
import signal
|
||||
import threading
|
||||
import time
|
||||
from contextlib import contextmanager
|
||||
from multiprocessing.process import BaseProcess
|
||||
from multiprocessing.sharedctypes import Synchronized
|
||||
from typing import Any, Iterator, List, Tuple, Type, Union
|
||||
from typing import List, Tuple, Type, Union
|
||||
|
||||
import zmq
|
||||
import zmq.asyncio
|
||||
@ -20,9 +20,10 @@ from vllm.v1.engine import (EngineCoreOutput, EngineCoreOutputs,
|
||||
EngineCoreProfile, EngineCoreRequest,
|
||||
EngineCoreRequestType)
|
||||
from vllm.v1.engine.mm_input_mapper import MMInputMapper
|
||||
from vllm.v1.executor.gpu_executor import GPUExecutor
|
||||
from vllm.v1.executor.abstract import Executor
|
||||
from vllm.v1.request import Request, RequestStatus
|
||||
from vllm.v1.serial_utils import PickleEncoder
|
||||
from vllm.v1.utils import make_zmq_socket
|
||||
from vllm.version import __version__ as VLLM_VERSION
|
||||
|
||||
logger = init_logger(__name__)
|
||||
@ -38,7 +39,7 @@ class EngineCore:
|
||||
def __init__(
|
||||
self,
|
||||
vllm_config: VllmConfig,
|
||||
executor_class: Type[GPUExecutor],
|
||||
executor_class: Type[Executor],
|
||||
usage_context: UsageContext,
|
||||
):
|
||||
assert vllm_config.model_config.task != "embedding"
|
||||
@ -80,7 +81,7 @@ class EngineCore:
|
||||
num_gpu_blocks = num_gpu_blocks_override
|
||||
|
||||
num_cpu_blocks = 0
|
||||
self.model_executor.initialize_cache(num_gpu_blocks)
|
||||
self.model_executor.initialize(num_gpu_blocks)
|
||||
elapsed = time.time() - start
|
||||
logger.info(("init engine (profile, create kv cache, "
|
||||
"warmup model) took %.2f seconds"), elapsed)
|
||||
@ -112,8 +113,11 @@ class EngineCore:
|
||||
scheduler_output, output)
|
||||
return engine_core_outputs
|
||||
|
||||
def shutdown(self):
|
||||
self.model_executor.shutdown()
|
||||
|
||||
def profile(self, is_start=True):
|
||||
self.model_executor.worker.profile(is_start)
|
||||
self.model_executor.profile(is_start)
|
||||
|
||||
|
||||
class EngineCoreProc(EngineCore):
|
||||
@ -124,7 +128,7 @@ class EngineCoreProc(EngineCore):
|
||||
def __init__(
|
||||
self,
|
||||
vllm_config: VllmConfig,
|
||||
executor_class: Type[GPUExecutor],
|
||||
executor_class: Type[Executor],
|
||||
usage_context: UsageContext,
|
||||
input_path: str,
|
||||
output_path: str,
|
||||
@ -151,32 +155,9 @@ class EngineCoreProc(EngineCore):
|
||||
daemon=True).start()
|
||||
|
||||
# Send Readiness signal to EngineClient.
|
||||
with self.make_socket(ready_path, zmq.constants.PUSH) as ready_socket:
|
||||
with make_zmq_socket(ready_path, zmq.constants.PUSH) as ready_socket:
|
||||
ready_socket.send_string(EngineCoreProc.READY_STR)
|
||||
|
||||
@contextmanager
|
||||
def make_socket(self, path: str, type: Any) -> Iterator[zmq.Socket]:
|
||||
"""Context manager for use """
|
||||
|
||||
ctx = zmq.Context()
|
||||
try:
|
||||
socket = ctx.socket(type)
|
||||
|
||||
if type == zmq.constants.PULL:
|
||||
socket.connect(path)
|
||||
elif type == zmq.constants.PUSH:
|
||||
socket.bind(path)
|
||||
else:
|
||||
raise ValueError(f"Unknown Socket Type: {type}")
|
||||
|
||||
yield socket
|
||||
|
||||
except KeyboardInterrupt:
|
||||
logger.debug("EngineCore had Keyboard Interrupt.")
|
||||
|
||||
finally:
|
||||
ctx.destroy(linger=0)
|
||||
|
||||
@staticmethod
|
||||
def wait_for_startup(
|
||||
proc: BaseProcess,
|
||||
@ -209,7 +190,7 @@ class EngineCoreProc(EngineCore):
|
||||
@staticmethod
|
||||
def make_engine_core_process(
|
||||
vllm_config: VllmConfig,
|
||||
executor_class: Type[GPUExecutor],
|
||||
executor_class: Type[Executor],
|
||||
usage_context: UsageContext,
|
||||
input_path: str,
|
||||
output_path: str,
|
||||
@ -244,17 +225,38 @@ class EngineCoreProc(EngineCore):
|
||||
def run_engine_core(*args, **kwargs):
|
||||
"""Launch EngineCore busy loop in background process."""
|
||||
|
||||
# Signal handler used for graceful termination.
|
||||
# SystemExit exception is only raised once to allow this and worker
|
||||
# processes to terminate without error
|
||||
shutdown_requested = False
|
||||
|
||||
def signal_handler(signum, frame):
|
||||
nonlocal shutdown_requested
|
||||
if not shutdown_requested:
|
||||
shutdown_requested = True
|
||||
raise SystemExit()
|
||||
|
||||
# Either SIGTERM or SIGINT will terminate the engine_core
|
||||
signal.signal(signal.SIGTERM, signal_handler)
|
||||
signal.signal(signal.SIGINT, signal_handler)
|
||||
|
||||
engine_core = None
|
||||
try:
|
||||
engine_core = EngineCoreProc(*args, **kwargs)
|
||||
engine_core.run_busy_loop()
|
||||
|
||||
except KeyboardInterrupt:
|
||||
except SystemExit:
|
||||
logger.debug("EngineCore interrupted.")
|
||||
|
||||
except BaseException as e:
|
||||
logger.exception(e)
|
||||
raise e
|
||||
|
||||
finally:
|
||||
if engine_core is not None:
|
||||
engine_core.shutdown()
|
||||
engine_core = None
|
||||
|
||||
def run_busy_loop(self):
|
||||
"""Core busy loop of the EngineCore."""
|
||||
|
||||
@ -272,6 +274,8 @@ class EngineCoreProc(EngineCore):
|
||||
logger.debug("EngineCore busy loop waiting.")
|
||||
if self.should_shutdown:
|
||||
return
|
||||
except BaseException:
|
||||
raise
|
||||
|
||||
# 2) Handle any new client requests (Abort or Add).
|
||||
while not self.input_queue.empty():
|
||||
@ -321,7 +325,7 @@ class EngineCoreProc(EngineCore):
|
||||
decoder_add_req = PickleEncoder()
|
||||
decoder_abort_req = PickleEncoder()
|
||||
|
||||
with self.make_socket(input_path, zmq.constants.PULL) as socket:
|
||||
with make_zmq_socket(input_path, zmq.constants.PULL) as socket:
|
||||
while True:
|
||||
# (RequestType, RequestData)
|
||||
type_frame, data_frame = socket.recv_multipart(copy=False)
|
||||
@ -349,7 +353,7 @@ class EngineCoreProc(EngineCore):
|
||||
# Reuse send buffer.
|
||||
buffer = bytearray()
|
||||
|
||||
with self.make_socket(output_path, zmq.constants.PUSH) as socket:
|
||||
with make_zmq_socket(output_path, zmq.constants.PUSH) as socket:
|
||||
while True:
|
||||
engine_core_outputs = self.output_queue.get()
|
||||
outputs = EngineCoreOutputs(outputs=engine_core_outputs)
|
||||
|
@ -1,5 +1,4 @@
|
||||
import multiprocessing
|
||||
import time
|
||||
from typing import List, Union
|
||||
|
||||
import msgspec
|
||||
@ -7,7 +6,7 @@ import zmq
|
||||
import zmq.asyncio
|
||||
|
||||
from vllm.logger import init_logger
|
||||
from vllm.utils import get_open_zmq_ipc_path
|
||||
from vllm.utils import get_open_zmq_ipc_path, kill_process_tree
|
||||
from vllm.v1.engine import (EngineCoreOutput, EngineCoreOutputs,
|
||||
EngineCoreProfile, EngineCoreRequest,
|
||||
EngineCoreRequestType)
|
||||
@ -99,6 +98,12 @@ class InprocClient(EngineCoreClient):
|
||||
def abort_requests(self, request_ids: List[str]) -> None:
|
||||
self.engine_core.abort_requests(request_ids)
|
||||
|
||||
def shutdown(self):
|
||||
self.engine_core.shutdown()
|
||||
|
||||
def __del__(self):
|
||||
self.shutdown()
|
||||
|
||||
async def profile(self, is_start=True) -> None:
|
||||
self.engine_core.profile(is_start)
|
||||
|
||||
@ -163,10 +168,10 @@ class MPClient(EngineCoreClient):
|
||||
# Shutdown the process if needed.
|
||||
if hasattr(self, "proc") and self.proc.is_alive():
|
||||
self.proc.terminate()
|
||||
self.proc.join(5)
|
||||
|
||||
time.sleep(5)
|
||||
if self.proc.is_alive():
|
||||
self.proc.kill()
|
||||
kill_process_tree(self.proc.pid)
|
||||
|
||||
def __del__(self):
|
||||
self.shutdown()
|
||||
|
@ -20,7 +20,7 @@ from vllm.usage.usage_lib import UsageContext
|
||||
from vllm.v1.engine.core_client import EngineCoreClient
|
||||
from vllm.v1.engine.detokenizer import Detokenizer
|
||||
from vllm.v1.engine.processor import Processor
|
||||
from vllm.v1.executor.gpu_executor import GPUExecutor
|
||||
from vllm.v1.executor.abstract import Executor
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
@ -33,7 +33,7 @@ class LLMEngine:
|
||||
def __init__(
|
||||
self,
|
||||
vllm_config: VllmConfig,
|
||||
executor_class: Type[GPUExecutor],
|
||||
executor_class: Type[Executor],
|
||||
log_stats: bool,
|
||||
usage_context: UsageContext = UsageContext.ENGINE_CONTEXT,
|
||||
stat_loggers: Optional[Dict[str, StatLoggerBase]] = None,
|
||||
@ -104,10 +104,17 @@ class LLMEngine:
|
||||
|
||||
@classmethod
|
||||
def _get_executor_cls(cls, vllm_config: VllmConfig):
|
||||
return GPUExecutor
|
||||
distributed_executor_backend = (
|
||||
vllm_config.parallel_config.distributed_executor_backend)
|
||||
if distributed_executor_backend == "mp":
|
||||
from vllm.v1.executor.multiproc_executor import MultiprocExecutor
|
||||
executor_class = MultiprocExecutor
|
||||
else:
|
||||
assert (distributed_executor_backend is None)
|
||||
from vllm.v1.executor.uniproc_executor import UniprocExecutor
|
||||
executor_class = UniprocExecutor
|
||||
|
||||
def stop_remote_worker_execution_loop(self) -> None:
|
||||
raise NotImplementedError("TP not implemented yet.")
|
||||
return executor_class
|
||||
|
||||
def get_num_unfinished_requests(self) -> int:
|
||||
return self.detokenizer.get_num_unfinished_requests()
|
||||
|
48
vllm/v1/executor/abstract.py
Normal file
48
vllm/v1/executor/abstract.py
Normal file
@ -0,0 +1,48 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Dict, Optional, Tuple
|
||||
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.v1.outputs import ModelRunnerOutput
|
||||
|
||||
|
||||
class Executor(ABC):
|
||||
"""Abstract class for executors."""
|
||||
|
||||
@abstractmethod
|
||||
def __init__(self, vllm_config: VllmConfig) -> None:
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def initialize(self, num_gpu_blocks: int) -> None:
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def determine_num_available_blocks(self) -> Tuple[int, int]:
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def execute_model(
|
||||
self,
|
||||
scheduler_output,
|
||||
) -> ModelRunnerOutput:
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def profile(self, is_start=True):
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def shutdown(self):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def check_health(self) -> None:
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def collective_rpc(self,
|
||||
method: str,
|
||||
timeout: Optional[float] = None,
|
||||
args: Tuple = (),
|
||||
kwargs: Optional[Dict] = None) -> []:
|
||||
raise NotImplementedError
|
375
vllm/v1/executor/multiproc_executor.py
Normal file
375
vllm/v1/executor/multiproc_executor.py
Normal file
@ -0,0 +1,375 @@
|
||||
import atexit
|
||||
import os
|
||||
import pickle
|
||||
import signal
|
||||
import sys
|
||||
import time
|
||||
from dataclasses import dataclass
|
||||
from enum import Enum, auto
|
||||
from multiprocessing.process import BaseProcess
|
||||
from typing import Dict, List, Optional, Tuple
|
||||
|
||||
import zmq
|
||||
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.distributed import (destroy_distributed_environment,
|
||||
destroy_model_parallel)
|
||||
from vllm.distributed.device_communicators.shm_broadcast import (Handle,
|
||||
MessageQueue)
|
||||
from vllm.executor.multiproc_worker_utils import (
|
||||
_add_prefix, get_mp_context, set_multiprocessing_worker_envs)
|
||||
from vllm.logger import init_logger
|
||||
from vllm.utils import (get_distributed_init_method, get_open_port,
|
||||
get_open_zmq_ipc_path)
|
||||
from vllm.v1.outputs import ModelRunnerOutput
|
||||
from vllm.v1.utils import make_zmq_socket
|
||||
from vllm.worker.worker_base import WorkerWrapperBase
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
POLLING_TIMEOUT_MS = 5000
|
||||
POLLING_TIMEOUT_S = POLLING_TIMEOUT_MS // 1000
|
||||
|
||||
|
||||
class MultiprocExecutor:
|
||||
|
||||
def __init__(self, vllm_config: VllmConfig) -> None:
|
||||
# Call self.shutdown at exit to clean up
|
||||
# and ensure workers will be terminated.
|
||||
atexit.register(self.shutdown)
|
||||
|
||||
self.vllm_config = vllm_config
|
||||
self.parallel_config = vllm_config.parallel_config
|
||||
|
||||
self.world_size = self.parallel_config.world_size
|
||||
tensor_parallel_size = self.parallel_config.tensor_parallel_size
|
||||
assert self.world_size == tensor_parallel_size, (
|
||||
f"world_size ({self.world_size}) must be equal to the "
|
||||
f"tensor_parallel_size ({tensor_parallel_size}). "
|
||||
f"Pipeline parallelism is not yet implemented in v1")
|
||||
|
||||
# Set multiprocessing envs that are common to V0 and V1
|
||||
set_multiprocessing_worker_envs(self.parallel_config)
|
||||
|
||||
# Multiprocessing-based executor does not support multi-node setting.
|
||||
# Since it only works for single node, we can use the loopback address
|
||||
# 127.0.0.1 for communication.
|
||||
distributed_init_method = get_distributed_init_method(
|
||||
"127.0.0.1", get_open_port())
|
||||
|
||||
# Initialize worker and set up message queues for SchedulerOutputs
|
||||
# and ModelRunnerOutputs
|
||||
self.rpc_broadcast_mq = MessageQueue(self.world_size, self.world_size)
|
||||
scheduler_output_handle = self.rpc_broadcast_mq.export_handle()
|
||||
|
||||
# Create workers
|
||||
self.workers: List[WorkerProcHandle] = []
|
||||
for rank in range(self.world_size):
|
||||
worker = WorkerProc.make_worker_process(vllm_config, rank, rank,
|
||||
distributed_init_method,
|
||||
scheduler_output_handle)
|
||||
self.workers.append(worker)
|
||||
|
||||
# Ensure message queues are ready. Will deadlock if re-ordered
|
||||
# Must be kept consistent with the WorkerProc
|
||||
self.rpc_broadcast_mq.wait_until_ready()
|
||||
for w in self.workers:
|
||||
w.worker_response_mq.wait_until_ready()
|
||||
|
||||
def initialize(self, num_gpu_blocks: int) -> None:
|
||||
"""
|
||||
Initialize the KV caches and begin the model execution loop of the
|
||||
underlying workers.
|
||||
"""
|
||||
self.collective_rpc("initialize_cache", args=(num_gpu_blocks, ))
|
||||
self.collective_rpc("compile_or_warm_up_model")
|
||||
|
||||
def determine_num_available_blocks(self) -> Tuple[int, int]:
|
||||
"""
|
||||
Determine the number of available KV blocks by invoking the
|
||||
underlying worker.
|
||||
"""
|
||||
num_blocks = self.collective_rpc("determine_num_available_blocks")
|
||||
|
||||
# Since we use a shared centralized controller, we take the minimum
|
||||
# number of blocks across all workers to make sure all the memory
|
||||
# operators can be applied to all workers.
|
||||
num_gpu_blocks = min(b[0] for b in num_blocks)
|
||||
num_cpu_blocks = min(b[1] for b in num_blocks)
|
||||
|
||||
return num_gpu_blocks, num_cpu_blocks
|
||||
|
||||
def collective_rpc(self,
|
||||
method: str,
|
||||
timeout: Optional[float] = None,
|
||||
args: Tuple = (),
|
||||
kwargs: Optional[Dict] = None) -> []:
|
||||
"""
|
||||
Execute an RPC call on workers.
|
||||
|
||||
Args:
|
||||
method: Name of the worker method to execute
|
||||
timeout: Maximum time in seconds to wait for execution. Rases a
|
||||
TimeoutError on timeout. None means wait indefinitely.
|
||||
args: Positional arguments to pass to the worker method
|
||||
kwargs: Keyword arguments to pass to the worker method
|
||||
|
||||
Returns:
|
||||
List of results from each worker
|
||||
"""
|
||||
start_time = time.monotonic()
|
||||
kwargs = kwargs or {}
|
||||
|
||||
try:
|
||||
self.rpc_broadcast_mq.enqueue((method, args, kwargs))
|
||||
|
||||
responses = [None] * self.world_size
|
||||
for w in self.workers:
|
||||
dequeue_timeout = timeout - (time.monotonic() - start_time()
|
||||
) if timeout is not None else None
|
||||
status, result = w.worker_response_mq.dequeue(
|
||||
timeout=dequeue_timeout)
|
||||
|
||||
if status != WorkerProc.ResponseStatus.SUCCESS:
|
||||
if isinstance(result, Exception):
|
||||
raise result
|
||||
else:
|
||||
raise RuntimeError("Worker failed")
|
||||
|
||||
responses[w.rank] = result
|
||||
|
||||
return responses
|
||||
except TimeoutError as e:
|
||||
raise TimeoutError(f"RPC call to {method} timed out.") from e
|
||||
except Exception as e:
|
||||
# Re-raise any other exceptions
|
||||
raise e
|
||||
|
||||
def execute_model(
|
||||
self,
|
||||
scheduler_output,
|
||||
) -> ModelRunnerOutput:
|
||||
model_output = self.collective_rpc("execute_model",
|
||||
args=(scheduler_output, ))[0]
|
||||
return model_output
|
||||
|
||||
def profile(self, is_start=True):
|
||||
self.collective_rpc("profile", args=(is_start, ))
|
||||
return
|
||||
|
||||
def _ensure_worker_termination(self):
|
||||
"""Ensure that all worker processes are terminated. Assumes workers have
|
||||
received termination requests. Waits for processing, then sends
|
||||
termination and kill signals if needed."""
|
||||
|
||||
def wait_for_termination(procs, timeout):
|
||||
start_time = time.time()
|
||||
while time.time() - start_time < timeout:
|
||||
if all(not proc.is_alive() for proc in procs):
|
||||
return True
|
||||
time.sleep(0.1)
|
||||
return False
|
||||
|
||||
# Send SIGTERM if still running
|
||||
active_procs = [w.proc for w in self.workers if w.proc.is_alive()]
|
||||
self.workers = None
|
||||
for p in active_procs:
|
||||
p.terminate()
|
||||
if wait_for_termination(active_procs, 4):
|
||||
return
|
||||
|
||||
# Send SIGKILL if still running
|
||||
active_procs = [p for p in active_procs if p.is_alive()]
|
||||
for p in active_procs:
|
||||
p.kill()
|
||||
|
||||
def shutdown(self):
|
||||
"""Properly shut down the executor and its workers"""
|
||||
if (hasattr(self, 'workers') and self.workers is not None):
|
||||
for w in self.workers: #TODO: not sure if needed
|
||||
w.worker_response_mq = None
|
||||
self._ensure_worker_termination()
|
||||
|
||||
self.rpc_broadcast_mq = None
|
||||
|
||||
def check_health(self) -> None:
|
||||
self.collective_rpc("check_health", timeout=10)
|
||||
return
|
||||
|
||||
|
||||
@dataclass
|
||||
class WorkerProcHandle:
|
||||
proc: BaseProcess
|
||||
rank: int
|
||||
ready_path: str
|
||||
worker_response_mq: MessageQueue # The worker process writes to this MQ
|
||||
|
||||
|
||||
class WorkerProc:
|
||||
"""Wrapper that runs one Worker in a separate process."""
|
||||
|
||||
READY_STR = "READY"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
vllm_config: VllmConfig,
|
||||
local_rank: int,
|
||||
rank: int,
|
||||
distributed_init_method: str,
|
||||
input_shm_handle: Handle,
|
||||
ready_path: str,
|
||||
):
|
||||
self.rank = rank
|
||||
wrapper = WorkerWrapperBase(vllm_config=vllm_config)
|
||||
wrapper.init_worker(vllm_config, local_rank, rank,
|
||||
distributed_init_method)
|
||||
self.worker = wrapper.worker
|
||||
|
||||
pid = os.getpid()
|
||||
_add_prefix(sys.stdout, f"VllmWorker rank={rank}", pid)
|
||||
_add_prefix(sys.stderr, f"VllmWorker rank={rank}", pid)
|
||||
|
||||
# Initialize MessageQueue for receiving SchedulerOutput
|
||||
self.rpc_broadcast_mq = MessageQueue.create_from_handle(
|
||||
input_shm_handle, self.worker.rank)
|
||||
|
||||
# Initializes a message queue for sending the model output
|
||||
self.worker_response_mq = MessageQueue(1, 1)
|
||||
worker_response_mq_handle = self.worker_response_mq.export_handle()
|
||||
|
||||
# Send Readiness signal to EngineCore process.
|
||||
with make_zmq_socket(ready_path, zmq.constants.PUSH) as ready_socket:
|
||||
payload = pickle.dumps(worker_response_mq_handle,
|
||||
protocol=pickle.HIGHEST_PROTOCOL)
|
||||
ready_socket.send_string(WorkerProc.READY_STR)
|
||||
ready_socket.send(payload)
|
||||
|
||||
self.worker.initialize()
|
||||
self.worker.load_model()
|
||||
|
||||
@staticmethod
|
||||
def make_worker_process(
|
||||
vllm_config: VllmConfig,
|
||||
local_rank: int,
|
||||
rank: int,
|
||||
distributed_init_method: str,
|
||||
input_shm_handle, # Receive SchedulerOutput
|
||||
) -> WorkerProcHandle:
|
||||
context = get_mp_context()
|
||||
|
||||
# ZMQ path for worker to send ready message and shm_broadcast handle
|
||||
# back to core process.
|
||||
ready_path = get_open_zmq_ipc_path()
|
||||
|
||||
process_kwargs = {
|
||||
"vllm_config": vllm_config,
|
||||
"local_rank": local_rank,
|
||||
"rank": rank,
|
||||
"distributed_init_method": distributed_init_method,
|
||||
"input_shm_handle": input_shm_handle,
|
||||
"ready_path": ready_path,
|
||||
}
|
||||
# Run EngineCore busy loop in background process.
|
||||
proc = context.Process(target=WorkerProc.worker_main,
|
||||
kwargs=process_kwargs,
|
||||
daemon=True)
|
||||
proc.start()
|
||||
|
||||
# Wait for startup
|
||||
worker_response_mq_handle = WorkerProc.wait_for_startup(
|
||||
proc, ready_path)
|
||||
|
||||
worker_response_mq = MessageQueue.create_from_handle(
|
||||
worker_response_mq_handle, 0)
|
||||
|
||||
return WorkerProcHandle(proc, rank, ready_path, worker_response_mq)
|
||||
|
||||
def shutdown(self):
|
||||
self.rpc_broadcast_mq = None
|
||||
self.worker_response_mq = None
|
||||
destroy_model_parallel()
|
||||
destroy_distributed_environment()
|
||||
|
||||
@staticmethod
|
||||
def worker_main(*args, **kwargs):
|
||||
""" Worker initialization and execution loops.
|
||||
This runs a background process """
|
||||
|
||||
# Signal handler used for graceful termination.
|
||||
# SystemExit exception is only raised once to allow this and worker
|
||||
# processes to terminate without error
|
||||
shutdown_requested = False
|
||||
|
||||
def signal_handler(signum, frame):
|
||||
nonlocal shutdown_requested
|
||||
if not shutdown_requested:
|
||||
shutdown_requested = True
|
||||
raise SystemExit()
|
||||
|
||||
# Either SIGTERM or SIGINT will terminate the worker
|
||||
signal.signal(signal.SIGTERM, signal_handler)
|
||||
signal.signal(signal.SIGINT, signal_handler)
|
||||
|
||||
worker = None
|
||||
try:
|
||||
worker = WorkerProc(*args, **kwargs)
|
||||
|
||||
# Ensure message queues are ready. Will deadlock if re-ordered.
|
||||
# Must be kept consistent with the Executor
|
||||
worker.rpc_broadcast_mq.wait_until_ready()
|
||||
worker.worker_response_mq.wait_until_ready()
|
||||
|
||||
worker.worker_busy_loop()
|
||||
|
||||
except SystemExit:
|
||||
logger.debug("Worker interrupted.")
|
||||
|
||||
except BaseException as e:
|
||||
logger.exception(e)
|
||||
raise
|
||||
|
||||
finally:
|
||||
# Clean up once worker exits busy loop
|
||||
if worker is not None:
|
||||
worker.shutdown()
|
||||
worker = None
|
||||
|
||||
@staticmethod
|
||||
def wait_for_startup(
|
||||
proc: BaseProcess,
|
||||
ready_path: str,
|
||||
) -> Optional[Handle]:
|
||||
"""Wait until the Worker is ready."""
|
||||
with make_zmq_socket(ready_path, zmq.constants.PULL) as socket:
|
||||
|
||||
# Wait for Worker to send READY.
|
||||
while socket.poll(timeout=POLLING_TIMEOUT_MS) == 0:
|
||||
logger.debug("Waiting for WorkerProc to startup.")
|
||||
|
||||
if not proc.is_alive():
|
||||
raise RuntimeError("WorkerProc failed to start.")
|
||||
|
||||
message = socket.recv_string()
|
||||
assert message == WorkerProc.READY_STR
|
||||
handle_frame = socket.recv(copy=False)
|
||||
handle = pickle.loads(handle_frame.buffer)
|
||||
return handle
|
||||
|
||||
class ResponseStatus(Enum):
|
||||
SUCCESS = auto()
|
||||
FAILURE = auto()
|
||||
|
||||
def worker_busy_loop(self):
|
||||
"""Main busy loop for Multiprocessing Workers"""
|
||||
while True:
|
||||
method, args, kwargs = self.rpc_broadcast_mq.dequeue()
|
||||
|
||||
try:
|
||||
output = getattr(self.worker, method)(*args, **kwargs)
|
||||
except BaseException as e:
|
||||
self.worker_response_mq.enqueue(
|
||||
(WorkerProc.ResponseStatus.FAILURE, e))
|
||||
continue
|
||||
|
||||
self.worker_response_mq.enqueue(
|
||||
(WorkerProc.ResponseStatus.SUCCESS, output))
|
@ -10,7 +10,7 @@ from vllm.v1.worker.gpu_worker import Worker
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
class GPUExecutor:
|
||||
class UniprocExecutor:
|
||||
|
||||
def __init__(self, vllm_config: VllmConfig) -> None:
|
||||
self.vllm_config = vllm_config
|
||||
@ -54,7 +54,7 @@ class GPUExecutor:
|
||||
"""
|
||||
return self.worker.determine_num_available_blocks()
|
||||
|
||||
def initialize_cache(self, num_gpu_blocks: int) -> None:
|
||||
def initialize(self, num_gpu_blocks: int) -> None:
|
||||
"""Initialize the KV cache by invoking the underlying worker.
|
||||
"""
|
||||
# NOTE: This is logged in the executor because there can be >1 worker
|
||||
@ -71,7 +71,13 @@ class GPUExecutor:
|
||||
output = self.worker.execute_model(scheduler_output)
|
||||
return output
|
||||
|
||||
def profile(self, is_start: bool = True):
|
||||
self.worker.profile(is_start)
|
||||
|
||||
def shutdown(self):
|
||||
self.worker = None
|
||||
|
||||
def check_health(self) -> None:
|
||||
# GPUExecutor will always be healthy as long as
|
||||
# UniprocExecutor will always be healthy as long as
|
||||
# it's running.
|
||||
return
|
@ -8,7 +8,7 @@ import torch
|
||||
class SamplerOutput:
|
||||
|
||||
# [num_reqs]
|
||||
sampled_token_ids: torch.Tensor
|
||||
sampled_token_ids: List[int]
|
||||
|
||||
# [num_reqs, max_num_logprobs + 1]
|
||||
logprob_token_ids: Optional[torch.Tensor]
|
||||
@ -20,6 +20,8 @@ class SamplerOutput:
|
||||
prompt_logprobs: Optional[torch.Tensor]
|
||||
|
||||
|
||||
# ModelRunnerOutput is serialized and sent to the scheduler process.
|
||||
# This is expensive for torch.Tensor so prefer to use List instead.
|
||||
@dataclass
|
||||
class ModelRunnerOutput:
|
||||
|
||||
@ -29,7 +31,7 @@ class ModelRunnerOutput:
|
||||
req_id_to_index: Dict[str, int]
|
||||
|
||||
# [num_reqs]
|
||||
sampled_token_ids_cpu: torch.Tensor
|
||||
sampled_token_ids: List[int]
|
||||
|
||||
# [num_reqs, max_num_logprobs + 1]
|
||||
logprob_token_ids_cpu: Optional[torch.Tensor]
|
||||
|
@ -37,8 +37,9 @@ class Sampler(nn.Module):
|
||||
topk_logprobs = None
|
||||
topk_indices = None
|
||||
|
||||
# NOTE: CPU-GPU synchronization happens here.
|
||||
sampler_output = SamplerOutput(
|
||||
sampled_token_ids=sampled,
|
||||
sampled_token_ids=sampled.tolist(),
|
||||
logprob_token_ids=topk_indices,
|
||||
logprobs=topk_logprobs,
|
||||
prompt_logprob_token_ids=None,
|
||||
|
@ -1,4 +1,11 @@
|
||||
from typing import Generic, List, TypeVar, overload
|
||||
from contextlib import contextmanager
|
||||
from typing import Any, Generic, Iterator, List, TypeVar, overload
|
||||
|
||||
import zmq
|
||||
|
||||
from vllm.logger import init_logger
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
T = TypeVar("T")
|
||||
|
||||
@ -62,3 +69,27 @@ class ConstantList(Generic[T]):
|
||||
|
||||
def __len__(self):
|
||||
return len(self._x)
|
||||
|
||||
|
||||
@contextmanager
|
||||
def make_zmq_socket(path: str, type: Any) -> Iterator[zmq.Socket]:
|
||||
"""Context manager for a ZMQ socket"""
|
||||
|
||||
ctx = zmq.Context()
|
||||
try:
|
||||
socket = ctx.socket(type)
|
||||
|
||||
if type == zmq.constants.PULL:
|
||||
socket.connect(path)
|
||||
elif type == zmq.constants.PUSH:
|
||||
socket.bind(path)
|
||||
else:
|
||||
raise ValueError(f"Unknown Socket Type: {type}")
|
||||
|
||||
yield socket
|
||||
|
||||
except KeyboardInterrupt:
|
||||
logger.debug("Worker had Keyboard Interrupt.")
|
||||
|
||||
finally:
|
||||
ctx.destroy(linger=0)
|
||||
|
@ -34,6 +34,7 @@ class GPUModelRunner:
|
||||
def __init__(
|
||||
self,
|
||||
vllm_config: VllmConfig,
|
||||
device: torch.device,
|
||||
input_registry: InputRegistry = INPUT_REGISTRY,
|
||||
):
|
||||
self.vllm_config = vllm_config
|
||||
@ -43,7 +44,6 @@ class GPUModelRunner:
|
||||
self.load_config = vllm_config.load_config
|
||||
self.parallel_config = vllm_config.parallel_config
|
||||
self.scheduler_config = vllm_config.scheduler_config
|
||||
self.device_config = vllm_config.device_config
|
||||
self.speculative_config = vllm_config.speculative_config
|
||||
self.prompt_adapter_config = vllm_config.prompt_adapter_config
|
||||
self.observability_config = vllm_config.observability_config
|
||||
@ -52,7 +52,7 @@ class GPUModelRunner:
|
||||
cache_config = self.cache_config
|
||||
scheduler_config = self.scheduler_config
|
||||
parallel_config = self.parallel_config
|
||||
self.device = self.device_config.device
|
||||
self.device = device
|
||||
self.pin_memory = is_pin_memory_available()
|
||||
self.dtype = self.model_config.dtype
|
||||
if cache_config.cache_dtype == "auto":
|
||||
@ -477,9 +477,7 @@ class GPUModelRunner:
|
||||
sampling_metadata=sampling_metadata,
|
||||
)
|
||||
|
||||
# NOTE: CPU-GPU synchronization happens here.
|
||||
sampled_token_ids = sampler_output.sampled_token_ids.cpu()
|
||||
sampled_token_ids_list = sampled_token_ids.tolist()
|
||||
sampled_token_ids = sampler_output.sampled_token_ids
|
||||
# TODO(woosuk): The following loop can be slow since it iterates over
|
||||
# the requests one by one. Optimize.
|
||||
num_reqs = self.input_batch.num_reqs
|
||||
@ -490,7 +488,7 @@ class GPUModelRunner:
|
||||
assert seq_len <= req_state.num_tokens
|
||||
if seq_len == req_state.num_tokens:
|
||||
# Append the sampled token to the output token ids.
|
||||
token_id = sampled_token_ids_list[i]
|
||||
token_id = sampled_token_ids[i]
|
||||
self.input_batch.token_ids_cpu[i, seq_len] = token_id
|
||||
req_state.output_token_ids.append(token_id)
|
||||
else:
|
||||
@ -512,7 +510,7 @@ class GPUModelRunner:
|
||||
model_runner_output = ModelRunnerOutput(
|
||||
req_ids=self.input_batch.req_ids[:num_reqs],
|
||||
req_id_to_index=self.input_batch.req_id_to_index,
|
||||
sampled_token_ids_cpu=sampled_token_ids,
|
||||
sampled_token_ids=sampled_token_ids,
|
||||
logprob_token_ids_cpu=logprob_token_ids,
|
||||
logprobs_cpu=logprobs,
|
||||
)
|
||||
|
@ -15,6 +15,7 @@ from vllm.logger import init_logger
|
||||
from vllm.model_executor import set_random_seed
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE, get_dtype_size
|
||||
from vllm.v1.core.scheduler import SchedulerOutput
|
||||
from vllm.v1.outputs import ModelRunnerOutput
|
||||
from vllm.v1.worker.gpu_model_runner import GPUModelRunner
|
||||
|
||||
@ -56,7 +57,6 @@ class Worker:
|
||||
from vllm.utils import init_cached_hf_modules
|
||||
init_cached_hf_modules()
|
||||
|
||||
self.model_runner = GPUModelRunner(vllm_config)
|
||||
# Torch profiler. Enabled and configured through env vars:
|
||||
# VLLM_TORCH_PROFILER_DIR=/path/to/save/trace
|
||||
if envs.VLLM_TORCH_PROFILER_DIR:
|
||||
@ -103,6 +103,9 @@ class Worker:
|
||||
# Set random seed.
|
||||
set_random_seed(self.model_config.seed)
|
||||
|
||||
# Construct the model runner
|
||||
self.model_runner = GPUModelRunner(self.vllm_config, self.device)
|
||||
|
||||
def load_model(self) -> None:
|
||||
self.model_runner.load_model()
|
||||
|
||||
@ -198,7 +201,7 @@ class Worker:
|
||||
scheduler_output: "SchedulerOutput",
|
||||
) -> ModelRunnerOutput:
|
||||
output = self.model_runner.execute_model(scheduler_output)
|
||||
# TODO(woosuk): Send the output to the engine process.
|
||||
return output if self.rank == 0 else None
|
||||
return output
|
||||
|
||||
def profile(self, is_start=True):
|
||||
@ -209,6 +212,10 @@ class Worker:
|
||||
else:
|
||||
self.profiler.stop()
|
||||
|
||||
def check_health(self) -> None:
|
||||
# worker will always be healthy as long as it's running.
|
||||
return
|
||||
|
||||
|
||||
def init_worker_distributed_environment(
|
||||
parallel_config: ParallelConfig,
|
||||
|
Loading…
x
Reference in New Issue
Block a user