Implement block copy kernel to optimize beam search (#32)

This commit is contained in:
Woosuk Kwon 2023-04-07 17:45:07 -07:00 committed by GitHub
parent a490aafa36
commit 0f40557af6
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 154 additions and 48 deletions

View File

@ -50,14 +50,15 @@ def main(args: argparse.Namespace):
block_size=args.block_size,
)
sampling_params_dict = {
'n': 1,
'temperature': 0.0,
'n': args.n,
'temperature': 0.0 if args.use_beam_search else 1.0,
'top_p': 1.0,
'use_beam_search': False,
'use_beam_search': args.use_beam_search,
'stop_token_ids': set(),
'max_num_steps': args.output_len,
}
sampling_params = SamplingParams.from_dict(sampling_params_dict)
print(sampling_params)
input_token_ids = [0] * args.input_len
def profile_step(profile=False):
@ -93,6 +94,8 @@ if __name__ == '__main__':
parser.add_argument('--input-len', type=int, default=32)
parser.add_argument('--output-len', type=int, default=128)
parser.add_argument('--batch-size', type=int, default=8)
parser.add_argument('--n', type=int, default=1)
parser.add_argument('--use-beam-search', action='store_true')
args = parser.parse_args()
args.max_num_batched_tokens = max(
args.max_num_batched_tokens, args.batch_size * args.input_len)

View File

@ -185,9 +185,10 @@ def _sample_from_generation_tokens(
vocab_size = logprobs.size(-1)
beam_width = len(seq_ids)
_, topk_ids = torch.topk(logprobs.flatten(), beam_width)
seq_idx = torch.div(topk_ids, vocab_size, rounding_mode='floor').tolist()
topk_ids = topk_ids.tolist()
seq_idx = [i // vocab_size for i in topk_ids]
beam_seq_ids = [seq_ids[i] for i in seq_idx]
token_ids = (topk_ids % vocab_size).tolist()
token_ids = [i % vocab_size for i in topk_ids]
beam_outputs: Dict[int, Tuple[int, int]] = {}
outstanding_beams: List[Tuple[int, int]] = []

View File

@ -120,24 +120,8 @@ class CacheEngine:
def swap_out(self, src_to_dst: Dict[int, int]) -> None:
self._swap(self.gpu_cache, self.cpu_cache, src_to_dst)
def _copy(
self,
src: List[KVCache],
dst: List[KVCache],
src_to_dsts: Dict[int, List[int]],
) -> None:
with torch.cuda.stream(self.cache_stream):
for i in range(self.num_layers):
src_key_cache, src_value_cache = src[i]
dst_key_cache, dst_value_cache = dst[i]
# Copy the key blocks.
cache_ops.copy_blocks(
src_key_cache, dst_key_cache, src_to_dsts)
# Copy the value blocks.
cache_ops.copy_blocks(
src_value_cache, dst_value_cache, src_to_dsts)
event = self.events[i]
event.record(stream=self.cache_stream)
def copy(self, src_to_dsts: Dict[int, List[int]]) -> None:
self._copy(self.gpu_cache, self.gpu_cache, src_to_dsts)
key_caches = [key_cache for key_cache, _ in self.gpu_cache]
value_caches = [value_cache for _, value_cache in self.gpu_cache]
# NOTE(woosuk): This operation implicitly synchronizes the CPU and GPU.
cache_ops.copy_blocks(key_caches, value_caches, src_to_dsts)

View File

@ -9,8 +9,8 @@ void swap_blocks(
const std::map<int64_t, int64_t>& block_mapping);
void copy_blocks(
torch::Tensor& src,
torch::Tensor& dst,
std::vector<torch::Tensor>& key_caches,
std::vector<torch::Tensor>& value_caches,
const std::map<int64_t, std::vector<int64_t>>& block_mapping);
void reshape_and_cache(

View File

@ -43,33 +43,93 @@ void swap_blocks(
}
}
namespace cacheflow {
// Grid: (num_layers, num_pairs)
template<typename scalar_t>
__global__ void copy_blocks_kernel(
int64_t* key_cache_ptrs,
int64_t* value_cache_ptrs,
const int* __restrict__ block_mapping,
const int numel_per_block) {
const int layer_idx = blockIdx.x;
const int pair_idx = blockIdx.y;
scalar_t* key_cache = reinterpret_cast<scalar_t*>(key_cache_ptrs[layer_idx]);
scalar_t* value_cache = reinterpret_cast<scalar_t*>(value_cache_ptrs[layer_idx]);
int src_block_number = block_mapping[2 * pair_idx];
int dst_block_number = block_mapping[2 * pair_idx + 1];
const int src_block_offset = src_block_number * numel_per_block;
const int dst_block_offset = dst_block_number * numel_per_block;
for (int i = threadIdx.x; i < numel_per_block; i += blockDim.x) {
int src_offset = src_block_offset + i;
int dst_offset = dst_block_offset + i;
key_cache[dst_offset] = key_cache[src_offset];
}
for (int i = threadIdx.x; i < numel_per_block; i += blockDim.x) {
int src_offset = src_block_offset + i;
int dst_offset = dst_block_offset + i;
value_cache[dst_offset] = value_cache[src_offset];
}
}
} // namespace cacheflow
void copy_blocks(
torch::Tensor& src,
torch::Tensor& dst,
std::vector<torch::Tensor>& key_caches,
std::vector<torch::Tensor>& value_caches,
const std::map<int64_t, std::vector<int64_t>>& block_mapping) {
torch::Device src_device = src.device();
torch::Device dst_device = dst.device();
assert(src_device.is_cuda() && dst_device.is_cuda());
cudaMemcpyKind memcpy_type = cudaMemcpyDeviceToDevice;
int num_layers = key_caches.size();
TORCH_CHECK(num_layers == value_caches.size());
if (num_layers == 0) {
return;
}
torch::Device cache_device = key_caches[0].device();
TORCH_CHECK(cache_device.is_cuda());
void *src_ptr = src.data_ptr();
void *dst_ptr = dst.data_ptr();
const int64_t block_size_in_bytes = src.element_size() * src[0].numel();
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
// Create data structures for the kernel.
// Create an array of pointers to the key and value caches.
int64_t key_cache_ptrs[num_layers];
int64_t value_cache_ptrs[num_layers];
for (int layer_idx = 0; layer_idx < num_layers; ++layer_idx) {
key_cache_ptrs[layer_idx] = reinterpret_cast<int64_t>(key_caches[layer_idx].data_ptr());
value_cache_ptrs[layer_idx] = reinterpret_cast<int64_t>(value_caches[layer_idx].data_ptr());
}
// Create block mapping array.
std::vector<int> block_mapping_vec;
for (const auto& pair : block_mapping) {
int64_t src_block_number = pair.first;
for (int64_t dst_block_number : pair.second) {
int64_t src_offset = src_block_number * block_size_in_bytes;
int64_t dst_offset = dst_block_number * block_size_in_bytes;
cudaMemcpyAsync(
dst_ptr + dst_offset,
src_ptr + src_offset,
block_size_in_bytes,
memcpy_type,
stream);
int src_block_number = pair.first;
for (int dst_block_number : pair.second) {
block_mapping_vec.push_back(src_block_number);
block_mapping_vec.push_back(dst_block_number);
}
}
int* block_mapping_array = block_mapping_vec.data();
int num_pairs = block_mapping_vec.size() / 2;
// Move the data structures to the GPU.
// NOTE: This synchronizes the CPU and GPU.
torch::Tensor key_cache_ptrs_tensor = torch::from_blob(
key_cache_ptrs, {num_layers}, torch::kInt64).to(cache_device);
torch::Tensor value_cache_ptrs_tensor = torch::from_blob(
value_cache_ptrs, {num_layers}, torch::kInt64).to(cache_device);
torch::Tensor block_mapping_tensor = torch::from_blob(
block_mapping_array, {2 * num_pairs}, torch::kInt).to(cache_device);
// Launch the kernel.
const int numel_per_block = key_caches[0][0].numel();
dim3 grid(num_layers, num_pairs);
dim3 block(std::min(1024, numel_per_block));
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
AT_DISPATCH_FLOATING_TYPES_AND_HALF(
key_caches[0].scalar_type(), "copy_blocks_kernel", ([&] {
cacheflow::copy_blocks_kernel<scalar_t><<<grid, block, 0, stream>>>(
key_cache_ptrs_tensor.data_ptr<int64_t>(),
value_cache_ptrs_tensor.data_ptr<int64_t>(),
block_mapping_tensor.data_ptr<int>(),
numel_per_block);
}));
}
namespace cacheflow {

View File

@ -5,6 +5,61 @@ import torch
from cacheflow import cache_ops
def test_copy_blocks(
num_mappings: int,
num_layers: int,
num_heads: int,
head_size: int,
block_size: int,
num_blocks: int,
dtype: torch.dtype,
) -> None:
# Generate random block mappings.
src_blocks = random.sample(range(num_blocks), num_mappings)
remainig_blocks = list(set(range(num_blocks)) - set(src_blocks))
dst_blocks = random.sample(remainig_blocks, num_mappings)
block_mapping = {src: [dst] for src, dst in zip(src_blocks, dst_blocks)}
# Create the KV cache.
x = 16 // torch.tensor([], dtype=dtype).element_size()
key_cache_shape = (num_blocks, num_heads, head_size // x, block_size, x)
key_caches = []
for _ in range(num_layers):
key_cache = torch.randn(
size=key_cache_shape, dtype=dtype, device='cuda')
key_caches.append(key_cache)
cloned_key_caches = []
for key_cache in key_caches:
cloned_key_caches.append(key_cache.clone())
value_cache_shape = (num_blocks, num_heads, head_size, block_size)
value_caches = []
for _ in range(num_layers):
value_cache = torch.randn(
size=value_cache_shape, dtype=dtype, device='cuda')
value_caches.append(value_cache)
cloned_value_caches = []
for value_cache in value_caches:
cloned_value_caches.append(value_cache.clone())
# Call the copy blocks kernel.
cache_ops.copy_blocks(key_caches, value_caches, block_mapping)
# Reference implementation.
for src, dsts in block_mapping.items():
for dst in dsts:
for key_cache, cloned_key_cache in zip(key_caches, cloned_key_caches):
cloned_key_cache[dst] = cloned_key_cache[src]
for value_cache, cloned_value_cache in zip(value_caches, cloned_value_caches):
cloned_value_cache[dst] = cloned_value_cache[src]
# Compare the results.
for key_cache, cloned_key_cache in zip(key_caches, cloned_key_caches):
assert torch.allclose(key_cache, cloned_key_cache)
for value_cache, cloned_value_cache in zip(value_caches, cloned_value_caches):
assert torch.allclose(value_cache, cloned_value_cache)
def test_reshape_and_cache(
num_tokens: int,
num_heads: int,
@ -46,6 +101,9 @@ def test_reshape_and_cache(
@torch.inference_mode()
def test_cache() -> None:
test_copy_blocks(
num_mappings=23, num_layers=7, num_heads=17, head_size=16,
block_size=8, num_blocks=1024, dtype=torch.half)
test_reshape_and_cache(
num_tokens=3, num_heads=2, head_size=16, block_size=8, num_blocks=2,
dtype=torch.half)