vllm/cacheflow/worker/cache_engine.py

152 lines
5.4 KiB
Python
Raw Normal View History

"""CacheEngine class for managing the KV cache."""
2023-02-13 09:35:48 +00:00
from typing import Dict, List, Tuple
2023-02-09 11:28:02 +00:00
import torch
from cacheflow import cache_ops
2023-05-20 13:06:59 -07:00
from cacheflow.config import CacheConfig, ModelConfig, ParallelConfig
2023-02-09 11:28:02 +00:00
KVCache = Tuple[torch.Tensor, torch.Tensor]
class CacheEngine:
"""Manages the KV cache.
This class is responsible for initializing and managing the GPU and CPU KV
caches. It also provides methods for performing KV cache operations, such
as swapping and copying.
"""
2023-02-09 11:28:02 +00:00
def __init__(
self,
2023-05-20 13:06:59 -07:00
cache_config: CacheConfig,
model_config: ModelConfig,
parallel_config: ParallelConfig,
2023-02-09 11:28:02 +00:00
) -> None:
2023-05-20 13:06:59 -07:00
self.cache_config = cache_config
self.model_config = model_config
self.parallel_config = parallel_config
self.head_size = model_config.get_head_size()
self.num_layers = model_config.get_num_layers(parallel_config)
self.num_heads = model_config.get_num_heads(parallel_config)
self.dtype = model_config.dtype
self.block_size = cache_config.block_size
self.num_gpu_blocks = cache_config.num_gpu_blocks
self.num_cpu_blocks = cache_config.num_cpu_blocks
2023-02-09 11:28:02 +00:00
# Initialize the cache.
self.gpu_cache = self.allocate_gpu_cache()
self.cpu_cache = self.allocate_cpu_cache()
2023-02-13 09:35:48 +00:00
# Initialize the stream for caching operations.
2023-03-22 04:45:42 +08:00
self.cache_stream = torch.cuda.Stream()
assert self.cache_stream != torch.cuda.current_stream()
2023-02-13 09:35:48 +00:00
# Initialize the events for stream synchronization.
2023-05-20 13:06:59 -07:00
self.events = [torch.cuda.Event() for _ in range(self.num_layers)]
2023-02-09 11:28:02 +00:00
2023-02-16 01:33:03 +00:00
def get_key_block_shape(self) -> Tuple[int, int, int, int]:
2023-02-16 01:24:45 +00:00
element_size = torch.tensor([], dtype=self.dtype).element_size()
x = 16 // element_size
return (
self.num_heads,
self.head_size // x,
self.block_size,
x,
)
2023-02-16 01:33:03 +00:00
def get_value_block_shape(self) -> Tuple[int, int, int]:
2023-02-16 01:24:45 +00:00
return (
self.num_heads,
self.head_size,
self.block_size,
2023-02-16 01:24:45 +00:00
)
def allocate_gpu_cache(self) -> List[KVCache]:
gpu_cache: List[KVCache] = []
key_block_shape = self.get_key_block_shape()
value_block_shape = self.get_value_block_shape()
2023-02-09 11:28:02 +00:00
for _ in range(self.num_layers):
2023-02-16 01:24:45 +00:00
key_blocks = torch.empty(
size=(self.num_gpu_blocks, *key_block_shape),
2023-02-16 01:24:45 +00:00
dtype=self.dtype,
2023-03-22 04:45:42 +08:00
device="cuda",
2023-02-16 01:24:45 +00:00
)
value_blocks = torch.empty(
size=(self.num_gpu_blocks, *value_block_shape),
2023-02-16 01:24:45 +00:00
dtype=self.dtype,
2023-03-22 04:45:42 +08:00
device="cuda",
2023-02-16 01:24:45 +00:00
)
gpu_cache.append((key_blocks, value_blocks))
2023-02-09 11:28:02 +00:00
return gpu_cache
2023-02-16 01:24:45 +00:00
def allocate_cpu_cache(self) -> List[KVCache]:
cpu_cache: List[KVCache] = []
key_block_shape = self.get_key_block_shape()
value_block_shape = self.get_value_block_shape()
2023-02-09 11:28:02 +00:00
for _ in range(self.num_layers):
2023-02-16 01:24:45 +00:00
key_blocks = torch.empty(
size=(self.num_cpu_blocks, *key_block_shape),
2023-02-16 01:24:45 +00:00
dtype=self.dtype,
pin_memory=True,
)
value_blocks = torch.empty(
size=(self.num_cpu_blocks, *value_block_shape),
2023-02-16 01:24:45 +00:00
dtype=self.dtype,
pin_memory=True,
)
cpu_cache.append((key_blocks, value_blocks))
2023-02-09 11:28:02 +00:00
return cpu_cache
def _swap(
2023-02-16 07:47:03 +00:00
self,
src: List[KVCache],
dst: List[KVCache],
src_to_dst: Dict[int, int],
) -> None:
with torch.cuda.stream(self.cache_stream):
for i in range(self.num_layers):
src_key_cache, src_value_cache = src[i]
dst_key_cache, dst_value_cache = dst[i]
# Copy the key blocks.
cache_ops.swap_blocks(
2023-02-16 07:47:03 +00:00
src_key_cache, dst_key_cache, src_to_dst)
# Copy the value blocks.
cache_ops.swap_blocks(
2023-02-16 07:47:03 +00:00
src_value_cache, dst_value_cache, src_to_dst)
event = self.events[i]
event.record(stream=self.cache_stream)
2023-02-13 09:35:48 +00:00
def swap_in(self, src_to_dst: Dict[int, int]) -> None:
self._swap(self.cpu_cache, self.gpu_cache, src_to_dst)
2023-02-09 11:28:02 +00:00
2023-02-13 09:35:48 +00:00
def swap_out(self, src_to_dst: Dict[int, int]) -> None:
self._swap(self.gpu_cache, self.cpu_cache, src_to_dst)
def copy(self, src_to_dsts: Dict[int, List[int]]) -> None:
key_caches = [key_cache for key_cache, _ in self.gpu_cache]
value_caches = [value_cache for _, value_cache in self.gpu_cache]
# NOTE(woosuk): This operation implicitly synchronizes the CPU and GPU.
cache_ops.copy_blocks(key_caches, value_caches, src_to_dsts)
2023-05-20 13:06:59 -07:00
@staticmethod
def get_cache_block_size(
block_size: int,
model_config: ModelConfig,
parallel_config: ParallelConfig,
) -> int:
head_size = model_config.get_head_size()
num_heads = model_config.get_num_heads(parallel_config)
num_layers = model_config.get_num_layers(parallel_config)
key_cache_block = block_size * num_heads * head_size
value_cache_block = key_cache_block
total = num_layers * (key_cache_block + value_cache_block)
dtype_size = _get_dtype_size(model_config.dtype)
return dtype_size * total
def _get_dtype_size(dtype: torch.dtype) -> int:
return torch.tensor([], dtype=dtype).element_size()