[Hardware][Intel] OpenVINO vLLM backend (#5379)

This commit is contained in:
Ilya Lavrenov 2024-06-28 17:50:16 +04:00 committed by GitHub
parent 5932634409
commit 57f09a419c
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
22 changed files with 1393 additions and 23 deletions

14
.buildkite/run-openvino-test.sh Executable file
View File

@ -0,0 +1,14 @@
# This script build the OpenVINO docker image and run the offline inference inside the container.
# It serves a sanity check for compilation and basic model usage.
set -ex
# Try building the docker image
docker build -t openvino-test -f Dockerfile.openvino .
# Setup cleanup
remove_docker_container() { docker rm -f openvino-test || true; }
trap remove_docker_container EXIT
remove_docker_container
# Run the image and launch offline inference
docker run --network host --env VLLM_OPENVINO_KVCACHE_SPACE=1 --name openvino-test openvino-test python3 /workspace/vllm/examples/offline_inference.py

26
Dockerfile.openvino Normal file
View File

@ -0,0 +1,26 @@
# The vLLM Dockerfile is used to construct vLLM image that can be directly used
# to run the OpenAI compatible server.
FROM ubuntu:22.04 AS dev
RUN apt-get update -y && \
apt-get install -y python3-pip git
WORKDIR /workspace
# copy requirements
COPY requirements-build.txt /workspace/vllm/
COPY requirements-common.txt /workspace/vllm/
COPY requirements-openvino.txt /workspace/vllm/
COPY vllm/ /workspace/vllm/vllm
COPY setup.py /workspace/vllm/
# install build requirements
RUN PIP_EXTRA_INDEX_URL="https://download.pytorch.org/whl/cpu" python3 -m pip install -r /workspace/vllm/requirements-build.txt
# build vLLM with OpenVINO backend
RUN PIP_PRE=1 PIP_EXTRA_INDEX_URL="https://download.pytorch.org/whl/cpu https://storage.openvinotoolkit.org/simple/wheels/nightly/" VLLM_TARGET_DEVICE="openvino" python3 -m pip install /workspace/vllm/
COPY examples/ /workspace/vllm/examples
COPY benchmarks/ /workspace/vllm/benchmarks
CMD ["/bin/bash"]

View File

@ -207,9 +207,10 @@ if __name__ == '__main__':
parser.add_argument(
"--device",
type=str,
default="cuda",
choices=["cuda", "cpu", "tpu", "xpu"],
help='device type for vLLM execution, supporting CUDA and CPU.')
default="auto",
choices=["auto", "cuda", "cpu", "openvino", "tpu", "xpu"],
help='device type for vLLM execution, supporting CUDA, OpenVINO and '
'CPU.')
parser.add_argument('--block-size',
type=int,
default=16,

View File

@ -349,9 +349,10 @@ if __name__ == "__main__":
parser.add_argument(
"--device",
type=str,
default="cuda",
choices=["cuda", "cpu", "tpu", "xpu"],
help='device type for vLLM execution, supporting CUDA and CPU.')
default="auto",
choices=["auto", "cuda", "cpu", "openvino", "tpu", "xpu"],
help='device type for vLLM execution, supporting CUDA, OpenVINO and '
'CPU.')
parser.add_argument(
"--enable-prefix-caching",
action='store_true',

View File

@ -0,0 +1,95 @@
.. _installation_openvino:
Installation with OpenVINO
==========================
vLLM powered by OpenVINO supports all LLM models from :doc:`vLLM supported models list <../models/supported_models>` and can perform optimal model serving on all x86-64 CPUs with, at least, AVX2 support. OpenVINO vLLM backend supports the following advanced vLLM features:
- Prefix caching (``--enable-prefix-caching``)
- Chunked prefill (``--enable-chunked-prefill``)
**Table of contents**:
- :ref:`Requirements <openvino_backend_requirements>`
- :ref:`Quick start using Dockerfile <openvino_backend_quick_start_dockerfile>`
- :ref:`Build from source <install_openvino_backend_from_source>`
- :ref:`Performance tips <openvino_backend_performance_tips>`
- :ref:`Limitations <openvino_backend_limitations>`
.. _openvino_backend_requirements:
Requirements
------------
* OS: Linux
* Instruction set architecture (ISA) requirement: at least AVX2.
.. _openvino_backend_quick_start_dockerfile:
Quick start using Dockerfile
----------------------------
.. code-block:: console
$ docker build -f Dockerfile.openvino -t vllm-openvino-env .
$ docker run -it --rm vllm-openvino-env
.. _install_openvino_backend_from_source:
Install from source
-------------------
- First, install Python. For example, on Ubuntu 22.04, you can run:
.. code-block:: console
$ sudo apt-get update -y
$ sudo apt-get install python3
- Second, install prerequisites vLLM OpenVINO backend installation:
.. code-block:: console
$ pip install --upgrade pip
$ pip install -r requirements-build.txt --extra-index-url https://download.pytorch.org/whl/cpu
- Finally, install vLLM with OpenVINO backend:
.. code-block:: console
$ PIP_PRE=1 PIP_EXTRA_INDEX_URL="https://download.pytorch.org/whl/cpu https://storage.openvinotoolkit.org/simple/wheels/nightly/" VLLM_TARGET_DEVICE=openvino python -m pip install -v .
.. _openvino_backend_performance_tips:
Performance tips
----------------
vLLM OpenVINO backend uses the following environment variables to control behavior:
- ``VLLM_OPENVINO_KVCACHE_SPACE`` to specify the KV Cache size (e.g, ``VLLM_OPENVINO_KVCACHE_SPACE=40`` means 40 GB space for KV cache), larger setting will allow vLLM running more requests in parallel. This parameter should be set based on the hardware configuration and memory management pattern of users.
- ``VLLM_OPENVINO_CPU_KV_CACHE_PRECISION=u8`` to control KV cache precision. By default, FP16 / BF16 is used depending on platform.
- ``VLLM_OPENVINO_ENABLE_QUANTIZED_WEIGHTS=ON`` to enable U8 weights compression during model loading stage. By default, compression is turned off.
To enable better TPOT / TTFT latency, you can use vLLM's chunked prefill feature (``--enable-chunked-prefill``). Based on the experiments, the recommended batch size is ``256`` (``--max-num-batched-tokens``)
OpenVINO best known configuration is:
.. code-block:: console
$ VLLM_OPENVINO_KVCACHE_SPACE=100 VLLM_OPENVINO_CPU_KV_CACHE_PRECISION=u8 VLLM_OPENVINO_ENABLE_QUANTIZED_WEIGHTS=ON \
python3 vllm/benchmarks/benchmark_throughput.py --model meta-llama/Llama-2-7b-chat-hf --dataset vllm/benchmarks/ShareGPT_V3_unfiltered_cleaned_split.json --enable-chunked-prefill --max-num-batched-tokens 256
.. _openvino_backend_limitations:
Limitations
-----------
- LoRA serving is not supported.
- Only LLM models are currently supported. LLaVa and encoder-decoder models are not currently enabled in vLLM OpenVINO integration.
- Tensor and pipeline parallelism are not currently enabled in vLLM integration.
- Speculative sampling is not tested within vLLM integration.

View File

@ -63,6 +63,7 @@ Documentation
getting_started/installation
getting_started/amd-installation
getting_started/openvino-installation
getting_started/cpu-installation
getting_started/neuron-installation
getting_started/tpu-installation

View File

@ -0,0 +1,9 @@
# Common dependencies
-r requirements-common.txt
# OpenVINO dependencies
torch >= 2.1.2
openvino ~= 2024.3.0.dev
optimum-intel[openvino] >= 1.17.2
triton >= 2.2.0 # FIXME(woosuk): This is a hack to avoid import error.

View File

@ -233,6 +233,10 @@ def _is_cpu() -> bool:
return VLLM_TARGET_DEVICE == "cpu"
def _is_openvino() -> bool:
return VLLM_TARGET_DEVICE == "openvino"
def _is_xpu() -> bool:
return VLLM_TARGET_DEVICE == "xpu"
@ -337,6 +341,8 @@ def get_vllm_version() -> str:
if neuron_version != MAIN_CUDA_VERSION:
neuron_version_str = neuron_version.replace(".", "")[:3]
version += f"+neuron{neuron_version_str}"
elif _is_openvino():
version += "+openvino"
elif _is_tpu():
version += "+tpu"
elif _is_cpu():
@ -388,6 +394,8 @@ def get_requirements() -> List[str]:
requirements = _read_requirements("requirements-rocm.txt")
elif _is_neuron():
requirements = _read_requirements("requirements-neuron.txt")
elif _is_openvino():
requirements = _read_requirements("requirements-openvino.txt")
elif _is_tpu():
requirements = _read_requirements("requirements-tpu.txt")
elif _is_cpu():
@ -396,7 +404,8 @@ def get_requirements() -> List[str]:
requirements = _read_requirements("requirements-xpu.txt")
else:
raise ValueError(
"Unsupported platform, please use CUDA, ROCm, Neuron, or CPU.")
"Unsupported platform, please use CUDA, ROCm, Neuron, "
"OpenVINO, or CPU.")
return requirements

View File

@ -9,8 +9,8 @@ from vllm.attention.selector import which_attn_to_use
@pytest.mark.parametrize(
"name", ["TORCH_SDPA", "ROCM_FLASH", "XFORMERS", "FLASHINFER"])
@pytest.mark.parametrize("device", ["cpu", "hip", "cuda"])
"name", ["TORCH_SDPA", "ROCM_FLASH", "XFORMERS", "FLASHINFER", "OPENVINO"])
@pytest.mark.parametrize("device", ["cpu", "openvino", "hip", "cuda"])
def test_env(name: str, device: str, monkeypatch):
"""Test that the attention selector can be set via environment variable.
Note that we do not test FlashAttn because it is the default backend.
@ -28,6 +28,11 @@ def test_env(name: str, device: str, monkeypatch):
backend = which_attn_to_use(8, 16, 8, None, torch.float16,
torch.float16, 16)
assert backend.name == "ROCM_FLASH"
elif device == "openvino":
with patch("vllm.attention.selector.is_openvino", return_value=True):
backend = which_attn_to_use(8, 16, 8, None, torch.float16,
torch.float16, 16)
assert backend.name == "OPENVINO"
else:
backend = which_attn_to_use(8, 16, 8, None, torch.float16,
torch.float16, 16)

View File

@ -0,0 +1,101 @@
from dataclasses import dataclass
from typing import List, Tuple
import openvino as ov
import torch
from vllm.attention.backends.abstract import (AttentionBackend,
AttentionMetadata)
class OpenVINOAttentionBackend(AttentionBackend):
@staticmethod
def get_name() -> str:
return "openvino"
@staticmethod
def get_impl_cls():
# OpenVINO implements PagedAttention as part of the Optimum
# exported model
raise NotImplementedError
@staticmethod
def make_metadata(*args, **kwargs) -> "AttentionMetadata":
raise NotImplementedError
@staticmethod
def make_openvino_metadata(*args, **kwargs) -> "OpenVINOAttentionMetadata":
return OpenVINOAttentionMetadata(*args, **kwargs)
@staticmethod
def get_kv_cache_shape(
num_blocks: int,
block_size: int,
num_kv_heads: int,
head_size: int,
) -> Tuple[int, ...]:
return (2, num_blocks, num_kv_heads, block_size, head_size)
@staticmethod
def swap_blocks(
src_kv_cache: ov.Tensor,
dst_kv_cache: ov.Tensor,
src_to_dst: torch.Tensor,
) -> None:
# OpenVINO currently supports only CPU, which does not require
# swap of KV cache blocks
raise NotImplementedError
@staticmethod
def copy_blocks(
kv_caches: List[Tuple[ov.Tensor, ov.Tensor]],
src_to_dists: List[Tuple[int, int]],
) -> None:
for src, dst in src_to_dists:
for key_cache, value_cache in kv_caches:
key_cache.data[dst, :] = key_cache.data[src, :]
value_cache.data[dst, :] = value_cache.data[src, :]
@dataclass
class OpenVINOAttentionMetadata:
"""Metadata for OpenVINOAttentionBackend.
Basic terms used below:
- batch_size_in_sequences - total number of sequences to execute
- prompt_lens per sequence size number of scheduled tokens
- batch_size_in_tokens = sum(prompt_lens)
- max_context_len = max(context_lens)
- max_num_blocks = div_up(max_context_len / BLOCK_SIZE)
- num_blocks total number of blocks in block_indices
"""
# Describes past KV cache size for each sequence within a batch
# Shape: [batch_size_in_sequences]
# Type: i32
past_lens: torch.Tensor
# Describes start indices of input / speculative tokens from
# current sequences within a batch sequence
# Shape: [batch_size_in_sequences + 1]
# Type: i32
subsequence_begins: torch.Tensor
# Describes block tables for each sequence within a batch -
# indices along 0th dimension in key_cache and value_cache inputs
# Shape: [num_blocks]
# Type: i32
block_indices: torch.Tensor
# Describes block tables for each sequence within a batch -
# for i-th element, it is an index in block_indices with the
# first block belonging to i-th sequence
# Shape: [batch_size_in_sequences + 1]
# Type: i32
block_indices_begins: torch.Tensor
# Describes max context length
# Shape: scalar
# Type: i32
max_context_len: torch.Tensor

View File

@ -7,7 +7,7 @@ import torch
import vllm.envs as envs
from vllm.attention.backends.abstract import AttentionBackend
from vllm.logger import init_logger
from vllm.utils import is_cpu, is_hip, is_tpu, is_xpu
from vllm.utils import is_cpu, is_hip, is_openvino, is_tpu, is_xpu
logger = init_logger(__name__)
@ -17,6 +17,7 @@ class _Backend(enum.Enum):
XFORMERS = enum.auto()
ROCM_FLASH = enum.auto()
TORCH_SDPA = enum.auto()
OPENVINO = enum.auto()
FLASHINFER = enum.auto()
PALLAS = enum.auto()
IPEX = enum.auto()
@ -64,6 +65,10 @@ def get_attn_backend(
logger.info("Using Torch SDPA backend.")
from vllm.attention.backends.torch_sdpa import TorchSDPABackend
return TorchSDPABackend
elif backend == _Backend.OPENVINO:
logger.info("Using OpenVINO Attention backend.")
from vllm.attention.backends.openvino import OpenVINOAttentionBackend
return OpenVINOAttentionBackend
elif backend == _Backend.IPEX:
assert is_xpu(), RuntimeError(
"IPEX attention backend is only used for the XPU device.")
@ -113,6 +118,11 @@ def which_attn_to_use(
logger.info("Cannot use %s backend on CPU.", selected_backend)
return _Backend.TORCH_SDPA
if is_openvino():
if selected_backend != _Backend.OPENVINO:
logger.info("Cannot use %s backend on OpenVINO.", selected_backend)
return _Backend.OPENVINO
if is_xpu():
if selected_backend != _Backend.IPEX:
logger.info("Cannot use %s backend on XPU.", selected_backend)

View File

@ -14,8 +14,8 @@ from vllm.model_executor.models import ModelRegistry
from vllm.tracing import is_otel_installed
from vllm.transformers_utils.config import get_config, get_hf_text_config
from vllm.utils import (cuda_device_count_stateless, get_cpu_memory, is_cpu,
is_hip, is_neuron, is_tpu, is_xpu, print_warning_once,
update_environment_variables)
is_hip, is_neuron, is_openvino, is_tpu, is_xpu,
print_warning_once, update_environment_variables)
if TYPE_CHECKING:
from ray.util.placement_group import PlacementGroup
@ -781,6 +781,8 @@ class DeviceConfig:
# Automated device type detection
if is_neuron():
self.device_type = "neuron"
elif is_openvino():
self.device_type = "openvino"
elif is_tpu():
self.device_type = "tpu"
elif is_cpu():
@ -796,7 +798,7 @@ class DeviceConfig:
self.device_type = device
# Some device types require processing inputs on CPU
if self.device_type in ["neuron"]:
if self.device_type in ["neuron", "openvino"]:
self.device = torch.device("cpu")
elif self.device_type in ["tpu"]:
self.device = None

View File

@ -504,12 +504,14 @@ class EngineArgs:
'Enabling this will use the fully sharded layers. '
'At high sequence length, max rank or '
'tensor parallel size, this is likely faster.'))
parser.add_argument(
"--device",
type=str,
default=EngineArgs.device,
choices=["auto", "cuda", "neuron", "cpu", "tpu", "xpu"],
help='Device type for vLLM execution.')
parser.add_argument("--device",
type=str,
default=EngineArgs.device,
choices=[
"auto", "cuda", "neuron", "cpu", "openvino",
"tpu", "xpu"
],
help='Device type for vLLM execution.')
# Related to Vision-language models such as llava
parser = EngineArgs.add_cli_args_for_vlm(parser)

View File

@ -393,6 +393,12 @@ class AsyncLLMEngine:
"Distributed execution is not supported with the CPU backend.")
from vllm.executor.cpu_executor import CPUExecutorAsync
executor_class = CPUExecutorAsync
elif engine_config.device_config.device_type == "openvino":
assert distributed_executor_backend is None, (
"Distributed execution is not supported with "
"the OpenVINO backend.")
from vllm.executor.openvino_executor import OpenVINOExecutorAsync
executor_class = OpenVINOExecutorAsync
elif engine_config.device_config.device_type == "xpu":
if distributed_executor_backend is None:
from vllm.executor.xpu_executor import XPUExecutorAsync

View File

@ -363,6 +363,9 @@ class LLMEngine:
elif engine_config.device_config.device_type == "cpu":
from vllm.executor.cpu_executor import CPUExecutor
executor_class = CPUExecutor
elif engine_config.device_config.device_type == "openvino":
from vllm.executor.openvino_executor import OpenVINOExecutor
executor_class = OpenVINOExecutor
elif engine_config.device_config.device_type == "xpu":
if distributed_executor_backend == "ray":
initialize_ray_cluster(engine_config.parallel_config)

View File

@ -28,6 +28,9 @@ if TYPE_CHECKING:
VLLM_TRACE_FUNCTION: int = 0
VLLM_ATTENTION_BACKEND: Optional[str] = None
VLLM_CPU_KVCACHE_SPACE: int = 0
VLLM_OPENVINO_KVCACHE_SPACE: int = 0
VLLM_OPENVINO_CPU_KV_CACHE_PRECISION: Optional[str] = None
VLLM_OPENVINO_ENABLE_QUANTIZED_WEIGHTS: bool = False
VLLM_XLA_CACHE_PATH: str = "~/.vllm/xla_cache/"
VLLM_USE_RAY_COMPILED_DAG: bool = False
VLLM_WORKER_MULTIPROC_METHOD: str = "fork"
@ -49,7 +52,8 @@ environment_variables: Dict[str, Callable[[], Any]] = {
# ================== Installation Time Env Vars ==================
# Target device of vLLM, supporting [cuda (by default), rocm, neuron, cpu]
# Target device of vLLM, supporting [cuda (by default),
# rocm, neuron, cpu, openvino]
"VLLM_TARGET_DEVICE":
lambda: os.getenv("VLLM_TARGET_DEVICE", "cuda"),
@ -208,6 +212,22 @@ environment_variables: Dict[str, Callable[[], Any]] = {
"VLLM_CPU_KVCACHE_SPACE":
lambda: int(os.getenv("VLLM_CPU_KVCACHE_SPACE", "0")),
# OpenVINO key-value cache space
# default is 4GB
"VLLM_OPENVINO_KVCACHE_SPACE":
lambda: int(os.getenv("VLLM_OPENVINO_KVCACHE_SPACE", "0")),
# OpenVINO KV cache precision
# default is bf16 if natively supported by platform, otherwise f16
# To enable KV cache compression, please, explicitly specify u8
"VLLM_OPENVINO_CPU_KV_CACHE_PRECISION":
lambda: os.getenv("VLLM_OPENVINO_CPU_KV_CACHE_PRECISION", None),
# Enables weights compression during model export via HF Optimum
# default is False
"VLLM_OPENVINO_ENABLE_QUANTIZED_WEIGHTS":
lambda: bool(os.getenv("VLLM_OPENVINO_ENABLE_QUANTIZED_WEIGHTS", False)),
# If the env var is set, it uses the Ray's compiled DAG API
# which optimizes the control plane overhead.
# Run vLLM with VLLM_USE_RAY_COMPILED_DAG=1 to enable it.

View File

@ -0,0 +1,163 @@
from typing import List, Set, Tuple
import openvino as ov
import openvino.properties.hint as hints
import torch
import vllm.envs as envs
from vllm.config import CacheConfig, ModelConfig
from vllm.executor.executor_base import ExecutorAsyncBase, ExecutorBase
from vllm.logger import init_logger
from vllm.lora.request import LoRARequest
from vllm.sequence import ExecuteModelRequest, SamplerOutput
from vllm.utils import (get_distributed_init_method, get_ip, get_open_port,
make_async)
logger = init_logger(__name__)
class OpenVINOExecutor(ExecutorBase):
def _init_executor(self) -> None:
assert self.device_config.device_type == "openvino"
assert self.lora_config is None, "OpenVINO backend doesn't support LoRA"
self.model_config = _verify_and_get_model_config(self.model_config)
self.cache_config = _verify_and_get_cache_config(self.cache_config)
# Instantiate the worker and load the model to CPU.
self._init_worker()
def _init_worker(self):
from vllm.worker.openvino_worker import OpenVINOWorker
assert (
self.parallel_config.world_size == 1
), "OpenVINOExecutor only supports single CPU socket currently."
distributed_init_method = get_distributed_init_method(
get_ip(), get_open_port())
self.driver_worker = OpenVINOWorker(
model_config=self.model_config,
parallel_config=self.parallel_config,
scheduler_config=self.scheduler_config,
device_config=self.device_config,
cache_config=self.cache_config,
load_config=self.load_config,
local_rank=0,
rank=0,
distributed_init_method=distributed_init_method,
lora_config=self.lora_config,
vision_language_config=self.vision_language_config,
kv_cache_dtype=self.cache_config.cache_dtype,
is_driver_worker=True,
)
self.driver_worker.init_device()
self.driver_worker.load_model()
def determine_num_available_blocks(self) -> Tuple[int, int]:
"""Determine the number of available KV blocks by invoking the
underlying worker.
"""
return self.driver_worker.determine_num_available_blocks()
def initialize_cache(self, num_gpu_blocks: int,
num_cpu_blocks: int) -> None:
"""Initialize the KV cache by invoking the underlying worker."""
# NOTE: We log here to avoid multiple logs when number of workers is
# greater than one. We could log in the engine, but not all executors
# have GPUs.
# NOTE: `cpu block` for OpenVINO backend is located on CPU memory but is
# referred as `gpu block`. Because we want to reuse the existing block
# management procedure.
logger.info("# CPU blocks: %d", num_gpu_blocks)
self.driver_worker.initialize_cache(num_gpu_blocks, num_cpu_blocks)
def execute_model(
self,
execute_model_req: ExecuteModelRequest) -> List[SamplerOutput]:
output = self.driver_worker.execute_model(execute_model_req)
return output
def add_lora(self, lora_request: LoRARequest) -> bool:
return self.driver_worker.add_lora(lora_request)
def remove_lora(self, lora_id: int) -> bool:
return self.driver_worker.remove_lora(lora_id)
def pin_lora(self, lora_id: int) -> bool:
return self.driver_worker.pin_lora(lora_id)
def list_loras(self) -> Set[int]:
return self.driver_worker.list_loras()
def check_health(self) -> None:
# OpenVINOExecutor will always be healthy as long as
# it's running.
return
class OpenVINOExecutorAsync(OpenVINOExecutor, ExecutorAsyncBase):
async def execute_model_async(
self,
execute_model_req: ExecuteModelRequest) -> List[SamplerOutput]:
output = await make_async(self.driver_worker.execute_model
)(execute_model_req=execute_model_req, )
return output
async def check_health_async(self) -> None:
# OpenVINOExecutor will always be healthy as long as
# it's running.
return
def _verify_and_get_model_config(config: ModelConfig) -> ModelConfig:
if config.dtype != torch.float32:
logger.warning(
f"Only float32 dtype is supported on OpenVINO, casting from {config.dtype}." # noqa: G004, E501
)
config.dtype = torch.float32
if not config.enforce_eager:
logger.warning(
"CUDA graph is not supported on OpenVINO backend, fallback to the "
"eager mode.")
config.enforce_eager = True
return config
def _verify_and_get_cache_config(config: CacheConfig) -> CacheConfig:
if envs.VLLM_OPENVINO_CPU_KV_CACHE_PRECISION == "u8":
logger.info("KV cache type is overried to u8 via "
"VLLM_OPENVINO_CPU_KV_CACHE_PRECISION env var.")
config.cache_dtype = ov.Type.u8
else:
core = ov.Core()
inference_precision = core.get_property("CPU",
hints.inference_precision)
if inference_precision == ov.Type.bf16:
config.cache_dtype = ov.Type.bf16
else:
config.cache_dtype = ov.Type.f16
if config.block_size != 32:
logger.info(
f"OpenVINO optimal block size is 32, overriding currently set {config.block_size}" # noqa: G004, E501
)
config.block_size = 32
kv_cache_space = envs.VLLM_OPENVINO_KVCACHE_SPACE
if kv_cache_space >= 0:
_GB = 1 << 30
if kv_cache_space == 0:
config.openvino_kvcache_space_bytes = 4 * _GB # type: ignore
logger.warning(
"Environment variable VLLM_OPENVINO_KVCACHE_SPACE (GB) "
"for OpenVINO backend is not set, using 4 by default.")
else:
config.openvino_kvcache_space_bytes = kv_cache_space * _GB # type: ignore
else:
raise RuntimeError(
"Invalid environment variable VLLM_OPENVINO_KVCACHE_SPACE"
f" {kv_cache_space}, expect a positive integer value.")
return config

View File

@ -679,7 +679,7 @@ def _get_ranks(x: torch.Tensor, indices: torch.Tensor) -> torch.Tensor:
Returns:
torch.Tensor: 1D tensor of shape (N,) where N is the no. of tokens.
Each element in the returned tensor represents the rank
Each element in the returned tensor represents the rank
of the chosen token in the input logprob tensor.
"""
vals = x[torch.arange(0, len(x), device=x.device, dtype=indices.dtype),
@ -965,7 +965,7 @@ def _modify_greedy_probs_inplace(logprobs: torch.Tensor, probs: torch.Tensor,
distribution.
- Greedy sampling performs `argmax` to obtain the token with the
highest likelihood.
Ignoring greedy sampling for a moment, we find that the computed probability
distribution has the following property: we can sample from it independently
and find that the token sampled by the Sampler has a frequency corresponding

View File

@ -0,0 +1,210 @@
# ruff: noqa: SIM117
from pathlib import Path
from typing import List, Optional, Tuple
import openvino as ov
import torch
from huggingface_hub import HfApi
from openvino._offline_transformations import paged_attention_transformation
from optimum.intel import OVModelForCausalLM
from torch import nn
import vllm.envs as envs
from vllm.attention.backends.openvino import OpenVINOAttentionMetadata
from vllm.config import DeviceConfig, ModelConfig
from vllm.logger import init_logger
from vllm.model_executor.layers.logits_processor import (LogitsProcessor,
_prune_hidden_states)
from vllm.model_executor.layers.sampler import Sampler
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.sequence import SamplerOutput
logger = init_logger(__name__)
def _flattenize_inputs(inputs):
"""
Helper function for making nested inputs flattens
"""
flatten_inputs = []
for input_data in inputs:
if input_data is None:
continue
if isinstance(input_data, (list, tuple)):
flatten_inputs.extend(_flattenize_inputs(input_data))
elif isinstance(input_data, dict):
flatten_inputs.extend(_flattenize_inputs(list(
input_data.values())))
else:
flatten_inputs.append(input_data)
return flatten_inputs
def _modify_cache_parameters(model: ov.Model, kv_cache_dtype: ov.Type,
is_cpu: bool):
# Apply hardware dependent modifications to KV tensors
for parameter in model.get_parameters():
input = parameter.get_output_tensor(0)
input_names = input.get_names()
if len(input_names) != 1:
continue
input_name = next(iter(input_names))
shape = parameter.get_partial_shape()
# use real block size if available, just a placeholder
# to provide the expected rank
x_size = 1
num_blocks = ov.Dimension()
block_size = ov.Dimension()
head_size = ov.Dimension()
# TODO: Negotiate required layout with plugins (CPU is ~OK, GPU is TBD),
# pass more parameters to this function to set more static dimensions
if input_name.startswith("key_cache."):
cpu_shape = [num_blocks, shape[1], block_size, head_size]
gpu_shape = [
num_blocks,
shape[1],
shape[2].get_length() //
x_size if shape[2].is_static else ov.Dimension(),
block_size,
x_size,
]
elif input_name.startswith("value_cache."):
cpu_shape = [num_blocks, shape[1], block_size, head_size]
gpu_shape = [num_blocks, shape[1], shape[2], block_size]
else:
continue
parameter.set_partial_shape(
ov.PartialShape(cpu_shape if is_cpu else gpu_shape))
parameter.set_element_type(kv_cache_dtype)
model.validate_nodes_and_infer_types()
def _require_model_export(model_id, revision=None, subfolder=None):
model_dir = Path(model_id)
if subfolder is not None:
model_dir = model_dir / subfolder
if model_dir.is_dir():
return (not (model_dir / "openvino_model.xml").exists()
or not (model_dir / "openvino_model.bin").exists())
hf_api = HfApi()
try:
model_info = hf_api.model_info(model_id, revision=revision or "main")
normalized_subfolder = (None if subfolder is None else
Path(subfolder).as_posix())
model_files = [
file.rfilename for file in model_info.siblings
if normalized_subfolder is None
or file.rfilename.startswith(normalized_subfolder)
]
ov_model_path = ("openvino_model.xml" if normalized_subfolder is None
else f"{normalized_subfolder}/openvino_model.xml")
return (ov_model_path not in model_files
or ov_model_path.replace(".xml", ".bin") not in model_files)
except Exception:
return True
class OpenVINOCasualLM(nn.Module):
def __init__(
self,
model_config: ModelConfig,
device_config: DeviceConfig,
kv_cache_dtype: ov.Type,
) -> None:
super().__init__()
self.logits_processor = LogitsProcessor(
model_config.hf_config.vocab_size, logits_as_input=True)
self.sampler = Sampler()
export = _require_model_export(model_config.model)
if export:
logger.warning(
f"Provided model id {model_config.model} does not " # noqa: G004
"contain OpenVINO IR, the model will be converted to IR with "
"default options. If you need to use specific options for "
"model conversion, use optimum-cli export openvino with "
"desired options.")
else:
logger.warning(
"OpenVINO IR is available for provided model id " # noqa: G004
f"{model_config.model}. This IR will be used for inference "
"as-is, all possible options that may affect model conversion "
"are ignored.")
load_in_8bit = envs.VLLM_OPENVINO_ENABLE_QUANTIZED_WEIGHTS
pt_model = OVModelForCausalLM.from_pretrained(
model_config.model,
export=export,
compile=False,
load_in_8bit=load_in_8bit,
trust_remote_code=model_config.trust_remote_code,
)
paged_attention_transformation(pt_model.model)
_modify_cache_parameters(pt_model.model, kv_cache_dtype,
device_config.device.type == "cpu")
core = ov.Core()
ov_compiled = core.compile_model(pt_model.model, "CPU")
self.ov_request = ov_compiled.create_infer_request()
def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
kv_caches: List[Tuple[ov.Tensor, ov.Tensor]],
attn_metadata: OpenVINOAttentionMetadata,
) -> torch.Tensor:
flatten_kv_cache = _flattenize_inputs(kv_caches)
inputs = [
input_ids,
positions,
*flatten_kv_cache,
attn_metadata.past_lens,
attn_metadata.subsequence_begins,
attn_metadata.block_indices,
attn_metadata.block_indices_begins,
attn_metadata.max_context_len,
]
self.ov_request.start_async(inputs, share_inputs=True)
self.ov_request.wait()
logits = torch.from_numpy(self.ov_request.get_tensor("logits").data)
# TODO: remove 'view' once OpenVINO PA will drop 'seq_len' dimension
return logits.view(-1, logits.shape[-1])
def compute_logits(self, hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata) -> torch.Tensor:
hidden_states = _prune_hidden_states(hidden_states, sampling_metadata)
logits = self.logits_processor(None, hidden_states, sampling_metadata)
return logits
def sample(
self,
logits: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> Optional[SamplerOutput]:
next_tokens = self.sampler(logits, sampling_metadata)
return next_tokens
def get_model(
model_config: ModelConfig,
device_config: DeviceConfig,
kv_cache_dtype: ov.Type,
**kwargs,
) -> torch.nn.Module:
lora_config = kwargs.get("lora_config", None)
if lora_config:
raise ValueError(
"OpenVINO modeling does not support LoRA, "
"but LoRA is enabled. Support for this model may "
"be added in the future. If this is important to you, "
"please open an issue on github.")
return OpenVINOCasualLM(model_config, device_config, kv_cache_dtype)

View File

@ -176,6 +176,15 @@ def is_cpu() -> bool:
return False
@lru_cache(maxsize=None)
def is_openvino() -> bool:
from importlib.metadata import PackageNotFoundError, version
try:
return "openvino" in version("vllm")
except PackageNotFoundError:
return False
@lru_cache(maxsize=None)
def is_neuron() -> bool:
try:
@ -546,7 +555,7 @@ def is_pin_memory_available() -> bool:
elif is_neuron():
print_warning_once("Pin memory is not supported on Neuron.")
return False
elif is_cpu():
elif is_cpu() or is_openvino():
return False
return True

View File

@ -0,0 +1,330 @@
from typing import List, NamedTuple, Optional, Tuple
import openvino as ov
import torch
from torch import nn
from vllm.attention import get_attn_backend
from vllm.attention.backends.openvino import OpenVINOAttentionMetadata
from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig,
ModelConfig, ParallelConfig, SchedulerConfig,
VisionLanguageConfig)
from vllm.logger import init_logger
from vllm.model_executor import SamplingMetadata
from vllm.model_executor.model_loader.openvino import get_model
from vllm.sequence import SamplerOutput, SequenceGroupMetadata
logger = init_logger(__name__)
class ModelInput(NamedTuple):
input_tokens: torch.Tensor
input_positions: torch.Tensor
attn_metadata: Optional[OpenVINOAttentionMetadata]
seq_lens: List[int]
query_lens: List[int]
multi_modal_input: Optional[torch.Tensor]
@classmethod
def empty(cls, device):
return ModelInput(input_tokens=torch.empty(0, device=device),
input_positions=torch.empty(0, device=device),
attn_metadata=None,
seq_lens=[],
query_lens=[],
multi_modal_input=None)
class OpenVINOModelRunner:
def __init__(
self,
model_config: ModelConfig,
parallel_config: ParallelConfig,
scheduler_config: SchedulerConfig,
device_config: DeviceConfig,
cache_config: CacheConfig,
load_config: LoadConfig,
lora_config: Optional[LoRAConfig],
vision_language_config: Optional[VisionLanguageConfig],
kv_cache_dtype: Optional[str] = "auto",
is_driver_worker: bool = False,
*args,
**kwargs,
):
self.model_config = model_config
self.parallel_config = parallel_config
self.scheduler_config = scheduler_config
self.device_config = device_config
self.cache_config = cache_config
self.lora_config = lora_config
self.vision_language_config = vision_language_config
self.load_config = load_config
self.is_driver_worker = is_driver_worker
self.device = self.device_config.device
self.kv_cache_dtype = kv_cache_dtype
self.sliding_window = model_config.get_sliding_window()
self.block_size = cache_config.block_size
self.attn_backend = get_attn_backend(
self.model_config.get_num_attention_heads(self.parallel_config),
self.model_config.get_head_size(),
self.model_config.get_num_kv_heads(self.parallel_config),
self.model_config.get_sliding_window(),
self.model_config.dtype,
self.kv_cache_dtype,
self.block_size,
)
# Lazy initialization.
self.model: nn.Module # Set after init_Model
def load_model(self) -> None:
self.model = get_model(
model_config=self.model_config,
device_config=self.device_config,
kv_cache_dtype=self.kv_cache_dtype,
)
def _prepare_model_input(
self,
seq_group_metadata_list: List[SequenceGroupMetadata],
) -> ModelInput:
"""Prepare the model input based on a given sequence group.
The API assumes seq_group_metadata_list is sorted by prefill -> decode.
The result tensors and data structure also batches input in prefill
-> decode order. For example,
- input_tokens[:num_prefill_tokens] contains prefill tokens.
- input_tokens[num_prefill_tokens:] contains decode tokens.
"""
input_tokens: List[int] = []
input_positions: List[int] = []
seq_lens: List[int] = []
past_lens: List[int] = []
query_lens: List[int] = []
subsequence_begins: List[int] = []
block_indices: List[int] = []
block_indices_begins: List[int] = []
# initialize beginning of prefix sums
subsequence_begins.append(0)
block_indices_begins.append(0)
if len(seq_group_metadata_list) == 0:
return ModelInput.empty(self.device)
for seq_group_metadata in seq_group_metadata_list:
seq_ids = list(seq_group_metadata.seq_data.keys())
is_prompt = seq_group_metadata.is_prompt
for seq_id in seq_ids:
computed_block_nums = seq_group_metadata.computed_block_nums
if (self.scheduler_config is not None
and self.scheduler_config.chunked_prefill_enabled
and not (computed_block_nums is None
or computed_block_nums == [])):
raise RuntimeError(
"chunked prefill cannot be used with prefix caching "
"now.")
seq_data = seq_group_metadata.seq_data[seq_id]
if is_prompt:
computed_len = seq_data.get_num_computed_tokens()
else:
# get_num_computed_tokens is incorrect for spec decoding.
# So, we should have a special logic here.
# TODO(sang): Fix it.
computed_len = seq_data.get_len() - 1
seq_len = min(
seq_data.get_len(),
computed_len + seq_group_metadata.token_chunk_size,
)
if is_prompt:
tokens = seq_data.get_token_ids()[computed_len:seq_len]
else:
# Optimization. get_token_ids requires the entire copy of
# tokens.
tokens = [seq_data.get_last_token_id()]
# Prefix cache was hit.
# Prefix is not supported with sliding_window
prefix_cache_hit = (computed_block_nums is not None
and len(computed_block_nums) > 0
and self.sliding_window is None
and is_prompt)
block_table = seq_group_metadata.block_tables[seq_id]
# TODO(sang): Combine chunked prefill and prefix caching by
# only allowing multiple of block_size chunk size.
# NOTE: This only works for oooooooxxx style attention.
if prefix_cache_hit:
assert computed_block_nums is not None
computed_len = len(computed_block_nums) * self.block_size
tokens = tokens[computed_len:]
elif (self.scheduler_config.chunked_prefill_enabled
or not is_prompt):
if seq_group_metadata.block_tables is not None:
# chunked prefill or decode
block_table = seq_group_metadata.block_tables[seq_id]
if self.sliding_window is not None:
# chunked prefill doesn't support sliding window.
assert not self.scheduler_config.chunked_prefill_enabled # noqa: E501
sliding_window_blocks = (self.sliding_window //
self.block_size)
block_table = block_table[-sliding_window_blocks:]
else:
# Only happens when memory profiling runs.
block_table = []
else:
# prompt phase w/o prefix_caching, chunked_prefill
pass
block_indices.extend(block_table)
block_indices_begins.append(block_indices_begins[-1] +
len(block_table))
# TODO(sang): This is a hack to make sliding window work with
# paged attn. We can remove it if we make paged attn kernel
# to properly handle slinding window attn.
if self.sliding_window is not None and not is_prompt:
seq_len = min(seq_len, self.sliding_window)
computed_len = seq_len - 1
seq_lens.append(seq_len)
query_len = seq_len - computed_len
query_lens.append(query_len)
input_tokens.extend(tokens)
input_positions.extend(list(range(computed_len, seq_len)))
past_lens.append(computed_len)
subsequence_begins.append(subsequence_begins[-1] + query_len)
if is_prompt:
assert len(seq_ids) == 1
else:
assert (
query_len == 1
), "seq_len: {}, computed_len: {}, query_len: {}".format(
seq_len, computed_len, query_len)
max_query_len = max(query_lens)
assert max_query_len > 0, "query_lens: {}".format(query_lens)
input_tokens = torch.tensor(input_tokens,
dtype=torch.long,
device=self.device) # type: ignore
input_positions = torch.tensor(input_positions,
dtype=torch.long,
device=self.device) # type: ignore
past_lens_tensor = torch.tensor(past_lens,
dtype=torch.int32,
device=self.device) # type: ignore
subsequence_begins_tensor = torch.tensor(
subsequence_begins, dtype=torch.int32,
device=self.device) # type: ignore
block_indices_tensor = torch.tensor(block_indices,
dtype=torch.int32,
device=self.device) # type: ignore
block_indices_begins_tensor = torch.tensor(
block_indices_begins, dtype=torch.int32,
device=self.device) # type: ignore
max_context_len = max(seq_lens)
max_context_len_tensor = torch.tensor(
max_context_len, dtype=torch.int32,
device=self.device) # type: ignore
attn_metadata = self.attn_backend.make_openvino_metadata(
past_lens=past_lens_tensor,
subsequence_begins=subsequence_begins_tensor,
block_indices=block_indices_tensor,
block_indices_begins=block_indices_begins_tensor,
max_context_len=max_context_len_tensor,
)
return ModelInput(
input_tokens,
input_positions,
attn_metadata,
seq_lens,
query_lens,
None,
)
def prepare_input_tensors(
self,
seq_group_metadata_list: List[SequenceGroupMetadata],
) -> Tuple[torch.Tensor, torch.Tensor, OpenVINOAttentionMetadata,
SamplingMetadata, Optional[torch.Tensor], ]:
multi_modal_input = None
# Prepare input tensors.
(
input_tokens,
input_positions,
attn_metadata,
seq_lens,
query_lens,
multi_modal_input,
) = self._prepare_model_input(seq_group_metadata_list)
sampling_metadata = SamplingMetadata.prepare(
seq_group_metadata_list,
seq_lens,
query_lens,
self.device,
pin_memory=False,
)
return (
input_tokens,
input_positions,
attn_metadata,
sampling_metadata,
multi_modal_input,
)
@torch.inference_mode()
def execute_model(
self,
seq_group_metadata_list: List[SequenceGroupMetadata],
kv_caches: List[Tuple["ov.Tensor", "ov.Tensor"]],
) -> Optional[SamplerOutput]:
(
input_tokens,
input_positions,
attn_metadata,
sampling_metadata,
multi_modal_input,
) = self.prepare_input_tensors(seq_group_metadata_list)
model_executable = self.model
execute_model_kwargs = {
"input_ids": input_tokens,
"positions": input_positions,
"kv_caches": kv_caches,
"attn_metadata": attn_metadata,
}
if self.vision_language_config:
execute_model_kwargs.update({"image_input": multi_modal_input})
hidden_states = model_executable(**execute_model_kwargs)
# Compute the logits.
logits = self.model.compute_logits(hidden_states, sampling_metadata)
# Sample the next token.
output = self.model.sample(
logits=logits,
sampling_metadata=sampling_metadata,
)
return output

View File

@ -0,0 +1,353 @@
"""An OpenVINO worker class."""
from typing import Any, Dict, List, Optional, Tuple
import openvino as ov
import torch
import torch.distributed
from vllm.attention import get_attn_backend
from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig,
ModelConfig, ParallelConfig, SchedulerConfig,
VisionLanguageConfig)
from vllm.distributed import (broadcast_tensor_dict,
ensure_model_parallel_initialized,
init_distributed_environment)
from vllm.logger import init_logger
from vllm.model_executor import set_random_seed
from vllm.sequence import ExecuteModelRequest, SamplerOutput
from vllm.worker.openvino_model_runner import OpenVINOModelRunner
from vllm.worker.worker_base import LoraNotSupportedWorkerBase
logger = init_logger(__name__)
class OpenVINOCacheEngine:
"""Manages the KV cache for OpenVINO backend.
This class is responsible for initializing and managing CPU KV
caches. It also provides methods for performing KV cache operations, such
as copying.
"""
def __init__(
self,
cache_config: CacheConfig,
model_config: ModelConfig,
parallel_config: ParallelConfig,
device_config: DeviceConfig,
) -> None:
assert device_config.device_type == "openvino"
self.cache_config = cache_config
self.model_config = model_config
self.parallel_config = parallel_config
self.head_size = model_config.get_head_size()
if device_config.device.type == "cpu" and \
cache_config.cache_dtype == ov.Type.u8:
# Scale, zero point and quantized data will be stored together.
# The layout for per token per head:
# |scale(f32)|zeropoint(f32)|quantized data(u8,idx_1)|quantized data(u8,idx_2)|...|quantized data(u8,idx_head_size)| # noqa: E501
# so, we have to extend head_size by 8, which is sizeof(float)
# for scale and sizeof(float) for zeropoint
self.head_size += 8
self.num_layers = model_config.get_num_layers(parallel_config)
self.num_kv_heads = model_config.get_num_kv_heads(parallel_config)
self.block_size = cache_config.block_size
# Note: In CacheConfig, num_gpu_blocks actual is num_cpu_blocks
# for OpenVINO backend, because we want to reuse KV cache management
# in the scheduler.
self.num_cpu_blocks = cache_config.num_gpu_blocks
# Get attention backend.
self.attn_backend = get_attn_backend(
self.model_config.get_num_attention_heads(self.parallel_config),
self.head_size,
self.model_config.get_num_kv_heads(self.parallel_config),
self.model_config.get_sliding_window(),
self.model_config.dtype,
self.cache_config.cache_dtype,
self.block_size,
)
# Initialize the cache.
self.kv_cache: List[Tuple[ov.Tensor,
ov.Tensor]] = self._allocate_kv_cache(
self.num_cpu_blocks)
def _allocate_kv_cache(
self,
num_blocks: int,
) -> List[Tuple[ov.Tensor, ov.Tensor]]:
"""Allocates KV cache."""
k_block_shape = v_block_shape = self.attn_backend.get_kv_cache_shape(
num_blocks, self.block_size, self.num_kv_heads, self.head_size)[1:]
kv_cache: List[Tuple[ov.Tensor, ov.Tensor]] = []
for _ in range(self.num_layers):
key_blocks = ov.Tensor(self.cache_config.cache_dtype,
k_block_shape)
value_blocks = ov.Tensor(self.cache_config.cache_dtype,
v_block_shape)
kv_cache.append((key_blocks, value_blocks))
return kv_cache
def swap_in(self, src_to_dst: Dict[int, int]) -> None:
raise NotImplementedError(
"Swap is not supported in OpenVINOCacheEngine.")
def swap_out(self, src_to_dst: Dict[int, int]) -> None:
raise NotImplementedError(
"Swap is not supported in OpenVINOCacheEngine.")
def copy(self, src_to_dsts: Dict[int, List[int]]) -> None:
self.attn_backend.copy_blocks(self.kv_cache, src_to_dsts)
@staticmethod
def get_cache_block_size(
block_size: int,
cache_dtype: ov.Type,
model_config: ModelConfig,
parallel_config: ParallelConfig,
) -> int:
head_size = model_config.get_head_size()
num_kv_heads = model_config.get_num_kv_heads(parallel_config)
num_layers = model_config.get_num_layers(parallel_config)
if cache_dtype == ov.Type.u8:
# Scale, zero point and quantized data will be stored together.
# The layout for per token per head:
# |scale(f32)|zeropoint(f32)|quantized data(u8,idx_1)|quantized data(u8,idx_2)|...|quantized data(u8,idx_head_size)| # noqa: E501
# so, we have to extend head_size by 8, which is sizeof(float)
# for scale and sizeof(float) for zeropoint
head_size += 8
key_cache_block = block_size * num_kv_heads * head_size
value_cache_block = key_cache_block
total = num_layers * (key_cache_block + value_cache_block)
dtype_size = cache_dtype.size
return dtype_size * total
class OpenVINOWorker(LoraNotSupportedWorkerBase):
"""A worker class that executes the model on OpenVINO backend.
Each worker is associated with a single OpenVINO device. The worker is
responsible for maintaining the KV cache and executing the model on the
OpenVINO backend.
"""
def __init__(
self,
model_config: ModelConfig,
parallel_config: ParallelConfig,
scheduler_config: SchedulerConfig,
device_config: DeviceConfig,
cache_config: CacheConfig,
load_config: LoadConfig,
local_rank: int,
rank: int,
distributed_init_method: str,
lora_config: Optional[LoRAConfig] = None,
vision_language_config: Optional[VisionLanguageConfig] = None,
kv_cache_dtype: Optional[ov.Type] = ov.Type.undefined,
is_driver_worker: bool = False,
) -> None:
self.model_config = model_config
self.parallel_config = parallel_config
self.scheduler_config = scheduler_config
self.device_config = device_config
self.cache_config = cache_config
self.load_config = load_config
self.local_rank = local_rank
self.rank = rank
self.distributed_init_method = distributed_init_method
self.lora_config = lora_config
self.vision_language_config = vision_language_config
self.is_driver_worker = is_driver_worker
if self.is_driver_worker:
assert self.rank == 0, "The driver worker must have rank 0."
if self.model_config.trust_remote_code:
# note: lazy import to avoid importing torch before initializing
from vllm.utils import init_cached_hf_modules
init_cached_hf_modules()
self.model_runner = OpenVINOModelRunner(
model_config,
parallel_config,
scheduler_config,
device_config,
cache_config,
load_config=self.load_config,
lora_config=self.lora_config,
vision_language_config=self.vision_language_config,
kv_cache_dtype=kv_cache_dtype,
is_driver_worker=is_driver_worker,
)
# Uninitialized cache engine. Will be initialized by
# initialize_cache.
self.cache_engine: OpenVINOCacheEngine
self.kv_cache: List[Tuple[ov.Tensor, ov.Tensor]]
def init_device(self) -> None:
self.init_distributed_environment()
# Set random seed.
set_random_seed(self.model_config.seed)
def load_model(self):
self.model_runner.load_model()
def determine_num_available_blocks(self) -> Tuple[int, int]:
"""Determine the number of blocks available for the KV cache.
This determines how many KV blocks can fit into the configured
KV cache space.
Note that since vLLM assumes a block resides on GPU if it can be
modified, we return num_gpu_blocks=num_cpu_blocks and num_cpu_blocks=0.
This allows us to reuse the scheduler of vLLM without generalizing it
to different devices.
"""
# For OpenVINO backend, the block number will be calculated based on the
# openvino_kvcache_space_bytes.
cache_block_size = self.get_cache_block_size_bytes()
num_cpu_blocks = int(self.cache_config.openvino_kvcache_space_bytes //
cache_block_size)
num_cpu_blocks = max(num_cpu_blocks, 0)
# Note: To reuse the cache management procedure,
# use cpu cache as 'gpu cache'.
num_gpu_blocks = num_cpu_blocks
num_cpu_blocks = 0
return num_gpu_blocks, num_cpu_blocks
def initialize_cache(self, num_gpu_blocks: int,
num_cpu_blocks: int) -> None:
"""Initialize the KV cache. Currently, swappable CPU memory is not
supported.
Since this worker does not support GPUs, we use the num_gpu_blocks to
determine how many non-swappable CPU blocks to allocate.
"""
assert (num_cpu_blocks == 0
), f"{type(self)} does not support swappable cache"
# Note: To reuse the cache management procedure,
# use cpu cache as 'gpu cache'.
num_cpu_blocks = num_gpu_blocks
self._validate_num_cpu_blocks(num_cpu_blocks)
self.cache_config.num_gpu_blocks = num_cpu_blocks
self.cache_config.num_cpu_blocks = 0
# Initialize the cache.
self._init_cache_engine()
def _validate_num_cpu_blocks(self, num_cpu_blocks: int) -> None:
"""Raise errors if the num_cpu_blocks is invalid."""
if num_cpu_blocks <= 0:
raise ValueError(
"No available memory for the cache blocks. "
"Try increasing `VLLM_OPENVINO_KVCACHE_SPACE` when "
"initializing the engine.")
max_seq_len = self.cache_config.block_size * num_cpu_blocks
if self.model_config.max_model_len > max_seq_len:
raise ValueError(
f"The model's max seq len ({self.model_config.max_model_len}) "
"is larger than the maximum number of tokens that can be "
f"stored in KV cache ({max_seq_len}). Try increasing "
"`VLLM_OPENVINO_KVCACHE_SPACE` or decreasing `max_model_len` "
"when initializing the engine.")
def _init_cache_engine(self) -> None:
self.cache_engine = OpenVINOCacheEngine(
self.cache_config,
self.model_config,
self.parallel_config,
self.device_config,
)
self.kv_cache = self.cache_engine.kv_cache
self.model_runner.block_size = self.cache_engine.block_size
assert self.kv_cache is not None
# Populate the cache to warmup the memory
for key_cache, value_cache in self.kv_cache:
key_cache.data[:] = 0
value_cache.data[:] = 0
def cache_copy(
self,
blocks_to_copy: List[Tuple[int, int]],
) -> None:
self.cache_engine.copy(blocks_to_copy) # type: ignore
@torch.inference_mode()
def execute_model(
self,
execute_model_req: Optional[ExecuteModelRequest] = None,
) -> List[SamplerOutput]:
if execute_model_req is None:
seq_group_metadata_list = None
else:
seq_group_metadata_list = execute_model_req.seq_group_metadata_list
if self.is_driver_worker:
assert seq_group_metadata_list is not None
num_seq_groups: int = len(seq_group_metadata_list)
assert execute_model_req is not None
blocks_to_copy = execute_model_req.blocks_to_copy
assert len(execute_model_req.blocks_to_swap_in) == 0
assert len(execute_model_req.blocks_to_swap_out) == 0
data: Dict[str, Any] = {
"num_seq_groups": num_seq_groups,
"blocks_to_copy": execute_model_req.blocks_to_copy,
}
broadcast_tensor_dict(data, src=0)
else:
data = broadcast_tensor_dict(src=0)
num_seq_groups = data["num_seq_groups"]
blocks_to_copy = data["blocks_to_copy"]
self.cache_copy(blocks_to_copy)
# If there is no input, we don't need to execute the model.
if num_seq_groups == 0:
return []
output = self.model_runner.execute_model(seq_group_metadata_list,
self.kv_cache)
# OpenVINO worker only supports single-step execution.
return [output]
def init_distributed_environment(self) -> None:
"""Initialize the distributed environment."""
parallel_config = self.parallel_config
rank = self.rank
distributed_init_method = self.distributed_init_method
init_distributed_environment(
world_size=parallel_config.world_size,
rank=rank,
distributed_init_method=distributed_init_method,
backend="gloo",
)
# A small all_reduce for warmup.
torch.distributed.all_reduce(torch.zeros(1).cpu())
ensure_model_parallel_initialized(
parallel_config.tensor_parallel_size,
parallel_config.pipeline_parallel_size,
)
def get_cache_block_size_bytes(self) -> int:
"""Return the size in bytes of a single KV cache block."""
return OpenVINOCacheEngine.get_cache_block_size(
self.cache_config.block_size,
self.cache_config.cache_dtype,
self.model_config,
self.parallel_config,
)