[Hardware][Intel GPU] Add Intel GPU(XPU) inference backend (#3814)
Co-authored-by: Jiang Li <jiang1.li@intel.com> Co-authored-by: Abhilash Majumder <abhilash.majumder@intel.com> Co-authored-by: Abhilash Majumder <30946547+abhilash1910@users.noreply.github.com>
This commit is contained in:
parent
1f12122b17
commit
728c4c8a06
14
.buildkite/run-xpu-test.sh
Normal file
14
.buildkite/run-xpu-test.sh
Normal file
@ -0,0 +1,14 @@
|
|||||||
|
# This script build the CPU docker image and run the offline inference inside the container.
|
||||||
|
# It serves a sanity check for compilation and basic model usage.
|
||||||
|
set -ex
|
||||||
|
|
||||||
|
# Try building the docker image
|
||||||
|
docker build -t xpu-test -f Dockerfile.xpu .
|
||||||
|
|
||||||
|
# Setup cleanup
|
||||||
|
remove_docker_container() { docker rm -f xpu-test || true; }
|
||||||
|
trap remove_docker_container EXIT
|
||||||
|
remove_docker_container
|
||||||
|
|
||||||
|
# Run the image and launch offline inference
|
||||||
|
docker run --network host --name xpu-test --device /dev/dri -v /dev/dri/by-path:/dev/dri/by-path xpu-test python3 examples/offline_inference.py
|
@ -45,6 +45,11 @@ steps:
|
|||||||
queue: intel
|
queue: intel
|
||||||
command: bash .buildkite/run-cpu-test.sh
|
command: bash .buildkite/run-cpu-test.sh
|
||||||
|
|
||||||
|
- label: "XPU Test"
|
||||||
|
agents:
|
||||||
|
queue: intel
|
||||||
|
command: bash .buildkite/run-xpu-test.sh
|
||||||
|
|
||||||
{% for step in steps %}
|
{% for step in steps %}
|
||||||
- label: "{{ step.label }}"
|
- label: "{{ step.label }}"
|
||||||
agents:
|
agents:
|
||||||
|
22
Dockerfile.xpu
Normal file
22
Dockerfile.xpu
Normal file
@ -0,0 +1,22 @@
|
|||||||
|
FROM intel/oneapi-basekit:2024.1.0-devel-ubuntu22.04
|
||||||
|
|
||||||
|
RUN wget -O- https://apt.repos.intel.com/intel-gpg-keys/GPG-PUB-KEY-INTEL-SW-PRODUCTS.PUB | gpg --dearmor | tee /usr/share/keyrings/intel-oneapi-archive-keyring.gpg > /dev/null && \
|
||||||
|
echo "deb [signed-by=/usr/share/keyrings/intel-oneapi-archive-keyring.gpg] https://apt.repos.intel.com/oneapi all main " | tee /etc/apt/sources.list.d/oneAPI.list && \
|
||||||
|
chmod 644 /usr/share/keyrings/intel-oneapi-archive-keyring.gpg && \
|
||||||
|
rm /etc/apt/sources.list.d/intel-graphics.list && \
|
||||||
|
wget -O- https://repositories.intel.com/graphics/intel-graphics.key | gpg --dearmor | tee /usr/share/keyrings/intel-graphics.gpg > /dev/null && \
|
||||||
|
echo "deb [arch=amd64,i386 signed-by=/usr/share/keyrings/intel-graphics.gpg] https://repositories.intel.com/graphics/ubuntu jammy arc" | tee /etc/apt/sources.list.d/intel.gpu.jammy.list && \
|
||||||
|
chmod 644 /usr/share/keyrings/intel-graphics.gpg
|
||||||
|
|
||||||
|
RUN apt-get update -y \
|
||||||
|
&& apt-get install -y curl libicu70 lsb-release git wget vim numactl python3 python3-pip
|
||||||
|
|
||||||
|
COPY ./ /workspace/vllm
|
||||||
|
|
||||||
|
WORKDIR /workspace/vllm
|
||||||
|
|
||||||
|
RUN pip install -v -r requirements-xpu.txt
|
||||||
|
|
||||||
|
RUN VLLM_TARGET_DEVICE=xpu python3 setup.py install
|
||||||
|
|
||||||
|
CMD ["/bin/bash"]
|
@ -191,7 +191,7 @@ if __name__ == '__main__':
|
|||||||
"--device",
|
"--device",
|
||||||
type=str,
|
type=str,
|
||||||
default="cuda",
|
default="cuda",
|
||||||
choices=["cuda", "cpu", "tpu"],
|
choices=["cuda", "cpu", "tpu", "xpu"],
|
||||||
help='device type for vLLM execution, supporting CUDA and CPU.')
|
help='device type for vLLM execution, supporting CUDA and CPU.')
|
||||||
parser.add_argument('--block-size',
|
parser.add_argument('--block-size',
|
||||||
type=int,
|
type=int,
|
||||||
|
@ -349,7 +349,7 @@ if __name__ == "__main__":
|
|||||||
"--device",
|
"--device",
|
||||||
type=str,
|
type=str,
|
||||||
default="cuda",
|
default="cuda",
|
||||||
choices=["cuda", "cpu", "tpu"],
|
choices=["cuda", "cpu", "tpu", "xpu"],
|
||||||
help='device type for vLLM execution, supporting CUDA and CPU.')
|
help='device type for vLLM execution, supporting CUDA and CPU.')
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--enable-prefix-caching",
|
"--enable-prefix-caching",
|
||||||
|
61
docs/source/getting_started/xpu-installation.rst
Normal file
61
docs/source/getting_started/xpu-installation.rst
Normal file
@ -0,0 +1,61 @@
|
|||||||
|
.. _installation_xpu:
|
||||||
|
|
||||||
|
Installation with XPU
|
||||||
|
========================
|
||||||
|
|
||||||
|
vLLM initially supports basic model inferencing and serving on Intel GPU platform.
|
||||||
|
|
||||||
|
Table of contents:
|
||||||
|
|
||||||
|
#. :ref:`Requirements <xpu_backend_requirements>`
|
||||||
|
#. :ref:`Quick start using Dockerfile <xpu_backend_quick_start_dockerfile>`
|
||||||
|
#. :ref:`Build from source <build_xpu_backend_from_source>`
|
||||||
|
|
||||||
|
.. _xpu_backend_requirements:
|
||||||
|
|
||||||
|
Requirements
|
||||||
|
------------
|
||||||
|
|
||||||
|
* OS: Linux
|
||||||
|
* Supported Hardware: Intel Data Center GPU (Intel ARC GPU WIP)
|
||||||
|
* OneAPI requirements: oneAPI 2024.1
|
||||||
|
|
||||||
|
.. _xpu_backend_quick_start_dockerfile:
|
||||||
|
|
||||||
|
Quick start using Dockerfile
|
||||||
|
----------------------------
|
||||||
|
|
||||||
|
.. code-block:: console
|
||||||
|
|
||||||
|
$ docker build -f Dockerfile.xpu -t vllm-xpu-env --shm-size=4g .
|
||||||
|
$ docker run -it \
|
||||||
|
--rm \
|
||||||
|
--network=host \
|
||||||
|
--device /dev/dri \
|
||||||
|
-v /dev/dri/by-path:/dev/dri/by-path \
|
||||||
|
vllm-xpu-env
|
||||||
|
|
||||||
|
.. _build_xpu_backend_from_source:
|
||||||
|
|
||||||
|
Build from source
|
||||||
|
-----------------
|
||||||
|
|
||||||
|
- First, install required driver and intel OneAPI 2024.1.
|
||||||
|
|
||||||
|
- Second, install Python packages for vLLM XPU backend building:
|
||||||
|
|
||||||
|
.. code-block:: console
|
||||||
|
|
||||||
|
$ pip install --upgrade pip
|
||||||
|
$ pip install -v -r requirements-xpu.txt
|
||||||
|
|
||||||
|
- Finally, build and install vLLM XPU backend:
|
||||||
|
|
||||||
|
.. code-block:: console
|
||||||
|
|
||||||
|
$ VLLM_TARGET_DEVICE=xpu python setup.py install
|
||||||
|
|
||||||
|
.. note::
|
||||||
|
- FP16 is the default data type in the current XPU backend. The BF16 data
|
||||||
|
type will be supported in the future.
|
||||||
|
|
@ -66,6 +66,7 @@ Documentation
|
|||||||
getting_started/cpu-installation
|
getting_started/cpu-installation
|
||||||
getting_started/neuron-installation
|
getting_started/neuron-installation
|
||||||
getting_started/tpu-installation
|
getting_started/tpu-installation
|
||||||
|
getting_started/xpu-installation
|
||||||
getting_started/quickstart
|
getting_started/quickstart
|
||||||
getting_started/debugging
|
getting_started/debugging
|
||||||
getting_started/examples/examples_index
|
getting_started/examples/examples_index
|
||||||
|
11
requirements-xpu.txt
Normal file
11
requirements-xpu.txt
Normal file
@ -0,0 +1,11 @@
|
|||||||
|
# Common dependencies
|
||||||
|
-r requirements-common.txt
|
||||||
|
|
||||||
|
setuptools < 70.0.0 # IPEX's torch have some dependency. to be removed.
|
||||||
|
|
||||||
|
torch @ https://intel-extension-for-pytorch.s3.amazonaws.com/ipex_dev/xpu/torch-2.1.0.post1%2Bcxx11.abi-cp310-cp310-linux_x86_64.whl
|
||||||
|
intel_extension_for_pytorch @ https://intel-extension-for-pytorch.s3.amazonaws.com/ipex_dev/xpu/intel_extension_for_pytorch-2.1.30a0-cp310-cp310-linux_x86_64.whl
|
||||||
|
oneccl_bind_pt @ https://intel-extension-for-pytorch.s3.amazonaws.com/ipex_stable/xpu/oneccl_bind_pt-2.1.200%2Bxpu-cp310-cp310-linux_x86_64.whl
|
||||||
|
|
||||||
|
triton @ https://github.com/intel/intel-xpu-backend-for-triton/releases/download/v2.1.0/triton-2.1.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl
|
||||||
|
|
8
setup.py
8
setup.py
@ -233,6 +233,10 @@ def _is_cpu() -> bool:
|
|||||||
return VLLM_TARGET_DEVICE == "cpu"
|
return VLLM_TARGET_DEVICE == "cpu"
|
||||||
|
|
||||||
|
|
||||||
|
def _is_xpu() -> bool:
|
||||||
|
return VLLM_TARGET_DEVICE == "xpu"
|
||||||
|
|
||||||
|
|
||||||
def _build_custom_ops() -> bool:
|
def _build_custom_ops() -> bool:
|
||||||
return _is_cuda() or _is_hip() or _is_cpu()
|
return _is_cuda() or _is_hip() or _is_cpu()
|
||||||
|
|
||||||
@ -337,6 +341,8 @@ def get_vllm_version() -> str:
|
|||||||
version += "+tpu"
|
version += "+tpu"
|
||||||
elif _is_cpu():
|
elif _is_cpu():
|
||||||
version += "+cpu"
|
version += "+cpu"
|
||||||
|
elif _is_xpu():
|
||||||
|
version += "+xpu"
|
||||||
else:
|
else:
|
||||||
raise RuntimeError("Unknown runtime environment")
|
raise RuntimeError("Unknown runtime environment")
|
||||||
|
|
||||||
@ -386,6 +392,8 @@ def get_requirements() -> List[str]:
|
|||||||
requirements = _read_requirements("requirements-tpu.txt")
|
requirements = _read_requirements("requirements-tpu.txt")
|
||||||
elif _is_cpu():
|
elif _is_cpu():
|
||||||
requirements = _read_requirements("requirements-cpu.txt")
|
requirements = _read_requirements("requirements-cpu.txt")
|
||||||
|
elif _is_xpu():
|
||||||
|
requirements = _read_requirements("requirements-xpu.txt")
|
||||||
else:
|
else:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"Unsupported platform, please use CUDA, ROCm, Neuron, or CPU.")
|
"Unsupported platform, please use CUDA, ROCm, Neuron, or CPU.")
|
||||||
|
@ -373,7 +373,8 @@ def reshape_and_cache_flash(
|
|||||||
kv_cache_dtype)
|
kv_cache_dtype)
|
||||||
|
|
||||||
|
|
||||||
def copy_blocks(key_caches: torch.Tensor, value_caches: torch.Tensor,
|
def copy_blocks(key_caches: List[torch.Tensor],
|
||||||
|
value_caches: List[torch.Tensor],
|
||||||
block_mapping: torch.Tensor) -> None:
|
block_mapping: torch.Tensor) -> None:
|
||||||
torch.ops._C_cache_ops.copy_blocks(key_caches, value_caches, block_mapping)
|
torch.ops._C_cache_ops.copy_blocks(key_caches, value_caches, block_mapping)
|
||||||
|
|
||||||
|
241
vllm/_ipex_ops.py
Normal file
241
vllm/_ipex_ops.py
Normal file
@ -0,0 +1,241 @@
|
|||||||
|
from typing import List, Optional, Tuple
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from vllm.logger import init_logger
|
||||||
|
|
||||||
|
logger = init_logger(__name__)
|
||||||
|
|
||||||
|
try:
|
||||||
|
import intel_extension_for_pytorch as ipex
|
||||||
|
except ImportError as e:
|
||||||
|
logger.warning("Import error msg: %s", e.msg)
|
||||||
|
|
||||||
|
|
||||||
|
class ipex_ops:
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _reshape_activation_tensor(
|
||||||
|
x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||||
|
num = x.size(0)
|
||||||
|
d = x.size(1) // 2
|
||||||
|
x = x.reshape(num, 2, d)
|
||||||
|
x1, x2 = torch.chunk(x, chunks=2, dim=1)
|
||||||
|
x1 = x1.reshape(num, d)
|
||||||
|
x2 = x2.reshape(num, d)
|
||||||
|
return x1, x2
|
||||||
|
|
||||||
|
def silu_and_mul(out: torch.Tensor, x: torch.Tensor) -> None:
|
||||||
|
x1, x2 = ipex_ops._reshape_activation_tensor(x)
|
||||||
|
ipex.llm.functional.silu_mul(x1, x2, out)
|
||||||
|
|
||||||
|
def gelu_and_mul(out: torch.Tensor, x: torch.Tensor) -> None:
|
||||||
|
x1, x2 = ipex_ops._reshape_activation_tensor(x)
|
||||||
|
ipex.llm.functional.gelu_mul(x1, x2, out, "none")
|
||||||
|
|
||||||
|
def gelu_tanh_and_mul(out: torch.Tensor, x: torch.Tensor) -> None:
|
||||||
|
x1, x2 = ipex_ops._reshape_activation_tensor(x)
|
||||||
|
ipex.llm.functional.gelu_mul(x1, x2, out, "tanh")
|
||||||
|
|
||||||
|
def gelu_fast(out: torch.Tensor, x: torch.Tensor) -> None:
|
||||||
|
out.copy_(torch.nn.functional.gelu(x))
|
||||||
|
|
||||||
|
def gelu_new(out: torch.Tensor, x: torch.Tensor) -> None:
|
||||||
|
out.copy_(torch.nn.functional.gelu(x))
|
||||||
|
|
||||||
|
def paged_attention_v1(
|
||||||
|
out: torch.Tensor,
|
||||||
|
query: torch.Tensor,
|
||||||
|
key_cache: torch.Tensor,
|
||||||
|
value_cache: torch.Tensor,
|
||||||
|
num_kv_heads: int,
|
||||||
|
scale: float,
|
||||||
|
block_tables: torch.Tensor,
|
||||||
|
context_lens: torch.Tensor,
|
||||||
|
block_size: int,
|
||||||
|
max_context_len: int,
|
||||||
|
alibi_slopes: Optional[torch.Tensor],
|
||||||
|
kv_cache_dtype: str,
|
||||||
|
kv_scale: float,
|
||||||
|
tp_rank: int = 0,
|
||||||
|
blocksparse_local_blocks: int = 0,
|
||||||
|
blocksparse_vert_stride: int = 0,
|
||||||
|
blocksparse_block_size: int = 64,
|
||||||
|
blocksparse_head_sliding_step: int = 0,
|
||||||
|
) -> None:
|
||||||
|
assert kv_cache_dtype == "auto"
|
||||||
|
num_heads = out.size(1)
|
||||||
|
num_queries_per_tokens = num_heads // num_kv_heads
|
||||||
|
head_mapping = torch.arange(
|
||||||
|
0,
|
||||||
|
num_kv_heads,
|
||||||
|
device=query.device,
|
||||||
|
dtype=torch.int32,
|
||||||
|
).view(num_kv_heads,
|
||||||
|
1).repeat_interleave(num_queries_per_tokens).flatten()
|
||||||
|
# todo: ipex will refactor namespace
|
||||||
|
torch.xpu.paged_attention_v1(out, query.contiguous(),
|
||||||
|
key_cache.view_as(value_cache),
|
||||||
|
value_cache, head_mapping, scale,
|
||||||
|
block_tables, context_lens, block_size,
|
||||||
|
max_context_len, alibi_slopes)
|
||||||
|
|
||||||
|
def paged_attention_v2(
|
||||||
|
out: torch.Tensor,
|
||||||
|
exp_sum: torch.Tensor,
|
||||||
|
max_logits: torch.Tensor,
|
||||||
|
tmp_out: torch.Tensor,
|
||||||
|
query: torch.Tensor,
|
||||||
|
key_cache: torch.Tensor,
|
||||||
|
value_cache: torch.Tensor,
|
||||||
|
num_kv_heads: int,
|
||||||
|
scale: float,
|
||||||
|
block_tables: torch.Tensor,
|
||||||
|
context_lens: torch.Tensor,
|
||||||
|
block_size: int,
|
||||||
|
max_context_len: int,
|
||||||
|
alibi_slopes: Optional[torch.Tensor],
|
||||||
|
kv_cache_dtype: str,
|
||||||
|
kv_scale: float,
|
||||||
|
tp_rank: int = 0,
|
||||||
|
blocksparse_local_blocks: int = 0,
|
||||||
|
blocksparse_vert_stride: int = 0,
|
||||||
|
blocksparse_block_size: int = 64,
|
||||||
|
blocksparse_head_sliding_step: int = 0,
|
||||||
|
) -> None:
|
||||||
|
assert kv_cache_dtype == "auto"
|
||||||
|
num_heads = out.size(1)
|
||||||
|
num_queries_per_tokens = num_heads // num_kv_heads
|
||||||
|
head_mapping = torch.arange(
|
||||||
|
0,
|
||||||
|
num_kv_heads,
|
||||||
|
dtype=torch.int32,
|
||||||
|
device=query.device,
|
||||||
|
).view(num_kv_heads,
|
||||||
|
1).repeat_interleave(num_queries_per_tokens).flatten()
|
||||||
|
# todo: ipex will refactor namespace
|
||||||
|
torch.xpu.paged_attention_v2(out, exp_sum, max_logits, tmp_out,
|
||||||
|
query.contiguous(),
|
||||||
|
key_cache.view_as(value_cache),
|
||||||
|
value_cache, head_mapping, block_tables,
|
||||||
|
context_lens, scale, block_size,
|
||||||
|
max_context_len, alibi_slopes)
|
||||||
|
|
||||||
|
def rotary_embedding(
|
||||||
|
positions: torch.Tensor, # [batch_size, seq_len]
|
||||||
|
query: torch.Tensor, # [batch_size, seq_len, num_heads*head_size]
|
||||||
|
key: torch.Tensor, # [batch_size, seq_len, num_kv_heads*head_size]
|
||||||
|
head_size: int,
|
||||||
|
cos_sin_cache: torch.Tensor, # [cos_sin_dim, rot_dim]
|
||||||
|
is_neox: bool,
|
||||||
|
) -> None:
|
||||||
|
if positions.dim() == 1:
|
||||||
|
positions = positions.unsqueeze(0)
|
||||||
|
query = query.unsqueeze(0)
|
||||||
|
key = key.unsqueeze(0)
|
||||||
|
|
||||||
|
rotary_dim = cos_sin_cache.size(1)
|
||||||
|
query = query.view(*query.shape[:-1], -1, head_size)
|
||||||
|
key = key.view(*key.shape[:-1], -1, head_size)
|
||||||
|
|
||||||
|
query_rot = query[..., :rotary_dim]
|
||||||
|
key_rot = key[..., :rotary_dim]
|
||||||
|
|
||||||
|
cos_sin = cos_sin_cache[positions.long()]
|
||||||
|
cos, sin = cos_sin.chunk(2, dim=-1)
|
||||||
|
|
||||||
|
if is_neox:
|
||||||
|
cos = cos.repeat(1, 1, 2).unsqueeze(-2)
|
||||||
|
sin = sin.repeat(1, 1, 2).unsqueeze(-2)
|
||||||
|
else:
|
||||||
|
cos = cos.repeat_interleave(2, dim=-1).unsqueeze(-2)
|
||||||
|
sin = sin.repeat_interleave(2, dim=-1).unsqueeze(-2)
|
||||||
|
ipex.llm.functional.rotary_embedding(query_rot, key_rot, sin, cos,
|
||||||
|
rotary_dim, is_neox, positions)
|
||||||
|
|
||||||
|
def batched_rotary_embedding(positions: torch.Tensor, query: torch.Tensor,
|
||||||
|
key: torch.Tensor, head_size: int,
|
||||||
|
cos_sin_cache: torch.Tensor, is_neox: bool,
|
||||||
|
rot_dim: int,
|
||||||
|
cos_sin_cache_offsets: torch.Tensor) -> None:
|
||||||
|
if positions.dim() == 1:
|
||||||
|
positions = positions.unsqueeze(0)
|
||||||
|
query = query.unsqueeze(0)
|
||||||
|
key = key.unsqueeze(0)
|
||||||
|
cos_sin_cache_offsets = cos_sin_cache_offsets.view_as(positions)
|
||||||
|
rotary_dim = cos_sin_cache.size(1)
|
||||||
|
query = query.view(*query.shape[:-1], -1, head_size)
|
||||||
|
key = key.view(*key.shape[:-1], -1, head_size)
|
||||||
|
|
||||||
|
query_rot = query[..., :rotary_dim]
|
||||||
|
key_rot = key[..., :rotary_dim]
|
||||||
|
|
||||||
|
cos_sin = cos_sin_cache[torch.add(positions,
|
||||||
|
cos_sin_cache_offsets).long()]
|
||||||
|
cos, sin = cos_sin.chunk(2, dim=-1)
|
||||||
|
|
||||||
|
if is_neox:
|
||||||
|
cos = cos.repeat(1, 1, 2).unsqueeze(-2)
|
||||||
|
sin = sin.repeat(1, 1, 2).unsqueeze(-2)
|
||||||
|
else:
|
||||||
|
cos = cos.repeat_interleave(2, dim=-1).unsqueeze(-2)
|
||||||
|
sin = sin.repeat_interleave(2, dim=-1).unsqueeze(-2)
|
||||||
|
|
||||||
|
ipex.llm.functional.rotary_embedding(query_rot, key_rot, sin, cos,
|
||||||
|
rotary_dim, is_neox, positions)
|
||||||
|
|
||||||
|
def rms_norm(out: torch.Tensor, input: torch.Tensor, weight: torch.Tensor,
|
||||||
|
epsilon: float) -> None:
|
||||||
|
tmp = ipex.llm.functional.rms_norm(input, weight, epsilon)
|
||||||
|
out.copy_(tmp)
|
||||||
|
|
||||||
|
def fused_add_rms_norm(input: torch.Tensor, residual: torch.Tensor,
|
||||||
|
weight: torch.Tensor, epsilon: float) -> None:
|
||||||
|
tmp = ipex.llm.functional.add_rms_norm(residual, input, weight, None,
|
||||||
|
epsilon, True)
|
||||||
|
input.copy_(tmp)
|
||||||
|
|
||||||
|
def varlen_attention(
|
||||||
|
query: torch.Tensor,
|
||||||
|
key: torch.Tensor,
|
||||||
|
value: torch.Tensor,
|
||||||
|
out: torch.Tensor,
|
||||||
|
seqlen_q: torch.Tensor,
|
||||||
|
seqlen_k: torch.Tensor,
|
||||||
|
max_seqlen_q: int,
|
||||||
|
max_seqlen_k: int,
|
||||||
|
pdropout: float,
|
||||||
|
softmax_scale: float,
|
||||||
|
zero_tensors: bool,
|
||||||
|
is_causal: bool,
|
||||||
|
return_softmax: bool,
|
||||||
|
gen_: torch.Generator,
|
||||||
|
) -> None:
|
||||||
|
ipex.llm.functional.varlen_attention(query, key, value, out, seqlen_q,
|
||||||
|
seqlen_k, max_seqlen_q,
|
||||||
|
max_seqlen_k, pdropout,
|
||||||
|
softmax_scale, zero_tensors,
|
||||||
|
is_causal, return_softmax, gen_)
|
||||||
|
|
||||||
|
def reshape_and_cache(
|
||||||
|
key: torch.Tensor,
|
||||||
|
value: torch.Tensor,
|
||||||
|
key_cache: torch.Tensor,
|
||||||
|
value_cache: torch.Tensor,
|
||||||
|
slot_mapping: torch.Tensor,
|
||||||
|
kv_cache_dtype: str,
|
||||||
|
kv_scale: float,
|
||||||
|
) -> None:
|
||||||
|
assert kv_cache_dtype == "auto"
|
||||||
|
ipex.llm.modules.PagedAttention.reshape_and_cache(
|
||||||
|
key, value, key_cache, value_cache, slot_mapping)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def copy_blocks(key_caches: List[torch.Tensor],
|
||||||
|
value_caches: List[torch.Tensor],
|
||||||
|
block_mapping: torch.Tensor) -> None:
|
||||||
|
torch.xpu.copy_blocks(key_caches, value_caches, block_mapping)
|
||||||
|
|
||||||
|
def swap_blocks(src: torch.Tensor, dst: torch.Tensor,
|
||||||
|
block_mapping: torch.Tensor) -> None:
|
||||||
|
torch.xpu.swap_blocks(src, dst, block_mapping)
|
355
vllm/attention/backends/ipex_attn.py
Normal file
355
vllm/attention/backends/ipex_attn.py
Normal file
@ -0,0 +1,355 @@
|
|||||||
|
""" Attention layer with torch scaled_dot_product_attention
|
||||||
|
and PagedAttention."""
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from typing import Any, Dict, List, Optional, Tuple, Type
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from vllm._ipex_ops import ipex_ops
|
||||||
|
from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
|
||||||
|
AttentionMetadata)
|
||||||
|
from vllm.attention.ops.paged_attn import (PagedAttention,
|
||||||
|
PagedAttentionMetadata)
|
||||||
|
|
||||||
|
_PARTITION_SIZE = 512
|
||||||
|
|
||||||
|
|
||||||
|
class IpexAttnBackend(AttentionBackend):
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def get_name() -> str:
|
||||||
|
return "ipex-attn"
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def get_impl_cls() -> Type["IpexAttnBackendImpl"]:
|
||||||
|
return IpexAttnBackendImpl
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def make_metadata(*args, **kwargs) -> "IpexAttnMetadata":
|
||||||
|
return IpexAttnMetadata(*args, **kwargs)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def get_kv_cache_shape(
|
||||||
|
num_blocks: int,
|
||||||
|
block_size: int,
|
||||||
|
num_kv_heads: int,
|
||||||
|
head_size: int,
|
||||||
|
) -> Tuple[int, ...]:
|
||||||
|
return PagedAttention.get_kv_cache_shape(num_blocks, block_size,
|
||||||
|
num_kv_heads, head_size)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def swap_blocks(
|
||||||
|
src_kv_cache: torch.Tensor,
|
||||||
|
dst_kv_cache: torch.Tensor,
|
||||||
|
src_to_dst: torch.Tensor,
|
||||||
|
) -> None:
|
||||||
|
PagedAttention.swap_blocks(src_kv_cache, dst_kv_cache, src_to_dst)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def copy_blocks(
|
||||||
|
kv_caches: List[torch.Tensor],
|
||||||
|
src_to_dists: torch.Tensor,
|
||||||
|
) -> None:
|
||||||
|
PagedAttention.copy_blocks(kv_caches, src_to_dists)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class IpexAttnMetadata(AttentionMetadata, PagedAttentionMetadata):
|
||||||
|
"""Metadata for IpexAttnBackend.
|
||||||
|
"""
|
||||||
|
# Currently, input sequences can only contain all prompts
|
||||||
|
# or all decoding. True if all sequences are prompts.
|
||||||
|
is_prompt: bool
|
||||||
|
slot_mapping: torch.Tensor
|
||||||
|
seq_lens: Optional[List[int]]
|
||||||
|
seqlen_q: Optional[torch.Tensor]
|
||||||
|
max_seqlen: Optional[int]
|
||||||
|
|
||||||
|
def __post_init__(self):
|
||||||
|
# Set during the execution of the first attention op.
|
||||||
|
# It is a list because it is needed to set per prompt
|
||||||
|
# when alibi slopes is used. It is because of the limitation
|
||||||
|
# from xformer API.
|
||||||
|
# will not appear in the __repr__ and __init__
|
||||||
|
self.attn_bias: Optional[List[torch.Tensor]] = None
|
||||||
|
|
||||||
|
@property
|
||||||
|
def prefill_metadata(self) -> Optional["IpexAttnMetadata"]:
|
||||||
|
# Currently chunked prefill is not supported
|
||||||
|
if self.num_decode_tokens == 0:
|
||||||
|
assert self.num_prefills > 0
|
||||||
|
return self
|
||||||
|
|
||||||
|
return None
|
||||||
|
|
||||||
|
@property
|
||||||
|
def decode_metadata(self) -> Optional["IpexAttnMetadata"]:
|
||||||
|
# Currently chunked prefill is not supported
|
||||||
|
if self.num_prefills > 0:
|
||||||
|
assert self.num_decode_tokens == 0
|
||||||
|
return None
|
||||||
|
|
||||||
|
return self
|
||||||
|
|
||||||
|
|
||||||
|
class IpexAttnBackendImpl(AttentionImpl[IpexAttnMetadata]):
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
num_heads: int,
|
||||||
|
head_size: int,
|
||||||
|
scale: float,
|
||||||
|
num_kv_heads: int,
|
||||||
|
alibi_slopes: Optional[List[float]],
|
||||||
|
sliding_window: Optional[int],
|
||||||
|
kv_cache_dtype: str,
|
||||||
|
blocksparse_params: Optional[Dict[str, Any]] = None,
|
||||||
|
) -> None:
|
||||||
|
assert blocksparse_params is None, ValueError(
|
||||||
|
"Torch SPDA does not support block-sparse attention.")
|
||||||
|
self.num_heads = num_heads
|
||||||
|
self.head_size = head_size
|
||||||
|
self.scale = float(scale)
|
||||||
|
self.num_kv_heads = num_kv_heads
|
||||||
|
if alibi_slopes is not None:
|
||||||
|
alibi_slopes = torch.tensor(alibi_slopes, dtype=torch.float32)
|
||||||
|
self.alibi_slopes = alibi_slopes
|
||||||
|
self.sliding_window = sliding_window
|
||||||
|
self.kv_cache_dtype = kv_cache_dtype
|
||||||
|
|
||||||
|
assert self.num_heads % self.num_kv_heads == 0
|
||||||
|
self.num_queries_per_kv = self.num_heads // self.num_kv_heads
|
||||||
|
self.need_mask = (self.alibi_slopes is not None
|
||||||
|
or self.sliding_window is not None)
|
||||||
|
|
||||||
|
supported_head_sizes = PagedAttention.get_supported_head_sizes()
|
||||||
|
if head_size not in supported_head_sizes:
|
||||||
|
raise ValueError(
|
||||||
|
f"Head size {head_size} is not supported by PagedAttention. "
|
||||||
|
f"Supported head sizes are: {supported_head_sizes}.")
|
||||||
|
if kv_cache_dtype != "auto":
|
||||||
|
raise NotImplementedError(
|
||||||
|
"IPEX backend does not support FP8 KV cache. "
|
||||||
|
"Please use xFormers backend instead.")
|
||||||
|
|
||||||
|
def split_kv_cache(
|
||||||
|
self,
|
||||||
|
kv_cache: torch.Tensor,
|
||||||
|
num_kv_heads: int,
|
||||||
|
head_size: int,
|
||||||
|
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||||
|
x = 1
|
||||||
|
num_blocks = kv_cache.shape[1]
|
||||||
|
|
||||||
|
key_cache = kv_cache[0]
|
||||||
|
key_cache = key_cache.view(num_blocks, num_kv_heads, head_size // x,
|
||||||
|
-1, x)
|
||||||
|
value_cache = kv_cache[1]
|
||||||
|
value_cache = value_cache.view(num_blocks, num_kv_heads, head_size, -1)
|
||||||
|
return key_cache, value_cache
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
query: torch.Tensor,
|
||||||
|
key: torch.Tensor,
|
||||||
|
value: torch.Tensor,
|
||||||
|
kv_cache: Optional[torch.Tensor],
|
||||||
|
attn_metadata: IpexAttnMetadata, # type: ignore
|
||||||
|
kv_scale: float = 1.0,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
"""Forward pass with IPEX varlen_attention and PagedAttention.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
query: shape = [num_tokens, num_heads * head_size]
|
||||||
|
key: shape = [num_tokens, num_kv_heads * head_size]
|
||||||
|
value: shape = [num_tokens, num_kv_heads * head_size]
|
||||||
|
kv_cache = [2, num_blocks, block_size * num_kv_heads * head_size]
|
||||||
|
attn_metadata: Metadata for attention.
|
||||||
|
Returns:
|
||||||
|
shape = [num_tokens, num_heads * head_size]
|
||||||
|
"""
|
||||||
|
assert kv_scale == 1.0
|
||||||
|
num_tokens, hidden_size = query.shape
|
||||||
|
# Reshape the query, key, and value tensors.
|
||||||
|
query = query.view(-1, self.num_heads, self.head_size)
|
||||||
|
key = key.view(-1, self.num_kv_heads, self.head_size)
|
||||||
|
value = value.view(-1, self.num_kv_heads, self.head_size)
|
||||||
|
|
||||||
|
if kv_cache is not None:
|
||||||
|
key_cache, value_cache = self.split_kv_cache(
|
||||||
|
kv_cache, self.num_kv_heads, self.head_size)
|
||||||
|
ipex_ops.reshape_and_cache(
|
||||||
|
key,
|
||||||
|
value,
|
||||||
|
key_cache,
|
||||||
|
value_cache,
|
||||||
|
attn_metadata.slot_mapping.flatten(),
|
||||||
|
self.kv_cache_dtype,
|
||||||
|
kv_scale,
|
||||||
|
)
|
||||||
|
|
||||||
|
if attn_metadata.is_prompt:
|
||||||
|
assert attn_metadata.seq_lens is not None
|
||||||
|
if (kv_cache is None or attn_metadata.block_tables.numel() == 0):
|
||||||
|
if self.num_kv_heads != self.num_heads:
|
||||||
|
key = key.repeat_interleave(self.num_queries_per_kv, dim=1)
|
||||||
|
value = value.repeat_interleave(self.num_queries_per_kv,
|
||||||
|
dim=1)
|
||||||
|
|
||||||
|
if attn_metadata.attn_bias is None:
|
||||||
|
if self.alibi_slopes is not None:
|
||||||
|
att_masks = _make_alibi_bias(
|
||||||
|
self.alibi_slopes, query.dtype,
|
||||||
|
attn_metadata.seq_lens) # type: ignore
|
||||||
|
elif self.sliding_window is not None:
|
||||||
|
att_masks = _make_sliding_window_bias(
|
||||||
|
attn_metadata.seq_lens, self.sliding_window,
|
||||||
|
query.dtype) # type: ignore
|
||||||
|
else:
|
||||||
|
att_masks = _make_sliding_window_bias(
|
||||||
|
attn_metadata.seq_lens, None, dtype=query.dtype)
|
||||||
|
attn_metadata.attn_bias = att_masks
|
||||||
|
|
||||||
|
output = torch.empty(
|
||||||
|
(num_tokens, self.num_heads, self.head_size),
|
||||||
|
dtype=query.dtype,
|
||||||
|
device=query.device)
|
||||||
|
ipex_ops.varlen_attention(query,
|
||||||
|
key,
|
||||||
|
value,
|
||||||
|
output,
|
||||||
|
attn_metadata.seqlen_q,
|
||||||
|
attn_metadata.seqlen_q,
|
||||||
|
attn_metadata.max_seqlen,
|
||||||
|
attn_metadata.max_seqlen,
|
||||||
|
pdropout=0.0,
|
||||||
|
softmax_scale=self.scale,
|
||||||
|
zero_tensors=False,
|
||||||
|
is_causal=True,
|
||||||
|
return_softmax=False,
|
||||||
|
gen_=None)
|
||||||
|
else:
|
||||||
|
# prefix-enabled attention
|
||||||
|
raise RuntimeError(
|
||||||
|
"IPEX backend doesn't support prefix decoding.")
|
||||||
|
|
||||||
|
else:
|
||||||
|
# Decoding run.
|
||||||
|
max_seq_len = attn_metadata.max_decode_seq_len
|
||||||
|
output = torch.empty_like(query)
|
||||||
|
block_size = value_cache.shape[3]
|
||||||
|
num_seqs, num_heads, head_size = query.shape
|
||||||
|
max_num_partitions = ((max_seq_len + _PARTITION_SIZE - 1) //
|
||||||
|
_PARTITION_SIZE)
|
||||||
|
# NOTE(woosuk): We use a simple heuristic to decide whether to use
|
||||||
|
# PagedAttention V1 or V2. If the number of partitions is 1, we use
|
||||||
|
# V1 to avoid the overhead of reduction. Also, if the number of
|
||||||
|
# sequences or heads is large, we use V1 since there is enough work
|
||||||
|
# to parallelize.
|
||||||
|
# TODO(woosuk): Tune this heuristic.
|
||||||
|
# For context len > 8192, use V2 kernel to avoid shared memory
|
||||||
|
# shortage.
|
||||||
|
use_v1 = (max_seq_len <= 8192 and
|
||||||
|
(max_num_partitions == 1 or num_seqs * num_heads > 512))
|
||||||
|
if use_v1:
|
||||||
|
# Run PagedAttention V1.
|
||||||
|
ipex_ops.paged_attention_v1(
|
||||||
|
output,
|
||||||
|
query,
|
||||||
|
key_cache,
|
||||||
|
value_cache,
|
||||||
|
self.num_kv_heads,
|
||||||
|
self.scale,
|
||||||
|
attn_metadata.block_tables,
|
||||||
|
attn_metadata.seq_lens_tensor,
|
||||||
|
block_size,
|
||||||
|
max_seq_len,
|
||||||
|
self.alibi_slopes,
|
||||||
|
self.kv_cache_dtype,
|
||||||
|
kv_scale,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
# Run PagedAttention V2.
|
||||||
|
assert _PARTITION_SIZE % block_size == 0
|
||||||
|
tmp_output = torch.empty(
|
||||||
|
size=(num_seqs, num_heads, max_num_partitions, head_size),
|
||||||
|
dtype=output.dtype,
|
||||||
|
device=output.device,
|
||||||
|
)
|
||||||
|
exp_sums = torch.empty(
|
||||||
|
size=(num_seqs, num_heads, max_num_partitions),
|
||||||
|
dtype=torch.float32,
|
||||||
|
device=output.device,
|
||||||
|
)
|
||||||
|
max_logits = torch.empty_like(exp_sums)
|
||||||
|
ipex_ops.paged_attention_v2(
|
||||||
|
output,
|
||||||
|
exp_sums,
|
||||||
|
max_logits,
|
||||||
|
tmp_output,
|
||||||
|
query,
|
||||||
|
key_cache,
|
||||||
|
value_cache,
|
||||||
|
self.num_kv_heads,
|
||||||
|
self.scale,
|
||||||
|
attn_metadata.block_tables,
|
||||||
|
attn_metadata.seq_lens_tensor,
|
||||||
|
block_size,
|
||||||
|
max_seq_len,
|
||||||
|
self.alibi_slopes,
|
||||||
|
self.kv_cache_dtype,
|
||||||
|
kv_scale,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Reshape the output tensor.
|
||||||
|
return output.view(-1, self.num_heads * self.head_size)
|
||||||
|
|
||||||
|
|
||||||
|
def _make_alibi_bias(
|
||||||
|
alibi_slopes: torch.Tensor,
|
||||||
|
dtype: torch.dtype,
|
||||||
|
seq_lens: List[int],
|
||||||
|
) -> List[torch.Tensor]:
|
||||||
|
attn_biases = []
|
||||||
|
for seq_len in seq_lens:
|
||||||
|
bias = torch.arange(seq_len, dtype=dtype, device=alibi_slopes.device)
|
||||||
|
# NOTE(zhuohan): HF uses
|
||||||
|
# `bias = bias[None, :].repeat(seq_len, 1)`
|
||||||
|
# here. We find that both biases give the same results, but
|
||||||
|
# the bias below more accurately follows the original ALiBi
|
||||||
|
# paper.
|
||||||
|
bias = bias[None, :] - bias[:, None]
|
||||||
|
|
||||||
|
num_heads = alibi_slopes.shape[0]
|
||||||
|
bias = bias[None, :].repeat((num_heads, 1, 1))
|
||||||
|
bias.mul_(alibi_slopes[:, None, None])
|
||||||
|
inf_mask = torch.empty(
|
||||||
|
(1, seq_len, seq_len),
|
||||||
|
dtype=bias.dtype,
|
||||||
|
device=alibi_slopes.device).fill_(-torch.inf).triu_(diagonal=1)
|
||||||
|
attn_biases.append((bias + inf_mask).to(dtype))
|
||||||
|
|
||||||
|
return attn_biases
|
||||||
|
|
||||||
|
|
||||||
|
def _make_sliding_window_bias(
|
||||||
|
seq_lens: List[int],
|
||||||
|
window_size: Optional[int],
|
||||||
|
dtype: torch.dtype,
|
||||||
|
) -> List[torch.Tensor]:
|
||||||
|
attn_biases = []
|
||||||
|
for seq_len in seq_lens:
|
||||||
|
tensor = torch.full(
|
||||||
|
(1, seq_len, seq_len),
|
||||||
|
dtype=dtype,
|
||||||
|
fill_value=1,
|
||||||
|
)
|
||||||
|
shift = 0
|
||||||
|
mask = torch.tril(tensor, diagonal=shift).to(dtype) # type: ignore
|
||||||
|
if window_size is not None:
|
||||||
|
mask = torch.triu(mask, diagonal=shift - window_size + 1)
|
||||||
|
mask = torch.log(mask)
|
||||||
|
attn_biases.append(mask.to(dtype))
|
||||||
|
|
||||||
|
return attn_biases
|
@ -7,7 +7,7 @@ import torch
|
|||||||
import vllm.envs as envs
|
import vllm.envs as envs
|
||||||
from vllm.attention.backends.abstract import AttentionBackend
|
from vllm.attention.backends.abstract import AttentionBackend
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
from vllm.utils import is_cpu, is_hip, is_tpu
|
from vllm.utils import is_cpu, is_hip, is_tpu, is_xpu
|
||||||
|
|
||||||
logger = init_logger(__name__)
|
logger = init_logger(__name__)
|
||||||
|
|
||||||
@ -19,6 +19,7 @@ class _Backend(enum.Enum):
|
|||||||
TORCH_SDPA = enum.auto()
|
TORCH_SDPA = enum.auto()
|
||||||
FLASHINFER = enum.auto()
|
FLASHINFER = enum.auto()
|
||||||
PALLAS = enum.auto()
|
PALLAS = enum.auto()
|
||||||
|
IPEX = enum.auto()
|
||||||
|
|
||||||
|
|
||||||
@lru_cache(maxsize=None)
|
@lru_cache(maxsize=None)
|
||||||
@ -58,12 +59,17 @@ def get_attn_backend(
|
|||||||
ROCmFlashAttentionBackend)
|
ROCmFlashAttentionBackend)
|
||||||
return ROCmFlashAttentionBackend
|
return ROCmFlashAttentionBackend
|
||||||
elif backend == _Backend.TORCH_SDPA:
|
elif backend == _Backend.TORCH_SDPA:
|
||||||
# TODO: make XPU backend available here.
|
|
||||||
assert is_cpu(), RuntimeError(
|
assert is_cpu(), RuntimeError(
|
||||||
"Torch SDPA backend is only used for the CPU device.")
|
"Torch SDPA backend is only used for the CPU device.")
|
||||||
logger.info("Using Torch SDPA backend.")
|
logger.info("Using Torch SDPA backend.")
|
||||||
from vllm.attention.backends.torch_sdpa import TorchSDPABackend
|
from vllm.attention.backends.torch_sdpa import TorchSDPABackend
|
||||||
return TorchSDPABackend
|
return TorchSDPABackend
|
||||||
|
elif backend == _Backend.IPEX:
|
||||||
|
assert is_xpu(), RuntimeError(
|
||||||
|
"IPEX attention backend is only used for the XPU device.")
|
||||||
|
logger.info("Using IPEX attention backend.")
|
||||||
|
from vllm.attention.backends.ipex_attn import IpexAttnBackend
|
||||||
|
return IpexAttnBackend
|
||||||
elif backend == _Backend.FLASHINFER:
|
elif backend == _Backend.FLASHINFER:
|
||||||
logger.info("Using Flashinfer backend.")
|
logger.info("Using Flashinfer backend.")
|
||||||
logger.warning("Eager mode is required for the Flashinfer backend. "
|
logger.warning("Eager mode is required for the Flashinfer backend. "
|
||||||
@ -107,6 +113,11 @@ def which_attn_to_use(
|
|||||||
logger.info("Cannot use %s backend on CPU.", selected_backend)
|
logger.info("Cannot use %s backend on CPU.", selected_backend)
|
||||||
return _Backend.TORCH_SDPA
|
return _Backend.TORCH_SDPA
|
||||||
|
|
||||||
|
if is_xpu():
|
||||||
|
if selected_backend != _Backend.IPEX:
|
||||||
|
logger.info("Cannot use %s backend on XPU.", selected_backend)
|
||||||
|
return _Backend.IPEX
|
||||||
|
|
||||||
if is_tpu():
|
if is_tpu():
|
||||||
if selected_backend != _Backend.PALLAS:
|
if selected_backend != _Backend.PALLAS:
|
||||||
logger.info("Cannot use %s backend on TPU.", selected_backend)
|
logger.info("Cannot use %s backend on TPU.", selected_backend)
|
||||||
|
@ -12,7 +12,7 @@ from vllm.model_executor.layers.quantization import QUANTIZATION_METHODS
|
|||||||
from vllm.model_executor.models import ModelRegistry
|
from vllm.model_executor.models import ModelRegistry
|
||||||
from vllm.transformers_utils.config import get_config, get_hf_text_config
|
from vllm.transformers_utils.config import get_config, get_hf_text_config
|
||||||
from vllm.utils import (cuda_device_count_stateless, get_cpu_memory, is_cpu,
|
from vllm.utils import (cuda_device_count_stateless, get_cpu_memory, is_cpu,
|
||||||
is_hip, is_neuron, is_tpu)
|
is_hip, is_neuron, is_tpu, is_xpu)
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from ray.util.placement_group import PlacementGroup
|
from ray.util.placement_group import PlacementGroup
|
||||||
@ -757,6 +757,8 @@ class DeviceConfig:
|
|||||||
self.device_type = "tpu"
|
self.device_type = "tpu"
|
||||||
elif is_cpu():
|
elif is_cpu():
|
||||||
self.device_type = "cpu"
|
self.device_type = "cpu"
|
||||||
|
elif is_xpu():
|
||||||
|
self.device_type = "xpu"
|
||||||
else:
|
else:
|
||||||
# We don't call torch.cuda.is_available() here to
|
# We don't call torch.cuda.is_available() here to
|
||||||
# avoid initializing CUDA before workers are forked
|
# avoid initializing CUDA before workers are forked
|
||||||
|
@ -58,7 +58,7 @@ def _split_tensor_dict(
|
|||||||
# because it contains not only the device type but also the device
|
# because it contains not only the device type but also the device
|
||||||
# index (e.g. "cuda:0"). We only need the device type.
|
# index (e.g. "cuda:0"). We only need the device type.
|
||||||
# receiving side will set the device index.
|
# receiving side will set the device index.
|
||||||
device = "cpu" if value.is_cpu else "cuda"
|
device = value.device.type
|
||||||
metadata_list.append(
|
metadata_list.append(
|
||||||
(key, TensorMetadata(device, value.dtype, value.size())))
|
(key, TensorMetadata(device, value.dtype, value.size())))
|
||||||
tensor_list.append(value)
|
tensor_list.append(value)
|
||||||
|
@ -501,10 +501,11 @@ class EngineArgs:
|
|||||||
'Enabling this will use the fully sharded layers. '
|
'Enabling this will use the fully sharded layers. '
|
||||||
'At high sequence length, max rank or '
|
'At high sequence length, max rank or '
|
||||||
'tensor parallel size, this is likely faster.'))
|
'tensor parallel size, this is likely faster.'))
|
||||||
parser.add_argument("--device",
|
parser.add_argument(
|
||||||
|
"--device",
|
||||||
type=str,
|
type=str,
|
||||||
default=EngineArgs.device,
|
default=EngineArgs.device,
|
||||||
choices=["auto", "cuda", "neuron", "cpu", "tpu"],
|
choices=["auto", "cuda", "neuron", "cpu", "tpu", "xpu"],
|
||||||
help='Device type for vLLM execution.')
|
help='Device type for vLLM execution.')
|
||||||
|
|
||||||
# Related to Vision-language models such as llava
|
# Related to Vision-language models such as llava
|
||||||
|
@ -383,6 +383,17 @@ class AsyncLLMEngine:
|
|||||||
"Distributed execution is not supported with the CPU backend.")
|
"Distributed execution is not supported with the CPU backend.")
|
||||||
from vllm.executor.cpu_executor import CPUExecutorAsync
|
from vllm.executor.cpu_executor import CPUExecutorAsync
|
||||||
executor_class = CPUExecutorAsync
|
executor_class = CPUExecutorAsync
|
||||||
|
elif engine_config.device_config.device_type == "xpu":
|
||||||
|
if distributed_executor_backend is None:
|
||||||
|
from vllm.executor.xpu_executor import XPUExecutorAsync
|
||||||
|
executor_class = XPUExecutorAsync
|
||||||
|
elif distributed_executor_backend == "ray":
|
||||||
|
initialize_ray_cluster(engine_config.parallel_config)
|
||||||
|
from vllm.executor.ray_xpu_executor import RayXPUExecutorAsync
|
||||||
|
executor_class = RayXPUExecutorAsync
|
||||||
|
else:
|
||||||
|
raise RuntimeError(
|
||||||
|
"Not supported distributed execution model on XPU device.")
|
||||||
elif distributed_executor_backend == "ray":
|
elif distributed_executor_backend == "ray":
|
||||||
initialize_ray_cluster(engine_config.parallel_config)
|
initialize_ray_cluster(engine_config.parallel_config)
|
||||||
from vllm.executor.ray_gpu_executor import RayGPUExecutorAsync
|
from vllm.executor.ray_gpu_executor import RayGPUExecutorAsync
|
||||||
|
@ -347,6 +347,14 @@ class LLMEngine:
|
|||||||
elif engine_config.device_config.device_type == "cpu":
|
elif engine_config.device_config.device_type == "cpu":
|
||||||
from vllm.executor.cpu_executor import CPUExecutor
|
from vllm.executor.cpu_executor import CPUExecutor
|
||||||
executor_class = CPUExecutor
|
executor_class = CPUExecutor
|
||||||
|
elif engine_config.device_config.device_type == "xpu":
|
||||||
|
if distributed_executor_backend == "ray":
|
||||||
|
initialize_ray_cluster(engine_config.parallel_config)
|
||||||
|
from vllm.executor.ray_xpu_executor import RayXPUExecutor
|
||||||
|
executor_class = RayXPUExecutor
|
||||||
|
else:
|
||||||
|
from vllm.executor.xpu_executor import XPUExecutor
|
||||||
|
executor_class = XPUExecutor
|
||||||
elif distributed_executor_backend == "ray":
|
elif distributed_executor_backend == "ray":
|
||||||
initialize_ray_cluster(engine_config.parallel_config)
|
initialize_ray_cluster(engine_config.parallel_config)
|
||||||
from vllm.executor.ray_gpu_executor import RayGPUExecutor
|
from vllm.executor.ray_gpu_executor import RayGPUExecutor
|
||||||
|
@ -3,7 +3,7 @@ from typing import List, Optional, Tuple
|
|||||||
|
|
||||||
from vllm.config import ParallelConfig
|
from vllm.config import ParallelConfig
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
from vllm.utils import get_ip, is_hip
|
from vllm.utils import get_ip, is_hip, is_xpu
|
||||||
from vllm.worker.worker_base import WorkerWrapperBase
|
from vllm.worker.worker_base import WorkerWrapperBase
|
||||||
|
|
||||||
logger = init_logger(__name__)
|
logger = init_logger(__name__)
|
||||||
@ -71,7 +71,7 @@ def initialize_ray_cluster(
|
|||||||
"serving.")
|
"serving.")
|
||||||
|
|
||||||
# Connect to a ray cluster.
|
# Connect to a ray cluster.
|
||||||
if is_hip():
|
if is_hip() or is_xpu():
|
||||||
ray.init(address=ray_address,
|
ray.init(address=ray_address,
|
||||||
ignore_reinit_error=True,
|
ignore_reinit_error=True,
|
||||||
num_gpus=parallel_config.world_size)
|
num_gpus=parallel_config.world_size)
|
||||||
|
401
vllm/executor/ray_xpu_executor.py
Normal file
401
vllm/executor/ray_xpu_executor.py
Normal file
@ -0,0 +1,401 @@
|
|||||||
|
import asyncio
|
||||||
|
import os
|
||||||
|
import pickle
|
||||||
|
from collections import defaultdict
|
||||||
|
from itertools import islice, repeat
|
||||||
|
from typing import (TYPE_CHECKING, Any, Awaitable, Dict, List, Optional, Set,
|
||||||
|
Tuple, Union)
|
||||||
|
|
||||||
|
from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig,
|
||||||
|
ModelConfig, ParallelConfig, SchedulerConfig,
|
||||||
|
SpeculativeConfig, VisionLanguageConfig)
|
||||||
|
from vllm.executor.distributed_gpu_executor import ( # yapf: disable
|
||||||
|
DistributedGPUExecutor, DistributedGPUExecutorAsync)
|
||||||
|
from vllm.executor.ray_utils import RayWorkerWrapper, ray
|
||||||
|
from vllm.logger import init_logger
|
||||||
|
from vllm.lora.request import LoRARequest
|
||||||
|
from vllm.sequence import ExecuteModelRequest, SamplerOutput
|
||||||
|
from vllm.utils import (get_distributed_init_method, get_ip, get_open_port,
|
||||||
|
make_async)
|
||||||
|
|
||||||
|
if ray is not None:
|
||||||
|
from ray.util.scheduling_strategies import PlacementGroupSchedulingStrategy
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from ray.util.placement_group import PlacementGroup
|
||||||
|
|
||||||
|
logger = init_logger(__name__)
|
||||||
|
|
||||||
|
# If the env var is set, it uses the Ray's compiled DAG API
|
||||||
|
# which optimizes the control plane overhead.
|
||||||
|
# Run vLLM with VLLM_USE_RAY_COMPILED_DAG=1 to enable it.
|
||||||
|
USE_RAY_COMPILED_DAG = bool(os.getenv("VLLM_USE_RAY_COMPILED_DAG", 0))
|
||||||
|
|
||||||
|
|
||||||
|
class RayXPUExecutor(DistributedGPUExecutor):
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
model_config: ModelConfig,
|
||||||
|
cache_config: CacheConfig,
|
||||||
|
parallel_config: ParallelConfig,
|
||||||
|
scheduler_config: SchedulerConfig,
|
||||||
|
device_config: DeviceConfig,
|
||||||
|
load_config: LoadConfig,
|
||||||
|
lora_config: Optional[LoRAConfig],
|
||||||
|
vision_language_config: Optional[VisionLanguageConfig],
|
||||||
|
speculative_config: Optional[SpeculativeConfig],
|
||||||
|
) -> None:
|
||||||
|
assert device_config.device_type == "xpu"
|
||||||
|
assert (not speculative_config
|
||||||
|
), "Speculative decoding not yet supported for XPU backend"
|
||||||
|
|
||||||
|
self.model_config = model_config
|
||||||
|
self.cache_config = cache_config
|
||||||
|
self.load_config = load_config
|
||||||
|
self.lora_config = lora_config
|
||||||
|
self.parallel_config = parallel_config
|
||||||
|
self.scheduler_config = scheduler_config
|
||||||
|
self.device_config = device_config
|
||||||
|
self.vision_language_config = vision_language_config
|
||||||
|
|
||||||
|
placement_group = self.parallel_config.placement_group
|
||||||
|
|
||||||
|
# Disable Ray usage stats collection.
|
||||||
|
ray_usage = os.environ.get("RAY_USAGE_STATS_ENABLED", "0")
|
||||||
|
if ray_usage != "1":
|
||||||
|
os.environ["RAY_USAGE_STATS_ENABLED"] = "0"
|
||||||
|
|
||||||
|
# Create the parallel GPU workers.
|
||||||
|
self._init_workers_ray(placement_group)
|
||||||
|
|
||||||
|
# Profile the memory usage and initialize the cache.
|
||||||
|
self.forward_dag = None
|
||||||
|
if USE_RAY_COMPILED_DAG:
|
||||||
|
self.forward_dag = self._compiled_ray_dag()
|
||||||
|
|
||||||
|
# This is non-None when the execute model loop is running
|
||||||
|
# in the parallel workers. It's a coroutine in the AsyncLLMEngine case.
|
||||||
|
self.parallel_worker_tasks: Optional[Union[Any, Awaitable[Any]]] = None
|
||||||
|
# Updated by implementations that require additional args to be passed
|
||||||
|
# to the _run_workers execute_model call
|
||||||
|
self.extra_execute_model_run_workers_kwargs: Dict[str, Any] = {}
|
||||||
|
|
||||||
|
def _init_executor(self) -> None:
|
||||||
|
pass
|
||||||
|
|
||||||
|
def determine_num_available_blocks(self) -> Tuple[int, int]:
|
||||||
|
"""Determine the number of available KV blocks.
|
||||||
|
|
||||||
|
This invokes `determine_num_available_blocks` on each worker and takes
|
||||||
|
the min of the results, guaranteeing that the selected cache sizes are
|
||||||
|
compatible with all workers.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
- Tuple[num_gpu_blocks, num_cpu_blocks]
|
||||||
|
"""
|
||||||
|
# Get the maximum number of blocks that can be allocated on GPU and CPU.
|
||||||
|
num_blocks = self._run_workers("determine_num_available_blocks", )
|
||||||
|
|
||||||
|
# Since we use a shared centralized controller, we take the minimum
|
||||||
|
# number of blocks across all workers to make sure all the memory
|
||||||
|
# operators can be applied to all workers.
|
||||||
|
num_gpu_blocks = min(b[0] for b in num_blocks)
|
||||||
|
num_cpu_blocks = min(b[1] for b in num_blocks)
|
||||||
|
|
||||||
|
return num_gpu_blocks, num_cpu_blocks
|
||||||
|
|
||||||
|
def _init_workers_ray(self, placement_group: "PlacementGroup",
|
||||||
|
**ray_remote_kwargs):
|
||||||
|
if self.parallel_config.tensor_parallel_size == 1:
|
||||||
|
# For single GPU case, we use a ray worker with constrained memory.
|
||||||
|
num_gpus = self.cache_config.gpu_memory_utilization
|
||||||
|
else:
|
||||||
|
# Otherwise, the ray workers are allocated with a full GPU.
|
||||||
|
num_gpus = 1
|
||||||
|
|
||||||
|
# The driver dummy worker does not actually use any resources.
|
||||||
|
# It holds the resource for the driver worker.
|
||||||
|
self.driver_dummy_worker: Optional[RayWorkerWrapper] = None
|
||||||
|
# The remaining workers are the actual ray actors.
|
||||||
|
self.workers: List[RayWorkerWrapper] = []
|
||||||
|
|
||||||
|
# Create the workers.
|
||||||
|
driver_ip = get_ip()
|
||||||
|
for bundle_id, bundle in enumerate(placement_group.bundle_specs):
|
||||||
|
if not bundle.get("GPU", 0):
|
||||||
|
continue
|
||||||
|
scheduling_strategy = PlacementGroupSchedulingStrategy(
|
||||||
|
placement_group=placement_group,
|
||||||
|
placement_group_capture_child_tasks=True,
|
||||||
|
placement_group_bundle_index=bundle_id,
|
||||||
|
)
|
||||||
|
worker = ray.remote(
|
||||||
|
num_cpus=0,
|
||||||
|
num_gpus=num_gpus,
|
||||||
|
scheduling_strategy=scheduling_strategy,
|
||||||
|
**ray_remote_kwargs,
|
||||||
|
)(RayWorkerWrapper).remote(
|
||||||
|
worker_module_name="vllm.worker.xpu_worker",
|
||||||
|
worker_class_name="XPUWorker",
|
||||||
|
trust_remote_code=self.model_config.trust_remote_code,
|
||||||
|
)
|
||||||
|
|
||||||
|
worker_ip = ray.get(worker.get_node_ip.remote())
|
||||||
|
if worker_ip == driver_ip and self.driver_dummy_worker is None:
|
||||||
|
# If the worker is on the same node as the driver, we use it
|
||||||
|
# as the resource holder for the driver process.
|
||||||
|
self.driver_dummy_worker = worker
|
||||||
|
self.driver_worker = RayWorkerWrapper(
|
||||||
|
worker_module_name="vllm.worker.xpu_worker",
|
||||||
|
worker_class_name="XPUWorker",
|
||||||
|
trust_remote_code=self.model_config.trust_remote_code,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
# Else, added to the list of workers.
|
||||||
|
self.workers.append(worker)
|
||||||
|
if self.driver_dummy_worker is None:
|
||||||
|
raise ValueError(
|
||||||
|
"Ray does not allocate any GPUs on the driver node. Consider "
|
||||||
|
"adjusting the Ray placement group or running the driver on a "
|
||||||
|
"GPU node.")
|
||||||
|
|
||||||
|
# Get the set of GPU IDs used on each node.
|
||||||
|
worker_node_and_gpu_ids = self._run_workers("get_node_and_gpu_ids",
|
||||||
|
use_dummy_driver=True)
|
||||||
|
|
||||||
|
node_workers = defaultdict(list)
|
||||||
|
node_gpus = defaultdict(list)
|
||||||
|
|
||||||
|
for i, (node_id, gpu_ids) in enumerate(worker_node_and_gpu_ids):
|
||||||
|
node_workers[node_id].append(i)
|
||||||
|
node_gpus[node_id].extend(gpu_ids)
|
||||||
|
for node_id, gpu_ids in node_gpus.items():
|
||||||
|
node_gpus[node_id] = sorted(gpu_ids)
|
||||||
|
|
||||||
|
# TODO: add env var for xpu
|
||||||
|
|
||||||
|
distributed_init_method = get_distributed_init_method(
|
||||||
|
driver_ip, get_open_port())
|
||||||
|
|
||||||
|
def collect_arg_helper_func(**kwargs):
|
||||||
|
# avoid writing `{"name": value}` manually
|
||||||
|
return kwargs
|
||||||
|
|
||||||
|
init_worker_all_kwargs = []
|
||||||
|
|
||||||
|
# Initialize the actual workers inside worker wrapper.
|
||||||
|
for rank, (node_id, _) in enumerate(worker_node_and_gpu_ids, ):
|
||||||
|
local_rank = node_workers[node_id].index(rank)
|
||||||
|
init_worker_all_kwargs.append(
|
||||||
|
collect_arg_helper_func(
|
||||||
|
model_config=self.model_config,
|
||||||
|
parallel_config=self.parallel_config,
|
||||||
|
scheduler_config=self.scheduler_config,
|
||||||
|
device_config=self.device_config,
|
||||||
|
cache_config=self.cache_config,
|
||||||
|
load_config=self.load_config,
|
||||||
|
local_rank=local_rank,
|
||||||
|
rank=rank,
|
||||||
|
distributed_init_method=distributed_init_method,
|
||||||
|
lora_config=self.lora_config,
|
||||||
|
vision_language_config=self.vision_language_config,
|
||||||
|
is_driver_worker=rank == 0,
|
||||||
|
))
|
||||||
|
self._run_workers("init_worker", all_kwargs=init_worker_all_kwargs)
|
||||||
|
|
||||||
|
self._run_workers("init_device")
|
||||||
|
self._run_workers(
|
||||||
|
"load_model",
|
||||||
|
max_concurrent_workers=self.parallel_config.
|
||||||
|
max_parallel_loading_workers,
|
||||||
|
)
|
||||||
|
|
||||||
|
def initialize_cache(self, num_gpu_blocks: int,
|
||||||
|
num_cpu_blocks: int) -> None:
|
||||||
|
"""Initialize the KV cache in all workers.
|
||||||
|
"""
|
||||||
|
|
||||||
|
# NOTE: We log here to avoid multiple logs when number of workers is
|
||||||
|
# greater than one. We could log in the engine, but not all executors
|
||||||
|
# have GPUs.
|
||||||
|
logger.info("# GPU blocks: %d, "
|
||||||
|
"# CPU blocks: %d", num_gpu_blocks, num_cpu_blocks)
|
||||||
|
|
||||||
|
self.cache_config.num_gpu_blocks = num_gpu_blocks
|
||||||
|
self.cache_config.num_cpu_blocks = num_cpu_blocks
|
||||||
|
|
||||||
|
self._run_workers("initialize_cache",
|
||||||
|
num_gpu_blocks=num_gpu_blocks,
|
||||||
|
num_cpu_blocks=num_cpu_blocks)
|
||||||
|
|
||||||
|
def _driver_execute_model(
|
||||||
|
self,
|
||||||
|
execute_model_req: Optional[ExecuteModelRequest] = None
|
||||||
|
) -> List[SamplerOutput]:
|
||||||
|
"""Run execute_model in the driver worker.
|
||||||
|
|
||||||
|
Passing None will cause the driver to stop the model execution
|
||||||
|
loop running in each of the remote workers.
|
||||||
|
"""
|
||||||
|
return self.driver_worker.execute_method("execute_model",
|
||||||
|
execute_model_req)
|
||||||
|
|
||||||
|
def add_lora(self, lora_request: LoRARequest) -> bool:
|
||||||
|
assert lora_request.lora_int_id > 0, "lora_id must be greater than 0."
|
||||||
|
return self._run_workers(
|
||||||
|
"add_lora",
|
||||||
|
lora_request=lora_request,
|
||||||
|
)
|
||||||
|
|
||||||
|
def remove_lora(self, lora_id: int) -> bool:
|
||||||
|
assert lora_id > 0, "lora_id must be greater than 0."
|
||||||
|
return self._run_workers(
|
||||||
|
"remove_lora",
|
||||||
|
lora_id=lora_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
def list_loras(self) -> Set[int]:
|
||||||
|
return self._run_workers("list_loras")
|
||||||
|
|
||||||
|
def _run_workers(
|
||||||
|
self,
|
||||||
|
method: str,
|
||||||
|
*args,
|
||||||
|
async_run_remote_workers_only: bool = False,
|
||||||
|
all_args: Optional[List[Tuple[Any, ...]]] = None,
|
||||||
|
all_kwargs: Optional[List[Dict[str, Any]]] = None,
|
||||||
|
use_dummy_driver: bool = False,
|
||||||
|
max_concurrent_workers: Optional[int] = None,
|
||||||
|
use_ray_compiled_dag: bool = False,
|
||||||
|
**kwargs,
|
||||||
|
) -> Any:
|
||||||
|
"""Runs the given method on all workers. Can be used in the following
|
||||||
|
ways:
|
||||||
|
|
||||||
|
- args/kwargs: All workers share the same args/kwargs
|
||||||
|
- args/kwargs and driver_args/driver_kwargs: Driver worker has
|
||||||
|
different args
|
||||||
|
- all_args/all_kwargs: args/kwargs for each worker are specified
|
||||||
|
individually
|
||||||
|
"""
|
||||||
|
|
||||||
|
if max_concurrent_workers:
|
||||||
|
raise NotImplementedError(
|
||||||
|
"max_concurrent_workers is not supported yet.")
|
||||||
|
|
||||||
|
count = len(self.workers)
|
||||||
|
all_worker_args = repeat(args, count) if all_args is None \
|
||||||
|
else islice(all_args, 1, None)
|
||||||
|
all_worker_kwargs = repeat(kwargs, count) if all_kwargs is None \
|
||||||
|
else islice(all_kwargs, 1, None)
|
||||||
|
|
||||||
|
if use_ray_compiled_dag:
|
||||||
|
# Right now, compiled DAG can only accept a single
|
||||||
|
# input. TODO(sang): Fix it.
|
||||||
|
assert self.forward_dag is not None
|
||||||
|
output_channels = self.forward_dag.execute(1)
|
||||||
|
else:
|
||||||
|
# Start the ray workers first.
|
||||||
|
ray_worker_outputs = [
|
||||||
|
worker.execute_method.remote(method, *worker_args,
|
||||||
|
**worker_kwargs)
|
||||||
|
for (worker, worker_args, worker_kwargs
|
||||||
|
) in zip(self.workers, all_worker_args, all_worker_kwargs)
|
||||||
|
]
|
||||||
|
if async_run_remote_workers_only:
|
||||||
|
# Just return futures
|
||||||
|
return ray_worker_outputs
|
||||||
|
|
||||||
|
driver_args = args if all_args is None else all_args[0]
|
||||||
|
driver_kwargs = kwargs if all_kwargs is None else all_kwargs[0]
|
||||||
|
|
||||||
|
# Start the driver worker after all the ray workers.
|
||||||
|
if not use_dummy_driver:
|
||||||
|
driver_worker_output = self.driver_worker.execute_method(
|
||||||
|
method, *driver_args, **driver_kwargs)
|
||||||
|
else:
|
||||||
|
assert self.driver_dummy_worker is not None
|
||||||
|
driver_worker_output = ray.get(
|
||||||
|
self.driver_dummy_worker.execute_method.remote(
|
||||||
|
method, *driver_args, **driver_kwargs))
|
||||||
|
# Get the results of the ray workers.
|
||||||
|
if self.workers:
|
||||||
|
if use_ray_compiled_dag:
|
||||||
|
try:
|
||||||
|
ray_worker_outputs = [
|
||||||
|
pickle.loads(chan.begin_read())
|
||||||
|
for chan in output_channels
|
||||||
|
]
|
||||||
|
finally:
|
||||||
|
# Has to call end_read in order to reuse the DAG.
|
||||||
|
for chan in output_channels:
|
||||||
|
chan.end_read()
|
||||||
|
else:
|
||||||
|
ray_worker_outputs = ray.get(ray_worker_outputs)
|
||||||
|
|
||||||
|
return [driver_worker_output] + ray_worker_outputs
|
||||||
|
|
||||||
|
def _wait_for_tasks_completion(self, parallel_worker_tasks: Any) -> None:
|
||||||
|
"""Wait for futures returned from _run_workers() with
|
||||||
|
async_run_remote_workers_only to complete."""
|
||||||
|
ray.get(parallel_worker_tasks)
|
||||||
|
|
||||||
|
def _compiled_ray_dag(self):
|
||||||
|
import pkg_resources
|
||||||
|
required_version = "2.9"
|
||||||
|
current_version = pkg_resources.get_distribution("ray").version
|
||||||
|
if current_version < required_version:
|
||||||
|
raise ValueError(f"Ray version {required_version} or greater is "
|
||||||
|
f"required, but found {current_version}")
|
||||||
|
|
||||||
|
from ray.dag import InputNode, MultiOutputNode
|
||||||
|
assert self.parallel_config.worker_use_ray
|
||||||
|
|
||||||
|
# Right now, compiled DAG requires at least 1 arg. We send
|
||||||
|
# a dummy value for now. It will be fixed soon.
|
||||||
|
with InputNode() as input_data:
|
||||||
|
forward_dag = MultiOutputNode([
|
||||||
|
worker.execute_model_compiled_dag_remote.
|
||||||
|
bind( # type: ignore[attr-defined]
|
||||||
|
input_data) for worker in self.workers
|
||||||
|
])
|
||||||
|
return forward_dag.experimental_compile()
|
||||||
|
|
||||||
|
def check_health(self) -> None:
|
||||||
|
"""Raises an error if engine is unhealthy."""
|
||||||
|
self._check_if_any_actor_is_dead()
|
||||||
|
|
||||||
|
def _check_if_any_actor_is_dead(self):
|
||||||
|
if not self.workers:
|
||||||
|
return
|
||||||
|
|
||||||
|
dead_actors = []
|
||||||
|
for actor in self.workers:
|
||||||
|
actor_state = ray.state.actors(actor._ray_actor_id.hex()) # pylint: disable=protected-access
|
||||||
|
if actor_state["State"] == "DEAD":
|
||||||
|
dead_actors.append(actor)
|
||||||
|
if dead_actors:
|
||||||
|
raise RuntimeError("At least one Worker is dead. "
|
||||||
|
f"Dead Workers: {dead_actors}. ")
|
||||||
|
|
||||||
|
|
||||||
|
class RayXPUExecutorAsync(RayXPUExecutor, DistributedGPUExecutorAsync):
|
||||||
|
|
||||||
|
def __init__(self, *args, **kwargs):
|
||||||
|
super().__init__(*args, **kwargs)
|
||||||
|
self.driver_exec_method = make_async(self.driver_worker.execute_method)
|
||||||
|
|
||||||
|
async def _driver_execute_model_async(
|
||||||
|
self,
|
||||||
|
execute_model_req: Optional[ExecuteModelRequest] = None
|
||||||
|
) -> List[SamplerOutput]:
|
||||||
|
return await self.driver_exec_method("execute_model",
|
||||||
|
execute_model_req)
|
||||||
|
|
||||||
|
async def _start_worker_execution_loop(self):
|
||||||
|
coros = [
|
||||||
|
worker.execute_method.remote("start_worker_execution_loop")
|
||||||
|
for worker in self.workers
|
||||||
|
]
|
||||||
|
return await asyncio.gather(*coros)
|
98
vllm/executor/xpu_executor.py
Normal file
98
vllm/executor/xpu_executor.py
Normal file
@ -0,0 +1,98 @@
|
|||||||
|
from typing import List, Optional
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig,
|
||||||
|
ModelConfig, ParallelConfig, SchedulerConfig,
|
||||||
|
SpeculativeConfig, VisionLanguageConfig)
|
||||||
|
from vllm.executor.executor_base import ExecutorAsyncBase
|
||||||
|
from vllm.executor.gpu_executor import GPUExecutor
|
||||||
|
from vllm.logger import init_logger
|
||||||
|
from vllm.sequence import ExecuteModelRequest, SamplerOutput
|
||||||
|
from vllm.utils import make_async
|
||||||
|
from vllm.worker.worker_base import WorkerWrapperBase
|
||||||
|
|
||||||
|
logger = init_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class XPUExecutor(GPUExecutor):
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
model_config: ModelConfig,
|
||||||
|
cache_config: CacheConfig,
|
||||||
|
parallel_config: ParallelConfig,
|
||||||
|
scheduler_config: SchedulerConfig,
|
||||||
|
device_config: DeviceConfig,
|
||||||
|
load_config: LoadConfig,
|
||||||
|
lora_config: Optional[LoRAConfig],
|
||||||
|
vision_language_config: Optional[VisionLanguageConfig],
|
||||||
|
speculative_config: Optional[SpeculativeConfig],
|
||||||
|
) -> None:
|
||||||
|
assert device_config.device_type == "xpu"
|
||||||
|
assert (not speculative_config
|
||||||
|
), "Speculative decoding not yet supported for XPU backend"
|
||||||
|
|
||||||
|
model_config = _verify_and_get_model_config(model_config)
|
||||||
|
|
||||||
|
self.model_config = model_config
|
||||||
|
self.cache_config = cache_config
|
||||||
|
self.load_config = load_config
|
||||||
|
self.lora_config = lora_config
|
||||||
|
self.parallel_config = parallel_config
|
||||||
|
self.scheduler_config = scheduler_config
|
||||||
|
self.device_config = device_config
|
||||||
|
self.vision_language_config = vision_language_config
|
||||||
|
self.speculative_config = None
|
||||||
|
|
||||||
|
# Instantiate the worker and load the model to GPU.
|
||||||
|
self._init_executor()
|
||||||
|
|
||||||
|
def _create_worker(self,
|
||||||
|
local_rank: int = 0,
|
||||||
|
rank: int = 0,
|
||||||
|
distributed_init_method: Optional[str] = None):
|
||||||
|
if self.speculative_config is None:
|
||||||
|
worker_module_name = "vllm.worker.xpu_worker"
|
||||||
|
worker_class_name = "XPUWorker"
|
||||||
|
else:
|
||||||
|
raise NotImplementedError(
|
||||||
|
"XPU does not support speculative decoding")
|
||||||
|
|
||||||
|
wrapper = WorkerWrapperBase(
|
||||||
|
worker_module_name=worker_module_name,
|
||||||
|
worker_class_name=worker_class_name,
|
||||||
|
)
|
||||||
|
wrapper.init_worker(**self._get_worker_kwargs(local_rank, rank,
|
||||||
|
distributed_init_method))
|
||||||
|
return wrapper.worker
|
||||||
|
|
||||||
|
def execute_model(
|
||||||
|
self,
|
||||||
|
execute_model_req: ExecuteModelRequest) -> List[SamplerOutput]:
|
||||||
|
output = self.driver_worker.execute_model(execute_model_req)
|
||||||
|
return output
|
||||||
|
|
||||||
|
|
||||||
|
class XPUExecutorAsync(XPUExecutor, ExecutorAsyncBase):
|
||||||
|
|
||||||
|
async def execute_model_async(
|
||||||
|
self,
|
||||||
|
execute_model_req: ExecuteModelRequest,
|
||||||
|
) -> List[SamplerOutput]:
|
||||||
|
output = await make_async(self.driver_worker.execute_model
|
||||||
|
)(execute_model_req=execute_model_req)
|
||||||
|
return output
|
||||||
|
|
||||||
|
|
||||||
|
def _verify_and_get_model_config(config: ModelConfig) -> ModelConfig:
|
||||||
|
if config.dtype == torch.bfloat16:
|
||||||
|
logger.warning(
|
||||||
|
"bfloat16 is not fully supported on XPU, casting to float16.")
|
||||||
|
config.dtype = torch.float16
|
||||||
|
if not config.enforce_eager:
|
||||||
|
logger.warning(
|
||||||
|
"CUDA graph is not supported on XPU, fallback to the eager "
|
||||||
|
"mode.")
|
||||||
|
config.enforce_eager = True
|
||||||
|
return config
|
@ -1,6 +1,6 @@
|
|||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
|
|
||||||
from vllm.utils import is_cpu, is_hip, is_tpu
|
from vllm.utils import is_cpu, is_hip, is_tpu, is_xpu
|
||||||
|
|
||||||
|
|
||||||
class CustomOp(nn.Module):
|
class CustomOp(nn.Module):
|
||||||
@ -29,9 +29,7 @@ class CustomOp(nn.Module):
|
|||||||
return self.forward_cuda(*args, **kwargs)
|
return self.forward_cuda(*args, **kwargs)
|
||||||
|
|
||||||
def forward_xpu(self, *args, **kwargs):
|
def forward_xpu(self, *args, **kwargs):
|
||||||
# By default, we assume that XPU ops are compatible with CUDA ops.
|
raise NotImplementedError
|
||||||
# NOTE(woosuk): This is a placeholder for future extensions.
|
|
||||||
return self.forward_cuda(*args, **kwargs)
|
|
||||||
|
|
||||||
def forward_cpu(self, *args, **kwargs):
|
def forward_cpu(self, *args, **kwargs):
|
||||||
# By default, we assume that CPU ops are compatible with CUDA ops.
|
# By default, we assume that CPU ops are compatible with CUDA ops.
|
||||||
@ -58,5 +56,7 @@ class CustomOp(nn.Module):
|
|||||||
return self.forward_cpu
|
return self.forward_cpu
|
||||||
elif is_tpu():
|
elif is_tpu():
|
||||||
return self.forward_tpu
|
return self.forward_tpu
|
||||||
|
elif is_xpu():
|
||||||
|
return self.forward_xpu
|
||||||
else:
|
else:
|
||||||
return self.forward_cuda
|
return self.forward_cuda
|
||||||
|
@ -37,6 +37,15 @@ class SiluAndMul(CustomOp):
|
|||||||
ops.silu_and_mul(out, x)
|
ops.silu_and_mul(out, x)
|
||||||
return out
|
return out
|
||||||
|
|
||||||
|
def forward_xpu(self, x: torch.Tensor) -> torch.Tensor:
|
||||||
|
from vllm._ipex_ops import ipex_ops as ops
|
||||||
|
|
||||||
|
d = x.shape[-1] // 2
|
||||||
|
output_shape = (x.shape[:-1] + (d, ))
|
||||||
|
out = torch.empty(output_shape, dtype=x.dtype, device=x.device)
|
||||||
|
ops.silu_and_mul(out, x)
|
||||||
|
return out
|
||||||
|
|
||||||
|
|
||||||
class GeluAndMul(CustomOp):
|
class GeluAndMul(CustomOp):
|
||||||
"""An activation function for GeGLU.
|
"""An activation function for GeGLU.
|
||||||
@ -71,6 +80,18 @@ class GeluAndMul(CustomOp):
|
|||||||
ops.gelu_tanh_and_mul(out, x)
|
ops.gelu_tanh_and_mul(out, x)
|
||||||
return out
|
return out
|
||||||
|
|
||||||
|
def forward_xpu(self, x: torch.Tensor) -> torch.Tensor:
|
||||||
|
from vllm._ipex_ops import ipex_ops as ops
|
||||||
|
|
||||||
|
d = x.shape[-1] // 2
|
||||||
|
output_shape = (x.shape[:-1] + (d, ))
|
||||||
|
out = torch.empty(output_shape, dtype=x.dtype, device=x.device)
|
||||||
|
if self.approximate == "none":
|
||||||
|
ops.gelu_and_mul(out, x)
|
||||||
|
elif self.approximate == "tanh":
|
||||||
|
ops.gelu_tanh_and_mul(out, x)
|
||||||
|
return out
|
||||||
|
|
||||||
def extra_repr(self) -> str:
|
def extra_repr(self) -> str:
|
||||||
return f'approximate={repr(self.approximate)}'
|
return f'approximate={repr(self.approximate)}'
|
||||||
|
|
||||||
@ -90,6 +111,13 @@ class NewGELU(CustomOp):
|
|||||||
ops.gelu_new(out, x)
|
ops.gelu_new(out, x)
|
||||||
return out
|
return out
|
||||||
|
|
||||||
|
def forward_xpu(self, x: torch.Tensor) -> torch.Tensor:
|
||||||
|
from vllm._ipex_ops import ipex_ops as ops
|
||||||
|
|
||||||
|
out = torch.empty_like(x)
|
||||||
|
ops.gelu_new(out, x)
|
||||||
|
return out
|
||||||
|
|
||||||
|
|
||||||
class FastGELU(CustomOp):
|
class FastGELU(CustomOp):
|
||||||
|
|
||||||
@ -105,6 +133,13 @@ class FastGELU(CustomOp):
|
|||||||
ops.gelu_fast(out, x)
|
ops.gelu_fast(out, x)
|
||||||
return out
|
return out
|
||||||
|
|
||||||
|
def forward_xpu(self, x: torch.Tensor) -> torch.Tensor:
|
||||||
|
from vllm._ipex_ops import ipex_ops as ops
|
||||||
|
|
||||||
|
out = torch.empty_like(x)
|
||||||
|
ops.gelu_fast(out, x)
|
||||||
|
return out
|
||||||
|
|
||||||
|
|
||||||
class ScaledActivation(nn.Module):
|
class ScaledActivation(nn.Module):
|
||||||
"""An activation function with post-scale parameters.
|
"""An activation function with post-scale parameters.
|
||||||
|
@ -67,6 +67,30 @@ class RMSNorm(CustomOp):
|
|||||||
)
|
)
|
||||||
return out
|
return out
|
||||||
|
|
||||||
|
def forward_xpu(
|
||||||
|
self,
|
||||||
|
x: torch.Tensor,
|
||||||
|
residual: Optional[torch.Tensor] = None,
|
||||||
|
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
|
||||||
|
from vllm._ipex_ops import ipex_ops as ops
|
||||||
|
|
||||||
|
if residual is not None:
|
||||||
|
ops.fused_add_rms_norm(
|
||||||
|
x,
|
||||||
|
residual,
|
||||||
|
self.weight.data,
|
||||||
|
self.variance_epsilon,
|
||||||
|
)
|
||||||
|
return x, residual
|
||||||
|
out = torch.empty_like(x)
|
||||||
|
ops.rms_norm(
|
||||||
|
out,
|
||||||
|
x,
|
||||||
|
self.weight.data,
|
||||||
|
self.variance_epsilon,
|
||||||
|
)
|
||||||
|
return out
|
||||||
|
|
||||||
def extra_repr(self) -> str:
|
def extra_repr(self) -> str:
|
||||||
s = f"hidden_size={self.weight.data.size(0)}"
|
s = f"hidden_size={self.weight.data.size(0)}"
|
||||||
s += f", eps={self.variance_epsilon}"
|
s += f", eps={self.variance_epsilon}"
|
||||||
|
@ -221,6 +221,29 @@ class RotaryEmbedding(CustomOp):
|
|||||||
self.cos_sin_cache, self.is_neox_style)
|
self.cos_sin_cache, self.is_neox_style)
|
||||||
return query, key
|
return query, key
|
||||||
|
|
||||||
|
def forward_xpu(
|
||||||
|
self,
|
||||||
|
positions: torch.Tensor,
|
||||||
|
query: torch.Tensor,
|
||||||
|
key: torch.Tensor,
|
||||||
|
offsets: Optional[torch.Tensor] = None,
|
||||||
|
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||||
|
from vllm._ipex_ops import ipex_ops as ops
|
||||||
|
|
||||||
|
self.cos_sin_cache = self.cos_sin_cache.to(positions.device,
|
||||||
|
dtype=query.dtype)
|
||||||
|
# ops.rotary_embedding()/batched_rotary_embedding()
|
||||||
|
# are in-place operations that update the query and key tensors.
|
||||||
|
if offsets is not None:
|
||||||
|
ops.batched_rotary_embedding(positions, query, key, self.head_size,
|
||||||
|
self.cos_sin_cache,
|
||||||
|
self.is_neox_style, self.rotary_dim,
|
||||||
|
offsets)
|
||||||
|
else:
|
||||||
|
ops.rotary_embedding(positions, query, key, self.head_size,
|
||||||
|
self.cos_sin_cache, self.is_neox_style)
|
||||||
|
return query, key
|
||||||
|
|
||||||
def forward_tpu(
|
def forward_tpu(
|
||||||
self,
|
self,
|
||||||
positions: torch.Tensor,
|
positions: torch.Tensor,
|
||||||
|
@ -307,7 +307,7 @@ class VocabParallelEmbedding(torch.nn.Module):
|
|||||||
else:
|
else:
|
||||||
masked_input = input_
|
masked_input = input_
|
||||||
# Get the embeddings.
|
# Get the embeddings.
|
||||||
output_parallel = F.embedding(masked_input, self.weight)
|
output_parallel = F.embedding(masked_input.long(), self.weight)
|
||||||
# Mask the output embedding.
|
# Mask the output embedding.
|
||||||
if self.tp_size > 1:
|
if self.tp_size > 1:
|
||||||
output_parallel.masked_fill_(input_mask.unsqueeze(1), 0)
|
output_parallel.masked_fill_(input_mask.unsqueeze(1), 0)
|
||||||
|
@ -160,6 +160,26 @@ def is_tpu() -> bool:
|
|||||||
return libtpu is not None
|
return libtpu is not None
|
||||||
|
|
||||||
|
|
||||||
|
@lru_cache(maxsize=None)
|
||||||
|
def is_xpu() -> bool:
|
||||||
|
from importlib.metadata import version
|
||||||
|
is_xpu_flag = "xpu" in version("vllm")
|
||||||
|
# vllm is not build with xpu
|
||||||
|
if not is_xpu_flag:
|
||||||
|
return False
|
||||||
|
try:
|
||||||
|
import intel_extension_for_pytorch as ipex # noqa: F401
|
||||||
|
_import_ipex = True
|
||||||
|
except ImportError as e:
|
||||||
|
logger.warning("Import Error for IPEX: %s", e.msg)
|
||||||
|
_import_ipex = False
|
||||||
|
# ipex dependency is not ready
|
||||||
|
if not _import_ipex:
|
||||||
|
logger.warning("not found ipex lib")
|
||||||
|
return False
|
||||||
|
return hasattr(torch, "xpu") and torch.xpu.is_available()
|
||||||
|
|
||||||
|
|
||||||
@lru_cache(maxsize=None)
|
@lru_cache(maxsize=None)
|
||||||
def get_max_shared_memory_bytes(gpu: int = 0) -> int:
|
def get_max_shared_memory_bytes(gpu: int = 0) -> int:
|
||||||
"""Returns the maximum shared memory per thread block in bytes."""
|
"""Returns the maximum shared memory per thread block in bytes."""
|
||||||
@ -482,6 +502,9 @@ def is_pin_memory_available() -> bool:
|
|||||||
print_warning_once("Using 'pin_memory=False' as WSL is detected. "
|
print_warning_once("Using 'pin_memory=False' as WSL is detected. "
|
||||||
"This may slow down the performance.")
|
"This may slow down the performance.")
|
||||||
return False
|
return False
|
||||||
|
elif is_xpu():
|
||||||
|
print_warning_once("Pin memory is not supported on XPU.")
|
||||||
|
return False
|
||||||
elif is_neuron():
|
elif is_neuron():
|
||||||
print_warning_once("Pin memory is not supported on Neuron.")
|
print_warning_once("Pin memory is not supported on Neuron.")
|
||||||
return False
|
return False
|
||||||
@ -497,8 +520,12 @@ class CudaMemoryProfiler:
|
|||||||
|
|
||||||
def current_memory_usage(self) -> float:
|
def current_memory_usage(self) -> float:
|
||||||
# Return the memory usage in bytes.
|
# Return the memory usage in bytes.
|
||||||
|
if torch.cuda.is_available():
|
||||||
torch.cuda.reset_peak_memory_stats(self.device)
|
torch.cuda.reset_peak_memory_stats(self.device)
|
||||||
mem = torch.cuda.max_memory_allocated(self.device)
|
mem = torch.cuda.max_memory_allocated(self.device)
|
||||||
|
elif is_xpu():
|
||||||
|
torch.xpu.reset_peak_memory_stats(self.device)
|
||||||
|
mem = torch.xpu.max_memory_allocated(self.device)
|
||||||
return mem
|
return mem
|
||||||
|
|
||||||
def __enter__(self):
|
def __enter__(self):
|
||||||
|
@ -4,7 +4,7 @@ from typing import List
|
|||||||
import torch
|
import torch
|
||||||
|
|
||||||
from vllm.attention import get_attn_backend
|
from vllm.attention import get_attn_backend
|
||||||
from vllm.config import CacheConfig, ModelConfig, ParallelConfig
|
from vllm.config import CacheConfig, DeviceConfig, ModelConfig, ParallelConfig
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
from vllm.utils import (STR_DTYPE_TO_TORCH_DTYPE, get_dtype_size,
|
from vllm.utils import (STR_DTYPE_TO_TORCH_DTYPE, get_dtype_size,
|
||||||
is_pin_memory_available)
|
is_pin_memory_available)
|
||||||
@ -25,10 +25,12 @@ class CacheEngine:
|
|||||||
cache_config: CacheConfig,
|
cache_config: CacheConfig,
|
||||||
model_config: ModelConfig,
|
model_config: ModelConfig,
|
||||||
parallel_config: ParallelConfig,
|
parallel_config: ParallelConfig,
|
||||||
|
device_config: DeviceConfig,
|
||||||
) -> None:
|
) -> None:
|
||||||
self.cache_config = cache_config
|
self.cache_config = cache_config
|
||||||
self.model_config = model_config
|
self.model_config = model_config
|
||||||
self.parallel_config = parallel_config
|
self.parallel_config = parallel_config
|
||||||
|
self.device_config = device_config
|
||||||
|
|
||||||
self.head_size = model_config.get_head_size()
|
self.head_size = model_config.get_head_size()
|
||||||
self.num_layers = model_config.get_num_layers(parallel_config)
|
self.num_layers = model_config.get_num_layers(parallel_config)
|
||||||
@ -55,7 +57,8 @@ class CacheEngine:
|
|||||||
)
|
)
|
||||||
|
|
||||||
# Initialize the cache.
|
# Initialize the cache.
|
||||||
self.gpu_cache = self._allocate_kv_cache(self.num_gpu_blocks, "cuda")
|
self.gpu_cache = self._allocate_kv_cache(
|
||||||
|
self.num_gpu_blocks, self.device_config.device_type)
|
||||||
self.cpu_cache = self._allocate_kv_cache(self.num_cpu_blocks, "cpu")
|
self.cpu_cache = self._allocate_kv_cache(self.num_cpu_blocks, "cpu")
|
||||||
|
|
||||||
def _allocate_kv_cache(
|
def _allocate_kv_cache(
|
||||||
|
@ -205,7 +205,8 @@ class Worker(WorkerBase):
|
|||||||
def _init_cache_engine(self):
|
def _init_cache_engine(self):
|
||||||
assert self.cache_config.num_gpu_blocks is not None
|
assert self.cache_config.num_gpu_blocks is not None
|
||||||
self.cache_engine = CacheEngine(self.cache_config, self.model_config,
|
self.cache_engine = CacheEngine(self.cache_config, self.model_config,
|
||||||
self.parallel_config)
|
self.parallel_config,
|
||||||
|
self.device_config)
|
||||||
self.gpu_cache = self.cache_engine.gpu_cache
|
self.gpu_cache = self.cache_engine.gpu_cache
|
||||||
|
|
||||||
def _warm_up_model(self) -> None:
|
def _warm_up_model(self) -> None:
|
||||||
|
417
vllm/worker/xpu_model_runner.py
Normal file
417
vllm/worker/xpu_model_runner.py
Normal file
@ -0,0 +1,417 @@
|
|||||||
|
from typing import List, Optional, Tuple
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
|
||||||
|
from vllm.attention import get_attn_backend
|
||||||
|
from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig,
|
||||||
|
ModelConfig, ParallelConfig, SchedulerConfig,
|
||||||
|
VisionLanguageConfig)
|
||||||
|
from vllm.distributed import broadcast_tensor_dict
|
||||||
|
from vllm.logger import init_logger
|
||||||
|
from vllm.model_executor.model_loader import get_model
|
||||||
|
from vllm.sampling_params import SamplingParams
|
||||||
|
from vllm.sequence import SamplerOutput, SequenceData, SequenceGroupMetadata
|
||||||
|
from vllm.utils import CudaMemoryProfiler, make_tensor_with_pad
|
||||||
|
from vllm.worker.model_runner import AttentionMetadata, SamplingMetadata
|
||||||
|
|
||||||
|
logger = init_logger(__name__)
|
||||||
|
|
||||||
|
_PAD_SLOT_ID = -1
|
||||||
|
_BATCH_SIZE_ALIGNMENT = 8
|
||||||
|
_BATCH_SIZES_TO_CAPTURE = [1, 2, 4] + [
|
||||||
|
_BATCH_SIZE_ALIGNMENT * i for i in range(1, 33)
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
class XPUModelRunner:
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
model_config: ModelConfig,
|
||||||
|
parallel_config: ParallelConfig,
|
||||||
|
scheduler_config: SchedulerConfig,
|
||||||
|
device_config: DeviceConfig,
|
||||||
|
cache_config: CacheConfig,
|
||||||
|
load_config: LoadConfig,
|
||||||
|
lora_config: Optional[LoRAConfig],
|
||||||
|
vision_language_config: Optional[VisionLanguageConfig],
|
||||||
|
kv_cache_dtype: Optional[str] = "auto",
|
||||||
|
is_driver_worker: bool = False,
|
||||||
|
*args,
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
|
self.model_config = model_config
|
||||||
|
self.parallel_config = parallel_config
|
||||||
|
self.scheduler_config = scheduler_config
|
||||||
|
self.lora_config = lora_config
|
||||||
|
self.load_config = load_config
|
||||||
|
self.cache_config = cache_config
|
||||||
|
self.vision_language_config = vision_language_config
|
||||||
|
self.is_driver_worker = is_driver_worker
|
||||||
|
|
||||||
|
self.sliding_window = model_config.get_sliding_window()
|
||||||
|
self.device_config = device_config
|
||||||
|
self.device = self.device_config.device
|
||||||
|
|
||||||
|
self.kv_cache_dtype = kv_cache_dtype
|
||||||
|
self.block_size = cache_config.block_size
|
||||||
|
self.max_context_len_to_capture = (
|
||||||
|
self.model_config.max_context_len_to_capture
|
||||||
|
if self.model_config is not None else 0)
|
||||||
|
|
||||||
|
self.attn_backend = get_attn_backend(
|
||||||
|
self.model_config.get_num_attention_heads(self.parallel_config),
|
||||||
|
self.model_config.get_head_size(),
|
||||||
|
self.model_config.get_num_kv_heads(self.parallel_config),
|
||||||
|
self.model_config.get_sliding_window(),
|
||||||
|
self.model_config.dtype,
|
||||||
|
self.kv_cache_dtype,
|
||||||
|
self.block_size,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Lazy initialization.
|
||||||
|
self.model: nn.Module # Set after init_Model
|
||||||
|
|
||||||
|
def load_model(self) -> None:
|
||||||
|
with CudaMemoryProfiler() as m:
|
||||||
|
self.model = get_model(
|
||||||
|
model_config=self.model_config,
|
||||||
|
device_config=self.device_config,
|
||||||
|
load_config=self.load_config,
|
||||||
|
lora_config=self.lora_config,
|
||||||
|
vision_language_config=self.vision_language_config,
|
||||||
|
parallel_config=self.parallel_config,
|
||||||
|
scheduler_config=self.scheduler_config,
|
||||||
|
cache_config=self.cache_config,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.model_memory_usage = m.consumed_memory
|
||||||
|
logger.info("Loading model weights took %.4f GB",
|
||||||
|
self.model_memory_usage / float(2**30))
|
||||||
|
|
||||||
|
@property
|
||||||
|
def vocab_size(self) -> int:
|
||||||
|
return self.model_config.get_vocab_size()
|
||||||
|
|
||||||
|
@torch.inference_mode()
|
||||||
|
def profile_run(self) -> None:
|
||||||
|
# Enable top-k sampling to reflect the accurate memory usage.
|
||||||
|
sampling_params = SamplingParams(top_p=0.99, top_k=self.vocab_size - 1)
|
||||||
|
max_num_batched_tokens = self.scheduler_config.max_num_batched_tokens
|
||||||
|
max_num_seqs = self.scheduler_config.max_num_seqs
|
||||||
|
|
||||||
|
# Profile memory usage with max_num_sequences sequences and the total
|
||||||
|
# number of tokens equal to max_num_batched_tokens.
|
||||||
|
seqs: List[SequenceGroupMetadata] = []
|
||||||
|
# Additional GPU memory may be needed for vision encoding, which needs
|
||||||
|
# to be accounted for when calculating the GPU blocks for
|
||||||
|
# vLLM blocker manager.
|
||||||
|
# To exercise the worst scenario for GPU memory consumption,
|
||||||
|
# the number of seqs (batch_size) is chosen to maximize the number
|
||||||
|
# of images processed.
|
||||||
|
for group_id in range(max_num_seqs):
|
||||||
|
seq_len = (max_num_batched_tokens // max_num_seqs +
|
||||||
|
(group_id < max_num_batched_tokens % max_num_seqs))
|
||||||
|
|
||||||
|
seq_data = SequenceData([0] * seq_len)
|
||||||
|
dummy_multi_modal_data = None
|
||||||
|
seq = SequenceGroupMetadata(
|
||||||
|
request_id=str(group_id),
|
||||||
|
is_prompt=True,
|
||||||
|
seq_data={group_id: seq_data},
|
||||||
|
sampling_params=sampling_params,
|
||||||
|
block_tables=None,
|
||||||
|
lora_request=None,
|
||||||
|
multi_modal_data=dummy_multi_modal_data,
|
||||||
|
)
|
||||||
|
seqs.append(seq)
|
||||||
|
|
||||||
|
# Run the model with the dummy inputs.
|
||||||
|
num_layers = self.model_config.get_num_layers(self.parallel_config)
|
||||||
|
kv_caches = [None] * num_layers
|
||||||
|
self.execute_model(seqs, kv_caches)
|
||||||
|
torch.xpu.synchronize()
|
||||||
|
return
|
||||||
|
|
||||||
|
def prepare_input_tensors(
|
||||||
|
self,
|
||||||
|
seq_group_metadata_list: List[SequenceGroupMetadata],
|
||||||
|
) -> Tuple[torch.Tensor, torch.Tensor, AttentionMetadata, SamplingMetadata,
|
||||||
|
Optional[torch.Tensor]]:
|
||||||
|
multi_modal_input = None
|
||||||
|
if self.is_driver_worker:
|
||||||
|
# NOTE: We assume that all sequences in the group are all prompts or
|
||||||
|
# all decodes.
|
||||||
|
is_prompt = seq_group_metadata_list[0].is_prompt
|
||||||
|
# Prepare input tensors.
|
||||||
|
if is_prompt:
|
||||||
|
(input_tokens, input_positions, attn_metadata, seq_lens,
|
||||||
|
multi_modal_input
|
||||||
|
) = self._prepare_prompt(seq_group_metadata_list)
|
||||||
|
else:
|
||||||
|
(input_tokens, input_positions,
|
||||||
|
attn_metadata) = self._prepare_decode(seq_group_metadata_list)
|
||||||
|
seq_lens = []
|
||||||
|
sampling_metadata = SamplingMetadata.prepare(
|
||||||
|
seq_group_metadata_list,
|
||||||
|
seq_lens,
|
||||||
|
# subquery_lens is not needed if chunked prefill is not
|
||||||
|
# supported. Since CPU worker doesn't support chunked prefill
|
||||||
|
# just use seq_lens instead.
|
||||||
|
seq_lens,
|
||||||
|
self.device,
|
||||||
|
pin_memory=False)
|
||||||
|
# Broadcast the metadata.
|
||||||
|
metadata_dict = {
|
||||||
|
"input_tokens": input_tokens,
|
||||||
|
"input_positions": input_positions,
|
||||||
|
"selected_token_indices":
|
||||||
|
sampling_metadata.selected_token_indices,
|
||||||
|
}
|
||||||
|
metadata_dict.update(attn_metadata.asdict_zerocopy())
|
||||||
|
broadcast_tensor_dict(metadata_dict, src=0)
|
||||||
|
else:
|
||||||
|
metadata_dict = broadcast_tensor_dict(src=0)
|
||||||
|
input_tokens = metadata_dict.pop("input_tokens")
|
||||||
|
input_positions = metadata_dict.pop("input_positions")
|
||||||
|
selected_token_indices = metadata_dict.pop(
|
||||||
|
"selected_token_indices")
|
||||||
|
attn_metadata = self.attn_backend.make_metadata(**metadata_dict)
|
||||||
|
sampling_metadata = SamplingMetadata(
|
||||||
|
seq_groups=None,
|
||||||
|
selected_token_indices=selected_token_indices,
|
||||||
|
categorized_sample_indices=None,
|
||||||
|
num_prompts=0,
|
||||||
|
)
|
||||||
|
|
||||||
|
return (input_tokens, input_positions, attn_metadata,
|
||||||
|
sampling_metadata, multi_modal_input)
|
||||||
|
|
||||||
|
def _prepare_decode(
|
||||||
|
self,
|
||||||
|
seq_group_metadata_list: List[SequenceGroupMetadata],
|
||||||
|
) -> Tuple[torch.Tensor, torch.Tensor, AttentionMetadata]:
|
||||||
|
assert len(seq_group_metadata_list) > 0
|
||||||
|
input_tokens: List[int] = []
|
||||||
|
input_positions: List[int] = []
|
||||||
|
slot_mapping: List[int] = []
|
||||||
|
seq_lens: List[int] = []
|
||||||
|
block_tables: List[List[int]] = []
|
||||||
|
|
||||||
|
for seq_group_metadata in seq_group_metadata_list:
|
||||||
|
assert not seq_group_metadata.is_prompt
|
||||||
|
assert seq_group_metadata.token_chunk_size == 1
|
||||||
|
|
||||||
|
seq_ids = list(seq_group_metadata.seq_data.keys())
|
||||||
|
|
||||||
|
for seq_id in seq_ids:
|
||||||
|
seq_data = seq_group_metadata.seq_data[seq_id]
|
||||||
|
generation_token = seq_data.get_last_token_id()
|
||||||
|
input_tokens.append(generation_token)
|
||||||
|
|
||||||
|
seq_len = seq_data.get_len()
|
||||||
|
position = seq_len - 1
|
||||||
|
input_positions.append(position)
|
||||||
|
|
||||||
|
seq_len = seq_len if self.sliding_window is None else min(
|
||||||
|
seq_len, self.sliding_window)
|
||||||
|
seq_lens.append(seq_len)
|
||||||
|
|
||||||
|
block_table = seq_group_metadata.block_tables[seq_id]
|
||||||
|
block_number = block_table[position // self.block_size]
|
||||||
|
block_offset = position % self.block_size
|
||||||
|
slot = block_number * self.block_size + block_offset
|
||||||
|
slot_mapping.append(slot)
|
||||||
|
|
||||||
|
if self.sliding_window is not None:
|
||||||
|
sliding_window_blocks = (self.sliding_window //
|
||||||
|
self.block_size)
|
||||||
|
block_table = block_table[-sliding_window_blocks:]
|
||||||
|
block_tables.append(block_table)
|
||||||
|
|
||||||
|
max_decode_seq_len = max(seq_lens)
|
||||||
|
|
||||||
|
input_tokens = torch.tensor(input_tokens,
|
||||||
|
dtype=torch.long,
|
||||||
|
device=self.device)
|
||||||
|
input_positions = torch.tensor(input_positions,
|
||||||
|
dtype=torch.long,
|
||||||
|
device=self.device)
|
||||||
|
slot_mapping = torch.tensor(slot_mapping,
|
||||||
|
dtype=torch.long,
|
||||||
|
device=self.device)
|
||||||
|
seq_lens_tensor = torch.tensor(seq_lens,
|
||||||
|
dtype=torch.int,
|
||||||
|
device=self.device)
|
||||||
|
|
||||||
|
max_block_table_len = max(
|
||||||
|
len(block_table) for block_table in block_tables)
|
||||||
|
block_tables = make_tensor_with_pad(
|
||||||
|
block_tables,
|
||||||
|
max_len=max_block_table_len,
|
||||||
|
pad=0,
|
||||||
|
dtype=torch.int,
|
||||||
|
device=self.device,
|
||||||
|
)
|
||||||
|
|
||||||
|
attn_metadata = self.attn_backend.make_metadata(
|
||||||
|
is_prompt=False,
|
||||||
|
slot_mapping=slot_mapping,
|
||||||
|
seq_lens=seq_lens,
|
||||||
|
seqlen_q=None,
|
||||||
|
max_seqlen=None,
|
||||||
|
seq_lens_tensor=seq_lens_tensor,
|
||||||
|
max_decode_seq_len=max_decode_seq_len,
|
||||||
|
num_prefill_tokens=0,
|
||||||
|
num_decode_tokens=len(input_tokens),
|
||||||
|
num_prefills=0,
|
||||||
|
block_tables=block_tables,
|
||||||
|
)
|
||||||
|
return (
|
||||||
|
input_tokens,
|
||||||
|
input_positions,
|
||||||
|
attn_metadata,
|
||||||
|
)
|
||||||
|
|
||||||
|
@torch.inference_mode()
|
||||||
|
def execute_model(
|
||||||
|
self,
|
||||||
|
seq_group_metadata_list: List[SequenceGroupMetadata],
|
||||||
|
kv_caches: List[torch.Tensor],
|
||||||
|
) -> Optional[SamplerOutput]:
|
||||||
|
(input_tokens, input_positions, attn_metadata, sampling_metadata,
|
||||||
|
multi_modal_input
|
||||||
|
) = self.prepare_input_tensors(seq_group_metadata_list)
|
||||||
|
|
||||||
|
model_executable = self.model
|
||||||
|
execute_model_kwargs = {
|
||||||
|
"input_ids": input_tokens,
|
||||||
|
"positions": input_positions,
|
||||||
|
"kv_caches": kv_caches,
|
||||||
|
"attn_metadata": attn_metadata,
|
||||||
|
}
|
||||||
|
if self.vision_language_config:
|
||||||
|
execute_model_kwargs.update({"image_input": multi_modal_input})
|
||||||
|
|
||||||
|
hidden_states = model_executable(**execute_model_kwargs)
|
||||||
|
|
||||||
|
# Compute the logits.
|
||||||
|
logits = self.model.compute_logits(hidden_states, sampling_metadata)
|
||||||
|
|
||||||
|
# Only perform sampling in the driver worker.
|
||||||
|
if not self.is_driver_worker:
|
||||||
|
return None
|
||||||
|
|
||||||
|
# Sample the next token.
|
||||||
|
output = self.model.sample(
|
||||||
|
logits=logits,
|
||||||
|
sampling_metadata=sampling_metadata,
|
||||||
|
)
|
||||||
|
return output
|
||||||
|
|
||||||
|
def _prepare_prompt(
|
||||||
|
self,
|
||||||
|
seq_group_metadata_list: List[SequenceGroupMetadata],
|
||||||
|
) -> Tuple[torch.Tensor, torch.Tensor, AttentionMetadata, List[int],
|
||||||
|
Optional[torch.Tensor]]:
|
||||||
|
assert len(seq_group_metadata_list) > 0
|
||||||
|
input_tokens: List[int] = []
|
||||||
|
input_positions: List[int] = []
|
||||||
|
slot_mapping: List[int] = []
|
||||||
|
seq_lens: List[int] = []
|
||||||
|
multi_modal_input_list: List[torch.Tensor] = []
|
||||||
|
|
||||||
|
for seq_group_metadata in seq_group_metadata_list:
|
||||||
|
assert seq_group_metadata.is_prompt
|
||||||
|
seq_ids = list(seq_group_metadata.seq_data.keys())
|
||||||
|
assert len(seq_ids) == 1
|
||||||
|
seq_id = seq_ids[0]
|
||||||
|
|
||||||
|
seq_data = seq_group_metadata.seq_data[seq_id]
|
||||||
|
prompt_tokens = seq_data.get_token_ids()
|
||||||
|
computed_len = seq_data.get_num_computed_tokens()
|
||||||
|
seq_len = len(prompt_tokens)
|
||||||
|
|
||||||
|
seq_lens.append(seq_len) # Prompt token num
|
||||||
|
input_tokens.extend(prompt_tokens) # Token ids
|
||||||
|
|
||||||
|
# Token position ids
|
||||||
|
# NOTE(woosuk): Here we assume that the first token in the prompt
|
||||||
|
# is always the first token in the sequence.
|
||||||
|
input_positions.extend(list(range(computed_len, seq_len)))
|
||||||
|
|
||||||
|
if seq_group_metadata.multi_modal_data:
|
||||||
|
multi_modal_input_list.append(
|
||||||
|
seq_group_metadata.multi_modal_data.data)
|
||||||
|
|
||||||
|
if seq_group_metadata.block_tables is None:
|
||||||
|
# During memory profiling, the block tables are not initialized
|
||||||
|
# yet. In this case, we just use a dummy slot mapping.
|
||||||
|
slot_mapping.extend([_PAD_SLOT_ID] * seq_len)
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Compute the slot mapping.
|
||||||
|
block_table = seq_group_metadata.block_tables[seq_id]
|
||||||
|
# Mask the [0, start_idx) tokens of the prompt with _PAD_SLOT_ID,
|
||||||
|
# where start_idx is max(0, seq_len - sliding_window).
|
||||||
|
# For example, if the prompt len is 10, sliding window is 8, and
|
||||||
|
# block size is 4, the first two tokens are masked and the slot
|
||||||
|
# mapping will be [-1, -1, 2, 3, 4, 5, 6, 7, 0, 1].
|
||||||
|
start_idx = 0
|
||||||
|
if self.sliding_window is not None:
|
||||||
|
start_idx = max(0, seq_len - self.sliding_window)
|
||||||
|
|
||||||
|
for i in range(computed_len, seq_len):
|
||||||
|
if i < start_idx:
|
||||||
|
slot_mapping.append(_PAD_SLOT_ID)
|
||||||
|
continue
|
||||||
|
|
||||||
|
block_number = block_table[i //
|
||||||
|
self.block_size] # type: ignore
|
||||||
|
block_offset = i % self.block_size # type: ignore
|
||||||
|
slot = block_number * self.block_size + block_offset
|
||||||
|
slot_mapping.append(slot)
|
||||||
|
|
||||||
|
if multi_modal_input_list:
|
||||||
|
assert self.vision_language_config, (
|
||||||
|
"Multi-modal inputs are only supported by "
|
||||||
|
"vision language models.")
|
||||||
|
multi_modal_input = torch.cat(multi_modal_input_list,
|
||||||
|
dim=0).to(self.device)
|
||||||
|
else:
|
||||||
|
multi_modal_input = None
|
||||||
|
|
||||||
|
num_prompt_tokens = len(input_tokens)
|
||||||
|
|
||||||
|
input_tokens = torch.tensor(input_tokens,
|
||||||
|
dtype=torch.long,
|
||||||
|
device=self.device) # type: ignore
|
||||||
|
input_positions = torch.tensor(input_positions,
|
||||||
|
dtype=torch.long,
|
||||||
|
device=self.device) # type: ignore
|
||||||
|
slot_mapping = torch.tensor(slot_mapping,
|
||||||
|
dtype=torch.long,
|
||||||
|
device=self.device) # type: ignore
|
||||||
|
|
||||||
|
max_seqlen = max(seq_lens)
|
||||||
|
tmp = [0]
|
||||||
|
tmp.extend(seq_lens)
|
||||||
|
seqlen = torch.tensor(tmp)
|
||||||
|
seqlen_q = torch.cumsum(seqlen, dim=0).to(device=self.device)
|
||||||
|
|
||||||
|
attn_metadata = self.attn_backend.make_metadata(
|
||||||
|
is_prompt=True,
|
||||||
|
slot_mapping=slot_mapping,
|
||||||
|
seq_lens=seq_lens,
|
||||||
|
seqlen_q=seqlen_q,
|
||||||
|
max_seqlen=max_seqlen,
|
||||||
|
seq_lens_tensor=None,
|
||||||
|
max_decode_seq_len=None,
|
||||||
|
num_prefills=len(seq_lens),
|
||||||
|
num_prefill_tokens=num_prompt_tokens,
|
||||||
|
num_decode_tokens=0,
|
||||||
|
block_tables=torch.tensor([], device=self.device, dtype=torch.int),
|
||||||
|
)
|
||||||
|
return (input_tokens, input_positions, attn_metadata, seq_lens,
|
||||||
|
multi_modal_input)
|
193
vllm/worker/xpu_worker.py
Normal file
193
vllm/worker/xpu_worker.py
Normal file
@ -0,0 +1,193 @@
|
|||||||
|
"""A XPU worker class."""
|
||||||
|
import gc
|
||||||
|
import os
|
||||||
|
from typing import List, Optional, Tuple
|
||||||
|
|
||||||
|
import intel_extension_for_pytorch # noqa: F401
|
||||||
|
import oneccl_bindings_for_pytorch # noqa: F401
|
||||||
|
import torch
|
||||||
|
import torch.distributed
|
||||||
|
|
||||||
|
from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig,
|
||||||
|
ModelConfig, ParallelConfig, SchedulerConfig,
|
||||||
|
SpeculativeConfig, VisionLanguageConfig)
|
||||||
|
from vllm.distributed import (ensure_model_parallel_initialized,
|
||||||
|
init_distributed_environment)
|
||||||
|
from vllm.logger import init_logger
|
||||||
|
from vllm.model_executor import set_random_seed
|
||||||
|
from vllm.utils import is_xpu
|
||||||
|
from vllm.worker.cache_engine import CacheEngine
|
||||||
|
from vllm.worker.worker import Worker
|
||||||
|
from vllm.worker.worker_base import LoraNotSupportedWorkerBase
|
||||||
|
from vllm.worker.xpu_model_runner import XPUModelRunner
|
||||||
|
|
||||||
|
logger = init_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class XPUWorker(LoraNotSupportedWorkerBase, Worker):
|
||||||
|
"""A worker class that executes (a partition of) the model on a GPU.
|
||||||
|
|
||||||
|
Each worker is associated with a single XPU device. The worker is
|
||||||
|
responsible for maintaining the KV cache and executing the model on the
|
||||||
|
XPU. In case of distributed inference, each worker is assigned a partition
|
||||||
|
of the model.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
model_config: ModelConfig,
|
||||||
|
parallel_config: ParallelConfig,
|
||||||
|
scheduler_config: SchedulerConfig,
|
||||||
|
device_config: DeviceConfig,
|
||||||
|
cache_config: CacheConfig,
|
||||||
|
load_config: LoadConfig,
|
||||||
|
local_rank: int,
|
||||||
|
rank: int,
|
||||||
|
distributed_init_method: str,
|
||||||
|
lora_config: Optional[LoRAConfig] = None,
|
||||||
|
vision_language_config: Optional[VisionLanguageConfig] = None,
|
||||||
|
speculative_config: Optional[SpeculativeConfig] = None,
|
||||||
|
is_driver_worker: bool = False,
|
||||||
|
) -> None:
|
||||||
|
assert device_config.device_type == "xpu"
|
||||||
|
assert is_xpu()
|
||||||
|
|
||||||
|
self.model_config = model_config
|
||||||
|
self.parallel_config = parallel_config
|
||||||
|
self.scheduler_config = scheduler_config
|
||||||
|
self.device_config = device_config
|
||||||
|
self.cache_config = cache_config
|
||||||
|
self.load_config = load_config
|
||||||
|
self.local_rank = local_rank
|
||||||
|
self.rank = rank
|
||||||
|
self.distributed_init_method = distributed_init_method
|
||||||
|
self.lora_config = lora_config
|
||||||
|
self.is_driver_worker = is_driver_worker
|
||||||
|
if self.is_driver_worker:
|
||||||
|
assert self.rank == 0, "The driver worker must have rank 0."
|
||||||
|
|
||||||
|
self.vision_language_config = vision_language_config
|
||||||
|
if self.vision_language_config:
|
||||||
|
assert not self.lora_config, (
|
||||||
|
"To be tested: vision language model with LoRA settings.")
|
||||||
|
|
||||||
|
self.model_runner = XPUModelRunner( # type: ignore
|
||||||
|
model_config,
|
||||||
|
parallel_config,
|
||||||
|
scheduler_config,
|
||||||
|
device_config,
|
||||||
|
cache_config,
|
||||||
|
load_config=self.load_config,
|
||||||
|
lora_config=self.lora_config,
|
||||||
|
kv_cache_dtype=self.cache_config.cache_dtype,
|
||||||
|
is_driver_worker=is_driver_worker,
|
||||||
|
vision_language_config=vision_language_config,
|
||||||
|
)
|
||||||
|
# Uninitialized cache engine. Will be initialized by
|
||||||
|
# initialize_cache.
|
||||||
|
self.cache_engine: CacheEngine
|
||||||
|
self.gpu_cache: List[torch.Tensor]
|
||||||
|
|
||||||
|
def init_device(self) -> None:
|
||||||
|
if self.device_config.device.type == "xpu" and is_xpu():
|
||||||
|
self.device = torch.device(f"xpu:{self.local_rank}")
|
||||||
|
torch.xpu.set_device(self.device)
|
||||||
|
torch.xpu.empty_cache()
|
||||||
|
self.init_gpu_memory = torch.xpu.get_device_properties(
|
||||||
|
self.local_rank).total_memory
|
||||||
|
else:
|
||||||
|
raise RuntimeError(
|
||||||
|
f"Not support device type: {self.device_config.device}")
|
||||||
|
# Initialize the distributed environment.
|
||||||
|
self.init_worker_distributed_environment()
|
||||||
|
# Initialize the model.
|
||||||
|
set_random_seed(self.model_config.seed)
|
||||||
|
|
||||||
|
# keep this method for `empty_cache` and `synchronize` api
|
||||||
|
@torch.inference_mode()
|
||||||
|
def determine_num_available_blocks(self) -> Tuple[int, int]:
|
||||||
|
"""Profiles the peak memory usage of the model to determine how many
|
||||||
|
KV blocks may be allocated without OOMs.
|
||||||
|
|
||||||
|
The engine will first conduct a profiling of the existing memory usage.
|
||||||
|
Then, it calculate the maximum possible number of GPU and CPU blocks
|
||||||
|
that can be allocated with the remaining free memory.
|
||||||
|
|
||||||
|
.. tip::
|
||||||
|
You may limit the usage of GPU memory
|
||||||
|
by adjusting the `gpu_memory_utilization` parameter.
|
||||||
|
"""
|
||||||
|
# Profile the memory usage of the model and get the maximum number of
|
||||||
|
# cache blocks that can be allocated with the remaining free memory.
|
||||||
|
torch.xpu.empty_cache()
|
||||||
|
|
||||||
|
# Execute a forward pass with dummy inputs to profile the memory usage
|
||||||
|
# of the model.
|
||||||
|
self.model_runner.profile_run()
|
||||||
|
|
||||||
|
# Calculate the number of blocks that can be allocated with the
|
||||||
|
# profiled peak memory.
|
||||||
|
torch.xpu.synchronize()
|
||||||
|
used_memory = torch.xpu.memory_allocated()
|
||||||
|
total_gpu_memory = torch.xpu.get_device_properties(
|
||||||
|
self.local_rank).total_memory
|
||||||
|
free_gpu_memory = total_gpu_memory - used_memory
|
||||||
|
|
||||||
|
# NOTE(woosuk): Here we assume that the other processes using the same
|
||||||
|
# GPU did not change their memory usage during the profiling.
|
||||||
|
peak_memory = self.init_gpu_memory - free_gpu_memory
|
||||||
|
assert peak_memory > 0, (
|
||||||
|
"Error in memory profiling. This happens when the GPU memory was "
|
||||||
|
"not properly cleaned up before initializing the vLLM instance.")
|
||||||
|
|
||||||
|
cache_block_size = self.get_cache_block_size_bytes()
|
||||||
|
num_gpu_blocks = int(
|
||||||
|
(total_gpu_memory * self.cache_config.gpu_memory_utilization -
|
||||||
|
peak_memory) // cache_block_size)
|
||||||
|
num_cpu_blocks = int(self.cache_config.swap_space_bytes //
|
||||||
|
cache_block_size)
|
||||||
|
num_gpu_blocks = max(num_gpu_blocks, 0)
|
||||||
|
num_cpu_blocks = max(num_cpu_blocks, 0)
|
||||||
|
gc.collect()
|
||||||
|
torch.xpu.empty_cache()
|
||||||
|
return num_gpu_blocks, num_cpu_blocks
|
||||||
|
|
||||||
|
def _warm_up_model(self) -> None:
|
||||||
|
# IPEX don't support capture graph yet
|
||||||
|
pass
|
||||||
|
|
||||||
|
def init_worker_distributed_environment(self) -> None:
|
||||||
|
"""Initialize the distributed environment."""
|
||||||
|
|
||||||
|
parallel_config = self.parallel_config
|
||||||
|
rank = self.rank
|
||||||
|
distributed_init_method = self.distributed_init_method
|
||||||
|
|
||||||
|
if torch.distributed.is_initialized():
|
||||||
|
torch_world_size = torch.distributed.get_world_size()
|
||||||
|
if torch_world_size != parallel_config.world_size:
|
||||||
|
raise RuntimeError(
|
||||||
|
"torch.distributed is already initialized but the torch "
|
||||||
|
"world size does not match parallel_config.world_size "
|
||||||
|
f"({torch_world_size} vs. {parallel_config.world_size}).")
|
||||||
|
elif not distributed_init_method:
|
||||||
|
raise ValueError(
|
||||||
|
"distributed_init_method must be set if torch.distributed "
|
||||||
|
"is not already initialized")
|
||||||
|
else:
|
||||||
|
# use sockets as default Level zero IPC exchange backend. By
|
||||||
|
# default oneccl will use `drmfd` as mechanism which need extra
|
||||||
|
# dependency (libdrm and drm headers) on your system.
|
||||||
|
ENV_CCL_ZE_IPC_EXCHANGE = os.getenv("CCL_ZE_IPC_EXCHANGE",
|
||||||
|
"sockets")
|
||||||
|
os.environ['CCL_ZE_IPC_EXCHANGE'] = ENV_CCL_ZE_IPC_EXCHANGE
|
||||||
|
init_distributed_environment(
|
||||||
|
world_size=parallel_config.world_size,
|
||||||
|
rank=rank,
|
||||||
|
distributed_init_method=distributed_init_method,
|
||||||
|
local_rank=self.local_rank,
|
||||||
|
backend="ccl")
|
||||||
|
|
||||||
|
ensure_model_parallel_initialized(
|
||||||
|
parallel_config.tensor_parallel_size,
|
||||||
|
parallel_config.pipeline_parallel_size)
|
Loading…
x
Reference in New Issue
Block a user