Implement single_query_cached_kv_attention kernel (#3)

This commit is contained in:
Woosuk Kwon 2023-03-01 15:02:19 -08:00 committed by GitHub
parent cbf8779afa
commit 0deacbce6e
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
12 changed files with 2140 additions and 60 deletions

View File

@ -15,7 +15,9 @@ class BlockManager:
block_size: int,
num_blocks: int,
) -> None:
assert block_size in [8, 16, 32]
if block_size not in [8, 16]:
raise ValueError(f'Unsupported block size: {block_size}'
'The block size must be either 8 or 16.')
self.device = device
self.block_size = block_size
self.num_blocks = num_blocks

View File

@ -3,7 +3,8 @@ from typing import List, Optional
import torch
import torch.nn as nn
from cacheflow import ops
from cacheflow import attention_ops
from cacheflow import cache_ops
from cacheflow.models import InputMetadata
@ -11,7 +12,7 @@ class OPTCacheFlowAttention(nn.Module):
def __init__(self, scale: float) -> None:
super().__init__()
self.scale = scale
self.scale = float(scale)
def _masked_attention(
self,
@ -57,38 +58,21 @@ class OPTCacheFlowAttention(nn.Module):
output: torch.Tensor, # [num_generation_tokens, num_heads, head_size]
query: torch.Tensor, # [num_generation_tokens, num_heads, head_size]
key_cache: torch.Tensor, # [num_blocks, num_heads, head_size/x, block_size, x]
value_cache: torch.Tensor, # [num_blocks, num_heads, block_size, head_size]
value_cache: torch.Tensor, # [num_blocks, num_heads, head_size, block_size]
input_metadata: InputMetadata,
) -> None:
num_heads = value_cache.shape[1]
head_size = value_cache.shape[3]
block_size = value_cache.shape[2]
block_tables = input_metadata.block_tables
# FIXME(woosuk): Replace the following with a custom op.
for i in range(input_metadata.num_generation_tokens):
q = query[i].unsqueeze(0)
block_table = block_tables[i]
context_len = int(input_metadata.context_lens[i])
keys = []
values = []
for j in range(context_len):
block_number = int(block_table[j // block_size])
block_offset = j % block_size
k = key_cache[block_number, :, :, block_offset, :]
k = k.reshape(num_heads, head_size)
keys.append(k)
v = value_cache[block_number, :, block_offset, :]
values.append(v)
keys = torch.stack(keys, dim=0)
values = torch.stack(values, dim=0)
out = self._masked_attention(q, keys, values)
out = out.view(num_heads, head_size)
output[i].copy_(out, non_blocking=True)
block_size = value_cache.shape[3]
attention_ops.single_query_cached_kv_attention(
output,
query,
key_cache,
value_cache,
self.scale,
input_metadata.block_tables,
input_metadata.context_lens,
block_size,
input_metadata.max_context_len,
)
def forward(
self,
@ -96,7 +80,7 @@ class OPTCacheFlowAttention(nn.Module):
key: torch.Tensor, # [num_tokens, num_heads * head_size]
value: torch.Tensor, # [num_tokens, num_heads * head_size]
key_cache: torch.Tensor, # [num_blocks, num_heads, head_size/x, block_size, x]
value_cache: torch.Tensor, # [num_blocks, num_heads, block_size, head_size]
value_cache: torch.Tensor, # [num_blocks, num_heads, head_size, block_size]
input_metadata: InputMetadata,
cache_event: Optional[torch.cuda.Event],
) -> torch.Tensor: # [num_tokens, num_heads * head_size]
@ -110,7 +94,7 @@ class OPTCacheFlowAttention(nn.Module):
# Reshape the input tensors.
num_heads = value_cache.shape[1]
head_size = value_cache.shape[3]
head_size = value_cache.shape[2]
query = query.view(-1, num_heads, head_size)
key = key.view(-1, num_heads, head_size)
value = value.view(-1, num_heads, head_size)
@ -125,7 +109,7 @@ class OPTCacheFlowAttention(nn.Module):
cache_event.wait()
# Reshape the keys and values and store them in the cache.
ops.reshape_and_cache(
cache_ops.reshape_and_cache(
key, value, key_cache, value_cache, input_metadata.slot_mapping)
if input_metadata.num_generation_tokens > 0:

View File

@ -1,7 +1,7 @@
from typing import Dict, List, Tuple
import torch
from cacheflow import ops
from cacheflow import cache_ops
KVCache = Tuple[torch.Tensor, torch.Tensor]
@ -57,20 +57,22 @@ class CacheEngine:
def get_value_block_shape(self) -> Tuple[int, int, int]:
return (
self.num_heads,
self.block_size,
self.head_size,
self.block_size,
)
def allocate_gpu_cache(self) -> List[KVCache]:
gpu_cache: List[KVCache] = []
key_block_shape = self.get_key_block_shape()
value_block_shape = self.get_value_block_shape()
for _ in range(self.num_layers):
key_blocks = torch.empty(
size=(self.num_gpu_blocks, *self.get_key_block_shape()),
size=(self.num_gpu_blocks, *key_block_shape),
dtype=self.dtype,
device=self.gpu_id,
)
value_blocks = torch.empty(
size=(self.num_gpu_blocks, *self.get_value_block_shape()),
size=(self.num_gpu_blocks, *value_block_shape),
dtype=self.dtype,
device=self.gpu_id,
)
@ -79,14 +81,16 @@ class CacheEngine:
def allocate_cpu_cache(self) -> List[KVCache]:
cpu_cache: List[KVCache] = []
key_block_shape = self.get_key_block_shape()
value_block_shape = self.get_value_block_shape()
for _ in range(self.num_layers):
key_blocks = torch.empty(
size=(self.num_cpu_blocks, *self.get_key_block_shape()),
size=(self.num_cpu_blocks, *key_block_shape),
dtype=self.dtype,
pin_memory=True,
)
value_blocks = torch.empty(
size=(self.num_cpu_blocks, *self.get_value_block_shape()),
size=(self.num_cpu_blocks, *value_block_shape),
dtype=self.dtype,
pin_memory=True,
)
@ -104,10 +108,10 @@ class CacheEngine:
src_key_cache, src_value_cache = src[i]
dst_key_cache, dst_value_cache = dst[i]
# Copy the key blocks.
ops.copy_cache_blocks(
cache_ops.copy_cache_blocks(
src_key_cache, dst_key_cache, src_to_dst)
# Copy the value blocks.
ops.copy_cache_blocks(
cache_ops.copy_cache_blocks(
src_value_cache, dst_value_cache, src_to_dst)
event = self.events[i]
event.record(stream=self.cache_stream)

View File

@ -118,7 +118,7 @@ class Worker:
_pad_to_max(block_table, max_num_blocks_per_seq)
for block_table in generation_block_tables]
block_tables_tensor = torch.tensor(
padded_block_tables, dtype=int, device=self.device)
padded_block_tables, dtype=torch.int, device=self.device)
input_metadata = InputMetadata(
seq_ids=prompt_seq_ids + generation_seq_ids,

19
csrc/attention.cpp Normal file
View File

@ -0,0 +1,19 @@
#include <torch/extension.h>
void single_query_cached_kv_attention(
torch::Tensor& out,
torch::Tensor& query,
torch::Tensor& key_cache,
torch::Tensor& value_cache,
float scale,
torch::Tensor& block_tables,
torch::Tensor& context_lens,
int block_size,
int max_context_len);
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def(
"single_query_cached_kv_attention",
&single_query_cached_kv_attention,
"Compute the attention between an input query and the cached key/value tensors");
}

400
csrc/attention_kernels.cu Normal file
View File

@ -0,0 +1,400 @@
#include <torch/extension.h>
#include <ATen/cuda/CUDAContext.h>
#include "attention_utils.h"
#include "cuda_primitives.h"
#include <algorithm>
#define WARP_SIZE 32
namespace cacheflow {
// Grid: (num_heads, num_seqs).
template<
typename scalar_t,
int HEAD_SIZE,
int BLOCK_SIZE,
int NUM_THREADS>
__global__ void single_query_cached_kv_attention_kernel(
scalar_t* __restrict__ out, // [num_seqs, num_heads, head_size]
const scalar_t* __restrict__ q, // [num_seqs, num_heads, head_size]
const scalar_t* __restrict__ k_cache, // [num_blocks, num_heads, head_size/x, block_size, x]
const scalar_t* __restrict__ v_cache, // [num_blocks, num_heads, head_size, block_size]
const float scale,
const int* __restrict__ block_tables, // [num_seqs, max_num_blocks_per_seq]
const int* __restrict__ context_lens, // [num_seqs]
const int max_num_blocks_per_seq) {
constexpr int THREAD_GROUP_SIZE = WARP_SIZE / BLOCK_SIZE;
constexpr int NUM_WARPS = NUM_THREADS / WARP_SIZE;
const int thread_idx = threadIdx.x;
const int warp_idx = thread_idx / WARP_SIZE;
const int lane = thread_idx % WARP_SIZE;
const int head_idx = blockIdx.x;
const int num_heads = gridDim.x;
const int seq_idx = blockIdx.y;
// A vector type to store a part of a key or a query.
// The vector size is configured in such a way that the threads in a thread group
// fetch or comput 16 bytes at a time.
// For example, if the size of a thread group is 4 and the data type is half,
// then the vector size is 16 / (4 * sizeof(half)) == 2.
constexpr int VEC_SIZE = 16 / (THREAD_GROUP_SIZE * sizeof(scalar_t));
using K_vec = typename Vec<scalar_t, VEC_SIZE>::Type;
using Q_vec = typename Vec<scalar_t, VEC_SIZE>::Type;
constexpr int NUM_ELEMS_PER_THREAD = HEAD_SIZE / THREAD_GROUP_SIZE;
constexpr int NUM_VECS_PER_THREAD = NUM_ELEMS_PER_THREAD / VEC_SIZE;
const int thread_group_idx = thread_idx / THREAD_GROUP_SIZE;
const int thread_group_offset = thread_idx % THREAD_GROUP_SIZE;
// Load the query to registers.
// Each thread in a thread group has a different part of the query.
// For example, if the the thread group size is 4, then the first thread in the group
// has 0, 4, 8, ... th vectors of the query, and the second thread has 1, 5, 9, ...
// th vectors of the query, and so on.
const scalar_t* q_ptr = q + seq_idx * num_heads * HEAD_SIZE + head_idx * HEAD_SIZE;
Q_vec q_vecs[NUM_VECS_PER_THREAD];
#pragma unroll
for (int i = 0; i < NUM_VECS_PER_THREAD; i++) {
const int vec_idx = thread_group_offset + i * THREAD_GROUP_SIZE;
q_vecs[i] = *reinterpret_cast<const Q_vec*>(q_ptr + vec_idx * VEC_SIZE);
}
// Memory planning.
extern __shared__ char shared_mem[];
// NOTE(woosuk): We use FP32 logits and accumulation.
float *logits = reinterpret_cast<float*>(shared_mem);
// Workspace for reduction.
__shared__ float red_smem[2 * NUM_WARPS];
// x == THREAD_GROUP_SIZE * VEC_SIZE
// Each thread group fetches x elements from the key at a time.
constexpr int x = 16 / sizeof(scalar_t);
float qk_max = -FLT_MAX;
const int* block_table = block_tables + seq_idx * max_num_blocks_per_seq;
const int context_len = context_lens[seq_idx];
const int num_blocks = (context_len + BLOCK_SIZE - 1) / BLOCK_SIZE;
// Iterate over the key blocks.
// Each warp fetches a block of keys for each iteration.
// Each thread group in a warp fetches a key from the block, and computes
// dot product with the query.
for (int block_idx = warp_idx; block_idx < num_blocks; block_idx += NUM_WARPS) {
const int physical_block_number = block_table[block_idx];
const int physical_block_offset = thread_group_idx % BLOCK_SIZE;
const int token_idx = block_idx * BLOCK_SIZE + physical_block_offset;
// Load a key to registers.
// Each thread in a thread group has a different part of the key.
// For example, if the the thread group size is 4, then the first thread in the group
// has 0, 4, 8, ... th vectors of the key, and the second thread has 1, 5, 9, ... th
// vectors of the key, and so on.
K_vec k_vecs[NUM_VECS_PER_THREAD];
#pragma unroll
for (int i = 0; i < NUM_VECS_PER_THREAD; i++) {
const scalar_t* k_ptr = k_cache + physical_block_number * num_heads * HEAD_SIZE * BLOCK_SIZE
+ head_idx * HEAD_SIZE * BLOCK_SIZE
+ physical_block_offset * x;
const int vec_idx = thread_group_offset + i * THREAD_GROUP_SIZE;
const int offset1 = (vec_idx * VEC_SIZE) / x;
const int offset2 = (vec_idx * VEC_SIZE) % x;
k_vecs[i] = *reinterpret_cast<const K_vec*>(k_ptr + offset1 * BLOCK_SIZE * x + offset2);
}
// Compute dot product.
// This includes a reduction across the threads in the same thread group.
const float qk = scale * Qk_dot<scalar_t, THREAD_GROUP_SIZE>::dot(q_vecs, k_vecs);
const bool mask = token_idx >= context_len;
if (thread_group_offset == 0) {
// Store the partial reductions to shared memory.
// NOTE(woosuk): It is required to zero out the masked logits.
logits[token_idx] = mask ? 0.f : qk;
// Update the max value.
qk_max = mask ? qk_max : fmaxf(qk_max, qk);
}
}
// Perform reduction across the threads in the same warp to get the
// max qk value for each "warp" (not across the thread block yet).
// The 0-th thread of each thread group already has its max qk value.
#pragma unroll
for (int mask = WARP_SIZE / 2; mask >= THREAD_GROUP_SIZE; mask /= 2) {
qk_max = fmaxf(qk_max, __shfl_xor_sync(uint32_t(-1), qk_max, mask));
}
if (lane == 0) {
red_smem[warp_idx] = qk_max;
}
__syncthreads();
// TODO(woosuk): Refactor this part.
// Get the max qk value for the sequence.
qk_max = lane < NUM_WARPS ? red_smem[lane] : -FLT_MAX;
#pragma unroll
for (int mask = NUM_WARPS / 2; mask >= 1; mask /= 2) {
qk_max = fmaxf(qk_max, __shfl_xor_sync(uint32_t(-1), qk_max, mask));
}
// Broadcast the max qk value to all threads.
qk_max = __shfl_sync(uint32_t(-1), qk_max, 0);
// Get the sum of the exp values.
float exp_sum = 0.f;
for (int i = thread_idx; i < context_len; i += NUM_THREADS) {
float val = __expf(logits[i] - qk_max);
logits[i] = val;
exp_sum += val;
}
exp_sum = block_sum<NUM_WARPS>(&red_smem[NUM_WARPS], exp_sum);
// Compute softmax.
const float inv_sum = __fdividef(1.f, exp_sum + 1e-6f);
for (int i = thread_idx; i < context_len; i += NUM_THREADS) {
logits[i] *= inv_sum;
}
__syncthreads();
// Each thread will fetch 16 bytes from the value cache at a time.
constexpr int V_VEC_SIZE = 16 / sizeof(scalar_t);
using V_vec = typename Vec<scalar_t, V_VEC_SIZE>::Type;
using L_vec = typename FloatVec<V_vec>::Type;
constexpr int NUM_V_VECS_PER_ROW = BLOCK_SIZE / V_VEC_SIZE;
constexpr int NUM_ROWS_PER_ITER = WARP_SIZE / NUM_V_VECS_PER_ROW;
constexpr int NUM_ROWS_PER_THREAD = (HEAD_SIZE + NUM_ROWS_PER_ITER - 1) / NUM_ROWS_PER_ITER;
float accs[NUM_ROWS_PER_THREAD];
#pragma unroll
for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) {
accs[i] = 0.f;
}
for (int block_idx = warp_idx; block_idx < num_blocks; block_idx += NUM_WARPS) {
const int physical_block_number = block_table[block_idx];
const int physical_block_offset = (lane % NUM_V_VECS_PER_ROW) * V_VEC_SIZE;
const int token_idx = block_idx * BLOCK_SIZE + physical_block_offset;
L_vec logits_vec = *reinterpret_cast<L_vec*>(logits + token_idx);
const scalar_t* v_ptr = v_cache + physical_block_number * num_heads * HEAD_SIZE * BLOCK_SIZE
+ head_idx * HEAD_SIZE * BLOCK_SIZE;
#pragma unroll
for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) {
const int row_idx = lane / NUM_V_VECS_PER_ROW + i * NUM_ROWS_PER_ITER;
if (row_idx < HEAD_SIZE) {
const int offset = row_idx * BLOCK_SIZE + physical_block_offset;
V_vec v_vec = *reinterpret_cast<const V_vec*>(v_ptr + offset);
accs[i] += dot(logits_vec, cast_to_float(v_vec));
}
}
}
// Perform reduction within each warp.
#pragma unroll
for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) {
float acc = accs[i];
#pragma unroll
for (int mask = NUM_V_VECS_PER_ROW / 2; mask >= 1; mask /= 2) {
acc += __shfl_xor_sync(uint32_t(-1), acc, mask);
}
accs[i] = acc;
}
// NOTE(woosuk): A barrier is required because the shared memory space for logits
// is reused for the output.
__syncthreads();
// Perform reduction across warps.
float* out_smem = reinterpret_cast<float*>(shared_mem);
#pragma unroll
for (int i = NUM_WARPS; i > 1; i /= 2) {
int mid = i / 2;
// Upper warps write to shared memory.
if (warp_idx >= mid && warp_idx < i) {
float* dst = &out_smem[(warp_idx - mid) * HEAD_SIZE];
#pragma unroll
for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) {
const int row_idx = lane / NUM_V_VECS_PER_ROW + i * NUM_ROWS_PER_ITER;
if (row_idx < HEAD_SIZE && lane % NUM_V_VECS_PER_ROW == 0) {
dst[row_idx] = accs[i];
}
}
}
__syncthreads();
// Lower warps update the output.
if (warp_idx < mid) {
const float* src = &out_smem[warp_idx * HEAD_SIZE];
#pragma unroll
for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) {
const int row_idx = lane / NUM_V_VECS_PER_ROW + i * NUM_ROWS_PER_ITER;
if (row_idx < HEAD_SIZE && lane % NUM_V_VECS_PER_ROW == 0) {
accs[i] += src[row_idx];
}
}
}
__syncthreads();
}
// Write the final output.
if (warp_idx == 0) {
scalar_t* out_ptr = out + seq_idx * num_heads * HEAD_SIZE + head_idx * HEAD_SIZE;
#pragma unroll
for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) {
const int row_idx = lane / NUM_V_VECS_PER_ROW + i * NUM_ROWS_PER_ITER;
if (row_idx < HEAD_SIZE && lane % NUM_V_VECS_PER_ROW == 0) {
convert_from_float(*(out_ptr + row_idx), accs[i]);
}
}
}
}
} // namespace cacheflow
#define LAUNCH_ATTENTION_KERNEL(T, HEAD_SIZE, BLOCK_SIZE, NUM_THREADS) \
cacheflow::single_query_cached_kv_attention_kernel<T, HEAD_SIZE, BLOCK_SIZE, NUM_THREADS> \
<<<grid, block, shared_mem_size, stream>>>( \
out_ptr, \
query_ptr, \
key_cache_ptr, \
value_cache_ptr, \
scale, \
block_tables_ptr, \
context_lens_ptr, \
max_num_blocks_per_seq);
// TODO(woosuk): Tune NUM_THREADS.
template<
typename T,
int BLOCK_SIZE,
int NUM_THREADS = 128>
void single_query_cached_kv_attention_launcher(
torch::Tensor& out,
torch::Tensor& query,
torch::Tensor& key_cache,
torch::Tensor& value_cache,
float scale,
torch::Tensor& block_tables,
torch::Tensor& context_lens,
int max_context_len) {
int num_seqs = query.size(0);
int num_heads = query.size(1);
int head_size = query.size(2);
int max_num_blocks_per_seq = block_tables.size(1);
T* out_ptr = reinterpret_cast<T*>(out.data_ptr());
T* query_ptr = reinterpret_cast<T*>(query.data_ptr());
T* key_cache_ptr = reinterpret_cast<T*>(key_cache.data_ptr());
T* value_cache_ptr = reinterpret_cast<T*>(value_cache.data_ptr());
int* block_tables_ptr = block_tables.data_ptr<int>();
int* context_lens_ptr = context_lens.data_ptr<int>();
constexpr int NUM_WARPS = NUM_THREADS / WARP_SIZE;
int padded_max_context_len = ((max_context_len + BLOCK_SIZE - 1) / BLOCK_SIZE) * BLOCK_SIZE;
int logits_size = padded_max_context_len * sizeof(float);
int outputs_size = (NUM_WARPS / 2) * head_size * sizeof(float);
int shared_mem_size = std::max(logits_size, outputs_size);
dim3 grid(num_heads, num_seqs);
dim3 block(NUM_THREADS);
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
switch (head_size) {
case 32:
LAUNCH_ATTENTION_KERNEL(T, 32, BLOCK_SIZE, NUM_THREADS);
break;
case 64:
LAUNCH_ATTENTION_KERNEL(T, 64, BLOCK_SIZE, NUM_THREADS);
break;
case 80:
LAUNCH_ATTENTION_KERNEL(T, 80, BLOCK_SIZE, NUM_THREADS);
break;
case 96:
LAUNCH_ATTENTION_KERNEL(T, 96, BLOCK_SIZE, NUM_THREADS);
break;
case 128:
LAUNCH_ATTENTION_KERNEL(T, 128, BLOCK_SIZE, NUM_THREADS);
break;
case 160:
LAUNCH_ATTENTION_KERNEL(T, 160, BLOCK_SIZE, NUM_THREADS);
break;
case 192:
LAUNCH_ATTENTION_KERNEL(T, 192, BLOCK_SIZE, NUM_THREADS);
break;
case 256:
LAUNCH_ATTENTION_KERNEL(T, 256, BLOCK_SIZE, NUM_THREADS);
break;
default:
assert(false);
break;
}
}
void single_query_cached_kv_attention(
torch::Tensor& out,
torch::Tensor& query,
torch::Tensor& key_cache,
torch::Tensor& value_cache,
float scale,
torch::Tensor& block_tables,
torch::Tensor& context_lens,
int block_size,
int max_context_len) {
// TODO(woosuk): Support BF16.
if (query.element_size() == 2) {
// Half.
if (block_size == 8) {
single_query_cached_kv_attention_launcher<uint16_t, 8>(
out,
query,
key_cache,
value_cache,
scale,
block_tables,
context_lens,
max_context_len);
} else if (block_size == 16) {
single_query_cached_kv_attention_launcher<uint16_t, 16>(
out,
query,
key_cache,
value_cache,
scale,
block_tables,
context_lens,
max_context_len);
} else {
assert(false);
}
} else if (query.element_size() == 4) {
// Float.
if (block_size == 8) {
single_query_cached_kv_attention_launcher<float, 8>(
out,
query,
key_cache,
value_cache,
scale,
block_tables,
context_lens,
max_context_len);
} else if (block_size == 16) {
single_query_cached_kv_attention_launcher<float, 16>(
out,
query,
key_cache,
value_cache,
scale,
block_tables,
context_lens,
max_context_len);
} else {
assert(false);
}
} else {
assert(false);
}
}
#undef WARP_SIZE

204
csrc/attention_utils.h Normal file
View File

@ -0,0 +1,204 @@
#pragma once
#include "cuda_primitives.h"
#include <float.h>
#include <type_traits>
#define MMHA_USE_FP32_ACUM_FOR_FMA
#define MMHA_USE_FP32_ACUM_FOR_OUT
namespace cacheflow {
// A vector type to store Q, K, V elements.
template<typename T, int VEC_SIZE>
struct Vec {};
template<>
struct Vec<float, 1> {
using Type = float;
};
template<>
struct Vec<float, 2> {
using Type = float2;
};
template<>
struct Vec<float, 4> {
using Type = float4;
};
template<>
struct Vec<uint16_t, 1> {
using Type = uint16_t;
};
template<>
struct Vec<uint16_t, 2> {
using Type = uint32_t;
};
template<>
struct Vec<uint16_t, 4> {
using Type = uint2;
};
template<>
struct Vec<uint16_t, 8> {
using Type = uint4;
};
template<typename T>
struct FloatVec {};
template<>
struct FloatVec<float> {
using Type = float;
};
template<>
struct FloatVec<float2> {
using Type = float2;
};
template<>
struct FloatVec<float4> {
using Type = float4;
};
template<>
struct FloatVec<uint16_t> {
using Type = float;
};
template<>
struct FloatVec<uint32_t> {
using Type = float2;
};
template<>
struct FloatVec<uint2> {
using Type = Float4_;
};
template<>
struct FloatVec<uint4> {
using Type = Float8_;
};
template<int THREADS_PER_KEY, typename K_vec, int N>
inline __device__ float qk_dot_(const K_vec (&q)[N], const K_vec (&k)[N])
{
using K_vec_acum = typename FloatVec<K_vec>::Type;
// Compute the parallel products for Q*K^T (treat vector lanes separately).
K_vec_acum qk_vec = mul<K_vec_acum, K_vec, K_vec>(q[0], k[0]);
#pragma unroll
for (int ii = 1; ii < N; ++ii) {
qk_vec = fma(q[ii], k[ii], qk_vec);
}
// Finalize the reduction across lanes.
float qk = sum(qk_vec);
#pragma unroll
for (int mask = THREADS_PER_KEY / 2; mask >= 1; mask /= 2) {
qk += __shfl_xor_sync(uint32_t(-1), qk, mask);
}
return qk;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
template<typename T, int THREADS_PER_KEY>
struct Qk_dot {
template<typename K_vec, int N>
static inline __device__ float dot(const K_vec (&q)[N], const K_vec (&k)[N])
{
return qk_dot_<THREADS_PER_KEY>(q, k);
}
};
////////////////////////////////////////////////////////////////////////////////////////////////////
inline __device__ float4 hmma_fp32(const uint2& a, uint32_t b)
{
float4 c;
float zero = 0.f;
asm volatile("mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32 \n"
" {%0, %1, %2, %3}, \n"
" {%4, %5}, \n"
" {%6}, \n"
" {%7, %7, %7, %7}; \n"
: "=f"(c.x), "=f"(c.y), "=f"(c.z), "=f"(c.w)
: "r"(a.x) "r"(a.y), "r"(b), "f"(zero));
return c;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
template<int N>
inline __device__ float qk_hmma_dot_(const uint32_t (&q)[N], const uint32_t (&k)[N])
{
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 750
using K_vec_acum = typename FloatVec<uint32_t>::Type;
K_vec_acum qk_vec = mul<K_vec_acum, uint32_t, uint32_t>(q[0], k[0]);
#pragma unroll
for (int ii = 1; ii < N; ++ii) {
qk_vec = fma(q[ii], k[ii], qk_vec);
}
#ifdef MMHA_USE_FP32_ACUM_FOR_FMA
uint32_t qk_vec_ = float2_to_half2(qk_vec);
return hmma_fp32(make_uint2(qk_vec_, 0u), 0x3c003c00u).x;
#else
return hmma_fp32(make_uint2(qk_vec, 0u), 0x3c003c00u).x;
#endif
#else
return 0.f;
#endif
}
////////////////////////////////////////////////////////////////////////////////////////////////////
template<>
struct Qk_dot<uint16_t, 4> {
template<int N>
static inline __device__ float dot(const uint32_t (&q)[N], const uint32_t (&k)[N])
{
#if __CUDA_ARCH__ >= 750 && defined(MMHA_USE_HMMA_FOR_REDUCTION)
return qk_hmma_dot_(q, k);
#else
return qk_dot_<4>(q, k);
#endif // defined MMHA_USE_HMMA_FOR_REDUCTION
}
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template<int WARPS_PER_BLOCK, int WARP_SIZE = 32>
inline __device__ float block_sum(float* red_smem, float sum)
{
// Decompose the thread index into warp / lane.
int warp = threadIdx.x / WARP_SIZE;
int lane = threadIdx.x % WARP_SIZE;
// Compute the sum per warp.
#pragma unroll
for (int mask = WARP_SIZE / 2; mask >= 1; mask /= 2) {
sum += __shfl_xor_sync(uint32_t(-1), sum, mask);
}
// Warp leaders store the data to shared memory.
if (lane == 0) {
red_smem[warp] = sum;
}
// Make sure the data is in shared memory.
__syncthreads();
// The warps compute the final sums.
if (lane < WARPS_PER_BLOCK) {
sum = red_smem[lane];
}
// Parallel reduction inside the warp.
#pragma unroll
for (int mask = WARPS_PER_BLOCK / 2; mask >= 1; mask /= 2) {
sum += __shfl_xor_sync(uint32_t(-1), sum, mask);
}
// Broadcast to other threads.
return __shfl_sync(uint32_t(-1), sum, 0);
}
} // namespace cacheflow
#undef MMHA_USE_FP32_ACUM_FOR_FMA
#undef MMHA_USE_FP32_ACUM_FOR_OUT

View File

@ -48,7 +48,7 @@ __global__ void reshape_and_cache_kernel(
const scalar_t* __restrict__ key, // [num_tokens, num_heads, head_size]
const scalar_t* __restrict__ value, // [num_tokens, num_heads, head_size]
scalar_t* __restrict__ key_cache, // [num_blocks, num_heads, head_size/x, block_size, x]
scalar_t* __restrict__ value_cache, // [num_blocks, num_heads, block_size, head_size]
scalar_t* __restrict__ value_cache, // [num_blocks, num_heads, head_size, block_size]
const int* __restrict__ slot_mapping, // [num_tokens]
const int num_heads,
const int head_size,
@ -73,10 +73,10 @@ __global__ void reshape_and_cache_kernel(
+ x_idx * block_size * x
+ block_offset * x
+ x_offset;
const int tgt_value_idx = block_idx * num_heads * block_size * head_size
+ head_idx * block_size * head_size
+ block_offset * head_size
+ head_offset;
const int tgt_value_idx = block_idx * num_heads * head_size * block_size
+ head_idx * head_size * block_size
+ head_offset * block_size
+ block_offset;
key_cache[tgt_key_idx] = __ldg(&key[src_idx]);
value_cache[tgt_value_idx] = __ldg(&value[src_idx]);
}

1318
csrc/cuda_primitives.h Normal file

File diff suppressed because it is too large Load Diff

View File

@ -9,15 +9,22 @@ ext_modules = []
# Cache operations.
cache_extension = cpp_extension.CUDAExtension(
name='cacheflow.ops',
name='cacheflow.cache_ops',
sources=['csrc/cache.cpp', 'csrc/cache_kernels.cu'],
extra_compile_args={'cxx': CXX_FLAGS, 'nvcc': NVCC_FLAGS},
)
ext_modules.append(cache_extension)
# Attention kernels.
attention_extension = cpp_extension.CUDAExtension(
name='cacheflow.attention_ops',
sources=['csrc/attention.cpp', 'csrc/attention_kernels.cu'],
extra_compile_args={'cxx': CXX_FLAGS, 'nvcc': NVCC_FLAGS},
)
ext_modules.append(attention_extension)
setuptools.setup(
name='cacheflow',
requires_python='>=3.9',
ext_modules=ext_modules,
cmdclass={'build_ext': cpp_extension.BuildExtension},
)

142
tests/kernels/attention.py Normal file
View File

@ -0,0 +1,142 @@
import random
from typing import Optional
import torch
from cacheflow import attention_ops
def ref_masked_attention(
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
scale: float,
attn_mask: Optional[torch.Tensor] = None,
) -> torch.Tensor:
query = query * scale
attn = torch.einsum('qhd,khd->hqk', query, key)
if attn_mask is not None:
attn = attn + attn_mask
attn = torch.softmax(attn, dim=-1)
out = torch.einsum('hqk,khd->qhd', attn, value)
return out
def ref_single_query_cached_kv_attention(
output: torch.Tensor,
query: torch.Tensor,
key_cache: torch.Tensor,
value_cache: torch.Tensor,
block_tables: torch.Tensor,
context_lens: torch.Tensor,
) -> None:
num_heads = value_cache.shape[1]
head_size = value_cache.shape[2]
block_size = value_cache.shape[3]
num_input_tokens = query.shape[0]
for i in range(num_input_tokens):
q = query[i].unsqueeze(0)
block_table = block_tables[i]
context_len = int(context_lens[i])
keys = []
values = []
for j in range(context_len):
block_number = int(block_table[j // block_size])
block_offset = j % block_size
k = key_cache[block_number, :, :, block_offset, :]
k = k.reshape(num_heads, head_size)
keys.append(k)
v = value_cache[block_number, :, :, block_offset]
values.append(v)
keys = torch.stack(keys, dim=0)
values = torch.stack(values, dim=0)
scale = 1.0 / (head_size ** 0.5)
out = ref_masked_attention(q, keys, values, scale)
out = out.view(num_heads, head_size)
output[i].copy_(out, non_blocking=True)
def test_single_query_cached_kv_attention(
num_tokens: int,
num_heads: int,
head_size: int,
block_size: int,
num_blocks: int,
dtype: torch.dtype,
) -> None:
query = torch.randn(
num_tokens, num_heads, head_size, dtype=dtype, device='cuda')
x = 16 // torch.tensor([], dtype=dtype).element_size()
key_block_shape = (num_heads, head_size // x, block_size, x)
key_cache = torch.randn(
size=(num_blocks, *key_block_shape), dtype=dtype, device='cuda')
value_block_shape = (num_heads, head_size, block_size)
value_cache = torch.randn(
size=(num_blocks, *value_block_shape), dtype=dtype, device='cuda')
context_lens = [random.randint(1, 4096) for _ in range(num_tokens)]
max_context_len = max(context_lens)
context_lens = torch.tensor(context_lens, dtype=torch.int, device='cuda')
max_num_blocks_per_seq = (max_context_len + block_size - 1) // block_size
block_tables = []
for _ in range(num_tokens):
block_table = [
random.randint(0, num_blocks - 1)
for _ in range(max_num_blocks_per_seq)
]
block_tables.append(block_table)
block_tables = torch.tensor(block_tables, dtype=torch.int, device='cuda')
scale = float(1.0 / (head_size ** 0.5))
output = torch.empty_like(query)
attention_ops.single_query_cached_kv_attention(
output,
query,
key_cache,
value_cache,
scale,
block_tables,
context_lens,
block_size,
max_context_len,
)
ref_output = torch.empty_like(query)
ref_single_query_cached_kv_attention(
ref_output,
query,
key_cache,
value_cache,
block_tables,
context_lens,
)
# NOTE(woosuk): Due to the difference in the data types the two
# implementations use for attention softmax logits and accumulation,
# there is a small difference in the final outputs.
# We should use a relaxed tolerance for the test.
assert torch.allclose(output, ref_output, atol=1e-3, rtol=1e-5)
@torch.inference_mode()
def test_attention() -> None:
for dtype in [torch.half, torch.float]:
for block_size in [8, 16]:
for head_size in [64, 80, 96, 128, 256]:
test_single_query_cached_kv_attention(
num_tokens=37,
num_heads=3,
head_size=head_size,
block_size=block_size,
num_blocks=1024,
dtype=dtype,
)
if __name__ == '__main__':
test_attention()

View File

@ -2,7 +2,7 @@ import random
import torch
from cacheflow.ops import reshape_and_cache
from cacheflow import cache_ops
def test_reshape_and_cache(
@ -26,30 +26,30 @@ def test_reshape_and_cache(
key_cache = torch.randn(size=key_cache_shape, dtype=dtype, device='cuda')
cloned_key_cache = key_cache.clone()
value_cache_shape = (num_blocks, num_heads, block_size, head_size)
value_cache_shape = (num_blocks, num_heads, head_size, block_size)
value_cache = torch.randn(
size=value_cache_shape, dtype=dtype, device='cuda')
cloned_value_cache = value_cache.clone()
reshape_and_cache(key, value, key_cache, value_cache, slot_mapping)
cache_ops.reshape_and_cache(key, value, key_cache, value_cache, slot_mapping)
for i in range(num_tokens):
reshaped_key = key.reshape(num_tokens, num_heads, head_size // x, x)
block_idx = slot_mapping[i] // block_size
block_offset = slot_mapping[i] % block_size
cloned_key_cache[block_idx, :, :, block_offset, :] = reshaped_key[i]
cloned_value_cache[block_idx, :, block_offset, :] = value[i]
cloned_value_cache[block_idx, :, :, block_offset] = value[i]
assert torch.allclose(key_cache, cloned_key_cache)
assert torch.allclose(value_cache, cloned_value_cache)
@torch.no_grad()
def test_kernels():
@torch.inference_mode()
def test_cache() -> None:
test_reshape_and_cache(
num_tokens=3, num_heads=2, head_size=16, block_size=2, num_blocks=2,
num_tokens=3, num_heads=2, head_size=16, block_size=8, num_blocks=2,
dtype=torch.half)
if __name__ == '__main__':
test_kernels()
test_cache()