Implement block copy kernel to optimize beam search (#32)
This commit is contained in:
parent
a490aafa36
commit
0f40557af6
@ -50,14 +50,15 @@ def main(args: argparse.Namespace):
|
||||
block_size=args.block_size,
|
||||
)
|
||||
sampling_params_dict = {
|
||||
'n': 1,
|
||||
'temperature': 0.0,
|
||||
'n': args.n,
|
||||
'temperature': 0.0 if args.use_beam_search else 1.0,
|
||||
'top_p': 1.0,
|
||||
'use_beam_search': False,
|
||||
'use_beam_search': args.use_beam_search,
|
||||
'stop_token_ids': set(),
|
||||
'max_num_steps': args.output_len,
|
||||
}
|
||||
sampling_params = SamplingParams.from_dict(sampling_params_dict)
|
||||
print(sampling_params)
|
||||
input_token_ids = [0] * args.input_len
|
||||
|
||||
def profile_step(profile=False):
|
||||
@ -93,6 +94,8 @@ if __name__ == '__main__':
|
||||
parser.add_argument('--input-len', type=int, default=32)
|
||||
parser.add_argument('--output-len', type=int, default=128)
|
||||
parser.add_argument('--batch-size', type=int, default=8)
|
||||
parser.add_argument('--n', type=int, default=1)
|
||||
parser.add_argument('--use-beam-search', action='store_true')
|
||||
args = parser.parse_args()
|
||||
args.max_num_batched_tokens = max(
|
||||
args.max_num_batched_tokens, args.batch_size * args.input_len)
|
||||
|
@ -185,9 +185,10 @@ def _sample_from_generation_tokens(
|
||||
vocab_size = logprobs.size(-1)
|
||||
beam_width = len(seq_ids)
|
||||
_, topk_ids = torch.topk(logprobs.flatten(), beam_width)
|
||||
seq_idx = torch.div(topk_ids, vocab_size, rounding_mode='floor').tolist()
|
||||
topk_ids = topk_ids.tolist()
|
||||
seq_idx = [i // vocab_size for i in topk_ids]
|
||||
beam_seq_ids = [seq_ids[i] for i in seq_idx]
|
||||
token_ids = (topk_ids % vocab_size).tolist()
|
||||
token_ids = [i % vocab_size for i in topk_ids]
|
||||
|
||||
beam_outputs: Dict[int, Tuple[int, int]] = {}
|
||||
outstanding_beams: List[Tuple[int, int]] = []
|
||||
|
@ -120,24 +120,8 @@ class CacheEngine:
|
||||
def swap_out(self, src_to_dst: Dict[int, int]) -> None:
|
||||
self._swap(self.gpu_cache, self.cpu_cache, src_to_dst)
|
||||
|
||||
def _copy(
|
||||
self,
|
||||
src: List[KVCache],
|
||||
dst: List[KVCache],
|
||||
src_to_dsts: Dict[int, List[int]],
|
||||
) -> None:
|
||||
with torch.cuda.stream(self.cache_stream):
|
||||
for i in range(self.num_layers):
|
||||
src_key_cache, src_value_cache = src[i]
|
||||
dst_key_cache, dst_value_cache = dst[i]
|
||||
# Copy the key blocks.
|
||||
cache_ops.copy_blocks(
|
||||
src_key_cache, dst_key_cache, src_to_dsts)
|
||||
# Copy the value blocks.
|
||||
cache_ops.copy_blocks(
|
||||
src_value_cache, dst_value_cache, src_to_dsts)
|
||||
event = self.events[i]
|
||||
event.record(stream=self.cache_stream)
|
||||
|
||||
def copy(self, src_to_dsts: Dict[int, List[int]]) -> None:
|
||||
self._copy(self.gpu_cache, self.gpu_cache, src_to_dsts)
|
||||
key_caches = [key_cache for key_cache, _ in self.gpu_cache]
|
||||
value_caches = [value_cache for _, value_cache in self.gpu_cache]
|
||||
# NOTE(woosuk): This operation implicitly synchronizes the CPU and GPU.
|
||||
cache_ops.copy_blocks(key_caches, value_caches, src_to_dsts)
|
||||
|
@ -9,8 +9,8 @@ void swap_blocks(
|
||||
const std::map<int64_t, int64_t>& block_mapping);
|
||||
|
||||
void copy_blocks(
|
||||
torch::Tensor& src,
|
||||
torch::Tensor& dst,
|
||||
std::vector<torch::Tensor>& key_caches,
|
||||
std::vector<torch::Tensor>& value_caches,
|
||||
const std::map<int64_t, std::vector<int64_t>>& block_mapping);
|
||||
|
||||
void reshape_and_cache(
|
||||
|
@ -43,33 +43,93 @@ void swap_blocks(
|
||||
}
|
||||
}
|
||||
|
||||
namespace cacheflow {
|
||||
|
||||
// Grid: (num_layers, num_pairs)
|
||||
template<typename scalar_t>
|
||||
__global__ void copy_blocks_kernel(
|
||||
int64_t* key_cache_ptrs,
|
||||
int64_t* value_cache_ptrs,
|
||||
const int* __restrict__ block_mapping,
|
||||
const int numel_per_block) {
|
||||
const int layer_idx = blockIdx.x;
|
||||
const int pair_idx = blockIdx.y;
|
||||
|
||||
scalar_t* key_cache = reinterpret_cast<scalar_t*>(key_cache_ptrs[layer_idx]);
|
||||
scalar_t* value_cache = reinterpret_cast<scalar_t*>(value_cache_ptrs[layer_idx]);
|
||||
int src_block_number = block_mapping[2 * pair_idx];
|
||||
int dst_block_number = block_mapping[2 * pair_idx + 1];
|
||||
|
||||
const int src_block_offset = src_block_number * numel_per_block;
|
||||
const int dst_block_offset = dst_block_number * numel_per_block;
|
||||
for (int i = threadIdx.x; i < numel_per_block; i += blockDim.x) {
|
||||
int src_offset = src_block_offset + i;
|
||||
int dst_offset = dst_block_offset + i;
|
||||
key_cache[dst_offset] = key_cache[src_offset];
|
||||
}
|
||||
for (int i = threadIdx.x; i < numel_per_block; i += blockDim.x) {
|
||||
int src_offset = src_block_offset + i;
|
||||
int dst_offset = dst_block_offset + i;
|
||||
value_cache[dst_offset] = value_cache[src_offset];
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace cacheflow
|
||||
|
||||
void copy_blocks(
|
||||
torch::Tensor& src,
|
||||
torch::Tensor& dst,
|
||||
std::vector<torch::Tensor>& key_caches,
|
||||
std::vector<torch::Tensor>& value_caches,
|
||||
const std::map<int64_t, std::vector<int64_t>>& block_mapping) {
|
||||
torch::Device src_device = src.device();
|
||||
torch::Device dst_device = dst.device();
|
||||
assert(src_device.is_cuda() && dst_device.is_cuda());
|
||||
cudaMemcpyKind memcpy_type = cudaMemcpyDeviceToDevice;
|
||||
int num_layers = key_caches.size();
|
||||
TORCH_CHECK(num_layers == value_caches.size());
|
||||
if (num_layers == 0) {
|
||||
return;
|
||||
}
|
||||
torch::Device cache_device = key_caches[0].device();
|
||||
TORCH_CHECK(cache_device.is_cuda());
|
||||
|
||||
void *src_ptr = src.data_ptr();
|
||||
void *dst_ptr = dst.data_ptr();
|
||||
|
||||
const int64_t block_size_in_bytes = src.element_size() * src[0].numel();
|
||||
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
||||
// Create data structures for the kernel.
|
||||
// Create an array of pointers to the key and value caches.
|
||||
int64_t key_cache_ptrs[num_layers];
|
||||
int64_t value_cache_ptrs[num_layers];
|
||||
for (int layer_idx = 0; layer_idx < num_layers; ++layer_idx) {
|
||||
key_cache_ptrs[layer_idx] = reinterpret_cast<int64_t>(key_caches[layer_idx].data_ptr());
|
||||
value_cache_ptrs[layer_idx] = reinterpret_cast<int64_t>(value_caches[layer_idx].data_ptr());
|
||||
}
|
||||
// Create block mapping array.
|
||||
std::vector<int> block_mapping_vec;
|
||||
for (const auto& pair : block_mapping) {
|
||||
int64_t src_block_number = pair.first;
|
||||
for (int64_t dst_block_number : pair.second) {
|
||||
int64_t src_offset = src_block_number * block_size_in_bytes;
|
||||
int64_t dst_offset = dst_block_number * block_size_in_bytes;
|
||||
cudaMemcpyAsync(
|
||||
dst_ptr + dst_offset,
|
||||
src_ptr + src_offset,
|
||||
block_size_in_bytes,
|
||||
memcpy_type,
|
||||
stream);
|
||||
int src_block_number = pair.first;
|
||||
for (int dst_block_number : pair.second) {
|
||||
block_mapping_vec.push_back(src_block_number);
|
||||
block_mapping_vec.push_back(dst_block_number);
|
||||
}
|
||||
}
|
||||
int* block_mapping_array = block_mapping_vec.data();
|
||||
int num_pairs = block_mapping_vec.size() / 2;
|
||||
|
||||
// Move the data structures to the GPU.
|
||||
// NOTE: This synchronizes the CPU and GPU.
|
||||
torch::Tensor key_cache_ptrs_tensor = torch::from_blob(
|
||||
key_cache_ptrs, {num_layers}, torch::kInt64).to(cache_device);
|
||||
torch::Tensor value_cache_ptrs_tensor = torch::from_blob(
|
||||
value_cache_ptrs, {num_layers}, torch::kInt64).to(cache_device);
|
||||
torch::Tensor block_mapping_tensor = torch::from_blob(
|
||||
block_mapping_array, {2 * num_pairs}, torch::kInt).to(cache_device);
|
||||
|
||||
// Launch the kernel.
|
||||
const int numel_per_block = key_caches[0][0].numel();
|
||||
dim3 grid(num_layers, num_pairs);
|
||||
dim3 block(std::min(1024, numel_per_block));
|
||||
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
||||
AT_DISPATCH_FLOATING_TYPES_AND_HALF(
|
||||
key_caches[0].scalar_type(), "copy_blocks_kernel", ([&] {
|
||||
cacheflow::copy_blocks_kernel<scalar_t><<<grid, block, 0, stream>>>(
|
||||
key_cache_ptrs_tensor.data_ptr<int64_t>(),
|
||||
value_cache_ptrs_tensor.data_ptr<int64_t>(),
|
||||
block_mapping_tensor.data_ptr<int>(),
|
||||
numel_per_block);
|
||||
}));
|
||||
}
|
||||
|
||||
namespace cacheflow {
|
||||
|
@ -5,6 +5,61 @@ import torch
|
||||
from cacheflow import cache_ops
|
||||
|
||||
|
||||
def test_copy_blocks(
|
||||
num_mappings: int,
|
||||
num_layers: int,
|
||||
num_heads: int,
|
||||
head_size: int,
|
||||
block_size: int,
|
||||
num_blocks: int,
|
||||
dtype: torch.dtype,
|
||||
) -> None:
|
||||
# Generate random block mappings.
|
||||
src_blocks = random.sample(range(num_blocks), num_mappings)
|
||||
remainig_blocks = list(set(range(num_blocks)) - set(src_blocks))
|
||||
dst_blocks = random.sample(remainig_blocks, num_mappings)
|
||||
block_mapping = {src: [dst] for src, dst in zip(src_blocks, dst_blocks)}
|
||||
|
||||
# Create the KV cache.
|
||||
x = 16 // torch.tensor([], dtype=dtype).element_size()
|
||||
key_cache_shape = (num_blocks, num_heads, head_size // x, block_size, x)
|
||||
key_caches = []
|
||||
for _ in range(num_layers):
|
||||
key_cache = torch.randn(
|
||||
size=key_cache_shape, dtype=dtype, device='cuda')
|
||||
key_caches.append(key_cache)
|
||||
cloned_key_caches = []
|
||||
for key_cache in key_caches:
|
||||
cloned_key_caches.append(key_cache.clone())
|
||||
|
||||
value_cache_shape = (num_blocks, num_heads, head_size, block_size)
|
||||
value_caches = []
|
||||
for _ in range(num_layers):
|
||||
value_cache = torch.randn(
|
||||
size=value_cache_shape, dtype=dtype, device='cuda')
|
||||
value_caches.append(value_cache)
|
||||
cloned_value_caches = []
|
||||
for value_cache in value_caches:
|
||||
cloned_value_caches.append(value_cache.clone())
|
||||
|
||||
# Call the copy blocks kernel.
|
||||
cache_ops.copy_blocks(key_caches, value_caches, block_mapping)
|
||||
|
||||
# Reference implementation.
|
||||
for src, dsts in block_mapping.items():
|
||||
for dst in dsts:
|
||||
for key_cache, cloned_key_cache in zip(key_caches, cloned_key_caches):
|
||||
cloned_key_cache[dst] = cloned_key_cache[src]
|
||||
for value_cache, cloned_value_cache in zip(value_caches, cloned_value_caches):
|
||||
cloned_value_cache[dst] = cloned_value_cache[src]
|
||||
|
||||
# Compare the results.
|
||||
for key_cache, cloned_key_cache in zip(key_caches, cloned_key_caches):
|
||||
assert torch.allclose(key_cache, cloned_key_cache)
|
||||
for value_cache, cloned_value_cache in zip(value_caches, cloned_value_caches):
|
||||
assert torch.allclose(value_cache, cloned_value_cache)
|
||||
|
||||
|
||||
def test_reshape_and_cache(
|
||||
num_tokens: int,
|
||||
num_heads: int,
|
||||
@ -46,6 +101,9 @@ def test_reshape_and_cache(
|
||||
|
||||
@torch.inference_mode()
|
||||
def test_cache() -> None:
|
||||
test_copy_blocks(
|
||||
num_mappings=23, num_layers=7, num_heads=17, head_size=16,
|
||||
block_size=8, num_blocks=1024, dtype=torch.half)
|
||||
test_reshape_and_cache(
|
||||
num_tokens=3, num_heads=2, head_size=16, block_size=8, num_blocks=2,
|
||||
dtype=torch.half)
|
||||
|
Loading…
x
Reference in New Issue
Block a user