Implement cache ops

This commit is contained in:
Woosuk Kwon 2023-02-16 07:47:03 +00:00
parent a1c67e6db8
commit 6f058c7ba8
4 changed files with 109 additions and 6 deletions

View File

@ -1,6 +1,7 @@
from typing import Dict, List, Tuple
import torch
from cacheflow import ops
KVCache = Tuple[torch.Tensor, torch.Tensor]
@ -92,14 +93,30 @@ class CacheEngine:
cpu_cache.append((key_blocks, value_blocks))
return cpu_cache
def _copy_blocks(
self,
src: List[KVCache],
dst: List[KVCache],
src_to_dst: Dict[int, int],
) -> None:
with torch.cuda.stream(self.cache_stream):
for i in range(self.num_layers):
src_key_cache, src_value_cache = src[i]
dst_key_cache, dst_value_cache = dst[i]
# Copy the key blocks.
ops.copy_cache_blocks(
src_key_cache, dst_key_cache, src_to_dst)
# Copy the value blocks.
ops.copy_cache_blocks(
src_value_cache, dst_value_cache, src_to_dst)
event = self.events[i]
event.record(stream=self.cache_stream)
def copy(self, src_to_dst: Dict[int, int]) -> None:
for event in self.events:
pass
self._copy_blocks(self.gpu_cache, self.gpu_cache, src_to_dst)
def swap_in(self, src_to_dst: Dict[int, int]) -> None:
for event in self.events:
pass
self._copy_blocks(self.cpu_cache, self.gpu_cache, src_to_dst)
def swap_out(self, src_to_dst: Dict[int, int]) -> None:
for event in self.events:
pass
self._copy_blocks(self.gpu_cache, self.cpu_cache, src_to_dst)

20
csrc/cache.cpp Normal file
View File

@ -0,0 +1,20 @@
#include <torch/extension.h>
void copy_blocks(
torch::Tensor& src,
torch::Tensor& dst,
const std::map<int64_t, int64_t>& block_mapping);
void copy_cache_blocks(
torch::Tensor& src,
torch::Tensor& dst,
const std::map<int64_t, int64_t>& block_mapping) {
copy_blocks(src, dst, block_mapping);
}
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def(
"copy_cache_blocks",
&copy_cache_blocks,
"Copy the cache blocks from src to dst");
}

43
csrc/cache_kernel.cu Normal file
View File

@ -0,0 +1,43 @@
#include <torch/extension.h>
#include <ATen/cuda/CUDAContext.h>
#include <cassert>
#include <map>
void copy_blocks(
torch::Tensor& src,
torch::Tensor& dst,
const std::map<int64_t, int64_t>& block_mapping) {
torch::Device src_device = src.device();
torch::Device dst_device = dst.device();
cudaMemcpyKind memcpy_type;
if (src_device.is_cuda() && dst_device.is_cuda()) {
assert(src_device.index() == dst_device.index());
memcpy_type = cudaMemcpyDeviceToDevice;
} else if (src_device.is_cuda() && dst_device.is_cpu()) {
memcpy_type = cudaMemcpyDeviceToHost;
} else if (src_device.is_cpu() && dst_device.is_cuda()) {
memcpy_type = cudaMemcpyHostToDevice;
} else {
assert(false);
}
void *src_ptr = src.data_ptr();
void *dst_ptr = dst.data_ptr();
const int64_t block_size_in_bytes = src.element_size() * src[0].numel();
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
for (const auto& pair : block_mapping) {
int64_t src_block_number = pair.first;
int64_t dst_block_number = pair.second;
int64_t src_offset = src_block_number * block_size_in_bytes;
int64_t dst_offset = dst_block_number * block_size_in_bytes;
cudaMemcpyAsync(
dst_ptr + dst_offset,
src_ptr + src_offset,
block_size_in_bytes,
memcpy_type,
stream);
}
}

View File

@ -0,0 +1,23 @@
import setuptools
from torch.utils import cpp_extension
CXX_FLAGS = ['-g']
NVCC_FLAGS = ['-O2']
ext_modules = []
# Cache operations.
cache_extension = cpp_extension.CUDAExtension(
name='cacheflow.ops',
sources=['csrc/cache.cpp', 'csrc/cache_kernel.cu'],
extra_compile_args={'cxx': CXX_FLAGS, 'nvcc': NVCC_FLAGS},
)
ext_modules.append(cache_extension)
setuptools.setup(
name='cacheflow',
requires_python='>=3.9',
ext_modules=ext_modules,
cmdclass={'build_ext': cpp_extension.BuildExtension},
)