[V1] AsyncLLM data parallel (#13923)
Signed-off-by: Nick Hill <nhill@redhat.com>
This commit is contained in:
parent
112b3e5b3b
commit
15dac210f0
@ -135,12 +135,14 @@ steps:
|
||||
- examples/offline_inference/rlhf.py
|
||||
- examples/offline_inference/rlhf_colocate.py
|
||||
- tests/examples/offline_inference/data_parallel.py
|
||||
- tests/v1/test_async_llm_dp.py
|
||||
commands:
|
||||
# test with tp=2 and external_dp=2
|
||||
- VLLM_USE_V1=0 torchrun --nproc-per-node=4 distributed/test_torchrun_example.py
|
||||
- torchrun --nproc-per-node=4 distributed/test_torchrun_example.py
|
||||
# test with internal dp
|
||||
- python3 ../examples/offline_inference/data_parallel.py
|
||||
- TP_SIZE=2 DP_SIZE=2 pytest -v -s v1/test_async_llm_dp.py
|
||||
- pytest -v -s distributed/test_utils.py
|
||||
- pytest -v -s compile/test_basic_correctness.py
|
||||
- pytest -v -s distributed/test_pynccl.py
|
||||
@ -514,7 +516,10 @@ steps:
|
||||
- vllm/worker/worker.py
|
||||
- vllm/worker/model_runner.py
|
||||
- entrypoints/llm/test_collective_rpc.py
|
||||
- tests/v1/test_async_llm_dp.py
|
||||
- vllm/v1/engine/
|
||||
commands:
|
||||
- TP_SIZE=1 DP_SIZE=2 pytest -v -s v1/test_async_llm_dp.py
|
||||
- VLLM_ENABLE_V1_MULTIPROCESSING=0 pytest -v -s entrypoints/llm/test_collective_rpc.py
|
||||
- pytest -v -s ./compile/test_basic_correctness.py
|
||||
- pytest -v -s ./compile/test_wrapper.py
|
||||
|
@ -28,6 +28,7 @@ Multi-node:
|
||||
--master-port=13345
|
||||
"""
|
||||
import os
|
||||
from time import sleep
|
||||
|
||||
from vllm import LLM, SamplingParams
|
||||
from vllm.utils import get_open_port
|
||||
@ -36,14 +37,13 @@ from vllm.utils import get_open_port
|
||||
def main(model, dp_size, local_dp_rank, global_dp_rank, dp_master_ip,
|
||||
dp_master_port, GPUs_per_dp_rank):
|
||||
os.environ["VLLM_DP_RANK"] = str(global_dp_rank)
|
||||
os.environ["VLLM_DP_RANK_LOCAL"] = str(local_dp_rank)
|
||||
os.environ["VLLM_DP_SIZE"] = str(dp_size)
|
||||
os.environ["VLLM_DP_MASTER_IP"] = dp_master_ip
|
||||
os.environ["VLLM_DP_MASTER_PORT"] = str(dp_master_port)
|
||||
# set devices for each dp_rank
|
||||
os.environ["CUDA_VISIBLE_DEVICES"] = ",".join(
|
||||
str(i)
|
||||
for i in range(local_dp_rank * GPUs_per_dp_rank, (local_dp_rank + 1) *
|
||||
GPUs_per_dp_rank))
|
||||
|
||||
# CUDA_VISIBLE_DEVICES for each DP rank is set automatically inside the
|
||||
# engine processes.
|
||||
|
||||
# Sample prompts.
|
||||
prompts = [
|
||||
@ -90,6 +90,9 @@ def main(model, dp_size, local_dp_rank, global_dp_rank, dp_master_ip,
|
||||
print(f"DP rank {global_dp_rank}, Prompt: {prompt!r}, "
|
||||
f"Generated text: {generated_text!r}")
|
||||
|
||||
# Give engines time to pause their processing loops before exiting.
|
||||
sleep(1)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import argparse
|
||||
@ -152,8 +155,13 @@ if __name__ == "__main__":
|
||||
procs.append(proc)
|
||||
exit_code = 0
|
||||
for proc in procs:
|
||||
proc.join()
|
||||
if proc.exitcode:
|
||||
proc.join(timeout=300)
|
||||
if proc.exitcode is None:
|
||||
print(f"Killing process {proc.pid} that "
|
||||
f"didn't stop within 5 minutes.")
|
||||
proc.kill()
|
||||
exit_code = 1
|
||||
elif proc.exitcode:
|
||||
exit_code = proc.exitcode
|
||||
|
||||
exit(exit_code)
|
||||
|
@ -167,11 +167,11 @@ def test_engine_core_client(monkeypatch: pytest.MonkeyPatch,
|
||||
|
||||
core_client: SyncMPClient = client
|
||||
|
||||
result = core_client._call_utility("echo", "testarg")
|
||||
result = core_client.call_utility("echo", "testarg")
|
||||
assert result == "testarg"
|
||||
|
||||
with pytest.raises(Exception) as e_info:
|
||||
core_client._call_utility("echo", None, "help!")
|
||||
core_client.call_utility("echo", None, "help!")
|
||||
|
||||
assert str(e_info.value) == "Call to echo method failed: help!"
|
||||
|
||||
@ -238,10 +238,10 @@ async def test_engine_core_client_asyncio(monkeypatch: pytest.MonkeyPatch):
|
||||
|
||||
core_client: AsyncMPClient = client
|
||||
|
||||
result = await core_client._call_utility_async("echo", "testarg")
|
||||
result = await core_client.call_utility_async("echo", "testarg")
|
||||
assert result == "testarg"
|
||||
|
||||
with pytest.raises(Exception) as e_info:
|
||||
await core_client._call_utility_async("echo", None, "help!")
|
||||
await core_client.call_utility_async("echo", None, "help!")
|
||||
|
||||
assert str(e_info.value) == "Call to echo method failed: help!"
|
||||
|
109
tests/v1/test_async_llm_dp.py
Normal file
109
tests/v1/test_async_llm_dp.py
Normal file
@ -0,0 +1,109 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
import asyncio
|
||||
import os
|
||||
from contextlib import ExitStack
|
||||
from typing import Optional
|
||||
|
||||
import pytest
|
||||
|
||||
from vllm import SamplingParams
|
||||
from vllm.engine.arg_utils import AsyncEngineArgs
|
||||
from vllm.inputs import PromptType
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.sampling_params import RequestOutputKind
|
||||
from vllm.v1.engine.async_llm import AsyncLLM
|
||||
from vllm.v1.engine.core_client import DPAsyncMPClient
|
||||
|
||||
engine_args = AsyncEngineArgs(
|
||||
model="ibm-research/PowerMoE-3b",
|
||||
enforce_eager=True,
|
||||
disable_log_requests=True,
|
||||
tensor_parallel_size=int(os.getenv("TP_SIZE", 1)),
|
||||
data_parallel_size=int(os.getenv("DP_SIZE", 2)),
|
||||
)
|
||||
|
||||
if not current_platform.supports_v1(engine_args.create_model_config()):
|
||||
pytest.skip(reason="Requires V1-supporting platform.",
|
||||
allow_module_level=True)
|
||||
|
||||
|
||||
async def generate(engine: AsyncLLM,
|
||||
request_id: str,
|
||||
prompt: PromptType,
|
||||
output_kind: RequestOutputKind,
|
||||
max_tokens: int,
|
||||
prompt_logprobs: Optional[int] = None) -> tuple[int, str]:
|
||||
# Ensure generate doesn't complete too fast for cancellation test.
|
||||
await asyncio.sleep(0.2)
|
||||
|
||||
count = 0
|
||||
sampling_params = SamplingParams(max_tokens=max_tokens,
|
||||
ignore_eos=True,
|
||||
output_kind=output_kind,
|
||||
temperature=0,
|
||||
prompt_logprobs=prompt_logprobs)
|
||||
async for out in engine.generate(request_id=request_id,
|
||||
prompt=prompt,
|
||||
sampling_params=sampling_params):
|
||||
|
||||
num_tokens = len(out.outputs[0].token_ids)
|
||||
if output_kind == RequestOutputKind.DELTA:
|
||||
count += num_tokens
|
||||
else:
|
||||
count = num_tokens
|
||||
|
||||
await asyncio.sleep(0.)
|
||||
|
||||
return count, request_id
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"output_kind", [RequestOutputKind.DELTA, RequestOutputKind.FINAL_ONLY])
|
||||
@pytest.mark.asyncio
|
||||
async def test_load(output_kind: RequestOutputKind):
|
||||
|
||||
with ExitStack() as after:
|
||||
|
||||
prompt = "This is a test of data parallel"
|
||||
|
||||
engine = AsyncLLM.from_engine_args(engine_args)
|
||||
after.callback(engine.shutdown)
|
||||
|
||||
NUM_REQUESTS = 100
|
||||
NUM_EXPECTED_TOKENS = 10
|
||||
|
||||
request_ids = [f"request-{i}" for i in range(NUM_REQUESTS)]
|
||||
|
||||
# Create concurrent requests.
|
||||
tasks = []
|
||||
for request_id in request_ids:
|
||||
tasks.append(
|
||||
asyncio.create_task(
|
||||
generate(engine, request_id, prompt, output_kind,
|
||||
NUM_EXPECTED_TOKENS)))
|
||||
|
||||
# Confirm that we got all the EXPECTED tokens from the requests.
|
||||
done, pending = await asyncio.wait(tasks,
|
||||
return_when=asyncio.FIRST_EXCEPTION)
|
||||
for task in pending:
|
||||
task.cancel()
|
||||
for task in done:
|
||||
num_generated_tokens, request_id = await task
|
||||
assert num_generated_tokens == NUM_EXPECTED_TOKENS, (
|
||||
f"{request_id} generated {num_generated_tokens} but "
|
||||
f"expected {NUM_EXPECTED_TOKENS}")
|
||||
|
||||
assert not engine.output_processor.has_unfinished_requests()
|
||||
|
||||
# testing internals here which may break
|
||||
core_client: DPAsyncMPClient = engine.engine_core
|
||||
# the engines only synchronize stopping every N steps so
|
||||
# allow a small amount of time here.
|
||||
for _ in range(10):
|
||||
if core_client.num_engines_running == 0:
|
||||
break
|
||||
await asyncio.sleep(0.5)
|
||||
|
||||
assert core_client.num_engines_running == 0
|
||||
assert not core_client.reqs_in_flight
|
@ -40,7 +40,8 @@ from vllm.transformers_utils.config import (
|
||||
from vllm.transformers_utils.s3_utils import S3Model
|
||||
from vllm.transformers_utils.utils import is_s3, maybe_model_redirect
|
||||
from vllm.utils import (GiB_bytes, LayerBlockType, cuda_device_count_stateless,
|
||||
get_cpu_memory, random_uuid, resolve_obj_by_qualname)
|
||||
get_cpu_memory, get_open_port, random_uuid,
|
||||
resolve_obj_by_qualname)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from ray.util.placement_group import PlacementGroup
|
||||
@ -1389,6 +1390,8 @@ class ParallelConfig:
|
||||
tensor_parallel_size: int = 1 # Number of tensor parallel groups.
|
||||
data_parallel_size: int = 1 # Number of data parallel groups.
|
||||
data_parallel_rank: int = 0 # Rank of the data parallel group.
|
||||
# Local rank of the data parallel group, defaults to global rank.
|
||||
data_parallel_rank_local: Optional[int] = None
|
||||
# IP of the data parallel master.
|
||||
data_parallel_master_ip: str = "127.0.0.1"
|
||||
data_parallel_master_port: int = 29500 # Port of the data parallel master.
|
||||
@ -1493,10 +1496,18 @@ class ParallelConfig:
|
||||
self.world_size = self.pipeline_parallel_size * \
|
||||
self.tensor_parallel_size
|
||||
|
||||
self.data_parallel_size = envs.VLLM_DP_SIZE
|
||||
self.data_parallel_rank = envs.VLLM_DP_RANK
|
||||
self.data_parallel_master_ip = envs.VLLM_DP_MASTER_IP
|
||||
self.data_parallel_master_port = envs.VLLM_DP_MASTER_PORT
|
||||
if self.data_parallel_size > 1:
|
||||
# Data parallel was specified in the engine args.
|
||||
self.data_parallel_master_port = get_open_port()
|
||||
# TODO multi-node
|
||||
else:
|
||||
# Otherwise fall back to env vars (e.g. for offline SPMD case).
|
||||
self.data_parallel_size = envs.VLLM_DP_SIZE
|
||||
self.data_parallel_rank = envs.VLLM_DP_RANK
|
||||
self.data_parallel_rank_local = envs.VLLM_DP_RANK_LOCAL
|
||||
self.data_parallel_master_ip = envs.VLLM_DP_MASTER_IP
|
||||
self.data_parallel_master_port = envs.VLLM_DP_MASTER_PORT
|
||||
|
||||
self.world_size_across_dp = self.world_size * self.data_parallel_size
|
||||
|
||||
if self.distributed_executor_backend == "external_launcher":
|
||||
|
@ -15,6 +15,8 @@ import torch
|
||||
from torch.distributed import ProcessGroup, TCPStore
|
||||
from torch.distributed.distributed_c10d import (Backend, PrefixStore,
|
||||
_get_default_timeout,
|
||||
_shutdown_backend,
|
||||
_unregister_process_group,
|
||||
is_nccl_available)
|
||||
from torch.distributed.rendezvous import rendezvous
|
||||
|
||||
@ -333,3 +335,13 @@ def stateless_init_torch_distributed_process_group(
|
||||
pg._register_backend(device, backend_type, backend_class)
|
||||
|
||||
return pg
|
||||
|
||||
|
||||
def stateless_destroy_torch_distributed_process_group(
|
||||
pg: ProcessGroup) -> None:
|
||||
"""
|
||||
Destroy ProcessGroup returned by
|
||||
stateless_init_torch_distributed_process_group().
|
||||
"""
|
||||
_shutdown_backend(pg)
|
||||
_unregister_process_group(pg.group_name)
|
||||
|
@ -114,6 +114,7 @@ class EngineArgs:
|
||||
# number of P/D disaggregation (or other disaggregation) workers
|
||||
pipeline_parallel_size: int = 1
|
||||
tensor_parallel_size: int = 1
|
||||
data_parallel_size: int = 1
|
||||
enable_expert_parallel: bool = False
|
||||
max_parallel_loading_workers: Optional[int] = None
|
||||
block_size: Optional[int] = None
|
||||
@ -442,6 +443,14 @@ class EngineArgs:
|
||||
type=int,
|
||||
default=EngineArgs.tensor_parallel_size,
|
||||
help='Number of tensor parallel replicas.')
|
||||
parser.add_argument('--data-parallel-size',
|
||||
'-dp',
|
||||
type=int,
|
||||
default=EngineArgs.data_parallel_size,
|
||||
help='Number of data parallel replicas. '
|
||||
'MoE layers will be sharded according to the '
|
||||
'product of the tensor-parallel-size and '
|
||||
'data-parallel-size.')
|
||||
parser.add_argument(
|
||||
'--enable-expert-parallel',
|
||||
action='store_true',
|
||||
@ -1359,6 +1368,7 @@ class EngineArgs:
|
||||
parallel_config = ParallelConfig(
|
||||
pipeline_parallel_size=self.pipeline_parallel_size,
|
||||
tensor_parallel_size=self.tensor_parallel_size,
|
||||
data_parallel_size=self.data_parallel_size,
|
||||
enable_expert_parallel=self.enable_expert_parallel,
|
||||
max_parallel_loading_workers=self.max_parallel_loading_workers,
|
||||
disable_custom_all_reduce=self.disable_custom_all_reduce,
|
||||
|
@ -2,6 +2,7 @@
|
||||
|
||||
import hashlib
|
||||
import os
|
||||
import sys
|
||||
import tempfile
|
||||
from typing import TYPE_CHECKING, Any, Callable, Optional
|
||||
|
||||
@ -95,6 +96,7 @@ if TYPE_CHECKING:
|
||||
VLLM_CUDART_SO_PATH: Optional[str] = None
|
||||
VLLM_USE_HPU_CONTIGUOUS_CACHE_FETCH: bool = True
|
||||
VLLM_DP_RANK: int = 0
|
||||
VLLM_DP_RANK_LOCAL: int = -1
|
||||
VLLM_DP_SIZE: int = 1
|
||||
VLLM_DP_MASTER_IP: str = ""
|
||||
VLLM_DP_MASTER_PORT: int = 0
|
||||
@ -625,6 +627,12 @@ environment_variables: dict[str, Callable[[], Any]] = {
|
||||
"VLLM_DP_RANK":
|
||||
lambda: int(os.getenv("VLLM_DP_RANK", "0")),
|
||||
|
||||
# Rank of the process in the data parallel setting.
|
||||
# Defaults to VLLM_DP_RANK when not set.
|
||||
"VLLM_DP_RANK_LOCAL":
|
||||
lambda: int(
|
||||
os.getenv("VLLM_DP_RANK_LOCAL", sys.modules[__name__].VLLM_DP_RANK)),
|
||||
|
||||
# World size of the data parallel setting
|
||||
"VLLM_DP_SIZE":
|
||||
lambda: int(os.getenv("VLLM_DP_SIZE", "1")),
|
||||
|
@ -578,7 +578,7 @@ def get_open_port() -> int:
|
||||
dp_port = envs.VLLM_DP_MASTER_PORT
|
||||
while True:
|
||||
port = _get_open_port()
|
||||
if port >= dp_port and port < dp_port + 10:
|
||||
if dp_port <= port < dp_port + 10:
|
||||
continue
|
||||
return port
|
||||
return _get_open_port()
|
||||
@ -2176,11 +2176,11 @@ def make_zmq_socket(
|
||||
if socket_type == zmq.constants.PULL:
|
||||
socket.setsockopt(zmq.constants.RCVHWM, 0)
|
||||
socket.setsockopt(zmq.constants.RCVBUF, buf_size)
|
||||
socket.connect(path)
|
||||
socket.bind(path)
|
||||
elif socket_type == zmq.constants.PUSH:
|
||||
socket.setsockopt(zmq.constants.SNDHWM, 0)
|
||||
socket.setsockopt(zmq.constants.SNDBUF, buf_size)
|
||||
socket.bind(path)
|
||||
socket.connect(path)
|
||||
else:
|
||||
raise ValueError(f"Unknown Socket Type: {socket_type}")
|
||||
|
||||
@ -2188,7 +2188,11 @@ def make_zmq_socket(
|
||||
|
||||
|
||||
@contextlib.contextmanager
|
||||
def zmq_socket_ctx(path: str, socket_type: Any) -> Iterator[zmq.Socket]:
|
||||
def zmq_socket_ctx(
|
||||
path: str,
|
||||
socket_type: Any,
|
||||
linger: int = 0,
|
||||
) -> Iterator[zmq.Socket]:
|
||||
"""Context manager for a ZMQ socket"""
|
||||
|
||||
ctx = zmq.Context() # type: ignore[attr-defined]
|
||||
@ -2199,7 +2203,7 @@ def zmq_socket_ctx(path: str, socket_type: Any) -> Iterator[zmq.Socket]:
|
||||
logger.debug("Got Keyboard Interrupt.")
|
||||
|
||||
finally:
|
||||
ctx.destroy(linger=0)
|
||||
ctx.destroy(linger=linger)
|
||||
|
||||
|
||||
def is_in_ray_actor():
|
||||
|
@ -37,9 +37,10 @@ class Scheduler(SchedulerInterface):
|
||||
cache_config: CacheConfig,
|
||||
lora_config: Optional[LoRAConfig],
|
||||
speculative_config: Optional[SpeculativeConfig],
|
||||
log_stats: bool,
|
||||
structured_output_manager: StructuredOutputManager,
|
||||
mm_registry: MultiModalRegistry = MULTIMODAL_REGISTRY,
|
||||
include_finished_set: bool = False,
|
||||
log_stats: bool = False,
|
||||
) -> None:
|
||||
self.scheduler_config = scheduler_config
|
||||
self.cache_config = cache_config
|
||||
@ -48,6 +49,12 @@ class Scheduler(SchedulerInterface):
|
||||
self.log_stats = log_stats
|
||||
self.structured_output_manager = structured_output_manager
|
||||
|
||||
# include_finished_set controls whether a separate set of finished
|
||||
# request ids should be included in the EngineCoreOutputs returned
|
||||
# by update_from_outputs(). This is currently used in the multi-engine
|
||||
# case to track request lifetimes efficiently.
|
||||
self.include_finished_set = include_finished_set
|
||||
|
||||
# Scheduling constraints.
|
||||
self.max_num_running_reqs = self.scheduler_config.max_num_seqs
|
||||
self.max_num_scheduled_tokens = \
|
||||
@ -663,10 +670,16 @@ class Scheduler(SchedulerInterface):
|
||||
new_running.append(request)
|
||||
|
||||
self.running = new_running
|
||||
return EngineCoreOutputs(
|
||||
engine_core_outputs = EngineCoreOutputs(
|
||||
outputs=outputs,
|
||||
scheduler_stats=self.make_stats(),
|
||||
)
|
||||
if self.include_finished_set:
|
||||
#TODO currently sending duplicates here, improve this
|
||||
engine_core_outputs.finished_requests = (
|
||||
scheduler_output.finished_req_ids | self.finished_req_ids)
|
||||
|
||||
return engine_core_outputs
|
||||
|
||||
def add_request(self, request: Request) -> None:
|
||||
self.waiting.append(request)
|
||||
|
@ -128,12 +128,18 @@ class EngineCoreOutputs(
|
||||
#NOTE(Nick): We could consider ways to make this more compact,
|
||||
# e.g. columnwise layout
|
||||
|
||||
engine_index: int = 0
|
||||
|
||||
# [num_reqs]
|
||||
outputs: list[EngineCoreOutput] = []
|
||||
scheduler_stats: Optional[SchedulerStats] = None
|
||||
timestamp: float = 0.0
|
||||
|
||||
utility_output: Optional[UtilityOutput] = None
|
||||
finished_requests: Optional[set[str]] = None
|
||||
|
||||
# In DP case, used to signal that the engine is paused.
|
||||
engine_paused: bool = False
|
||||
|
||||
def __post_init__(self):
|
||||
if self.timestamp == 0.0:
|
||||
@ -147,4 +153,5 @@ class EngineCoreRequestType(enum.Enum):
|
||||
"""
|
||||
ADD = b'\x00'
|
||||
ABORT = b'\x01'
|
||||
UTILITY = b'\x02'
|
||||
START_DP = b'\x02'
|
||||
UTILITY = b'\x03'
|
||||
|
@ -66,11 +66,17 @@ class AsyncLLM(EngineClient):
|
||||
|
||||
self.log_requests = log_requests
|
||||
self.log_stats = log_stats
|
||||
self.stat_loggers: list[StatLoggerBase] = []
|
||||
|
||||
# Set up stat loggers; independent set for each DP rank.
|
||||
self.stat_loggers: list[list[StatLoggerBase]] = []
|
||||
if self.log_stats:
|
||||
if logger.isEnabledFor(logging.INFO):
|
||||
self.stat_loggers.append(LoggingStatLogger())
|
||||
self.stat_loggers.append(PrometheusStatLogger(vllm_config))
|
||||
for i in range(vllm_config.parallel_config.data_parallel_size):
|
||||
loggers: list[StatLoggerBase] = []
|
||||
if logger.isEnabledFor(logging.INFO):
|
||||
loggers.append(LoggingStatLogger(engine_index=i))
|
||||
loggers.append(
|
||||
PrometheusStatLogger(vllm_config, engine_index=i))
|
||||
self.stat_loggers.append(loggers)
|
||||
|
||||
# Tokenizer (+ ensure liveness if running in another process).
|
||||
self.tokenizer = init_tokenizer_from_configs(
|
||||
@ -329,6 +335,7 @@ class AsyncLLM(EngineClient):
|
||||
# TODO(rob): make into a coroutine and launch it in
|
||||
# background thread once Prometheus overhead is non-trivial.
|
||||
self._record_stats(
|
||||
engine_index=outputs.engine_index,
|
||||
scheduler_stats=outputs.scheduler_stats,
|
||||
iteration_stats=iteration_stats,
|
||||
)
|
||||
@ -350,12 +357,13 @@ class AsyncLLM(EngineClient):
|
||||
self,
|
||||
scheduler_stats: Optional[SchedulerStats],
|
||||
iteration_stats: Optional[IterationStats],
|
||||
engine_index: int = 0,
|
||||
):
|
||||
if not self.log_stats:
|
||||
return
|
||||
|
||||
assert scheduler_stats is not None
|
||||
for stat_logger in self.stat_loggers:
|
||||
for stat_logger in self.stat_loggers[engine_index]:
|
||||
stat_logger.record(scheduler_stats=scheduler_stats,
|
||||
iteration_stats=iteration_stats)
|
||||
|
||||
@ -393,8 +401,9 @@ class AsyncLLM(EngineClient):
|
||||
scheduler_outputs=None,
|
||||
model_output=None,
|
||||
) -> None:
|
||||
for stat_logger in self.stat_loggers:
|
||||
stat_logger.log()
|
||||
for loggers in self.stat_loggers:
|
||||
for stat_logger in loggers:
|
||||
stat_logger.log()
|
||||
|
||||
async def check_health(self) -> None:
|
||||
logger.debug("Called check_health.")
|
||||
|
@ -1,12 +1,13 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
import os
|
||||
import queue
|
||||
import signal
|
||||
import sys
|
||||
import threading
|
||||
import time
|
||||
from concurrent.futures import Future
|
||||
from inspect import isclass, signature
|
||||
from multiprocessing.connection import Connection
|
||||
from logging import DEBUG
|
||||
from typing import Any, Optional
|
||||
|
||||
import msgspec
|
||||
@ -14,7 +15,9 @@ import psutil
|
||||
import zmq
|
||||
import zmq.asyncio
|
||||
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.config import ParallelConfig, VllmConfig
|
||||
from vllm.distributed import stateless_destroy_torch_distributed_process_group
|
||||
from vllm.executor.multiproc_worker_utils import _add_prefix
|
||||
from vllm.logger import init_logger
|
||||
from vllm.lora.request import LoRARequest
|
||||
from vllm.transformers_utils.config import (
|
||||
@ -91,6 +94,8 @@ class EngineCore:
|
||||
cache_config=vllm_config.cache_config,
|
||||
lora_config=vllm_config.lora_config,
|
||||
speculative_config=vllm_config.speculative_config,
|
||||
include_finished_set=vllm_config.parallel_config.data_parallel_size
|
||||
> 1,
|
||||
log_stats=self.log_stats,
|
||||
structured_output_manager=self.structured_output_manager,
|
||||
)
|
||||
@ -283,10 +288,10 @@ class EngineCoreProc(EngineCore):
|
||||
self,
|
||||
input_path: str,
|
||||
output_path: str,
|
||||
ready_pipe: Connection,
|
||||
vllm_config: VllmConfig,
|
||||
executor_class: type[Executor],
|
||||
log_stats: bool,
|
||||
engine_index: int = 0,
|
||||
):
|
||||
super().__init__(vllm_config, executor_class, log_stats)
|
||||
|
||||
@ -302,14 +307,20 @@ class EngineCoreProc(EngineCore):
|
||||
args=(input_path, ),
|
||||
daemon=True).start()
|
||||
threading.Thread(target=self.process_output_socket,
|
||||
args=(output_path, ),
|
||||
args=(output_path, engine_index),
|
||||
daemon=True).start()
|
||||
|
||||
# Send Readiness signal to EngineClient.
|
||||
ready_pipe.send({"status": "READY"})
|
||||
self.global_unfinished_reqs = False
|
||||
|
||||
self.step_fn = (self.step if self.batch_queue is None else
|
||||
self.step_with_batch_queue)
|
||||
|
||||
@staticmethod
|
||||
def run_engine_core(*args, **kwargs):
|
||||
def run_engine_core(*args,
|
||||
dp_rank: int = 0,
|
||||
local_dp_rank: int = 0,
|
||||
ready_pipe,
|
||||
**kwargs):
|
||||
"""Launch EngineCore busy loop in background process."""
|
||||
|
||||
# Signal handler used for graceful termination.
|
||||
@ -331,9 +342,21 @@ class EngineCoreProc(EngineCore):
|
||||
signal.signal(signal.SIGINT, signal_handler)
|
||||
|
||||
parent_process = psutil.Process().parent()
|
||||
engine_core = None
|
||||
engine_core: Optional[EngineCoreProc] = None
|
||||
try:
|
||||
engine_core = EngineCoreProc(*args, **kwargs)
|
||||
parallel_config: ParallelConfig = kwargs[
|
||||
"vllm_config"].parallel_config
|
||||
if parallel_config.data_parallel_size > 1:
|
||||
# Set data parallel rank for this engine process.
|
||||
parallel_config.data_parallel_rank = dp_rank
|
||||
parallel_config.data_parallel_rank_local = local_dp_rank
|
||||
engine_core = DPEngineCoreProc(*args, **kwargs)
|
||||
else:
|
||||
engine_core = EngineCoreProc(*args, **kwargs)
|
||||
|
||||
# Send Readiness signal to EngineClient.
|
||||
ready_pipe.send({"status": "READY"})
|
||||
|
||||
engine_core.run_busy_loop()
|
||||
|
||||
except SystemExit:
|
||||
@ -351,28 +374,44 @@ class EngineCoreProc(EngineCore):
|
||||
def run_busy_loop(self):
|
||||
"""Core busy loop of the EngineCore."""
|
||||
|
||||
step_fn = (self.step
|
||||
if self.batch_queue is None else self.step_with_batch_queue)
|
||||
|
||||
# Loop until process is sent a SIGINT or SIGTERM
|
||||
while True:
|
||||
# 1) Poll the input queue until there is work to do.
|
||||
while not self.scheduler.has_requests():
|
||||
logger.debug("EngineCore busy loop waiting.")
|
||||
req = self.input_queue.get()
|
||||
self._handle_client_request(*req)
|
||||
self._process_input_queue()
|
||||
# 2) Step the engine core and return the outputs.
|
||||
self._process_engine_step()
|
||||
|
||||
# 2) Handle any new client requests.
|
||||
while not self.input_queue.empty():
|
||||
req = self.input_queue.get_nowait()
|
||||
self._handle_client_request(*req)
|
||||
def _process_input_queue(self):
|
||||
"""Exits when an engine step needs to be performed."""
|
||||
|
||||
# 3) Step the engine core.
|
||||
outputs = step_fn()
|
||||
waited = False
|
||||
while not self.global_unfinished_reqs and not (
|
||||
self.scheduler.has_requests()):
|
||||
if logger.isEnabledFor(DEBUG) and self.input_queue.empty():
|
||||
logger.debug("EngineCore waiting for work.")
|
||||
waited = True
|
||||
req = self.input_queue.get()
|
||||
self._handle_client_request(*req)
|
||||
|
||||
# 4) Put EngineCoreOutputs into the output queue.
|
||||
if outputs is not None:
|
||||
self.output_queue.put_nowait(outputs)
|
||||
if waited:
|
||||
logger.debug(
|
||||
"EngineCore loop active - local unfinished: %s, finished: %s.",
|
||||
self.scheduler.has_unfinished_requests(),
|
||||
self.scheduler.has_finished_requests())
|
||||
|
||||
# Handle any more client requests.
|
||||
while not self.input_queue.empty():
|
||||
req = self.input_queue.get_nowait()
|
||||
self._handle_client_request(*req)
|
||||
|
||||
def _process_engine_step(self):
|
||||
"""Called only when there are unfinished local requests."""
|
||||
|
||||
# Step the engine core.
|
||||
outputs = self.step_fn()
|
||||
# Put EngineCoreOutputs into the output queue.
|
||||
if outputs is not None:
|
||||
self.output_queue.put_nowait(outputs)
|
||||
|
||||
def _handle_client_request(self, request_type: EngineCoreRequestType,
|
||||
request: Any) -> None:
|
||||
@ -382,6 +421,10 @@ class EngineCoreProc(EngineCore):
|
||||
self.add_request(request)
|
||||
elif request_type == EngineCoreRequestType.ABORT:
|
||||
self.abort_requests(request)
|
||||
elif request_type == EngineCoreRequestType.START_DP:
|
||||
if not self.global_unfinished_reqs:
|
||||
logger.debug("EngineCore starting idle loop.")
|
||||
self.global_unfinished_reqs = True
|
||||
elif request_type == EngineCoreRequestType.UTILITY:
|
||||
call_id, method_name, args = request
|
||||
output = UtilityOutput(call_id)
|
||||
@ -432,7 +475,7 @@ class EngineCoreProc(EngineCore):
|
||||
# Push to input queue for core busy loop.
|
||||
self.input_queue.put_nowait((request_type, request))
|
||||
|
||||
def process_output_socket(self, output_path: str):
|
||||
def process_output_socket(self, output_path: str, engine_index: int):
|
||||
"""Output socket IO thread."""
|
||||
|
||||
# Msgpack serialization encoding.
|
||||
@ -443,5 +486,114 @@ class EngineCoreProc(EngineCore):
|
||||
with zmq_socket_ctx(output_path, zmq.constants.PUSH) as socket:
|
||||
while True:
|
||||
outputs = self.output_queue.get()
|
||||
outputs.engine_index = engine_index
|
||||
encoder.encode_into(outputs, buffer)
|
||||
socket.send_multipart((buffer, ), copy=False)
|
||||
socket.send(buffer, copy=False)
|
||||
|
||||
|
||||
ENGINE_PAUSED_OUTPUTS = EngineCoreOutputs(engine_paused=True)
|
||||
|
||||
|
||||
class DPEngineCoreProc(EngineCoreProc):
|
||||
"""ZMQ-wrapper for running EngineCore in background process
|
||||
in a data parallel context."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
input_path: str,
|
||||
output_path: str,
|
||||
vllm_config: VllmConfig,
|
||||
executor_class: type[Executor],
|
||||
log_stats: bool,
|
||||
):
|
||||
# Add process-specific prefix to stdout and stderr before
|
||||
# we initialize the engine.
|
||||
from multiprocessing import current_process
|
||||
process_name = current_process().name
|
||||
pid = os.getpid()
|
||||
_add_prefix(sys.stdout, process_name, pid)
|
||||
_add_prefix(sys.stderr, process_name, pid)
|
||||
|
||||
dp_size = vllm_config.parallel_config.data_parallel_size
|
||||
dp_rank = vllm_config.parallel_config.data_parallel_rank
|
||||
local_dp_rank = vllm_config.parallel_config.data_parallel_rank_local
|
||||
|
||||
assert dp_size > 1
|
||||
assert 0 <= local_dp_rank <= dp_rank < dp_size
|
||||
|
||||
from vllm.platforms import current_platform
|
||||
if current_platform.is_cuda_alike():
|
||||
from vllm.platforms.cuda import device_id_to_physical_device_id
|
||||
tp_size = vllm_config.parallel_config.tensor_parallel_size
|
||||
os.environ["CUDA_VISIBLE_DEVICES"] = ",".join(
|
||||
str(device_id_to_physical_device_id(i))
|
||||
for i in range(local_dp_rank * tp_size, (local_dp_rank + 1) *
|
||||
tp_size))
|
||||
|
||||
self.dp_group = vllm_config.parallel_config.stateless_init_dp_group()
|
||||
|
||||
# Initialize the engine after setting up environment.
|
||||
super().__init__(input_path, output_path, vllm_config, executor_class,
|
||||
log_stats, dp_rank)
|
||||
|
||||
# Counts forward-passes of the model so that we can synchronize
|
||||
# finished with DP peers every N steps.
|
||||
self.counter = 0
|
||||
|
||||
def shutdown(self):
|
||||
super().shutdown()
|
||||
if dp_group := getattr(self, "dp_group", None):
|
||||
stateless_destroy_torch_distributed_process_group(dp_group)
|
||||
|
||||
def run_busy_loop(self):
|
||||
"""Core busy loop of the EngineCore for data parallel case."""
|
||||
|
||||
# Loop until process is sent a SIGINT or SIGTERM
|
||||
while True:
|
||||
# 1) Poll the input queue until there is work to do.
|
||||
self._process_input_queue()
|
||||
|
||||
local_unfinished_reqs = self.scheduler.has_unfinished_requests()
|
||||
|
||||
if local_unfinished_reqs:
|
||||
# 2) Step the engine core.
|
||||
self._process_engine_step()
|
||||
|
||||
# Check if we have now finished all requests.
|
||||
local_unfinished_reqs = (
|
||||
self.scheduler.has_unfinished_requests())
|
||||
else:
|
||||
if self.scheduler.has_finished_requests():
|
||||
# There are no unfinished requests, but there are some
|
||||
# finished requests remaining to be removed from the
|
||||
# batch state. This engine step won't perform a forward
|
||||
# pass but will flush the finished requests to ensure
|
||||
# up-to-date state is returned in the engine outputs.
|
||||
self._process_engine_step()
|
||||
|
||||
if not self.global_unfinished_reqs:
|
||||
# All engines are idle.
|
||||
continue
|
||||
|
||||
# There must be unfinished requests in DP peers, run a
|
||||
# dummy forward pass.
|
||||
self.execute_dummy_batch()
|
||||
|
||||
# 3) All-reduce operation to determine global unfinished reqs.
|
||||
self.global_unfinished_reqs = self._has_global_unfinished_reqs(
|
||||
local_unfinished_reqs)
|
||||
|
||||
if not self.global_unfinished_reqs:
|
||||
# Notify client that we are pausing the loop.
|
||||
self.output_queue.put_nowait(ENGINE_PAUSED_OUTPUTS)
|
||||
|
||||
def _has_global_unfinished_reqs(self, local_unfinished: bool) -> bool:
|
||||
|
||||
# Optimization - only perform finish-sync all-reduce every 16 steps.
|
||||
self.counter += 1
|
||||
if self.counter != 16:
|
||||
return True
|
||||
self.counter = 0
|
||||
|
||||
return ParallelConfig.has_unfinished_dp(self.dp_group,
|
||||
local_unfinished)
|
||||
|
@ -8,10 +8,11 @@ import threading
|
||||
import uuid
|
||||
import weakref
|
||||
from abc import ABC, abstractmethod
|
||||
from collections.abc import Awaitable, Sequence
|
||||
from concurrent.futures import Future
|
||||
from dataclasses import dataclass
|
||||
from dataclasses import dataclass, field
|
||||
from threading import Thread
|
||||
from typing import Any, Optional, Union
|
||||
from typing import Any, Callable, Optional, Union
|
||||
|
||||
import zmq
|
||||
import zmq.asyncio
|
||||
@ -60,6 +61,9 @@ class EngineCoreClient(ABC):
|
||||
"is not currently supported.")
|
||||
|
||||
if multiprocess_mode and asyncio_mode:
|
||||
if vllm_config.parallel_config.data_parallel_size > 1:
|
||||
return DPAsyncMPClient(vllm_config, executor_class, log_stats)
|
||||
|
||||
return AsyncMPClient(vllm_config, executor_class, log_stats)
|
||||
|
||||
if multiprocess_mode and not asyncio_mode:
|
||||
@ -207,28 +211,74 @@ class InprocClient(EngineCoreClient):
|
||||
return self.engine_core.pin_lora(lora_id)
|
||||
|
||||
|
||||
class CoreEngine:
|
||||
"""One per data parallel rank."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
vllm_config: VllmConfig,
|
||||
executor_class: type[Executor],
|
||||
log_stats: bool,
|
||||
ctx: Union[zmq.Context, zmq.asyncio.Context],
|
||||
output_path: str,
|
||||
index: int = 0,
|
||||
local_dp_rank: int = 0,
|
||||
):
|
||||
# Paths and sockets for IPC.
|
||||
input_path = get_open_zmq_ipc_path()
|
||||
self.input_socket = make_zmq_socket(ctx, input_path,
|
||||
zmq.constants.PUSH)
|
||||
try:
|
||||
# Start EngineCore in background process.
|
||||
self.proc_handle = BackgroundProcHandle(
|
||||
input_path=input_path,
|
||||
output_path=output_path,
|
||||
process_name=f"EngineCore_{index}",
|
||||
target_fn=EngineCoreProc.run_engine_core,
|
||||
process_kwargs={
|
||||
"vllm_config": vllm_config,
|
||||
"dp_rank": index,
|
||||
"local_dp_rank": local_dp_rank,
|
||||
"executor_class": executor_class,
|
||||
"log_stats": log_stats,
|
||||
})
|
||||
|
||||
self.num_reqs_in_flight = 0
|
||||
finally:
|
||||
if not hasattr(self, "num_reqs_in_flight"):
|
||||
# Ensure socket is closed if process fails to start.
|
||||
self.close()
|
||||
|
||||
def send_multipart(self, msg_parts: Sequence):
|
||||
return self.input_socket.send_multipart(msg_parts, copy=False)
|
||||
|
||||
def close(self):
|
||||
if proc_handle := getattr(self, "proc_handle", None):
|
||||
proc_handle.shutdown()
|
||||
if socket := getattr(self, "input_socket", None):
|
||||
socket.close(linger=0)
|
||||
|
||||
|
||||
@dataclass
|
||||
class BackgroundResources:
|
||||
"""Used as a finalizer for clean shutdown, avoiding
|
||||
circular reference back to the client object."""
|
||||
|
||||
ctx: zmq.Context
|
||||
ctx: Union[zmq.Context]
|
||||
core_engines: list[CoreEngine] = field(default_factory=list)
|
||||
output_socket: Optional[Union[zmq.Socket, zmq.asyncio.Socket]] = None
|
||||
input_socket: Optional[Union[zmq.Socket, zmq.asyncio.Socket]] = None
|
||||
proc_handle: Optional[BackgroundProcHandle] = None
|
||||
shutdown_path: Optional[str] = None
|
||||
|
||||
def __call__(self):
|
||||
"""Clean up background resources."""
|
||||
|
||||
if self.proc_handle is not None:
|
||||
self.proc_handle.shutdown()
|
||||
for core_engine in self.core_engines:
|
||||
core_engine.close()
|
||||
|
||||
# ZMQ context termination can hang if the sockets
|
||||
# aren't explicitly closed first.
|
||||
if self.output_socket is not None:
|
||||
self.output_socket.close(linger=0)
|
||||
if self.input_socket is not None:
|
||||
self.input_socket.close(linger=0)
|
||||
if self.shutdown_path is not None:
|
||||
# We must ensure that the sync output socket is
|
||||
# closed cleanly in its own thread.
|
||||
@ -284,7 +334,7 @@ class MPClient(EngineCoreClient):
|
||||
self.decoder = MsgpackDecoder(EngineCoreOutputs)
|
||||
|
||||
# ZMQ setup.
|
||||
sync_ctx = zmq.Context()
|
||||
sync_ctx = zmq.Context(io_threads=2)
|
||||
self.ctx = zmq.asyncio.Context(sync_ctx) if asyncio_mode else sync_ctx
|
||||
|
||||
# This will ensure resources created so far are closed
|
||||
@ -293,28 +343,38 @@ class MPClient(EngineCoreClient):
|
||||
self.resources = BackgroundResources(ctx=sync_ctx)
|
||||
self._finalizer = weakref.finalize(self, self.resources)
|
||||
|
||||
# Paths for IPC.
|
||||
# Paths and sockets for IPC.
|
||||
self.output_path = get_open_zmq_ipc_path()
|
||||
input_path = get_open_zmq_ipc_path()
|
||||
|
||||
# Start EngineCore in background process.
|
||||
self.resources.proc_handle = BackgroundProcHandle(
|
||||
input_path=input_path,
|
||||
output_path=self.output_path,
|
||||
process_name="EngineCore",
|
||||
target_fn=EngineCoreProc.run_engine_core,
|
||||
process_kwargs={
|
||||
"vllm_config": vllm_config,
|
||||
"executor_class": executor_class,
|
||||
"log_stats": log_stats,
|
||||
})
|
||||
new_core_engine = lambda index, local_dp_rank=None: CoreEngine(
|
||||
vllm_config, executor_class, log_stats, self.ctx, self.output_path,
|
||||
index, local_dp_rank)
|
||||
|
||||
# Start engine core process(es).
|
||||
self._init_core_engines(vllm_config, new_core_engine,
|
||||
self.resources.core_engines)
|
||||
|
||||
# Wait for engine core process(es) to start.
|
||||
for engine in self.resources.core_engines:
|
||||
engine.proc_handle.wait_for_startup()
|
||||
|
||||
# Create input socket.
|
||||
self.resources.input_socket = make_zmq_socket(self.ctx, input_path,
|
||||
zmq.constants.PUSH)
|
||||
self.input_socket = self.resources.input_socket
|
||||
self.utility_results: dict[int, AnyFuture] = {}
|
||||
|
||||
def _init_core_engines(
|
||||
self,
|
||||
vllm_config: VllmConfig,
|
||||
new_core_engine: Callable[[int, Optional[int]], CoreEngine],
|
||||
core_engines: list[CoreEngine],
|
||||
) -> None:
|
||||
|
||||
# Default case - single core engine.
|
||||
dp_rank = vllm_config.parallel_config.data_parallel_rank
|
||||
local_dp_rank = vllm_config.parallel_config.data_parallel_rank_local
|
||||
core_engine = new_core_engine(
|
||||
dp_rank, local_dp_rank if local_dp_rank is not None else dp_rank)
|
||||
core_engines.append(core_engine)
|
||||
self.core_engine = core_engine
|
||||
|
||||
def shutdown(self):
|
||||
self._finalizer()
|
||||
|
||||
@ -370,7 +430,7 @@ class SyncMPClient(MPClient):
|
||||
# shutdown signal, exit thread.
|
||||
break
|
||||
|
||||
(frame, ) = out_socket.recv_multipart(copy=False)
|
||||
frame = out_socket.recv(copy=False)
|
||||
outputs = decoder.decode(frame.buffer)
|
||||
if outputs.utility_output:
|
||||
_process_utility_output(outputs.utility_output,
|
||||
@ -391,18 +451,15 @@ class SyncMPClient(MPClient):
|
||||
def get_output(self) -> EngineCoreOutputs:
|
||||
return self.outputs_queue.get()
|
||||
|
||||
def _send_input(self, request_type: EngineCoreRequestType,
|
||||
request: Any) -> None:
|
||||
|
||||
def _send_input(self, request_type: EngineCoreRequestType, request: Any):
|
||||
# (RequestType, SerializedRequest)
|
||||
msg = (request_type.value, self.encoder.encode(request))
|
||||
self.input_socket.send_multipart(msg, copy=False)
|
||||
self.core_engine.send_multipart(msg)
|
||||
|
||||
def _call_utility(self, method: str, *args) -> Any:
|
||||
def call_utility(self, method: str, *args) -> Any:
|
||||
call_id = uuid.uuid1().int >> 64
|
||||
future: Future[Any] = Future()
|
||||
self.utility_results[call_id] = future
|
||||
|
||||
self._send_input(EngineCoreRequestType.UTILITY,
|
||||
(call_id, method, args))
|
||||
|
||||
@ -419,34 +476,34 @@ class SyncMPClient(MPClient):
|
||||
self._send_input(EngineCoreRequestType.ABORT, request_ids)
|
||||
|
||||
def profile(self, is_start: bool = True) -> None:
|
||||
self._call_utility("profile", is_start)
|
||||
self.call_utility("profile", is_start)
|
||||
|
||||
def reset_prefix_cache(self) -> None:
|
||||
self._call_utility("reset_prefix_cache")
|
||||
self.call_utility("reset_prefix_cache")
|
||||
|
||||
def add_lora(self, lora_request: LoRARequest) -> bool:
|
||||
return self._call_utility("add_lora", lora_request)
|
||||
return self.call_utility("add_lora", lora_request)
|
||||
|
||||
def remove_lora(self, lora_id: int) -> bool:
|
||||
return self._call_utility("remove_lora", lora_id)
|
||||
return self.call_utility("remove_lora", lora_id)
|
||||
|
||||
def list_loras(self) -> set[int]:
|
||||
return self._call_utility("list_loras")
|
||||
return self.call_utility("list_loras")
|
||||
|
||||
def pin_lora(self, lora_id: int) -> bool:
|
||||
return self._call_utility("pin_lora", lora_id)
|
||||
return self.call_utility("pin_lora", lora_id)
|
||||
|
||||
def sleep(self, level: int = 1) -> None:
|
||||
self._call_utility("sleep", level)
|
||||
self.call_utility("sleep", level)
|
||||
|
||||
def wake_up(self) -> None:
|
||||
self._call_utility("wake_up")
|
||||
self.call_utility("wake_up")
|
||||
|
||||
def is_sleeping(self) -> bool:
|
||||
return self._call_utility("is_sleeping")
|
||||
return self.call_utility("is_sleeping")
|
||||
|
||||
def execute_dummy_batch(self) -> None:
|
||||
self._call_utility("execute_dummy_batch")
|
||||
self.call_utility("execute_dummy_batch")
|
||||
|
||||
|
||||
class AsyncMPClient(MPClient):
|
||||
@ -464,13 +521,21 @@ class AsyncMPClient(MPClient):
|
||||
self.outputs_queue: Optional[asyncio.Queue[EngineCoreOutputs]] = None
|
||||
self.queue_task: Optional[asyncio.Task] = None
|
||||
|
||||
async def _start_output_queue_task(self):
|
||||
self.outputs_handler: Optional[Callable[
|
||||
[AsyncMPClient, EngineCoreOutputs], Awaitable[None]]] = None
|
||||
|
||||
def _ensure_output_queue_task(self):
|
||||
if self.outputs_queue is not None:
|
||||
return
|
||||
|
||||
# Perform IO in separate task to parallelize as much as possible.
|
||||
# Avoid task having direct reference back to the client.
|
||||
self.outputs_queue = asyncio.Queue()
|
||||
decoder = self.decoder
|
||||
utility_results = self.utility_results
|
||||
outputs_queue = self.outputs_queue
|
||||
output_handler = self.outputs_handler
|
||||
_self_ref = weakref.ref(self) if output_handler else None
|
||||
output_path = self.output_path
|
||||
output_socket = make_zmq_socket(self.ctx, output_path,
|
||||
zmq.constants.PULL)
|
||||
@ -483,34 +548,52 @@ class AsyncMPClient(MPClient):
|
||||
if outputs.utility_output:
|
||||
_process_utility_output(outputs.utility_output,
|
||||
utility_results)
|
||||
else:
|
||||
continue
|
||||
|
||||
if output_handler is not None:
|
||||
assert _self_ref is not None
|
||||
_self = _self_ref()
|
||||
if not _self:
|
||||
# Client has been garbage collected, abort.
|
||||
return
|
||||
await output_handler(_self, outputs)
|
||||
|
||||
if outputs.outputs or outputs.scheduler_stats:
|
||||
outputs_queue.put_nowait(outputs)
|
||||
|
||||
self.queue_task = asyncio.create_task(process_outputs_socket(),
|
||||
name="EngineCoreOutputQueueTask")
|
||||
|
||||
async def get_output_async(self) -> EngineCoreOutputs:
|
||||
if self.outputs_queue is None:
|
||||
await self._start_output_queue_task()
|
||||
assert self.outputs_queue is not None
|
||||
self._ensure_output_queue_task()
|
||||
assert self.outputs_queue is not None
|
||||
return await self.outputs_queue.get()
|
||||
|
||||
async def _send_input(self, request_type: EngineCoreRequestType,
|
||||
request: Any) -> None:
|
||||
await self.core_engine.send_multipart(
|
||||
(request_type.value, self.encoder.encode(request)))
|
||||
|
||||
msg = (request_type.value, self.encoder.encode(request))
|
||||
await self.input_socket.send_multipart(msg, copy=False)
|
||||
self._ensure_output_queue_task()
|
||||
|
||||
if self.outputs_queue is None:
|
||||
await self._start_output_queue_task()
|
||||
async def call_utility_async(self, method: str, *args) -> Any:
|
||||
return await self._call_utility_async(method,
|
||||
*args,
|
||||
engine=self.core_engine)
|
||||
|
||||
async def _call_utility_async(self, method: str, *args) -> Any:
|
||||
async def _call_utility_async(
|
||||
self,
|
||||
method: str,
|
||||
*args,
|
||||
engine: CoreEngine,
|
||||
) -> Any:
|
||||
call_id = uuid.uuid1().int >> 64
|
||||
future = asyncio.get_running_loop().create_future()
|
||||
self.utility_results[call_id] = future
|
||||
await self._send_input(EngineCoreRequestType.UTILITY,
|
||||
(call_id, method, args))
|
||||
|
||||
message = (EngineCoreRequestType.UTILITY.value,
|
||||
self.encoder.encode((call_id, method, args)))
|
||||
await engine.send_multipart(message)
|
||||
self._ensure_output_queue_task()
|
||||
return await future
|
||||
|
||||
async def add_request_async(self, request: EngineCoreRequest) -> None:
|
||||
@ -524,31 +607,146 @@ class AsyncMPClient(MPClient):
|
||||
await self._send_input(EngineCoreRequestType.ABORT, request_ids)
|
||||
|
||||
async def profile_async(self, is_start: bool = True) -> None:
|
||||
await self._call_utility_async("profile", is_start)
|
||||
await self.call_utility_async("profile", is_start)
|
||||
|
||||
async def reset_prefix_cache_async(self) -> None:
|
||||
await self._call_utility_async("reset_prefix_cache")
|
||||
await self.call_utility_async("reset_prefix_cache")
|
||||
|
||||
async def sleep_async(self, level: int = 1) -> None:
|
||||
await self._call_utility_async("sleep", level)
|
||||
await self.call_utility_async("sleep", level)
|
||||
|
||||
async def wake_up_async(self) -> None:
|
||||
await self._call_utility_async("wake_up")
|
||||
await self.call_utility_async("wake_up")
|
||||
|
||||
async def is_sleeping_async(self) -> bool:
|
||||
return await self._call_utility_async("is_sleeping")
|
||||
return await self.call_utility_async("is_sleeping")
|
||||
|
||||
async def execute_dummy_batch_async(self) -> None:
|
||||
await self._call_utility_async("execute_dummy_batch")
|
||||
await self.call_utility_async("execute_dummy_batch")
|
||||
|
||||
async def add_lora_async(self, lora_request: LoRARequest) -> bool:
|
||||
return await self._call_utility_async("add_lora", lora_request)
|
||||
return await self.call_utility_async("add_lora", lora_request)
|
||||
|
||||
async def remove_lora_async(self, lora_id: int) -> bool:
|
||||
return await self._call_utility_async("remove_lora", lora_id)
|
||||
return await self.call_utility_async("remove_lora", lora_id)
|
||||
|
||||
async def list_loras_async(self) -> set[int]:
|
||||
return await self._call_utility_async("list_loras")
|
||||
return await self.call_utility_async("list_loras")
|
||||
|
||||
async def pin_lora_async(self, lora_id: int) -> bool:
|
||||
return await self._call_utility_async("pin_lora", lora_id)
|
||||
return await self.call_utility_async("pin_lora", lora_id)
|
||||
|
||||
|
||||
class DPAsyncMPClient(AsyncMPClient):
|
||||
"""Asyncio-compatible client for multi-proc, multi-engine (data parallel)
|
||||
EngineCore."""
|
||||
|
||||
def __init__(self, vllm_config: VllmConfig, executor_class: type[Executor],
|
||||
log_stats: bool):
|
||||
super().__init__(vllm_config, executor_class, log_stats)
|
||||
|
||||
assert len(self.core_engines) > 1
|
||||
|
||||
# Control message used for triggering dp idle mode loop.
|
||||
self.start_dp_msg = (EngineCoreRequestType.START_DP.value,
|
||||
self.encoder.encode(None))
|
||||
|
||||
self.num_engines_running = 0
|
||||
self.reqs_in_flight: dict[str, CoreEngine] = {}
|
||||
|
||||
self.outputs_handler = DPAsyncMPClient.process_engine_outputs # type: ignore[assignment]
|
||||
|
||||
def _init_core_engines(
|
||||
self,
|
||||
vllm_config: VllmConfig,
|
||||
new_core_engine: Callable[[int, Optional[int]], CoreEngine],
|
||||
core_engines: list[CoreEngine],
|
||||
) -> None:
|
||||
|
||||
# Launch a core engine for each data parallel rank.
|
||||
dp_size = vllm_config.parallel_config.data_parallel_size
|
||||
for i in range(dp_size):
|
||||
# Multi-node not yet supported so local_dp_rank == dp_rank.
|
||||
core_engines.append(new_core_engine(i, i))
|
||||
|
||||
self.core_engines = core_engines
|
||||
|
||||
async def call_utility_async(self, method: str, *args) -> Any:
|
||||
# Only the result from the first engine is returned.
|
||||
return (await asyncio.gather(*[
|
||||
self._call_utility_async(method, *args, engine=engine)
|
||||
for engine in self.core_engines
|
||||
]))[0]
|
||||
|
||||
async def add_request_async(self, request: EngineCoreRequest) -> None:
|
||||
# NOTE: text prompt is not needed in the core engine as it has been
|
||||
# tokenized.
|
||||
request.prompt = None
|
||||
|
||||
msg = (EngineCoreRequestType.ADD.value, self.encoder.encode(request))
|
||||
|
||||
chosen_engine = self.get_core_engine_for_request()
|
||||
self.reqs_in_flight[request.request_id] = chosen_engine
|
||||
chosen_engine.num_reqs_in_flight += 1
|
||||
if self.num_engines_running >= len(self.core_engines):
|
||||
await chosen_engine.send_multipart(msg)
|
||||
else:
|
||||
# Send request to chosen engine and dp start loop
|
||||
# control message to all other engines.
|
||||
self.num_engines_running += len(self.core_engines)
|
||||
await asyncio.gather(*[
|
||||
engine.send_multipart(msg if engine is
|
||||
chosen_engine else self.start_dp_msg)
|
||||
for engine in self.core_engines
|
||||
])
|
||||
|
||||
self._ensure_output_queue_task()
|
||||
|
||||
def get_core_engine_for_request(self) -> CoreEngine:
|
||||
return min(self.core_engines, key=lambda e: e.num_reqs_in_flight)
|
||||
|
||||
@staticmethod
|
||||
async def process_engine_outputs(self: "DPAsyncMPClient",
|
||||
outputs: EngineCoreOutputs):
|
||||
if self.reqs_in_flight:
|
||||
for req_id in outputs.finished_requests or ():
|
||||
if engine := self.reqs_in_flight.pop(req_id, None):
|
||||
engine.num_reqs_in_flight -= 1
|
||||
|
||||
if outputs.engine_paused:
|
||||
assert self.num_engines_running >= 1
|
||||
self.num_engines_running -= 1
|
||||
if not self.num_engines_running and self.reqs_in_flight:
|
||||
# If there are requests in flight here, they must have
|
||||
# been sent after the engines paused. We must make
|
||||
# sure to start the other engines:
|
||||
self.num_engines_running = len(self.core_engines)
|
||||
coros = [
|
||||
engine.send_multipart(self.start_dp_msg)
|
||||
for engine in self.core_engines
|
||||
if not engine.num_reqs_in_flight
|
||||
]
|
||||
if coros:
|
||||
await asyncio.gather(*coros)
|
||||
|
||||
async def abort_requests_async(self, request_ids: list[str]) -> None:
|
||||
if not request_ids:
|
||||
return
|
||||
|
||||
if len(request_ids) == 1:
|
||||
# Fast-path common case.
|
||||
if engine := self.reqs_in_flight.get(request_ids[0]):
|
||||
await self._abort_requests(request_ids, engine)
|
||||
return
|
||||
|
||||
by_engine: dict[CoreEngine, list[str]] = {}
|
||||
for req_id in request_ids:
|
||||
if engine := self.reqs_in_flight.get(req_id):
|
||||
by_engine.setdefault(engine, []).append(req_id)
|
||||
for engine, req_ids in by_engine.items():
|
||||
await self._abort_requests(req_ids, engine)
|
||||
|
||||
async def _abort_requests(self, request_ids: list[str],
|
||||
engine: CoreEngine) -> None:
|
||||
await engine.send_multipart((EngineCoreRequestType.ABORT.value,
|
||||
self.encoder.encode(request_ids)))
|
||||
|
@ -8,6 +8,7 @@ from typing_extensions import TypeVar
|
||||
|
||||
import vllm.envs as envs
|
||||
from vllm.config import ParallelConfig, VllmConfig
|
||||
from vllm.distributed import stateless_destroy_torch_distributed_process_group
|
||||
from vllm.engine.arg_utils import EngineArgs
|
||||
from vllm.engine.metrics_types import StatLoggerBase
|
||||
from vllm.inputs import INPUT_REGISTRY, InputRegistry, PromptType
|
||||
@ -60,11 +61,13 @@ class LLMEngine:
|
||||
self.cache_config = vllm_config.cache_config
|
||||
|
||||
# important: init dp group before init the engine_core
|
||||
self.parallel_config = vllm_config.parallel_config
|
||||
self.dp_enabled = self.parallel_config.data_parallel_size > 1 # noqa
|
||||
# In the decoupled engine case this is handled in EngineCoreProc.
|
||||
parallel_config = vllm_config.parallel_config
|
||||
if not multiprocess_mode and parallel_config.data_parallel_size > 1:
|
||||
self.dp_group = parallel_config.stateless_init_dp_group()
|
||||
else:
|
||||
self.dp_group = None
|
||||
self.should_execute_dummy_batch = False
|
||||
if self.dp_enabled:
|
||||
self.dp_group = self.parallel_config.stateless_init_dp_group()
|
||||
|
||||
# Tokenizer (+ ensure liveness if running in another process).
|
||||
self.tokenizer = init_tokenizer_from_configs(
|
||||
@ -148,7 +151,7 @@ class LLMEngine:
|
||||
|
||||
def has_unfinished_requests(self) -> bool:
|
||||
has_unfinished = self.output_processor.has_unfinished_requests()
|
||||
if not self.dp_enabled:
|
||||
if self.dp_group is None:
|
||||
return has_unfinished
|
||||
return self.has_unfinished_requests_dp(has_unfinished)
|
||||
|
||||
@ -280,3 +283,7 @@ class LLMEngine:
|
||||
def pin_lora(self, lora_id: int) -> bool:
|
||||
"""Prevent an adapter from being evicted."""
|
||||
return self.engine_core.pin_lora(lora_id)
|
||||
|
||||
def __del__(self):
|
||||
if dp_group := getattr(self, "dp_group", None):
|
||||
stateless_destroy_torch_distributed_process_group(dp_group)
|
||||
|
@ -235,7 +235,10 @@ class WorkerProc:
|
||||
worker_response_mq_handle = self.worker_response_mq.export_handle()
|
||||
|
||||
# Send Readiness signal to EngineCore process.
|
||||
with zmq_socket_ctx(ready_path, zmq.constants.PUSH) as ready_socket:
|
||||
# Set linger here because we want to ensure the message has
|
||||
# been sent before the context is closed.
|
||||
with zmq_socket_ctx(ready_path, zmq.constants.PUSH,
|
||||
linger=10000) as ready_socket:
|
||||
payload = pickle.dumps(worker_response_mq_handle,
|
||||
protocol=pickle.HIGHEST_PROTOCOL)
|
||||
ready_socket.send_string(WorkerProc.READY_STR)
|
||||
@ -270,11 +273,13 @@ class WorkerProc:
|
||||
proc = context.Process(target=WorkerProc.worker_main,
|
||||
kwargs=process_kwargs,
|
||||
daemon=True)
|
||||
proc.start()
|
||||
|
||||
# Wait for startup
|
||||
worker_response_mq_handle = WorkerProc.wait_for_startup(
|
||||
proc, ready_path)
|
||||
with zmq_socket_ctx(ready_path, zmq.constants.PULL) as ready_socket:
|
||||
proc.start()
|
||||
|
||||
# Wait for startup
|
||||
worker_response_mq_handle = WorkerProc.wait_for_startup(
|
||||
proc, ready_socket)
|
||||
|
||||
worker_response_mq = MessageQueue.create_from_handle(
|
||||
worker_response_mq_handle, 0)
|
||||
@ -337,23 +342,22 @@ class WorkerProc:
|
||||
@staticmethod
|
||||
def wait_for_startup(
|
||||
proc: BaseProcess,
|
||||
ready_path: str,
|
||||
ready_socket: zmq.Socket,
|
||||
) -> Optional[Handle]:
|
||||
"""Wait until the Worker is ready."""
|
||||
with zmq_socket_ctx(ready_path, zmq.constants.PULL) as socket:
|
||||
|
||||
# Wait for Worker to send READY.
|
||||
while socket.poll(timeout=POLLING_TIMEOUT_MS) == 0:
|
||||
logger.debug("Waiting for WorkerProc to startup.")
|
||||
# Wait for Worker to send READY.
|
||||
while ready_socket.poll(timeout=POLLING_TIMEOUT_MS) == 0:
|
||||
logger.debug("Waiting for WorkerProc to startup.")
|
||||
|
||||
if not proc.is_alive():
|
||||
raise RuntimeError("WorkerProc failed to start.")
|
||||
if not proc.is_alive():
|
||||
raise RuntimeError("WorkerProc failed to start.")
|
||||
|
||||
message = socket.recv_string()
|
||||
assert message == WorkerProc.READY_STR
|
||||
handle_frame = socket.recv(copy=False)
|
||||
handle = pickle.loads(handle_frame.buffer)
|
||||
return handle
|
||||
message = ready_socket.recv_string()
|
||||
assert message == WorkerProc.READY_STR
|
||||
handle_frame = ready_socket.recv(copy=False)
|
||||
handle = pickle.loads(handle_frame.buffer)
|
||||
return handle
|
||||
|
||||
class ResponseStatus(Enum):
|
||||
SUCCESS = auto()
|
||||
|
@ -31,7 +31,8 @@ class StatLoggerBase(ABC):
|
||||
|
||||
class LoggingStatLogger(StatLoggerBase):
|
||||
|
||||
def __init__(self):
|
||||
def __init__(self, engine_index: int = 0):
|
||||
self.engine_index = engine_index
|
||||
self._reset(time.monotonic())
|
||||
self.last_scheduler_stats = SchedulerStats()
|
||||
# Prefix cache metrics. This cannot be reset.
|
||||
@ -78,11 +79,13 @@ class LoggingStatLogger(StatLoggerBase):
|
||||
|
||||
# Format and print output.
|
||||
logger.info(
|
||||
"Engine %03d: "
|
||||
"Avg prompt throughput: %.1f tokens/s, "
|
||||
"Avg generation throughput: %.1f tokens/s, "
|
||||
"Running: %d reqs, Waiting: %d reqs, "
|
||||
"GPU KV cache usage: %.1f%%, "
|
||||
"Prefix cache hit rate: %.1f%%",
|
||||
self.engine_index,
|
||||
prompt_throughput,
|
||||
generation_throughput,
|
||||
scheduler_stats.num_running_reqs,
|
||||
@ -94,7 +97,7 @@ class LoggingStatLogger(StatLoggerBase):
|
||||
|
||||
class PrometheusStatLogger(StatLoggerBase):
|
||||
|
||||
def __init__(self, vllm_config: VllmConfig):
|
||||
def __init__(self, vllm_config: VllmConfig, engine_index: int = 0):
|
||||
self._unregister_vllm_metrics()
|
||||
|
||||
# Use this flag to hide metrics that were deprecated in
|
||||
@ -102,8 +105,11 @@ class PrometheusStatLogger(StatLoggerBase):
|
||||
self.show_hidden_metrics = \
|
||||
vllm_config.observability_config.show_hidden_metrics
|
||||
|
||||
labelnames = ["model_name"]
|
||||
labelvalues = [vllm_config.model_config.served_model_name]
|
||||
labelnames = ["model_name", "engine"]
|
||||
labelvalues = [
|
||||
vllm_config.model_config.served_model_name,
|
||||
str(engine_index)
|
||||
]
|
||||
|
||||
max_model_len = vllm_config.model_config.max_model_len
|
||||
|
||||
|
@ -105,7 +105,7 @@ class BackgroundProcHandle:
|
||||
process_kwargs: dict[Any, Any],
|
||||
):
|
||||
context = get_mp_context()
|
||||
reader, writer = context.Pipe(duplex=False)
|
||||
self.reader, writer = context.Pipe(duplex=False)
|
||||
|
||||
assert ("ready_pipe" not in process_kwargs
|
||||
and "input_path" not in process_kwargs
|
||||
@ -115,14 +115,17 @@ class BackgroundProcHandle:
|
||||
process_kwargs["output_path"] = output_path
|
||||
|
||||
# Run busy loop in background process.
|
||||
self.proc = context.Process(target=target_fn, kwargs=process_kwargs)
|
||||
self.proc = context.Process(target=target_fn,
|
||||
kwargs=process_kwargs,
|
||||
name=process_name)
|
||||
self._finalizer = weakref.finalize(self, shutdown, self.proc,
|
||||
input_path, output_path)
|
||||
self.proc.start()
|
||||
|
||||
def wait_for_startup(self):
|
||||
# Wait for startup.
|
||||
if reader.recv()["status"] != "READY":
|
||||
raise RuntimeError(f"{process_name} initialization failed. "
|
||||
if self.reader.recv()["status"] != "READY":
|
||||
raise RuntimeError(f"{self.proc.name} initialization failed. "
|
||||
"See root cause above.")
|
||||
|
||||
def shutdown(self):
|
||||
|
Loading…
x
Reference in New Issue
Block a user