From 6f058c7ba88e657457ad5db9226d5b194e5aaabe Mon Sep 17 00:00:00 2001 From: Woosuk Kwon Date: Thu, 16 Feb 2023 07:47:03 +0000 Subject: [PATCH] Implement cache ops --- cacheflow/worker/cache_engine.py | 29 ++++++++++++++++----- csrc/cache.cpp | 20 +++++++++++++++ csrc/cache_kernel.cu | 43 ++++++++++++++++++++++++++++++++ setup.py | 23 +++++++++++++++++ 4 files changed, 109 insertions(+), 6 deletions(-) create mode 100644 csrc/cache.cpp create mode 100644 csrc/cache_kernel.cu diff --git a/cacheflow/worker/cache_engine.py b/cacheflow/worker/cache_engine.py index 03a60f9b..7f4d291b 100644 --- a/cacheflow/worker/cache_engine.py +++ b/cacheflow/worker/cache_engine.py @@ -1,6 +1,7 @@ from typing import Dict, List, Tuple import torch +from cacheflow import ops KVCache = Tuple[torch.Tensor, torch.Tensor] @@ -92,14 +93,30 @@ class CacheEngine: cpu_cache.append((key_blocks, value_blocks)) return cpu_cache + def _copy_blocks( + self, + src: List[KVCache], + dst: List[KVCache], + src_to_dst: Dict[int, int], + ) -> None: + with torch.cuda.stream(self.cache_stream): + for i in range(self.num_layers): + src_key_cache, src_value_cache = src[i] + dst_key_cache, dst_value_cache = dst[i] + # Copy the key blocks. + ops.copy_cache_blocks( + src_key_cache, dst_key_cache, src_to_dst) + # Copy the value blocks. + ops.copy_cache_blocks( + src_value_cache, dst_value_cache, src_to_dst) + event = self.events[i] + event.record(stream=self.cache_stream) + def copy(self, src_to_dst: Dict[int, int]) -> None: - for event in self.events: - pass + self._copy_blocks(self.gpu_cache, self.gpu_cache, src_to_dst) def swap_in(self, src_to_dst: Dict[int, int]) -> None: - for event in self.events: - pass + self._copy_blocks(self.cpu_cache, self.gpu_cache, src_to_dst) def swap_out(self, src_to_dst: Dict[int, int]) -> None: - for event in self.events: - pass + self._copy_blocks(self.gpu_cache, self.cpu_cache, src_to_dst) diff --git a/csrc/cache.cpp b/csrc/cache.cpp new file mode 100644 index 00000000..786452c0 --- /dev/null +++ b/csrc/cache.cpp @@ -0,0 +1,20 @@ +#include + +void copy_blocks( + torch::Tensor& src, + torch::Tensor& dst, + const std::map& block_mapping); + +void copy_cache_blocks( + torch::Tensor& src, + torch::Tensor& dst, + const std::map& block_mapping) { + copy_blocks(src, dst, block_mapping); +} + +PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { + m.def( + "copy_cache_blocks", + ©_cache_blocks, + "Copy the cache blocks from src to dst"); +} diff --git a/csrc/cache_kernel.cu b/csrc/cache_kernel.cu new file mode 100644 index 00000000..7a8befc0 --- /dev/null +++ b/csrc/cache_kernel.cu @@ -0,0 +1,43 @@ +#include + +#include + +#include +#include + +void copy_blocks( + torch::Tensor& src, + torch::Tensor& dst, + const std::map& block_mapping) { + torch::Device src_device = src.device(); + torch::Device dst_device = dst.device(); + cudaMemcpyKind memcpy_type; + if (src_device.is_cuda() && dst_device.is_cuda()) { + assert(src_device.index() == dst_device.index()); + memcpy_type = cudaMemcpyDeviceToDevice; + } else if (src_device.is_cuda() && dst_device.is_cpu()) { + memcpy_type = cudaMemcpyDeviceToHost; + } else if (src_device.is_cpu() && dst_device.is_cuda()) { + memcpy_type = cudaMemcpyHostToDevice; + } else { + assert(false); + } + + void *src_ptr = src.data_ptr(); + void *dst_ptr = dst.data_ptr(); + + const int64_t block_size_in_bytes = src.element_size() * src[0].numel(); + const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + for (const auto& pair : block_mapping) { + int64_t src_block_number = pair.first; + int64_t dst_block_number = pair.second; + int64_t src_offset = src_block_number * block_size_in_bytes; + int64_t dst_offset = dst_block_number * block_size_in_bytes; + cudaMemcpyAsync( + dst_ptr + dst_offset, + src_ptr + src_offset, + block_size_in_bytes, + memcpy_type, + stream); + } +} diff --git a/setup.py b/setup.py index e69de29b..1262323b 100644 --- a/setup.py +++ b/setup.py @@ -0,0 +1,23 @@ +import setuptools +from torch.utils import cpp_extension + +CXX_FLAGS = ['-g'] +NVCC_FLAGS = ['-O2'] + + +ext_modules = [] + +# Cache operations. +cache_extension = cpp_extension.CUDAExtension( + name='cacheflow.ops', + sources=['csrc/cache.cpp', 'csrc/cache_kernel.cu'], + extra_compile_args={'cxx': CXX_FLAGS, 'nvcc': NVCC_FLAGS}, +) +ext_modules.append(cache_extension) + +setuptools.setup( + name='cacheflow', + requires_python='>=3.9', + ext_modules=ext_modules, + cmdclass={'build_ext': cpp_extension.BuildExtension}, +)