2025-02-02 14:58:18 -05:00
|
|
|
# SPDX-License-Identifier: Apache-2.0
|
|
|
|
|
2025-01-15 13:45:21 +08:00
|
|
|
import os
|
2024-12-10 01:24:46 +08:00
|
|
|
from typing import TYPE_CHECKING, Optional
|
2024-11-16 23:14:23 -08:00
|
|
|
|
2024-09-29 10:50:51 +08:00
|
|
|
import psutil
|
2024-09-12 00:46:46 +08:00
|
|
|
import torch
|
|
|
|
|
2024-11-16 23:14:23 -08:00
|
|
|
from vllm.logger import init_logger
|
|
|
|
|
2024-11-19 11:22:26 +08:00
|
|
|
from .interface import Platform, PlatformEnum, _Backend
|
|
|
|
|
|
|
|
logger = init_logger(__name__)
|
2024-09-12 00:46:46 +08:00
|
|
|
|
2024-11-16 23:14:23 -08:00
|
|
|
if TYPE_CHECKING:
|
|
|
|
from vllm.config import VllmConfig
|
|
|
|
else:
|
|
|
|
VllmConfig = None
|
|
|
|
|
|
|
|
logger = init_logger(__name__)
|
|
|
|
|
2024-09-12 00:46:46 +08:00
|
|
|
|
|
|
|
class CpuPlatform(Platform):
|
|
|
|
_enum = PlatformEnum.CPU
|
2024-11-29 23:22:21 +08:00
|
|
|
device_name: str = "cpu"
|
2024-11-21 12:44:20 +08:00
|
|
|
device_type: str = "cpu"
|
2024-11-22 14:04:42 -08:00
|
|
|
dispatch_key: str = "CPU"
|
2024-09-12 00:46:46 +08:00
|
|
|
|
2024-09-18 18:38:11 +08:00
|
|
|
@classmethod
|
|
|
|
def get_device_name(cls, device_id: int = 0) -> str:
|
2024-09-12 00:46:46 +08:00
|
|
|
return "cpu"
|
|
|
|
|
2024-11-19 11:22:26 +08:00
|
|
|
@classmethod
|
2025-01-09 21:46:50 +08:00
|
|
|
def get_attn_backend_cls(cls, selected_backend: _Backend, head_size: int,
|
|
|
|
dtype: torch.dtype, kv_cache_dtype: Optional[str],
|
2025-01-31 02:49:37 -05:00
|
|
|
block_size: int, use_v1: bool,
|
|
|
|
use_mla: bool) -> str:
|
2025-02-09 22:45:07 -05:00
|
|
|
if selected_backend and selected_backend != _Backend.TORCH_SDPA:
|
2024-11-19 11:22:26 +08:00
|
|
|
logger.info("Cannot use %s backend on CPU.", selected_backend)
|
2025-01-09 21:46:50 +08:00
|
|
|
logger.info("Using Torch SDPA backend.")
|
|
|
|
return "vllm.attention.backends.torch_sdpa.TorchSDPABackend"
|
2024-11-19 11:22:26 +08:00
|
|
|
|
2024-09-29 10:50:51 +08:00
|
|
|
@classmethod
|
|
|
|
def get_device_total_memory(cls, device_id: int = 0) -> int:
|
|
|
|
return psutil.virtual_memory().total
|
|
|
|
|
2024-12-10 01:24:46 +08:00
|
|
|
@classmethod
|
|
|
|
def is_async_output_supported(cls, enforce_eager: Optional[bool]) -> bool:
|
|
|
|
return False
|
|
|
|
|
2024-09-18 18:38:11 +08:00
|
|
|
@classmethod
|
|
|
|
def inference_mode(cls):
|
2024-09-12 00:46:46 +08:00
|
|
|
return torch.no_grad()
|
2024-11-16 23:14:23 -08:00
|
|
|
|
|
|
|
@classmethod
|
|
|
|
def check_and_update_config(cls, vllm_config: VllmConfig) -> None:
|
|
|
|
import vllm.envs as envs
|
|
|
|
from vllm.utils import GiB_bytes
|
|
|
|
model_config = vllm_config.model_config
|
2025-01-06 21:40:31 +08:00
|
|
|
# Reminder: Please update docs/source/features/compatibility_matrix.md
|
2024-11-16 23:14:23 -08:00
|
|
|
# If the feature combo become valid
|
|
|
|
if not model_config.enforce_eager:
|
|
|
|
logger.warning(
|
|
|
|
"CUDA graph is not supported on CPU, fallback to the eager "
|
|
|
|
"mode.")
|
|
|
|
model_config.enforce_eager = True
|
|
|
|
|
|
|
|
cache_config = vllm_config.cache_config
|
|
|
|
|
2024-12-17 14:11:06 +08:00
|
|
|
if cache_config and cache_config.block_size is None:
|
|
|
|
cache_config.block_size = 16
|
|
|
|
|
2024-11-16 23:14:23 -08:00
|
|
|
kv_cache_space = envs.VLLM_CPU_KVCACHE_SPACE
|
|
|
|
|
|
|
|
if kv_cache_space >= 0:
|
|
|
|
if kv_cache_space == 0:
|
|
|
|
cache_config.cpu_kvcache_space_bytes = 4 * GiB_bytes # type: ignore
|
|
|
|
logger.warning(
|
|
|
|
"Environment variable VLLM_CPU_KVCACHE_SPACE (GB) "
|
|
|
|
"for CPU backend is not set, using 4 by default.")
|
|
|
|
else:
|
|
|
|
cache_config.cpu_kvcache_space_bytes = kv_cache_space * GiB_bytes # type: ignore # noqa
|
|
|
|
else:
|
|
|
|
raise RuntimeError(
|
|
|
|
"Invalid environment variable VLLM_CPU_KVCACHE_SPACE"
|
|
|
|
f" {kv_cache_space}, expect a positive integer value.")
|
|
|
|
|
|
|
|
scheduler_config = vllm_config.scheduler_config
|
2024-11-20 18:57:39 +08:00
|
|
|
if ((scheduler_config.chunked_prefill_enabled
|
|
|
|
or cache_config.enable_prefix_caching)
|
|
|
|
and model_config.dtype == torch.half):
|
|
|
|
logger.warning("Chunked-prefill on the CPU backend only does not"
|
|
|
|
" support fp16 for now, cast to bf16.")
|
|
|
|
model_config.dtype = torch.bfloat16
|
2024-11-16 23:14:23 -08:00
|
|
|
|
|
|
|
parallel_config = vllm_config.parallel_config
|
|
|
|
if (parallel_config.distributed_executor_backend is not None
|
|
|
|
and parallel_config.distributed_executor_backend != "mp"):
|
|
|
|
logger.warning(("%s is not supported on CPU, fallback to mp "
|
|
|
|
"distributed executor backend."),
|
|
|
|
parallel_config.distributed_executor_backend)
|
|
|
|
parallel_config.distributed_executor_backend = "mp"
|
2024-11-21 21:00:32 -08:00
|
|
|
if parallel_config.worker_cls == "auto":
|
2024-11-26 19:57:11 -06:00
|
|
|
if vllm_config.speculative_config:
|
|
|
|
parallel_config.worker_cls = \
|
|
|
|
"vllm.spec_decode.spec_decode_worker.create_spec_worker"
|
|
|
|
parallel_config.sd_worker_cls = \
|
|
|
|
"vllm.worker.cpu_worker.CPUWorker"
|
|
|
|
else:
|
|
|
|
parallel_config.worker_cls = "vllm.worker.cpu_worker.CPUWorker"
|
2024-12-13 21:39:00 +08:00
|
|
|
|
2025-01-15 13:45:21 +08:00
|
|
|
assert vllm_config.device_config.device_type == "cpu"
|
|
|
|
|
|
|
|
#
|
|
|
|
# Environment variables for CPU executor
|
|
|
|
#
|
|
|
|
|
|
|
|
# Disable torch async compiling which won't work with daemonic processes
|
|
|
|
os.environ["TORCHINDUCTOR_COMPILE_THREADS"] = "1"
|
|
|
|
|
|
|
|
# Intel OpenMP setting
|
|
|
|
ld_prealod_str = os.getenv("LD_PRELOAD", "")
|
|
|
|
if "libiomp5.so" in ld_prealod_str:
|
|
|
|
# The time(milliseconds) that a thread should wait after
|
|
|
|
# completing the execution of a parallel region, before sleeping.
|
|
|
|
os.environ['KMP_BLOCKTIME'] = "1"
|
|
|
|
# Prevents the CPU to run into low performance state
|
|
|
|
os.environ['KMP_TPAUSE'] = "0"
|
|
|
|
# Provides fine granularity parallelism
|
|
|
|
os.environ['KMP_FORKJOIN_BARRIER_PATTERN'] = "dist,dist"
|
|
|
|
os.environ['KMP_PLAIN_BARRIER_PATTERN'] = "dist,dist"
|
|
|
|
os.environ['KMP_REDUCTION_BARRIER_PATTERN'] = "dist,dist"
|
|
|
|
|
|
|
|
# To hint IPEX uses shared memory based AllReduce
|
|
|
|
os.environ["LOCAL_WORLD_SIZE"] = str(
|
|
|
|
vllm_config.parallel_config.tensor_parallel_size)
|
|
|
|
|
2024-12-13 21:39:00 +08:00
|
|
|
@classmethod
|
|
|
|
def is_pin_memory_available(cls) -> bool:
|
|
|
|
logger.warning("Pin memory is not supported on CPU.")
|
|
|
|
return False
|
2025-01-13 21:12:10 +08:00
|
|
|
|
|
|
|
@classmethod
|
|
|
|
def get_punica_wrapper(cls) -> str:
|
|
|
|
return "vllm.lora.punica_wrapper.punica_cpu.PunicaWrapperCPU"
|