Implement cache ops
This commit is contained in:
parent
a1c67e6db8
commit
6f058c7ba8
@ -1,6 +1,7 @@
|
|||||||
from typing import Dict, List, Tuple
|
from typing import Dict, List, Tuple
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
from cacheflow import ops
|
||||||
|
|
||||||
KVCache = Tuple[torch.Tensor, torch.Tensor]
|
KVCache = Tuple[torch.Tensor, torch.Tensor]
|
||||||
|
|
||||||
@ -92,14 +93,30 @@ class CacheEngine:
|
|||||||
cpu_cache.append((key_blocks, value_blocks))
|
cpu_cache.append((key_blocks, value_blocks))
|
||||||
return cpu_cache
|
return cpu_cache
|
||||||
|
|
||||||
|
def _copy_blocks(
|
||||||
|
self,
|
||||||
|
src: List[KVCache],
|
||||||
|
dst: List[KVCache],
|
||||||
|
src_to_dst: Dict[int, int],
|
||||||
|
) -> None:
|
||||||
|
with torch.cuda.stream(self.cache_stream):
|
||||||
|
for i in range(self.num_layers):
|
||||||
|
src_key_cache, src_value_cache = src[i]
|
||||||
|
dst_key_cache, dst_value_cache = dst[i]
|
||||||
|
# Copy the key blocks.
|
||||||
|
ops.copy_cache_blocks(
|
||||||
|
src_key_cache, dst_key_cache, src_to_dst)
|
||||||
|
# Copy the value blocks.
|
||||||
|
ops.copy_cache_blocks(
|
||||||
|
src_value_cache, dst_value_cache, src_to_dst)
|
||||||
|
event = self.events[i]
|
||||||
|
event.record(stream=self.cache_stream)
|
||||||
|
|
||||||
def copy(self, src_to_dst: Dict[int, int]) -> None:
|
def copy(self, src_to_dst: Dict[int, int]) -> None:
|
||||||
for event in self.events:
|
self._copy_blocks(self.gpu_cache, self.gpu_cache, src_to_dst)
|
||||||
pass
|
|
||||||
|
|
||||||
def swap_in(self, src_to_dst: Dict[int, int]) -> None:
|
def swap_in(self, src_to_dst: Dict[int, int]) -> None:
|
||||||
for event in self.events:
|
self._copy_blocks(self.cpu_cache, self.gpu_cache, src_to_dst)
|
||||||
pass
|
|
||||||
|
|
||||||
def swap_out(self, src_to_dst: Dict[int, int]) -> None:
|
def swap_out(self, src_to_dst: Dict[int, int]) -> None:
|
||||||
for event in self.events:
|
self._copy_blocks(self.gpu_cache, self.cpu_cache, src_to_dst)
|
||||||
pass
|
|
||||||
|
20
csrc/cache.cpp
Normal file
20
csrc/cache.cpp
Normal file
@ -0,0 +1,20 @@
|
|||||||
|
#include <torch/extension.h>
|
||||||
|
|
||||||
|
void copy_blocks(
|
||||||
|
torch::Tensor& src,
|
||||||
|
torch::Tensor& dst,
|
||||||
|
const std::map<int64_t, int64_t>& block_mapping);
|
||||||
|
|
||||||
|
void copy_cache_blocks(
|
||||||
|
torch::Tensor& src,
|
||||||
|
torch::Tensor& dst,
|
||||||
|
const std::map<int64_t, int64_t>& block_mapping) {
|
||||||
|
copy_blocks(src, dst, block_mapping);
|
||||||
|
}
|
||||||
|
|
||||||
|
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
|
||||||
|
m.def(
|
||||||
|
"copy_cache_blocks",
|
||||||
|
©_cache_blocks,
|
||||||
|
"Copy the cache blocks from src to dst");
|
||||||
|
}
|
43
csrc/cache_kernel.cu
Normal file
43
csrc/cache_kernel.cu
Normal file
@ -0,0 +1,43 @@
|
|||||||
|
#include <torch/extension.h>
|
||||||
|
|
||||||
|
#include <ATen/cuda/CUDAContext.h>
|
||||||
|
|
||||||
|
#include <cassert>
|
||||||
|
#include <map>
|
||||||
|
|
||||||
|
void copy_blocks(
|
||||||
|
torch::Tensor& src,
|
||||||
|
torch::Tensor& dst,
|
||||||
|
const std::map<int64_t, int64_t>& block_mapping) {
|
||||||
|
torch::Device src_device = src.device();
|
||||||
|
torch::Device dst_device = dst.device();
|
||||||
|
cudaMemcpyKind memcpy_type;
|
||||||
|
if (src_device.is_cuda() && dst_device.is_cuda()) {
|
||||||
|
assert(src_device.index() == dst_device.index());
|
||||||
|
memcpy_type = cudaMemcpyDeviceToDevice;
|
||||||
|
} else if (src_device.is_cuda() && dst_device.is_cpu()) {
|
||||||
|
memcpy_type = cudaMemcpyDeviceToHost;
|
||||||
|
} else if (src_device.is_cpu() && dst_device.is_cuda()) {
|
||||||
|
memcpy_type = cudaMemcpyHostToDevice;
|
||||||
|
} else {
|
||||||
|
assert(false);
|
||||||
|
}
|
||||||
|
|
||||||
|
void *src_ptr = src.data_ptr();
|
||||||
|
void *dst_ptr = dst.data_ptr();
|
||||||
|
|
||||||
|
const int64_t block_size_in_bytes = src.element_size() * src[0].numel();
|
||||||
|
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
||||||
|
for (const auto& pair : block_mapping) {
|
||||||
|
int64_t src_block_number = pair.first;
|
||||||
|
int64_t dst_block_number = pair.second;
|
||||||
|
int64_t src_offset = src_block_number * block_size_in_bytes;
|
||||||
|
int64_t dst_offset = dst_block_number * block_size_in_bytes;
|
||||||
|
cudaMemcpyAsync(
|
||||||
|
dst_ptr + dst_offset,
|
||||||
|
src_ptr + src_offset,
|
||||||
|
block_size_in_bytes,
|
||||||
|
memcpy_type,
|
||||||
|
stream);
|
||||||
|
}
|
||||||
|
}
|
23
setup.py
23
setup.py
@ -0,0 +1,23 @@
|
|||||||
|
import setuptools
|
||||||
|
from torch.utils import cpp_extension
|
||||||
|
|
||||||
|
CXX_FLAGS = ['-g']
|
||||||
|
NVCC_FLAGS = ['-O2']
|
||||||
|
|
||||||
|
|
||||||
|
ext_modules = []
|
||||||
|
|
||||||
|
# Cache operations.
|
||||||
|
cache_extension = cpp_extension.CUDAExtension(
|
||||||
|
name='cacheflow.ops',
|
||||||
|
sources=['csrc/cache.cpp', 'csrc/cache_kernel.cu'],
|
||||||
|
extra_compile_args={'cxx': CXX_FLAGS, 'nvcc': NVCC_FLAGS},
|
||||||
|
)
|
||||||
|
ext_modules.append(cache_extension)
|
||||||
|
|
||||||
|
setuptools.setup(
|
||||||
|
name='cacheflow',
|
||||||
|
requires_python='>=3.9',
|
||||||
|
ext_modules=ext_modules,
|
||||||
|
cmdclass={'build_ext': cpp_extension.BuildExtension},
|
||||||
|
)
|
Loading…
x
Reference in New Issue
Block a user