[Hardware] [Intel] Enable Multiprocessing and tensor parallel in CPU backend and update documentation (#6125)

This commit is contained in:
Li, Jiang 2024-07-27 04:50:10 +08:00 committed by GitHub
parent aa4867791e
commit 3bbb4936dc
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
14 changed files with 403 additions and 89 deletions

View File

@ -3,26 +3,38 @@
set -ex
# Try building the docker image
docker build -t cpu-test -f Dockerfile.cpu .
docker build --build-arg VLLM_CPU_DISABLE_AVX512="true" -t cpu-test-avx2 -f Dockerfile.cpu .
numactl -C 48-95 -N 1 docker build -t cpu-test -f Dockerfile.cpu .
numactl -C 48-95 -N 1 docker build --build-arg VLLM_CPU_DISABLE_AVX512="true" -t cpu-test-avx2 -f Dockerfile.cpu .
# Setup cleanup
remove_docker_container() { docker rm -f cpu-test cpu-test-avx2 || true; }
trap remove_docker_container EXIT
remove_docker_container
# Run the image
# Run the image, setting --shm-size=4g for tensor parallel.
docker run -itd --entrypoint /bin/bash -v ~/.cache/huggingface:/root/.cache/huggingface --cpuset-cpus=48-95 \
--cpuset-mems=1 --network host -e HF_TOKEN --env VLLM_CPU_KVCACHE_SPACE=4 --name cpu-test cpu-test
--cpuset-mems=1 --privileged=true --network host -e HF_TOKEN --env VLLM_CPU_KVCACHE_SPACE=4 --shm-size=4g --name cpu-test cpu-test
docker run -itd --entrypoint /bin/bash -v ~/.cache/huggingface:/root/.cache/huggingface --cpuset-cpus=48-95 \
--cpuset-mems=1 --network host -e HF_TOKEN --env VLLM_CPU_KVCACHE_SPACE=4 --name cpu-test-avx2 cpu-test-avx2
--cpuset-mems=1 --privileged=true --network host -e HF_TOKEN --env VLLM_CPU_KVCACHE_SPACE=4 --shm-size=4g --name cpu-test-avx2 cpu-test-avx2
# offline inference
docker exec cpu-test bash -c "python3 examples/offline_inference.py"
docker exec cpu-test-avx2 bash -c "python3 examples/offline_inference.py"
# Run basic model test
docker exec cpu-test bash -c "cd tests;
docker exec cpu-test bash -c "
pip install pytest Pillow protobuf
cd ../
pytest -v -s tests/models -m \"not vlm\" --ignore=tests/models/test_embedding.py --ignore=tests/models/test_registry.py --ignore=tests/models/test_jamba.py" # Mamba on CPU is not supported
# online inference
docker exec cpu-test bash -c "
export VLLM_CPU_KVCACHE_SPACE=10
export VLLM_CPU_OMP_THREADS_BIND=48-92
python3 -m vllm.entrypoints.openai.api_server --model facebook/opt-125m &
timeout 600 bash -c 'until curl localhost:8000/v1/models; do sleep 1; done' || exit 1
python3 benchmarks/benchmark_serving.py \
--backend vllm \
--dataset-name random \
--model facebook/opt-125m \
--num-prompts 20 \
--endpoint /v1/completions \
--tokenizer facebook/opt-125m"

View File

@ -2,8 +2,8 @@
FROM ubuntu:22.04 AS cpu-test-1
RUN apt-get update -y \
&& apt-get install -y git wget vim numactl gcc-12 g++-12 python3 python3-pip libtcmalloc-minimal4 \
RUN apt-get update -y \
&& apt-get install -y curl git wget vim numactl gcc-12 g++-12 python3 python3-pip libtcmalloc-minimal4 libnuma-dev \
&& update-alternatives --install /usr/bin/gcc gcc /usr/bin/gcc-12 10 --slave /usr/bin/g++ g++ /usr/bin/g++-12
# https://intel.github.io/intel-extension-for-pytorch/cpu/latest/tutorials/performance_tuning/tuning_guide.html
@ -13,8 +13,9 @@ RUN pip install intel-openmp
ENV LD_PRELOAD="/usr/lib/x86_64-linux-gnu/libtcmalloc_minimal.so.4:/usr/local/lib/libiomp5.so:$LD_PRELOAD"
RUN echo 'ulimit -c 0' >> ~/.bashrc
RUN pip install https://intel-extension-for-pytorch.s3.amazonaws.com/ipex_dev/cpu/intel_extension_for_pytorch-2.3.100%2Bgit0eb3473-cp310-cp310-linux_x86_64.whl
RUN pip install https://intel-extension-for-pytorch.s3.amazonaws.com/ipex_dev/cpu/intel_extension_for_pytorch-2.4.0%2Bgitfbaa4bc-cp310-cp310-linux_x86_64.whl
RUN pip install --upgrade pip \
&& pip install wheel packaging ninja "setuptools>=49.4.0" numpy
@ -25,7 +26,7 @@ COPY ./ /workspace/vllm
WORKDIR /workspace/vllm
RUN pip install -v -r requirements-cpu.txt --extra-index-url https://download.pytorch.org/whl/cpu
RUN pip install -v -r requirements-cpu.txt --extra-index-url https://download.pytorch.org/whl/test/cpu
# Support for building with non-AVX512 vLLM: docker build --build-arg VLLM_CPU_DISABLE_AVX512="true" ...
ARG VLLM_CPU_DISABLE_AVX512

View File

@ -83,6 +83,8 @@ endif()
message(STATUS "CPU extension compile flags: ${CXX_COMPILE_FLAGS}")
list(APPEND LIBS "numa")
#
# Define extension targets
@ -95,6 +97,7 @@ set(VLLM_EXT_SRC
"csrc/cpu/activation.cpp"
"csrc/cpu/attention.cpp"
"csrc/cpu/cache.cpp"
"csrc/cpu/utils.cpp"
"csrc/cpu/layernorm.cpp"
"csrc/cpu/pos_encoding.cpp"
"csrc/cpu/torch_bindings.cpp")
@ -104,6 +107,7 @@ define_gpu_extension_target(
DESTINATION vllm
LANGUAGE CXX
SOURCES ${VLLM_EXT_SRC}
LIBRARIES ${LIBS}
COMPILE_FLAGS ${CXX_COMPILE_FLAGS}
USE_SABI 3
WITH_SOABI

View File

@ -4,6 +4,8 @@
#include <torch/library.h>
void init_cpu_threads_env(const std::string& cpu_ids);
TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
// vLLM custom ops
@ -107,4 +109,9 @@ TORCH_LIBRARY_EXPAND(CONCAT(TORCH_EXTENSION_NAME, _cache_ops), cache_ops) {
cache_ops.impl("reshape_and_cache", torch::kCPU, &reshape_and_cache);
}
TORCH_LIBRARY_EXPAND(CONCAT(TORCH_EXTENSION_NAME, _utils), utils) {
// CPU utils
utils.def("init_cpu_threads_env(str cpu_ids) -> ()", &init_cpu_threads_env);
}
REGISTER_EXTENSION(TORCH_EXTENSION_NAME)

65
csrc/cpu/utils.cpp Normal file
View File

@ -0,0 +1,65 @@
#include <numa.h>
#include <unistd.h>
#include <string>
#include <sched.h>
#include "cpu_types.hpp"
void init_cpu_threads_env(const std::string& cpu_ids) {
bitmask* omp_cpu_mask = numa_parse_cpustring(cpu_ids.c_str());
TORCH_CHECK(omp_cpu_mask->size > 0);
std::vector<int> omp_cpu_ids;
omp_cpu_ids.reserve(omp_cpu_mask->size);
constexpr int group_size = 8 * sizeof(*omp_cpu_mask->maskp);
for (int offset = 0; offset < omp_cpu_mask->size; offset += group_size) {
unsigned long group_mask = omp_cpu_mask->maskp[offset / group_size];
int i = 0;
while (group_mask) {
if (group_mask & 1) {
omp_cpu_ids.emplace_back(offset + i);
}
++i;
group_mask >>= 1;
}
}
// Memory node binding
if (numa_available() != -1) {
int mem_node_id = numa_node_of_cpu(omp_cpu_ids.front());
bitmask* mask = numa_parse_nodestring(std::to_string(mem_node_id).c_str());
bitmask* src_mask = numa_get_membind();
int pid = getpid();
// move all existing pages to the specified numa node.
*(src_mask->maskp) = *(src_mask->maskp) ^ *(mask->maskp);
int page_num = numa_migrate_pages(pid, src_mask, mask);
if (page_num == -1) {
TORCH_CHECK(false,
"numa_migrate_pages failed. errno: " + std::to_string(errno));
}
// restrict memory allocation node.
numa_set_membind(mask);
numa_set_strict(1);
}
// OMP threads binding
omp_set_num_threads((int)omp_cpu_ids.size());
torch::set_num_threads((int)omp_cpu_ids.size());
TORCH_CHECK_EQ(omp_cpu_ids.size(), torch::get_num_threads());
TORCH_CHECK_EQ(omp_cpu_ids.size(), omp_get_max_threads());
#pragma omp parallel for schedule(static, 1)
for (size_t i = 0; i < omp_cpu_ids.size(); ++i) {
cpu_set_t* mask = CPU_ALLOC(omp_cpu_mask->size);
size_t size = CPU_ALLOC_SIZE(omp_cpu_mask->size);
CPU_ZERO_S(size, mask);
CPU_SET_S(omp_cpu_ids[i], size, mask);
sched_setaffinity(0, sizeof(cpu_set_t), mask);
CPU_FREE(mask);
}
numa_free_nodemask(omp_cpu_mask);
}

View File

@ -10,6 +10,7 @@ Table of contents:
#. :ref:`Requirements <cpu_backend_requirements>`
#. :ref:`Quick start using Dockerfile <cpu_backend_quick_start_dockerfile>`
#. :ref:`Build from source <build_cpu_backend_from_source>`
#. :ref:`Related runtime environment variables <env_intro>`
#. :ref:`Intel Extension for PyTorch <ipex_guidance>`
#. :ref:`Performance tips <cpu_backend_performance_tips>`
@ -47,7 +48,7 @@ Build from source
.. code-block:: console
$ sudo apt-get update -y
$ sudo apt-get install -y gcc-12 g++-12
$ sudo apt-get install -y gcc-12 g++-12 libnuma-dev
$ sudo update-alternatives --install /usr/bin/gcc gcc /usr/bin/gcc-12 10 --slave /usr/bin/g++ g++ /usr/bin/g++-12
- Second, install Python packages for vLLM CPU backend building:
@ -71,6 +72,15 @@ Build from source
- If you want to force enable AVX512_BF16 for the cross-compilation, please set environment variable VLLM_CPU_AVX512BF16=1 before the building.
.. _env_intro:
Related runtime environment variables
-------------------------------------
- ``VLLM_CPU_KVCACHE_SPACE``: specify the KV Cache size (e.g, ``VLLM_CPU_KVCACHE_SPACE=40`` means 40 GB space for KV cache), larger setting will allow vLLM running more requests in parallel. This parameter should be set based on the hardware configuration and memory management pattern of users.
- ``VLLM_CPU_OMP_THREADS_BIND``: specify the CPU cores dedicated to the OpenMP threads. For example, ``VLLM_CPU_OMP_THREADS_BIND=0-31`` means there will be 32 OpenMP threads bound on 0-31 CPU cores. ``VLLM_CPU_OMP_THREADS_BIND=0-31|32-63`` means there will be 2 tensor parallel processes, 32 OpenMP threads of rank0 are bound on 0-31 CPU cores, and the OpenMP threads of rank1 are bound on 32-63 CPU cores.
.. _ipex_guidance:
Intel Extension for PyTorch
@ -78,15 +88,11 @@ Intel Extension for PyTorch
- `Intel Extension for PyTorch (IPEX) <https://github.com/intel/intel-extension-for-pytorch>`_ extends PyTorch with up-to-date features optimizations for an extra performance boost on Intel hardware.
- IPEX after the ``2.3.0`` can be enabled in the CPU backend by default if it is installed.
.. _cpu_backend_performance_tips:
Performance tips
-----------------
- vLLM CPU backend uses environment variable ``VLLM_CPU_KVCACHE_SPACE`` to specify the KV Cache size (e.g, ``VLLM_CPU_KVCACHE_SPACE=40`` means 40 GB space for KV cache), larger setting will allow vLLM running more requests in parallel. This parameter should be set based on the hardware configuration and memory management pattern of users.
- We highly recommend to use TCMalloc for high performance memory allocation and better cache locality. For example, on Ubuntu 22.4, you can run:
.. code-block:: console
@ -96,11 +102,44 @@ Performance tips
$ export LD_PRELOAD=/usr/lib/x86_64-linux-gnu/libtcmalloc_minimal.so.4:$LD_PRELOAD # prepend the library to LD_PRELOAD
$ python examples/offline_inference.py # run vLLM
- vLLM CPU backend uses OpenMP for thread-parallel computation. If you want the best performance on CPU, it will be very critical to isolate CPU cores for OpenMP threads with other thread pools (like web-service event-loop), to avoid CPU oversubscription.
- When using the online serving, it is recommended to reserve 1-2 CPU cores for the serving framework to avoid CPU oversubscription. For example, on a platform with 32 physical CPU cores, reserving CPU 30 and 31 for the framework and using CPU 0-29 for OpenMP:
- If using vLLM CPU backend on a bare-metal machine, it is recommended to disable the hyper-threading.
.. code-block:: console
- If using vLLM CPU backend on a multi-socket machine with NUMA, be aware to set CPU cores and memory nodes, to avoid the remote memory node access. ``numactl`` is an useful tool for CPU core and memory binding on NUMA platform. Besides, ``--cpuset-cpus`` and ``--cpuset-mems`` arguments of ``docker run`` are also useful.
$ export VLLM_CPU_KVCACHE_SPACE=40
$ export VLLM_CPU_OMP_THREADS_BIND=0-29
$ vllm serve facebook/opt-125m
- If using vLLM CPU backend on a machine with hyper-threading, it is recommended to bind only one OpenMP thread on each physical CPU core using ``VLLM_CPU_OMP_THREADS_BIND``. On a hyper-threading enabled platform with 16 logical CPU cores / 8 physical CPU cores:
.. code-block:: console
$ lscpu -e # check the mapping between logical CPU cores and physical CPU cores
# The "CPU" column means the logical CPU core IDs, and the "CORE" column means the physical core IDs. On this platform, two logical cores are sharing one physical core.
CPU NODE SOCKET CORE L1d:L1i:L2:L3 ONLINE MAXMHZ MINMHZ MHZ
0 0 0 0 0:0:0:0 yes 2401.0000 800.0000 800.000
1 0 0 1 1:1:1:0 yes 2401.0000 800.0000 800.000
2 0 0 2 2:2:2:0 yes 2401.0000 800.0000 800.000
3 0 0 3 3:3:3:0 yes 2401.0000 800.0000 800.000
4 0 0 4 4:4:4:0 yes 2401.0000 800.0000 800.000
5 0 0 5 5:5:5:0 yes 2401.0000 800.0000 800.000
6 0 0 6 6:6:6:0 yes 2401.0000 800.0000 800.000
7 0 0 7 7:7:7:0 yes 2401.0000 800.0000 800.000
8 0 0 0 0:0:0:0 yes 2401.0000 800.0000 800.000
9 0 0 1 1:1:1:0 yes 2401.0000 800.0000 800.000
10 0 0 2 2:2:2:0 yes 2401.0000 800.0000 800.000
11 0 0 3 3:3:3:0 yes 2401.0000 800.0000 800.000
12 0 0 4 4:4:4:0 yes 2401.0000 800.0000 800.000
13 0 0 5 5:5:5:0 yes 2401.0000 800.0000 800.000
14 0 0 6 6:6:6:0 yes 2401.0000 800.0000 800.000
15 0 0 7 7:7:7:0 yes 2401.0000 800.0000 800.000
# On this platform, it is recommend to only bind openMP threads on logical CPU cores 0-7 or 8-15
$ export VLLM_CPU_OMP_THREADS_BIND=0-7
$ python examples/offline_inference.py
- If using vLLM CPU backend on a multi-socket machine with NUMA, be aware to set CPU cores using ``VLLM_CPU_OMP_THREADS_BIND`` to avoid cross NUMA node memory access.

View File

@ -2,6 +2,6 @@
-r requirements-common.txt
# Dependencies for x86_64 CPUs
torch == 2.3.1+cpu; platform_machine != "ppc64le"
torchvision == 0.18.1+cpu; platform_machine != "ppc64le" # required for the image processor of phi3v, this must be updated alongside torch
torch == 2.4.0; platform_machine != "ppc64le"
torchvision; platform_machine != "ppc64le" # required for the image processor of phi3v, this must be updated alongside torch
triton >= 2.2.0 # FIXME(woosuk): This is a hack to avoid import error.

View File

@ -296,6 +296,9 @@ class GroupCoordinator:
pynccl_comm = self.pynccl_comm
if (pynccl_comm is not None and not pynccl_comm.disabled):
pynccl_comm.all_reduce(input_)
elif input_.is_cpu:
import intel_extension_for_pytorch as ipex
ipex.distributed.all_reduce(input_, group=self.device_group)
else:
torch.distributed.all_reduce(input_, group=self.device_group)
return input_

View File

@ -410,8 +410,6 @@ class AsyncLLMEngine:
from vllm.executor.tpu_executor import TPUExecutorAsync
executor_class = TPUExecutorAsync
elif engine_config.device_config.device_type == "cpu":
assert distributed_executor_backend is None, (
"Distributed execution is not supported with the CPU backend.")
from vllm.executor.cpu_executor import CPUExecutorAsync
executor_class = CPUExecutorAsync
elif engine_config.device_config.device_type == "openvino":

View File

@ -29,6 +29,7 @@ if TYPE_CHECKING:
VLLM_TRACE_FUNCTION: int = 0
VLLM_ATTENTION_BACKEND: Optional[str] = None
VLLM_CPU_KVCACHE_SPACE: int = 0
VLLM_CPU_OMP_THREADS_BIND: str = ""
VLLM_OPENVINO_KVCACHE_SPACE: int = 0
VLLM_OPENVINO_CPU_KV_CACHE_PRECISION: Optional[str] = None
VLLM_OPENVINO_ENABLE_QUANTIZED_WEIGHTS: bool = False
@ -241,11 +242,16 @@ environment_variables: Dict[str, Callable[[], Any]] = {
"VLLM_ATTENTION_BACKEND":
lambda: os.getenv("VLLM_ATTENTION_BACKEND", None),
# CPU key-value cache space
# (CPU backend only) CPU key-value cache space.
# default is 4GB
"VLLM_CPU_KVCACHE_SPACE":
lambda: int(os.getenv("VLLM_CPU_KVCACHE_SPACE", "0")),
# (CPU backend only) CPU core ids bound by OpenMP threads, e.g., "0-31",
# "0,1,2", "0-31,33". CPU cores of different ranks are separated by '|'.
"VLLM_CPU_OMP_THREADS_BIND":
lambda: os.getenv("VLLM_CPU_OMP_THREADS_BIND", "all"),
# OpenVINO key-value cache space
# default is 4GB
"VLLM_OPENVINO_KVCACHE_SPACE":

View File

@ -1,16 +1,21 @@
from typing import List, Set, Tuple
import os
from functools import partial
from typing import Any, Awaitable, List, Optional, Set, Tuple, Union
import torch
import vllm.envs as envs
from vllm.config import CacheConfig, ModelConfig, SchedulerConfig
from vllm.executor.executor_base import ExecutorAsyncBase, ExecutorBase
from vllm.executor.multiproc_worker_utils import (ProcessWorkerWrapper,
ResultHandler, WorkerMonitor)
from vllm.logger import init_logger
from vllm.lora.request import LoRARequest
from vllm.prompt_adapter.request import PromptAdapterRequest
from vllm.sequence import ExecuteModelRequest, SamplerOutput
from vllm.utils import (get_distributed_init_method, get_ip, get_open_port,
make_async)
from vllm.utils import (get_distributed_init_method, get_open_port,
get_vllm_instance_id, make_async)
from vllm.worker.worker_base import WorkerWrapperBase
logger = init_logger(__name__)
@ -22,46 +27,173 @@ class CPUExecutor(ExecutorBase):
def _init_executor(self) -> None:
assert self.device_config.device_type == "cpu"
assert self.lora_config is None, "cpu backend doesn't support LoRA"
#
# Environment variables for CPU executor
#
# Ensure that VLLM_INSTANCE_ID is set, to be inherited by workers
os.environ["VLLM_INSTANCE_ID"] = get_vllm_instance_id()
# Disable torch async compiling which won't work with daemonic processes
os.environ["TORCHINDUCTOR_COMPILE_THREADS"] = "1"
# Intel OpenMP setting
ld_prealod_str = os.getenv("LD_PRELOAD", "")
if "libiomp5.so" in ld_prealod_str:
# The time(milliseconds) that a thread should wait after
# completing the execution of a parallel region, before sleeping.
os.environ['KMP_BLOCKTIME'] = "1"
# Prevents the CPU to run into low performance state
os.environ['KMP_TPAUSE'] = "0"
# Provides fine granularity parallelism
os.environ['KMP_FORKJOIN_BARRIER_PATTERN'] = "dist,dist"
os.environ['KMP_PLAIN_BARRIER_PATTERN'] = "dist,dist"
os.environ['KMP_REDUCTION_BARRIER_PATTERN'] = "dist,dist"
# To hint IPEX uses shared memory based AllReduce
os.environ["LOCAL_WORLD_SIZE"] = str(
self.parallel_config.tensor_parallel_size)
self.model_config = _verify_and_get_model_config(self.model_config)
self.cache_config = _verify_and_get_cache_config(self.cache_config)
self.scheduler_config = _verify_and_get_scheduler_config(
self.scheduler_config)
# Instantiate the worker and load the model to CPU.
self._init_worker()
# Multiprocessing-based executor does not support multi-node setting.
# Since it only works for single node, we can use the loopback address
# 127.0.0.1 for communication.
ip = "127.0.0.1"
port = get_open_port()
self.distributed_init_method = get_distributed_init_method(ip, port)
def _init_worker(self):
from vllm.worker.cpu_worker import CPUWorker
is_async = isinstance(self, CPUExecutorAsync)
assert self.parallel_config.world_size == 1, (
"CPUExecutor only supports single CPU socket currently.")
world_size = self.parallel_config.tensor_parallel_size
result_handler = ResultHandler()
self.parallel_worker_tasks: Optional[Union[Any, Awaitable[Any]]] = None
self.workers = []
distributed_init_method = get_distributed_init_method(
get_ip(), get_open_port())
self.driver_worker = CPUWorker(
if is_async:
self.workers = [
ProcessWorkerWrapper(
result_handler,
partial(
self._create_worker,
rank=rank,
local_rank=rank,
)) for rank in range(0, world_size)
]
self.driver_worker = self.workers[0]
self.workers = self.workers[1:]
self.driver_method_invoker = _async_driver_method_invoker
else:
self.driver_worker = self._create_worker()
self.driver_method_invoker = _driver_method_invoker
if world_size != 1:
self.workers = [
ProcessWorkerWrapper(
result_handler,
partial(
self._create_worker,
rank=rank,
local_rank=rank,
)) for rank in range(1, world_size)
]
if world_size != 1 or is_async:
if is_async:
async_worker_list = self.workers + [self.driver_worker]
else:
async_worker_list = self.workers
self.worker_monitor = WorkerMonitor(async_worker_list,
result_handler)
result_handler.start()
self.worker_monitor.start()
self._run_workers("init_device")
self._run_workers("load_model")
def _create_worker(
self,
local_rank: int = 0,
rank: int = 0,
):
worker_module_name = "vllm.worker.cpu_worker"
worker_class_name = "CPUWorker"
wrapper = WorkerWrapperBase(
worker_module_name=worker_module_name,
worker_class_name=worker_class_name,
)
assert self.distributed_init_method is not None
kwargs = dict(
model_config=self.model_config,
parallel_config=self.parallel_config,
scheduler_config=self.scheduler_config,
device_config=self.device_config,
cache_config=self.cache_config,
load_config=self.load_config,
local_rank=0,
rank=0,
distributed_init_method=distributed_init_method,
local_rank=local_rank,
rank=rank,
distributed_init_method=self.distributed_init_method,
lora_config=self.lora_config,
multimodal_config=self.multimodal_config,
kv_cache_dtype=self.cache_config.cache_dtype,
prompt_adapter_config=self.prompt_adapter_config,
is_driver_worker=True,
is_driver_worker=rank == 0,
)
self.driver_worker.init_device()
self.driver_worker.load_model()
wrapper.init_worker(**kwargs)
return wrapper.worker
def _run_workers(
self,
method: str,
*args,
async_run_remote_workers_only: bool = False,
max_concurrent_workers: Optional[int] = None,
**kwargs,
) -> Any:
"""Runs the given method on all workers.
Args:
async_run_remote_workers_only: If True the method will be run only
in the remote workers, not the driver worker. It will also be
run asynchronously and return a list of futures rather than
blocking on the results.
"""
if max_concurrent_workers:
raise NotImplementedError(
"max_concurrent_workers is not supported yet.")
# Start the workers first.
worker_outputs = [
worker.execute_method(method, *args, **kwargs)
for worker in self.workers
]
if async_run_remote_workers_only:
# Just return futures
return worker_outputs
driver_worker_output = self.driver_method_invoker(
self.driver_worker, method, *args, **kwargs)
# Get the results of the workers.
return [driver_worker_output
] + [output.get() for output in worker_outputs]
def determine_num_available_blocks(self) -> Tuple[int, int]:
"""Determine the number of available KV blocks by invoking the
underlying worker.
"""
return self.driver_worker.determine_num_available_blocks()
return self.driver_method_invoker(self.driver_worker,
"determine_num_available_blocks")
def initialize_cache(self, num_gpu_blocks: int,
num_cpu_blocks: int) -> None:
@ -74,43 +206,95 @@ class CPUExecutor(ExecutorBase):
# referred as `gpu block`. Because we want to reuse the existing block
# management procedure.
logger.info("# CPU blocks: %d", num_gpu_blocks)
self.driver_worker.initialize_cache(num_gpu_blocks, num_cpu_blocks)
self._run_workers("initialize_cache",
num_gpu_blocks=num_gpu_blocks,
num_cpu_blocks=num_cpu_blocks)
def execute_model(
self,
execute_model_req: ExecuteModelRequest) -> List[SamplerOutput]:
output = self.driver_worker.execute_model(execute_model_req)
if (self.parallel_config.tensor_parallel_size > 1
and self.parallel_worker_tasks is None):
self.parallel_worker_tasks = self._run_workers(
"start_worker_execution_loop",
async_run_remote_workers_only=True,
)
output = self.driver_method_invoker(self.driver_worker,
"execute_model", execute_model_req)
return output
def stop_remote_worker_execution_loop(self) -> None:
if self.parallel_worker_tasks is None:
return
"""
Passing None will cause the driver to stop the model execution
loop running in each of the remote workers.
"""
self.driver_method_invoker(self.driver_worker, "execute_model", None)
parallel_worker_tasks = self.parallel_worker_tasks
self.parallel_worker_tasks = None
# Ensure that workers exit model loop cleanly
# (this will raise otherwise)
self._wait_for_tasks_completion(parallel_worker_tasks)
def add_lora(self, lora_request: LoRARequest) -> bool:
return self.driver_worker.add_lora(lora_request)
return all(self._run_workers("add_lora", lora_request))
def remove_lora(self, lora_id: int) -> bool:
return self.driver_worker.remove_lora(lora_id)
return all(self._run_workers("remove_lora", lora_id))
def pin_lora(self, lora_id: int) -> bool:
return self.driver_worker.pin_lora(lora_id)
assert lora_id > 0, "lora_id must be greater than 0."
return all(self._run_workers(
"pin_lora",
lora_id=lora_id,
))
def list_loras(self) -> Set[int]:
return self.driver_worker.list_loras()
return self.driver_method_invoker(self.driver_worker, "list_loras")
def add_prompt_adapter(
self, prompt_adapter_request: PromptAdapterRequest) -> bool:
return self.driver_worker.add_prompt_adapter(prompt_adapter_request)
return all(
self._run_workers(
"add_prompt_adapter",
prompt_adapter_request,
))
def remove_prompt_adapter(self, prompt_adapter_id: int) -> bool:
return self.driver_worker.remove_prompt_adapter(prompt_adapter_id)
return all(
self._run_workers(
"remove_prompt_adapter",
prompt_adapter_id,
))
def list_prompt_adapters(self) -> Set[int]:
return self.driver_worker.list_prompt_adapters()
return self.driver_method_invoker(self.driver_worker,
"list_prompt_adapters")
def pin_prompt_adapter(self, prompt_adapter_id: int) -> bool:
return self.driver_worker.pin_prompt_adapter(prompt_adapter_id)
return all(self._run_workers(
"pin_prompt_adapter",
prompt_adapter_id,
))
def check_health(self) -> None:
# CPUExecutor will always be healthy as long as
# it's running.
return
"""Raises an error if engine is unhealthy."""
if self.worker_monitor is not None and not self.worker_monitor.is_alive(
):
raise RuntimeError("Worker processes are not running")
def shutdown(self):
if (worker_monitor := getattr(self, "worker_monitor",
None)) is not None:
worker_monitor.close()
def _wait_for_tasks_completion(self, parallel_worker_tasks: Any) -> None:
"""Wait for futures returned from _run_workers() with
async_run_remote_workers_only to complete."""
for result in parallel_worker_tasks:
result.get()
class CPUExecutorAsync(CPUExecutor, ExecutorAsyncBase):
@ -118,14 +302,12 @@ class CPUExecutorAsync(CPUExecutor, ExecutorAsyncBase):
async def execute_model_async(
self,
execute_model_req: ExecuteModelRequest) -> List[SamplerOutput]:
output = await make_async(self.driver_worker.execute_model
output = await make_async(self.execute_model
)(execute_model_req=execute_model_req, )
return output
async def check_health_async(self) -> None:
# CPUExecutor will always be healthy as long as
# it's running.
return
self.check_health()
def _verify_and_get_model_config(config: ModelConfig) -> ModelConfig:
@ -170,3 +352,11 @@ def _verify_and_get_cache_config(config: CacheConfig) -> CacheConfig:
f" {kv_cache_space}, expect a positive integer value.")
return config
def _driver_method_invoker(driver, method: str, *args, **kwargs):
return getattr(driver, method)(*args, **kwargs)
def _async_driver_method_invoker(driver, method: str, *args, **kwargs):
return driver.execute_method(method, *args, **kwargs).get()

View File

@ -404,27 +404,6 @@ def update_environment_variables(envs: Dict[str, str]):
os.environ[k] = v
def init_kmp_env():
if not is_cpu():
return
ld_prealod_str = os.getenv("LD_PRELOAD", "")
if "libiomp5.so" not in ld_prealod_str:
return
# The time(milliseconds) that a thread should wait after completing the
# execution of a parallel region, before sleeping.
os.environ['KMP_BLOCKTIME'] = "1"
# dump settings on start up
os.environ['KMP_SETTINGS'] = "1"
# Prevents the CPU to run into low performance state
os.environ['KMP_TPAUSE'] = "0"
# Provides fine granularity parallelism
os.environ['KMP_FORKJOIN_BARRIER_PATTERN'] = "dist,dist"
os.environ['KMP_PLAIN_BARRIER_PATTERN'] = "dist,dist"
os.environ['KMP_REDUCTION_BARRIER_PATTERN'] = "dist,dist"
def chunk_list(lst: List[T], chunk_size: int):
"""Yield successive chunk_size chunks from lst."""
for i in range(0, len(lst), chunk_size):

View File

@ -42,6 +42,7 @@ class CPUModelInput(ModelRunnerInputBase):
attn_metadata: Optional["AttentionMetadata"] = None
sampling_metadata: Optional["SamplingMetadata"] = None
multi_modal_kwargs: Optional[Mapping[str, BatchedTensors]] = None
virtual_engine: Optional[int] = None
def as_broadcastable_tensor_dict(
self) -> Dict[str, Union[int, torch.Tensor]]:
@ -204,8 +205,8 @@ class CPUModelRunner(ModelRunnerBase[CPUModelInput]):
attn_metadata = self.attn_backend.make_metadata(
is_prompt=True,
seq_lens=seq_lens,
seq_lens_tensor=None,
max_decode_seq_len=None,
seq_lens_tensor=torch.tensor([]),
max_decode_seq_len=0,
num_prefills=len(seq_lens),
num_prefill_tokens=num_prompt_tokens,
num_decode_tokens=0,
@ -345,7 +346,7 @@ class CPUModelRunner(ModelRunnerBase[CPUModelInput]):
multi_modal_kwargs=multi_modal_kwargs,
)
@torch.inference_mode()
@torch.no_grad()
def execute_model(
self,
model_input: CPUModelInput,

View File

@ -4,6 +4,7 @@ from typing import Dict, List, Optional, Tuple
import torch
import torch.distributed
import vllm.envs as envs
from vllm.attention import get_attn_backend
from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig,
ModelConfig, MultiModalConfig, ParallelConfig,
@ -13,7 +14,7 @@ from vllm.distributed import (ensure_model_parallel_initialized,
from vllm.logger import init_logger
from vllm.model_executor import set_random_seed
from vllm.sequence import ExecuteModelRequest
from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE, init_kmp_env
from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE
from vllm.worker.cpu_model_runner import CPUModelRunner
from vllm.worker.worker_base import (LocalOrDistributedWorkerBase,
LoraNotSupportedWorkerBase, WorkerInput)
@ -152,13 +153,18 @@ class CPUWorker(LoraNotSupportedWorkerBase, LocalOrDistributedWorkerBase):
if self.is_driver_worker:
assert self.rank == 0, "The driver worker must have rank 0."
# try to initialize intel openmp optimized tunings
init_kmp_env()
if self.model_config.trust_remote_code:
# note: lazy import to avoid importing torch before initializing
from vllm.utils import init_cached_hf_modules
init_cached_hf_modules()
# Setup OpenMP threads affinity.
omp_cpuids = envs.VLLM_CPU_OMP_THREADS_BIND
if omp_cpuids == "all":
self.local_omp_cpuid = "all"
else:
self.local_omp_cpuid = omp_cpuids.split("|")[rank]
self.model_runner: CPUModelRunner = CPUModelRunner(
model_config,
parallel_config,
@ -177,6 +183,9 @@ class CPUWorker(LoraNotSupportedWorkerBase, LocalOrDistributedWorkerBase):
self.cpu_cache: List[List[torch.Tensor]]
def init_device(self) -> None:
if self.local_omp_cpuid != "all":
torch.ops._C_utils.init_cpu_threads_env(self.local_omp_cpuid)
self.init_distributed_environment()
# Set random seed.
set_random_seed(self.model_config.seed)