[OpenVINO] Enable GPU support for OpenVINO vLLM backend (#8192)

This commit is contained in:
Sergey Shlyapnikov 2024-10-03 01:50:01 +04:00 committed by GitHub
parent afb050b29d
commit f58d4fccc9
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
8 changed files with 446 additions and 107 deletions

View File

@ -3,7 +3,7 @@
Installation with OpenVINO
==========================
vLLM powered by OpenVINO supports all LLM models from :doc:`vLLM supported models list <../models/supported_models>` and can perform optimal model serving on all x86-64 CPUs with, at least, AVX2 support. OpenVINO vLLM backend supports the following advanced vLLM features:
vLLM powered by OpenVINO supports all LLM models from :doc:`vLLM supported models list <../models/supported_models>` and can perform optimal model serving on all x86-64 CPUs with, at least, AVX2 support, as well as on both integrated and discrete Intel® GPUs (`the list of supported GPUs <https://docs.openvino.ai/2024/about-openvino/release-notes-openvino/system-requirements.html#gpu>`_). OpenVINO vLLM backend supports the following advanced vLLM features:
- Prefix caching (``--enable-prefix-caching``)
- Chunked prefill (``--enable-chunked-prefill``)
@ -59,28 +59,51 @@ Install from source
$ PIP_EXTRA_INDEX_URL="https://download.pytorch.org/whl/cpu" VLLM_TARGET_DEVICE=openvino python -m pip install -v .
- [Optional] To use vLLM OpenVINO backend with a GPU device, ensure your system is properly set up. Follow the instructions provided here: `https://docs.openvino.ai/2024/get-started/configurations/configurations-intel-gpu.html <https://docs.openvino.ai/2024/get-started/configurations/configurations-intel-gpu.html>`_.
.. _openvino_backend_performance_tips:
Performance tips
----------------
vLLM OpenVINO backend uses the following environment variables to control behavior:
vLLM OpenVINO backend environment variables
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
- ``VLLM_OPENVINO_DEVICE`` to specify which device utilize for the inference. If there are multiple GPUs in the system, additional indexes can be used to choose the proper one (e.g, ``VLLM_OPENVINO_DEVICE=GPU.1``). If the value is not specified, CPU device is used by default.
- ``VLLM_OPENVINO_ENABLE_QUANTIZED_WEIGHTS=ON`` to enable U8 weights compression during model loading stage. By default, compression is turned off. You can also export model with different compression techniques using `optimum-cli` and pass exported folder as `<model_id>`
CPU performance tips
~~~~~~~~~~~~~~~~~~~~
CPU uses the following environment variables to control behavior:
- ``VLLM_OPENVINO_KVCACHE_SPACE`` to specify the KV Cache size (e.g, ``VLLM_OPENVINO_KVCACHE_SPACE=40`` means 40 GB space for KV cache), larger setting will allow vLLM running more requests in parallel. This parameter should be set based on the hardware configuration and memory management pattern of users.
- ``VLLM_OPENVINO_CPU_KV_CACHE_PRECISION=u8`` to control KV cache precision. By default, FP16 / BF16 is used depending on platform.
- ``VLLM_OPENVINO_ENABLE_QUANTIZED_WEIGHTS=ON`` to enable U8 weights compression during model loading stage. By default, compression is turned off. You can also export model with different compression techniques using `optimum-cli` and pass exported folder as `<model_id>`
To enable better TPOT / TTFT latency, you can use vLLM's chunked prefill feature (``--enable-chunked-prefill``). Based on the experiments, the recommended batch size is ``256`` (``--max-num-batched-tokens``)
OpenVINO best known configuration is:
OpenVINO best known configuration for CPU is:
.. code-block:: console
$ VLLM_OPENVINO_KVCACHE_SPACE=100 VLLM_OPENVINO_CPU_KV_CACHE_PRECISION=u8 VLLM_OPENVINO_ENABLE_QUANTIZED_WEIGHTS=ON \
python3 vllm/benchmarks/benchmark_throughput.py --model meta-llama/Llama-2-7b-chat-hf --dataset vllm/benchmarks/ShareGPT_V3_unfiltered_cleaned_split.json --enable-chunked-prefill --max-num-batched-tokens 256
GPU performance tips
~~~~~~~~~~~~~~~~~~~~
GPU device implements the logic for automatic detection of available GPU memory and, by default, tries to reserve as much memory as possible for the KV cache (taking into account ``gpu_memory_utilization`` option). However, this behavior can be overridden by explicitly specifying the desired amount of memory for the KV cache using ``VLLM_OPENVINO_KVCACHE_SPACE`` environment variable (e.g, ``VLLM_OPENVINO_KVCACHE_SPACE=8`` means 8 GB space for KV cache).
Currently, the best performance using GPU can be achieved with the default vLLM execution parameters for models with quantized weights (8 and 4-bit integer data types are supported) and `preemption-mode=swap`.
OpenVINO best known configuration for GPU is:
.. code-block:: console
$ VLLM_OPENVINO_DEVICE=GPU VLLM_OPENVINO_ENABLE_QUANTIZED_WEIGHTS=ON \
python3 vllm/benchmarks/benchmark_throughput.py --model meta-llama/Llama-2-7b-chat-hf --dataset vllm/benchmarks/ShareGPT_V3_unfiltered_cleaned_split.json
.. _openvino_backend_limitations:
Limitations

View File

@ -3,5 +3,6 @@
# OpenVINO dependencies
torch >= 2.1.2
openvino ~= 2024.3.0
optimum-intel[openvino] >= 1.18.2
openvino ~= 2024.4.0
openvino-tokenizers[transformers] ~= 2024.4.0
optimum-intel[openvino] >= 1.19.0

View File

@ -9,6 +9,31 @@ from vllm.attention.backends.abstract import (AttentionBackend,
from vllm.attention.backends.utils import CommonAttentionState
def copy_cache_block(src_tensor: ov.Tensor, dst_tensor: ov.Tensor,
src_offset: int, dst_offset: int) -> None:
def create_roi_tensor(
tensor: ov.Tensor,
block_number: int,
) -> ov.Tensor:
roi_begin = ov.runtime.Coordinate([0, 0, 0, 0])
roi_end = ov.runtime.Coordinate(tensor.get_shape())
roi_begin[0] = block_number
roi_end[0] = block_number + 1
if isinstance(tensor, ov.Tensor):
return ov.Tensor(tensor, roi_begin, roi_end)
else:
return ov.RemoteTensor(tensor, roi_begin, roi_end)
src_roi_tensor = \
create_roi_tensor(src_tensor, src_offset)
dst_roi_tensor = \
create_roi_tensor(dst_tensor, dst_offset)
src_roi_tensor.copy_to(dst_roi_tensor)
class OpenVINOAttentionBackend(AttentionBackend):
@staticmethod
@ -44,13 +69,12 @@ class OpenVINOAttentionBackend(AttentionBackend):
@staticmethod
def swap_blocks(
src_kv_cache: ov.Tensor,
dst_kv_cache: ov.Tensor,
src_to_dst: torch.Tensor,
src_tensor: ov.Tensor,
dst_tensor: ov.Tensor,
src_to_dists: List[Tuple[int, int]],
) -> None:
# OpenVINO currently supports only CPU, which does not require
# swap of KV cache blocks
raise NotImplementedError
for src, dst in src_to_dists:
copy_cache_block(src_tensor, dst_tensor, src, dst)
@staticmethod
def copy_blocks(
@ -59,8 +83,8 @@ class OpenVINOAttentionBackend(AttentionBackend):
) -> None:
for src, dst in src_to_dists:
for key_cache, value_cache in kv_caches:
key_cache.data[dst, :] = key_cache.data[src, :]
value_cache.data[dst, :] = value_cache.data[src, :]
copy_cache_block(key_cache, key_cache, src, dst)
copy_cache_block(value_cache, value_cache, src, dst)
@dataclass

View File

@ -35,6 +35,7 @@ if TYPE_CHECKING:
VLLM_PP_LAYER_PARTITION: Optional[str] = None
VLLM_CPU_KVCACHE_SPACE: int = 0
VLLM_CPU_OMP_THREADS_BIND: str = ""
VLLM_OPENVINO_DEVICE: str = "CPU"
VLLM_OPENVINO_KVCACHE_SPACE: int = 0
VLLM_OPENVINO_CPU_KV_CACHE_PRECISION: Optional[str] = None
VLLM_OPENVINO_ENABLE_QUANTIZED_WEIGHTS: bool = False
@ -302,6 +303,11 @@ environment_variables: Dict[str, Callable[[], Any]] = {
"VLLM_CPU_OMP_THREADS_BIND":
lambda: os.getenv("VLLM_CPU_OMP_THREADS_BIND", "all"),
# OpenVINO device selection
# default is CPU
"VLLM_OPENVINO_DEVICE":
lambda: os.getenv("VLLM_OPENVINO_DEVICE", "CPU").upper(),
# OpenVINO key-value cache space
# default is 4GB
"VLLM_OPENVINO_KVCACHE_SPACE":

View File

@ -17,6 +17,14 @@ from vllm.utils import (GiB_bytes, get_distributed_init_method, get_ip,
logger = init_logger(__name__)
def is_openvino_cpu() -> bool:
return "CPU" in envs.VLLM_OPENVINO_DEVICE
def is_openvino_gpu() -> bool:
return "GPU" in envs.VLLM_OPENVINO_DEVICE
class OpenVINOExecutor(ExecutorBase):
uses_ray: bool = False
@ -24,8 +32,13 @@ class OpenVINOExecutor(ExecutorBase):
def _init_executor(self) -> None:
assert self.device_config.device_type == "openvino"
assert self.lora_config is None, "OpenVINO backend doesn't support LoRA"
assert is_openvino_cpu() or is_openvino_gpu(), \
"OpenVINO backend supports only CPU and GPU devices"
self.ov_core = ov.Core()
self.model_config = _verify_and_get_model_config(self.model_config)
self.cache_config = _verify_and_get_cache_config(self.cache_config)
self.cache_config = _verify_and_get_cache_config(
self.ov_core, self.cache_config)
# Instantiate the worker and load the model to CPU.
self._init_worker()
@ -40,6 +53,7 @@ class OpenVINOExecutor(ExecutorBase):
distributed_init_method = get_distributed_init_method(
get_ip(), get_open_port())
self.driver_worker = OpenVINOWorker(
ov_core=self.ov_core,
model_config=self.model_config,
parallel_config=self.parallel_config,
scheduler_config=self.scheduler_config,
@ -68,10 +82,13 @@ class OpenVINOExecutor(ExecutorBase):
# NOTE: We log here to avoid multiple logs when number of workers is
# greater than one. We could log in the engine, but not all executors
# have GPUs.
# NOTE: `cpu block` for OpenVINO backend is located on CPU memory but is
# referred as `gpu block`. Because we want to reuse the existing block
# management procedure.
logger.info("# CPU blocks: %d", num_gpu_blocks)
# NOTE: In case of a CPU device, `cpu block` for OpenVINO backend
# is located on CPU memory but is referred as `gpu block`.
# Because we want to reuse the existing block management procedure.
device_blocks = num_gpu_blocks
swap_blocks = num_cpu_blocks
logger.info("OpenVINO %s: # device blocks: %d; # swap blocks: %d",
envs.VLLM_OPENVINO_DEVICE, device_blocks, swap_blocks)
self.driver_worker.initialize_cache(num_gpu_blocks, num_cpu_blocks)
def execute_model(
@ -143,29 +160,45 @@ def _verify_and_get_model_config(config: ModelConfig) -> ModelConfig:
return config
def _verify_and_get_cache_config(config: CacheConfig) -> CacheConfig:
def _verify_and_get_cache_config(ov_core: ov.Core,
config: CacheConfig) -> CacheConfig:
if envs.VLLM_OPENVINO_CPU_KV_CACHE_PRECISION == "u8":
logger.info("KV cache type is overried to u8 via "
if not is_openvino_cpu():
logger.info("VLLM_OPENVINO_CPU_KV_CACHE_PRECISION is"
"ignored for GPU, f16 data type will be used.")
config.cache_dtype = ov.Type.f16
else:
logger.info("KV cache type is overridden to u8 via "
"VLLM_OPENVINO_CPU_KV_CACHE_PRECISION env var.")
config.cache_dtype = ov.Type.u8
else:
core = ov.Core()
inference_precision = core.get_property("CPU",
hints.inference_precision)
if is_openvino_cpu():
ov_device = envs.VLLM_OPENVINO_DEVICE
inference_precision = ov_core.get_property(
ov_device, hints.inference_precision)
if inference_precision == ov.Type.bf16:
config.cache_dtype = ov.Type.bf16
else:
config.cache_dtype = ov.Type.f16
else:
config.cache_dtype = ov.Type.f16
if is_openvino_cpu():
if config.block_size != 32:
logger.info(
f"OpenVINO optimal block size is 32, overriding currently set {config.block_size}" # noqa: G004, E501
f"OpenVINO CPU optimal block size is 32, overriding currently set {config.block_size}" # noqa: G004, E501
)
config.block_size = 32
else:
if config.block_size != 16:
logger.info(
f"OpenVINO GPU optimal block size is 16, overriding currently set {config.block_size}" # noqa: G004, E501
)
config.block_size = 16
kv_cache_space = envs.VLLM_OPENVINO_KVCACHE_SPACE
if kv_cache_space >= 0:
if kv_cache_space == 0:
if kv_cache_space == 0 and is_openvino_cpu():
config.openvino_kvcache_space_bytes = 4 * GiB_bytes # type: ignore
logger.warning(
"Environment variable VLLM_OPENVINO_KVCACHE_SPACE (GB) "

View File

@ -12,6 +12,7 @@ from torch import nn
import vllm.envs as envs
from vllm.attention.backends.openvino import OpenVINOAttentionMetadata
from vllm.config import DeviceConfig, ModelConfig
from vllm.executor.openvino_executor import is_openvino_cpu
from vllm.logger import init_logger
from vllm.model_executor.layers.logits_processor import (LogitsProcessor,
_prune_hidden_states)
@ -51,25 +52,15 @@ def _modify_cache_parameters(model: ov.Model, kv_cache_dtype: ov.Type,
shape = parameter.get_partial_shape()
# use real block size if available, just a placeholder
# to provide the expected rank
x_size = 1
num_blocks = ov.Dimension()
block_size = ov.Dimension()
head_size = ov.Dimension()
# TODO: Negotiate required layout with plugins (CPU is ~OK, GPU is TBD),
# pass more parameters to this function to set more static dimensions
if input_name.startswith("key_cache."):
cpu_shape = [num_blocks, shape[1], block_size, head_size]
gpu_shape = [
num_blocks,
shape[1],
shape[2].get_length() //
x_size if shape[2].is_static else ov.Dimension(),
block_size,
x_size,
]
gpu_shape = [num_blocks, shape[1], shape[2], block_size]
elif input_name.startswith("value_cache."):
cpu_shape = [num_blocks, shape[1], block_size, head_size]
gpu_shape = [num_blocks, shape[1], shape[2], block_size]
gpu_shape = [num_blocks, shape[1], block_size, shape[2]]
else:
continue
parameter.set_partial_shape(
@ -108,6 +99,7 @@ class OpenVINOCasualLM(nn.Module):
def __init__(
self,
ov_core: ov.Core,
model_config: ModelConfig,
device_config: DeviceConfig,
kv_cache_dtype: ov.Type,
@ -141,12 +133,12 @@ class OpenVINOCasualLM(nn.Module):
trust_remote_code=model_config.trust_remote_code,
)
ov_device = envs.VLLM_OPENVINO_DEVICE
paged_attention_transformation(pt_model.model)
_modify_cache_parameters(pt_model.model, kv_cache_dtype,
device_config.device.type == "cpu")
is_openvino_cpu())
core = ov.Core()
ov_compiled = core.compile_model(pt_model.model, "CPU")
ov_compiled = ov_core.compile_model(pt_model.model, ov_device)
self.ov_request = ov_compiled.create_infer_request()
def forward(
@ -199,6 +191,7 @@ def get_model(
**kwargs,
) -> torch.nn.Module:
lora_config = kwargs.get("lora_config", None)
ov_core = kwargs.get("ov_core")
if lora_config:
raise ValueError(
"OpenVINO modeling does not support LoRA, "
@ -206,4 +199,5 @@ def get_model(
"be added in the future. If this is important to you, "
"please open an issue on github.")
return OpenVINOCasualLM(model_config, device_config, kv_cache_dtype)
return OpenVINOCasualLM(ov_core, model_config, device_config,
kv_cache_dtype)

View File

@ -42,6 +42,7 @@ class OpenVINOModelRunner:
def __init__(
self,
ov_core: ov.Core,
model_config: ModelConfig,
parallel_config: ParallelConfig,
scheduler_config: SchedulerConfig,
@ -55,6 +56,7 @@ class OpenVINOModelRunner:
*args,
**kwargs,
):
self.ov_core = ov_core
self.model_config = model_config
self.parallel_config = parallel_config
self.scheduler_config = scheduler_config
@ -89,11 +91,10 @@ class OpenVINOModelRunner:
self.model: nn.Module # Set after init_Model
def load_model(self) -> None:
self.model = get_model(
model_config=self.model_config,
self.model = get_model(model_config=self.model_config,
device_config=self.device_config,
kv_cache_dtype=self.kv_cache_dtype,
)
ov_core=self.ov_core)
def _prepare_model_input(
self,

View File

@ -5,6 +5,7 @@ import openvino as ov
import torch
import torch.distributed
import vllm.envs as envs
from vllm.attention import get_attn_backend
from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig,
ModelConfig, MultiModalConfig, ParallelConfig,
@ -12,10 +13,14 @@ from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig,
from vllm.distributed import (broadcast_tensor_dict,
ensure_model_parallel_initialized,
init_distributed_environment)
from vllm.executor.openvino_executor import is_openvino_cpu
from vllm.inputs import INPUT_REGISTRY
from vllm.logger import init_logger
from vllm.model_executor import set_random_seed
from vllm.model_executor.layers.sampler import SamplerOutput
from vllm.sequence import ExecuteModelRequest
from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.sampling_params import SamplingParams
from vllm.sequence import ExecuteModelRequest, SequenceGroupMetadata
from vllm.worker.openvino_model_runner import OpenVINOModelRunner
from vllm.worker.worker_base import LoraNotSupportedWorkerBase
@ -36,6 +41,8 @@ class OpenVINOCacheEngine:
model_config: ModelConfig,
parallel_config: ParallelConfig,
device_config: DeviceConfig,
ov_core: ov.Core,
ov_device: str,
) -> None:
assert device_config.device_type == "openvino"
self.cache_config = cache_config
@ -56,9 +63,10 @@ class OpenVINOCacheEngine:
self.block_size = cache_config.block_size
# Note: In CacheConfig, num_gpu_blocks actual is num_cpu_blocks
# for OpenVINO backend, because we want to reuse KV cache management
# in the scheduler.
self.num_cpu_blocks = cache_config.num_gpu_blocks
# for OpenVINO backend with a CPU target device, because we want
# to reuse KV cache management in the scheduler.
self.num_device_blocks = cache_config.num_gpu_blocks
self.num_swap_blocks = cache_config.num_cpu_blocks
# Get attention backend.
self.attn_backend = get_attn_backend(
@ -74,33 +82,99 @@ class OpenVINOCacheEngine:
# Initialize the cache.
self.kv_cache: List[Tuple[ov.Tensor,
ov.Tensor]] = self._allocate_kv_cache(
self.num_cpu_blocks)
self.num_device_blocks, ov_core,
ov_device)
# Initialize the swap.
self.swap_cache: List[Tuple[ov.Tensor,
ov.Tensor]] = self._allocate_swap_cache(
self.num_swap_blocks, ov_device)
def _allocate_kv_cache(
self,
num_blocks: int,
ov_core: ov.Core,
ov_device: str,
) -> List[Tuple[ov.Tensor, ov.Tensor]]:
"""Allocates KV cache."""
k_block_shape = v_block_shape = self.attn_backend.get_kv_cache_shape(
num_blocks, self.block_size, self.num_kv_heads, self.head_size)[1:]
kv_cache: List[Tuple[ov.Tensor, ov.Tensor]] = []
if is_openvino_cpu():
for _ in range(self.num_layers):
key_blocks = ov.Tensor(self.cache_config.cache_dtype,
k_block_shape)
value_blocks = ov.Tensor(self.cache_config.cache_dtype,
v_block_shape)
kv_cache.append((key_blocks, value_blocks))
else:
# Update key_cache shape:
k_block_shape = (v_block_shape[0], v_block_shape[1],
v_block_shape[3], v_block_shape[2])
remote_context = ov_core.get_default_context(ov_device)
for _ in range(self.num_layers):
key_blocks = \
remote_context.create_tensor(self.cache_config.cache_dtype,
ov.Shape(k_block_shape),
{})
value_blocks = \
remote_context.create_tensor(self.cache_config.cache_dtype,
ov.Shape(v_block_shape),
{})
kv_cache.append((key_blocks, value_blocks))
return kv_cache
def swap_in(self, src_to_dst: Dict[int, int]) -> None:
raise NotImplementedError(
"Swap is not supported in OpenVINOCacheEngine.")
def _allocate_swap_cache(
self,
num_blocks: int,
ov_device: str,
) -> List[Tuple[ov.Tensor, ov.Tensor]]:
"""Allocates swap cache."""
k_block_shape = v_block_shape = self.attn_backend.get_kv_cache_shape(
num_blocks, self.block_size, self.num_kv_heads, self.head_size)[1:]
swap_cache: List[Tuple[ov.Tensor, ov.Tensor]] = []
def swap_out(self, src_to_dst: Dict[int, int]) -> None:
raise NotImplementedError(
"Swap is not supported in OpenVINOCacheEngine.")
if num_blocks == 0:
return swap_cache
def copy(self, src_to_dsts: Dict[int, List[int]]) -> None:
assert not is_openvino_cpu(), \
"CPU device isn't supposed to have swap cache"
# Update key_cache shape:
k_block_shape = (v_block_shape[0], v_block_shape[1], v_block_shape[3],
v_block_shape[2])
for _ in range(self.num_layers):
key_blocks = ov.Tensor(self.cache_config.cache_dtype,
k_block_shape)
value_blocks = ov.Tensor(self.cache_config.cache_dtype,
v_block_shape)
swap_cache.append((key_blocks, value_blocks))
return swap_cache
def swap_in(self, src_to_dst: List[Tuple[int, int]]) -> None:
for i in range(self.num_layers):
for swap_tensor, kv_tensor in zip(self.swap_cache[i],
self.kv_cache[i]):
self.attn_backend.swap_blocks(swap_tensor, kv_tensor,
src_to_dst)
def swap_out(self, src_to_dst: List[Tuple[int, int]]) -> None:
for i in range(self.num_layers):
for swap_tensor, kv_tensor in zip(self.swap_cache[i],
self.kv_cache[i]):
self.attn_backend.swap_blocks(kv_tensor, swap_tensor,
src_to_dst)
def copy(self, src_to_dsts: List[Tuple[int, int]]) -> None:
if (len(src_to_dsts) > 0):
self.attn_backend.copy_blocks(self.kv_cache, src_to_dsts)
@staticmethod
@ -139,6 +213,7 @@ class OpenVINOWorker(LoraNotSupportedWorkerBase):
def __init__(
self,
ov_core: ov.Core,
model_config: ModelConfig,
parallel_config: ParallelConfig,
scheduler_config: SchedulerConfig,
@ -153,6 +228,7 @@ class OpenVINOWorker(LoraNotSupportedWorkerBase):
kv_cache_dtype: Optional[ov.Type] = ov.Type.undefined,
is_driver_worker: bool = False,
) -> None:
self.ov_core = ov_core
self.model_config = model_config
self.parallel_config = parallel_config
self.parallel_config.rank = rank
@ -175,6 +251,7 @@ class OpenVINOWorker(LoraNotSupportedWorkerBase):
init_cached_hf_modules()
self.model_runner = OpenVINOModelRunner(
self.ov_core,
model_config,
parallel_config,
scheduler_config,
@ -204,56 +281,69 @@ class OpenVINOWorker(LoraNotSupportedWorkerBase):
This determines how many KV blocks can fit into the configured
KV cache space.
Note that since vLLM assumes a block resides on GPU if it can be
modified, we return num_gpu_blocks=num_cpu_blocks and num_cpu_blocks=0.
This allows us to reuse the scheduler of vLLM without generalizing it
to different devices.
"""
# For OpenVINO backend, the block number will be calculated based on the
# openvino_kvcache_space_bytes.
# For OpenVINO backend, in case of CPU device, the block number will be
# calculated based on the openvino_kvcache_space_bytes.
cache_block_size = self.get_cache_block_size_bytes()
num_cpu_blocks = int(self.cache_config.openvino_kvcache_space_bytes //
cache_block_size)
num_cpu_blocks = max(num_cpu_blocks, 0)
kvcache_space_bytes = self.cache_config.openvino_kvcache_space_bytes
# Note: To reuse the cache management procedure,
# use cpu cache as 'gpu cache'.
num_gpu_blocks = num_cpu_blocks
num_cpu_blocks = 0
return num_gpu_blocks, num_cpu_blocks
if is_openvino_cpu():
num_device_blocks = int(kvcache_space_bytes // cache_block_size)
num_swap_blocks = 0
else:
if kvcache_space_bytes > 0:
logger.info("KV_CACHE size was explicitly configured via "
"VLLM_OPENVINO_KVCACHE_SPACE environment "
"variable, ignoring profiling run.")
kv_cache_size = kvcache_space_bytes
else:
try:
kv_cache_size = self.profile_run()
except Exception as err:
raise RuntimeError(
"The error occurred during profile run. This might be "
"due to insufficient GPU memory. Consider decreasing "
"`max_model_len` to limit the maximum simultaneously "
"processed tokens.") from err
num_device_blocks = int(kv_cache_size // cache_block_size)
num_swap_blocks = int(self.cache_config.swap_space_bytes //
cache_block_size)
return num_device_blocks, num_swap_blocks
def initialize_cache(self, num_gpu_blocks: int,
num_cpu_blocks: int) -> None:
"""Initialize the KV cache. Currently, swappable CPU memory is not
supported.
"""Initialize the KV cache. Swappable CPU memory is only
supported on GPU.
Since this worker does not support GPUs, we use the num_gpu_blocks to
For CPU, we use the num_gpu_blocks to
determine how many non-swappable CPU blocks to allocate.
"""
assert (num_cpu_blocks == 0
), f"{type(self)} does not support swappable cache"
# Note: To reuse the cache management procedure,
# use cpu cache as 'gpu cache'.
num_cpu_blocks = num_gpu_blocks
num_device_blocks = num_gpu_blocks
num_swap_blocks = num_cpu_blocks
self._validate_num_cpu_blocks(num_cpu_blocks)
self.cache_config.num_gpu_blocks = num_cpu_blocks
self.cache_config.num_cpu_blocks = 0
if is_openvino_cpu():
assert (num_swap_blocks == 0
), f"{type(self)} does not support swappable cache for CPU"
self._validate_num_blocks(num_device_blocks)
self.cache_config.num_gpu_blocks = num_device_blocks
self.cache_config.num_cpu_blocks = num_swap_blocks
# Initialize the cache.
self._init_cache_engine()
def _validate_num_cpu_blocks(self, num_cpu_blocks: int) -> None:
"""Raise errors if the num_cpu_blocks is invalid."""
if num_cpu_blocks <= 0:
def _validate_num_blocks(self, num_blocks: int) -> None:
"""Raise errors if the num_blocks is invalid."""
if num_blocks <= 0:
raise ValueError(
"No available memory for the cache blocks. "
"Try increasing `VLLM_OPENVINO_KVCACHE_SPACE` when "
"initializing the engine.")
max_seq_len = self.cache_config.block_size * num_cpu_blocks
max_seq_len = self.cache_config.block_size * num_blocks
if self.model_config.max_model_len > max_seq_len:
raise ValueError(
f"The model's max seq len ({self.model_config.max_model_len}) "
@ -263,11 +353,14 @@ class OpenVINOWorker(LoraNotSupportedWorkerBase):
"when initializing the engine.")
def _init_cache_engine(self) -> None:
ov_device = envs.VLLM_OPENVINO_DEVICE
self.cache_engine = OpenVINOCacheEngine(
self.cache_config,
self.model_config,
self.parallel_config,
self.device_config,
self.ov_core,
ov_device,
)
self.kv_cache = self.cache_engine.kv_cache
self.model_runner.block_size = self.cache_engine.block_size
@ -275,10 +368,17 @@ class OpenVINOWorker(LoraNotSupportedWorkerBase):
assert self.kv_cache is not None
# Populate the cache to warmup the memory
if is_openvino_cpu():
for key_cache, value_cache in self.kv_cache:
key_cache.data[:] = 0
value_cache.data[:] = 0
def cache_swap_in(self, src_to_dst: List[Tuple[int, int]]) -> None:
self.cache_engine.swap_in(src_to_dst)
def cache_swap_out(self, src_to_dst: List[Tuple[int, int]]) -> None:
self.cache_engine.swap_out(src_to_dst)
def cache_copy(
self,
blocks_to_copy: List[Tuple[int, int]],
@ -300,17 +400,28 @@ class OpenVINOWorker(LoraNotSupportedWorkerBase):
num_seq_groups: int = len(seq_group_metadata_list)
assert execute_model_req is not None
blocks_to_copy = execute_model_req.blocks_to_copy
assert len(execute_model_req.blocks_to_swap_in) == 0
assert len(execute_model_req.blocks_to_swap_out) == 0
blocks_to_swap_in = execute_model_req.blocks_to_swap_in
blocks_to_swap_out = execute_model_req.blocks_to_swap_out
data: Dict[str, Any] = {
"num_seq_groups": num_seq_groups,
"blocks_to_copy": execute_model_req.blocks_to_copy,
"blocks_to_swap_in": execute_model_req.blocks_to_swap_in,
"blocks_to_swap_out": execute_model_req.blocks_to_swap_out,
}
broadcast_tensor_dict(data, src=0)
else:
data = broadcast_tensor_dict(src=0)
num_seq_groups = data["num_seq_groups"]
blocks_to_copy = data["blocks_to_copy"]
blocks_to_swap_in = data["blocks_to_swap_in"]
blocks_to_swap_out = data["blocks_to_swap_out"]
if is_openvino_cpu():
assert len(execute_model_req.blocks_to_swap_in) == 0
assert len(execute_model_req.blocks_to_swap_out) == 0
else:
self.cache_swap_in(blocks_to_swap_in)
self.cache_swap_out(blocks_to_swap_out)
self.cache_copy(blocks_to_copy)
@ -353,3 +464,149 @@ class OpenVINOWorker(LoraNotSupportedWorkerBase):
self.model_config,
self.parallel_config,
)
def profile_run(self) -> int:
ov_device = envs.VLLM_OPENVINO_DEVICE
assert not is_openvino_cpu(), \
"CPU device isn't supposed to use profile run."
import openvino.properties.device as device
import openvino.properties.intel_gpu as intel_gpu
ov_core = self.ov_core
cache_config = self.cache_config
model_config = self.model_config
parallel_config = self.parallel_config
device_config = self.device_config
input_registry = INPUT_REGISTRY
mm_registry = MULTIMODAL_REGISTRY
mm_registry.init_mm_limits_per_prompt(model_config)
# Execute a forward pass with dummy inputs to profile the memory usage
# of the model.
def model_profile_run():
top_k = model_config.get_vocab_size() - 1
sampling_params = SamplingParams(top_p=0.99, top_k=top_k)
max_num_batched_tokens = \
self.scheduler_config.max_num_batched_tokens
max_num_seqs = self.scheduler_config.max_num_seqs
tmp_cache_config = CacheConfig(cache_config.block_size,
cache_config.gpu_memory_utilization,
cache_config.swap_space_bytes,
"auto")
tmp_cache_config.num_gpu_blocks = 1
tmp_cache_config.num_cpu_blocks = 0
tmp_cache_config.cache_dtype = cache_config.cache_dtype
profiling_cache_engine = OpenVINOCacheEngine(
tmp_cache_config, model_config, parallel_config, device_config,
ov_core, ov_device)
# Profile memory usage with max_num_sequences sequences and the
# total # number of tokens equal to max_num_batched_tokens.
seqs: List[SequenceGroupMetadata] = []
for group_id in range(max_num_seqs):
seq_len = (max_num_batched_tokens // max_num_seqs +
(group_id < max_num_batched_tokens % max_num_seqs))
block_size = cache_config.block_size
seq_num_blocks = (seq_len + block_size - 1) // block_size
seq_data, dummy_multi_modal_data = input_registry \
.dummy_data_for_profiling(model_config,
seq_len,
mm_registry)
block_tables = [[0] * seq_num_blocks] * max_num_seqs
seq = SequenceGroupMetadata(
request_id=str(group_id),
is_prompt=True,
seq_data={group_id: seq_data},
sampling_params=sampling_params,
block_tables=block_tables,
lora_request=None,
multi_modal_data=dummy_multi_modal_data)
seqs.append(seq)
self.model_runner.block_size = tmp_cache_config.block_size
# Run the model with the dummy inputs.
self.model_runner.execute_model(seqs,
profiling_cache_engine.kv_cache)
# explicitly delete temporary KV cache manager to free KV cache
# when real inputs will be passed to OV
del profiling_cache_engine
logger.info(
"Start profiling run with dummy inputs to evaluate "
"memory usage for %s. It might take a while.", ov_device)
model_profile_run()
gpu_device_type = ov_core.get_property(ov_device, device.type)
memory_statistics = \
ov_core.get_property(ov_device, intel_gpu.memory_statistics)
memory_utilization = cache_config.gpu_memory_utilization
if gpu_device_type == device.Type.INTEGRATED and \
memory_utilization >= 0.9:
logger.warning(
"iGPU is used with high gpu_memory_utilization=%f "
"value. This may cause low performance due to "
"occupying the majority of available system "
"memory. Please consider decreasing "
"gpu_memory_utilization or explicitly setting"
"`VLLM_OPENVINO_KVCACHE_SPACE` (GB) environment "
"variable.", memory_utilization)
# sum up all used device memory
device_memory_types = ["cl_mem", "usm_device"]
used_device_mem = \
sum(memory_statistics.get(key, 0) for key in device_memory_types)
if gpu_device_type == device.Type.INTEGRATED:
used_device_mem += memory_statistics.get("usm_host", 0)
# there could be unaccounted extra memory reserved by kernels, kept
# in memory pools, etc
# therefore, add a threshold to account for this
used_memory_threshold = 1.1
used_device_mem *= used_memory_threshold
total_device_memory = \
ov_core.get_property(ov_device, intel_gpu.device_total_mem_size)
def format_memory_size(size) -> str:
units = ["B", "KB", "MB", "GB"]
unit_index = 0
while size > 1024 and unit_index < len(units) - 1:
size /= 1024
unit_index += 1
return f"{size:.2f} {units[unit_index]}"
total_device_memory_str = \
format(format_memory_size(total_device_memory))
used_device_memory_str = \
format(format_memory_size(used_device_mem))
logger.info(
"Total %s memory: %s. "
"Amount of memory required to run the model with "
"max_num_batched_tokens=%d: %s.", ov_device,
total_device_memory_str,
self.scheduler_config.max_num_batched_tokens,
used_device_memory_str)
if used_device_mem >= total_device_memory:
raise RuntimeError(
f"The required memory size {used_device_memory_str} for model "
"is higher than the total available device "
"memory {total_device_memory_str}. Please consider to "
"decrease `max_num_batched_tokens` or increase "
"`gpu_memory_utilization`")
return total_device_memory * memory_utilization - used_device_mem