vllm/vllm/executor/executor_base.py
Eric Tang ddb94c2605
[core] Add tags parameter to wake_up() (#15500)
Signed-off-by: Eric <erictang000@gmail.com>
2025-04-02 01:59:27 -07:00

401 lines
15 KiB
Python

# SPDX-License-Identifier: Apache-2.0
import asyncio
import time
from abc import ABC, abstractmethod
from typing import (Any, Awaitable, Callable, Dict, List, Optional, Set, Tuple,
Union)
import torch.nn as nn
from typing_extensions import TypeVar
import vllm.platforms
from vllm.config import VllmConfig
from vllm.logger import init_logger
from vllm.lora.request import LoRARequest
from vllm.model_executor.layers.sampler import SamplerOutput
from vllm.prompt_adapter.request import PromptAdapterRequest
from vllm.sequence import ExecuteModelRequest, PoolerOutput
from vllm.utils import make_async
from vllm.worker.worker_base import WorkerBase
logger = init_logger(__name__)
_R = TypeVar("_R", default=Any)
class ExecutorBase(ABC):
"""Base class for all executors.
An executor is responsible for executing the model on one device,
or it can be a distributed executor
that can execute the model on multiple devices.
"""
uses_ray: bool # whether the executor uses Ray for orchestration.
def __init__(
self,
vllm_config: VllmConfig,
) -> None:
self.vllm_config = vllm_config
self.model_config = vllm_config.model_config
self.cache_config = vllm_config.cache_config
self.lora_config = vllm_config.lora_config
self.load_config = vllm_config.load_config
self.parallel_config = vllm_config.parallel_config
self.scheduler_config = vllm_config.scheduler_config
self.device_config = vllm_config.device_config
self.speculative_config = vllm_config.speculative_config
self.prompt_adapter_config = vllm_config.prompt_adapter_config
self.observability_config = vllm_config.observability_config
self._init_executor()
self.is_sleeping = False
self.sleeping_tags: set[str] = set()
@abstractmethod
def _init_executor(self) -> None:
raise NotImplementedError
@abstractmethod
def collective_rpc(self,
method: Union[str, Callable[..., _R]],
timeout: Optional[float] = None,
args: Tuple = (),
kwargs: Optional[Dict[str, Any]] = None) -> List[_R]:
"""
Execute an RPC call on all workers.
Args:
method: Name of the worker method to execute, or a callable that
is serialized and sent to all workers to execute.
If the method is a callable, it should accept an additional
`self` argument, in addition to the arguments passed in `args`
and `kwargs`. The `self` argument will be the worker object.
timeout: Maximum time in seconds to wait for execution. Raises a
:exc:`TimeoutError` on timeout. `None` means wait indefinitely.
args: Positional arguments to pass to the worker method.
kwargs: Keyword arguments to pass to the worker method.
Returns:
A list containing the results from each worker.
Note:
It is recommended to use this API to only pass control messages,
and set up data-plane communication to pass data.
"""
raise NotImplementedError
def determine_num_available_blocks(self) -> Tuple[int, int]:
"""Determine the number of available blocks for the GPU KV cache and
swappable CPU KV cache.
Normally, this should simply delegate to the underlying Worker. Some
ExecutorBase may require modification of the result, e.g. to ensure the
selected cache sizes are compatible with all workers.
Returns a Tuple[num_gpu_blocks, num_cpu_blocks], where num_gpu_blocks
are blocks that are "active" on the device and can be appended to.
num_cpu_blocks refers to "swapped" blocks in CPU memory and cannot be
appended to.
"""
results = self.collective_rpc("determine_num_available_blocks")
a = min([r[0] for r in results])
b = min([r[1] for r in results])
return a, b
def initialize_cache(self, num_gpu_blocks: int, num_cpu_blocks) -> None:
"""Initialize the KV cache by invoking the underlying worker.
"""
# NOTE: This is logged in the executor because there can be >1 workers.
logger.info("# %s blocks: %d, # CPU blocks: %d",
vllm.platforms.current_platform.device_name,
num_gpu_blocks, num_cpu_blocks)
max_concurrency = (num_gpu_blocks * self.cache_config.block_size /
self.model_config.max_model_len)
logger.info("Maximum concurrency for %s tokens per request: %.2fx",
self.model_config.max_model_len, max_concurrency)
self.cache_config.num_gpu_blocks = num_gpu_blocks
self.cache_config.num_cpu_blocks = num_cpu_blocks
self.collective_rpc("initialize_cache",
args=(num_gpu_blocks, num_cpu_blocks))
def apply_model(self, func: Callable[[nn.Module], _R]) -> list[_R]:
"""
Run a function directly on the model inside each worker,
returning the result for each of them.
"""
def rpc_func(worker: WorkerBase) -> _R:
return func(worker.get_model())
return self.collective_rpc(rpc_func)
def execute_model(
self, execute_model_req: ExecuteModelRequest
) -> Optional[List[Union[SamplerOutput, PoolerOutput]]]:
output = self.collective_rpc("execute_model",
args=(execute_model_req, ))
return output[0]
def stop_remote_worker_execution_loop(self) -> None:
"""Releases parallel workers from model loop."""
return
def add_lora(self, lora_request: LoRARequest) -> bool:
assert lora_request.lora_int_id > 0, "lora_id must be greater than 0."
return all(self.collective_rpc("add_lora", args=(lora_request, )))
def remove_lora(self, lora_id: int) -> bool:
assert lora_id > 0, "lora_id must be greater than 0."
return all(self.collective_rpc("remove_lora", args=(lora_id, )))
def pin_lora(self, lora_id: int) -> bool:
assert lora_id > 0, "lora_id must be greater than 0."
return all(self.collective_rpc("pin_lora", args=(lora_id, )))
def list_loras(self) -> Set[int]:
sets = self.collective_rpc("list_loras")
for s in sets:
assert s == sets[0], "All workers should have the same LORAs."
return sets[0]
def add_prompt_adapter(
self, prompt_adapter_request: PromptAdapterRequest) -> bool:
assert prompt_adapter_request.prompt_adapter_id > 0, \
"prompt_adapter_id must be greater than 0."
return all(
self.collective_rpc("add_prompt_adapter",
args=(prompt_adapter_request, )))
def remove_prompt_adapter(self, prompt_adapter_id: int) -> bool:
assert prompt_adapter_id > 0, \
"prompt_adapter_id must be greater than 0."
return all(
self.collective_rpc("remove_prompt_adapter",
args=(prompt_adapter_id, )))
def pin_prompt_adapter(self, prompt_adapter_id: int) -> bool:
assert prompt_adapter_id > 0, \
"prompt_adapter_id must be greater than 0."
return all(
self.collective_rpc("pin_prompt_adapter",
args=(prompt_adapter_id, )))
def list_prompt_adapters(self) -> Set[int]:
sets = self.collective_rpc("list_prompt_adapters")
for s in sets:
assert (s == sets[0]
), "All workers should have the same prompt adapters."
return sets[0]
def start_profile(self) -> None:
self.collective_rpc("start_profile")
def stop_profile(self) -> None:
self.collective_rpc("stop_profile")
def sleep(self, level: int = 1):
if self.is_sleeping:
logger.warning("Executor is already sleeping.")
return
time_before_sleep = time.perf_counter()
self.collective_rpc("sleep", kwargs=dict(level=level))
time_after_sleep = time.perf_counter()
self.sleeping_tags = {"weights", "kv_cache"}
self.is_sleeping = True
logger.info("It took %.6f seconds to fall asleep.",
time_after_sleep - time_before_sleep)
def wake_up(self, tags: Optional[list[str]] = None):
if not self.is_sleeping:
logger.warning("Executor is not sleeping.")
return
if tags:
for tag in tags:
if tag not in self.sleeping_tags:
logger.warning("Tag %s is not in sleeping tags %s", tag,
self.sleeping_tags)
return
time_before_wakeup = time.perf_counter()
self.collective_rpc("wake_up", kwargs=dict(tags=tags))
time_after_wakeup = time.perf_counter()
logger.info("It took %.6f seconds to wake up tags %s.",
time_after_wakeup - time_before_wakeup,
tags if tags is not None else self.sleeping_tags)
if tags:
for tag in tags:
self.sleeping_tags.remove(tag)
else:
self.sleeping_tags.clear()
if not self.sleeping_tags:
self.is_sleeping = False
def save_sharded_state(
self,
path: str,
pattern: Optional[str] = None,
max_size: Optional[int] = None,
) -> None:
self.collective_rpc("save_sharded_state",
kwargs=dict(path=path,
pattern=pattern,
max_size=max_size))
@abstractmethod
def check_health(self) -> None:
"""Checks if the executor is healthy. If not, it should raise an
exception."""
raise NotImplementedError
def shutdown(self) -> None:
"""Shutdown the executor."""
return
def __del__(self):
self.shutdown()
async def execute_model_async(
self,
execute_model_req: ExecuteModelRequest) -> List[SamplerOutput]:
"""Executes one model step on the given sequences."""
output = await make_async(self.execute_model)(execute_model_req)
return output
async def stop_remote_worker_execution_loop_async(self) -> None:
"""Releases parallel workers from model loop."""
return
async def check_health_async(self) -> None:
"""Checks if the executor is healthy. If not, it should raise an
exception."""
self.check_health()
class DistributedExecutorBase(ExecutorBase):
"""Abstract superclass of distributed executor implementations."""
def __init__(self, *args, **kwargs):
# This is non-None when the execute model loop is running
# in the parallel workers. It's a coroutine in the AsyncLLMEngine case.
self.parallel_worker_tasks: Optional[Union[Any, Awaitable[Any]]] = None
super().__init__(*args, **kwargs)
def execute_model(
self,
execute_model_req: ExecuteModelRequest,
) -> List[SamplerOutput]:
# TODO: unify into collective_rpc
if self.parallel_worker_tasks is None:
self.parallel_worker_tasks = self._run_workers(
"start_worker_execution_loop",
async_run_tensor_parallel_workers_only=True)
# Only the driver worker returns the sampling results.
driver_outputs = self._driver_execute_model(execute_model_req)
assert driver_outputs is not None
return driver_outputs
def stop_remote_worker_execution_loop(self) -> None:
if self.parallel_worker_tasks is None:
return
self._driver_execute_model(execute_model_req=None)
parallel_worker_tasks = self.parallel_worker_tasks
self.parallel_worker_tasks = None
# Ensure that workers exit model loop cleanly
# (this will raise otherwise)
self._wait_for_tasks_completion(parallel_worker_tasks)
@abstractmethod
def _driver_execute_model(
self, execute_model_req: Optional[ExecuteModelRequest]
) -> Optional[List[SamplerOutput]]:
"""Run execute_model in the driver worker.
Passing None will cause the driver to stop the model execution loop
running in each of the remote workers. In this case, this method
returns None. Otherwise, this method returns the model output.
"""
raise NotImplementedError
def collective_rpc(self,
method: Union[str, Callable],
timeout: Optional[float] = None,
args: Tuple = (),
kwargs: Optional[Dict] = None) -> List[Any]:
return self._run_workers(method, *args, **(kwargs or {}))
@abstractmethod
def _run_workers(
self,
method: Union[str, Callable],
*args,
async_run_tensor_parallel_workers_only: bool = False,
max_concurrent_workers: Optional[int] = None,
**kwargs,
) -> Any:
"""Runs the given method on all workers.
Args:
async_run_tensor_parallel_workers_only: If True the method will be
run only in the remote TP workers, not the driver worker.
It will also be run asynchronously and return a list of futures
rather than blocking on the results.
# TODO: simplify and merge with collective_rpc
"""
raise NotImplementedError
@abstractmethod
def _wait_for_tasks_completion(self, parallel_worker_tasks: Any) -> None:
"""Wait for futures returned from _run_workers() with
async_run_remote_workers_only to complete."""
raise NotImplementedError
async def execute_model_async(
self,
execute_model_req: ExecuteModelRequest) -> List[SamplerOutput]:
if self.parallel_worker_tasks is None:
# Start model execution loop running in the parallel workers
self.parallel_worker_tasks = asyncio.create_task(
self._start_worker_execution_loop())
# Only the driver worker returns the sampling results.
return await self._driver_execute_model_async(execute_model_req)
async def stop_remote_worker_execution_loop_async(self) -> None:
if self.parallel_worker_tasks is None:
return
await self._driver_execute_model_async()
parallel_worker_tasks = self.parallel_worker_tasks
self.parallel_worker_tasks = None
# Ensure that workers exit model loop cleanly
# (this will raise otherwise)
await parallel_worker_tasks
@abstractmethod
async def _driver_execute_model_async(
self,
execute_model_req: Optional[ExecuteModelRequest] = None,
) -> List[SamplerOutput]:
"""Execute the model asynchronously in the driver worker.
Passing None will cause the driver to stop the model execution
loop running in each of the remote workers.
"""
raise NotImplementedError
@abstractmethod
async def _start_worker_execution_loop(self):
"""Run execution loop on all workers. It guarantees all workers run
the loop or None of them is running the loop. Loop can be stopped by
`stop_remote_worker_execution_loop`.
The API is idempotent (guarantee only 1 loop run at any moment)."""
raise NotImplementedError