From 15dac210f0e6b907f191911917238273042552ed Mon Sep 17 00:00:00 2001 From: Nick Hill Date: Thu, 27 Mar 2025 16:14:41 -0700 Subject: [PATCH] [V1] AsyncLLM data parallel (#13923) Signed-off-by: Nick Hill --- .buildkite/test-pipeline.yaml | 5 + examples/offline_inference/data_parallel.py | 22 +- tests/v1/engine/test_engine_core_client.py | 8 +- tests/v1/test_async_llm_dp.py | 109 +++++++ vllm/config.py | 21 +- vllm/distributed/utils.py | 12 + vllm/engine/arg_utils.py | 10 + vllm/envs.py | 8 + vllm/utils.py | 14 +- vllm/v1/core/sched/scheduler.py | 17 +- vllm/v1/engine/__init__.py | 9 +- vllm/v1/engine/async_llm.py | 23 +- vllm/v1/engine/core.py | 208 ++++++++++-- vllm/v1/engine/core_client.py | 332 ++++++++++++++++---- vllm/v1/engine/llm_engine.py | 17 +- vllm/v1/executor/multiproc_executor.py | 38 ++- vllm/v1/metrics/loggers.py | 14 +- vllm/v1/utils.py | 11 +- 18 files changed, 722 insertions(+), 156 deletions(-) create mode 100644 tests/v1/test_async_llm_dp.py diff --git a/.buildkite/test-pipeline.yaml b/.buildkite/test-pipeline.yaml index f22b2b0a..428b4c59 100644 --- a/.buildkite/test-pipeline.yaml +++ b/.buildkite/test-pipeline.yaml @@ -135,12 +135,14 @@ steps: - examples/offline_inference/rlhf.py - examples/offline_inference/rlhf_colocate.py - tests/examples/offline_inference/data_parallel.py + - tests/v1/test_async_llm_dp.py commands: # test with tp=2 and external_dp=2 - VLLM_USE_V1=0 torchrun --nproc-per-node=4 distributed/test_torchrun_example.py - torchrun --nproc-per-node=4 distributed/test_torchrun_example.py # test with internal dp - python3 ../examples/offline_inference/data_parallel.py + - TP_SIZE=2 DP_SIZE=2 pytest -v -s v1/test_async_llm_dp.py - pytest -v -s distributed/test_utils.py - pytest -v -s compile/test_basic_correctness.py - pytest -v -s distributed/test_pynccl.py @@ -514,7 +516,10 @@ steps: - vllm/worker/worker.py - vllm/worker/model_runner.py - entrypoints/llm/test_collective_rpc.py + - tests/v1/test_async_llm_dp.py + - vllm/v1/engine/ commands: + - TP_SIZE=1 DP_SIZE=2 pytest -v -s v1/test_async_llm_dp.py - VLLM_ENABLE_V1_MULTIPROCESSING=0 pytest -v -s entrypoints/llm/test_collective_rpc.py - pytest -v -s ./compile/test_basic_correctness.py - pytest -v -s ./compile/test_wrapper.py diff --git a/examples/offline_inference/data_parallel.py b/examples/offline_inference/data_parallel.py index 232afd8b..04a79e2f 100644 --- a/examples/offline_inference/data_parallel.py +++ b/examples/offline_inference/data_parallel.py @@ -28,6 +28,7 @@ Multi-node: --master-port=13345 """ import os +from time import sleep from vllm import LLM, SamplingParams from vllm.utils import get_open_port @@ -36,14 +37,13 @@ from vllm.utils import get_open_port def main(model, dp_size, local_dp_rank, global_dp_rank, dp_master_ip, dp_master_port, GPUs_per_dp_rank): os.environ["VLLM_DP_RANK"] = str(global_dp_rank) + os.environ["VLLM_DP_RANK_LOCAL"] = str(local_dp_rank) os.environ["VLLM_DP_SIZE"] = str(dp_size) os.environ["VLLM_DP_MASTER_IP"] = dp_master_ip os.environ["VLLM_DP_MASTER_PORT"] = str(dp_master_port) - # set devices for each dp_rank - os.environ["CUDA_VISIBLE_DEVICES"] = ",".join( - str(i) - for i in range(local_dp_rank * GPUs_per_dp_rank, (local_dp_rank + 1) * - GPUs_per_dp_rank)) + + # CUDA_VISIBLE_DEVICES for each DP rank is set automatically inside the + # engine processes. # Sample prompts. prompts = [ @@ -90,6 +90,9 @@ def main(model, dp_size, local_dp_rank, global_dp_rank, dp_master_ip, print(f"DP rank {global_dp_rank}, Prompt: {prompt!r}, " f"Generated text: {generated_text!r}") + # Give engines time to pause their processing loops before exiting. + sleep(1) + if __name__ == "__main__": import argparse @@ -152,8 +155,13 @@ if __name__ == "__main__": procs.append(proc) exit_code = 0 for proc in procs: - proc.join() - if proc.exitcode: + proc.join(timeout=300) + if proc.exitcode is None: + print(f"Killing process {proc.pid} that " + f"didn't stop within 5 minutes.") + proc.kill() + exit_code = 1 + elif proc.exitcode: exit_code = proc.exitcode exit(exit_code) diff --git a/tests/v1/engine/test_engine_core_client.py b/tests/v1/engine/test_engine_core_client.py index 48f451a5..68844b87 100644 --- a/tests/v1/engine/test_engine_core_client.py +++ b/tests/v1/engine/test_engine_core_client.py @@ -167,11 +167,11 @@ def test_engine_core_client(monkeypatch: pytest.MonkeyPatch, core_client: SyncMPClient = client - result = core_client._call_utility("echo", "testarg") + result = core_client.call_utility("echo", "testarg") assert result == "testarg" with pytest.raises(Exception) as e_info: - core_client._call_utility("echo", None, "help!") + core_client.call_utility("echo", None, "help!") assert str(e_info.value) == "Call to echo method failed: help!" @@ -238,10 +238,10 @@ async def test_engine_core_client_asyncio(monkeypatch: pytest.MonkeyPatch): core_client: AsyncMPClient = client - result = await core_client._call_utility_async("echo", "testarg") + result = await core_client.call_utility_async("echo", "testarg") assert result == "testarg" with pytest.raises(Exception) as e_info: - await core_client._call_utility_async("echo", None, "help!") + await core_client.call_utility_async("echo", None, "help!") assert str(e_info.value) == "Call to echo method failed: help!" diff --git a/tests/v1/test_async_llm_dp.py b/tests/v1/test_async_llm_dp.py new file mode 100644 index 00000000..f0e03196 --- /dev/null +++ b/tests/v1/test_async_llm_dp.py @@ -0,0 +1,109 @@ +# SPDX-License-Identifier: Apache-2.0 + +import asyncio +import os +from contextlib import ExitStack +from typing import Optional + +import pytest + +from vllm import SamplingParams +from vllm.engine.arg_utils import AsyncEngineArgs +from vllm.inputs import PromptType +from vllm.platforms import current_platform +from vllm.sampling_params import RequestOutputKind +from vllm.v1.engine.async_llm import AsyncLLM +from vllm.v1.engine.core_client import DPAsyncMPClient + +engine_args = AsyncEngineArgs( + model="ibm-research/PowerMoE-3b", + enforce_eager=True, + disable_log_requests=True, + tensor_parallel_size=int(os.getenv("TP_SIZE", 1)), + data_parallel_size=int(os.getenv("DP_SIZE", 2)), +) + +if not current_platform.supports_v1(engine_args.create_model_config()): + pytest.skip(reason="Requires V1-supporting platform.", + allow_module_level=True) + + +async def generate(engine: AsyncLLM, + request_id: str, + prompt: PromptType, + output_kind: RequestOutputKind, + max_tokens: int, + prompt_logprobs: Optional[int] = None) -> tuple[int, str]: + # Ensure generate doesn't complete too fast for cancellation test. + await asyncio.sleep(0.2) + + count = 0 + sampling_params = SamplingParams(max_tokens=max_tokens, + ignore_eos=True, + output_kind=output_kind, + temperature=0, + prompt_logprobs=prompt_logprobs) + async for out in engine.generate(request_id=request_id, + prompt=prompt, + sampling_params=sampling_params): + + num_tokens = len(out.outputs[0].token_ids) + if output_kind == RequestOutputKind.DELTA: + count += num_tokens + else: + count = num_tokens + + await asyncio.sleep(0.) + + return count, request_id + + +@pytest.mark.parametrize( + "output_kind", [RequestOutputKind.DELTA, RequestOutputKind.FINAL_ONLY]) +@pytest.mark.asyncio +async def test_load(output_kind: RequestOutputKind): + + with ExitStack() as after: + + prompt = "This is a test of data parallel" + + engine = AsyncLLM.from_engine_args(engine_args) + after.callback(engine.shutdown) + + NUM_REQUESTS = 100 + NUM_EXPECTED_TOKENS = 10 + + request_ids = [f"request-{i}" for i in range(NUM_REQUESTS)] + + # Create concurrent requests. + tasks = [] + for request_id in request_ids: + tasks.append( + asyncio.create_task( + generate(engine, request_id, prompt, output_kind, + NUM_EXPECTED_TOKENS))) + + # Confirm that we got all the EXPECTED tokens from the requests. + done, pending = await asyncio.wait(tasks, + return_when=asyncio.FIRST_EXCEPTION) + for task in pending: + task.cancel() + for task in done: + num_generated_tokens, request_id = await task + assert num_generated_tokens == NUM_EXPECTED_TOKENS, ( + f"{request_id} generated {num_generated_tokens} but " + f"expected {NUM_EXPECTED_TOKENS}") + + assert not engine.output_processor.has_unfinished_requests() + + # testing internals here which may break + core_client: DPAsyncMPClient = engine.engine_core + # the engines only synchronize stopping every N steps so + # allow a small amount of time here. + for _ in range(10): + if core_client.num_engines_running == 0: + break + await asyncio.sleep(0.5) + + assert core_client.num_engines_running == 0 + assert not core_client.reqs_in_flight diff --git a/vllm/config.py b/vllm/config.py index 687c8b56..831fa2e4 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -40,7 +40,8 @@ from vllm.transformers_utils.config import ( from vllm.transformers_utils.s3_utils import S3Model from vllm.transformers_utils.utils import is_s3, maybe_model_redirect from vllm.utils import (GiB_bytes, LayerBlockType, cuda_device_count_stateless, - get_cpu_memory, random_uuid, resolve_obj_by_qualname) + get_cpu_memory, get_open_port, random_uuid, + resolve_obj_by_qualname) if TYPE_CHECKING: from ray.util.placement_group import PlacementGroup @@ -1389,6 +1390,8 @@ class ParallelConfig: tensor_parallel_size: int = 1 # Number of tensor parallel groups. data_parallel_size: int = 1 # Number of data parallel groups. data_parallel_rank: int = 0 # Rank of the data parallel group. + # Local rank of the data parallel group, defaults to global rank. + data_parallel_rank_local: Optional[int] = None # IP of the data parallel master. data_parallel_master_ip: str = "127.0.0.1" data_parallel_master_port: int = 29500 # Port of the data parallel master. @@ -1493,10 +1496,18 @@ class ParallelConfig: self.world_size = self.pipeline_parallel_size * \ self.tensor_parallel_size - self.data_parallel_size = envs.VLLM_DP_SIZE - self.data_parallel_rank = envs.VLLM_DP_RANK - self.data_parallel_master_ip = envs.VLLM_DP_MASTER_IP - self.data_parallel_master_port = envs.VLLM_DP_MASTER_PORT + if self.data_parallel_size > 1: + # Data parallel was specified in the engine args. + self.data_parallel_master_port = get_open_port() + # TODO multi-node + else: + # Otherwise fall back to env vars (e.g. for offline SPMD case). + self.data_parallel_size = envs.VLLM_DP_SIZE + self.data_parallel_rank = envs.VLLM_DP_RANK + self.data_parallel_rank_local = envs.VLLM_DP_RANK_LOCAL + self.data_parallel_master_ip = envs.VLLM_DP_MASTER_IP + self.data_parallel_master_port = envs.VLLM_DP_MASTER_PORT + self.world_size_across_dp = self.world_size * self.data_parallel_size if self.distributed_executor_backend == "external_launcher": diff --git a/vllm/distributed/utils.py b/vllm/distributed/utils.py index 84899358..b8178af5 100644 --- a/vllm/distributed/utils.py +++ b/vllm/distributed/utils.py @@ -15,6 +15,8 @@ import torch from torch.distributed import ProcessGroup, TCPStore from torch.distributed.distributed_c10d import (Backend, PrefixStore, _get_default_timeout, + _shutdown_backend, + _unregister_process_group, is_nccl_available) from torch.distributed.rendezvous import rendezvous @@ -333,3 +335,13 @@ def stateless_init_torch_distributed_process_group( pg._register_backend(device, backend_type, backend_class) return pg + + +def stateless_destroy_torch_distributed_process_group( + pg: ProcessGroup) -> None: + """ + Destroy ProcessGroup returned by + stateless_init_torch_distributed_process_group(). + """ + _shutdown_backend(pg) + _unregister_process_group(pg.group_name) diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index 53af3e57..a3b83c65 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -114,6 +114,7 @@ class EngineArgs: # number of P/D disaggregation (or other disaggregation) workers pipeline_parallel_size: int = 1 tensor_parallel_size: int = 1 + data_parallel_size: int = 1 enable_expert_parallel: bool = False max_parallel_loading_workers: Optional[int] = None block_size: Optional[int] = None @@ -442,6 +443,14 @@ class EngineArgs: type=int, default=EngineArgs.tensor_parallel_size, help='Number of tensor parallel replicas.') + parser.add_argument('--data-parallel-size', + '-dp', + type=int, + default=EngineArgs.data_parallel_size, + help='Number of data parallel replicas. ' + 'MoE layers will be sharded according to the ' + 'product of the tensor-parallel-size and ' + 'data-parallel-size.') parser.add_argument( '--enable-expert-parallel', action='store_true', @@ -1359,6 +1368,7 @@ class EngineArgs: parallel_config = ParallelConfig( pipeline_parallel_size=self.pipeline_parallel_size, tensor_parallel_size=self.tensor_parallel_size, + data_parallel_size=self.data_parallel_size, enable_expert_parallel=self.enable_expert_parallel, max_parallel_loading_workers=self.max_parallel_loading_workers, disable_custom_all_reduce=self.disable_custom_all_reduce, diff --git a/vllm/envs.py b/vllm/envs.py index e5025485..53346673 100644 --- a/vllm/envs.py +++ b/vllm/envs.py @@ -2,6 +2,7 @@ import hashlib import os +import sys import tempfile from typing import TYPE_CHECKING, Any, Callable, Optional @@ -95,6 +96,7 @@ if TYPE_CHECKING: VLLM_CUDART_SO_PATH: Optional[str] = None VLLM_USE_HPU_CONTIGUOUS_CACHE_FETCH: bool = True VLLM_DP_RANK: int = 0 + VLLM_DP_RANK_LOCAL: int = -1 VLLM_DP_SIZE: int = 1 VLLM_DP_MASTER_IP: str = "" VLLM_DP_MASTER_PORT: int = 0 @@ -625,6 +627,12 @@ environment_variables: dict[str, Callable[[], Any]] = { "VLLM_DP_RANK": lambda: int(os.getenv("VLLM_DP_RANK", "0")), + # Rank of the process in the data parallel setting. + # Defaults to VLLM_DP_RANK when not set. + "VLLM_DP_RANK_LOCAL": + lambda: int( + os.getenv("VLLM_DP_RANK_LOCAL", sys.modules[__name__].VLLM_DP_RANK)), + # World size of the data parallel setting "VLLM_DP_SIZE": lambda: int(os.getenv("VLLM_DP_SIZE", "1")), diff --git a/vllm/utils.py b/vllm/utils.py index 77f4e2dc..afe68a2b 100644 --- a/vllm/utils.py +++ b/vllm/utils.py @@ -578,7 +578,7 @@ def get_open_port() -> int: dp_port = envs.VLLM_DP_MASTER_PORT while True: port = _get_open_port() - if port >= dp_port and port < dp_port + 10: + if dp_port <= port < dp_port + 10: continue return port return _get_open_port() @@ -2176,11 +2176,11 @@ def make_zmq_socket( if socket_type == zmq.constants.PULL: socket.setsockopt(zmq.constants.RCVHWM, 0) socket.setsockopt(zmq.constants.RCVBUF, buf_size) - socket.connect(path) + socket.bind(path) elif socket_type == zmq.constants.PUSH: socket.setsockopt(zmq.constants.SNDHWM, 0) socket.setsockopt(zmq.constants.SNDBUF, buf_size) - socket.bind(path) + socket.connect(path) else: raise ValueError(f"Unknown Socket Type: {socket_type}") @@ -2188,7 +2188,11 @@ def make_zmq_socket( @contextlib.contextmanager -def zmq_socket_ctx(path: str, socket_type: Any) -> Iterator[zmq.Socket]: +def zmq_socket_ctx( + path: str, + socket_type: Any, + linger: int = 0, +) -> Iterator[zmq.Socket]: """Context manager for a ZMQ socket""" ctx = zmq.Context() # type: ignore[attr-defined] @@ -2199,7 +2203,7 @@ def zmq_socket_ctx(path: str, socket_type: Any) -> Iterator[zmq.Socket]: logger.debug("Got Keyboard Interrupt.") finally: - ctx.destroy(linger=0) + ctx.destroy(linger=linger) def is_in_ray_actor(): diff --git a/vllm/v1/core/sched/scheduler.py b/vllm/v1/core/sched/scheduler.py index 87d30c8a..44811976 100644 --- a/vllm/v1/core/sched/scheduler.py +++ b/vllm/v1/core/sched/scheduler.py @@ -37,9 +37,10 @@ class Scheduler(SchedulerInterface): cache_config: CacheConfig, lora_config: Optional[LoRAConfig], speculative_config: Optional[SpeculativeConfig], - log_stats: bool, structured_output_manager: StructuredOutputManager, mm_registry: MultiModalRegistry = MULTIMODAL_REGISTRY, + include_finished_set: bool = False, + log_stats: bool = False, ) -> None: self.scheduler_config = scheduler_config self.cache_config = cache_config @@ -48,6 +49,12 @@ class Scheduler(SchedulerInterface): self.log_stats = log_stats self.structured_output_manager = structured_output_manager + # include_finished_set controls whether a separate set of finished + # request ids should be included in the EngineCoreOutputs returned + # by update_from_outputs(). This is currently used in the multi-engine + # case to track request lifetimes efficiently. + self.include_finished_set = include_finished_set + # Scheduling constraints. self.max_num_running_reqs = self.scheduler_config.max_num_seqs self.max_num_scheduled_tokens = \ @@ -663,10 +670,16 @@ class Scheduler(SchedulerInterface): new_running.append(request) self.running = new_running - return EngineCoreOutputs( + engine_core_outputs = EngineCoreOutputs( outputs=outputs, scheduler_stats=self.make_stats(), ) + if self.include_finished_set: + #TODO currently sending duplicates here, improve this + engine_core_outputs.finished_requests = ( + scheduler_output.finished_req_ids | self.finished_req_ids) + + return engine_core_outputs def add_request(self, request: Request) -> None: self.waiting.append(request) diff --git a/vllm/v1/engine/__init__.py b/vllm/v1/engine/__init__.py index 3699779b..0557d0c6 100644 --- a/vllm/v1/engine/__init__.py +++ b/vllm/v1/engine/__init__.py @@ -128,12 +128,18 @@ class EngineCoreOutputs( #NOTE(Nick): We could consider ways to make this more compact, # e.g. columnwise layout + engine_index: int = 0 + # [num_reqs] outputs: list[EngineCoreOutput] = [] scheduler_stats: Optional[SchedulerStats] = None timestamp: float = 0.0 utility_output: Optional[UtilityOutput] = None + finished_requests: Optional[set[str]] = None + + # In DP case, used to signal that the engine is paused. + engine_paused: bool = False def __post_init__(self): if self.timestamp == 0.0: @@ -147,4 +153,5 @@ class EngineCoreRequestType(enum.Enum): """ ADD = b'\x00' ABORT = b'\x01' - UTILITY = b'\x02' + START_DP = b'\x02' + UTILITY = b'\x03' diff --git a/vllm/v1/engine/async_llm.py b/vllm/v1/engine/async_llm.py index 3a6811db..1fb9ae8c 100644 --- a/vllm/v1/engine/async_llm.py +++ b/vllm/v1/engine/async_llm.py @@ -66,11 +66,17 @@ class AsyncLLM(EngineClient): self.log_requests = log_requests self.log_stats = log_stats - self.stat_loggers: list[StatLoggerBase] = [] + + # Set up stat loggers; independent set for each DP rank. + self.stat_loggers: list[list[StatLoggerBase]] = [] if self.log_stats: - if logger.isEnabledFor(logging.INFO): - self.stat_loggers.append(LoggingStatLogger()) - self.stat_loggers.append(PrometheusStatLogger(vllm_config)) + for i in range(vllm_config.parallel_config.data_parallel_size): + loggers: list[StatLoggerBase] = [] + if logger.isEnabledFor(logging.INFO): + loggers.append(LoggingStatLogger(engine_index=i)) + loggers.append( + PrometheusStatLogger(vllm_config, engine_index=i)) + self.stat_loggers.append(loggers) # Tokenizer (+ ensure liveness if running in another process). self.tokenizer = init_tokenizer_from_configs( @@ -329,6 +335,7 @@ class AsyncLLM(EngineClient): # TODO(rob): make into a coroutine and launch it in # background thread once Prometheus overhead is non-trivial. self._record_stats( + engine_index=outputs.engine_index, scheduler_stats=outputs.scheduler_stats, iteration_stats=iteration_stats, ) @@ -350,12 +357,13 @@ class AsyncLLM(EngineClient): self, scheduler_stats: Optional[SchedulerStats], iteration_stats: Optional[IterationStats], + engine_index: int = 0, ): if not self.log_stats: return assert scheduler_stats is not None - for stat_logger in self.stat_loggers: + for stat_logger in self.stat_loggers[engine_index]: stat_logger.record(scheduler_stats=scheduler_stats, iteration_stats=iteration_stats) @@ -393,8 +401,9 @@ class AsyncLLM(EngineClient): scheduler_outputs=None, model_output=None, ) -> None: - for stat_logger in self.stat_loggers: - stat_logger.log() + for loggers in self.stat_loggers: + for stat_logger in loggers: + stat_logger.log() async def check_health(self) -> None: logger.debug("Called check_health.") diff --git a/vllm/v1/engine/core.py b/vllm/v1/engine/core.py index 42511777..20904cd4 100644 --- a/vllm/v1/engine/core.py +++ b/vllm/v1/engine/core.py @@ -1,12 +1,13 @@ # SPDX-License-Identifier: Apache-2.0 - +import os import queue import signal +import sys import threading import time from concurrent.futures import Future from inspect import isclass, signature -from multiprocessing.connection import Connection +from logging import DEBUG from typing import Any, Optional import msgspec @@ -14,7 +15,9 @@ import psutil import zmq import zmq.asyncio -from vllm.config import VllmConfig +from vllm.config import ParallelConfig, VllmConfig +from vllm.distributed import stateless_destroy_torch_distributed_process_group +from vllm.executor.multiproc_worker_utils import _add_prefix from vllm.logger import init_logger from vllm.lora.request import LoRARequest from vllm.transformers_utils.config import ( @@ -91,6 +94,8 @@ class EngineCore: cache_config=vllm_config.cache_config, lora_config=vllm_config.lora_config, speculative_config=vllm_config.speculative_config, + include_finished_set=vllm_config.parallel_config.data_parallel_size + > 1, log_stats=self.log_stats, structured_output_manager=self.structured_output_manager, ) @@ -283,10 +288,10 @@ class EngineCoreProc(EngineCore): self, input_path: str, output_path: str, - ready_pipe: Connection, vllm_config: VllmConfig, executor_class: type[Executor], log_stats: bool, + engine_index: int = 0, ): super().__init__(vllm_config, executor_class, log_stats) @@ -302,14 +307,20 @@ class EngineCoreProc(EngineCore): args=(input_path, ), daemon=True).start() threading.Thread(target=self.process_output_socket, - args=(output_path, ), + args=(output_path, engine_index), daemon=True).start() - # Send Readiness signal to EngineClient. - ready_pipe.send({"status": "READY"}) + self.global_unfinished_reqs = False + + self.step_fn = (self.step if self.batch_queue is None else + self.step_with_batch_queue) @staticmethod - def run_engine_core(*args, **kwargs): + def run_engine_core(*args, + dp_rank: int = 0, + local_dp_rank: int = 0, + ready_pipe, + **kwargs): """Launch EngineCore busy loop in background process.""" # Signal handler used for graceful termination. @@ -331,9 +342,21 @@ class EngineCoreProc(EngineCore): signal.signal(signal.SIGINT, signal_handler) parent_process = psutil.Process().parent() - engine_core = None + engine_core: Optional[EngineCoreProc] = None try: - engine_core = EngineCoreProc(*args, **kwargs) + parallel_config: ParallelConfig = kwargs[ + "vllm_config"].parallel_config + if parallel_config.data_parallel_size > 1: + # Set data parallel rank for this engine process. + parallel_config.data_parallel_rank = dp_rank + parallel_config.data_parallel_rank_local = local_dp_rank + engine_core = DPEngineCoreProc(*args, **kwargs) + else: + engine_core = EngineCoreProc(*args, **kwargs) + + # Send Readiness signal to EngineClient. + ready_pipe.send({"status": "READY"}) + engine_core.run_busy_loop() except SystemExit: @@ -351,28 +374,44 @@ class EngineCoreProc(EngineCore): def run_busy_loop(self): """Core busy loop of the EngineCore.""" - step_fn = (self.step - if self.batch_queue is None else self.step_with_batch_queue) - # Loop until process is sent a SIGINT or SIGTERM while True: # 1) Poll the input queue until there is work to do. - while not self.scheduler.has_requests(): - logger.debug("EngineCore busy loop waiting.") - req = self.input_queue.get() - self._handle_client_request(*req) + self._process_input_queue() + # 2) Step the engine core and return the outputs. + self._process_engine_step() - # 2) Handle any new client requests. - while not self.input_queue.empty(): - req = self.input_queue.get_nowait() - self._handle_client_request(*req) + def _process_input_queue(self): + """Exits when an engine step needs to be performed.""" - # 3) Step the engine core. - outputs = step_fn() + waited = False + while not self.global_unfinished_reqs and not ( + self.scheduler.has_requests()): + if logger.isEnabledFor(DEBUG) and self.input_queue.empty(): + logger.debug("EngineCore waiting for work.") + waited = True + req = self.input_queue.get() + self._handle_client_request(*req) - # 4) Put EngineCoreOutputs into the output queue. - if outputs is not None: - self.output_queue.put_nowait(outputs) + if waited: + logger.debug( + "EngineCore loop active - local unfinished: %s, finished: %s.", + self.scheduler.has_unfinished_requests(), + self.scheduler.has_finished_requests()) + + # Handle any more client requests. + while not self.input_queue.empty(): + req = self.input_queue.get_nowait() + self._handle_client_request(*req) + + def _process_engine_step(self): + """Called only when there are unfinished local requests.""" + + # Step the engine core. + outputs = self.step_fn() + # Put EngineCoreOutputs into the output queue. + if outputs is not None: + self.output_queue.put_nowait(outputs) def _handle_client_request(self, request_type: EngineCoreRequestType, request: Any) -> None: @@ -382,6 +421,10 @@ class EngineCoreProc(EngineCore): self.add_request(request) elif request_type == EngineCoreRequestType.ABORT: self.abort_requests(request) + elif request_type == EngineCoreRequestType.START_DP: + if not self.global_unfinished_reqs: + logger.debug("EngineCore starting idle loop.") + self.global_unfinished_reqs = True elif request_type == EngineCoreRequestType.UTILITY: call_id, method_name, args = request output = UtilityOutput(call_id) @@ -432,7 +475,7 @@ class EngineCoreProc(EngineCore): # Push to input queue for core busy loop. self.input_queue.put_nowait((request_type, request)) - def process_output_socket(self, output_path: str): + def process_output_socket(self, output_path: str, engine_index: int): """Output socket IO thread.""" # Msgpack serialization encoding. @@ -443,5 +486,114 @@ class EngineCoreProc(EngineCore): with zmq_socket_ctx(output_path, zmq.constants.PUSH) as socket: while True: outputs = self.output_queue.get() + outputs.engine_index = engine_index encoder.encode_into(outputs, buffer) - socket.send_multipart((buffer, ), copy=False) + socket.send(buffer, copy=False) + + +ENGINE_PAUSED_OUTPUTS = EngineCoreOutputs(engine_paused=True) + + +class DPEngineCoreProc(EngineCoreProc): + """ZMQ-wrapper for running EngineCore in background process + in a data parallel context.""" + + def __init__( + self, + input_path: str, + output_path: str, + vllm_config: VllmConfig, + executor_class: type[Executor], + log_stats: bool, + ): + # Add process-specific prefix to stdout and stderr before + # we initialize the engine. + from multiprocessing import current_process + process_name = current_process().name + pid = os.getpid() + _add_prefix(sys.stdout, process_name, pid) + _add_prefix(sys.stderr, process_name, pid) + + dp_size = vllm_config.parallel_config.data_parallel_size + dp_rank = vllm_config.parallel_config.data_parallel_rank + local_dp_rank = vllm_config.parallel_config.data_parallel_rank_local + + assert dp_size > 1 + assert 0 <= local_dp_rank <= dp_rank < dp_size + + from vllm.platforms import current_platform + if current_platform.is_cuda_alike(): + from vllm.platforms.cuda import device_id_to_physical_device_id + tp_size = vllm_config.parallel_config.tensor_parallel_size + os.environ["CUDA_VISIBLE_DEVICES"] = ",".join( + str(device_id_to_physical_device_id(i)) + for i in range(local_dp_rank * tp_size, (local_dp_rank + 1) * + tp_size)) + + self.dp_group = vllm_config.parallel_config.stateless_init_dp_group() + + # Initialize the engine after setting up environment. + super().__init__(input_path, output_path, vllm_config, executor_class, + log_stats, dp_rank) + + # Counts forward-passes of the model so that we can synchronize + # finished with DP peers every N steps. + self.counter = 0 + + def shutdown(self): + super().shutdown() + if dp_group := getattr(self, "dp_group", None): + stateless_destroy_torch_distributed_process_group(dp_group) + + def run_busy_loop(self): + """Core busy loop of the EngineCore for data parallel case.""" + + # Loop until process is sent a SIGINT or SIGTERM + while True: + # 1) Poll the input queue until there is work to do. + self._process_input_queue() + + local_unfinished_reqs = self.scheduler.has_unfinished_requests() + + if local_unfinished_reqs: + # 2) Step the engine core. + self._process_engine_step() + + # Check if we have now finished all requests. + local_unfinished_reqs = ( + self.scheduler.has_unfinished_requests()) + else: + if self.scheduler.has_finished_requests(): + # There are no unfinished requests, but there are some + # finished requests remaining to be removed from the + # batch state. This engine step won't perform a forward + # pass but will flush the finished requests to ensure + # up-to-date state is returned in the engine outputs. + self._process_engine_step() + + if not self.global_unfinished_reqs: + # All engines are idle. + continue + + # There must be unfinished requests in DP peers, run a + # dummy forward pass. + self.execute_dummy_batch() + + # 3) All-reduce operation to determine global unfinished reqs. + self.global_unfinished_reqs = self._has_global_unfinished_reqs( + local_unfinished_reqs) + + if not self.global_unfinished_reqs: + # Notify client that we are pausing the loop. + self.output_queue.put_nowait(ENGINE_PAUSED_OUTPUTS) + + def _has_global_unfinished_reqs(self, local_unfinished: bool) -> bool: + + # Optimization - only perform finish-sync all-reduce every 16 steps. + self.counter += 1 + if self.counter != 16: + return True + self.counter = 0 + + return ParallelConfig.has_unfinished_dp(self.dp_group, + local_unfinished) diff --git a/vllm/v1/engine/core_client.py b/vllm/v1/engine/core_client.py index 13b72c80..c41ee670 100644 --- a/vllm/v1/engine/core_client.py +++ b/vllm/v1/engine/core_client.py @@ -8,10 +8,11 @@ import threading import uuid import weakref from abc import ABC, abstractmethod +from collections.abc import Awaitable, Sequence from concurrent.futures import Future -from dataclasses import dataclass +from dataclasses import dataclass, field from threading import Thread -from typing import Any, Optional, Union +from typing import Any, Callable, Optional, Union import zmq import zmq.asyncio @@ -60,6 +61,9 @@ class EngineCoreClient(ABC): "is not currently supported.") if multiprocess_mode and asyncio_mode: + if vllm_config.parallel_config.data_parallel_size > 1: + return DPAsyncMPClient(vllm_config, executor_class, log_stats) + return AsyncMPClient(vllm_config, executor_class, log_stats) if multiprocess_mode and not asyncio_mode: @@ -207,28 +211,74 @@ class InprocClient(EngineCoreClient): return self.engine_core.pin_lora(lora_id) +class CoreEngine: + """One per data parallel rank.""" + + def __init__( + self, + vllm_config: VllmConfig, + executor_class: type[Executor], + log_stats: bool, + ctx: Union[zmq.Context, zmq.asyncio.Context], + output_path: str, + index: int = 0, + local_dp_rank: int = 0, + ): + # Paths and sockets for IPC. + input_path = get_open_zmq_ipc_path() + self.input_socket = make_zmq_socket(ctx, input_path, + zmq.constants.PUSH) + try: + # Start EngineCore in background process. + self.proc_handle = BackgroundProcHandle( + input_path=input_path, + output_path=output_path, + process_name=f"EngineCore_{index}", + target_fn=EngineCoreProc.run_engine_core, + process_kwargs={ + "vllm_config": vllm_config, + "dp_rank": index, + "local_dp_rank": local_dp_rank, + "executor_class": executor_class, + "log_stats": log_stats, + }) + + self.num_reqs_in_flight = 0 + finally: + if not hasattr(self, "num_reqs_in_flight"): + # Ensure socket is closed if process fails to start. + self.close() + + def send_multipart(self, msg_parts: Sequence): + return self.input_socket.send_multipart(msg_parts, copy=False) + + def close(self): + if proc_handle := getattr(self, "proc_handle", None): + proc_handle.shutdown() + if socket := getattr(self, "input_socket", None): + socket.close(linger=0) + + @dataclass class BackgroundResources: """Used as a finalizer for clean shutdown, avoiding circular reference back to the client object.""" - ctx: zmq.Context + ctx: Union[zmq.Context] + core_engines: list[CoreEngine] = field(default_factory=list) output_socket: Optional[Union[zmq.Socket, zmq.asyncio.Socket]] = None - input_socket: Optional[Union[zmq.Socket, zmq.asyncio.Socket]] = None - proc_handle: Optional[BackgroundProcHandle] = None shutdown_path: Optional[str] = None def __call__(self): """Clean up background resources.""" - if self.proc_handle is not None: - self.proc_handle.shutdown() + for core_engine in self.core_engines: + core_engine.close() + # ZMQ context termination can hang if the sockets # aren't explicitly closed first. if self.output_socket is not None: self.output_socket.close(linger=0) - if self.input_socket is not None: - self.input_socket.close(linger=0) if self.shutdown_path is not None: # We must ensure that the sync output socket is # closed cleanly in its own thread. @@ -284,7 +334,7 @@ class MPClient(EngineCoreClient): self.decoder = MsgpackDecoder(EngineCoreOutputs) # ZMQ setup. - sync_ctx = zmq.Context() + sync_ctx = zmq.Context(io_threads=2) self.ctx = zmq.asyncio.Context(sync_ctx) if asyncio_mode else sync_ctx # This will ensure resources created so far are closed @@ -293,28 +343,38 @@ class MPClient(EngineCoreClient): self.resources = BackgroundResources(ctx=sync_ctx) self._finalizer = weakref.finalize(self, self.resources) - # Paths for IPC. + # Paths and sockets for IPC. self.output_path = get_open_zmq_ipc_path() - input_path = get_open_zmq_ipc_path() - # Start EngineCore in background process. - self.resources.proc_handle = BackgroundProcHandle( - input_path=input_path, - output_path=self.output_path, - process_name="EngineCore", - target_fn=EngineCoreProc.run_engine_core, - process_kwargs={ - "vllm_config": vllm_config, - "executor_class": executor_class, - "log_stats": log_stats, - }) + new_core_engine = lambda index, local_dp_rank=None: CoreEngine( + vllm_config, executor_class, log_stats, self.ctx, self.output_path, + index, local_dp_rank) + + # Start engine core process(es). + self._init_core_engines(vllm_config, new_core_engine, + self.resources.core_engines) + + # Wait for engine core process(es) to start. + for engine in self.resources.core_engines: + engine.proc_handle.wait_for_startup() - # Create input socket. - self.resources.input_socket = make_zmq_socket(self.ctx, input_path, - zmq.constants.PUSH) - self.input_socket = self.resources.input_socket self.utility_results: dict[int, AnyFuture] = {} + def _init_core_engines( + self, + vllm_config: VllmConfig, + new_core_engine: Callable[[int, Optional[int]], CoreEngine], + core_engines: list[CoreEngine], + ) -> None: + + # Default case - single core engine. + dp_rank = vllm_config.parallel_config.data_parallel_rank + local_dp_rank = vllm_config.parallel_config.data_parallel_rank_local + core_engine = new_core_engine( + dp_rank, local_dp_rank if local_dp_rank is not None else dp_rank) + core_engines.append(core_engine) + self.core_engine = core_engine + def shutdown(self): self._finalizer() @@ -370,7 +430,7 @@ class SyncMPClient(MPClient): # shutdown signal, exit thread. break - (frame, ) = out_socket.recv_multipart(copy=False) + frame = out_socket.recv(copy=False) outputs = decoder.decode(frame.buffer) if outputs.utility_output: _process_utility_output(outputs.utility_output, @@ -391,18 +451,15 @@ class SyncMPClient(MPClient): def get_output(self) -> EngineCoreOutputs: return self.outputs_queue.get() - def _send_input(self, request_type: EngineCoreRequestType, - request: Any) -> None: - + def _send_input(self, request_type: EngineCoreRequestType, request: Any): # (RequestType, SerializedRequest) msg = (request_type.value, self.encoder.encode(request)) - self.input_socket.send_multipart(msg, copy=False) + self.core_engine.send_multipart(msg) - def _call_utility(self, method: str, *args) -> Any: + def call_utility(self, method: str, *args) -> Any: call_id = uuid.uuid1().int >> 64 future: Future[Any] = Future() self.utility_results[call_id] = future - self._send_input(EngineCoreRequestType.UTILITY, (call_id, method, args)) @@ -419,34 +476,34 @@ class SyncMPClient(MPClient): self._send_input(EngineCoreRequestType.ABORT, request_ids) def profile(self, is_start: bool = True) -> None: - self._call_utility("profile", is_start) + self.call_utility("profile", is_start) def reset_prefix_cache(self) -> None: - self._call_utility("reset_prefix_cache") + self.call_utility("reset_prefix_cache") def add_lora(self, lora_request: LoRARequest) -> bool: - return self._call_utility("add_lora", lora_request) + return self.call_utility("add_lora", lora_request) def remove_lora(self, lora_id: int) -> bool: - return self._call_utility("remove_lora", lora_id) + return self.call_utility("remove_lora", lora_id) def list_loras(self) -> set[int]: - return self._call_utility("list_loras") + return self.call_utility("list_loras") def pin_lora(self, lora_id: int) -> bool: - return self._call_utility("pin_lora", lora_id) + return self.call_utility("pin_lora", lora_id) def sleep(self, level: int = 1) -> None: - self._call_utility("sleep", level) + self.call_utility("sleep", level) def wake_up(self) -> None: - self._call_utility("wake_up") + self.call_utility("wake_up") def is_sleeping(self) -> bool: - return self._call_utility("is_sleeping") + return self.call_utility("is_sleeping") def execute_dummy_batch(self) -> None: - self._call_utility("execute_dummy_batch") + self.call_utility("execute_dummy_batch") class AsyncMPClient(MPClient): @@ -464,13 +521,21 @@ class AsyncMPClient(MPClient): self.outputs_queue: Optional[asyncio.Queue[EngineCoreOutputs]] = None self.queue_task: Optional[asyncio.Task] = None - async def _start_output_queue_task(self): + self.outputs_handler: Optional[Callable[ + [AsyncMPClient, EngineCoreOutputs], Awaitable[None]]] = None + + def _ensure_output_queue_task(self): + if self.outputs_queue is not None: + return + # Perform IO in separate task to parallelize as much as possible. # Avoid task having direct reference back to the client. self.outputs_queue = asyncio.Queue() decoder = self.decoder utility_results = self.utility_results outputs_queue = self.outputs_queue + output_handler = self.outputs_handler + _self_ref = weakref.ref(self) if output_handler else None output_path = self.output_path output_socket = make_zmq_socket(self.ctx, output_path, zmq.constants.PULL) @@ -483,34 +548,52 @@ class AsyncMPClient(MPClient): if outputs.utility_output: _process_utility_output(outputs.utility_output, utility_results) - else: + continue + + if output_handler is not None: + assert _self_ref is not None + _self = _self_ref() + if not _self: + # Client has been garbage collected, abort. + return + await output_handler(_self, outputs) + + if outputs.outputs or outputs.scheduler_stats: outputs_queue.put_nowait(outputs) self.queue_task = asyncio.create_task(process_outputs_socket(), name="EngineCoreOutputQueueTask") async def get_output_async(self) -> EngineCoreOutputs: - if self.outputs_queue is None: - await self._start_output_queue_task() - assert self.outputs_queue is not None + self._ensure_output_queue_task() + assert self.outputs_queue is not None return await self.outputs_queue.get() async def _send_input(self, request_type: EngineCoreRequestType, request: Any) -> None: + await self.core_engine.send_multipart( + (request_type.value, self.encoder.encode(request))) - msg = (request_type.value, self.encoder.encode(request)) - await self.input_socket.send_multipart(msg, copy=False) + self._ensure_output_queue_task() - if self.outputs_queue is None: - await self._start_output_queue_task() + async def call_utility_async(self, method: str, *args) -> Any: + return await self._call_utility_async(method, + *args, + engine=self.core_engine) - async def _call_utility_async(self, method: str, *args) -> Any: + async def _call_utility_async( + self, + method: str, + *args, + engine: CoreEngine, + ) -> Any: call_id = uuid.uuid1().int >> 64 future = asyncio.get_running_loop().create_future() self.utility_results[call_id] = future - await self._send_input(EngineCoreRequestType.UTILITY, - (call_id, method, args)) - + message = (EngineCoreRequestType.UTILITY.value, + self.encoder.encode((call_id, method, args))) + await engine.send_multipart(message) + self._ensure_output_queue_task() return await future async def add_request_async(self, request: EngineCoreRequest) -> None: @@ -524,31 +607,146 @@ class AsyncMPClient(MPClient): await self._send_input(EngineCoreRequestType.ABORT, request_ids) async def profile_async(self, is_start: bool = True) -> None: - await self._call_utility_async("profile", is_start) + await self.call_utility_async("profile", is_start) async def reset_prefix_cache_async(self) -> None: - await self._call_utility_async("reset_prefix_cache") + await self.call_utility_async("reset_prefix_cache") async def sleep_async(self, level: int = 1) -> None: - await self._call_utility_async("sleep", level) + await self.call_utility_async("sleep", level) async def wake_up_async(self) -> None: - await self._call_utility_async("wake_up") + await self.call_utility_async("wake_up") async def is_sleeping_async(self) -> bool: - return await self._call_utility_async("is_sleeping") + return await self.call_utility_async("is_sleeping") async def execute_dummy_batch_async(self) -> None: - await self._call_utility_async("execute_dummy_batch") + await self.call_utility_async("execute_dummy_batch") async def add_lora_async(self, lora_request: LoRARequest) -> bool: - return await self._call_utility_async("add_lora", lora_request) + return await self.call_utility_async("add_lora", lora_request) async def remove_lora_async(self, lora_id: int) -> bool: - return await self._call_utility_async("remove_lora", lora_id) + return await self.call_utility_async("remove_lora", lora_id) async def list_loras_async(self) -> set[int]: - return await self._call_utility_async("list_loras") + return await self.call_utility_async("list_loras") async def pin_lora_async(self, lora_id: int) -> bool: - return await self._call_utility_async("pin_lora", lora_id) + return await self.call_utility_async("pin_lora", lora_id) + + +class DPAsyncMPClient(AsyncMPClient): + """Asyncio-compatible client for multi-proc, multi-engine (data parallel) + EngineCore.""" + + def __init__(self, vllm_config: VllmConfig, executor_class: type[Executor], + log_stats: bool): + super().__init__(vllm_config, executor_class, log_stats) + + assert len(self.core_engines) > 1 + + # Control message used for triggering dp idle mode loop. + self.start_dp_msg = (EngineCoreRequestType.START_DP.value, + self.encoder.encode(None)) + + self.num_engines_running = 0 + self.reqs_in_flight: dict[str, CoreEngine] = {} + + self.outputs_handler = DPAsyncMPClient.process_engine_outputs # type: ignore[assignment] + + def _init_core_engines( + self, + vllm_config: VllmConfig, + new_core_engine: Callable[[int, Optional[int]], CoreEngine], + core_engines: list[CoreEngine], + ) -> None: + + # Launch a core engine for each data parallel rank. + dp_size = vllm_config.parallel_config.data_parallel_size + for i in range(dp_size): + # Multi-node not yet supported so local_dp_rank == dp_rank. + core_engines.append(new_core_engine(i, i)) + + self.core_engines = core_engines + + async def call_utility_async(self, method: str, *args) -> Any: + # Only the result from the first engine is returned. + return (await asyncio.gather(*[ + self._call_utility_async(method, *args, engine=engine) + for engine in self.core_engines + ]))[0] + + async def add_request_async(self, request: EngineCoreRequest) -> None: + # NOTE: text prompt is not needed in the core engine as it has been + # tokenized. + request.prompt = None + + msg = (EngineCoreRequestType.ADD.value, self.encoder.encode(request)) + + chosen_engine = self.get_core_engine_for_request() + self.reqs_in_flight[request.request_id] = chosen_engine + chosen_engine.num_reqs_in_flight += 1 + if self.num_engines_running >= len(self.core_engines): + await chosen_engine.send_multipart(msg) + else: + # Send request to chosen engine and dp start loop + # control message to all other engines. + self.num_engines_running += len(self.core_engines) + await asyncio.gather(*[ + engine.send_multipart(msg if engine is + chosen_engine else self.start_dp_msg) + for engine in self.core_engines + ]) + + self._ensure_output_queue_task() + + def get_core_engine_for_request(self) -> CoreEngine: + return min(self.core_engines, key=lambda e: e.num_reqs_in_flight) + + @staticmethod + async def process_engine_outputs(self: "DPAsyncMPClient", + outputs: EngineCoreOutputs): + if self.reqs_in_flight: + for req_id in outputs.finished_requests or (): + if engine := self.reqs_in_flight.pop(req_id, None): + engine.num_reqs_in_flight -= 1 + + if outputs.engine_paused: + assert self.num_engines_running >= 1 + self.num_engines_running -= 1 + if not self.num_engines_running and self.reqs_in_flight: + # If there are requests in flight here, they must have + # been sent after the engines paused. We must make + # sure to start the other engines: + self.num_engines_running = len(self.core_engines) + coros = [ + engine.send_multipart(self.start_dp_msg) + for engine in self.core_engines + if not engine.num_reqs_in_flight + ] + if coros: + await asyncio.gather(*coros) + + async def abort_requests_async(self, request_ids: list[str]) -> None: + if not request_ids: + return + + if len(request_ids) == 1: + # Fast-path common case. + if engine := self.reqs_in_flight.get(request_ids[0]): + await self._abort_requests(request_ids, engine) + return + + by_engine: dict[CoreEngine, list[str]] = {} + for req_id in request_ids: + if engine := self.reqs_in_flight.get(req_id): + by_engine.setdefault(engine, []).append(req_id) + for engine, req_ids in by_engine.items(): + await self._abort_requests(req_ids, engine) + + async def _abort_requests(self, request_ids: list[str], + engine: CoreEngine) -> None: + await engine.send_multipart((EngineCoreRequestType.ABORT.value, + self.encoder.encode(request_ids))) diff --git a/vllm/v1/engine/llm_engine.py b/vllm/v1/engine/llm_engine.py index 7bda3a30..8cc73f9f 100644 --- a/vllm/v1/engine/llm_engine.py +++ b/vllm/v1/engine/llm_engine.py @@ -8,6 +8,7 @@ from typing_extensions import TypeVar import vllm.envs as envs from vllm.config import ParallelConfig, VllmConfig +from vllm.distributed import stateless_destroy_torch_distributed_process_group from vllm.engine.arg_utils import EngineArgs from vllm.engine.metrics_types import StatLoggerBase from vllm.inputs import INPUT_REGISTRY, InputRegistry, PromptType @@ -60,11 +61,13 @@ class LLMEngine: self.cache_config = vllm_config.cache_config # important: init dp group before init the engine_core - self.parallel_config = vllm_config.parallel_config - self.dp_enabled = self.parallel_config.data_parallel_size > 1 # noqa + # In the decoupled engine case this is handled in EngineCoreProc. + parallel_config = vllm_config.parallel_config + if not multiprocess_mode and parallel_config.data_parallel_size > 1: + self.dp_group = parallel_config.stateless_init_dp_group() + else: + self.dp_group = None self.should_execute_dummy_batch = False - if self.dp_enabled: - self.dp_group = self.parallel_config.stateless_init_dp_group() # Tokenizer (+ ensure liveness if running in another process). self.tokenizer = init_tokenizer_from_configs( @@ -148,7 +151,7 @@ class LLMEngine: def has_unfinished_requests(self) -> bool: has_unfinished = self.output_processor.has_unfinished_requests() - if not self.dp_enabled: + if self.dp_group is None: return has_unfinished return self.has_unfinished_requests_dp(has_unfinished) @@ -280,3 +283,7 @@ class LLMEngine: def pin_lora(self, lora_id: int) -> bool: """Prevent an adapter from being evicted.""" return self.engine_core.pin_lora(lora_id) + + def __del__(self): + if dp_group := getattr(self, "dp_group", None): + stateless_destroy_torch_distributed_process_group(dp_group) diff --git a/vllm/v1/executor/multiproc_executor.py b/vllm/v1/executor/multiproc_executor.py index 21e7d265..1d5175eb 100644 --- a/vllm/v1/executor/multiproc_executor.py +++ b/vllm/v1/executor/multiproc_executor.py @@ -235,7 +235,10 @@ class WorkerProc: worker_response_mq_handle = self.worker_response_mq.export_handle() # Send Readiness signal to EngineCore process. - with zmq_socket_ctx(ready_path, zmq.constants.PUSH) as ready_socket: + # Set linger here because we want to ensure the message has + # been sent before the context is closed. + with zmq_socket_ctx(ready_path, zmq.constants.PUSH, + linger=10000) as ready_socket: payload = pickle.dumps(worker_response_mq_handle, protocol=pickle.HIGHEST_PROTOCOL) ready_socket.send_string(WorkerProc.READY_STR) @@ -270,11 +273,13 @@ class WorkerProc: proc = context.Process(target=WorkerProc.worker_main, kwargs=process_kwargs, daemon=True) - proc.start() - # Wait for startup - worker_response_mq_handle = WorkerProc.wait_for_startup( - proc, ready_path) + with zmq_socket_ctx(ready_path, zmq.constants.PULL) as ready_socket: + proc.start() + + # Wait for startup + worker_response_mq_handle = WorkerProc.wait_for_startup( + proc, ready_socket) worker_response_mq = MessageQueue.create_from_handle( worker_response_mq_handle, 0) @@ -337,23 +342,22 @@ class WorkerProc: @staticmethod def wait_for_startup( proc: BaseProcess, - ready_path: str, + ready_socket: zmq.Socket, ) -> Optional[Handle]: """Wait until the Worker is ready.""" - with zmq_socket_ctx(ready_path, zmq.constants.PULL) as socket: - # Wait for Worker to send READY. - while socket.poll(timeout=POLLING_TIMEOUT_MS) == 0: - logger.debug("Waiting for WorkerProc to startup.") + # Wait for Worker to send READY. + while ready_socket.poll(timeout=POLLING_TIMEOUT_MS) == 0: + logger.debug("Waiting for WorkerProc to startup.") - if not proc.is_alive(): - raise RuntimeError("WorkerProc failed to start.") + if not proc.is_alive(): + raise RuntimeError("WorkerProc failed to start.") - message = socket.recv_string() - assert message == WorkerProc.READY_STR - handle_frame = socket.recv(copy=False) - handle = pickle.loads(handle_frame.buffer) - return handle + message = ready_socket.recv_string() + assert message == WorkerProc.READY_STR + handle_frame = ready_socket.recv(copy=False) + handle = pickle.loads(handle_frame.buffer) + return handle class ResponseStatus(Enum): SUCCESS = auto() diff --git a/vllm/v1/metrics/loggers.py b/vllm/v1/metrics/loggers.py index fcb4d4f5..6ffd00eb 100644 --- a/vllm/v1/metrics/loggers.py +++ b/vllm/v1/metrics/loggers.py @@ -31,7 +31,8 @@ class StatLoggerBase(ABC): class LoggingStatLogger(StatLoggerBase): - def __init__(self): + def __init__(self, engine_index: int = 0): + self.engine_index = engine_index self._reset(time.monotonic()) self.last_scheduler_stats = SchedulerStats() # Prefix cache metrics. This cannot be reset. @@ -78,11 +79,13 @@ class LoggingStatLogger(StatLoggerBase): # Format and print output. logger.info( + "Engine %03d: " "Avg prompt throughput: %.1f tokens/s, " "Avg generation throughput: %.1f tokens/s, " "Running: %d reqs, Waiting: %d reqs, " "GPU KV cache usage: %.1f%%, " "Prefix cache hit rate: %.1f%%", + self.engine_index, prompt_throughput, generation_throughput, scheduler_stats.num_running_reqs, @@ -94,7 +97,7 @@ class LoggingStatLogger(StatLoggerBase): class PrometheusStatLogger(StatLoggerBase): - def __init__(self, vllm_config: VllmConfig): + def __init__(self, vllm_config: VllmConfig, engine_index: int = 0): self._unregister_vllm_metrics() # Use this flag to hide metrics that were deprecated in @@ -102,8 +105,11 @@ class PrometheusStatLogger(StatLoggerBase): self.show_hidden_metrics = \ vllm_config.observability_config.show_hidden_metrics - labelnames = ["model_name"] - labelvalues = [vllm_config.model_config.served_model_name] + labelnames = ["model_name", "engine"] + labelvalues = [ + vllm_config.model_config.served_model_name, + str(engine_index) + ] max_model_len = vllm_config.model_config.max_model_len diff --git a/vllm/v1/utils.py b/vllm/v1/utils.py index 6c01ed3d..f42b3501 100644 --- a/vllm/v1/utils.py +++ b/vllm/v1/utils.py @@ -105,7 +105,7 @@ class BackgroundProcHandle: process_kwargs: dict[Any, Any], ): context = get_mp_context() - reader, writer = context.Pipe(duplex=False) + self.reader, writer = context.Pipe(duplex=False) assert ("ready_pipe" not in process_kwargs and "input_path" not in process_kwargs @@ -115,14 +115,17 @@ class BackgroundProcHandle: process_kwargs["output_path"] = output_path # Run busy loop in background process. - self.proc = context.Process(target=target_fn, kwargs=process_kwargs) + self.proc = context.Process(target=target_fn, + kwargs=process_kwargs, + name=process_name) self._finalizer = weakref.finalize(self, shutdown, self.proc, input_path, output_path) self.proc.start() + def wait_for_startup(self): # Wait for startup. - if reader.recv()["status"] != "READY": - raise RuntimeError(f"{process_name} initialization failed. " + if self.reader.recv()["status"] != "READY": + raise RuntimeError(f"{self.proc.name} initialization failed. " "See root cause above.") def shutdown(self):