Memcpy kernel for flash attention (#29)
* optimize * add benchmark * add assert * add test
This commit is contained in:
parent
b9926f7f66
commit
e3cec88aa5
81
benchmark/benchmark_cache.py
Normal file
81
benchmark/benchmark_cache.py
Normal file
@ -0,0 +1,81 @@
|
||||
import functools
|
||||
import random
|
||||
import time
|
||||
|
||||
import torch
|
||||
|
||||
from cacheflow import cache_ops
|
||||
|
||||
|
||||
def benchmark(name, f, size: int, num_warmup = 10, num_iters = 100):
|
||||
for _ in range(num_warmup):
|
||||
f()
|
||||
torch.cuda.synchronize()
|
||||
|
||||
start = time.time()
|
||||
for _ in range(num_iters):
|
||||
f()
|
||||
torch.cuda.synchronize()
|
||||
end = time.time()
|
||||
avg_time = (end - start) / num_iters
|
||||
print(f'[Latency] {name}: {avg_time * 1000:.3f} ms')
|
||||
print(f'[Throughput] {name}: {size / avg_time / 2 ** 30:.3f} GB/s')
|
||||
|
||||
|
||||
@torch.inference_mode()
|
||||
def test_gather_cached_kv(
|
||||
num_tokens: int,
|
||||
num_heads: int,
|
||||
head_size: int,
|
||||
block_size: int,
|
||||
num_blocks: int,
|
||||
dtype: torch.dtype,
|
||||
) -> None:
|
||||
print(f'num_tokens: {num_tokens}, num_heads: {num_heads}, '
|
||||
f'head_size: {head_size}, block_size: {block_size}, '
|
||||
f'num_blocks: {num_blocks}, dtype: {dtype}')
|
||||
|
||||
num_slots = block_size * num_blocks
|
||||
slot_mapping = random.sample(range(num_slots), num_tokens)
|
||||
slot_mapping = torch.tensor(slot_mapping, dtype=torch.int, device='cuda')
|
||||
|
||||
qkv = torch.randn(
|
||||
num_tokens, 3, num_heads, head_size, dtype=dtype, device='cuda')
|
||||
_, key, value = qkv.unbind(dim=1)
|
||||
|
||||
x = 16 // torch.tensor([], dtype=dtype).element_size()
|
||||
key_cache_shape = (num_blocks, num_heads, head_size // x, block_size, x)
|
||||
key_cache = torch.randn(size=key_cache_shape, dtype=dtype, device='cuda')
|
||||
|
||||
value_cache_shape = (num_blocks, num_heads, head_size, block_size)
|
||||
value_cache = torch.randn(
|
||||
size=value_cache_shape, dtype=dtype, device='cuda')
|
||||
|
||||
# Run Flash attention.
|
||||
def run():
|
||||
cache_ops.gather_cached_kv(key, value, key_cache, value_cache, slot_mapping)
|
||||
|
||||
benchmark('gather_cached_kv', run,
|
||||
size=num_tokens * num_heads * head_size * 2 * qkv.element_size())
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
BLOCK_SIZE = 8
|
||||
NUM_BLOCKS = 1024
|
||||
DTYPE = torch.half
|
||||
|
||||
# LLaMA-13B and OPT-13B
|
||||
NUM_HEADS = 40
|
||||
HEAD_SIZE = 128
|
||||
|
||||
run_benchmark = functools.partial(
|
||||
test_gather_cached_kv,
|
||||
num_heads=NUM_HEADS,
|
||||
head_size=HEAD_SIZE,
|
||||
block_size=BLOCK_SIZE,
|
||||
num_blocks=NUM_BLOCKS,
|
||||
dtype=DTYPE,
|
||||
)
|
||||
|
||||
for i in range(6, 12):
|
||||
run_benchmark(num_tokens=2 ** i)
|
@ -20,6 +20,13 @@ void reshape_and_cache(
|
||||
torch::Tensor& value_cache,
|
||||
torch::Tensor& slot_mapping);
|
||||
|
||||
void gather_cached_kv(
|
||||
torch::Tensor& key,
|
||||
torch::Tensor& value,
|
||||
torch::Tensor& key_cache,
|
||||
torch::Tensor& value_cache,
|
||||
torch::Tensor& slot_mapping);
|
||||
|
||||
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
|
||||
m.def(
|
||||
"swap_blocks",
|
||||
@ -33,4 +40,8 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
|
||||
"reshape_and_cache",
|
||||
&reshape_and_cache,
|
||||
"Reshape the key and value tensors and cache them");
|
||||
m.def(
|
||||
"gather_cached_kv",
|
||||
&gather_cached_kv,
|
||||
"Gather key and value from the cache into contiguous QKV tensors");
|
||||
}
|
||||
|
@ -176,6 +176,124 @@ __global__ void reshape_and_cache_kernel(
|
||||
}
|
||||
}
|
||||
|
||||
// Grid: (num_blocks, block_size).
|
||||
template<typename scalar_t>
|
||||
__global__ void gather_cached_kv_kernel(
|
||||
scalar_t* __restrict__ key, // [num_tokens, [stride], num_heads, head_size]
|
||||
scalar_t* __restrict__ value, // [num_tokens, [stride], num_heads, head_size]
|
||||
const scalar_t* __restrict__ key_cache, // [num_blocks, num_heads, head_size/x, block_size, x]
|
||||
const scalar_t* __restrict__ value_cache, // [num_blocks, num_heads, head_size, block_size]
|
||||
const int* __restrict__ slot_mapping, // [num_tokens]
|
||||
const int key_stride,
|
||||
const int value_stride,
|
||||
const int num_heads,
|
||||
const int head_size,
|
||||
const int block_size,
|
||||
const int x) {
|
||||
const int token_idx = blockIdx.x;
|
||||
const int slot_idx = slot_mapping[token_idx];
|
||||
const int block_idx = slot_idx / block_size;
|
||||
const int block_offset = slot_idx % block_size;
|
||||
|
||||
const int num_tokens = num_heads * head_size;
|
||||
for (int i = threadIdx.x; i < num_tokens; i += blockDim.x) {
|
||||
const int tgt_key_idx = token_idx * key_stride + i;
|
||||
const int tgt_value_idx = token_idx * value_stride + i;
|
||||
|
||||
const int head_idx = i / head_size;
|
||||
const int head_offset = i % head_size;
|
||||
const int x_idx = head_offset / x; // the offset of the [head_size/x] dimension
|
||||
const int x_offset = head_offset % x;
|
||||
|
||||
const int src_key_idx = block_idx * num_heads * (head_size / x) * block_size * x
|
||||
+ head_idx * (head_size / x) * block_size * x
|
||||
+ x_idx * block_size * x
|
||||
+ block_offset * x
|
||||
+ x_offset;
|
||||
const int src_value_idx = block_idx * num_heads * head_size * block_size
|
||||
+ head_idx * head_size * block_size
|
||||
+ head_offset * block_size
|
||||
+ block_offset;
|
||||
|
||||
key[tgt_key_idx] = __ldg(&key_cache[src_key_idx]);
|
||||
value[tgt_value_idx] = __ldg(&value_cache[src_value_idx]);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename scalar_t>
|
||||
__global__ void gather_cached_kv_kernel_optimized(
|
||||
scalar_t *__restrict__ key, // [num_tokens, [stride], num_heads, head_size]
|
||||
scalar_t *__restrict__ value, // [num_tokens, [stride], num_heads, head_size]
|
||||
const scalar_t *__restrict__ key_cache, // [num_blocks, num_heads, head_size/x, block_size, x]
|
||||
const scalar_t *__restrict__ value_cache, // [num_blocks, num_heads, head_size, block_size]
|
||||
const int *__restrict__ slot_mapping, // [num_tokens]
|
||||
const int key_stride,
|
||||
const int value_stride,
|
||||
const int num_heads,
|
||||
const int head_size,
|
||||
const int block_size,
|
||||
const int x)
|
||||
{
|
||||
const int token_idx = blockIdx.x;
|
||||
const int slot_idx = slot_mapping[token_idx];
|
||||
const int block_idx = slot_idx / block_size;
|
||||
const int block_offset = slot_idx % block_size;
|
||||
|
||||
const int dim = num_heads * head_size;
|
||||
assert(dim % 4 == 0); // this is true for known use cases
|
||||
const int unroll_factor = 4;
|
||||
const int unrolled_dim = dim / unroll_factor;
|
||||
|
||||
for (int i = threadIdx.x; i < unrolled_dim; i += blockDim.x)
|
||||
{
|
||||
int tgt_key_indices[unroll_factor];
|
||||
int tgt_value_indices[unroll_factor];
|
||||
int src_key_indices[unroll_factor];
|
||||
int src_value_indices[unroll_factor];
|
||||
scalar_t keys_to_store[unroll_factor];
|
||||
scalar_t values_to_store[unroll_factor];
|
||||
|
||||
#pragma unroll
|
||||
for (int j = 0; j < unroll_factor; ++j)
|
||||
{
|
||||
int index = i + j * unrolled_dim;
|
||||
|
||||
const int tgt_key_idx = token_idx * key_stride + index;
|
||||
const int tgt_value_idx = token_idx * value_stride + index;
|
||||
|
||||
const int head_idx = index / head_size;
|
||||
const int head_offset = index % head_size;
|
||||
const int x_idx = head_offset / x;
|
||||
const int x_offset = head_offset % x;
|
||||
|
||||
const int src_key_idx = block_idx * num_heads * (head_size / x) * block_size * x
|
||||
+ head_idx * (head_size / x) * block_size * x
|
||||
+ x_idx * block_size * x
|
||||
+ block_offset * x
|
||||
+ x_offset;
|
||||
const int src_value_idx = block_idx * num_heads * head_size * block_size
|
||||
+ head_idx * head_size * block_size
|
||||
+ head_offset * block_size
|
||||
+ block_offset;
|
||||
|
||||
tgt_key_indices[j] = tgt_key_idx;
|
||||
tgt_value_indices[j] = tgt_value_idx;
|
||||
src_key_indices[j] = src_key_idx;
|
||||
src_value_indices[j] = src_value_idx;
|
||||
|
||||
keys_to_store[j] = __ldg(&key_cache[src_key_idx]);
|
||||
values_to_store[j] = __ldg(&value_cache[src_value_idx]);
|
||||
}
|
||||
|
||||
#pragma unroll
|
||||
for (int j = 0; j < unroll_factor; ++j)
|
||||
{
|
||||
key[tgt_key_indices[j]] = keys_to_store[j];
|
||||
value[tgt_value_indices[j]] = values_to_store[j];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace cacheflow
|
||||
|
||||
void reshape_and_cache(
|
||||
@ -215,3 +333,42 @@ void reshape_and_cache(
|
||||
x);
|
||||
});
|
||||
}
|
||||
|
||||
|
||||
void gather_cached_kv(
|
||||
torch::Tensor& key, // [out] [num_tokens, num_heads, head_size]
|
||||
torch::Tensor& value, // [out] [num_tokens, num_heads, head_size]
|
||||
torch::Tensor& key_cache, // [in] [num_blocks, num_heads, head_size/x, block_size, x]
|
||||
torch::Tensor& value_cache, // [in] [num_blocks, num_heads, head_size, block_size]
|
||||
torch::Tensor& slot_mapping) // [in] [num_tokens]
|
||||
{
|
||||
int num_tokens = key.size(0);
|
||||
int num_heads = key.size(1);
|
||||
int head_size = key.size(2);
|
||||
int block_size = key_cache.size(3);
|
||||
int x = key_cache.size(4);
|
||||
|
||||
int key_stride = key.stride(0);
|
||||
int value_stride = value.stride(0);
|
||||
|
||||
dim3 grid(num_tokens);
|
||||
dim3 block(std::min(num_heads * head_size, 512));
|
||||
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
||||
AT_DISPATCH_FLOATING_TYPES_AND_HALF(
|
||||
key.scalar_type(),
|
||||
"gather_cached_kv_kernel_optimized",
|
||||
[&] {
|
||||
cacheflow::gather_cached_kv_kernel_optimized<scalar_t><<<grid, block, 0, stream>>>(
|
||||
key.data_ptr<scalar_t>(),
|
||||
value.data_ptr<scalar_t>(),
|
||||
key_cache.data_ptr<scalar_t>(),
|
||||
value_cache.data_ptr<scalar_t>(),
|
||||
slot_mapping.data_ptr<int>(),
|
||||
key_stride,
|
||||
value_stride,
|
||||
num_heads,
|
||||
head_size,
|
||||
block_size,
|
||||
x);
|
||||
});
|
||||
}
|
||||
|
@ -99,6 +99,47 @@ def test_reshape_and_cache(
|
||||
assert torch.allclose(value_cache, cloned_value_cache)
|
||||
|
||||
|
||||
def test_gather_cached_kv(
|
||||
num_tokens: int,
|
||||
num_heads: int,
|
||||
head_size: int,
|
||||
block_size: int,
|
||||
num_blocks: int,
|
||||
dtype: torch.dtype,
|
||||
) -> None:
|
||||
num_slots = block_size * num_blocks
|
||||
slot_mapping = random.sample(range(num_slots), num_tokens)
|
||||
slot_mapping = torch.tensor(slot_mapping, dtype=torch.int, device='cuda')
|
||||
|
||||
qkv = torch.randn(
|
||||
num_tokens, 3, num_heads, head_size, dtype=dtype, device='cuda')
|
||||
_, key, value = qkv.unbind(dim=1)
|
||||
|
||||
qkv_clone = qkv.clone()
|
||||
_, cloned_key, cloned_value = qkv_clone.unbind(dim=1)
|
||||
|
||||
x = 16 // torch.tensor([], dtype=dtype).element_size()
|
||||
key_cache_shape = (num_blocks, num_heads, head_size // x, block_size, x)
|
||||
key_cache = torch.randn(size=key_cache_shape, dtype=dtype, device='cuda')
|
||||
|
||||
value_cache_shape = (num_blocks, num_heads, head_size, block_size)
|
||||
value_cache = torch.randn(
|
||||
size=value_cache_shape, dtype=dtype, device='cuda')
|
||||
|
||||
cache_ops.gather_cached_kv(key, value, key_cache, value_cache, slot_mapping)
|
||||
|
||||
# Reference implementation.
|
||||
for i in range(num_tokens):
|
||||
reshaped_key = cloned_key.reshape(num_tokens, num_heads, head_size // x, x)
|
||||
block_idx = torch.div(slot_mapping[i], block_size, rounding_mode='floor')
|
||||
block_offset = slot_mapping[i] % block_size
|
||||
reshaped_key[i] = key_cache[block_idx, :, :, block_offset, :]
|
||||
cloned_value[i] = value_cache[block_idx, :, :, block_offset]
|
||||
|
||||
assert torch.allclose(key, cloned_key)
|
||||
assert torch.allclose(value, cloned_value)
|
||||
|
||||
|
||||
@torch.inference_mode()
|
||||
def test_cache() -> None:
|
||||
test_copy_blocks(
|
||||
@ -107,6 +148,9 @@ def test_cache() -> None:
|
||||
test_reshape_and_cache(
|
||||
num_tokens=3, num_heads=2, head_size=16, block_size=8, num_blocks=2,
|
||||
dtype=torch.half)
|
||||
test_gather_cached_kv(
|
||||
num_tokens=3, num_heads=2, head_size=16, block_size=8, num_blocks=2,
|
||||
dtype=torch.half)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
Loading…
x
Reference in New Issue
Block a user