[Hardware] Initial TPU integration (#5292)
This commit is contained in:
parent
847cdcca1c
commit
1a8bfd92d5
19
Dockerfile.tpu
Normal file
19
Dockerfile.tpu
Normal file
@ -0,0 +1,19 @@
|
||||
ARG NIGHTLY_DATE="20240601"
|
||||
ARG BASE_IMAGE="us-central1-docker.pkg.dev/tpu-pytorch-releases/docker/xla:nightly_3.10_tpuvm_$NIGHTLY_DATE"
|
||||
|
||||
FROM $BASE_IMAGE
|
||||
|
||||
WORKDIR /workspace
|
||||
COPY . /workspace/vllm
|
||||
|
||||
ENV VLLM_TARGET_DEVICE="tpu"
|
||||
# Install aiohttp separately to avoid build errors.
|
||||
RUN pip install aiohttp
|
||||
# Install the TPU and Pallas dependencies.
|
||||
RUN pip install torch_xla[tpu] -f https://storage.googleapis.com/libtpu-releases/index.html
|
||||
RUN pip install torch_xla[pallas] -f https://storage.googleapis.com/jax-releases/jax_nightly_releases.html -f https://storage.googleapis.com/jax-releases/jaxlib_nightly_releases.html
|
||||
|
||||
# Build vLLM.
|
||||
RUN cd /workspace/vllm && python setup.py develop
|
||||
|
||||
CMD ["/bin/bash"]
|
@ -189,7 +189,7 @@ if __name__ == '__main__':
|
||||
"--device",
|
||||
type=str,
|
||||
default="cuda",
|
||||
choices=["cuda", "cpu"],
|
||||
choices=["cuda", "cpu", "tpu"],
|
||||
help='device type for vLLM execution, supporting CUDA and CPU.')
|
||||
parser.add_argument('--block-size',
|
||||
type=int,
|
||||
|
@ -346,7 +346,7 @@ if __name__ == "__main__":
|
||||
"--device",
|
||||
type=str,
|
||||
default="cuda",
|
||||
choices=["cuda", "cpu"],
|
||||
choices=["cuda", "cpu", "tpu"],
|
||||
help='device type for vLLM execution, supporting CUDA and CPU.')
|
||||
parser.add_argument(
|
||||
"--enable-prefix-caching",
|
||||
|
75
docs/source/getting_started/tpu-installation.rst
Normal file
75
docs/source/getting_started/tpu-installation.rst
Normal file
@ -0,0 +1,75 @@
|
||||
.. _installation_tpu:
|
||||
|
||||
Installation with TPU
|
||||
=====================
|
||||
|
||||
vLLM supports Google Cloud TPUs using PyTorch XLA.
|
||||
|
||||
Requirements
|
||||
------------
|
||||
|
||||
* Google Cloud TPU VM (single host)
|
||||
* TPU versions: v5e, v5p, v4
|
||||
* Python: 3.10
|
||||
|
||||
Installation options:
|
||||
|
||||
1. :ref:`Build a docker image with Dockerfile <build_docker_tpu>`.
|
||||
2. :ref:`Build from source <build_from_source_tpu>`.
|
||||
|
||||
.. _build_docker_tpu:
|
||||
|
||||
Build a docker image with :code:`Dockerfile.tpu`
|
||||
------------------------------------------------
|
||||
|
||||
`Dockerfile.tpu <https://github.com/vllm-project/vllm/blob/main/Dockerfile.tpu>`_ is provided to build a docker image with TPU support.
|
||||
|
||||
.. code-block:: console
|
||||
|
||||
$ docker build -f Dockerfile.tpu -t vllm-tpu .
|
||||
|
||||
|
||||
You can run the docker image with the following command:
|
||||
|
||||
.. code-block:: console
|
||||
|
||||
$ # Make sure to add `--privileged --net host --shm-size=16G`.
|
||||
$ docker run --privileged --net host --shm-size=16G -it vllm-tpu
|
||||
|
||||
|
||||
.. _build_from_source_tpu:
|
||||
|
||||
Build from source
|
||||
-----------------
|
||||
|
||||
You can also build and install the TPU backend from source.
|
||||
|
||||
First, install the dependencies:
|
||||
|
||||
.. code-block:: console
|
||||
|
||||
$ # (Recommended) Create a new conda environment.
|
||||
$ conda create -n myenv python=3.10 -y
|
||||
$ conda activate myenv
|
||||
|
||||
$ # Clean up the existing torch and torch-xla packages.
|
||||
$ pip uninstall torch torch-xla -y
|
||||
|
||||
$ # Install PyTorch and PyTorch XLA.
|
||||
$ export DATE="+20240601"
|
||||
$ pip install https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch-nightly${DATE}-cp310-cp310-linux_x86_64.whl
|
||||
$ pip install https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch_xla-nightly${DATE}-cp310-cp310-linux_x86_64.whl
|
||||
|
||||
$ # Install JAX and Pallas.
|
||||
$ pip install torch_xla[tpu] -f https://storage.googleapis.com/libtpu-releases/index.html
|
||||
$ pip install torch_xla[pallas] -f https://storage.googleapis.com/jax-releases/jax_nightly_releases.html -f https://storage.googleapis.com/jax-releases/jaxlib_nightly_releases.html
|
||||
|
||||
$ # Install other build dependencies.
|
||||
$ pip install packaging aiohttp
|
||||
|
||||
|
||||
Next, build vLLM from source. This will only take a few seconds:
|
||||
|
||||
.. code-block:: console
|
||||
|
||||
$ VLLM_TARGET_DEVICE="tpu" python setup.py develop
|
@ -63,8 +63,9 @@ Documentation
|
||||
|
||||
getting_started/installation
|
||||
getting_started/amd-installation
|
||||
getting_started/neuron-installation
|
||||
getting_started/cpu-installation
|
||||
getting_started/neuron-installation
|
||||
getting_started/tpu-installation
|
||||
getting_started/quickstart
|
||||
getting_started/debugging
|
||||
getting_started/examples/examples_index
|
||||
|
7
requirements-tpu.txt
Normal file
7
requirements-tpu.txt
Normal file
@ -0,0 +1,7 @@
|
||||
# Common dependencies
|
||||
-r requirements-common.txt
|
||||
|
||||
# Dependencies for TPU
|
||||
# Currently, the TPU backend uses a nightly version of PyTorch XLA.
|
||||
# You can install the dependencies in Dockerfile.tpu.
|
||||
triton # To avoid import errors
|
22
setup.py
22
setup.py
@ -206,9 +206,9 @@ class cmake_build_ext(build_ext):
|
||||
|
||||
|
||||
def _is_cuda() -> bool:
|
||||
return VLLM_TARGET_DEVICE == "cuda" \
|
||||
and torch.version.cuda is not None \
|
||||
and not _is_neuron()
|
||||
has_cuda = torch.version.cuda is not None
|
||||
return (VLLM_TARGET_DEVICE == "cuda" and has_cuda
|
||||
and not (_is_neuron() or _is_tpu()))
|
||||
|
||||
|
||||
def _is_hip() -> bool:
|
||||
@ -225,10 +225,18 @@ def _is_neuron() -> bool:
|
||||
return torch_neuronx_installed or VLLM_TARGET_DEVICE == "neuron"
|
||||
|
||||
|
||||
def _is_tpu() -> bool:
|
||||
return VLLM_TARGET_DEVICE == "tpu"
|
||||
|
||||
|
||||
def _is_cpu() -> bool:
|
||||
return VLLM_TARGET_DEVICE == "cpu"
|
||||
|
||||
|
||||
def _build_custom_ops() -> bool:
|
||||
return _is_cuda() or _is_hip() or _is_cpu()
|
||||
|
||||
|
||||
def _install_punica() -> bool:
|
||||
return envs.VLLM_INSTALL_PUNICA_KERNELS
|
||||
|
||||
@ -325,6 +333,8 @@ def get_vllm_version() -> str:
|
||||
if neuron_version != MAIN_CUDA_VERSION:
|
||||
neuron_version_str = neuron_version.replace(".", "")[:3]
|
||||
version += f"+neuron{neuron_version_str}"
|
||||
elif _is_tpu():
|
||||
version += "+tpu"
|
||||
elif _is_cpu():
|
||||
version += "+cpu"
|
||||
else:
|
||||
@ -372,6 +382,8 @@ def get_requirements() -> List[str]:
|
||||
requirements = _read_requirements("requirements-rocm.txt")
|
||||
elif _is_neuron():
|
||||
requirements = _read_requirements("requirements-neuron.txt")
|
||||
elif _is_tpu():
|
||||
requirements = _read_requirements("requirements-tpu.txt")
|
||||
elif _is_cpu():
|
||||
requirements = _read_requirements("requirements-cpu.txt")
|
||||
else:
|
||||
@ -385,7 +397,7 @@ ext_modules = []
|
||||
if _is_cuda() or _is_hip():
|
||||
ext_modules.append(CMakeExtension(name="vllm._moe_C"))
|
||||
|
||||
if not _is_neuron():
|
||||
if _build_custom_ops():
|
||||
ext_modules.append(CMakeExtension(name="vllm._C"))
|
||||
|
||||
if _install_punica():
|
||||
@ -428,6 +440,6 @@ setup(
|
||||
extras_require={
|
||||
"tensorizer": ["tensorizer>=2.9.0"],
|
||||
},
|
||||
cmdclass={"build_ext": cmake_build_ext} if not _is_neuron() else {},
|
||||
cmdclass={"build_ext": cmake_build_ext} if _build_custom_ops() else {},
|
||||
package_data=package_data,
|
||||
)
|
||||
|
232
vllm/attention/backends/pallas.py
Normal file
232
vllm/attention/backends/pallas.py
Normal file
@ -0,0 +1,232 @@
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Dict, List, Optional, Tuple, Type
|
||||
|
||||
import torch
|
||||
import torch_xla.experimental.custom_kernel # Required to register custom ops.
|
||||
import torch_xla.experimental.dynamo_set_buffer_donor
|
||||
|
||||
from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
|
||||
AttentionMetadata)
|
||||
|
||||
|
||||
class PallasAttentionBackend(AttentionBackend):
|
||||
|
||||
@staticmethod
|
||||
def get_impl_cls() -> Type["PallasAttentionBackendImpl"]:
|
||||
return PallasAttentionBackendImpl
|
||||
|
||||
@staticmethod
|
||||
def make_metadata(*args, **kwargs) -> "PallasMetadata":
|
||||
return PallasMetadata(*args, **kwargs)
|
||||
|
||||
@staticmethod
|
||||
def get_kv_cache_shape(
|
||||
num_blocks: int,
|
||||
block_size: int,
|
||||
num_kv_heads: int,
|
||||
head_size: int,
|
||||
) -> Tuple[int, ...]:
|
||||
return (num_kv_heads, num_blocks, block_size, head_size)
|
||||
|
||||
@staticmethod
|
||||
def swap_blocks(
|
||||
src_kv_cache: torch.Tensor,
|
||||
dst_kv_cache: torch.Tensor,
|
||||
src_to_dst: Dict[int, int],
|
||||
) -> None:
|
||||
raise NotImplementedError("swap_blocks is not implemented.")
|
||||
|
||||
@staticmethod
|
||||
def copy_blocks(
|
||||
kv_caches: List[torch.Tensor],
|
||||
src_to_dists: Dict[int, List[int]],
|
||||
) -> None:
|
||||
# TODO(woosuk): Implement this.
|
||||
raise NotImplementedError("copy_blocks is not implemented.")
|
||||
|
||||
|
||||
@dataclass
|
||||
class PallasMetadata(AttentionMetadata):
|
||||
|
||||
# Currently, input sequences can only contain all prefills
|
||||
# or all decoding.
|
||||
block_tables: Optional[torch.Tensor]
|
||||
context_lens: Optional[torch.Tensor]
|
||||
|
||||
@property
|
||||
def prefill_metadata(self) -> Optional["PallasMetadata"]:
|
||||
if self.num_prefills == 0:
|
||||
return None
|
||||
|
||||
assert self.num_decode_tokens == 0
|
||||
assert self.block_tables is None
|
||||
assert self.context_lens is None
|
||||
return self
|
||||
|
||||
@property
|
||||
def decode_metadata(self) -> Optional["PallasMetadata"]:
|
||||
if self.num_decode_tokens == 0:
|
||||
return None
|
||||
|
||||
assert self.num_prefills == 0
|
||||
assert self.num_prefill_tokens == 0
|
||||
assert self.block_tables is not None
|
||||
assert self.context_lens is not None
|
||||
return self
|
||||
|
||||
|
||||
class PallasAttentionBackendImpl(AttentionImpl):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
num_heads: int,
|
||||
head_size: int,
|
||||
scale: float,
|
||||
num_kv_heads: int,
|
||||
alibi_slopes: Optional[List[float]],
|
||||
sliding_window: Optional[int],
|
||||
kv_cache_dtype: str,
|
||||
blocksparse_params: Optional[Dict[str, Any]] = None,
|
||||
) -> None:
|
||||
self.num_heads = num_heads
|
||||
self.head_size = head_size
|
||||
self.scale = float(scale)
|
||||
self.num_kv_heads = num_heads if num_kv_heads is None else num_kv_heads
|
||||
|
||||
assert self.num_heads % self.num_kv_heads == 0
|
||||
self.num_queries_per_kv = self.num_heads // self.num_kv_heads
|
||||
if head_size % 128 != 0:
|
||||
raise NotImplementedError("Head size must be a multiple of 128.")
|
||||
if alibi_slopes is not None:
|
||||
raise NotImplementedError("Alibi slopes is not supported.")
|
||||
if sliding_window is not None:
|
||||
raise NotImplementedError("Sliding window is not supported.")
|
||||
if kv_cache_dtype != "auto":
|
||||
raise NotImplementedError("FP8 KV cache dtype is not supported.")
|
||||
if blocksparse_params is not None:
|
||||
raise NotImplementedError("Blocksparse is not supported.")
|
||||
|
||||
if torch_xla.tpu.version() < 4:
|
||||
raise NotImplementedError("TPU version must be 4 or higher.")
|
||||
|
||||
self.megacore_mode = None
|
||||
tpu_type = torch_xla.tpu.get_tpu_env()["TYPE"].lower()
|
||||
if not tpu_type.endswith("lite"):
|
||||
if self.num_kv_heads % 2 == 0:
|
||||
self.megacore_mode = "kv_head"
|
||||
else:
|
||||
# NOTE(woosuk): If the batch size is not a multiple of 2, the
|
||||
# megacore mode will be None.
|
||||
self.megacore_mode = "batch"
|
||||
|
||||
def forward(
|
||||
self,
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
kv_cache: Tuple[Optional[torch.Tensor], Optional[torch.Tensor]],
|
||||
attn_metadata: PallasMetadata,
|
||||
kv_scale: float = 1.0,
|
||||
) -> torch.Tensor:
|
||||
"""Forward pass with Pallas attention.
|
||||
|
||||
Args:
|
||||
query: shape = [batch_size, seq_len, num_heads * head_size]
|
||||
key: shape = [batch_size, seq_len, num_kv_heads * head_size]
|
||||
value: shape = [batch_size, seq_len, num_kv_heads * head_size]
|
||||
key_cache = [num_kv_heads, num_blocks, block_size, head_size]
|
||||
value_cache = [num_kv_heads, num_blocks, block_size, head_size]
|
||||
attn_metadata: Metadata for attention.
|
||||
Returns:
|
||||
shape = [batch_size, seq_len, num_heads * head_size]
|
||||
"""
|
||||
assert kv_scale == 1.0
|
||||
batch_size, seq_len, hidden_size = query.shape
|
||||
query = query.view(batch_size, seq_len, self.num_heads, self.head_size)
|
||||
key = key.view(batch_size, seq_len, self.num_kv_heads, self.head_size)
|
||||
value = value.view(batch_size, seq_len, self.num_kv_heads,
|
||||
self.head_size)
|
||||
|
||||
if kv_cache[0] is not None:
|
||||
slot_mapping = attn_metadata.slot_mapping
|
||||
key_cache, value_cache = kv_cache
|
||||
write_to_kv_cache(key, value, key_cache, value_cache, slot_mapping)
|
||||
|
||||
query = query * self.scale
|
||||
if attn_metadata.num_prefills > 0:
|
||||
assert seq_len % 16 == 0, (
|
||||
"Pallas FlashAttention kernel requires seq_len to be a "
|
||||
f"multiple of 16 but got {seq_len}")
|
||||
|
||||
# Handle GQA/MQA.
|
||||
if self.num_kv_heads != self.num_heads:
|
||||
key = key.repeat_interleave(self.num_queries_per_kv, dim=-2)
|
||||
key = key.view(batch_size, seq_len, self.num_heads,
|
||||
self.head_size)
|
||||
value = value.repeat_interleave(self.num_queries_per_kv,
|
||||
dim=-2)
|
||||
value = value.view(batch_size, seq_len, self.num_heads,
|
||||
self.head_size)
|
||||
# FlashAttention requires [batch_size, num_heads, seq_len, d_model]
|
||||
# while the input is [batch_size, seq_len, num_heads, d_model].
|
||||
# Permute the input to match the required format.
|
||||
output = torch.ops.xla.flash_attention(
|
||||
query.permute(0, 2, 1, 3),
|
||||
key.permute(0, 2, 1, 3),
|
||||
value.permute(0, 2, 1, 3),
|
||||
True,
|
||||
)
|
||||
output = output.permute(0, 2, 1, 3)
|
||||
else:
|
||||
# Decoding run.
|
||||
assert kv_cache is not None
|
||||
|
||||
pages_per_compute_block = 16 # TODO(woosuk): Tune this value.
|
||||
if self.megacore_mode == "batch" and batch_size % 2 != 0:
|
||||
megacore_mode = None
|
||||
else:
|
||||
megacore_mode = self.megacore_mode
|
||||
|
||||
# NOTE(woosuk): A temporary workaround to avoid the error:
|
||||
# "xla::paged_attention() Expected a value of type 'str' for
|
||||
# argument 'megacore_mode' but instead found type 'NoneType'."
|
||||
if megacore_mode is not None:
|
||||
output = torch.ops.xla.paged_attention(
|
||||
query.squeeze(dim=1),
|
||||
key_cache,
|
||||
value_cache,
|
||||
attn_metadata.context_lens,
|
||||
attn_metadata.block_tables,
|
||||
pages_per_compute_block,
|
||||
megacore_mode=megacore_mode,
|
||||
)
|
||||
else:
|
||||
output = torch.ops.xla.paged_attention(
|
||||
query.squeeze(dim=1),
|
||||
key_cache,
|
||||
value_cache,
|
||||
attn_metadata.context_lens,
|
||||
attn_metadata.block_tables,
|
||||
pages_per_compute_block,
|
||||
)
|
||||
|
||||
# Reshape the output tensor.
|
||||
return output.reshape(batch_size, seq_len, hidden_size)
|
||||
|
||||
|
||||
def write_to_kv_cache(
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
key_cache: torch.Tensor,
|
||||
value_cache: torch.Tensor,
|
||||
slot_mapping: torch.Tensor,
|
||||
) -> None:
|
||||
torch.ops.xla.dynamo_set_buffer_donor_(key_cache, True)
|
||||
torch.ops.xla.dynamo_set_buffer_donor_(value_cache, True)
|
||||
|
||||
key = key.flatten(0, 2)
|
||||
value = value.flatten(0, 2)
|
||||
key_cache = key_cache.flatten(0, 2)
|
||||
value_cache = value_cache.flatten(0, 2)
|
||||
key_cache.index_copy_(0, slot_mapping, key)
|
||||
value_cache.index_copy_(0, slot_mapping, value)
|
@ -7,7 +7,7 @@ import torch
|
||||
import vllm.envs as envs
|
||||
from vllm.attention.backends.abstract import AttentionBackend
|
||||
from vllm.logger import init_logger
|
||||
from vllm.utils import is_cpu, is_hip
|
||||
from vllm.utils import is_cpu, is_hip, is_tpu
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
@ -18,6 +18,7 @@ class _Backend(enum.Enum):
|
||||
ROCM_FLASH = enum.auto()
|
||||
TORCH_SDPA = enum.auto()
|
||||
FLASHINFER = enum.auto()
|
||||
PALLAS = enum.auto()
|
||||
|
||||
|
||||
@lru_cache(maxsize=None)
|
||||
@ -66,6 +67,10 @@ def get_attn_backend(
|
||||
"Please make sure --enforce-eager is set.")
|
||||
from vllm.attention.backends.flashinfer import FlashInferBackend
|
||||
return FlashInferBackend
|
||||
elif backend == _Backend.PALLAS:
|
||||
logger.info("Using Pallas backend.")
|
||||
from vllm.attention.backends.pallas import PallasAttentionBackend
|
||||
return PallasAttentionBackend
|
||||
else:
|
||||
raise ValueError("Invalid attention backend.")
|
||||
|
||||
@ -80,7 +85,6 @@ def which_attn_to_use(
|
||||
block_size: int,
|
||||
) -> _Backend:
|
||||
"""Returns which flash attention backend to use."""
|
||||
|
||||
# Default case.
|
||||
selected_backend = _Backend.FLASH_ATTN
|
||||
|
||||
@ -100,6 +104,11 @@ def which_attn_to_use(
|
||||
logger.info("Cannot use %s backend on CPU.", selected_backend)
|
||||
return _Backend.TORCH_SDPA
|
||||
|
||||
if is_tpu():
|
||||
if selected_backend != _Backend.PALLAS:
|
||||
logger.info("Cannot use %s backend on TPU.", selected_backend)
|
||||
return _Backend.PALLAS
|
||||
|
||||
if is_hip():
|
||||
# AMD GPUs.
|
||||
selected_backend = (_Backend.ROCM_FLASH if selected_backend
|
||||
|
@ -11,7 +11,7 @@ from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.quantization import QUANTIZATION_METHODS
|
||||
from vllm.model_executor.models import ModelRegistry
|
||||
from vllm.transformers_utils.config import get_config, get_hf_text_config
|
||||
from vllm.utils import get_cpu_memory, is_cpu, is_hip, is_neuron
|
||||
from vllm.utils import get_cpu_memory, is_cpu, is_hip, is_neuron, is_tpu
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from ray.util.placement_group import PlacementGroup
|
||||
@ -748,6 +748,8 @@ class DeviceConfig:
|
||||
# Automated device type detection
|
||||
if is_neuron():
|
||||
self.device_type = "neuron"
|
||||
elif is_tpu():
|
||||
self.device_type = "tpu"
|
||||
elif is_cpu():
|
||||
self.device_type = "cpu"
|
||||
else:
|
||||
@ -761,6 +763,8 @@ class DeviceConfig:
|
||||
# Some device types require processing inputs on CPU
|
||||
if self.device_type in ["neuron"]:
|
||||
self.device = torch.device("cpu")
|
||||
elif self.device_type in ["tpu"]:
|
||||
self.device = None
|
||||
else:
|
||||
# Set device with device type
|
||||
self.device = torch.device(self.device_type)
|
||||
|
@ -504,7 +504,7 @@ class EngineArgs:
|
||||
parser.add_argument("--device",
|
||||
type=str,
|
||||
default=EngineArgs.device,
|
||||
choices=["auto", "cuda", "neuron", "cpu"],
|
||||
choices=["auto", "cuda", "neuron", "cpu", "tpu"],
|
||||
help='Device type for vLLM execution.')
|
||||
|
||||
# Related to Vision-language models such as llava
|
||||
|
@ -375,6 +375,9 @@ class AsyncLLMEngine:
|
||||
if engine_config.device_config.device_type == "neuron":
|
||||
from vllm.executor.neuron_executor import NeuronExecutorAsync
|
||||
executor_class = NeuronExecutorAsync
|
||||
elif engine_config.device_config.device_type == "tpu":
|
||||
from vllm.executor.tpu_executor import TPUExecutorAsync
|
||||
executor_class = TPUExecutorAsync
|
||||
elif engine_config.device_config.device_type == "cpu":
|
||||
assert distributed_executor_backend is None, (
|
||||
"Distributed execution is not supported with the CPU backend.")
|
||||
|
@ -341,6 +341,9 @@ class LLMEngine:
|
||||
if engine_config.device_config.device_type == "neuron":
|
||||
from vllm.executor.neuron_executor import NeuronExecutor
|
||||
executor_class = NeuronExecutor
|
||||
elif engine_config.device_config.device_type == "tpu":
|
||||
from vllm.executor.tpu_executor import TPUExecutor
|
||||
executor_class = TPUExecutor
|
||||
elif engine_config.device_config.device_type == "cpu":
|
||||
from vllm.executor.cpu_executor import CPUExecutor
|
||||
executor_class = CPUExecutor
|
||||
|
@ -27,6 +27,7 @@ if TYPE_CHECKING:
|
||||
VLLM_TRACE_FUNCTION: int = 0
|
||||
VLLM_ATTENTION_BACKEND: Optional[str] = None
|
||||
VLLM_CPU_KVCACHE_SPACE: int = 0
|
||||
VLLM_XLA_CACHE_PATH: str = "~/.vllm/xla_cache/"
|
||||
VLLM_USE_RAY_COMPILED_DAG: bool = False
|
||||
VLLM_WORKER_MULTIPROC_METHOD: str = "spawn"
|
||||
VLLM_IMAGE_FETCH_TIMEOUT: int = 5
|
||||
@ -217,6 +218,11 @@ environment_variables: Dict[str, Callable[[], Any]] = {
|
||||
# Default is 5 seconds
|
||||
"VLLM_IMAGE_FETCH_TIMEOUT":
|
||||
lambda: int(os.getenv("VLLM_IMAGE_FETCH_TIMEOUT", "5")),
|
||||
|
||||
# Path to the XLA persistent cache directory.
|
||||
# Only used for XLA devices such as TPUs.
|
||||
"VLLM_XLA_CACHE_PATH":
|
||||
lambda: os.getenv("VLLM_XLA_CACHE_PATH", "~/.vllm/xla_cache/"),
|
||||
}
|
||||
|
||||
# end-env-vars-definition
|
||||
|
101
vllm/executor/tpu_executor.py
Normal file
101
vllm/executor/tpu_executor.py
Normal file
@ -0,0 +1,101 @@
|
||||
from typing import List, Set, Tuple
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.executor.executor_base import ExecutorAsyncBase, ExecutorBase
|
||||
from vllm.logger import init_logger
|
||||
from vllm.lora.request import LoRARequest
|
||||
from vllm.sequence import ExecuteModelRequest, SamplerOutput
|
||||
from vllm.utils import (get_distributed_init_method, get_ip, get_open_port,
|
||||
make_async)
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
class TPUExecutor(ExecutorBase):
|
||||
|
||||
def _init_executor(self) -> None:
|
||||
assert not self.scheduler_config.chunked_prefill_enabled, (
|
||||
"Chunked prefill is not yet supported for TPU backend")
|
||||
assert not self.speculative_config, (
|
||||
"Speculative decoding is not yet supported for TPU backend")
|
||||
if self.model_config.dtype in (torch.float16, torch.float32):
|
||||
logger.warning(
|
||||
"The TPU backend currently does not support %s. "
|
||||
"Using bfloat16 instead.", self.model_config.dtype)
|
||||
self.model_config.dtype = torch.bfloat16
|
||||
|
||||
# Instantiate the worker and load the model to the device.
|
||||
self._init_worker()
|
||||
|
||||
def _init_worker(self):
|
||||
from vllm.worker.tpu_worker import TPUWorker
|
||||
|
||||
assert self.parallel_config.world_size == 1, (
|
||||
"TPUExecutor currently only supports a single TPU chip.")
|
||||
distributed_init_method = get_distributed_init_method(
|
||||
get_ip(), get_open_port())
|
||||
self.driver_worker = TPUWorker(
|
||||
self.model_config,
|
||||
self.parallel_config,
|
||||
self.scheduler_config,
|
||||
self.device_config,
|
||||
self.cache_config,
|
||||
self.load_config,
|
||||
self.vision_language_config,
|
||||
local_rank=0,
|
||||
rank=0,
|
||||
distributed_init_method=distributed_init_method,
|
||||
)
|
||||
self.driver_worker.init_device()
|
||||
self.driver_worker.load_model()
|
||||
|
||||
def initialize_cache(
|
||||
self,
|
||||
num_gpu_blocks: int,
|
||||
num_cpu_blocks: int,
|
||||
) -> None:
|
||||
"""Initialize the KV cache by invoking the underlying worker."""
|
||||
# NOTE: This is logged in the executor because there can be >1 worker
|
||||
# with other executors. We could log in the engine level, but work
|
||||
# remains to abstract away the device for non-GPU configurations.
|
||||
logger.info("# TPU blocks: %d, # CPU blocks: %d", num_gpu_blocks,
|
||||
num_cpu_blocks)
|
||||
self.driver_worker.initialize_cache(num_gpu_blocks, num_cpu_blocks)
|
||||
|
||||
def determine_num_available_blocks(self) -> Tuple[int, int]:
|
||||
"""Determine the number of available KV blocks by invoking the
|
||||
underlying worker.
|
||||
"""
|
||||
return self.driver_worker.determine_num_available_blocks()
|
||||
|
||||
def execute_model(
|
||||
self,
|
||||
execute_model_req: ExecuteModelRequest,
|
||||
) -> List[SamplerOutput]:
|
||||
output = self.driver_worker.execute_model(execute_model_req)
|
||||
return output
|
||||
|
||||
def add_lora(self, lora_request: LoRARequest) -> bool:
|
||||
raise NotImplementedError("LoRA is not implemented for TPU backend.")
|
||||
|
||||
def remove_lora(self, lora_id: int) -> bool:
|
||||
raise NotImplementedError("LoRA is not implemented for TPU backend.")
|
||||
|
||||
def list_loras(self) -> Set[int]:
|
||||
raise NotImplementedError("LoRA is not implemented for TPU backend.")
|
||||
|
||||
def check_health(self) -> None:
|
||||
# TPUExecutor will always be healthy as long as it's running.
|
||||
return
|
||||
|
||||
|
||||
class TPUExecutorAsync(TPUExecutor, ExecutorAsyncBase):
|
||||
|
||||
async def execute_model_async(
|
||||
self,
|
||||
sexecute_model_req: ExecuteModelRequest,
|
||||
) -> SamplerOutput:
|
||||
output = await make_async(self.driver_worker.execute_model
|
||||
)(sexecute_model_req)
|
||||
return output
|
@ -1,6 +1,6 @@
|
||||
import torch.nn as nn
|
||||
|
||||
from vllm.utils import is_cpu, is_hip
|
||||
from vllm.utils import is_cpu, is_hip, is_tpu
|
||||
|
||||
|
||||
class CustomOp(nn.Module):
|
||||
@ -56,5 +56,7 @@ class CustomOp(nn.Module):
|
||||
return self.forward_hip
|
||||
elif is_cpu():
|
||||
return self.forward_cpu
|
||||
elif is_tpu():
|
||||
return self.forward_tpu
|
||||
else:
|
||||
return self.forward_cuda
|
||||
|
@ -28,6 +28,7 @@ import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from vllm.model_executor.custom_op import CustomOp
|
||||
from vllm.utils import is_tpu
|
||||
|
||||
|
||||
def _rotate_neox(x: torch.Tensor) -> torch.Tensor:
|
||||
@ -43,6 +44,19 @@ def _rotate_gptj(x: torch.Tensor) -> torch.Tensor:
|
||||
return x.flatten(-2)
|
||||
|
||||
|
||||
def _apply_rotary_emb(
|
||||
x: torch.Tensor,
|
||||
freqs_cis: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
x_ = torch.view_as_complex(
|
||||
torch.stack(torch.chunk(x.transpose(1, 2).float(), 2, dim=-1), dim=-1))
|
||||
x_out = torch.view_as_real(x_ * freqs_cis).type_as(x)
|
||||
x_out = torch.cat(torch.chunk(x_out, 2, dim=-1), dim=-2)
|
||||
x_out = x_out.reshape(x_out.shape[0], x_out.shape[1], x_out.shape[2],
|
||||
-1).transpose(1, 2)
|
||||
return x_out
|
||||
|
||||
|
||||
class RotaryEmbedding(CustomOp):
|
||||
"""Original rotary positional embedding."""
|
||||
|
||||
@ -64,8 +78,14 @@ class RotaryEmbedding(CustomOp):
|
||||
self.dtype = dtype
|
||||
|
||||
cache = self._compute_cos_sin_cache()
|
||||
cache = cache.to(dtype)
|
||||
self.register_buffer("cos_sin_cache", cache, persistent=False)
|
||||
self.use_native2 = is_tpu() and is_neox_style
|
||||
if not self.use_native2:
|
||||
cache = cache.to(dtype)
|
||||
self.register_buffer("cos_sin_cache", cache, persistent=False)
|
||||
else:
|
||||
cos, sin = cache.chunk(2, dim=-1)
|
||||
freqs_cis = cos + 1j * sin
|
||||
self.register_buffer("freqs_cis", freqs_cis, persistent=False)
|
||||
|
||||
def _compute_inv_freq(self, base: Union[int, float]) -> torch.Tensor:
|
||||
"""Compute the inverse frequency."""
|
||||
@ -100,7 +120,11 @@ class RotaryEmbedding(CustomOp):
|
||||
key: torch.Tensor,
|
||||
offsets: Optional[torch.Tensor] = None,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
"""PyTorch-native implementation equivalent to forward()."""
|
||||
"""A PyTorch-native implementation equivalent to forward().
|
||||
|
||||
This method mimics the implementation of the custom CUDA kernel
|
||||
used in `forward_cuda()`.
|
||||
"""
|
||||
query = query.view(*query.shape[:-1], -1, self.head_size)
|
||||
key = key.view(*key.shape[:-1], -1, self.head_size)
|
||||
|
||||
@ -138,6 +162,42 @@ class RotaryEmbedding(CustomOp):
|
||||
key = key.flatten(-2)
|
||||
return query, key
|
||||
|
||||
def forward_native2(
|
||||
self,
|
||||
positions: torch.Tensor,
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
offsets: Optional[torch.Tensor] = None,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
"""Another PyTorch-native implementation of forward().
|
||||
|
||||
This method might perform better than `forward_native()` when compiled.
|
||||
"""
|
||||
if positions.dim() == 1:
|
||||
batch_size = 1
|
||||
seq_len = positions.shape[0]
|
||||
else:
|
||||
batch_size, seq_len = positions.shape
|
||||
if offsets is not None:
|
||||
positions = positions + offsets
|
||||
freqs_cis = self.freqs_cis.index_select(0, positions.flatten())
|
||||
freqs_cis = freqs_cis.view(batch_size, 1, seq_len, -1)
|
||||
|
||||
query_shape = query.shape
|
||||
query = query.view(batch_size, seq_len, -1, self.head_size)
|
||||
query_rot = query[..., :self.rotary_dim]
|
||||
query_pass = query[..., self.rotary_dim:]
|
||||
query_rot = _apply_rotary_emb(query_rot, freqs_cis)
|
||||
query = torch.cat((query_rot, query_pass), dim=-1).reshape(query_shape)
|
||||
|
||||
key_shape = key.shape
|
||||
key = key.view(batch_size, seq_len, -1, self.head_size)
|
||||
key_rot = key[..., :self.rotary_dim]
|
||||
key_pass = key[..., self.rotary_dim:]
|
||||
key_rot = _apply_rotary_emb(key_rot, freqs_cis)
|
||||
key = torch.cat((key_rot, key_pass), dim=-1).reshape(key_shape)
|
||||
return query, key
|
||||
|
||||
def forward_cuda(
|
||||
self,
|
||||
positions: torch.Tensor,
|
||||
@ -161,6 +221,17 @@ class RotaryEmbedding(CustomOp):
|
||||
self.cos_sin_cache, self.is_neox_style)
|
||||
return query, key
|
||||
|
||||
def forward_tpu(
|
||||
self,
|
||||
positions: torch.Tensor,
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
offsets: Optional[torch.Tensor] = None,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
forward_fn = (self.forward_native2
|
||||
if self.use_native2 else self.forward_native)
|
||||
return forward_fn(positions, query, key, offsets)
|
||||
|
||||
def extra_repr(self) -> str:
|
||||
s = f"head_size={self.head_size}, rotary_dim={self.rotary_dim}"
|
||||
s += f", max_position_embeddings={self.max_position_embeddings}"
|
||||
|
@ -34,6 +34,7 @@ from vllm.model_executor.model_loader.weight_utils import (
|
||||
pt_weights_iterator, safetensors_weights_iterator)
|
||||
from vllm.model_executor.models.vlm_base import VisionLanguageModelBase
|
||||
from vllm.model_executor.utils import set_weight_attrs
|
||||
from vllm.utils import is_tpu
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
@ -227,12 +228,26 @@ class DefaultModelLoader(BaseModelLoader):
|
||||
if self.load_config.load_format == LoadFormat.NPCACHE:
|
||||
# Currently np_cache only support *.bin checkpoints
|
||||
assert use_safetensors is False
|
||||
return np_cache_weights_iterator(model_name_or_path,
|
||||
self.load_config.download_dir,
|
||||
hf_folder, hf_weights_files)
|
||||
if use_safetensors:
|
||||
return safetensors_weights_iterator(hf_weights_files)
|
||||
return pt_weights_iterator(hf_weights_files)
|
||||
weights_iterator = np_cache_weights_iterator(
|
||||
model_name_or_path, self.load_config.download_dir, hf_folder,
|
||||
hf_weights_files)
|
||||
elif use_safetensors:
|
||||
weights_iterator = safetensors_weights_iterator(hf_weights_files)
|
||||
else:
|
||||
weights_iterator = pt_weights_iterator(hf_weights_files)
|
||||
|
||||
if is_tpu():
|
||||
# In PyTorch XLA, we should call `xm.mark_step` frequently so that
|
||||
# not too many ops are accumulated in the XLA program.
|
||||
import torch_xla.core.xla_model as xm
|
||||
|
||||
def _xla_weights_iterator(iterator: Generator):
|
||||
for weights in iterator:
|
||||
yield weights
|
||||
xm.mark_step()
|
||||
|
||||
weights_iterator = _xla_weights_iterator(weights_iterator)
|
||||
return weights_iterator
|
||||
|
||||
def load_model(self, *, model_config: ModelConfig,
|
||||
device_config: DeviceConfig,
|
||||
|
@ -146,6 +146,15 @@ def is_neuron() -> bool:
|
||||
return transformers_neuronx is not None
|
||||
|
||||
|
||||
@lru_cache(maxsize=None)
|
||||
def is_tpu() -> bool:
|
||||
try:
|
||||
import libtpu
|
||||
except ImportError:
|
||||
libtpu = None
|
||||
return libtpu is not None
|
||||
|
||||
|
||||
@lru_cache(maxsize=None)
|
||||
def get_max_shared_memory_bytes(gpu: int = 0) -> int:
|
||||
"""Returns the maximum shared memory per thread block in bytes."""
|
||||
@ -546,6 +555,11 @@ def maybe_expand_dim(tensor: torch.Tensor,
|
||||
return tensor
|
||||
|
||||
|
||||
def get_dtype_size(dtype: torch.dtype) -> int:
|
||||
"""Get the size of the data type in bytes."""
|
||||
return torch.tensor([], dtype=dtype).element_size()
|
||||
|
||||
|
||||
def merge_dicts(dict1: Dict[Any, List[Any]],
|
||||
dict2: Dict[Any, List[Any]]) -> Dict[Any, List[Any]]:
|
||||
"""Merge 2 dicts that have key -> List of items.
|
||||
|
@ -6,7 +6,8 @@ import torch
|
||||
from vllm.attention import get_attn_backend
|
||||
from vllm.config import CacheConfig, ModelConfig, ParallelConfig
|
||||
from vllm.logger import init_logger
|
||||
from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE, is_pin_memory_available
|
||||
from vllm.utils import (STR_DTYPE_TO_TORCH_DTYPE, get_dtype_size,
|
||||
is_pin_memory_available)
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
@ -108,9 +109,5 @@ class CacheEngine:
|
||||
dtype = model_config.dtype
|
||||
else:
|
||||
dtype = STR_DTYPE_TO_TORCH_DTYPE[cache_config.cache_dtype]
|
||||
dtype_size = _get_dtype_size(dtype)
|
||||
dtype_size = get_dtype_size(dtype)
|
||||
return dtype_size * total
|
||||
|
||||
|
||||
def _get_dtype_size(dtype: torch.dtype) -> int:
|
||||
return torch.tensor([], dtype=dtype).element_size()
|
||||
|
525
vllm/worker/tpu_model_runner.py
Normal file
525
vllm/worker/tpu_model_runner.py
Normal file
@ -0,0 +1,525 @@
|
||||
import time
|
||||
from typing import List, Optional, Tuple
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch_xla.core.xla_model as xm
|
||||
|
||||
from vllm.attention import AttentionMetadata, get_attn_backend
|
||||
from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, ModelConfig,
|
||||
ParallelConfig, SchedulerConfig, VisionLanguageConfig)
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.model_loader import get_model
|
||||
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
||||
from vllm.sequence import (CompletionSequenceGroupOutput, Logprob,
|
||||
SamplerOutput, SequenceGroupMetadata,
|
||||
SequenceOutput)
|
||||
from vllm.utils import make_tensor_with_pad
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
_PAD_SLOT_ID = 0 # FIXME(woosuk)
|
||||
|
||||
|
||||
class TPUModelRunner:
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model_config: ModelConfig,
|
||||
parallel_config: ParallelConfig,
|
||||
scheduler_config: SchedulerConfig,
|
||||
device_config: DeviceConfig,
|
||||
cache_config: CacheConfig,
|
||||
load_config: LoadConfig,
|
||||
vision_language_config: Optional[VisionLanguageConfig] = None,
|
||||
):
|
||||
self.model_config = model_config
|
||||
self.parallel_config = parallel_config
|
||||
self.scheduler_config = scheduler_config
|
||||
self.device_config = device_config
|
||||
self.cache_config = cache_config
|
||||
self.load_config = load_config
|
||||
self.vision_language_config = vision_language_config
|
||||
|
||||
self.block_size = self.cache_config.block_size
|
||||
self.max_num_blocks_per_seq = (self.model_config.max_model_len //
|
||||
self.block_size)
|
||||
self.block_tables = np.zeros(
|
||||
(self.scheduler_config.max_num_seqs, self.max_num_blocks_per_seq),
|
||||
dtype=np.int32)
|
||||
self.attn_backend = get_attn_backend(
|
||||
self.model_config.get_num_attention_heads(self.parallel_config),
|
||||
self.model_config.get_head_size(),
|
||||
self.model_config.get_num_kv_heads(self.parallel_config),
|
||||
self.model_config.get_sliding_window(),
|
||||
self.model_config.dtype,
|
||||
self.cache_config.cache_dtype,
|
||||
self.block_size,
|
||||
False,
|
||||
)
|
||||
|
||||
def load_model(self) -> None:
|
||||
self.device = self.device_config.device
|
||||
|
||||
model = get_model(
|
||||
model_config=self.model_config,
|
||||
load_config=self.load_config,
|
||||
device_config=self.device_config,
|
||||
parallel_config=self.parallel_config,
|
||||
cache_config=self.cache_config,
|
||||
scheduler_config=self.scheduler_config,
|
||||
vision_language_config=self.vision_language_config,
|
||||
lora_config=None,
|
||||
)
|
||||
xm.wait_device_ops()
|
||||
|
||||
model = ModelWrapper(model)
|
||||
self.model = torch.compile(model, backend="openxla", fullgraph=True)
|
||||
|
||||
def _dummy_run(
|
||||
self,
|
||||
batch_size: int,
|
||||
seq_len: int,
|
||||
kv_caches: List[Tuple[torch.Tensor, torch.Tensor]],
|
||||
is_prompt: bool,
|
||||
) -> None:
|
||||
if is_prompt:
|
||||
seq_len = (seq_len + 15) // 16 * 16
|
||||
token_ids = torch.zeros((batch_size, seq_len),
|
||||
dtype=torch.int32,
|
||||
device=self.device)
|
||||
position_ids = torch.zeros((batch_size, seq_len),
|
||||
dtype=torch.int32,
|
||||
device=self.device)
|
||||
slot_mapping = torch.zeros((batch_size, seq_len),
|
||||
dtype=torch.int64,
|
||||
device=self.device)
|
||||
attn_metadata = self.attn_backend.make_metadata(
|
||||
num_prefills=batch_size,
|
||||
num_prefill_tokens=batch_size * seq_len,
|
||||
num_decode_tokens=0,
|
||||
slot_mapping=slot_mapping,
|
||||
block_tables=None,
|
||||
context_lens=None,
|
||||
)
|
||||
input_lens = torch.ones((batch_size, ),
|
||||
dtype=torch.int32,
|
||||
device=self.device)
|
||||
else:
|
||||
assert seq_len == 1
|
||||
token_ids = torch.zeros((batch_size, seq_len),
|
||||
dtype=torch.int32,
|
||||
device=self.device)
|
||||
position_ids = torch.zeros((batch_size, seq_len),
|
||||
dtype=torch.int32,
|
||||
device=self.device)
|
||||
slot_mapping = torch.zeros((batch_size, seq_len),
|
||||
dtype=torch.int64,
|
||||
device=self.device)
|
||||
block_tables = torch.zeros(
|
||||
(batch_size, self.max_num_blocks_per_seq),
|
||||
dtype=torch.int32,
|
||||
device=self.device)
|
||||
context_lens = torch.ones((batch_size, ),
|
||||
dtype=torch.int32,
|
||||
device=self.device)
|
||||
input_lens = torch.ones((batch_size, ),
|
||||
dtype=torch.int32,
|
||||
device=self.device)
|
||||
attn_metadata = self.attn_backend.make_metadata(
|
||||
num_prefills=0,
|
||||
num_prefill_tokens=0,
|
||||
num_decode_tokens=batch_size * seq_len,
|
||||
slot_mapping=slot_mapping,
|
||||
block_tables=block_tables,
|
||||
context_lens=context_lens,
|
||||
)
|
||||
t = torch.ones((batch_size, ), dtype=torch.float32, device=self.device)
|
||||
p = torch.ones((batch_size, ), dtype=torch.float32, device=self.device)
|
||||
|
||||
# Dummy run.
|
||||
self.model(token_ids, position_ids, kv_caches, attn_metadata,
|
||||
input_lens, t, p)
|
||||
|
||||
def warmup_model(
|
||||
self,
|
||||
kv_caches: List[Tuple[torch.Tensor, torch.Tensor]],
|
||||
) -> None:
|
||||
# Prefill
|
||||
logger.info("Compiling the model with different input shapes...")
|
||||
start = time.time()
|
||||
for batch_size in [1]:
|
||||
seq_len = 16
|
||||
while True:
|
||||
self._dummy_run(batch_size, seq_len, kv_caches, is_prompt=True)
|
||||
xm.wait_device_ops()
|
||||
logger.info("batch_size: %d, seq_len: %d", batch_size, seq_len)
|
||||
|
||||
if seq_len >= self.model_config.max_model_len:
|
||||
break
|
||||
num_tokens = batch_size * seq_len
|
||||
if num_tokens >= self.scheduler_config.max_num_batched_tokens:
|
||||
break
|
||||
seq_len = seq_len * 2
|
||||
|
||||
end = time.time()
|
||||
logger.info("Compilation for prefill done in %.2f s.", end - start)
|
||||
|
||||
# Decode
|
||||
start = time.time()
|
||||
seq_len = 1
|
||||
batch_size = 1
|
||||
while True:
|
||||
self._dummy_run(batch_size, seq_len, kv_caches, is_prompt=False)
|
||||
xm.wait_device_ops()
|
||||
logger.info("batch_size: %d, seq_len: %d", batch_size, seq_len)
|
||||
|
||||
if batch_size >= self.scheduler_config.max_num_seqs:
|
||||
break
|
||||
batch_size = batch_size + 16 if batch_size >= 16 else batch_size * 2
|
||||
|
||||
end = time.time()
|
||||
logger.info("Compilation for decode done in %.2f s.", end - start)
|
||||
|
||||
def _prepare_prompt(
|
||||
self,
|
||||
seq_group_metadata_list: List[SequenceGroupMetadata],
|
||||
):
|
||||
assert len(seq_group_metadata_list) > 0
|
||||
input_tokens: List[List[int]] = []
|
||||
input_positions: List[List[int]] = []
|
||||
prompt_lens: List[int] = []
|
||||
slot_mapping: List[List[int]] = []
|
||||
|
||||
for seq_group_metadata in seq_group_metadata_list:
|
||||
assert seq_group_metadata.is_prompt
|
||||
seq_ids = list(seq_group_metadata.seq_data.keys())
|
||||
assert len(seq_ids) == 1
|
||||
seq_id = seq_ids[0]
|
||||
|
||||
seq_data = seq_group_metadata.seq_data[seq_id]
|
||||
# Could include output tokens when a request is preempted.
|
||||
prompt_tokens = seq_data.get_token_ids()
|
||||
prompt_len = len(prompt_tokens)
|
||||
prompt_lens.append(prompt_len)
|
||||
|
||||
input_tokens.append(prompt_tokens)
|
||||
input_positions.append(list(range(prompt_len)))
|
||||
|
||||
assert seq_group_metadata.block_tables is not None
|
||||
block_table = seq_group_metadata.block_tables[seq_id]
|
||||
slot_mapping.append([])
|
||||
for i in range(prompt_len):
|
||||
block_number = block_table[i // self.block_size]
|
||||
block_offset = i % self.block_size
|
||||
slot = block_number * self.block_size + block_offset
|
||||
slot_mapping[-1].append(slot)
|
||||
|
||||
assert len(prompt_lens) > 0
|
||||
num_prefills = len(prompt_lens)
|
||||
num_prefill_tokens = sum(prompt_lens)
|
||||
|
||||
# Add paddings to make the shape [batch_size, max_prompt_len] where
|
||||
# max_prompt_len is smallest power of 2 that is greater than or equal
|
||||
# to the maximum prompt length.
|
||||
# We need the 2D input shape because the Pallas FlashAttention kernel
|
||||
# does not support packed 1D inputs.
|
||||
# We pad the seq_len to powers of 2 to reduce the compilation overhead.
|
||||
max_prompt_len = _get_padded_prefill_len(max(prompt_lens))
|
||||
input_tokens = make_tensor_with_pad(input_tokens,
|
||||
max_prompt_len,
|
||||
pad=0,
|
||||
dtype=torch.int32,
|
||||
device=self.device)
|
||||
input_positions = make_tensor_with_pad(input_positions,
|
||||
max_prompt_len,
|
||||
pad=0,
|
||||
dtype=torch.int32,
|
||||
device=self.device)
|
||||
slot_mapping = make_tensor_with_pad(slot_mapping,
|
||||
max_prompt_len,
|
||||
pad=_PAD_SLOT_ID,
|
||||
dtype=torch.int64,
|
||||
device=self.device)
|
||||
prompt_lens = torch.tensor(prompt_lens,
|
||||
dtype=torch.int32,
|
||||
device=self.device)
|
||||
attn_metadata = self.attn_backend.make_metadata(
|
||||
num_prefills=num_prefills,
|
||||
num_prefill_tokens=num_prefill_tokens, # NOTE: This is not used.
|
||||
num_decode_tokens=0,
|
||||
slot_mapping=slot_mapping,
|
||||
block_tables=None,
|
||||
context_lens=None,
|
||||
)
|
||||
return input_tokens, input_positions, attn_metadata, prompt_lens
|
||||
|
||||
def _prepare_decode(
|
||||
self,
|
||||
seq_group_metadata_list: List[SequenceGroupMetadata],
|
||||
):
|
||||
assert len(seq_group_metadata_list) > 0
|
||||
input_tokens: List[List[int]] = []
|
||||
input_positions: List[List[int]] = []
|
||||
slot_mapping: List[List[int]] = []
|
||||
context_lens: List[int] = []
|
||||
num_seq_groups = len(seq_group_metadata_list)
|
||||
batch_size = _get_padded_batch_size(num_seq_groups)
|
||||
|
||||
for i, seq_group_metadata in enumerate(seq_group_metadata_list):
|
||||
assert not seq_group_metadata.is_prompt
|
||||
|
||||
seq_ids = list(seq_group_metadata.seq_data.keys())
|
||||
|
||||
for seq_id in seq_ids:
|
||||
seq_data = seq_group_metadata.seq_data[seq_id]
|
||||
generation_token = seq_data.get_last_token_id()
|
||||
input_tokens.append([generation_token])
|
||||
|
||||
seq_len = seq_data.get_len()
|
||||
position = seq_len - 1
|
||||
input_positions.append([position])
|
||||
context_lens.append(seq_len)
|
||||
|
||||
assert seq_group_metadata.block_tables is not None
|
||||
block_table = seq_group_metadata.block_tables[seq_id]
|
||||
self.block_tables[i, :len(block_table)] = block_table
|
||||
|
||||
block_number = block_table[position // self.block_size]
|
||||
block_offset = position % self.block_size
|
||||
slot = block_number * self.block_size + block_offset
|
||||
slot_mapping.append([slot])
|
||||
|
||||
num_paddings = batch_size - num_seq_groups
|
||||
input_tokens = input_tokens + [[0]] * num_paddings
|
||||
input_positions = input_positions + [[0]] * num_paddings
|
||||
slot_mapping = slot_mapping + [[_PAD_SLOT_ID]] * num_paddings
|
||||
context_lens = context_lens + [0] * num_paddings
|
||||
|
||||
input_tokens = torch.tensor(input_tokens,
|
||||
dtype=torch.int32,
|
||||
device=self.device)
|
||||
input_positions = torch.tensor(input_positions,
|
||||
dtype=torch.int32,
|
||||
device=self.device)
|
||||
slot_mapping = torch.tensor(slot_mapping,
|
||||
dtype=torch.int64,
|
||||
device=self.device)
|
||||
context_lens = torch.tensor(context_lens,
|
||||
dtype=torch.int32,
|
||||
device=self.device)
|
||||
block_tables = torch.tensor(self.block_tables[:batch_size],
|
||||
dtype=torch.int32,
|
||||
device=self.device)
|
||||
input_lens = torch.tensor([1] * batch_size,
|
||||
dtype=torch.int32,
|
||||
device=self.device)
|
||||
attn_metadata = self.attn_backend.make_metadata(
|
||||
num_prefills=0,
|
||||
num_prefill_tokens=0,
|
||||
num_decode_tokens=batch_size,
|
||||
slot_mapping=slot_mapping,
|
||||
block_tables=block_tables,
|
||||
context_lens=context_lens,
|
||||
)
|
||||
return input_tokens, input_positions, attn_metadata, input_lens
|
||||
|
||||
def _prepare_sample(
|
||||
self,
|
||||
seq_group_metadata_list: List[SequenceGroupMetadata],
|
||||
padded_batch_size: int,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
assert len(seq_group_metadata_list) > 0
|
||||
t = []
|
||||
p = []
|
||||
for seq_group_metadata in seq_group_metadata_list:
|
||||
assert seq_group_metadata.sampling_params is not None
|
||||
sampling_params = seq_group_metadata.sampling_params
|
||||
|
||||
t.append(sampling_params.temperature
|
||||
if sampling_params.temperature >= 1e-5 else 1e-5)
|
||||
p.append(sampling_params.top_p)
|
||||
num_paddings = padded_batch_size - len(seq_group_metadata_list)
|
||||
t += [1.0] * num_paddings
|
||||
p += [1.0] * num_paddings
|
||||
|
||||
t = torch.tensor(t, dtype=torch.float32, device=self.device)
|
||||
p = torch.tensor(p, dtype=torch.float32, device=self.device)
|
||||
return t, p
|
||||
|
||||
def prepare_inputs(
|
||||
self,
|
||||
seq_group_metadata_list: Optional[List[SequenceGroupMetadata]],
|
||||
):
|
||||
assert seq_group_metadata_list is not None
|
||||
assert len(seq_group_metadata_list) > 0
|
||||
# NOTE: We assume that all sequences in the group are all prompts or
|
||||
# all decodes.
|
||||
if seq_group_metadata_list[0].is_prompt:
|
||||
inputs = self._prepare_prompt(seq_group_metadata_list)
|
||||
else:
|
||||
inputs = self._prepare_decode(seq_group_metadata_list)
|
||||
padded_batch_size = inputs[0].shape[0]
|
||||
sample_inputs = self._prepare_sample(seq_group_metadata_list,
|
||||
padded_batch_size)
|
||||
return inputs + sample_inputs
|
||||
|
||||
def _execute_model(
|
||||
self,
|
||||
seq_group_metadata_list: List[SequenceGroupMetadata],
|
||||
kv_caches: List[Tuple[torch.Tensor, torch.Tensor]],
|
||||
) -> List[CompletionSequenceGroupOutput]:
|
||||
inputs = self.prepare_inputs(seq_group_metadata_list)
|
||||
next_token_ids = self.model(inputs[0], inputs[1], kv_caches,
|
||||
*inputs[2:])
|
||||
next_token_ids = next_token_ids.cpu().tolist()
|
||||
|
||||
i = 0
|
||||
sampler_outputs = []
|
||||
for seq_group_metadata in seq_group_metadata_list:
|
||||
seq_outputs = []
|
||||
seq_ids = list(seq_group_metadata.seq_data.keys())
|
||||
for seq_id in seq_ids:
|
||||
next_token_id = next_token_ids[i]
|
||||
seq_outputs.append(
|
||||
SequenceOutput(seq_id, next_token_id,
|
||||
{next_token_id: Logprob(0.0)}))
|
||||
i += 1
|
||||
sampler_outputs.append(
|
||||
CompletionSequenceGroupOutput(seq_outputs, None))
|
||||
return sampler_outputs
|
||||
|
||||
def execute_model(
|
||||
self,
|
||||
seq_group_metadata_list: Optional[List[SequenceGroupMetadata]],
|
||||
kv_caches: List[Tuple[torch.Tensor, torch.Tensor]],
|
||||
) -> SamplerOutput:
|
||||
assert seq_group_metadata_list is not None
|
||||
if seq_group_metadata_list[0].is_prompt:
|
||||
# NOTE(woosuk): To reduce the compilation time, we only compile the
|
||||
# prefill inputs with batch size 1. Because the scheduler is not
|
||||
# aware of this limitation, we need to handle batch size > 1
|
||||
# internally by calling the model multiple times and concatenating
|
||||
# the outputs.
|
||||
# FIXME(woosuk): This is a temporary hack to not change the existing
|
||||
# scheduler. We need to fix this in the future.
|
||||
sampler_outputs = []
|
||||
for seq_group_metadata in seq_group_metadata_list:
|
||||
sampler_outputs += self._execute_model([seq_group_metadata],
|
||||
kv_caches)
|
||||
else:
|
||||
sampler_outputs = self._execute_model(seq_group_metadata_list,
|
||||
kv_caches)
|
||||
return SamplerOutput(sampler_outputs)
|
||||
|
||||
|
||||
class ModelWrapper(nn.Module):
|
||||
|
||||
def __init__(self, model: nn.Module):
|
||||
super().__init__()
|
||||
self.model = model.eval()
|
||||
|
||||
def forward(
|
||||
self,
|
||||
token_ids: torch.Tensor,
|
||||
position_ids: torch.Tensor,
|
||||
kv_caches: List[Tuple[Optional[torch.Tensor], Optional[torch.Tensor]]],
|
||||
attn_metadata: AttentionMetadata,
|
||||
input_lens: torch.Tensor,
|
||||
t: torch.Tensor,
|
||||
p: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
"""Executes the forward pass of the model and samples the next token.
|
||||
|
||||
Args:
|
||||
token_ids: The input token IDs of shape [batch_size, seq_len].
|
||||
position_ids: The input position IDs of shape [batch_size, seq_len].
|
||||
kv_caches: The key and value caches. They can be None during the
|
||||
memory profiling at initialization.
|
||||
attn_metadata: The Pallas attention metadata.
|
||||
input_lens: The actual input lengths of shape [batch_size].
|
||||
t: The sampling temperature of shape [batch_size].
|
||||
p: The top-p probability of shape [batch_size].
|
||||
"""
|
||||
batch_size, seq_len = token_ids.shape
|
||||
# Calculate the positions to sample from.
|
||||
base_indicies = torch.arange(
|
||||
batch_size, dtype=torch.int32, device=input_lens.device) * seq_len
|
||||
logits_indices = base_indicies + input_lens - 1
|
||||
|
||||
# FIXME(woosuk): This is a temporary hack to avoid using the existing
|
||||
# sampler and sampling metadata.
|
||||
sampling_metadata = SamplingMetadata(
|
||||
seq_groups=[],
|
||||
selected_token_indices=logits_indices,
|
||||
categorized_sample_indices={},
|
||||
num_prompts=attn_metadata.num_prefills,
|
||||
)
|
||||
|
||||
# Skip this in memory profiling at initialization.
|
||||
if kv_caches[0][0] is not None:
|
||||
# index_copy_(slot_mapping) only works when the inserted dimension
|
||||
# is 0. However, the KV cache in the Pallas backend has the shape
|
||||
# [num_kv_heads, num_blocks, block_size, head_size]. To make it
|
||||
# work, we need to flatten the first three dimensions and modify
|
||||
# the slot_mapping accordingly.
|
||||
num_kv_heads, num_blocks, block_size, _ = kv_caches[0][0].shape
|
||||
slot_mapping = attn_metadata.slot_mapping
|
||||
slot_mapping = slot_mapping.flatten()
|
||||
head_indicies = torch.arange(0,
|
||||
num_kv_heads,
|
||||
device=slot_mapping.device,
|
||||
dtype=slot_mapping.dtype)
|
||||
head_indicies *= block_size * num_blocks
|
||||
slot_mapping = slot_mapping.repeat_interleave(num_kv_heads).view(
|
||||
-1, num_kv_heads)
|
||||
slot_mapping = slot_mapping + head_indicies.view(1, -1)
|
||||
slot_mapping = slot_mapping.flatten()
|
||||
attn_metadata.slot_mapping = slot_mapping
|
||||
|
||||
hidden_states = self.model(
|
||||
token_ids,
|
||||
position_ids,
|
||||
kv_caches,
|
||||
attn_metadata,
|
||||
)
|
||||
hidden_states = hidden_states.flatten(0, 1)
|
||||
logits = self.model.compute_logits(hidden_states, sampling_metadata)
|
||||
|
||||
logits = logits / t.unsqueeze(dim=1)
|
||||
# FIXME(woosuk): Disabled top-p sampling since it's too slow.
|
||||
# logits = _apply_top_p(logits, p.unsqueeze(dim=1))
|
||||
probs = torch.softmax(logits, dim=-1, dtype=torch.float32)
|
||||
# FIXME(woosuk): best_of > 1 is not supported.
|
||||
next_token_ids = torch.multinomial(probs, num_samples=1).squeeze(dim=1)
|
||||
return next_token_ids
|
||||
|
||||
|
||||
def _get_padded_prefill_len(x: int) -> int:
|
||||
# NOTE(woosuk): The pallas FlashAttention kernel requires the sequence
|
||||
# length to be a multiple of 16. We pad the prompt length to the nearest
|
||||
# multiple of 16. This is also good for performance.
|
||||
if x <= 16:
|
||||
return 16
|
||||
return 1 << (x - 1).bit_length()
|
||||
|
||||
|
||||
def _get_padded_batch_size(batch_size: int) -> int:
|
||||
if batch_size <= 2:
|
||||
return batch_size
|
||||
elif batch_size <= 4:
|
||||
return 4
|
||||
elif batch_size <= 8:
|
||||
return 8
|
||||
else:
|
||||
return ((batch_size + 15) // 16) * 16
|
||||
|
||||
|
||||
def _apply_top_p(logits: torch.Tensor, p: torch.Tensor) -> torch.Tensor:
|
||||
logits_sorted = torch.sort(logits, dim=-1, descending=True).values
|
||||
sorted_cum_probs = torch.cumsum(logits_sorted.softmax(dim=-1), dim=-1)
|
||||
cutoff_index = torch.sum(sorted_cum_probs < p, dim=-1, keepdim=True)
|
||||
cutoff_logit = torch.gather(logits_sorted, -1, cutoff_index)
|
||||
logits = logits.masked_fill_(logits < cutoff_logit, -float("inf"))
|
||||
return logits
|
198
vllm/worker/tpu_worker.py
Normal file
198
vllm/worker/tpu_worker.py
Normal file
@ -0,0 +1,198 @@
|
||||
import os
|
||||
from typing import List, Optional, Tuple
|
||||
|
||||
import torch
|
||||
import torch_xla.core.xla_model as xm
|
||||
import torch_xla.runtime as xr
|
||||
|
||||
import vllm.envs as envs
|
||||
from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, ModelConfig,
|
||||
ParallelConfig, SchedulerConfig, VisionLanguageConfig)
|
||||
from vllm.distributed import (ensure_model_parallel_initialized,
|
||||
init_distributed_environment)
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor import set_random_seed
|
||||
from vllm.sequence import ExecuteModelRequest, SamplerOutput
|
||||
from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE, get_dtype_size
|
||||
from vllm.worker.tpu_model_runner import TPUModelRunner
|
||||
from vllm.worker.worker_base import LoraNotSupportedWorkerBase
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
class TPUWorker(LoraNotSupportedWorkerBase):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model_config: ModelConfig,
|
||||
parallel_config: ParallelConfig,
|
||||
scheduler_config: SchedulerConfig,
|
||||
device_config: DeviceConfig,
|
||||
cache_config: CacheConfig,
|
||||
load_config: LoadConfig,
|
||||
vision_language_config: Optional[VisionLanguageConfig],
|
||||
local_rank: int,
|
||||
rank: int,
|
||||
distributed_init_method: str,
|
||||
) -> None:
|
||||
self.model_config = model_config
|
||||
self.parallel_config = parallel_config
|
||||
self.scheduler_config = scheduler_config
|
||||
self.device_config = device_config
|
||||
self.cache_config = cache_config
|
||||
self.load_config = load_config
|
||||
self.vision_language_config = vision_language_config
|
||||
self.local_rank = local_rank
|
||||
self.rank = rank
|
||||
self.distributed_init_method = distributed_init_method
|
||||
|
||||
assert self.device_config.device_type == "tpu"
|
||||
if self.cache_config.cache_dtype == "auto":
|
||||
self.cache_dtype = self.model_config.dtype
|
||||
else:
|
||||
self.cache_dtype = STR_DTYPE_TO_TORCH_DTYPE[
|
||||
self.cache_config.cache_dtype]
|
||||
|
||||
self.model_runner = TPUModelRunner(model_config, parallel_config,
|
||||
scheduler_config, device_config,
|
||||
cache_config, load_config,
|
||||
vision_language_config)
|
||||
|
||||
def init_device(self) -> None:
|
||||
os.environ["PJRT_DEVICE"] = "TPU"
|
||||
self.device = xm.xla_device()
|
||||
self.device_config.device = self.device
|
||||
torch.set_grad_enabled(False)
|
||||
torch.set_default_dtype(self.model_config.dtype)
|
||||
|
||||
# NOTE(woosuk): This is just a hack to initialize the TP group.
|
||||
# This cannot perform the actual communication ops.
|
||||
init_distributed_environment(
|
||||
world_size=self.parallel_config.world_size,
|
||||
rank=self.rank,
|
||||
local_rank=self.local_rank,
|
||||
distributed_init_method=self.distributed_init_method,
|
||||
backend="gloo",
|
||||
)
|
||||
ensure_model_parallel_initialized(
|
||||
self.parallel_config.tensor_parallel_size,
|
||||
self.parallel_config.pipeline_parallel_size)
|
||||
|
||||
# Set random seed.
|
||||
set_random_seed(self.model_config.seed)
|
||||
xm.set_rng_state(self.model_config.seed, self.device)
|
||||
|
||||
# Increase the cache size limit, which is the maximum number of
|
||||
# dynamo graphs that can be compiled.
|
||||
# NOTE(woosuk): Usually, we compile 10-15 graphs for prefill and
|
||||
# 30-40 graphs for decode. 128 is an arbitrary safe number.
|
||||
torch._dynamo.config.cache_size_limit = 128
|
||||
# Use persistent cache to avoid XLA recompilation.
|
||||
# NOTE(woosuk): This does not completely eliminate the recompilation
|
||||
# overhead because dynamo does not cache the compiled results.
|
||||
xr.initialize_cache(os.path.expanduser(envs.VLLM_XLA_CACHE_PATH),
|
||||
readonly=False)
|
||||
|
||||
def load_model(self):
|
||||
self.model_runner.load_model()
|
||||
|
||||
def determine_num_available_blocks(self) -> Tuple[int, int]:
|
||||
num_layers = self.model_config.get_num_layers(self.parallel_config)
|
||||
head_size = self.model_config.get_head_size()
|
||||
num_kv_heads = self.model_config.get_num_kv_heads(self.parallel_config)
|
||||
|
||||
kv_caches = [(None, None) for _ in range(num_layers)]
|
||||
self.model_runner._dummy_run(
|
||||
batch_size=1,
|
||||
seq_len=self.scheduler_config.max_num_batched_tokens,
|
||||
kv_caches=kv_caches,
|
||||
is_prompt=True,
|
||||
)
|
||||
# Synchronize before measuring the memory usage.
|
||||
xm.wait_device_ops()
|
||||
|
||||
m = xm.get_memory_info(self.device)
|
||||
program_size = 1024 * 1024 * 1024 # 1GB
|
||||
free_bytes = max(m["bytes_limit"] - m["bytes_used"] - program_size, 0)
|
||||
kv_cache_bytes = int(free_bytes *
|
||||
self.cache_config.gpu_memory_utilization)
|
||||
kv_cache_dtype_btyes = get_dtype_size(self.cache_dtype)
|
||||
block_size = self.cache_config.block_size
|
||||
num_tpu_blocks = (kv_cache_bytes //
|
||||
(kv_cache_dtype_btyes * block_size * num_layers * 2 *
|
||||
head_size * num_kv_heads))
|
||||
num_tpu_blocks = (num_tpu_blocks // 8) * 8 # Round down to 8.
|
||||
return num_tpu_blocks, 0
|
||||
|
||||
def initialize_cache(
|
||||
self,
|
||||
num_gpu_blocks: int,
|
||||
num_cpu_blocks: int,
|
||||
) -> None:
|
||||
self.cache_config.num_gpu_blocks = num_gpu_blocks
|
||||
self.cache_config.num_cpu_blocks = num_cpu_blocks
|
||||
self.block_size = self.cache_config.block_size
|
||||
|
||||
dtype = self.cache_dtype
|
||||
num_layers = self.model_config.get_num_layers(self.parallel_config)
|
||||
num_kv_heads = self.model_config.get_num_kv_heads(self.parallel_config)
|
||||
head_size = self.model_config.get_head_size()
|
||||
|
||||
self.tpu_cache = []
|
||||
tpu_cache_shape = self.model_runner.attn_backend.get_kv_cache_shape(
|
||||
num_gpu_blocks, self.block_size, num_kv_heads, head_size)
|
||||
for _ in range(num_layers):
|
||||
key_cache = torch.zeros(tpu_cache_shape,
|
||||
dtype=dtype,
|
||||
device=self.device)
|
||||
value_cache = torch.zeros_like(key_cache)
|
||||
self.tpu_cache.append((key_cache, value_cache))
|
||||
self._warmup_model()
|
||||
|
||||
def _warmup_model(self) -> None:
|
||||
# FIXME(woosuk): Here we are abusing `enforce_eager` which is defined
|
||||
# for CUDA graphs. We should refactor this part.
|
||||
if not self.model_config.enforce_eager:
|
||||
# Warm up the model with all possible input shapes so that
|
||||
# compilation never happens during the actual execution.
|
||||
# This may take ~30 mins for the first run and ~20 mins for the
|
||||
# subsequent runs.
|
||||
# If `enforce_eager` is True, the ahead-of-time compilation is
|
||||
# skipped and the compilation happens during the actual execution,
|
||||
# which is bad for performance but useful for development.
|
||||
self.model_runner.warmup_model(self.tpu_cache)
|
||||
|
||||
def get_cache_block_size_bytes(self) -> int:
|
||||
head_size = self.model_config.get_head_size()
|
||||
num_heads = self.model_config.get_num_kv_heads(self.parallel_config)
|
||||
num_layers = self.model_config.get_num_layers(self.parallel_config)
|
||||
|
||||
key_cache_block = self.cache_config.block_size * num_heads * head_size
|
||||
value_cache_block = key_cache_block
|
||||
total = num_layers * (key_cache_block + value_cache_block)
|
||||
dtype_size = get_dtype_size(self.cache_dtype)
|
||||
return dtype_size * total
|
||||
|
||||
def execute_model(
|
||||
self,
|
||||
execute_model_req: Optional[ExecuteModelRequest] = None
|
||||
) -> List[SamplerOutput]:
|
||||
if execute_model_req is None:
|
||||
return []
|
||||
|
||||
seq_group_metadata_list = execute_model_req.seq_group_metadata_list
|
||||
num_seq_groups = len(seq_group_metadata_list)
|
||||
if num_seq_groups == 0:
|
||||
return []
|
||||
|
||||
# Currently, TPUWorker does not support swapping.
|
||||
# TODO(woosuk): Support block copying.
|
||||
assert len(execute_model_req.blocks_to_swap_in) == 0, (
|
||||
"Swapping is not supported for the TPU backend.")
|
||||
assert len(execute_model_req.blocks_to_swap_out) == 0, (
|
||||
"Swapping is not supported for the TPU backend.")
|
||||
assert len(execute_model_req.blocks_to_copy) == 0
|
||||
|
||||
output = self.model_runner.execute_model(seq_group_metadata_list,
|
||||
self.tpu_cache)
|
||||
return [output]
|
Loading…
x
Reference in New Issue
Block a user