[OpenVINO] Enable GPU support for OpenVINO vLLM backend (#8192)
This commit is contained in:
parent
afb050b29d
commit
f58d4fccc9
@ -3,7 +3,7 @@
|
||||
Installation with OpenVINO
|
||||
==========================
|
||||
|
||||
vLLM powered by OpenVINO supports all LLM models from :doc:`vLLM supported models list <../models/supported_models>` and can perform optimal model serving on all x86-64 CPUs with, at least, AVX2 support. OpenVINO vLLM backend supports the following advanced vLLM features:
|
||||
vLLM powered by OpenVINO supports all LLM models from :doc:`vLLM supported models list <../models/supported_models>` and can perform optimal model serving on all x86-64 CPUs with, at least, AVX2 support, as well as on both integrated and discrete Intel® GPUs (`the list of supported GPUs <https://docs.openvino.ai/2024/about-openvino/release-notes-openvino/system-requirements.html#gpu>`_). OpenVINO vLLM backend supports the following advanced vLLM features:
|
||||
|
||||
- Prefix caching (``--enable-prefix-caching``)
|
||||
- Chunked prefill (``--enable-chunked-prefill``)
|
||||
@ -59,28 +59,51 @@ Install from source
|
||||
|
||||
$ PIP_EXTRA_INDEX_URL="https://download.pytorch.org/whl/cpu" VLLM_TARGET_DEVICE=openvino python -m pip install -v .
|
||||
|
||||
- [Optional] To use vLLM OpenVINO backend with a GPU device, ensure your system is properly set up. Follow the instructions provided here: `https://docs.openvino.ai/2024/get-started/configurations/configurations-intel-gpu.html <https://docs.openvino.ai/2024/get-started/configurations/configurations-intel-gpu.html>`_.
|
||||
|
||||
.. _openvino_backend_performance_tips:
|
||||
|
||||
Performance tips
|
||||
----------------
|
||||
|
||||
vLLM OpenVINO backend uses the following environment variables to control behavior:
|
||||
vLLM OpenVINO backend environment variables
|
||||
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
- ``VLLM_OPENVINO_DEVICE`` to specify which device utilize for the inference. If there are multiple GPUs in the system, additional indexes can be used to choose the proper one (e.g, ``VLLM_OPENVINO_DEVICE=GPU.1``). If the value is not specified, CPU device is used by default.
|
||||
|
||||
- ``VLLM_OPENVINO_ENABLE_QUANTIZED_WEIGHTS=ON`` to enable U8 weights compression during model loading stage. By default, compression is turned off. You can also export model with different compression techniques using `optimum-cli` and pass exported folder as `<model_id>`
|
||||
|
||||
CPU performance tips
|
||||
~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
CPU uses the following environment variables to control behavior:
|
||||
|
||||
- ``VLLM_OPENVINO_KVCACHE_SPACE`` to specify the KV Cache size (e.g, ``VLLM_OPENVINO_KVCACHE_SPACE=40`` means 40 GB space for KV cache), larger setting will allow vLLM running more requests in parallel. This parameter should be set based on the hardware configuration and memory management pattern of users.
|
||||
|
||||
- ``VLLM_OPENVINO_CPU_KV_CACHE_PRECISION=u8`` to control KV cache precision. By default, FP16 / BF16 is used depending on platform.
|
||||
|
||||
- ``VLLM_OPENVINO_ENABLE_QUANTIZED_WEIGHTS=ON`` to enable U8 weights compression during model loading stage. By default, compression is turned off. You can also export model with different compression techniques using `optimum-cli` and pass exported folder as `<model_id>`
|
||||
|
||||
To enable better TPOT / TTFT latency, you can use vLLM's chunked prefill feature (``--enable-chunked-prefill``). Based on the experiments, the recommended batch size is ``256`` (``--max-num-batched-tokens``)
|
||||
|
||||
OpenVINO best known configuration is:
|
||||
OpenVINO best known configuration for CPU is:
|
||||
|
||||
.. code-block:: console
|
||||
|
||||
$ VLLM_OPENVINO_KVCACHE_SPACE=100 VLLM_OPENVINO_CPU_KV_CACHE_PRECISION=u8 VLLM_OPENVINO_ENABLE_QUANTIZED_WEIGHTS=ON \
|
||||
python3 vllm/benchmarks/benchmark_throughput.py --model meta-llama/Llama-2-7b-chat-hf --dataset vllm/benchmarks/ShareGPT_V3_unfiltered_cleaned_split.json --enable-chunked-prefill --max-num-batched-tokens 256
|
||||
|
||||
GPU performance tips
|
||||
~~~~~~~~~~~~~~~~~~~~
|
||||
GPU device implements the logic for automatic detection of available GPU memory and, by default, tries to reserve as much memory as possible for the KV cache (taking into account ``gpu_memory_utilization`` option). However, this behavior can be overridden by explicitly specifying the desired amount of memory for the KV cache using ``VLLM_OPENVINO_KVCACHE_SPACE`` environment variable (e.g, ``VLLM_OPENVINO_KVCACHE_SPACE=8`` means 8 GB space for KV cache).
|
||||
|
||||
Currently, the best performance using GPU can be achieved with the default vLLM execution parameters for models with quantized weights (8 and 4-bit integer data types are supported) and `preemption-mode=swap`.
|
||||
|
||||
OpenVINO best known configuration for GPU is:
|
||||
|
||||
.. code-block:: console
|
||||
|
||||
$ VLLM_OPENVINO_DEVICE=GPU VLLM_OPENVINO_ENABLE_QUANTIZED_WEIGHTS=ON \
|
||||
python3 vllm/benchmarks/benchmark_throughput.py --model meta-llama/Llama-2-7b-chat-hf --dataset vllm/benchmarks/ShareGPT_V3_unfiltered_cleaned_split.json
|
||||
|
||||
.. _openvino_backend_limitations:
|
||||
|
||||
Limitations
|
||||
|
@ -3,5 +3,6 @@
|
||||
|
||||
# OpenVINO dependencies
|
||||
torch >= 2.1.2
|
||||
openvino ~= 2024.3.0
|
||||
optimum-intel[openvino] >= 1.18.2
|
||||
openvino ~= 2024.4.0
|
||||
openvino-tokenizers[transformers] ~= 2024.4.0
|
||||
optimum-intel[openvino] >= 1.19.0
|
||||
|
@ -9,6 +9,31 @@ from vllm.attention.backends.abstract import (AttentionBackend,
|
||||
from vllm.attention.backends.utils import CommonAttentionState
|
||||
|
||||
|
||||
def copy_cache_block(src_tensor: ov.Tensor, dst_tensor: ov.Tensor,
|
||||
src_offset: int, dst_offset: int) -> None:
|
||||
|
||||
def create_roi_tensor(
|
||||
tensor: ov.Tensor,
|
||||
block_number: int,
|
||||
) -> ov.Tensor:
|
||||
roi_begin = ov.runtime.Coordinate([0, 0, 0, 0])
|
||||
roi_end = ov.runtime.Coordinate(tensor.get_shape())
|
||||
|
||||
roi_begin[0] = block_number
|
||||
roi_end[0] = block_number + 1
|
||||
|
||||
if isinstance(tensor, ov.Tensor):
|
||||
return ov.Tensor(tensor, roi_begin, roi_end)
|
||||
else:
|
||||
return ov.RemoteTensor(tensor, roi_begin, roi_end)
|
||||
|
||||
src_roi_tensor = \
|
||||
create_roi_tensor(src_tensor, src_offset)
|
||||
dst_roi_tensor = \
|
||||
create_roi_tensor(dst_tensor, dst_offset)
|
||||
src_roi_tensor.copy_to(dst_roi_tensor)
|
||||
|
||||
|
||||
class OpenVINOAttentionBackend(AttentionBackend):
|
||||
|
||||
@staticmethod
|
||||
@ -44,13 +69,12 @@ class OpenVINOAttentionBackend(AttentionBackend):
|
||||
|
||||
@staticmethod
|
||||
def swap_blocks(
|
||||
src_kv_cache: ov.Tensor,
|
||||
dst_kv_cache: ov.Tensor,
|
||||
src_to_dst: torch.Tensor,
|
||||
src_tensor: ov.Tensor,
|
||||
dst_tensor: ov.Tensor,
|
||||
src_to_dists: List[Tuple[int, int]],
|
||||
) -> None:
|
||||
# OpenVINO currently supports only CPU, which does not require
|
||||
# swap of KV cache blocks
|
||||
raise NotImplementedError
|
||||
for src, dst in src_to_dists:
|
||||
copy_cache_block(src_tensor, dst_tensor, src, dst)
|
||||
|
||||
@staticmethod
|
||||
def copy_blocks(
|
||||
@ -59,8 +83,8 @@ class OpenVINOAttentionBackend(AttentionBackend):
|
||||
) -> None:
|
||||
for src, dst in src_to_dists:
|
||||
for key_cache, value_cache in kv_caches:
|
||||
key_cache.data[dst, :] = key_cache.data[src, :]
|
||||
value_cache.data[dst, :] = value_cache.data[src, :]
|
||||
copy_cache_block(key_cache, key_cache, src, dst)
|
||||
copy_cache_block(value_cache, value_cache, src, dst)
|
||||
|
||||
|
||||
@dataclass
|
||||
|
@ -35,6 +35,7 @@ if TYPE_CHECKING:
|
||||
VLLM_PP_LAYER_PARTITION: Optional[str] = None
|
||||
VLLM_CPU_KVCACHE_SPACE: int = 0
|
||||
VLLM_CPU_OMP_THREADS_BIND: str = ""
|
||||
VLLM_OPENVINO_DEVICE: str = "CPU"
|
||||
VLLM_OPENVINO_KVCACHE_SPACE: int = 0
|
||||
VLLM_OPENVINO_CPU_KV_CACHE_PRECISION: Optional[str] = None
|
||||
VLLM_OPENVINO_ENABLE_QUANTIZED_WEIGHTS: bool = False
|
||||
@ -302,6 +303,11 @@ environment_variables: Dict[str, Callable[[], Any]] = {
|
||||
"VLLM_CPU_OMP_THREADS_BIND":
|
||||
lambda: os.getenv("VLLM_CPU_OMP_THREADS_BIND", "all"),
|
||||
|
||||
# OpenVINO device selection
|
||||
# default is CPU
|
||||
"VLLM_OPENVINO_DEVICE":
|
||||
lambda: os.getenv("VLLM_OPENVINO_DEVICE", "CPU").upper(),
|
||||
|
||||
# OpenVINO key-value cache space
|
||||
# default is 4GB
|
||||
"VLLM_OPENVINO_KVCACHE_SPACE":
|
||||
|
@ -17,6 +17,14 @@ from vllm.utils import (GiB_bytes, get_distributed_init_method, get_ip,
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
def is_openvino_cpu() -> bool:
|
||||
return "CPU" in envs.VLLM_OPENVINO_DEVICE
|
||||
|
||||
|
||||
def is_openvino_gpu() -> bool:
|
||||
return "GPU" in envs.VLLM_OPENVINO_DEVICE
|
||||
|
||||
|
||||
class OpenVINOExecutor(ExecutorBase):
|
||||
|
||||
uses_ray: bool = False
|
||||
@ -24,8 +32,13 @@ class OpenVINOExecutor(ExecutorBase):
|
||||
def _init_executor(self) -> None:
|
||||
assert self.device_config.device_type == "openvino"
|
||||
assert self.lora_config is None, "OpenVINO backend doesn't support LoRA"
|
||||
assert is_openvino_cpu() or is_openvino_gpu(), \
|
||||
"OpenVINO backend supports only CPU and GPU devices"
|
||||
|
||||
self.ov_core = ov.Core()
|
||||
self.model_config = _verify_and_get_model_config(self.model_config)
|
||||
self.cache_config = _verify_and_get_cache_config(self.cache_config)
|
||||
self.cache_config = _verify_and_get_cache_config(
|
||||
self.ov_core, self.cache_config)
|
||||
|
||||
# Instantiate the worker and load the model to CPU.
|
||||
self._init_worker()
|
||||
@ -40,6 +53,7 @@ class OpenVINOExecutor(ExecutorBase):
|
||||
distributed_init_method = get_distributed_init_method(
|
||||
get_ip(), get_open_port())
|
||||
self.driver_worker = OpenVINOWorker(
|
||||
ov_core=self.ov_core,
|
||||
model_config=self.model_config,
|
||||
parallel_config=self.parallel_config,
|
||||
scheduler_config=self.scheduler_config,
|
||||
@ -68,10 +82,13 @@ class OpenVINOExecutor(ExecutorBase):
|
||||
# NOTE: We log here to avoid multiple logs when number of workers is
|
||||
# greater than one. We could log in the engine, but not all executors
|
||||
# have GPUs.
|
||||
# NOTE: `cpu block` for OpenVINO backend is located on CPU memory but is
|
||||
# referred as `gpu block`. Because we want to reuse the existing block
|
||||
# management procedure.
|
||||
logger.info("# CPU blocks: %d", num_gpu_blocks)
|
||||
# NOTE: In case of a CPU device, `cpu block` for OpenVINO backend
|
||||
# is located on CPU memory but is referred as `gpu block`.
|
||||
# Because we want to reuse the existing block management procedure.
|
||||
device_blocks = num_gpu_blocks
|
||||
swap_blocks = num_cpu_blocks
|
||||
logger.info("OpenVINO %s: # device blocks: %d; # swap blocks: %d",
|
||||
envs.VLLM_OPENVINO_DEVICE, device_blocks, swap_blocks)
|
||||
self.driver_worker.initialize_cache(num_gpu_blocks, num_cpu_blocks)
|
||||
|
||||
def execute_model(
|
||||
@ -143,29 +160,45 @@ def _verify_and_get_model_config(config: ModelConfig) -> ModelConfig:
|
||||
return config
|
||||
|
||||
|
||||
def _verify_and_get_cache_config(config: CacheConfig) -> CacheConfig:
|
||||
def _verify_and_get_cache_config(ov_core: ov.Core,
|
||||
config: CacheConfig) -> CacheConfig:
|
||||
if envs.VLLM_OPENVINO_CPU_KV_CACHE_PRECISION == "u8":
|
||||
logger.info("KV cache type is overried to u8 via "
|
||||
"VLLM_OPENVINO_CPU_KV_CACHE_PRECISION env var.")
|
||||
config.cache_dtype = ov.Type.u8
|
||||
if not is_openvino_cpu():
|
||||
logger.info("VLLM_OPENVINO_CPU_KV_CACHE_PRECISION is"
|
||||
"ignored for GPU, f16 data type will be used.")
|
||||
config.cache_dtype = ov.Type.f16
|
||||
else:
|
||||
logger.info("KV cache type is overridden to u8 via "
|
||||
"VLLM_OPENVINO_CPU_KV_CACHE_PRECISION env var.")
|
||||
config.cache_dtype = ov.Type.u8
|
||||
else:
|
||||
core = ov.Core()
|
||||
inference_precision = core.get_property("CPU",
|
||||
hints.inference_precision)
|
||||
if inference_precision == ov.Type.bf16:
|
||||
config.cache_dtype = ov.Type.bf16
|
||||
if is_openvino_cpu():
|
||||
ov_device = envs.VLLM_OPENVINO_DEVICE
|
||||
inference_precision = ov_core.get_property(
|
||||
ov_device, hints.inference_precision)
|
||||
if inference_precision == ov.Type.bf16:
|
||||
config.cache_dtype = ov.Type.bf16
|
||||
else:
|
||||
config.cache_dtype = ov.Type.f16
|
||||
else:
|
||||
config.cache_dtype = ov.Type.f16
|
||||
|
||||
if config.block_size != 32:
|
||||
logger.info(
|
||||
f"OpenVINO optimal block size is 32, overriding currently set {config.block_size}" # noqa: G004, E501
|
||||
)
|
||||
config.block_size = 32
|
||||
if is_openvino_cpu():
|
||||
if config.block_size != 32:
|
||||
logger.info(
|
||||
f"OpenVINO CPU optimal block size is 32, overriding currently set {config.block_size}" # noqa: G004, E501
|
||||
)
|
||||
config.block_size = 32
|
||||
else:
|
||||
if config.block_size != 16:
|
||||
logger.info(
|
||||
f"OpenVINO GPU optimal block size is 16, overriding currently set {config.block_size}" # noqa: G004, E501
|
||||
)
|
||||
config.block_size = 16
|
||||
|
||||
kv_cache_space = envs.VLLM_OPENVINO_KVCACHE_SPACE
|
||||
if kv_cache_space >= 0:
|
||||
if kv_cache_space == 0:
|
||||
if kv_cache_space == 0 and is_openvino_cpu():
|
||||
config.openvino_kvcache_space_bytes = 4 * GiB_bytes # type: ignore
|
||||
logger.warning(
|
||||
"Environment variable VLLM_OPENVINO_KVCACHE_SPACE (GB) "
|
||||
|
@ -12,6 +12,7 @@ from torch import nn
|
||||
import vllm.envs as envs
|
||||
from vllm.attention.backends.openvino import OpenVINOAttentionMetadata
|
||||
from vllm.config import DeviceConfig, ModelConfig
|
||||
from vllm.executor.openvino_executor import is_openvino_cpu
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.logits_processor import (LogitsProcessor,
|
||||
_prune_hidden_states)
|
||||
@ -51,25 +52,15 @@ def _modify_cache_parameters(model: ov.Model, kv_cache_dtype: ov.Type,
|
||||
shape = parameter.get_partial_shape()
|
||||
# use real block size if available, just a placeholder
|
||||
# to provide the expected rank
|
||||
x_size = 1
|
||||
num_blocks = ov.Dimension()
|
||||
block_size = ov.Dimension()
|
||||
head_size = ov.Dimension()
|
||||
# TODO: Negotiate required layout with plugins (CPU is ~OK, GPU is TBD),
|
||||
# pass more parameters to this function to set more static dimensions
|
||||
if input_name.startswith("key_cache."):
|
||||
cpu_shape = [num_blocks, shape[1], block_size, head_size]
|
||||
gpu_shape = [
|
||||
num_blocks,
|
||||
shape[1],
|
||||
shape[2].get_length() //
|
||||
x_size if shape[2].is_static else ov.Dimension(),
|
||||
block_size,
|
||||
x_size,
|
||||
]
|
||||
gpu_shape = [num_blocks, shape[1], shape[2], block_size]
|
||||
elif input_name.startswith("value_cache."):
|
||||
cpu_shape = [num_blocks, shape[1], block_size, head_size]
|
||||
gpu_shape = [num_blocks, shape[1], shape[2], block_size]
|
||||
gpu_shape = [num_blocks, shape[1], block_size, shape[2]]
|
||||
else:
|
||||
continue
|
||||
parameter.set_partial_shape(
|
||||
@ -108,6 +99,7 @@ class OpenVINOCasualLM(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
ov_core: ov.Core,
|
||||
model_config: ModelConfig,
|
||||
device_config: DeviceConfig,
|
||||
kv_cache_dtype: ov.Type,
|
||||
@ -141,12 +133,12 @@ class OpenVINOCasualLM(nn.Module):
|
||||
trust_remote_code=model_config.trust_remote_code,
|
||||
)
|
||||
|
||||
ov_device = envs.VLLM_OPENVINO_DEVICE
|
||||
paged_attention_transformation(pt_model.model)
|
||||
_modify_cache_parameters(pt_model.model, kv_cache_dtype,
|
||||
device_config.device.type == "cpu")
|
||||
is_openvino_cpu())
|
||||
|
||||
core = ov.Core()
|
||||
ov_compiled = core.compile_model(pt_model.model, "CPU")
|
||||
ov_compiled = ov_core.compile_model(pt_model.model, ov_device)
|
||||
self.ov_request = ov_compiled.create_infer_request()
|
||||
|
||||
def forward(
|
||||
@ -199,6 +191,7 @@ def get_model(
|
||||
**kwargs,
|
||||
) -> torch.nn.Module:
|
||||
lora_config = kwargs.get("lora_config", None)
|
||||
ov_core = kwargs.get("ov_core")
|
||||
if lora_config:
|
||||
raise ValueError(
|
||||
"OpenVINO modeling does not support LoRA, "
|
||||
@ -206,4 +199,5 @@ def get_model(
|
||||
"be added in the future. If this is important to you, "
|
||||
"please open an issue on github.")
|
||||
|
||||
return OpenVINOCasualLM(model_config, device_config, kv_cache_dtype)
|
||||
return OpenVINOCasualLM(ov_core, model_config, device_config,
|
||||
kv_cache_dtype)
|
||||
|
@ -42,6 +42,7 @@ class OpenVINOModelRunner:
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
ov_core: ov.Core,
|
||||
model_config: ModelConfig,
|
||||
parallel_config: ParallelConfig,
|
||||
scheduler_config: SchedulerConfig,
|
||||
@ -55,6 +56,7 @@ class OpenVINOModelRunner:
|
||||
*args,
|
||||
**kwargs,
|
||||
):
|
||||
self.ov_core = ov_core
|
||||
self.model_config = model_config
|
||||
self.parallel_config = parallel_config
|
||||
self.scheduler_config = scheduler_config
|
||||
@ -89,11 +91,10 @@ class OpenVINOModelRunner:
|
||||
self.model: nn.Module # Set after init_Model
|
||||
|
||||
def load_model(self) -> None:
|
||||
self.model = get_model(
|
||||
model_config=self.model_config,
|
||||
device_config=self.device_config,
|
||||
kv_cache_dtype=self.kv_cache_dtype,
|
||||
)
|
||||
self.model = get_model(model_config=self.model_config,
|
||||
device_config=self.device_config,
|
||||
kv_cache_dtype=self.kv_cache_dtype,
|
||||
ov_core=self.ov_core)
|
||||
|
||||
def _prepare_model_input(
|
||||
self,
|
||||
|
@ -5,6 +5,7 @@ import openvino as ov
|
||||
import torch
|
||||
import torch.distributed
|
||||
|
||||
import vllm.envs as envs
|
||||
from vllm.attention import get_attn_backend
|
||||
from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig,
|
||||
ModelConfig, MultiModalConfig, ParallelConfig,
|
||||
@ -12,10 +13,14 @@ from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig,
|
||||
from vllm.distributed import (broadcast_tensor_dict,
|
||||
ensure_model_parallel_initialized,
|
||||
init_distributed_environment)
|
||||
from vllm.executor.openvino_executor import is_openvino_cpu
|
||||
from vllm.inputs import INPUT_REGISTRY
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor import set_random_seed
|
||||
from vllm.model_executor.layers.sampler import SamplerOutput
|
||||
from vllm.sequence import ExecuteModelRequest
|
||||
from vllm.multimodal import MULTIMODAL_REGISTRY
|
||||
from vllm.sampling_params import SamplingParams
|
||||
from vllm.sequence import ExecuteModelRequest, SequenceGroupMetadata
|
||||
from vllm.worker.openvino_model_runner import OpenVINOModelRunner
|
||||
from vllm.worker.worker_base import LoraNotSupportedWorkerBase
|
||||
|
||||
@ -36,6 +41,8 @@ class OpenVINOCacheEngine:
|
||||
model_config: ModelConfig,
|
||||
parallel_config: ParallelConfig,
|
||||
device_config: DeviceConfig,
|
||||
ov_core: ov.Core,
|
||||
ov_device: str,
|
||||
) -> None:
|
||||
assert device_config.device_type == "openvino"
|
||||
self.cache_config = cache_config
|
||||
@ -56,9 +63,10 @@ class OpenVINOCacheEngine:
|
||||
|
||||
self.block_size = cache_config.block_size
|
||||
# Note: In CacheConfig, num_gpu_blocks actual is num_cpu_blocks
|
||||
# for OpenVINO backend, because we want to reuse KV cache management
|
||||
# in the scheduler.
|
||||
self.num_cpu_blocks = cache_config.num_gpu_blocks
|
||||
# for OpenVINO backend with a CPU target device, because we want
|
||||
# to reuse KV cache management in the scheduler.
|
||||
self.num_device_blocks = cache_config.num_gpu_blocks
|
||||
self.num_swap_blocks = cache_config.num_cpu_blocks
|
||||
|
||||
# Get attention backend.
|
||||
self.attn_backend = get_attn_backend(
|
||||
@ -74,34 +82,100 @@ class OpenVINOCacheEngine:
|
||||
# Initialize the cache.
|
||||
self.kv_cache: List[Tuple[ov.Tensor,
|
||||
ov.Tensor]] = self._allocate_kv_cache(
|
||||
self.num_cpu_blocks)
|
||||
self.num_device_blocks, ov_core,
|
||||
ov_device)
|
||||
|
||||
# Initialize the swap.
|
||||
self.swap_cache: List[Tuple[ov.Tensor,
|
||||
ov.Tensor]] = self._allocate_swap_cache(
|
||||
self.num_swap_blocks, ov_device)
|
||||
|
||||
def _allocate_kv_cache(
|
||||
self,
|
||||
num_blocks: int,
|
||||
ov_core: ov.Core,
|
||||
ov_device: str,
|
||||
) -> List[Tuple[ov.Tensor, ov.Tensor]]:
|
||||
"""Allocates KV cache."""
|
||||
k_block_shape = v_block_shape = self.attn_backend.get_kv_cache_shape(
|
||||
num_blocks, self.block_size, self.num_kv_heads, self.head_size)[1:]
|
||||
kv_cache: List[Tuple[ov.Tensor, ov.Tensor]] = []
|
||||
|
||||
if is_openvino_cpu():
|
||||
for _ in range(self.num_layers):
|
||||
key_blocks = ov.Tensor(self.cache_config.cache_dtype,
|
||||
k_block_shape)
|
||||
value_blocks = ov.Tensor(self.cache_config.cache_dtype,
|
||||
v_block_shape)
|
||||
kv_cache.append((key_blocks, value_blocks))
|
||||
else:
|
||||
# Update key_cache shape:
|
||||
k_block_shape = (v_block_shape[0], v_block_shape[1],
|
||||
v_block_shape[3], v_block_shape[2])
|
||||
|
||||
remote_context = ov_core.get_default_context(ov_device)
|
||||
|
||||
for _ in range(self.num_layers):
|
||||
key_blocks = \
|
||||
remote_context.create_tensor(self.cache_config.cache_dtype,
|
||||
ov.Shape(k_block_shape),
|
||||
{})
|
||||
|
||||
value_blocks = \
|
||||
remote_context.create_tensor(self.cache_config.cache_dtype,
|
||||
ov.Shape(v_block_shape),
|
||||
{})
|
||||
|
||||
kv_cache.append((key_blocks, value_blocks))
|
||||
|
||||
return kv_cache
|
||||
|
||||
def _allocate_swap_cache(
|
||||
self,
|
||||
num_blocks: int,
|
||||
ov_device: str,
|
||||
) -> List[Tuple[ov.Tensor, ov.Tensor]]:
|
||||
"""Allocates swap cache."""
|
||||
k_block_shape = v_block_shape = self.attn_backend.get_kv_cache_shape(
|
||||
num_blocks, self.block_size, self.num_kv_heads, self.head_size)[1:]
|
||||
swap_cache: List[Tuple[ov.Tensor, ov.Tensor]] = []
|
||||
|
||||
if num_blocks == 0:
|
||||
return swap_cache
|
||||
|
||||
assert not is_openvino_cpu(), \
|
||||
"CPU device isn't supposed to have swap cache"
|
||||
|
||||
# Update key_cache shape:
|
||||
k_block_shape = (v_block_shape[0], v_block_shape[1], v_block_shape[3],
|
||||
v_block_shape[2])
|
||||
|
||||
for _ in range(self.num_layers):
|
||||
key_blocks = ov.Tensor(self.cache_config.cache_dtype,
|
||||
k_block_shape)
|
||||
value_blocks = ov.Tensor(self.cache_config.cache_dtype,
|
||||
v_block_shape)
|
||||
kv_cache.append((key_blocks, value_blocks))
|
||||
return kv_cache
|
||||
swap_cache.append((key_blocks, value_blocks))
|
||||
|
||||
def swap_in(self, src_to_dst: Dict[int, int]) -> None:
|
||||
raise NotImplementedError(
|
||||
"Swap is not supported in OpenVINOCacheEngine.")
|
||||
return swap_cache
|
||||
|
||||
def swap_out(self, src_to_dst: Dict[int, int]) -> None:
|
||||
raise NotImplementedError(
|
||||
"Swap is not supported in OpenVINOCacheEngine.")
|
||||
def swap_in(self, src_to_dst: List[Tuple[int, int]]) -> None:
|
||||
for i in range(self.num_layers):
|
||||
for swap_tensor, kv_tensor in zip(self.swap_cache[i],
|
||||
self.kv_cache[i]):
|
||||
self.attn_backend.swap_blocks(swap_tensor, kv_tensor,
|
||||
src_to_dst)
|
||||
|
||||
def copy(self, src_to_dsts: Dict[int, List[int]]) -> None:
|
||||
self.attn_backend.copy_blocks(self.kv_cache, src_to_dsts)
|
||||
def swap_out(self, src_to_dst: List[Tuple[int, int]]) -> None:
|
||||
for i in range(self.num_layers):
|
||||
for swap_tensor, kv_tensor in zip(self.swap_cache[i],
|
||||
self.kv_cache[i]):
|
||||
self.attn_backend.swap_blocks(kv_tensor, swap_tensor,
|
||||
src_to_dst)
|
||||
|
||||
def copy(self, src_to_dsts: List[Tuple[int, int]]) -> None:
|
||||
if (len(src_to_dsts) > 0):
|
||||
self.attn_backend.copy_blocks(self.kv_cache, src_to_dsts)
|
||||
|
||||
@staticmethod
|
||||
def get_cache_block_size(
|
||||
@ -139,6 +213,7 @@ class OpenVINOWorker(LoraNotSupportedWorkerBase):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
ov_core: ov.Core,
|
||||
model_config: ModelConfig,
|
||||
parallel_config: ParallelConfig,
|
||||
scheduler_config: SchedulerConfig,
|
||||
@ -153,6 +228,7 @@ class OpenVINOWorker(LoraNotSupportedWorkerBase):
|
||||
kv_cache_dtype: Optional[ov.Type] = ov.Type.undefined,
|
||||
is_driver_worker: bool = False,
|
||||
) -> None:
|
||||
self.ov_core = ov_core
|
||||
self.model_config = model_config
|
||||
self.parallel_config = parallel_config
|
||||
self.parallel_config.rank = rank
|
||||
@ -175,6 +251,7 @@ class OpenVINOWorker(LoraNotSupportedWorkerBase):
|
||||
|
||||
init_cached_hf_modules()
|
||||
self.model_runner = OpenVINOModelRunner(
|
||||
self.ov_core,
|
||||
model_config,
|
||||
parallel_config,
|
||||
scheduler_config,
|
||||
@ -204,56 +281,69 @@ class OpenVINOWorker(LoraNotSupportedWorkerBase):
|
||||
|
||||
This determines how many KV blocks can fit into the configured
|
||||
KV cache space.
|
||||
|
||||
Note that since vLLM assumes a block resides on GPU if it can be
|
||||
modified, we return num_gpu_blocks=num_cpu_blocks and num_cpu_blocks=0.
|
||||
This allows us to reuse the scheduler of vLLM without generalizing it
|
||||
to different devices.
|
||||
"""
|
||||
# For OpenVINO backend, the block number will be calculated based on the
|
||||
# openvino_kvcache_space_bytes.
|
||||
# For OpenVINO backend, in case of CPU device, the block number will be
|
||||
# calculated based on the openvino_kvcache_space_bytes.
|
||||
cache_block_size = self.get_cache_block_size_bytes()
|
||||
num_cpu_blocks = int(self.cache_config.openvino_kvcache_space_bytes //
|
||||
cache_block_size)
|
||||
num_cpu_blocks = max(num_cpu_blocks, 0)
|
||||
kvcache_space_bytes = self.cache_config.openvino_kvcache_space_bytes
|
||||
|
||||
# Note: To reuse the cache management procedure,
|
||||
# use cpu cache as 'gpu cache'.
|
||||
num_gpu_blocks = num_cpu_blocks
|
||||
num_cpu_blocks = 0
|
||||
return num_gpu_blocks, num_cpu_blocks
|
||||
if is_openvino_cpu():
|
||||
num_device_blocks = int(kvcache_space_bytes // cache_block_size)
|
||||
num_swap_blocks = 0
|
||||
else:
|
||||
if kvcache_space_bytes > 0:
|
||||
logger.info("KV_CACHE size was explicitly configured via "
|
||||
"VLLM_OPENVINO_KVCACHE_SPACE environment "
|
||||
"variable, ignoring profiling run.")
|
||||
kv_cache_size = kvcache_space_bytes
|
||||
else:
|
||||
try:
|
||||
kv_cache_size = self.profile_run()
|
||||
except Exception as err:
|
||||
raise RuntimeError(
|
||||
"The error occurred during profile run. This might be "
|
||||
"due to insufficient GPU memory. Consider decreasing "
|
||||
"`max_model_len` to limit the maximum simultaneously "
|
||||
"processed tokens.") from err
|
||||
|
||||
num_device_blocks = int(kv_cache_size // cache_block_size)
|
||||
num_swap_blocks = int(self.cache_config.swap_space_bytes //
|
||||
cache_block_size)
|
||||
|
||||
return num_device_blocks, num_swap_blocks
|
||||
|
||||
def initialize_cache(self, num_gpu_blocks: int,
|
||||
num_cpu_blocks: int) -> None:
|
||||
"""Initialize the KV cache. Currently, swappable CPU memory is not
|
||||
supported.
|
||||
"""Initialize the KV cache. Swappable CPU memory is only
|
||||
supported on GPU.
|
||||
|
||||
Since this worker does not support GPUs, we use the num_gpu_blocks to
|
||||
For CPU, we use the num_gpu_blocks to
|
||||
determine how many non-swappable CPU blocks to allocate.
|
||||
"""
|
||||
assert (num_cpu_blocks == 0
|
||||
), f"{type(self)} does not support swappable cache"
|
||||
|
||||
# Note: To reuse the cache management procedure,
|
||||
# use cpu cache as 'gpu cache'.
|
||||
num_cpu_blocks = num_gpu_blocks
|
||||
num_device_blocks = num_gpu_blocks
|
||||
num_swap_blocks = num_cpu_blocks
|
||||
|
||||
self._validate_num_cpu_blocks(num_cpu_blocks)
|
||||
self.cache_config.num_gpu_blocks = num_cpu_blocks
|
||||
self.cache_config.num_cpu_blocks = 0
|
||||
if is_openvino_cpu():
|
||||
assert (num_swap_blocks == 0
|
||||
), f"{type(self)} does not support swappable cache for CPU"
|
||||
|
||||
self._validate_num_blocks(num_device_blocks)
|
||||
self.cache_config.num_gpu_blocks = num_device_blocks
|
||||
self.cache_config.num_cpu_blocks = num_swap_blocks
|
||||
|
||||
# Initialize the cache.
|
||||
self._init_cache_engine()
|
||||
|
||||
def _validate_num_cpu_blocks(self, num_cpu_blocks: int) -> None:
|
||||
"""Raise errors if the num_cpu_blocks is invalid."""
|
||||
if num_cpu_blocks <= 0:
|
||||
def _validate_num_blocks(self, num_blocks: int) -> None:
|
||||
"""Raise errors if the num_blocks is invalid."""
|
||||
if num_blocks <= 0:
|
||||
raise ValueError(
|
||||
"No available memory for the cache blocks. "
|
||||
"Try increasing `VLLM_OPENVINO_KVCACHE_SPACE` when "
|
||||
"initializing the engine.")
|
||||
|
||||
max_seq_len = self.cache_config.block_size * num_cpu_blocks
|
||||
max_seq_len = self.cache_config.block_size * num_blocks
|
||||
if self.model_config.max_model_len > max_seq_len:
|
||||
raise ValueError(
|
||||
f"The model's max seq len ({self.model_config.max_model_len}) "
|
||||
@ -263,11 +353,14 @@ class OpenVINOWorker(LoraNotSupportedWorkerBase):
|
||||
"when initializing the engine.")
|
||||
|
||||
def _init_cache_engine(self) -> None:
|
||||
ov_device = envs.VLLM_OPENVINO_DEVICE
|
||||
self.cache_engine = OpenVINOCacheEngine(
|
||||
self.cache_config,
|
||||
self.model_config,
|
||||
self.parallel_config,
|
||||
self.device_config,
|
||||
self.ov_core,
|
||||
ov_device,
|
||||
)
|
||||
self.kv_cache = self.cache_engine.kv_cache
|
||||
self.model_runner.block_size = self.cache_engine.block_size
|
||||
@ -275,9 +368,16 @@ class OpenVINOWorker(LoraNotSupportedWorkerBase):
|
||||
assert self.kv_cache is not None
|
||||
|
||||
# Populate the cache to warmup the memory
|
||||
for key_cache, value_cache in self.kv_cache:
|
||||
key_cache.data[:] = 0
|
||||
value_cache.data[:] = 0
|
||||
if is_openvino_cpu():
|
||||
for key_cache, value_cache in self.kv_cache:
|
||||
key_cache.data[:] = 0
|
||||
value_cache.data[:] = 0
|
||||
|
||||
def cache_swap_in(self, src_to_dst: List[Tuple[int, int]]) -> None:
|
||||
self.cache_engine.swap_in(src_to_dst)
|
||||
|
||||
def cache_swap_out(self, src_to_dst: List[Tuple[int, int]]) -> None:
|
||||
self.cache_engine.swap_out(src_to_dst)
|
||||
|
||||
def cache_copy(
|
||||
self,
|
||||
@ -300,17 +400,28 @@ class OpenVINOWorker(LoraNotSupportedWorkerBase):
|
||||
num_seq_groups: int = len(seq_group_metadata_list)
|
||||
assert execute_model_req is not None
|
||||
blocks_to_copy = execute_model_req.blocks_to_copy
|
||||
assert len(execute_model_req.blocks_to_swap_in) == 0
|
||||
assert len(execute_model_req.blocks_to_swap_out) == 0
|
||||
blocks_to_swap_in = execute_model_req.blocks_to_swap_in
|
||||
blocks_to_swap_out = execute_model_req.blocks_to_swap_out
|
||||
data: Dict[str, Any] = {
|
||||
"num_seq_groups": num_seq_groups,
|
||||
"blocks_to_copy": execute_model_req.blocks_to_copy,
|
||||
"blocks_to_swap_in": execute_model_req.blocks_to_swap_in,
|
||||
"blocks_to_swap_out": execute_model_req.blocks_to_swap_out,
|
||||
}
|
||||
broadcast_tensor_dict(data, src=0)
|
||||
else:
|
||||
data = broadcast_tensor_dict(src=0)
|
||||
num_seq_groups = data["num_seq_groups"]
|
||||
blocks_to_copy = data["blocks_to_copy"]
|
||||
blocks_to_swap_in = data["blocks_to_swap_in"]
|
||||
blocks_to_swap_out = data["blocks_to_swap_out"]
|
||||
|
||||
if is_openvino_cpu():
|
||||
assert len(execute_model_req.blocks_to_swap_in) == 0
|
||||
assert len(execute_model_req.blocks_to_swap_out) == 0
|
||||
else:
|
||||
self.cache_swap_in(blocks_to_swap_in)
|
||||
self.cache_swap_out(blocks_to_swap_out)
|
||||
|
||||
self.cache_copy(blocks_to_copy)
|
||||
|
||||
@ -353,3 +464,149 @@ class OpenVINOWorker(LoraNotSupportedWorkerBase):
|
||||
self.model_config,
|
||||
self.parallel_config,
|
||||
)
|
||||
|
||||
def profile_run(self) -> int:
|
||||
ov_device = envs.VLLM_OPENVINO_DEVICE
|
||||
|
||||
assert not is_openvino_cpu(), \
|
||||
"CPU device isn't supposed to use profile run."
|
||||
|
||||
import openvino.properties.device as device
|
||||
import openvino.properties.intel_gpu as intel_gpu
|
||||
|
||||
ov_core = self.ov_core
|
||||
cache_config = self.cache_config
|
||||
model_config = self.model_config
|
||||
parallel_config = self.parallel_config
|
||||
device_config = self.device_config
|
||||
input_registry = INPUT_REGISTRY
|
||||
mm_registry = MULTIMODAL_REGISTRY
|
||||
mm_registry.init_mm_limits_per_prompt(model_config)
|
||||
|
||||
# Execute a forward pass with dummy inputs to profile the memory usage
|
||||
# of the model.
|
||||
def model_profile_run():
|
||||
top_k = model_config.get_vocab_size() - 1
|
||||
sampling_params = SamplingParams(top_p=0.99, top_k=top_k)
|
||||
|
||||
max_num_batched_tokens = \
|
||||
self.scheduler_config.max_num_batched_tokens
|
||||
max_num_seqs = self.scheduler_config.max_num_seqs
|
||||
tmp_cache_config = CacheConfig(cache_config.block_size,
|
||||
cache_config.gpu_memory_utilization,
|
||||
cache_config.swap_space_bytes,
|
||||
"auto")
|
||||
tmp_cache_config.num_gpu_blocks = 1
|
||||
tmp_cache_config.num_cpu_blocks = 0
|
||||
tmp_cache_config.cache_dtype = cache_config.cache_dtype
|
||||
|
||||
profiling_cache_engine = OpenVINOCacheEngine(
|
||||
tmp_cache_config, model_config, parallel_config, device_config,
|
||||
ov_core, ov_device)
|
||||
|
||||
# Profile memory usage with max_num_sequences sequences and the
|
||||
# total # number of tokens equal to max_num_batched_tokens.
|
||||
seqs: List[SequenceGroupMetadata] = []
|
||||
for group_id in range(max_num_seqs):
|
||||
seq_len = (max_num_batched_tokens // max_num_seqs +
|
||||
(group_id < max_num_batched_tokens % max_num_seqs))
|
||||
block_size = cache_config.block_size
|
||||
seq_num_blocks = (seq_len + block_size - 1) // block_size
|
||||
|
||||
seq_data, dummy_multi_modal_data = input_registry \
|
||||
.dummy_data_for_profiling(model_config,
|
||||
seq_len,
|
||||
mm_registry)
|
||||
|
||||
block_tables = [[0] * seq_num_blocks] * max_num_seqs
|
||||
seq = SequenceGroupMetadata(
|
||||
request_id=str(group_id),
|
||||
is_prompt=True,
|
||||
seq_data={group_id: seq_data},
|
||||
sampling_params=sampling_params,
|
||||
block_tables=block_tables,
|
||||
lora_request=None,
|
||||
multi_modal_data=dummy_multi_modal_data)
|
||||
seqs.append(seq)
|
||||
|
||||
self.model_runner.block_size = tmp_cache_config.block_size
|
||||
|
||||
# Run the model with the dummy inputs.
|
||||
self.model_runner.execute_model(seqs,
|
||||
profiling_cache_engine.kv_cache)
|
||||
|
||||
# explicitly delete temporary KV cache manager to free KV cache
|
||||
# when real inputs will be passed to OV
|
||||
del profiling_cache_engine
|
||||
|
||||
logger.info(
|
||||
"Start profiling run with dummy inputs to evaluate "
|
||||
"memory usage for %s. It might take a while.", ov_device)
|
||||
|
||||
model_profile_run()
|
||||
|
||||
gpu_device_type = ov_core.get_property(ov_device, device.type)
|
||||
memory_statistics = \
|
||||
ov_core.get_property(ov_device, intel_gpu.memory_statistics)
|
||||
memory_utilization = cache_config.gpu_memory_utilization
|
||||
|
||||
if gpu_device_type == device.Type.INTEGRATED and \
|
||||
memory_utilization >= 0.9:
|
||||
logger.warning(
|
||||
"iGPU is used with high gpu_memory_utilization=%f "
|
||||
"value. This may cause low performance due to "
|
||||
"occupying the majority of available system "
|
||||
"memory. Please consider decreasing "
|
||||
"gpu_memory_utilization or explicitly setting"
|
||||
"`VLLM_OPENVINO_KVCACHE_SPACE` (GB) environment "
|
||||
"variable.", memory_utilization)
|
||||
|
||||
# sum up all used device memory
|
||||
device_memory_types = ["cl_mem", "usm_device"]
|
||||
used_device_mem = \
|
||||
sum(memory_statistics.get(key, 0) for key in device_memory_types)
|
||||
|
||||
if gpu_device_type == device.Type.INTEGRATED:
|
||||
used_device_mem += memory_statistics.get("usm_host", 0)
|
||||
|
||||
# there could be unaccounted extra memory reserved by kernels, kept
|
||||
# in memory pools, etc
|
||||
# therefore, add a threshold to account for this
|
||||
used_memory_threshold = 1.1
|
||||
used_device_mem *= used_memory_threshold
|
||||
|
||||
total_device_memory = \
|
||||
ov_core.get_property(ov_device, intel_gpu.device_total_mem_size)
|
||||
|
||||
def format_memory_size(size) -> str:
|
||||
units = ["B", "KB", "MB", "GB"]
|
||||
unit_index = 0
|
||||
|
||||
while size > 1024 and unit_index < len(units) - 1:
|
||||
size /= 1024
|
||||
unit_index += 1
|
||||
|
||||
return f"{size:.2f} {units[unit_index]}"
|
||||
|
||||
total_device_memory_str = \
|
||||
format(format_memory_size(total_device_memory))
|
||||
used_device_memory_str = \
|
||||
format(format_memory_size(used_device_mem))
|
||||
|
||||
logger.info(
|
||||
"Total %s memory: %s. "
|
||||
"Amount of memory required to run the model with "
|
||||
"max_num_batched_tokens=%d: %s.", ov_device,
|
||||
total_device_memory_str,
|
||||
self.scheduler_config.max_num_batched_tokens,
|
||||
used_device_memory_str)
|
||||
|
||||
if used_device_mem >= total_device_memory:
|
||||
raise RuntimeError(
|
||||
f"The required memory size {used_device_memory_str} for model "
|
||||
"is higher than the total available device "
|
||||
"memory {total_device_memory_str}. Please consider to "
|
||||
"decrease `max_num_batched_tokens` or increase "
|
||||
"`gpu_memory_utilization`")
|
||||
|
||||
return total_device_memory * memory_utilization - used_device_mem
|
||||
|
Loading…
x
Reference in New Issue
Block a user