Implement block copy kernel to optimize beam search (#32)
This commit is contained in:
parent
a490aafa36
commit
0f40557af6
@ -50,14 +50,15 @@ def main(args: argparse.Namespace):
|
|||||||
block_size=args.block_size,
|
block_size=args.block_size,
|
||||||
)
|
)
|
||||||
sampling_params_dict = {
|
sampling_params_dict = {
|
||||||
'n': 1,
|
'n': args.n,
|
||||||
'temperature': 0.0,
|
'temperature': 0.0 if args.use_beam_search else 1.0,
|
||||||
'top_p': 1.0,
|
'top_p': 1.0,
|
||||||
'use_beam_search': False,
|
'use_beam_search': args.use_beam_search,
|
||||||
'stop_token_ids': set(),
|
'stop_token_ids': set(),
|
||||||
'max_num_steps': args.output_len,
|
'max_num_steps': args.output_len,
|
||||||
}
|
}
|
||||||
sampling_params = SamplingParams.from_dict(sampling_params_dict)
|
sampling_params = SamplingParams.from_dict(sampling_params_dict)
|
||||||
|
print(sampling_params)
|
||||||
input_token_ids = [0] * args.input_len
|
input_token_ids = [0] * args.input_len
|
||||||
|
|
||||||
def profile_step(profile=False):
|
def profile_step(profile=False):
|
||||||
@ -93,6 +94,8 @@ if __name__ == '__main__':
|
|||||||
parser.add_argument('--input-len', type=int, default=32)
|
parser.add_argument('--input-len', type=int, default=32)
|
||||||
parser.add_argument('--output-len', type=int, default=128)
|
parser.add_argument('--output-len', type=int, default=128)
|
||||||
parser.add_argument('--batch-size', type=int, default=8)
|
parser.add_argument('--batch-size', type=int, default=8)
|
||||||
|
parser.add_argument('--n', type=int, default=1)
|
||||||
|
parser.add_argument('--use-beam-search', action='store_true')
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
args.max_num_batched_tokens = max(
|
args.max_num_batched_tokens = max(
|
||||||
args.max_num_batched_tokens, args.batch_size * args.input_len)
|
args.max_num_batched_tokens, args.batch_size * args.input_len)
|
||||||
|
@ -185,9 +185,10 @@ def _sample_from_generation_tokens(
|
|||||||
vocab_size = logprobs.size(-1)
|
vocab_size = logprobs.size(-1)
|
||||||
beam_width = len(seq_ids)
|
beam_width = len(seq_ids)
|
||||||
_, topk_ids = torch.topk(logprobs.flatten(), beam_width)
|
_, topk_ids = torch.topk(logprobs.flatten(), beam_width)
|
||||||
seq_idx = torch.div(topk_ids, vocab_size, rounding_mode='floor').tolist()
|
topk_ids = topk_ids.tolist()
|
||||||
|
seq_idx = [i // vocab_size for i in topk_ids]
|
||||||
beam_seq_ids = [seq_ids[i] for i in seq_idx]
|
beam_seq_ids = [seq_ids[i] for i in seq_idx]
|
||||||
token_ids = (topk_ids % vocab_size).tolist()
|
token_ids = [i % vocab_size for i in topk_ids]
|
||||||
|
|
||||||
beam_outputs: Dict[int, Tuple[int, int]] = {}
|
beam_outputs: Dict[int, Tuple[int, int]] = {}
|
||||||
outstanding_beams: List[Tuple[int, int]] = []
|
outstanding_beams: List[Tuple[int, int]] = []
|
||||||
|
@ -120,24 +120,8 @@ class CacheEngine:
|
|||||||
def swap_out(self, src_to_dst: Dict[int, int]) -> None:
|
def swap_out(self, src_to_dst: Dict[int, int]) -> None:
|
||||||
self._swap(self.gpu_cache, self.cpu_cache, src_to_dst)
|
self._swap(self.gpu_cache, self.cpu_cache, src_to_dst)
|
||||||
|
|
||||||
def _copy(
|
|
||||||
self,
|
|
||||||
src: List[KVCache],
|
|
||||||
dst: List[KVCache],
|
|
||||||
src_to_dsts: Dict[int, List[int]],
|
|
||||||
) -> None:
|
|
||||||
with torch.cuda.stream(self.cache_stream):
|
|
||||||
for i in range(self.num_layers):
|
|
||||||
src_key_cache, src_value_cache = src[i]
|
|
||||||
dst_key_cache, dst_value_cache = dst[i]
|
|
||||||
# Copy the key blocks.
|
|
||||||
cache_ops.copy_blocks(
|
|
||||||
src_key_cache, dst_key_cache, src_to_dsts)
|
|
||||||
# Copy the value blocks.
|
|
||||||
cache_ops.copy_blocks(
|
|
||||||
src_value_cache, dst_value_cache, src_to_dsts)
|
|
||||||
event = self.events[i]
|
|
||||||
event.record(stream=self.cache_stream)
|
|
||||||
|
|
||||||
def copy(self, src_to_dsts: Dict[int, List[int]]) -> None:
|
def copy(self, src_to_dsts: Dict[int, List[int]]) -> None:
|
||||||
self._copy(self.gpu_cache, self.gpu_cache, src_to_dsts)
|
key_caches = [key_cache for key_cache, _ in self.gpu_cache]
|
||||||
|
value_caches = [value_cache for _, value_cache in self.gpu_cache]
|
||||||
|
# NOTE(woosuk): This operation implicitly synchronizes the CPU and GPU.
|
||||||
|
cache_ops.copy_blocks(key_caches, value_caches, src_to_dsts)
|
||||||
|
@ -9,8 +9,8 @@ void swap_blocks(
|
|||||||
const std::map<int64_t, int64_t>& block_mapping);
|
const std::map<int64_t, int64_t>& block_mapping);
|
||||||
|
|
||||||
void copy_blocks(
|
void copy_blocks(
|
||||||
torch::Tensor& src,
|
std::vector<torch::Tensor>& key_caches,
|
||||||
torch::Tensor& dst,
|
std::vector<torch::Tensor>& value_caches,
|
||||||
const std::map<int64_t, std::vector<int64_t>>& block_mapping);
|
const std::map<int64_t, std::vector<int64_t>>& block_mapping);
|
||||||
|
|
||||||
void reshape_and_cache(
|
void reshape_and_cache(
|
||||||
|
@ -43,33 +43,93 @@ void swap_blocks(
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
namespace cacheflow {
|
||||||
|
|
||||||
|
// Grid: (num_layers, num_pairs)
|
||||||
|
template<typename scalar_t>
|
||||||
|
__global__ void copy_blocks_kernel(
|
||||||
|
int64_t* key_cache_ptrs,
|
||||||
|
int64_t* value_cache_ptrs,
|
||||||
|
const int* __restrict__ block_mapping,
|
||||||
|
const int numel_per_block) {
|
||||||
|
const int layer_idx = blockIdx.x;
|
||||||
|
const int pair_idx = blockIdx.y;
|
||||||
|
|
||||||
|
scalar_t* key_cache = reinterpret_cast<scalar_t*>(key_cache_ptrs[layer_idx]);
|
||||||
|
scalar_t* value_cache = reinterpret_cast<scalar_t*>(value_cache_ptrs[layer_idx]);
|
||||||
|
int src_block_number = block_mapping[2 * pair_idx];
|
||||||
|
int dst_block_number = block_mapping[2 * pair_idx + 1];
|
||||||
|
|
||||||
|
const int src_block_offset = src_block_number * numel_per_block;
|
||||||
|
const int dst_block_offset = dst_block_number * numel_per_block;
|
||||||
|
for (int i = threadIdx.x; i < numel_per_block; i += blockDim.x) {
|
||||||
|
int src_offset = src_block_offset + i;
|
||||||
|
int dst_offset = dst_block_offset + i;
|
||||||
|
key_cache[dst_offset] = key_cache[src_offset];
|
||||||
|
}
|
||||||
|
for (int i = threadIdx.x; i < numel_per_block; i += blockDim.x) {
|
||||||
|
int src_offset = src_block_offset + i;
|
||||||
|
int dst_offset = dst_block_offset + i;
|
||||||
|
value_cache[dst_offset] = value_cache[src_offset];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace cacheflow
|
||||||
|
|
||||||
void copy_blocks(
|
void copy_blocks(
|
||||||
torch::Tensor& src,
|
std::vector<torch::Tensor>& key_caches,
|
||||||
torch::Tensor& dst,
|
std::vector<torch::Tensor>& value_caches,
|
||||||
const std::map<int64_t, std::vector<int64_t>>& block_mapping) {
|
const std::map<int64_t, std::vector<int64_t>>& block_mapping) {
|
||||||
torch::Device src_device = src.device();
|
int num_layers = key_caches.size();
|
||||||
torch::Device dst_device = dst.device();
|
TORCH_CHECK(num_layers == value_caches.size());
|
||||||
assert(src_device.is_cuda() && dst_device.is_cuda());
|
if (num_layers == 0) {
|
||||||
cudaMemcpyKind memcpy_type = cudaMemcpyDeviceToDevice;
|
return;
|
||||||
|
}
|
||||||
|
torch::Device cache_device = key_caches[0].device();
|
||||||
|
TORCH_CHECK(cache_device.is_cuda());
|
||||||
|
|
||||||
void *src_ptr = src.data_ptr();
|
// Create data structures for the kernel.
|
||||||
void *dst_ptr = dst.data_ptr();
|
// Create an array of pointers to the key and value caches.
|
||||||
|
int64_t key_cache_ptrs[num_layers];
|
||||||
const int64_t block_size_in_bytes = src.element_size() * src[0].numel();
|
int64_t value_cache_ptrs[num_layers];
|
||||||
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
for (int layer_idx = 0; layer_idx < num_layers; ++layer_idx) {
|
||||||
|
key_cache_ptrs[layer_idx] = reinterpret_cast<int64_t>(key_caches[layer_idx].data_ptr());
|
||||||
|
value_cache_ptrs[layer_idx] = reinterpret_cast<int64_t>(value_caches[layer_idx].data_ptr());
|
||||||
|
}
|
||||||
|
// Create block mapping array.
|
||||||
|
std::vector<int> block_mapping_vec;
|
||||||
for (const auto& pair : block_mapping) {
|
for (const auto& pair : block_mapping) {
|
||||||
int64_t src_block_number = pair.first;
|
int src_block_number = pair.first;
|
||||||
for (int64_t dst_block_number : pair.second) {
|
for (int dst_block_number : pair.second) {
|
||||||
int64_t src_offset = src_block_number * block_size_in_bytes;
|
block_mapping_vec.push_back(src_block_number);
|
||||||
int64_t dst_offset = dst_block_number * block_size_in_bytes;
|
block_mapping_vec.push_back(dst_block_number);
|
||||||
cudaMemcpyAsync(
|
|
||||||
dst_ptr + dst_offset,
|
|
||||||
src_ptr + src_offset,
|
|
||||||
block_size_in_bytes,
|
|
||||||
memcpy_type,
|
|
||||||
stream);
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
int* block_mapping_array = block_mapping_vec.data();
|
||||||
|
int num_pairs = block_mapping_vec.size() / 2;
|
||||||
|
|
||||||
|
// Move the data structures to the GPU.
|
||||||
|
// NOTE: This synchronizes the CPU and GPU.
|
||||||
|
torch::Tensor key_cache_ptrs_tensor = torch::from_blob(
|
||||||
|
key_cache_ptrs, {num_layers}, torch::kInt64).to(cache_device);
|
||||||
|
torch::Tensor value_cache_ptrs_tensor = torch::from_blob(
|
||||||
|
value_cache_ptrs, {num_layers}, torch::kInt64).to(cache_device);
|
||||||
|
torch::Tensor block_mapping_tensor = torch::from_blob(
|
||||||
|
block_mapping_array, {2 * num_pairs}, torch::kInt).to(cache_device);
|
||||||
|
|
||||||
|
// Launch the kernel.
|
||||||
|
const int numel_per_block = key_caches[0][0].numel();
|
||||||
|
dim3 grid(num_layers, num_pairs);
|
||||||
|
dim3 block(std::min(1024, numel_per_block));
|
||||||
|
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
||||||
|
AT_DISPATCH_FLOATING_TYPES_AND_HALF(
|
||||||
|
key_caches[0].scalar_type(), "copy_blocks_kernel", ([&] {
|
||||||
|
cacheflow::copy_blocks_kernel<scalar_t><<<grid, block, 0, stream>>>(
|
||||||
|
key_cache_ptrs_tensor.data_ptr<int64_t>(),
|
||||||
|
value_cache_ptrs_tensor.data_ptr<int64_t>(),
|
||||||
|
block_mapping_tensor.data_ptr<int>(),
|
||||||
|
numel_per_block);
|
||||||
|
}));
|
||||||
}
|
}
|
||||||
|
|
||||||
namespace cacheflow {
|
namespace cacheflow {
|
||||||
|
@ -5,6 +5,61 @@ import torch
|
|||||||
from cacheflow import cache_ops
|
from cacheflow import cache_ops
|
||||||
|
|
||||||
|
|
||||||
|
def test_copy_blocks(
|
||||||
|
num_mappings: int,
|
||||||
|
num_layers: int,
|
||||||
|
num_heads: int,
|
||||||
|
head_size: int,
|
||||||
|
block_size: int,
|
||||||
|
num_blocks: int,
|
||||||
|
dtype: torch.dtype,
|
||||||
|
) -> None:
|
||||||
|
# Generate random block mappings.
|
||||||
|
src_blocks = random.sample(range(num_blocks), num_mappings)
|
||||||
|
remainig_blocks = list(set(range(num_blocks)) - set(src_blocks))
|
||||||
|
dst_blocks = random.sample(remainig_blocks, num_mappings)
|
||||||
|
block_mapping = {src: [dst] for src, dst in zip(src_blocks, dst_blocks)}
|
||||||
|
|
||||||
|
# Create the KV cache.
|
||||||
|
x = 16 // torch.tensor([], dtype=dtype).element_size()
|
||||||
|
key_cache_shape = (num_blocks, num_heads, head_size // x, block_size, x)
|
||||||
|
key_caches = []
|
||||||
|
for _ in range(num_layers):
|
||||||
|
key_cache = torch.randn(
|
||||||
|
size=key_cache_shape, dtype=dtype, device='cuda')
|
||||||
|
key_caches.append(key_cache)
|
||||||
|
cloned_key_caches = []
|
||||||
|
for key_cache in key_caches:
|
||||||
|
cloned_key_caches.append(key_cache.clone())
|
||||||
|
|
||||||
|
value_cache_shape = (num_blocks, num_heads, head_size, block_size)
|
||||||
|
value_caches = []
|
||||||
|
for _ in range(num_layers):
|
||||||
|
value_cache = torch.randn(
|
||||||
|
size=value_cache_shape, dtype=dtype, device='cuda')
|
||||||
|
value_caches.append(value_cache)
|
||||||
|
cloned_value_caches = []
|
||||||
|
for value_cache in value_caches:
|
||||||
|
cloned_value_caches.append(value_cache.clone())
|
||||||
|
|
||||||
|
# Call the copy blocks kernel.
|
||||||
|
cache_ops.copy_blocks(key_caches, value_caches, block_mapping)
|
||||||
|
|
||||||
|
# Reference implementation.
|
||||||
|
for src, dsts in block_mapping.items():
|
||||||
|
for dst in dsts:
|
||||||
|
for key_cache, cloned_key_cache in zip(key_caches, cloned_key_caches):
|
||||||
|
cloned_key_cache[dst] = cloned_key_cache[src]
|
||||||
|
for value_cache, cloned_value_cache in zip(value_caches, cloned_value_caches):
|
||||||
|
cloned_value_cache[dst] = cloned_value_cache[src]
|
||||||
|
|
||||||
|
# Compare the results.
|
||||||
|
for key_cache, cloned_key_cache in zip(key_caches, cloned_key_caches):
|
||||||
|
assert torch.allclose(key_cache, cloned_key_cache)
|
||||||
|
for value_cache, cloned_value_cache in zip(value_caches, cloned_value_caches):
|
||||||
|
assert torch.allclose(value_cache, cloned_value_cache)
|
||||||
|
|
||||||
|
|
||||||
def test_reshape_and_cache(
|
def test_reshape_and_cache(
|
||||||
num_tokens: int,
|
num_tokens: int,
|
||||||
num_heads: int,
|
num_heads: int,
|
||||||
@ -46,6 +101,9 @@ def test_reshape_and_cache(
|
|||||||
|
|
||||||
@torch.inference_mode()
|
@torch.inference_mode()
|
||||||
def test_cache() -> None:
|
def test_cache() -> None:
|
||||||
|
test_copy_blocks(
|
||||||
|
num_mappings=23, num_layers=7, num_heads=17, head_size=16,
|
||||||
|
block_size=8, num_blocks=1024, dtype=torch.half)
|
||||||
test_reshape_and_cache(
|
test_reshape_and_cache(
|
||||||
num_tokens=3, num_heads=2, head_size=16, block_size=8, num_blocks=2,
|
num_tokens=3, num_heads=2, head_size=16, block_size=8, num_blocks=2,
|
||||||
dtype=torch.half)
|
dtype=torch.half)
|
||||||
|
Loading…
x
Reference in New Issue
Block a user