[V1] AsyncLLM data parallel (#13923)

Signed-off-by: Nick Hill <nhill@redhat.com>
This commit is contained in:
Nick Hill 2025-03-27 16:14:41 -07:00 committed by GitHub
parent 112b3e5b3b
commit 15dac210f0
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
18 changed files with 722 additions and 156 deletions

View File

@ -135,12 +135,14 @@ steps:
- examples/offline_inference/rlhf.py
- examples/offline_inference/rlhf_colocate.py
- tests/examples/offline_inference/data_parallel.py
- tests/v1/test_async_llm_dp.py
commands:
# test with tp=2 and external_dp=2
- VLLM_USE_V1=0 torchrun --nproc-per-node=4 distributed/test_torchrun_example.py
- torchrun --nproc-per-node=4 distributed/test_torchrun_example.py
# test with internal dp
- python3 ../examples/offline_inference/data_parallel.py
- TP_SIZE=2 DP_SIZE=2 pytest -v -s v1/test_async_llm_dp.py
- pytest -v -s distributed/test_utils.py
- pytest -v -s compile/test_basic_correctness.py
- pytest -v -s distributed/test_pynccl.py
@ -514,7 +516,10 @@ steps:
- vllm/worker/worker.py
- vllm/worker/model_runner.py
- entrypoints/llm/test_collective_rpc.py
- tests/v1/test_async_llm_dp.py
- vllm/v1/engine/
commands:
- TP_SIZE=1 DP_SIZE=2 pytest -v -s v1/test_async_llm_dp.py
- VLLM_ENABLE_V1_MULTIPROCESSING=0 pytest -v -s entrypoints/llm/test_collective_rpc.py
- pytest -v -s ./compile/test_basic_correctness.py
- pytest -v -s ./compile/test_wrapper.py

View File

@ -28,6 +28,7 @@ Multi-node:
--master-port=13345
"""
import os
from time import sleep
from vllm import LLM, SamplingParams
from vllm.utils import get_open_port
@ -36,14 +37,13 @@ from vllm.utils import get_open_port
def main(model, dp_size, local_dp_rank, global_dp_rank, dp_master_ip,
dp_master_port, GPUs_per_dp_rank):
os.environ["VLLM_DP_RANK"] = str(global_dp_rank)
os.environ["VLLM_DP_RANK_LOCAL"] = str(local_dp_rank)
os.environ["VLLM_DP_SIZE"] = str(dp_size)
os.environ["VLLM_DP_MASTER_IP"] = dp_master_ip
os.environ["VLLM_DP_MASTER_PORT"] = str(dp_master_port)
# set devices for each dp_rank
os.environ["CUDA_VISIBLE_DEVICES"] = ",".join(
str(i)
for i in range(local_dp_rank * GPUs_per_dp_rank, (local_dp_rank + 1) *
GPUs_per_dp_rank))
# CUDA_VISIBLE_DEVICES for each DP rank is set automatically inside the
# engine processes.
# Sample prompts.
prompts = [
@ -90,6 +90,9 @@ def main(model, dp_size, local_dp_rank, global_dp_rank, dp_master_ip,
print(f"DP rank {global_dp_rank}, Prompt: {prompt!r}, "
f"Generated text: {generated_text!r}")
# Give engines time to pause their processing loops before exiting.
sleep(1)
if __name__ == "__main__":
import argparse
@ -152,8 +155,13 @@ if __name__ == "__main__":
procs.append(proc)
exit_code = 0
for proc in procs:
proc.join()
if proc.exitcode:
proc.join(timeout=300)
if proc.exitcode is None:
print(f"Killing process {proc.pid} that "
f"didn't stop within 5 minutes.")
proc.kill()
exit_code = 1
elif proc.exitcode:
exit_code = proc.exitcode
exit(exit_code)

View File

@ -167,11 +167,11 @@ def test_engine_core_client(monkeypatch: pytest.MonkeyPatch,
core_client: SyncMPClient = client
result = core_client._call_utility("echo", "testarg")
result = core_client.call_utility("echo", "testarg")
assert result == "testarg"
with pytest.raises(Exception) as e_info:
core_client._call_utility("echo", None, "help!")
core_client.call_utility("echo", None, "help!")
assert str(e_info.value) == "Call to echo method failed: help!"
@ -238,10 +238,10 @@ async def test_engine_core_client_asyncio(monkeypatch: pytest.MonkeyPatch):
core_client: AsyncMPClient = client
result = await core_client._call_utility_async("echo", "testarg")
result = await core_client.call_utility_async("echo", "testarg")
assert result == "testarg"
with pytest.raises(Exception) as e_info:
await core_client._call_utility_async("echo", None, "help!")
await core_client.call_utility_async("echo", None, "help!")
assert str(e_info.value) == "Call to echo method failed: help!"

View File

@ -0,0 +1,109 @@
# SPDX-License-Identifier: Apache-2.0
import asyncio
import os
from contextlib import ExitStack
from typing import Optional
import pytest
from vllm import SamplingParams
from vllm.engine.arg_utils import AsyncEngineArgs
from vllm.inputs import PromptType
from vllm.platforms import current_platform
from vllm.sampling_params import RequestOutputKind
from vllm.v1.engine.async_llm import AsyncLLM
from vllm.v1.engine.core_client import DPAsyncMPClient
engine_args = AsyncEngineArgs(
model="ibm-research/PowerMoE-3b",
enforce_eager=True,
disable_log_requests=True,
tensor_parallel_size=int(os.getenv("TP_SIZE", 1)),
data_parallel_size=int(os.getenv("DP_SIZE", 2)),
)
if not current_platform.supports_v1(engine_args.create_model_config()):
pytest.skip(reason="Requires V1-supporting platform.",
allow_module_level=True)
async def generate(engine: AsyncLLM,
request_id: str,
prompt: PromptType,
output_kind: RequestOutputKind,
max_tokens: int,
prompt_logprobs: Optional[int] = None) -> tuple[int, str]:
# Ensure generate doesn't complete too fast for cancellation test.
await asyncio.sleep(0.2)
count = 0
sampling_params = SamplingParams(max_tokens=max_tokens,
ignore_eos=True,
output_kind=output_kind,
temperature=0,
prompt_logprobs=prompt_logprobs)
async for out in engine.generate(request_id=request_id,
prompt=prompt,
sampling_params=sampling_params):
num_tokens = len(out.outputs[0].token_ids)
if output_kind == RequestOutputKind.DELTA:
count += num_tokens
else:
count = num_tokens
await asyncio.sleep(0.)
return count, request_id
@pytest.mark.parametrize(
"output_kind", [RequestOutputKind.DELTA, RequestOutputKind.FINAL_ONLY])
@pytest.mark.asyncio
async def test_load(output_kind: RequestOutputKind):
with ExitStack() as after:
prompt = "This is a test of data parallel"
engine = AsyncLLM.from_engine_args(engine_args)
after.callback(engine.shutdown)
NUM_REQUESTS = 100
NUM_EXPECTED_TOKENS = 10
request_ids = [f"request-{i}" for i in range(NUM_REQUESTS)]
# Create concurrent requests.
tasks = []
for request_id in request_ids:
tasks.append(
asyncio.create_task(
generate(engine, request_id, prompt, output_kind,
NUM_EXPECTED_TOKENS)))
# Confirm that we got all the EXPECTED tokens from the requests.
done, pending = await asyncio.wait(tasks,
return_when=asyncio.FIRST_EXCEPTION)
for task in pending:
task.cancel()
for task in done:
num_generated_tokens, request_id = await task
assert num_generated_tokens == NUM_EXPECTED_TOKENS, (
f"{request_id} generated {num_generated_tokens} but "
f"expected {NUM_EXPECTED_TOKENS}")
assert not engine.output_processor.has_unfinished_requests()
# testing internals here which may break
core_client: DPAsyncMPClient = engine.engine_core
# the engines only synchronize stopping every N steps so
# allow a small amount of time here.
for _ in range(10):
if core_client.num_engines_running == 0:
break
await asyncio.sleep(0.5)
assert core_client.num_engines_running == 0
assert not core_client.reqs_in_flight

View File

@ -40,7 +40,8 @@ from vllm.transformers_utils.config import (
from vllm.transformers_utils.s3_utils import S3Model
from vllm.transformers_utils.utils import is_s3, maybe_model_redirect
from vllm.utils import (GiB_bytes, LayerBlockType, cuda_device_count_stateless,
get_cpu_memory, random_uuid, resolve_obj_by_qualname)
get_cpu_memory, get_open_port, random_uuid,
resolve_obj_by_qualname)
if TYPE_CHECKING:
from ray.util.placement_group import PlacementGroup
@ -1389,6 +1390,8 @@ class ParallelConfig:
tensor_parallel_size: int = 1 # Number of tensor parallel groups.
data_parallel_size: int = 1 # Number of data parallel groups.
data_parallel_rank: int = 0 # Rank of the data parallel group.
# Local rank of the data parallel group, defaults to global rank.
data_parallel_rank_local: Optional[int] = None
# IP of the data parallel master.
data_parallel_master_ip: str = "127.0.0.1"
data_parallel_master_port: int = 29500 # Port of the data parallel master.
@ -1493,10 +1496,18 @@ class ParallelConfig:
self.world_size = self.pipeline_parallel_size * \
self.tensor_parallel_size
self.data_parallel_size = envs.VLLM_DP_SIZE
self.data_parallel_rank = envs.VLLM_DP_RANK
self.data_parallel_master_ip = envs.VLLM_DP_MASTER_IP
self.data_parallel_master_port = envs.VLLM_DP_MASTER_PORT
if self.data_parallel_size > 1:
# Data parallel was specified in the engine args.
self.data_parallel_master_port = get_open_port()
# TODO multi-node
else:
# Otherwise fall back to env vars (e.g. for offline SPMD case).
self.data_parallel_size = envs.VLLM_DP_SIZE
self.data_parallel_rank = envs.VLLM_DP_RANK
self.data_parallel_rank_local = envs.VLLM_DP_RANK_LOCAL
self.data_parallel_master_ip = envs.VLLM_DP_MASTER_IP
self.data_parallel_master_port = envs.VLLM_DP_MASTER_PORT
self.world_size_across_dp = self.world_size * self.data_parallel_size
if self.distributed_executor_backend == "external_launcher":

View File

@ -15,6 +15,8 @@ import torch
from torch.distributed import ProcessGroup, TCPStore
from torch.distributed.distributed_c10d import (Backend, PrefixStore,
_get_default_timeout,
_shutdown_backend,
_unregister_process_group,
is_nccl_available)
from torch.distributed.rendezvous import rendezvous
@ -333,3 +335,13 @@ def stateless_init_torch_distributed_process_group(
pg._register_backend(device, backend_type, backend_class)
return pg
def stateless_destroy_torch_distributed_process_group(
pg: ProcessGroup) -> None:
"""
Destroy ProcessGroup returned by
stateless_init_torch_distributed_process_group().
"""
_shutdown_backend(pg)
_unregister_process_group(pg.group_name)

View File

@ -114,6 +114,7 @@ class EngineArgs:
# number of P/D disaggregation (or other disaggregation) workers
pipeline_parallel_size: int = 1
tensor_parallel_size: int = 1
data_parallel_size: int = 1
enable_expert_parallel: bool = False
max_parallel_loading_workers: Optional[int] = None
block_size: Optional[int] = None
@ -442,6 +443,14 @@ class EngineArgs:
type=int,
default=EngineArgs.tensor_parallel_size,
help='Number of tensor parallel replicas.')
parser.add_argument('--data-parallel-size',
'-dp',
type=int,
default=EngineArgs.data_parallel_size,
help='Number of data parallel replicas. '
'MoE layers will be sharded according to the '
'product of the tensor-parallel-size and '
'data-parallel-size.')
parser.add_argument(
'--enable-expert-parallel',
action='store_true',
@ -1359,6 +1368,7 @@ class EngineArgs:
parallel_config = ParallelConfig(
pipeline_parallel_size=self.pipeline_parallel_size,
tensor_parallel_size=self.tensor_parallel_size,
data_parallel_size=self.data_parallel_size,
enable_expert_parallel=self.enable_expert_parallel,
max_parallel_loading_workers=self.max_parallel_loading_workers,
disable_custom_all_reduce=self.disable_custom_all_reduce,

View File

@ -2,6 +2,7 @@
import hashlib
import os
import sys
import tempfile
from typing import TYPE_CHECKING, Any, Callable, Optional
@ -95,6 +96,7 @@ if TYPE_CHECKING:
VLLM_CUDART_SO_PATH: Optional[str] = None
VLLM_USE_HPU_CONTIGUOUS_CACHE_FETCH: bool = True
VLLM_DP_RANK: int = 0
VLLM_DP_RANK_LOCAL: int = -1
VLLM_DP_SIZE: int = 1
VLLM_DP_MASTER_IP: str = ""
VLLM_DP_MASTER_PORT: int = 0
@ -625,6 +627,12 @@ environment_variables: dict[str, Callable[[], Any]] = {
"VLLM_DP_RANK":
lambda: int(os.getenv("VLLM_DP_RANK", "0")),
# Rank of the process in the data parallel setting.
# Defaults to VLLM_DP_RANK when not set.
"VLLM_DP_RANK_LOCAL":
lambda: int(
os.getenv("VLLM_DP_RANK_LOCAL", sys.modules[__name__].VLLM_DP_RANK)),
# World size of the data parallel setting
"VLLM_DP_SIZE":
lambda: int(os.getenv("VLLM_DP_SIZE", "1")),

View File

@ -578,7 +578,7 @@ def get_open_port() -> int:
dp_port = envs.VLLM_DP_MASTER_PORT
while True:
port = _get_open_port()
if port >= dp_port and port < dp_port + 10:
if dp_port <= port < dp_port + 10:
continue
return port
return _get_open_port()
@ -2176,11 +2176,11 @@ def make_zmq_socket(
if socket_type == zmq.constants.PULL:
socket.setsockopt(zmq.constants.RCVHWM, 0)
socket.setsockopt(zmq.constants.RCVBUF, buf_size)
socket.connect(path)
socket.bind(path)
elif socket_type == zmq.constants.PUSH:
socket.setsockopt(zmq.constants.SNDHWM, 0)
socket.setsockopt(zmq.constants.SNDBUF, buf_size)
socket.bind(path)
socket.connect(path)
else:
raise ValueError(f"Unknown Socket Type: {socket_type}")
@ -2188,7 +2188,11 @@ def make_zmq_socket(
@contextlib.contextmanager
def zmq_socket_ctx(path: str, socket_type: Any) -> Iterator[zmq.Socket]:
def zmq_socket_ctx(
path: str,
socket_type: Any,
linger: int = 0,
) -> Iterator[zmq.Socket]:
"""Context manager for a ZMQ socket"""
ctx = zmq.Context() # type: ignore[attr-defined]
@ -2199,7 +2203,7 @@ def zmq_socket_ctx(path: str, socket_type: Any) -> Iterator[zmq.Socket]:
logger.debug("Got Keyboard Interrupt.")
finally:
ctx.destroy(linger=0)
ctx.destroy(linger=linger)
def is_in_ray_actor():

View File

@ -37,9 +37,10 @@ class Scheduler(SchedulerInterface):
cache_config: CacheConfig,
lora_config: Optional[LoRAConfig],
speculative_config: Optional[SpeculativeConfig],
log_stats: bool,
structured_output_manager: StructuredOutputManager,
mm_registry: MultiModalRegistry = MULTIMODAL_REGISTRY,
include_finished_set: bool = False,
log_stats: bool = False,
) -> None:
self.scheduler_config = scheduler_config
self.cache_config = cache_config
@ -48,6 +49,12 @@ class Scheduler(SchedulerInterface):
self.log_stats = log_stats
self.structured_output_manager = structured_output_manager
# include_finished_set controls whether a separate set of finished
# request ids should be included in the EngineCoreOutputs returned
# by update_from_outputs(). This is currently used in the multi-engine
# case to track request lifetimes efficiently.
self.include_finished_set = include_finished_set
# Scheduling constraints.
self.max_num_running_reqs = self.scheduler_config.max_num_seqs
self.max_num_scheduled_tokens = \
@ -663,10 +670,16 @@ class Scheduler(SchedulerInterface):
new_running.append(request)
self.running = new_running
return EngineCoreOutputs(
engine_core_outputs = EngineCoreOutputs(
outputs=outputs,
scheduler_stats=self.make_stats(),
)
if self.include_finished_set:
#TODO currently sending duplicates here, improve this
engine_core_outputs.finished_requests = (
scheduler_output.finished_req_ids | self.finished_req_ids)
return engine_core_outputs
def add_request(self, request: Request) -> None:
self.waiting.append(request)

View File

@ -128,12 +128,18 @@ class EngineCoreOutputs(
#NOTE(Nick): We could consider ways to make this more compact,
# e.g. columnwise layout
engine_index: int = 0
# [num_reqs]
outputs: list[EngineCoreOutput] = []
scheduler_stats: Optional[SchedulerStats] = None
timestamp: float = 0.0
utility_output: Optional[UtilityOutput] = None
finished_requests: Optional[set[str]] = None
# In DP case, used to signal that the engine is paused.
engine_paused: bool = False
def __post_init__(self):
if self.timestamp == 0.0:
@ -147,4 +153,5 @@ class EngineCoreRequestType(enum.Enum):
"""
ADD = b'\x00'
ABORT = b'\x01'
UTILITY = b'\x02'
START_DP = b'\x02'
UTILITY = b'\x03'

View File

@ -66,11 +66,17 @@ class AsyncLLM(EngineClient):
self.log_requests = log_requests
self.log_stats = log_stats
self.stat_loggers: list[StatLoggerBase] = []
# Set up stat loggers; independent set for each DP rank.
self.stat_loggers: list[list[StatLoggerBase]] = []
if self.log_stats:
if logger.isEnabledFor(logging.INFO):
self.stat_loggers.append(LoggingStatLogger())
self.stat_loggers.append(PrometheusStatLogger(vllm_config))
for i in range(vllm_config.parallel_config.data_parallel_size):
loggers: list[StatLoggerBase] = []
if logger.isEnabledFor(logging.INFO):
loggers.append(LoggingStatLogger(engine_index=i))
loggers.append(
PrometheusStatLogger(vllm_config, engine_index=i))
self.stat_loggers.append(loggers)
# Tokenizer (+ ensure liveness if running in another process).
self.tokenizer = init_tokenizer_from_configs(
@ -329,6 +335,7 @@ class AsyncLLM(EngineClient):
# TODO(rob): make into a coroutine and launch it in
# background thread once Prometheus overhead is non-trivial.
self._record_stats(
engine_index=outputs.engine_index,
scheduler_stats=outputs.scheduler_stats,
iteration_stats=iteration_stats,
)
@ -350,12 +357,13 @@ class AsyncLLM(EngineClient):
self,
scheduler_stats: Optional[SchedulerStats],
iteration_stats: Optional[IterationStats],
engine_index: int = 0,
):
if not self.log_stats:
return
assert scheduler_stats is not None
for stat_logger in self.stat_loggers:
for stat_logger in self.stat_loggers[engine_index]:
stat_logger.record(scheduler_stats=scheduler_stats,
iteration_stats=iteration_stats)
@ -393,8 +401,9 @@ class AsyncLLM(EngineClient):
scheduler_outputs=None,
model_output=None,
) -> None:
for stat_logger in self.stat_loggers:
stat_logger.log()
for loggers in self.stat_loggers:
for stat_logger in loggers:
stat_logger.log()
async def check_health(self) -> None:
logger.debug("Called check_health.")

View File

@ -1,12 +1,13 @@
# SPDX-License-Identifier: Apache-2.0
import os
import queue
import signal
import sys
import threading
import time
from concurrent.futures import Future
from inspect import isclass, signature
from multiprocessing.connection import Connection
from logging import DEBUG
from typing import Any, Optional
import msgspec
@ -14,7 +15,9 @@ import psutil
import zmq
import zmq.asyncio
from vllm.config import VllmConfig
from vllm.config import ParallelConfig, VllmConfig
from vllm.distributed import stateless_destroy_torch_distributed_process_group
from vllm.executor.multiproc_worker_utils import _add_prefix
from vllm.logger import init_logger
from vllm.lora.request import LoRARequest
from vllm.transformers_utils.config import (
@ -91,6 +94,8 @@ class EngineCore:
cache_config=vllm_config.cache_config,
lora_config=vllm_config.lora_config,
speculative_config=vllm_config.speculative_config,
include_finished_set=vllm_config.parallel_config.data_parallel_size
> 1,
log_stats=self.log_stats,
structured_output_manager=self.structured_output_manager,
)
@ -283,10 +288,10 @@ class EngineCoreProc(EngineCore):
self,
input_path: str,
output_path: str,
ready_pipe: Connection,
vllm_config: VllmConfig,
executor_class: type[Executor],
log_stats: bool,
engine_index: int = 0,
):
super().__init__(vllm_config, executor_class, log_stats)
@ -302,14 +307,20 @@ class EngineCoreProc(EngineCore):
args=(input_path, ),
daemon=True).start()
threading.Thread(target=self.process_output_socket,
args=(output_path, ),
args=(output_path, engine_index),
daemon=True).start()
# Send Readiness signal to EngineClient.
ready_pipe.send({"status": "READY"})
self.global_unfinished_reqs = False
self.step_fn = (self.step if self.batch_queue is None else
self.step_with_batch_queue)
@staticmethod
def run_engine_core(*args, **kwargs):
def run_engine_core(*args,
dp_rank: int = 0,
local_dp_rank: int = 0,
ready_pipe,
**kwargs):
"""Launch EngineCore busy loop in background process."""
# Signal handler used for graceful termination.
@ -331,9 +342,21 @@ class EngineCoreProc(EngineCore):
signal.signal(signal.SIGINT, signal_handler)
parent_process = psutil.Process().parent()
engine_core = None
engine_core: Optional[EngineCoreProc] = None
try:
engine_core = EngineCoreProc(*args, **kwargs)
parallel_config: ParallelConfig = kwargs[
"vllm_config"].parallel_config
if parallel_config.data_parallel_size > 1:
# Set data parallel rank for this engine process.
parallel_config.data_parallel_rank = dp_rank
parallel_config.data_parallel_rank_local = local_dp_rank
engine_core = DPEngineCoreProc(*args, **kwargs)
else:
engine_core = EngineCoreProc(*args, **kwargs)
# Send Readiness signal to EngineClient.
ready_pipe.send({"status": "READY"})
engine_core.run_busy_loop()
except SystemExit:
@ -351,28 +374,44 @@ class EngineCoreProc(EngineCore):
def run_busy_loop(self):
"""Core busy loop of the EngineCore."""
step_fn = (self.step
if self.batch_queue is None else self.step_with_batch_queue)
# Loop until process is sent a SIGINT or SIGTERM
while True:
# 1) Poll the input queue until there is work to do.
while not self.scheduler.has_requests():
logger.debug("EngineCore busy loop waiting.")
req = self.input_queue.get()
self._handle_client_request(*req)
self._process_input_queue()
# 2) Step the engine core and return the outputs.
self._process_engine_step()
# 2) Handle any new client requests.
while not self.input_queue.empty():
req = self.input_queue.get_nowait()
self._handle_client_request(*req)
def _process_input_queue(self):
"""Exits when an engine step needs to be performed."""
# 3) Step the engine core.
outputs = step_fn()
waited = False
while not self.global_unfinished_reqs and not (
self.scheduler.has_requests()):
if logger.isEnabledFor(DEBUG) and self.input_queue.empty():
logger.debug("EngineCore waiting for work.")
waited = True
req = self.input_queue.get()
self._handle_client_request(*req)
# 4) Put EngineCoreOutputs into the output queue.
if outputs is not None:
self.output_queue.put_nowait(outputs)
if waited:
logger.debug(
"EngineCore loop active - local unfinished: %s, finished: %s.",
self.scheduler.has_unfinished_requests(),
self.scheduler.has_finished_requests())
# Handle any more client requests.
while not self.input_queue.empty():
req = self.input_queue.get_nowait()
self._handle_client_request(*req)
def _process_engine_step(self):
"""Called only when there are unfinished local requests."""
# Step the engine core.
outputs = self.step_fn()
# Put EngineCoreOutputs into the output queue.
if outputs is not None:
self.output_queue.put_nowait(outputs)
def _handle_client_request(self, request_type: EngineCoreRequestType,
request: Any) -> None:
@ -382,6 +421,10 @@ class EngineCoreProc(EngineCore):
self.add_request(request)
elif request_type == EngineCoreRequestType.ABORT:
self.abort_requests(request)
elif request_type == EngineCoreRequestType.START_DP:
if not self.global_unfinished_reqs:
logger.debug("EngineCore starting idle loop.")
self.global_unfinished_reqs = True
elif request_type == EngineCoreRequestType.UTILITY:
call_id, method_name, args = request
output = UtilityOutput(call_id)
@ -432,7 +475,7 @@ class EngineCoreProc(EngineCore):
# Push to input queue for core busy loop.
self.input_queue.put_nowait((request_type, request))
def process_output_socket(self, output_path: str):
def process_output_socket(self, output_path: str, engine_index: int):
"""Output socket IO thread."""
# Msgpack serialization encoding.
@ -443,5 +486,114 @@ class EngineCoreProc(EngineCore):
with zmq_socket_ctx(output_path, zmq.constants.PUSH) as socket:
while True:
outputs = self.output_queue.get()
outputs.engine_index = engine_index
encoder.encode_into(outputs, buffer)
socket.send_multipart((buffer, ), copy=False)
socket.send(buffer, copy=False)
ENGINE_PAUSED_OUTPUTS = EngineCoreOutputs(engine_paused=True)
class DPEngineCoreProc(EngineCoreProc):
"""ZMQ-wrapper for running EngineCore in background process
in a data parallel context."""
def __init__(
self,
input_path: str,
output_path: str,
vllm_config: VllmConfig,
executor_class: type[Executor],
log_stats: bool,
):
# Add process-specific prefix to stdout and stderr before
# we initialize the engine.
from multiprocessing import current_process
process_name = current_process().name
pid = os.getpid()
_add_prefix(sys.stdout, process_name, pid)
_add_prefix(sys.stderr, process_name, pid)
dp_size = vllm_config.parallel_config.data_parallel_size
dp_rank = vllm_config.parallel_config.data_parallel_rank
local_dp_rank = vllm_config.parallel_config.data_parallel_rank_local
assert dp_size > 1
assert 0 <= local_dp_rank <= dp_rank < dp_size
from vllm.platforms import current_platform
if current_platform.is_cuda_alike():
from vllm.platforms.cuda import device_id_to_physical_device_id
tp_size = vllm_config.parallel_config.tensor_parallel_size
os.environ["CUDA_VISIBLE_DEVICES"] = ",".join(
str(device_id_to_physical_device_id(i))
for i in range(local_dp_rank * tp_size, (local_dp_rank + 1) *
tp_size))
self.dp_group = vllm_config.parallel_config.stateless_init_dp_group()
# Initialize the engine after setting up environment.
super().__init__(input_path, output_path, vllm_config, executor_class,
log_stats, dp_rank)
# Counts forward-passes of the model so that we can synchronize
# finished with DP peers every N steps.
self.counter = 0
def shutdown(self):
super().shutdown()
if dp_group := getattr(self, "dp_group", None):
stateless_destroy_torch_distributed_process_group(dp_group)
def run_busy_loop(self):
"""Core busy loop of the EngineCore for data parallel case."""
# Loop until process is sent a SIGINT or SIGTERM
while True:
# 1) Poll the input queue until there is work to do.
self._process_input_queue()
local_unfinished_reqs = self.scheduler.has_unfinished_requests()
if local_unfinished_reqs:
# 2) Step the engine core.
self._process_engine_step()
# Check if we have now finished all requests.
local_unfinished_reqs = (
self.scheduler.has_unfinished_requests())
else:
if self.scheduler.has_finished_requests():
# There are no unfinished requests, but there are some
# finished requests remaining to be removed from the
# batch state. This engine step won't perform a forward
# pass but will flush the finished requests to ensure
# up-to-date state is returned in the engine outputs.
self._process_engine_step()
if not self.global_unfinished_reqs:
# All engines are idle.
continue
# There must be unfinished requests in DP peers, run a
# dummy forward pass.
self.execute_dummy_batch()
# 3) All-reduce operation to determine global unfinished reqs.
self.global_unfinished_reqs = self._has_global_unfinished_reqs(
local_unfinished_reqs)
if not self.global_unfinished_reqs:
# Notify client that we are pausing the loop.
self.output_queue.put_nowait(ENGINE_PAUSED_OUTPUTS)
def _has_global_unfinished_reqs(self, local_unfinished: bool) -> bool:
# Optimization - only perform finish-sync all-reduce every 16 steps.
self.counter += 1
if self.counter != 16:
return True
self.counter = 0
return ParallelConfig.has_unfinished_dp(self.dp_group,
local_unfinished)

View File

@ -8,10 +8,11 @@ import threading
import uuid
import weakref
from abc import ABC, abstractmethod
from collections.abc import Awaitable, Sequence
from concurrent.futures import Future
from dataclasses import dataclass
from dataclasses import dataclass, field
from threading import Thread
from typing import Any, Optional, Union
from typing import Any, Callable, Optional, Union
import zmq
import zmq.asyncio
@ -60,6 +61,9 @@ class EngineCoreClient(ABC):
"is not currently supported.")
if multiprocess_mode and asyncio_mode:
if vllm_config.parallel_config.data_parallel_size > 1:
return DPAsyncMPClient(vllm_config, executor_class, log_stats)
return AsyncMPClient(vllm_config, executor_class, log_stats)
if multiprocess_mode and not asyncio_mode:
@ -207,28 +211,74 @@ class InprocClient(EngineCoreClient):
return self.engine_core.pin_lora(lora_id)
class CoreEngine:
"""One per data parallel rank."""
def __init__(
self,
vllm_config: VllmConfig,
executor_class: type[Executor],
log_stats: bool,
ctx: Union[zmq.Context, zmq.asyncio.Context],
output_path: str,
index: int = 0,
local_dp_rank: int = 0,
):
# Paths and sockets for IPC.
input_path = get_open_zmq_ipc_path()
self.input_socket = make_zmq_socket(ctx, input_path,
zmq.constants.PUSH)
try:
# Start EngineCore in background process.
self.proc_handle = BackgroundProcHandle(
input_path=input_path,
output_path=output_path,
process_name=f"EngineCore_{index}",
target_fn=EngineCoreProc.run_engine_core,
process_kwargs={
"vllm_config": vllm_config,
"dp_rank": index,
"local_dp_rank": local_dp_rank,
"executor_class": executor_class,
"log_stats": log_stats,
})
self.num_reqs_in_flight = 0
finally:
if not hasattr(self, "num_reqs_in_flight"):
# Ensure socket is closed if process fails to start.
self.close()
def send_multipart(self, msg_parts: Sequence):
return self.input_socket.send_multipart(msg_parts, copy=False)
def close(self):
if proc_handle := getattr(self, "proc_handle", None):
proc_handle.shutdown()
if socket := getattr(self, "input_socket", None):
socket.close(linger=0)
@dataclass
class BackgroundResources:
"""Used as a finalizer for clean shutdown, avoiding
circular reference back to the client object."""
ctx: zmq.Context
ctx: Union[zmq.Context]
core_engines: list[CoreEngine] = field(default_factory=list)
output_socket: Optional[Union[zmq.Socket, zmq.asyncio.Socket]] = None
input_socket: Optional[Union[zmq.Socket, zmq.asyncio.Socket]] = None
proc_handle: Optional[BackgroundProcHandle] = None
shutdown_path: Optional[str] = None
def __call__(self):
"""Clean up background resources."""
if self.proc_handle is not None:
self.proc_handle.shutdown()
for core_engine in self.core_engines:
core_engine.close()
# ZMQ context termination can hang if the sockets
# aren't explicitly closed first.
if self.output_socket is not None:
self.output_socket.close(linger=0)
if self.input_socket is not None:
self.input_socket.close(linger=0)
if self.shutdown_path is not None:
# We must ensure that the sync output socket is
# closed cleanly in its own thread.
@ -284,7 +334,7 @@ class MPClient(EngineCoreClient):
self.decoder = MsgpackDecoder(EngineCoreOutputs)
# ZMQ setup.
sync_ctx = zmq.Context()
sync_ctx = zmq.Context(io_threads=2)
self.ctx = zmq.asyncio.Context(sync_ctx) if asyncio_mode else sync_ctx
# This will ensure resources created so far are closed
@ -293,28 +343,38 @@ class MPClient(EngineCoreClient):
self.resources = BackgroundResources(ctx=sync_ctx)
self._finalizer = weakref.finalize(self, self.resources)
# Paths for IPC.
# Paths and sockets for IPC.
self.output_path = get_open_zmq_ipc_path()
input_path = get_open_zmq_ipc_path()
# Start EngineCore in background process.
self.resources.proc_handle = BackgroundProcHandle(
input_path=input_path,
output_path=self.output_path,
process_name="EngineCore",
target_fn=EngineCoreProc.run_engine_core,
process_kwargs={
"vllm_config": vllm_config,
"executor_class": executor_class,
"log_stats": log_stats,
})
new_core_engine = lambda index, local_dp_rank=None: CoreEngine(
vllm_config, executor_class, log_stats, self.ctx, self.output_path,
index, local_dp_rank)
# Start engine core process(es).
self._init_core_engines(vllm_config, new_core_engine,
self.resources.core_engines)
# Wait for engine core process(es) to start.
for engine in self.resources.core_engines:
engine.proc_handle.wait_for_startup()
# Create input socket.
self.resources.input_socket = make_zmq_socket(self.ctx, input_path,
zmq.constants.PUSH)
self.input_socket = self.resources.input_socket
self.utility_results: dict[int, AnyFuture] = {}
def _init_core_engines(
self,
vllm_config: VllmConfig,
new_core_engine: Callable[[int, Optional[int]], CoreEngine],
core_engines: list[CoreEngine],
) -> None:
# Default case - single core engine.
dp_rank = vllm_config.parallel_config.data_parallel_rank
local_dp_rank = vllm_config.parallel_config.data_parallel_rank_local
core_engine = new_core_engine(
dp_rank, local_dp_rank if local_dp_rank is not None else dp_rank)
core_engines.append(core_engine)
self.core_engine = core_engine
def shutdown(self):
self._finalizer()
@ -370,7 +430,7 @@ class SyncMPClient(MPClient):
# shutdown signal, exit thread.
break
(frame, ) = out_socket.recv_multipart(copy=False)
frame = out_socket.recv(copy=False)
outputs = decoder.decode(frame.buffer)
if outputs.utility_output:
_process_utility_output(outputs.utility_output,
@ -391,18 +451,15 @@ class SyncMPClient(MPClient):
def get_output(self) -> EngineCoreOutputs:
return self.outputs_queue.get()
def _send_input(self, request_type: EngineCoreRequestType,
request: Any) -> None:
def _send_input(self, request_type: EngineCoreRequestType, request: Any):
# (RequestType, SerializedRequest)
msg = (request_type.value, self.encoder.encode(request))
self.input_socket.send_multipart(msg, copy=False)
self.core_engine.send_multipart(msg)
def _call_utility(self, method: str, *args) -> Any:
def call_utility(self, method: str, *args) -> Any:
call_id = uuid.uuid1().int >> 64
future: Future[Any] = Future()
self.utility_results[call_id] = future
self._send_input(EngineCoreRequestType.UTILITY,
(call_id, method, args))
@ -419,34 +476,34 @@ class SyncMPClient(MPClient):
self._send_input(EngineCoreRequestType.ABORT, request_ids)
def profile(self, is_start: bool = True) -> None:
self._call_utility("profile", is_start)
self.call_utility("profile", is_start)
def reset_prefix_cache(self) -> None:
self._call_utility("reset_prefix_cache")
self.call_utility("reset_prefix_cache")
def add_lora(self, lora_request: LoRARequest) -> bool:
return self._call_utility("add_lora", lora_request)
return self.call_utility("add_lora", lora_request)
def remove_lora(self, lora_id: int) -> bool:
return self._call_utility("remove_lora", lora_id)
return self.call_utility("remove_lora", lora_id)
def list_loras(self) -> set[int]:
return self._call_utility("list_loras")
return self.call_utility("list_loras")
def pin_lora(self, lora_id: int) -> bool:
return self._call_utility("pin_lora", lora_id)
return self.call_utility("pin_lora", lora_id)
def sleep(self, level: int = 1) -> None:
self._call_utility("sleep", level)
self.call_utility("sleep", level)
def wake_up(self) -> None:
self._call_utility("wake_up")
self.call_utility("wake_up")
def is_sleeping(self) -> bool:
return self._call_utility("is_sleeping")
return self.call_utility("is_sleeping")
def execute_dummy_batch(self) -> None:
self._call_utility("execute_dummy_batch")
self.call_utility("execute_dummy_batch")
class AsyncMPClient(MPClient):
@ -464,13 +521,21 @@ class AsyncMPClient(MPClient):
self.outputs_queue: Optional[asyncio.Queue[EngineCoreOutputs]] = None
self.queue_task: Optional[asyncio.Task] = None
async def _start_output_queue_task(self):
self.outputs_handler: Optional[Callable[
[AsyncMPClient, EngineCoreOutputs], Awaitable[None]]] = None
def _ensure_output_queue_task(self):
if self.outputs_queue is not None:
return
# Perform IO in separate task to parallelize as much as possible.
# Avoid task having direct reference back to the client.
self.outputs_queue = asyncio.Queue()
decoder = self.decoder
utility_results = self.utility_results
outputs_queue = self.outputs_queue
output_handler = self.outputs_handler
_self_ref = weakref.ref(self) if output_handler else None
output_path = self.output_path
output_socket = make_zmq_socket(self.ctx, output_path,
zmq.constants.PULL)
@ -483,34 +548,52 @@ class AsyncMPClient(MPClient):
if outputs.utility_output:
_process_utility_output(outputs.utility_output,
utility_results)
else:
continue
if output_handler is not None:
assert _self_ref is not None
_self = _self_ref()
if not _self:
# Client has been garbage collected, abort.
return
await output_handler(_self, outputs)
if outputs.outputs or outputs.scheduler_stats:
outputs_queue.put_nowait(outputs)
self.queue_task = asyncio.create_task(process_outputs_socket(),
name="EngineCoreOutputQueueTask")
async def get_output_async(self) -> EngineCoreOutputs:
if self.outputs_queue is None:
await self._start_output_queue_task()
assert self.outputs_queue is not None
self._ensure_output_queue_task()
assert self.outputs_queue is not None
return await self.outputs_queue.get()
async def _send_input(self, request_type: EngineCoreRequestType,
request: Any) -> None:
await self.core_engine.send_multipart(
(request_type.value, self.encoder.encode(request)))
msg = (request_type.value, self.encoder.encode(request))
await self.input_socket.send_multipart(msg, copy=False)
self._ensure_output_queue_task()
if self.outputs_queue is None:
await self._start_output_queue_task()
async def call_utility_async(self, method: str, *args) -> Any:
return await self._call_utility_async(method,
*args,
engine=self.core_engine)
async def _call_utility_async(self, method: str, *args) -> Any:
async def _call_utility_async(
self,
method: str,
*args,
engine: CoreEngine,
) -> Any:
call_id = uuid.uuid1().int >> 64
future = asyncio.get_running_loop().create_future()
self.utility_results[call_id] = future
await self._send_input(EngineCoreRequestType.UTILITY,
(call_id, method, args))
message = (EngineCoreRequestType.UTILITY.value,
self.encoder.encode((call_id, method, args)))
await engine.send_multipart(message)
self._ensure_output_queue_task()
return await future
async def add_request_async(self, request: EngineCoreRequest) -> None:
@ -524,31 +607,146 @@ class AsyncMPClient(MPClient):
await self._send_input(EngineCoreRequestType.ABORT, request_ids)
async def profile_async(self, is_start: bool = True) -> None:
await self._call_utility_async("profile", is_start)
await self.call_utility_async("profile", is_start)
async def reset_prefix_cache_async(self) -> None:
await self._call_utility_async("reset_prefix_cache")
await self.call_utility_async("reset_prefix_cache")
async def sleep_async(self, level: int = 1) -> None:
await self._call_utility_async("sleep", level)
await self.call_utility_async("sleep", level)
async def wake_up_async(self) -> None:
await self._call_utility_async("wake_up")
await self.call_utility_async("wake_up")
async def is_sleeping_async(self) -> bool:
return await self._call_utility_async("is_sleeping")
return await self.call_utility_async("is_sleeping")
async def execute_dummy_batch_async(self) -> None:
await self._call_utility_async("execute_dummy_batch")
await self.call_utility_async("execute_dummy_batch")
async def add_lora_async(self, lora_request: LoRARequest) -> bool:
return await self._call_utility_async("add_lora", lora_request)
return await self.call_utility_async("add_lora", lora_request)
async def remove_lora_async(self, lora_id: int) -> bool:
return await self._call_utility_async("remove_lora", lora_id)
return await self.call_utility_async("remove_lora", lora_id)
async def list_loras_async(self) -> set[int]:
return await self._call_utility_async("list_loras")
return await self.call_utility_async("list_loras")
async def pin_lora_async(self, lora_id: int) -> bool:
return await self._call_utility_async("pin_lora", lora_id)
return await self.call_utility_async("pin_lora", lora_id)
class DPAsyncMPClient(AsyncMPClient):
"""Asyncio-compatible client for multi-proc, multi-engine (data parallel)
EngineCore."""
def __init__(self, vllm_config: VllmConfig, executor_class: type[Executor],
log_stats: bool):
super().__init__(vllm_config, executor_class, log_stats)
assert len(self.core_engines) > 1
# Control message used for triggering dp idle mode loop.
self.start_dp_msg = (EngineCoreRequestType.START_DP.value,
self.encoder.encode(None))
self.num_engines_running = 0
self.reqs_in_flight: dict[str, CoreEngine] = {}
self.outputs_handler = DPAsyncMPClient.process_engine_outputs # type: ignore[assignment]
def _init_core_engines(
self,
vllm_config: VllmConfig,
new_core_engine: Callable[[int, Optional[int]], CoreEngine],
core_engines: list[CoreEngine],
) -> None:
# Launch a core engine for each data parallel rank.
dp_size = vllm_config.parallel_config.data_parallel_size
for i in range(dp_size):
# Multi-node not yet supported so local_dp_rank == dp_rank.
core_engines.append(new_core_engine(i, i))
self.core_engines = core_engines
async def call_utility_async(self, method: str, *args) -> Any:
# Only the result from the first engine is returned.
return (await asyncio.gather(*[
self._call_utility_async(method, *args, engine=engine)
for engine in self.core_engines
]))[0]
async def add_request_async(self, request: EngineCoreRequest) -> None:
# NOTE: text prompt is not needed in the core engine as it has been
# tokenized.
request.prompt = None
msg = (EngineCoreRequestType.ADD.value, self.encoder.encode(request))
chosen_engine = self.get_core_engine_for_request()
self.reqs_in_flight[request.request_id] = chosen_engine
chosen_engine.num_reqs_in_flight += 1
if self.num_engines_running >= len(self.core_engines):
await chosen_engine.send_multipart(msg)
else:
# Send request to chosen engine and dp start loop
# control message to all other engines.
self.num_engines_running += len(self.core_engines)
await asyncio.gather(*[
engine.send_multipart(msg if engine is
chosen_engine else self.start_dp_msg)
for engine in self.core_engines
])
self._ensure_output_queue_task()
def get_core_engine_for_request(self) -> CoreEngine:
return min(self.core_engines, key=lambda e: e.num_reqs_in_flight)
@staticmethod
async def process_engine_outputs(self: "DPAsyncMPClient",
outputs: EngineCoreOutputs):
if self.reqs_in_flight:
for req_id in outputs.finished_requests or ():
if engine := self.reqs_in_flight.pop(req_id, None):
engine.num_reqs_in_flight -= 1
if outputs.engine_paused:
assert self.num_engines_running >= 1
self.num_engines_running -= 1
if not self.num_engines_running and self.reqs_in_flight:
# If there are requests in flight here, they must have
# been sent after the engines paused. We must make
# sure to start the other engines:
self.num_engines_running = len(self.core_engines)
coros = [
engine.send_multipart(self.start_dp_msg)
for engine in self.core_engines
if not engine.num_reqs_in_flight
]
if coros:
await asyncio.gather(*coros)
async def abort_requests_async(self, request_ids: list[str]) -> None:
if not request_ids:
return
if len(request_ids) == 1:
# Fast-path common case.
if engine := self.reqs_in_flight.get(request_ids[0]):
await self._abort_requests(request_ids, engine)
return
by_engine: dict[CoreEngine, list[str]] = {}
for req_id in request_ids:
if engine := self.reqs_in_flight.get(req_id):
by_engine.setdefault(engine, []).append(req_id)
for engine, req_ids in by_engine.items():
await self._abort_requests(req_ids, engine)
async def _abort_requests(self, request_ids: list[str],
engine: CoreEngine) -> None:
await engine.send_multipart((EngineCoreRequestType.ABORT.value,
self.encoder.encode(request_ids)))

View File

@ -8,6 +8,7 @@ from typing_extensions import TypeVar
import vllm.envs as envs
from vllm.config import ParallelConfig, VllmConfig
from vllm.distributed import stateless_destroy_torch_distributed_process_group
from vllm.engine.arg_utils import EngineArgs
from vllm.engine.metrics_types import StatLoggerBase
from vllm.inputs import INPUT_REGISTRY, InputRegistry, PromptType
@ -60,11 +61,13 @@ class LLMEngine:
self.cache_config = vllm_config.cache_config
# important: init dp group before init the engine_core
self.parallel_config = vllm_config.parallel_config
self.dp_enabled = self.parallel_config.data_parallel_size > 1 # noqa
# In the decoupled engine case this is handled in EngineCoreProc.
parallel_config = vllm_config.parallel_config
if not multiprocess_mode and parallel_config.data_parallel_size > 1:
self.dp_group = parallel_config.stateless_init_dp_group()
else:
self.dp_group = None
self.should_execute_dummy_batch = False
if self.dp_enabled:
self.dp_group = self.parallel_config.stateless_init_dp_group()
# Tokenizer (+ ensure liveness if running in another process).
self.tokenizer = init_tokenizer_from_configs(
@ -148,7 +151,7 @@ class LLMEngine:
def has_unfinished_requests(self) -> bool:
has_unfinished = self.output_processor.has_unfinished_requests()
if not self.dp_enabled:
if self.dp_group is None:
return has_unfinished
return self.has_unfinished_requests_dp(has_unfinished)
@ -280,3 +283,7 @@ class LLMEngine:
def pin_lora(self, lora_id: int) -> bool:
"""Prevent an adapter from being evicted."""
return self.engine_core.pin_lora(lora_id)
def __del__(self):
if dp_group := getattr(self, "dp_group", None):
stateless_destroy_torch_distributed_process_group(dp_group)

View File

@ -235,7 +235,10 @@ class WorkerProc:
worker_response_mq_handle = self.worker_response_mq.export_handle()
# Send Readiness signal to EngineCore process.
with zmq_socket_ctx(ready_path, zmq.constants.PUSH) as ready_socket:
# Set linger here because we want to ensure the message has
# been sent before the context is closed.
with zmq_socket_ctx(ready_path, zmq.constants.PUSH,
linger=10000) as ready_socket:
payload = pickle.dumps(worker_response_mq_handle,
protocol=pickle.HIGHEST_PROTOCOL)
ready_socket.send_string(WorkerProc.READY_STR)
@ -270,11 +273,13 @@ class WorkerProc:
proc = context.Process(target=WorkerProc.worker_main,
kwargs=process_kwargs,
daemon=True)
proc.start()
# Wait for startup
worker_response_mq_handle = WorkerProc.wait_for_startup(
proc, ready_path)
with zmq_socket_ctx(ready_path, zmq.constants.PULL) as ready_socket:
proc.start()
# Wait for startup
worker_response_mq_handle = WorkerProc.wait_for_startup(
proc, ready_socket)
worker_response_mq = MessageQueue.create_from_handle(
worker_response_mq_handle, 0)
@ -337,23 +342,22 @@ class WorkerProc:
@staticmethod
def wait_for_startup(
proc: BaseProcess,
ready_path: str,
ready_socket: zmq.Socket,
) -> Optional[Handle]:
"""Wait until the Worker is ready."""
with zmq_socket_ctx(ready_path, zmq.constants.PULL) as socket:
# Wait for Worker to send READY.
while socket.poll(timeout=POLLING_TIMEOUT_MS) == 0:
logger.debug("Waiting for WorkerProc to startup.")
# Wait for Worker to send READY.
while ready_socket.poll(timeout=POLLING_TIMEOUT_MS) == 0:
logger.debug("Waiting for WorkerProc to startup.")
if not proc.is_alive():
raise RuntimeError("WorkerProc failed to start.")
if not proc.is_alive():
raise RuntimeError("WorkerProc failed to start.")
message = socket.recv_string()
assert message == WorkerProc.READY_STR
handle_frame = socket.recv(copy=False)
handle = pickle.loads(handle_frame.buffer)
return handle
message = ready_socket.recv_string()
assert message == WorkerProc.READY_STR
handle_frame = ready_socket.recv(copy=False)
handle = pickle.loads(handle_frame.buffer)
return handle
class ResponseStatus(Enum):
SUCCESS = auto()

View File

@ -31,7 +31,8 @@ class StatLoggerBase(ABC):
class LoggingStatLogger(StatLoggerBase):
def __init__(self):
def __init__(self, engine_index: int = 0):
self.engine_index = engine_index
self._reset(time.monotonic())
self.last_scheduler_stats = SchedulerStats()
# Prefix cache metrics. This cannot be reset.
@ -78,11 +79,13 @@ class LoggingStatLogger(StatLoggerBase):
# Format and print output.
logger.info(
"Engine %03d: "
"Avg prompt throughput: %.1f tokens/s, "
"Avg generation throughput: %.1f tokens/s, "
"Running: %d reqs, Waiting: %d reqs, "
"GPU KV cache usage: %.1f%%, "
"Prefix cache hit rate: %.1f%%",
self.engine_index,
prompt_throughput,
generation_throughput,
scheduler_stats.num_running_reqs,
@ -94,7 +97,7 @@ class LoggingStatLogger(StatLoggerBase):
class PrometheusStatLogger(StatLoggerBase):
def __init__(self, vllm_config: VllmConfig):
def __init__(self, vllm_config: VllmConfig, engine_index: int = 0):
self._unregister_vllm_metrics()
# Use this flag to hide metrics that were deprecated in
@ -102,8 +105,11 @@ class PrometheusStatLogger(StatLoggerBase):
self.show_hidden_metrics = \
vllm_config.observability_config.show_hidden_metrics
labelnames = ["model_name"]
labelvalues = [vllm_config.model_config.served_model_name]
labelnames = ["model_name", "engine"]
labelvalues = [
vllm_config.model_config.served_model_name,
str(engine_index)
]
max_model_len = vllm_config.model_config.max_model_len

View File

@ -105,7 +105,7 @@ class BackgroundProcHandle:
process_kwargs: dict[Any, Any],
):
context = get_mp_context()
reader, writer = context.Pipe(duplex=False)
self.reader, writer = context.Pipe(duplex=False)
assert ("ready_pipe" not in process_kwargs
and "input_path" not in process_kwargs
@ -115,14 +115,17 @@ class BackgroundProcHandle:
process_kwargs["output_path"] = output_path
# Run busy loop in background process.
self.proc = context.Process(target=target_fn, kwargs=process_kwargs)
self.proc = context.Process(target=target_fn,
kwargs=process_kwargs,
name=process_name)
self._finalizer = weakref.finalize(self, shutdown, self.proc,
input_path, output_path)
self.proc.start()
def wait_for_startup(self):
# Wait for startup.
if reader.recv()["status"] != "READY":
raise RuntimeError(f"{process_name} initialization failed. "
if self.reader.recv()["status"] != "READY":
raise RuntimeError(f"{self.proc.name} initialization failed. "
"See root cause above.")
def shutdown(self):