2023-05-14 22:32:38 -07:00
|
|
|
"""CacheEngine class for managing the KV cache."""
|
2023-02-13 09:35:48 +00:00
|
|
|
from typing import Dict, List, Tuple
|
2023-02-09 11:28:02 +00:00
|
|
|
|
|
|
|
import torch
|
2023-05-14 22:32:38 -07:00
|
|
|
|
2023-03-01 15:02:19 -08:00
|
|
|
from cacheflow import cache_ops
|
2023-05-20 13:06:59 -07:00
|
|
|
from cacheflow.config import CacheConfig, ModelConfig, ParallelConfig
|
2023-02-09 11:28:02 +00:00
|
|
|
|
|
|
|
KVCache = Tuple[torch.Tensor, torch.Tensor]
|
|
|
|
|
|
|
|
|
|
|
|
class CacheEngine:
|
2023-05-14 22:32:38 -07:00
|
|
|
"""Manages the KV cache.
|
|
|
|
|
|
|
|
This class is responsible for initializing and managing the GPU and CPU KV
|
|
|
|
caches. It also provides methods for performing KV cache operations, such
|
|
|
|
as swapping and copying.
|
|
|
|
"""
|
2023-02-09 11:28:02 +00:00
|
|
|
|
|
|
|
def __init__(
|
|
|
|
self,
|
2023-05-20 13:06:59 -07:00
|
|
|
cache_config: CacheConfig,
|
|
|
|
model_config: ModelConfig,
|
|
|
|
parallel_config: ParallelConfig,
|
2023-02-09 11:28:02 +00:00
|
|
|
) -> None:
|
2023-05-20 13:06:59 -07:00
|
|
|
self.cache_config = cache_config
|
|
|
|
self.model_config = model_config
|
|
|
|
self.parallel_config = parallel_config
|
|
|
|
|
|
|
|
self.head_size = model_config.get_head_size()
|
|
|
|
self.num_layers = model_config.get_num_layers(parallel_config)
|
|
|
|
self.num_heads = model_config.get_num_heads(parallel_config)
|
|
|
|
self.dtype = model_config.dtype
|
|
|
|
|
|
|
|
self.block_size = cache_config.block_size
|
|
|
|
self.num_gpu_blocks = cache_config.num_gpu_blocks
|
|
|
|
self.num_cpu_blocks = cache_config.num_cpu_blocks
|
2023-02-09 11:28:02 +00:00
|
|
|
|
|
|
|
# Initialize the cache.
|
|
|
|
self.gpu_cache = self.allocate_gpu_cache()
|
|
|
|
self.cpu_cache = self.allocate_cpu_cache()
|
|
|
|
|
2023-02-13 09:35:48 +00:00
|
|
|
# Initialize the stream for caching operations.
|
2023-03-22 04:45:42 +08:00
|
|
|
self.cache_stream = torch.cuda.Stream()
|
|
|
|
assert self.cache_stream != torch.cuda.current_stream()
|
2023-02-13 09:35:48 +00:00
|
|
|
# Initialize the events for stream synchronization.
|
2023-05-20 13:06:59 -07:00
|
|
|
self.events = [torch.cuda.Event() for _ in range(self.num_layers)]
|
2023-02-09 11:28:02 +00:00
|
|
|
|
2023-02-16 01:33:03 +00:00
|
|
|
def get_key_block_shape(self) -> Tuple[int, int, int, int]:
|
2023-02-16 01:24:45 +00:00
|
|
|
element_size = torch.tensor([], dtype=self.dtype).element_size()
|
|
|
|
x = 16 // element_size
|
|
|
|
return (
|
|
|
|
self.num_heads,
|
|
|
|
self.head_size // x,
|
|
|
|
self.block_size,
|
|
|
|
x,
|
|
|
|
)
|
|
|
|
|
2023-02-16 01:33:03 +00:00
|
|
|
def get_value_block_shape(self) -> Tuple[int, int, int]:
|
2023-02-16 01:24:45 +00:00
|
|
|
return (
|
|
|
|
self.num_heads,
|
|
|
|
self.head_size,
|
2023-03-01 15:02:19 -08:00
|
|
|
self.block_size,
|
2023-02-16 01:24:45 +00:00
|
|
|
)
|
|
|
|
|
|
|
|
def allocate_gpu_cache(self) -> List[KVCache]:
|
|
|
|
gpu_cache: List[KVCache] = []
|
2023-03-01 15:02:19 -08:00
|
|
|
key_block_shape = self.get_key_block_shape()
|
|
|
|
value_block_shape = self.get_value_block_shape()
|
2023-02-09 11:28:02 +00:00
|
|
|
for _ in range(self.num_layers):
|
2023-02-16 01:24:45 +00:00
|
|
|
key_blocks = torch.empty(
|
2023-03-01 15:02:19 -08:00
|
|
|
size=(self.num_gpu_blocks, *key_block_shape),
|
2023-02-16 01:24:45 +00:00
|
|
|
dtype=self.dtype,
|
2023-03-22 04:45:42 +08:00
|
|
|
device="cuda",
|
2023-02-16 01:24:45 +00:00
|
|
|
)
|
|
|
|
value_blocks = torch.empty(
|
2023-03-01 15:02:19 -08:00
|
|
|
size=(self.num_gpu_blocks, *value_block_shape),
|
2023-02-16 01:24:45 +00:00
|
|
|
dtype=self.dtype,
|
2023-03-22 04:45:42 +08:00
|
|
|
device="cuda",
|
2023-02-16 01:24:45 +00:00
|
|
|
)
|
|
|
|
gpu_cache.append((key_blocks, value_blocks))
|
2023-02-09 11:28:02 +00:00
|
|
|
return gpu_cache
|
|
|
|
|
2023-02-16 01:24:45 +00:00
|
|
|
def allocate_cpu_cache(self) -> List[KVCache]:
|
|
|
|
cpu_cache: List[KVCache] = []
|
2023-03-01 15:02:19 -08:00
|
|
|
key_block_shape = self.get_key_block_shape()
|
|
|
|
value_block_shape = self.get_value_block_shape()
|
2023-02-09 11:28:02 +00:00
|
|
|
for _ in range(self.num_layers):
|
2023-02-16 01:24:45 +00:00
|
|
|
key_blocks = torch.empty(
|
2023-03-01 15:02:19 -08:00
|
|
|
size=(self.num_cpu_blocks, *key_block_shape),
|
2023-02-16 01:24:45 +00:00
|
|
|
dtype=self.dtype,
|
|
|
|
pin_memory=True,
|
|
|
|
)
|
|
|
|
value_blocks = torch.empty(
|
2023-03-01 15:02:19 -08:00
|
|
|
size=(self.num_cpu_blocks, *value_block_shape),
|
2023-02-16 01:24:45 +00:00
|
|
|
dtype=self.dtype,
|
|
|
|
pin_memory=True,
|
|
|
|
)
|
|
|
|
cpu_cache.append((key_blocks, value_blocks))
|
2023-02-09 11:28:02 +00:00
|
|
|
return cpu_cache
|
|
|
|
|
2023-03-10 09:58:21 -08:00
|
|
|
def _swap(
|
2023-02-16 07:47:03 +00:00
|
|
|
self,
|
|
|
|
src: List[KVCache],
|
|
|
|
dst: List[KVCache],
|
|
|
|
src_to_dst: Dict[int, int],
|
|
|
|
) -> None:
|
|
|
|
with torch.cuda.stream(self.cache_stream):
|
|
|
|
for i in range(self.num_layers):
|
|
|
|
src_key_cache, src_value_cache = src[i]
|
|
|
|
dst_key_cache, dst_value_cache = dst[i]
|
|
|
|
# Copy the key blocks.
|
2023-03-10 09:58:21 -08:00
|
|
|
cache_ops.swap_blocks(
|
2023-02-16 07:47:03 +00:00
|
|
|
src_key_cache, dst_key_cache, src_to_dst)
|
|
|
|
# Copy the value blocks.
|
2023-03-10 09:58:21 -08:00
|
|
|
cache_ops.swap_blocks(
|
2023-02-16 07:47:03 +00:00
|
|
|
src_value_cache, dst_value_cache, src_to_dst)
|
|
|
|
event = self.events[i]
|
|
|
|
event.record(stream=self.cache_stream)
|
|
|
|
|
2023-02-13 09:35:48 +00:00
|
|
|
def swap_in(self, src_to_dst: Dict[int, int]) -> None:
|
2023-03-10 09:58:21 -08:00
|
|
|
self._swap(self.cpu_cache, self.gpu_cache, src_to_dst)
|
2023-02-09 11:28:02 +00:00
|
|
|
|
2023-02-13 09:35:48 +00:00
|
|
|
def swap_out(self, src_to_dst: Dict[int, int]) -> None:
|
2023-03-10 09:58:21 -08:00
|
|
|
self._swap(self.gpu_cache, self.cpu_cache, src_to_dst)
|
|
|
|
|
|
|
|
def copy(self, src_to_dsts: Dict[int, List[int]]) -> None:
|
2023-04-07 17:45:07 -07:00
|
|
|
key_caches = [key_cache for key_cache, _ in self.gpu_cache]
|
|
|
|
value_caches = [value_cache for _, value_cache in self.gpu_cache]
|
|
|
|
# NOTE(woosuk): This operation implicitly synchronizes the CPU and GPU.
|
|
|
|
cache_ops.copy_blocks(key_caches, value_caches, src_to_dsts)
|
2023-05-20 13:06:59 -07:00
|
|
|
|
|
|
|
@staticmethod
|
|
|
|
def get_cache_block_size(
|
|
|
|
block_size: int,
|
|
|
|
model_config: ModelConfig,
|
|
|
|
parallel_config: ParallelConfig,
|
|
|
|
) -> int:
|
|
|
|
head_size = model_config.get_head_size()
|
|
|
|
num_heads = model_config.get_num_heads(parallel_config)
|
|
|
|
num_layers = model_config.get_num_layers(parallel_config)
|
|
|
|
|
|
|
|
key_cache_block = block_size * num_heads * head_size
|
|
|
|
value_cache_block = key_cache_block
|
|
|
|
total = num_layers * (key_cache_block + value_cache_block)
|
|
|
|
dtype_size = _get_dtype_size(model_config.dtype)
|
|
|
|
return dtype_size * total
|
|
|
|
|
|
|
|
|
|
|
|
def _get_dtype_size(dtype: torch.dtype) -> int:
|
|
|
|
return torch.tensor([], dtype=dtype).element_size()
|