794 lines
29 KiB
Python
794 lines
29 KiB
Python
# SPDX-License-Identifier: Apache-2.0
|
|
|
|
import asyncio
|
|
import os
|
|
import queue
|
|
import signal
|
|
import threading
|
|
import uuid
|
|
import weakref
|
|
from abc import ABC, abstractmethod
|
|
from collections.abc import Awaitable, Sequence
|
|
from concurrent.futures import Future
|
|
from dataclasses import dataclass, field
|
|
from threading import Thread
|
|
from typing import Any, Callable, Optional, TypeVar, Union
|
|
|
|
import zmq
|
|
import zmq.asyncio
|
|
|
|
from vllm.config import VllmConfig
|
|
from vllm.logger import init_logger
|
|
from vllm.lora.request import LoRARequest
|
|
from vllm.utils import (get_open_zmq_inproc_path, get_open_zmq_ipc_path,
|
|
kill_process_tree, make_zmq_socket)
|
|
from vllm.v1.engine import (EngineCoreOutputs, EngineCoreRequest,
|
|
EngineCoreRequestType, UtilityOutput)
|
|
from vllm.v1.engine.core import EngineCore, EngineCoreProc
|
|
from vllm.v1.executor.abstract import Executor
|
|
from vllm.v1.serial_utils import MsgpackDecoder, MsgpackEncoder
|
|
from vllm.v1.utils import BackgroundProcHandle
|
|
|
|
logger = init_logger(__name__)
|
|
|
|
AnyFuture = Union[asyncio.Future[Any], Future[Any]]
|
|
|
|
_R = TypeVar('_R') # Return type for collective_rpc
|
|
|
|
|
|
class EngineCoreClient(ABC):
|
|
"""
|
|
EngineCoreClient: subclasses handle different methods for pushing
|
|
and pulling from the EngineCore for asyncio / multiprocessing.
|
|
|
|
Subclasses:
|
|
* InprocClient: In process EngineCore (for V0-style LLMEngine use)
|
|
* SyncMPClient: ZMQ + background proc EngineCore (for LLM)
|
|
* AsyncMPClient: ZMQ + background proc EngineCore w/ asyncio (for AsyncLLM)
|
|
"""
|
|
|
|
@staticmethod
|
|
def make_client(
|
|
multiprocess_mode: bool,
|
|
asyncio_mode: bool,
|
|
vllm_config: VllmConfig,
|
|
executor_class: type[Executor],
|
|
log_stats: bool,
|
|
) -> "EngineCoreClient":
|
|
|
|
# TODO: support this for debugging purposes.
|
|
if asyncio_mode and not multiprocess_mode:
|
|
raise NotImplementedError(
|
|
"Running EngineCore in asyncio without multiprocessing "
|
|
"is not currently supported.")
|
|
|
|
if multiprocess_mode and asyncio_mode:
|
|
if vllm_config.parallel_config.data_parallel_size > 1:
|
|
return DPAsyncMPClient(vllm_config, executor_class, log_stats)
|
|
|
|
return AsyncMPClient(vllm_config, executor_class, log_stats)
|
|
|
|
if multiprocess_mode and not asyncio_mode:
|
|
return SyncMPClient(vllm_config, executor_class, log_stats)
|
|
|
|
return InprocClient(vllm_config, executor_class, log_stats)
|
|
|
|
@abstractmethod
|
|
def shutdown(self):
|
|
...
|
|
|
|
def get_output(self) -> EngineCoreOutputs:
|
|
raise NotImplementedError
|
|
|
|
def add_request(self, request: EngineCoreRequest) -> None:
|
|
raise NotImplementedError
|
|
|
|
def profile(self, is_start: bool = True) -> None:
|
|
raise NotImplementedError
|
|
|
|
def reset_prefix_cache(self) -> None:
|
|
raise NotImplementedError
|
|
|
|
def sleep(self, level: int = 1) -> None:
|
|
raise NotImplementedError
|
|
|
|
def wake_up(self, tags: Optional[list[str]] = None) -> None:
|
|
raise NotImplementedError
|
|
|
|
def is_sleeping(self) -> bool:
|
|
raise NotImplementedError
|
|
|
|
def execute_dummy_batch(self) -> None:
|
|
raise NotImplementedError
|
|
|
|
async def execute_dummy_batch_async(self) -> None:
|
|
raise NotImplementedError
|
|
|
|
def abort_requests(self, request_ids: list[str]) -> None:
|
|
raise NotImplementedError
|
|
|
|
def add_lora(self, lora_request: LoRARequest) -> bool:
|
|
raise NotImplementedError
|
|
|
|
def remove_lora(self, lora_id: int) -> bool:
|
|
raise NotImplementedError
|
|
|
|
def list_loras(self) -> set[int]:
|
|
raise NotImplementedError
|
|
|
|
def pin_lora(self, lora_id: int) -> bool:
|
|
raise NotImplementedError
|
|
|
|
def collective_rpc(self,
|
|
method: Union[str, Callable[..., _R]],
|
|
timeout: Optional[float] = None,
|
|
args: tuple = (),
|
|
kwargs: Optional[dict[str, Any]] = None) -> list[_R]:
|
|
raise NotImplementedError
|
|
|
|
async def get_output_async(self) -> EngineCoreOutputs:
|
|
raise NotImplementedError
|
|
|
|
async def add_request_async(self, request: EngineCoreRequest) -> None:
|
|
raise NotImplementedError
|
|
|
|
async def profile_async(self, is_start: bool = True) -> None:
|
|
raise NotImplementedError
|
|
|
|
async def reset_prefix_cache_async(self) -> None:
|
|
raise NotImplementedError
|
|
|
|
async def sleep_async(self, level: int = 1) -> None:
|
|
raise NotImplementedError
|
|
|
|
async def wake_up_async(self, tags: Optional[list[str]] = None) -> None:
|
|
raise NotImplementedError
|
|
|
|
async def is_sleeping_async(self) -> bool:
|
|
raise NotImplementedError
|
|
|
|
async def abort_requests_async(self, request_ids: list[str]) -> None:
|
|
raise NotImplementedError
|
|
|
|
async def add_lora_async(self, lora_request: LoRARequest) -> bool:
|
|
raise NotImplementedError
|
|
|
|
async def remove_lora_async(self, lora_id: int) -> bool:
|
|
raise NotImplementedError
|
|
|
|
async def list_loras_async(self) -> set[int]:
|
|
raise NotImplementedError
|
|
|
|
async def pin_lora_async(self, lora_id: int) -> bool:
|
|
raise NotImplementedError
|
|
|
|
async def collective_rpc_async(
|
|
self,
|
|
method: Union[str, Callable[..., _R]],
|
|
timeout: Optional[float] = None,
|
|
args: tuple = (),
|
|
kwargs: Optional[dict[str, Any]] = None) -> list[_R]:
|
|
raise NotImplementedError
|
|
|
|
|
|
class InprocClient(EngineCoreClient):
|
|
"""
|
|
InprocClient: client for in-process EngineCore. Intended
|
|
for use in LLMEngine for V0-style add_request() and step()
|
|
EngineCore setup in this process (no busy loop).
|
|
|
|
* pushes EngineCoreRequest directly into the EngineCore
|
|
* pulls EngineCoreOutputs by stepping the EngineCore
|
|
"""
|
|
|
|
def __init__(self, *args, **kwargs):
|
|
self.engine_core = EngineCore(*args, **kwargs)
|
|
|
|
def get_output(self) -> EngineCoreOutputs:
|
|
return self.engine_core.step()
|
|
|
|
def add_request(self, request: EngineCoreRequest) -> None:
|
|
self.engine_core.add_request(request)
|
|
|
|
def abort_requests(self, request_ids: list[str]) -> None:
|
|
if len(request_ids) > 0:
|
|
self.engine_core.abort_requests(request_ids)
|
|
|
|
def shutdown(self) -> None:
|
|
self.engine_core.shutdown()
|
|
|
|
def profile(self, is_start: bool = True) -> None:
|
|
self.engine_core.profile(is_start)
|
|
|
|
def reset_prefix_cache(self) -> None:
|
|
self.engine_core.reset_prefix_cache()
|
|
|
|
def sleep(self, level: int = 1) -> None:
|
|
self.engine_core.sleep(level)
|
|
|
|
def wake_up(self, tags: Optional[list[str]] = None) -> None:
|
|
self.engine_core.wake_up(tags)
|
|
|
|
def is_sleeping(self) -> bool:
|
|
return self.engine_core.is_sleeping()
|
|
|
|
def execute_dummy_batch(self) -> None:
|
|
self.engine_core.execute_dummy_batch()
|
|
|
|
def add_lora(self, lora_request: LoRARequest) -> bool:
|
|
return self.engine_core.add_lora(lora_request)
|
|
|
|
def remove_lora(self, lora_id: int) -> bool:
|
|
return self.engine_core.remove_lora(lora_id)
|
|
|
|
def list_loras(self) -> set[int]:
|
|
return self.engine_core.list_loras()
|
|
|
|
def pin_lora(self, lora_id: int) -> bool:
|
|
return self.engine_core.pin_lora(lora_id)
|
|
|
|
def collective_rpc(self,
|
|
method: Union[str, Callable[..., _R]],
|
|
timeout: Optional[float] = None,
|
|
args: tuple = (),
|
|
kwargs: Optional[dict[str, Any]] = None) -> list[_R]:
|
|
return self.engine_core.collective_rpc(method, timeout, args, kwargs)
|
|
|
|
|
|
class CoreEngine:
|
|
"""One per data parallel rank."""
|
|
|
|
def __init__(
|
|
self,
|
|
vllm_config: VllmConfig,
|
|
executor_class: type[Executor],
|
|
log_stats: bool,
|
|
ctx: Union[zmq.Context, zmq.asyncio.Context],
|
|
output_path: str,
|
|
index: int = 0,
|
|
local_dp_rank: int = 0,
|
|
):
|
|
# Paths and sockets for IPC.
|
|
input_path = get_open_zmq_ipc_path()
|
|
self.input_socket = make_zmq_socket(ctx, input_path,
|
|
zmq.constants.PUSH)
|
|
try:
|
|
# Start EngineCore in background process.
|
|
self.proc_handle = BackgroundProcHandle(
|
|
input_path=input_path,
|
|
output_path=output_path,
|
|
process_name=f"EngineCore_{index}",
|
|
target_fn=EngineCoreProc.run_engine_core,
|
|
process_kwargs={
|
|
"vllm_config": vllm_config,
|
|
"dp_rank": index,
|
|
"local_dp_rank": local_dp_rank,
|
|
"executor_class": executor_class,
|
|
"log_stats": log_stats,
|
|
})
|
|
|
|
self.num_reqs_in_flight = 0
|
|
finally:
|
|
if not hasattr(self, "num_reqs_in_flight"):
|
|
# Ensure socket is closed if process fails to start.
|
|
self.close()
|
|
|
|
def send_multipart(self, msg_parts: Sequence):
|
|
return self.input_socket.send_multipart(msg_parts, copy=False)
|
|
|
|
def close(self):
|
|
if proc_handle := getattr(self, "proc_handle", None):
|
|
proc_handle.shutdown()
|
|
if socket := getattr(self, "input_socket", None):
|
|
socket.close(linger=0)
|
|
|
|
|
|
@dataclass
|
|
class BackgroundResources:
|
|
"""Used as a finalizer for clean shutdown, avoiding
|
|
circular reference back to the client object."""
|
|
|
|
ctx: Union[zmq.Context]
|
|
core_engines: list[CoreEngine] = field(default_factory=list)
|
|
output_socket: Optional[Union[zmq.Socket, zmq.asyncio.Socket]] = None
|
|
shutdown_path: Optional[str] = None
|
|
|
|
def __call__(self):
|
|
"""Clean up background resources."""
|
|
|
|
for core_engine in self.core_engines:
|
|
core_engine.close()
|
|
|
|
# ZMQ context termination can hang if the sockets
|
|
# aren't explicitly closed first.
|
|
if self.output_socket is not None:
|
|
self.output_socket.close(linger=0)
|
|
if self.shutdown_path is not None:
|
|
# We must ensure that the sync output socket is
|
|
# closed cleanly in its own thread.
|
|
with self.ctx.socket(zmq.PAIR) as shutdown_sender:
|
|
shutdown_sender.connect(self.shutdown_path)
|
|
# Send shutdown signal.
|
|
shutdown_sender.send(b'')
|
|
|
|
|
|
class MPClient(EngineCoreClient):
|
|
"""
|
|
MPClient: base client for multi-proc EngineCore.
|
|
EngineCore runs in a background process busy loop, getting
|
|
new EngineCoreRequests and returning EngineCoreOutputs
|
|
|
|
* pushes EngineCoreRequests via input_socket
|
|
* pulls EngineCoreOutputs via output_socket
|
|
|
|
* AsyncMPClient subclass for AsyncLLM usage
|
|
* SyncMPClient subclass for LLM usage
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
asyncio_mode: bool,
|
|
vllm_config: VllmConfig,
|
|
executor_class: type[Executor],
|
|
log_stats: bool,
|
|
):
|
|
# The child processes will send SIGUSR1 when unrecoverable
|
|
# errors happen. We kill the process tree here so that the
|
|
# stack trace is very evident.
|
|
# TODO(rob): rather than killing the main process, we should
|
|
# figure out how to raise an AsyncEngineDeadError and
|
|
# handle at the API server level so we can return a better
|
|
# error code to the clients calling vLLM.
|
|
def sigusr1_handler(signum, frame):
|
|
logger.fatal("Got fatal signal from worker processes, shutting "
|
|
"down. See stack trace above for root cause issue.")
|
|
kill_process_tree(os.getpid())
|
|
|
|
if threading.current_thread() == threading.main_thread():
|
|
signal.signal(signal.SIGUSR1, sigusr1_handler)
|
|
else:
|
|
logger.warning("SIGUSR1 handler not installed because we are not "
|
|
"running in the main thread. In this case the "
|
|
"forked engine process may not be killed when "
|
|
"an exception is raised, and you need to handle "
|
|
"the engine process shutdown manually.")
|
|
|
|
# Serialization setup.
|
|
self.encoder = MsgpackEncoder()
|
|
self.decoder = MsgpackDecoder(EngineCoreOutputs)
|
|
|
|
# ZMQ setup.
|
|
sync_ctx = zmq.Context(io_threads=2)
|
|
self.ctx = zmq.asyncio.Context(sync_ctx) if asyncio_mode else sync_ctx
|
|
|
|
# This will ensure resources created so far are closed
|
|
# when the client is garbage collected, even if an
|
|
# exception is raised mid-construction.
|
|
self.resources = BackgroundResources(ctx=sync_ctx)
|
|
self._finalizer = weakref.finalize(self, self.resources)
|
|
|
|
# Paths and sockets for IPC.
|
|
self.output_path = get_open_zmq_ipc_path()
|
|
|
|
new_core_engine = lambda index, local_dp_rank=None: CoreEngine(
|
|
vllm_config, executor_class, log_stats, self.ctx, self.output_path,
|
|
index, local_dp_rank)
|
|
|
|
# Start engine core process(es).
|
|
self._init_core_engines(vllm_config, new_core_engine,
|
|
self.resources.core_engines)
|
|
|
|
# Wait for engine core process(es) to start.
|
|
for engine in self.resources.core_engines:
|
|
engine.proc_handle.wait_for_startup()
|
|
|
|
self.utility_results: dict[int, AnyFuture] = {}
|
|
|
|
def _init_core_engines(
|
|
self,
|
|
vllm_config: VllmConfig,
|
|
new_core_engine: Callable[[int, Optional[int]], CoreEngine],
|
|
core_engines: list[CoreEngine],
|
|
) -> None:
|
|
|
|
# Default case - single core engine.
|
|
dp_rank = vllm_config.parallel_config.data_parallel_rank
|
|
local_dp_rank = vllm_config.parallel_config.data_parallel_rank_local
|
|
core_engine = new_core_engine(
|
|
dp_rank, local_dp_rank if local_dp_rank is not None else dp_rank)
|
|
core_engines.append(core_engine)
|
|
self.core_engine = core_engine
|
|
|
|
def shutdown(self):
|
|
self._finalizer()
|
|
|
|
|
|
def _process_utility_output(output: UtilityOutput,
|
|
utility_results: dict[int, AnyFuture]):
|
|
"""Set the result from a utility method in the waiting future"""
|
|
future = utility_results.pop(output.call_id)
|
|
if output.failure_message is not None:
|
|
future.set_exception(Exception(output.failure_message))
|
|
else:
|
|
future.set_result(output.result)
|
|
|
|
|
|
class SyncMPClient(MPClient):
|
|
"""Synchronous client for multi-proc EngineCore."""
|
|
|
|
def __init__(self, vllm_config: VllmConfig, executor_class: type[Executor],
|
|
log_stats: bool):
|
|
super().__init__(
|
|
asyncio_mode=False,
|
|
vllm_config=vllm_config,
|
|
executor_class=executor_class,
|
|
log_stats=log_stats,
|
|
)
|
|
|
|
self.outputs_queue: queue.Queue[EngineCoreOutputs] = queue.Queue()
|
|
|
|
# Ensure that the outputs socket processing thread does not have
|
|
# a ref to the client which prevents gc.
|
|
ctx = self.ctx
|
|
output_path = self.output_path
|
|
decoder = self.decoder
|
|
utility_results = self.utility_results
|
|
outputs_queue = self.outputs_queue
|
|
|
|
shutdown_path = get_open_zmq_inproc_path()
|
|
self.resources.shutdown_path = shutdown_path
|
|
|
|
def process_outputs_socket():
|
|
shutdown_socket = ctx.socket(zmq.PAIR)
|
|
out_socket = make_zmq_socket(ctx, output_path, zmq.constants.PULL)
|
|
try:
|
|
shutdown_socket.bind(shutdown_path)
|
|
poller = zmq.Poller()
|
|
poller.register(shutdown_socket)
|
|
poller.register(out_socket)
|
|
while True:
|
|
socks = poller.poll()
|
|
if not socks:
|
|
continue
|
|
if len(socks) == 2 or socks[0][0] == shutdown_socket:
|
|
# shutdown signal, exit thread.
|
|
break
|
|
|
|
frame = out_socket.recv(copy=False)
|
|
outputs = decoder.decode(frame.buffer)
|
|
if outputs.utility_output:
|
|
_process_utility_output(outputs.utility_output,
|
|
utility_results)
|
|
else:
|
|
outputs_queue.put_nowait(outputs)
|
|
finally:
|
|
# Close sockets.
|
|
shutdown_socket.close(linger=0)
|
|
out_socket.close(linger=0)
|
|
|
|
# Process outputs from engine in separate thread.
|
|
self.output_queue_thread = Thread(target=process_outputs_socket,
|
|
name="EngineCoreOutputQueueThread",
|
|
daemon=True)
|
|
self.output_queue_thread.start()
|
|
|
|
def get_output(self) -> EngineCoreOutputs:
|
|
return self.outputs_queue.get()
|
|
|
|
def _send_input(self, request_type: EngineCoreRequestType, request: Any):
|
|
# (RequestType, SerializedRequest)
|
|
msg = (request_type.value, self.encoder.encode(request))
|
|
self.core_engine.send_multipart(msg)
|
|
|
|
def call_utility(self, method: str, *args) -> Any:
|
|
call_id = uuid.uuid1().int >> 64
|
|
future: Future[Any] = Future()
|
|
self.utility_results[call_id] = future
|
|
self._send_input(EngineCoreRequestType.UTILITY,
|
|
(call_id, method, args))
|
|
|
|
return future.result()
|
|
|
|
def add_request(self, request: EngineCoreRequest) -> None:
|
|
# NOTE: text prompt is not needed in the core engine as it has been
|
|
# tokenized.
|
|
request.prompt = None
|
|
self._send_input(EngineCoreRequestType.ADD, request)
|
|
|
|
def abort_requests(self, request_ids: list[str]) -> None:
|
|
if len(request_ids) > 0:
|
|
self._send_input(EngineCoreRequestType.ABORT, request_ids)
|
|
|
|
def profile(self, is_start: bool = True) -> None:
|
|
self.call_utility("profile", is_start)
|
|
|
|
def reset_prefix_cache(self) -> None:
|
|
self.call_utility("reset_prefix_cache")
|
|
|
|
def add_lora(self, lora_request: LoRARequest) -> bool:
|
|
return self.call_utility("add_lora", lora_request)
|
|
|
|
def remove_lora(self, lora_id: int) -> bool:
|
|
return self.call_utility("remove_lora", lora_id)
|
|
|
|
def list_loras(self) -> set[int]:
|
|
return self.call_utility("list_loras")
|
|
|
|
def pin_lora(self, lora_id: int) -> bool:
|
|
return self.call_utility("pin_lora", lora_id)
|
|
|
|
def sleep(self, level: int = 1) -> None:
|
|
self.call_utility("sleep", level)
|
|
|
|
def wake_up(self, tags: Optional[list[str]] = None) -> None:
|
|
self.call_utility("wake_up", tags)
|
|
|
|
def is_sleeping(self) -> bool:
|
|
return self.call_utility("is_sleeping")
|
|
|
|
def execute_dummy_batch(self) -> None:
|
|
self.call_utility("execute_dummy_batch")
|
|
|
|
def collective_rpc(self,
|
|
method: Union[str, Callable[..., _R]],
|
|
timeout: Optional[float] = None,
|
|
args: tuple = (),
|
|
kwargs: Optional[dict[str, Any]] = None) -> list[_R]:
|
|
return self.call_utility("collective_rpc", method, timeout, args,
|
|
kwargs)
|
|
|
|
|
|
class AsyncMPClient(MPClient):
|
|
"""Asyncio-compatible client for multi-proc EngineCore."""
|
|
|
|
def __init__(self, vllm_config: VllmConfig, executor_class: type[Executor],
|
|
log_stats: bool):
|
|
super().__init__(
|
|
asyncio_mode=True,
|
|
vllm_config=vllm_config,
|
|
executor_class=executor_class,
|
|
log_stats=log_stats,
|
|
)
|
|
|
|
self.outputs_queue: Optional[asyncio.Queue[EngineCoreOutputs]] = None
|
|
self.queue_task: Optional[asyncio.Task] = None
|
|
|
|
self.outputs_handler: Optional[Callable[
|
|
[AsyncMPClient, EngineCoreOutputs], Awaitable[None]]] = None
|
|
|
|
def _ensure_output_queue_task(self):
|
|
if self.outputs_queue is not None:
|
|
return
|
|
|
|
# Perform IO in separate task to parallelize as much as possible.
|
|
# Avoid task having direct reference back to the client.
|
|
self.outputs_queue = asyncio.Queue()
|
|
decoder = self.decoder
|
|
utility_results = self.utility_results
|
|
outputs_queue = self.outputs_queue
|
|
output_handler = self.outputs_handler
|
|
_self_ref = weakref.ref(self) if output_handler else None
|
|
output_path = self.output_path
|
|
output_socket = make_zmq_socket(self.ctx, output_path,
|
|
zmq.constants.PULL)
|
|
self.resources.output_socket = output_socket
|
|
|
|
async def process_outputs_socket():
|
|
while True:
|
|
(frame, ) = await output_socket.recv_multipart(copy=False)
|
|
outputs: EngineCoreOutputs = decoder.decode(frame.buffer)
|
|
if outputs.utility_output:
|
|
_process_utility_output(outputs.utility_output,
|
|
utility_results)
|
|
continue
|
|
|
|
if output_handler is not None:
|
|
assert _self_ref is not None
|
|
_self = _self_ref()
|
|
if not _self:
|
|
# Client has been garbage collected, abort.
|
|
return
|
|
await output_handler(_self, outputs)
|
|
|
|
if outputs.outputs or outputs.scheduler_stats:
|
|
outputs_queue.put_nowait(outputs)
|
|
|
|
self.queue_task = asyncio.create_task(process_outputs_socket(),
|
|
name="EngineCoreOutputQueueTask")
|
|
|
|
async def get_output_async(self) -> EngineCoreOutputs:
|
|
self._ensure_output_queue_task()
|
|
assert self.outputs_queue is not None
|
|
return await self.outputs_queue.get()
|
|
|
|
async def _send_input(self, request_type: EngineCoreRequestType,
|
|
request: Any) -> None:
|
|
await self.core_engine.send_multipart(
|
|
(request_type.value, self.encoder.encode(request)))
|
|
|
|
self._ensure_output_queue_task()
|
|
|
|
async def call_utility_async(self, method: str, *args) -> Any:
|
|
return await self._call_utility_async(method,
|
|
*args,
|
|
engine=self.core_engine)
|
|
|
|
async def _call_utility_async(
|
|
self,
|
|
method: str,
|
|
*args,
|
|
engine: CoreEngine,
|
|
) -> Any:
|
|
call_id = uuid.uuid1().int >> 64
|
|
future = asyncio.get_running_loop().create_future()
|
|
self.utility_results[call_id] = future
|
|
message = (EngineCoreRequestType.UTILITY.value,
|
|
self.encoder.encode((call_id, method, args)))
|
|
await engine.send_multipart(message)
|
|
self._ensure_output_queue_task()
|
|
return await future
|
|
|
|
async def add_request_async(self, request: EngineCoreRequest) -> None:
|
|
# NOTE: text prompt is not needed in the core engine as it has been
|
|
# tokenized.
|
|
request.prompt = None
|
|
await self._send_input(EngineCoreRequestType.ADD, request)
|
|
|
|
async def abort_requests_async(self, request_ids: list[str]) -> None:
|
|
if len(request_ids) > 0:
|
|
await self._send_input(EngineCoreRequestType.ABORT, request_ids)
|
|
|
|
async def profile_async(self, is_start: bool = True) -> None:
|
|
await self.call_utility_async("profile", is_start)
|
|
|
|
async def reset_prefix_cache_async(self) -> None:
|
|
await self.call_utility_async("reset_prefix_cache")
|
|
|
|
async def sleep_async(self, level: int = 1) -> None:
|
|
await self.call_utility_async("sleep", level)
|
|
|
|
async def wake_up_async(self, tags: Optional[list[str]] = None) -> None:
|
|
await self.call_utility_async("wake_up", tags)
|
|
|
|
async def is_sleeping_async(self) -> bool:
|
|
return await self.call_utility_async("is_sleeping")
|
|
|
|
async def execute_dummy_batch_async(self) -> None:
|
|
await self.call_utility_async("execute_dummy_batch")
|
|
|
|
async def add_lora_async(self, lora_request: LoRARequest) -> bool:
|
|
return await self.call_utility_async("add_lora", lora_request)
|
|
|
|
async def remove_lora_async(self, lora_id: int) -> bool:
|
|
return await self.call_utility_async("remove_lora", lora_id)
|
|
|
|
async def list_loras_async(self) -> set[int]:
|
|
return await self.call_utility_async("list_loras")
|
|
|
|
async def pin_lora_async(self, lora_id: int) -> bool:
|
|
return await self.call_utility_async("pin_lora", lora_id)
|
|
|
|
async def collective_rpc_async(
|
|
self,
|
|
method: Union[str, Callable[..., _R]],
|
|
timeout: Optional[float] = None,
|
|
args: tuple = (),
|
|
kwargs: Optional[dict[str, Any]] = None) -> list[_R]:
|
|
return await self.call_utility_async("collective_rpc", method, timeout,
|
|
args, kwargs)
|
|
|
|
|
|
class DPAsyncMPClient(AsyncMPClient):
|
|
"""Asyncio-compatible client for multi-proc, multi-engine (data parallel)
|
|
EngineCore."""
|
|
|
|
def __init__(self, vllm_config: VllmConfig, executor_class: type[Executor],
|
|
log_stats: bool):
|
|
super().__init__(vllm_config, executor_class, log_stats)
|
|
|
|
assert len(self.core_engines) > 1
|
|
|
|
# Control message used for triggering dp idle mode loop.
|
|
self.start_dp_msg = (EngineCoreRequestType.START_DP.value,
|
|
self.encoder.encode(None))
|
|
|
|
self.num_engines_running = 0
|
|
self.reqs_in_flight: dict[str, CoreEngine] = {}
|
|
|
|
self.outputs_handler = DPAsyncMPClient.process_engine_outputs # type: ignore[assignment]
|
|
|
|
def _init_core_engines(
|
|
self,
|
|
vllm_config: VllmConfig,
|
|
new_core_engine: Callable[[int, Optional[int]], CoreEngine],
|
|
core_engines: list[CoreEngine],
|
|
) -> None:
|
|
|
|
# Launch a core engine for each data parallel rank.
|
|
dp_size = vllm_config.parallel_config.data_parallel_size
|
|
for i in range(dp_size):
|
|
# Multi-node not yet supported so local_dp_rank == dp_rank.
|
|
core_engines.append(new_core_engine(i, i))
|
|
|
|
self.core_engines = core_engines
|
|
|
|
async def call_utility_async(self, method: str, *args) -> Any:
|
|
# Only the result from the first engine is returned.
|
|
return (await asyncio.gather(*[
|
|
self._call_utility_async(method, *args, engine=engine)
|
|
for engine in self.core_engines
|
|
]))[0]
|
|
|
|
async def add_request_async(self, request: EngineCoreRequest) -> None:
|
|
# NOTE: text prompt is not needed in the core engine as it has been
|
|
# tokenized.
|
|
request.prompt = None
|
|
|
|
msg = (EngineCoreRequestType.ADD.value, self.encoder.encode(request))
|
|
|
|
chosen_engine = self.get_core_engine_for_request()
|
|
self.reqs_in_flight[request.request_id] = chosen_engine
|
|
chosen_engine.num_reqs_in_flight += 1
|
|
if self.num_engines_running >= len(self.core_engines):
|
|
await chosen_engine.send_multipart(msg)
|
|
else:
|
|
# Send request to chosen engine and dp start loop
|
|
# control message to all other engines.
|
|
self.num_engines_running += len(self.core_engines)
|
|
await asyncio.gather(*[
|
|
engine.send_multipart(msg if engine is
|
|
chosen_engine else self.start_dp_msg)
|
|
for engine in self.core_engines
|
|
])
|
|
|
|
self._ensure_output_queue_task()
|
|
|
|
def get_core_engine_for_request(self) -> CoreEngine:
|
|
return min(self.core_engines, key=lambda e: e.num_reqs_in_flight)
|
|
|
|
@staticmethod
|
|
async def process_engine_outputs(self: "DPAsyncMPClient",
|
|
outputs: EngineCoreOutputs):
|
|
if self.reqs_in_flight:
|
|
for req_id in outputs.finished_requests or ():
|
|
if engine := self.reqs_in_flight.pop(req_id, None):
|
|
engine.num_reqs_in_flight -= 1
|
|
|
|
if outputs.engine_paused:
|
|
assert self.num_engines_running >= 1
|
|
self.num_engines_running -= 1
|
|
if not self.num_engines_running and self.reqs_in_flight:
|
|
# If there are requests in flight here, they must have
|
|
# been sent after the engines paused. We must make
|
|
# sure to start the other engines:
|
|
self.num_engines_running = len(self.core_engines)
|
|
coros = [
|
|
engine.send_multipart(self.start_dp_msg)
|
|
for engine in self.core_engines
|
|
if not engine.num_reqs_in_flight
|
|
]
|
|
if coros:
|
|
await asyncio.gather(*coros)
|
|
|
|
async def abort_requests_async(self, request_ids: list[str]) -> None:
|
|
if not request_ids:
|
|
return
|
|
|
|
if len(request_ids) == 1:
|
|
# Fast-path common case.
|
|
if engine := self.reqs_in_flight.get(request_ids[0]):
|
|
await self._abort_requests(request_ids, engine)
|
|
return
|
|
|
|
by_engine: dict[CoreEngine, list[str]] = {}
|
|
for req_id in request_ids:
|
|
if engine := self.reqs_in_flight.get(req_id):
|
|
by_engine.setdefault(engine, []).append(req_id)
|
|
for engine, req_ids in by_engine.items():
|
|
await self._abort_requests(req_ids, engine)
|
|
|
|
async def _abort_requests(self, request_ids: list[str],
|
|
engine: CoreEngine) -> None:
|
|
await engine.send_multipart((EngineCoreRequestType.ABORT.value,
|
|
self.encoder.encode(request_ids)))
|