Implement single_query_cached_kv_attention
kernel (#3)
This commit is contained in:
parent
cbf8779afa
commit
0deacbce6e
@ -15,7 +15,9 @@ class BlockManager:
|
|||||||
block_size: int,
|
block_size: int,
|
||||||
num_blocks: int,
|
num_blocks: int,
|
||||||
) -> None:
|
) -> None:
|
||||||
assert block_size in [8, 16, 32]
|
if block_size not in [8, 16]:
|
||||||
|
raise ValueError(f'Unsupported block size: {block_size}'
|
||||||
|
'The block size must be either 8 or 16.')
|
||||||
self.device = device
|
self.device = device
|
||||||
self.block_size = block_size
|
self.block_size = block_size
|
||||||
self.num_blocks = num_blocks
|
self.num_blocks = num_blocks
|
||||||
|
@ -3,7 +3,8 @@ from typing import List, Optional
|
|||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
|
|
||||||
from cacheflow import ops
|
from cacheflow import attention_ops
|
||||||
|
from cacheflow import cache_ops
|
||||||
from cacheflow.models import InputMetadata
|
from cacheflow.models import InputMetadata
|
||||||
|
|
||||||
|
|
||||||
@ -11,7 +12,7 @@ class OPTCacheFlowAttention(nn.Module):
|
|||||||
|
|
||||||
def __init__(self, scale: float) -> None:
|
def __init__(self, scale: float) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.scale = scale
|
self.scale = float(scale)
|
||||||
|
|
||||||
def _masked_attention(
|
def _masked_attention(
|
||||||
self,
|
self,
|
||||||
@ -57,38 +58,21 @@ class OPTCacheFlowAttention(nn.Module):
|
|||||||
output: torch.Tensor, # [num_generation_tokens, num_heads, head_size]
|
output: torch.Tensor, # [num_generation_tokens, num_heads, head_size]
|
||||||
query: torch.Tensor, # [num_generation_tokens, num_heads, head_size]
|
query: torch.Tensor, # [num_generation_tokens, num_heads, head_size]
|
||||||
key_cache: torch.Tensor, # [num_blocks, num_heads, head_size/x, block_size, x]
|
key_cache: torch.Tensor, # [num_blocks, num_heads, head_size/x, block_size, x]
|
||||||
value_cache: torch.Tensor, # [num_blocks, num_heads, block_size, head_size]
|
value_cache: torch.Tensor, # [num_blocks, num_heads, head_size, block_size]
|
||||||
input_metadata: InputMetadata,
|
input_metadata: InputMetadata,
|
||||||
) -> None:
|
) -> None:
|
||||||
num_heads = value_cache.shape[1]
|
block_size = value_cache.shape[3]
|
||||||
head_size = value_cache.shape[3]
|
attention_ops.single_query_cached_kv_attention(
|
||||||
block_size = value_cache.shape[2]
|
output,
|
||||||
block_tables = input_metadata.block_tables
|
query,
|
||||||
|
key_cache,
|
||||||
# FIXME(woosuk): Replace the following with a custom op.
|
value_cache,
|
||||||
for i in range(input_metadata.num_generation_tokens):
|
self.scale,
|
||||||
q = query[i].unsqueeze(0)
|
input_metadata.block_tables,
|
||||||
block_table = block_tables[i]
|
input_metadata.context_lens,
|
||||||
context_len = int(input_metadata.context_lens[i])
|
block_size,
|
||||||
|
input_metadata.max_context_len,
|
||||||
keys = []
|
)
|
||||||
values = []
|
|
||||||
for j in range(context_len):
|
|
||||||
block_number = int(block_table[j // block_size])
|
|
||||||
block_offset = j % block_size
|
|
||||||
|
|
||||||
k = key_cache[block_number, :, :, block_offset, :]
|
|
||||||
k = k.reshape(num_heads, head_size)
|
|
||||||
keys.append(k)
|
|
||||||
|
|
||||||
v = value_cache[block_number, :, block_offset, :]
|
|
||||||
values.append(v)
|
|
||||||
keys = torch.stack(keys, dim=0)
|
|
||||||
values = torch.stack(values, dim=0)
|
|
||||||
|
|
||||||
out = self._masked_attention(q, keys, values)
|
|
||||||
out = out.view(num_heads, head_size)
|
|
||||||
output[i].copy_(out, non_blocking=True)
|
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
@ -96,7 +80,7 @@ class OPTCacheFlowAttention(nn.Module):
|
|||||||
key: torch.Tensor, # [num_tokens, num_heads * head_size]
|
key: torch.Tensor, # [num_tokens, num_heads * head_size]
|
||||||
value: torch.Tensor, # [num_tokens, num_heads * head_size]
|
value: torch.Tensor, # [num_tokens, num_heads * head_size]
|
||||||
key_cache: torch.Tensor, # [num_blocks, num_heads, head_size/x, block_size, x]
|
key_cache: torch.Tensor, # [num_blocks, num_heads, head_size/x, block_size, x]
|
||||||
value_cache: torch.Tensor, # [num_blocks, num_heads, block_size, head_size]
|
value_cache: torch.Tensor, # [num_blocks, num_heads, head_size, block_size]
|
||||||
input_metadata: InputMetadata,
|
input_metadata: InputMetadata,
|
||||||
cache_event: Optional[torch.cuda.Event],
|
cache_event: Optional[torch.cuda.Event],
|
||||||
) -> torch.Tensor: # [num_tokens, num_heads * head_size]
|
) -> torch.Tensor: # [num_tokens, num_heads * head_size]
|
||||||
@ -110,7 +94,7 @@ class OPTCacheFlowAttention(nn.Module):
|
|||||||
|
|
||||||
# Reshape the input tensors.
|
# Reshape the input tensors.
|
||||||
num_heads = value_cache.shape[1]
|
num_heads = value_cache.shape[1]
|
||||||
head_size = value_cache.shape[3]
|
head_size = value_cache.shape[2]
|
||||||
query = query.view(-1, num_heads, head_size)
|
query = query.view(-1, num_heads, head_size)
|
||||||
key = key.view(-1, num_heads, head_size)
|
key = key.view(-1, num_heads, head_size)
|
||||||
value = value.view(-1, num_heads, head_size)
|
value = value.view(-1, num_heads, head_size)
|
||||||
@ -125,7 +109,7 @@ class OPTCacheFlowAttention(nn.Module):
|
|||||||
cache_event.wait()
|
cache_event.wait()
|
||||||
|
|
||||||
# Reshape the keys and values and store them in the cache.
|
# Reshape the keys and values and store them in the cache.
|
||||||
ops.reshape_and_cache(
|
cache_ops.reshape_and_cache(
|
||||||
key, value, key_cache, value_cache, input_metadata.slot_mapping)
|
key, value, key_cache, value_cache, input_metadata.slot_mapping)
|
||||||
|
|
||||||
if input_metadata.num_generation_tokens > 0:
|
if input_metadata.num_generation_tokens > 0:
|
||||||
|
@ -1,7 +1,7 @@
|
|||||||
from typing import Dict, List, Tuple
|
from typing import Dict, List, Tuple
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from cacheflow import ops
|
from cacheflow import cache_ops
|
||||||
|
|
||||||
KVCache = Tuple[torch.Tensor, torch.Tensor]
|
KVCache = Tuple[torch.Tensor, torch.Tensor]
|
||||||
|
|
||||||
@ -57,20 +57,22 @@ class CacheEngine:
|
|||||||
def get_value_block_shape(self) -> Tuple[int, int, int]:
|
def get_value_block_shape(self) -> Tuple[int, int, int]:
|
||||||
return (
|
return (
|
||||||
self.num_heads,
|
self.num_heads,
|
||||||
self.block_size,
|
|
||||||
self.head_size,
|
self.head_size,
|
||||||
|
self.block_size,
|
||||||
)
|
)
|
||||||
|
|
||||||
def allocate_gpu_cache(self) -> List[KVCache]:
|
def allocate_gpu_cache(self) -> List[KVCache]:
|
||||||
gpu_cache: List[KVCache] = []
|
gpu_cache: List[KVCache] = []
|
||||||
|
key_block_shape = self.get_key_block_shape()
|
||||||
|
value_block_shape = self.get_value_block_shape()
|
||||||
for _ in range(self.num_layers):
|
for _ in range(self.num_layers):
|
||||||
key_blocks = torch.empty(
|
key_blocks = torch.empty(
|
||||||
size=(self.num_gpu_blocks, *self.get_key_block_shape()),
|
size=(self.num_gpu_blocks, *key_block_shape),
|
||||||
dtype=self.dtype,
|
dtype=self.dtype,
|
||||||
device=self.gpu_id,
|
device=self.gpu_id,
|
||||||
)
|
)
|
||||||
value_blocks = torch.empty(
|
value_blocks = torch.empty(
|
||||||
size=(self.num_gpu_blocks, *self.get_value_block_shape()),
|
size=(self.num_gpu_blocks, *value_block_shape),
|
||||||
dtype=self.dtype,
|
dtype=self.dtype,
|
||||||
device=self.gpu_id,
|
device=self.gpu_id,
|
||||||
)
|
)
|
||||||
@ -79,14 +81,16 @@ class CacheEngine:
|
|||||||
|
|
||||||
def allocate_cpu_cache(self) -> List[KVCache]:
|
def allocate_cpu_cache(self) -> List[KVCache]:
|
||||||
cpu_cache: List[KVCache] = []
|
cpu_cache: List[KVCache] = []
|
||||||
|
key_block_shape = self.get_key_block_shape()
|
||||||
|
value_block_shape = self.get_value_block_shape()
|
||||||
for _ in range(self.num_layers):
|
for _ in range(self.num_layers):
|
||||||
key_blocks = torch.empty(
|
key_blocks = torch.empty(
|
||||||
size=(self.num_cpu_blocks, *self.get_key_block_shape()),
|
size=(self.num_cpu_blocks, *key_block_shape),
|
||||||
dtype=self.dtype,
|
dtype=self.dtype,
|
||||||
pin_memory=True,
|
pin_memory=True,
|
||||||
)
|
)
|
||||||
value_blocks = torch.empty(
|
value_blocks = torch.empty(
|
||||||
size=(self.num_cpu_blocks, *self.get_value_block_shape()),
|
size=(self.num_cpu_blocks, *value_block_shape),
|
||||||
dtype=self.dtype,
|
dtype=self.dtype,
|
||||||
pin_memory=True,
|
pin_memory=True,
|
||||||
)
|
)
|
||||||
@ -104,10 +108,10 @@ class CacheEngine:
|
|||||||
src_key_cache, src_value_cache = src[i]
|
src_key_cache, src_value_cache = src[i]
|
||||||
dst_key_cache, dst_value_cache = dst[i]
|
dst_key_cache, dst_value_cache = dst[i]
|
||||||
# Copy the key blocks.
|
# Copy the key blocks.
|
||||||
ops.copy_cache_blocks(
|
cache_ops.copy_cache_blocks(
|
||||||
src_key_cache, dst_key_cache, src_to_dst)
|
src_key_cache, dst_key_cache, src_to_dst)
|
||||||
# Copy the value blocks.
|
# Copy the value blocks.
|
||||||
ops.copy_cache_blocks(
|
cache_ops.copy_cache_blocks(
|
||||||
src_value_cache, dst_value_cache, src_to_dst)
|
src_value_cache, dst_value_cache, src_to_dst)
|
||||||
event = self.events[i]
|
event = self.events[i]
|
||||||
event.record(stream=self.cache_stream)
|
event.record(stream=self.cache_stream)
|
||||||
|
@ -118,7 +118,7 @@ class Worker:
|
|||||||
_pad_to_max(block_table, max_num_blocks_per_seq)
|
_pad_to_max(block_table, max_num_blocks_per_seq)
|
||||||
for block_table in generation_block_tables]
|
for block_table in generation_block_tables]
|
||||||
block_tables_tensor = torch.tensor(
|
block_tables_tensor = torch.tensor(
|
||||||
padded_block_tables, dtype=int, device=self.device)
|
padded_block_tables, dtype=torch.int, device=self.device)
|
||||||
|
|
||||||
input_metadata = InputMetadata(
|
input_metadata = InputMetadata(
|
||||||
seq_ids=prompt_seq_ids + generation_seq_ids,
|
seq_ids=prompt_seq_ids + generation_seq_ids,
|
||||||
|
19
csrc/attention.cpp
Normal file
19
csrc/attention.cpp
Normal file
@ -0,0 +1,19 @@
|
|||||||
|
#include <torch/extension.h>
|
||||||
|
|
||||||
|
void single_query_cached_kv_attention(
|
||||||
|
torch::Tensor& out,
|
||||||
|
torch::Tensor& query,
|
||||||
|
torch::Tensor& key_cache,
|
||||||
|
torch::Tensor& value_cache,
|
||||||
|
float scale,
|
||||||
|
torch::Tensor& block_tables,
|
||||||
|
torch::Tensor& context_lens,
|
||||||
|
int block_size,
|
||||||
|
int max_context_len);
|
||||||
|
|
||||||
|
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
|
||||||
|
m.def(
|
||||||
|
"single_query_cached_kv_attention",
|
||||||
|
&single_query_cached_kv_attention,
|
||||||
|
"Compute the attention between an input query and the cached key/value tensors");
|
||||||
|
}
|
400
csrc/attention_kernels.cu
Normal file
400
csrc/attention_kernels.cu
Normal file
@ -0,0 +1,400 @@
|
|||||||
|
#include <torch/extension.h>
|
||||||
|
#include <ATen/cuda/CUDAContext.h>
|
||||||
|
|
||||||
|
#include "attention_utils.h"
|
||||||
|
#include "cuda_primitives.h"
|
||||||
|
|
||||||
|
#include <algorithm>
|
||||||
|
|
||||||
|
#define WARP_SIZE 32
|
||||||
|
|
||||||
|
namespace cacheflow {
|
||||||
|
|
||||||
|
// Grid: (num_heads, num_seqs).
|
||||||
|
template<
|
||||||
|
typename scalar_t,
|
||||||
|
int HEAD_SIZE,
|
||||||
|
int BLOCK_SIZE,
|
||||||
|
int NUM_THREADS>
|
||||||
|
__global__ void single_query_cached_kv_attention_kernel(
|
||||||
|
scalar_t* __restrict__ out, // [num_seqs, num_heads, head_size]
|
||||||
|
const scalar_t* __restrict__ q, // [num_seqs, num_heads, head_size]
|
||||||
|
const scalar_t* __restrict__ k_cache, // [num_blocks, num_heads, head_size/x, block_size, x]
|
||||||
|
const scalar_t* __restrict__ v_cache, // [num_blocks, num_heads, head_size, block_size]
|
||||||
|
const float scale,
|
||||||
|
const int* __restrict__ block_tables, // [num_seqs, max_num_blocks_per_seq]
|
||||||
|
const int* __restrict__ context_lens, // [num_seqs]
|
||||||
|
const int max_num_blocks_per_seq) {
|
||||||
|
constexpr int THREAD_GROUP_SIZE = WARP_SIZE / BLOCK_SIZE;
|
||||||
|
constexpr int NUM_WARPS = NUM_THREADS / WARP_SIZE;
|
||||||
|
const int thread_idx = threadIdx.x;
|
||||||
|
const int warp_idx = thread_idx / WARP_SIZE;
|
||||||
|
const int lane = thread_idx % WARP_SIZE;
|
||||||
|
|
||||||
|
const int head_idx = blockIdx.x;
|
||||||
|
const int num_heads = gridDim.x;
|
||||||
|
const int seq_idx = blockIdx.y;
|
||||||
|
|
||||||
|
// A vector type to store a part of a key or a query.
|
||||||
|
// The vector size is configured in such a way that the threads in a thread group
|
||||||
|
// fetch or comput 16 bytes at a time.
|
||||||
|
// For example, if the size of a thread group is 4 and the data type is half,
|
||||||
|
// then the vector size is 16 / (4 * sizeof(half)) == 2.
|
||||||
|
constexpr int VEC_SIZE = 16 / (THREAD_GROUP_SIZE * sizeof(scalar_t));
|
||||||
|
using K_vec = typename Vec<scalar_t, VEC_SIZE>::Type;
|
||||||
|
using Q_vec = typename Vec<scalar_t, VEC_SIZE>::Type;
|
||||||
|
|
||||||
|
constexpr int NUM_ELEMS_PER_THREAD = HEAD_SIZE / THREAD_GROUP_SIZE;
|
||||||
|
constexpr int NUM_VECS_PER_THREAD = NUM_ELEMS_PER_THREAD / VEC_SIZE;
|
||||||
|
|
||||||
|
const int thread_group_idx = thread_idx / THREAD_GROUP_SIZE;
|
||||||
|
const int thread_group_offset = thread_idx % THREAD_GROUP_SIZE;
|
||||||
|
|
||||||
|
// Load the query to registers.
|
||||||
|
// Each thread in a thread group has a different part of the query.
|
||||||
|
// For example, if the the thread group size is 4, then the first thread in the group
|
||||||
|
// has 0, 4, 8, ... th vectors of the query, and the second thread has 1, 5, 9, ...
|
||||||
|
// th vectors of the query, and so on.
|
||||||
|
const scalar_t* q_ptr = q + seq_idx * num_heads * HEAD_SIZE + head_idx * HEAD_SIZE;
|
||||||
|
Q_vec q_vecs[NUM_VECS_PER_THREAD];
|
||||||
|
#pragma unroll
|
||||||
|
for (int i = 0; i < NUM_VECS_PER_THREAD; i++) {
|
||||||
|
const int vec_idx = thread_group_offset + i * THREAD_GROUP_SIZE;
|
||||||
|
q_vecs[i] = *reinterpret_cast<const Q_vec*>(q_ptr + vec_idx * VEC_SIZE);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Memory planning.
|
||||||
|
extern __shared__ char shared_mem[];
|
||||||
|
// NOTE(woosuk): We use FP32 logits and accumulation.
|
||||||
|
float *logits = reinterpret_cast<float*>(shared_mem);
|
||||||
|
// Workspace for reduction.
|
||||||
|
__shared__ float red_smem[2 * NUM_WARPS];
|
||||||
|
|
||||||
|
// x == THREAD_GROUP_SIZE * VEC_SIZE
|
||||||
|
// Each thread group fetches x elements from the key at a time.
|
||||||
|
constexpr int x = 16 / sizeof(scalar_t);
|
||||||
|
float qk_max = -FLT_MAX;
|
||||||
|
|
||||||
|
const int* block_table = block_tables + seq_idx * max_num_blocks_per_seq;
|
||||||
|
const int context_len = context_lens[seq_idx];
|
||||||
|
const int num_blocks = (context_len + BLOCK_SIZE - 1) / BLOCK_SIZE;
|
||||||
|
|
||||||
|
// Iterate over the key blocks.
|
||||||
|
// Each warp fetches a block of keys for each iteration.
|
||||||
|
// Each thread group in a warp fetches a key from the block, and computes
|
||||||
|
// dot product with the query.
|
||||||
|
for (int block_idx = warp_idx; block_idx < num_blocks; block_idx += NUM_WARPS) {
|
||||||
|
const int physical_block_number = block_table[block_idx];
|
||||||
|
const int physical_block_offset = thread_group_idx % BLOCK_SIZE;
|
||||||
|
const int token_idx = block_idx * BLOCK_SIZE + physical_block_offset;
|
||||||
|
|
||||||
|
// Load a key to registers.
|
||||||
|
// Each thread in a thread group has a different part of the key.
|
||||||
|
// For example, if the the thread group size is 4, then the first thread in the group
|
||||||
|
// has 0, 4, 8, ... th vectors of the key, and the second thread has 1, 5, 9, ... th
|
||||||
|
// vectors of the key, and so on.
|
||||||
|
K_vec k_vecs[NUM_VECS_PER_THREAD];
|
||||||
|
#pragma unroll
|
||||||
|
for (int i = 0; i < NUM_VECS_PER_THREAD; i++) {
|
||||||
|
const scalar_t* k_ptr = k_cache + physical_block_number * num_heads * HEAD_SIZE * BLOCK_SIZE
|
||||||
|
+ head_idx * HEAD_SIZE * BLOCK_SIZE
|
||||||
|
+ physical_block_offset * x;
|
||||||
|
const int vec_idx = thread_group_offset + i * THREAD_GROUP_SIZE;
|
||||||
|
const int offset1 = (vec_idx * VEC_SIZE) / x;
|
||||||
|
const int offset2 = (vec_idx * VEC_SIZE) % x;
|
||||||
|
k_vecs[i] = *reinterpret_cast<const K_vec*>(k_ptr + offset1 * BLOCK_SIZE * x + offset2);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Compute dot product.
|
||||||
|
// This includes a reduction across the threads in the same thread group.
|
||||||
|
const float qk = scale * Qk_dot<scalar_t, THREAD_GROUP_SIZE>::dot(q_vecs, k_vecs);
|
||||||
|
const bool mask = token_idx >= context_len;
|
||||||
|
|
||||||
|
if (thread_group_offset == 0) {
|
||||||
|
// Store the partial reductions to shared memory.
|
||||||
|
// NOTE(woosuk): It is required to zero out the masked logits.
|
||||||
|
logits[token_idx] = mask ? 0.f : qk;
|
||||||
|
// Update the max value.
|
||||||
|
qk_max = mask ? qk_max : fmaxf(qk_max, qk);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Perform reduction across the threads in the same warp to get the
|
||||||
|
// max qk value for each "warp" (not across the thread block yet).
|
||||||
|
// The 0-th thread of each thread group already has its max qk value.
|
||||||
|
#pragma unroll
|
||||||
|
for (int mask = WARP_SIZE / 2; mask >= THREAD_GROUP_SIZE; mask /= 2) {
|
||||||
|
qk_max = fmaxf(qk_max, __shfl_xor_sync(uint32_t(-1), qk_max, mask));
|
||||||
|
}
|
||||||
|
if (lane == 0) {
|
||||||
|
red_smem[warp_idx] = qk_max;
|
||||||
|
}
|
||||||
|
__syncthreads();
|
||||||
|
|
||||||
|
// TODO(woosuk): Refactor this part.
|
||||||
|
// Get the max qk value for the sequence.
|
||||||
|
qk_max = lane < NUM_WARPS ? red_smem[lane] : -FLT_MAX;
|
||||||
|
#pragma unroll
|
||||||
|
for (int mask = NUM_WARPS / 2; mask >= 1; mask /= 2) {
|
||||||
|
qk_max = fmaxf(qk_max, __shfl_xor_sync(uint32_t(-1), qk_max, mask));
|
||||||
|
}
|
||||||
|
// Broadcast the max qk value to all threads.
|
||||||
|
qk_max = __shfl_sync(uint32_t(-1), qk_max, 0);
|
||||||
|
|
||||||
|
// Get the sum of the exp values.
|
||||||
|
float exp_sum = 0.f;
|
||||||
|
for (int i = thread_idx; i < context_len; i += NUM_THREADS) {
|
||||||
|
float val = __expf(logits[i] - qk_max);
|
||||||
|
logits[i] = val;
|
||||||
|
exp_sum += val;
|
||||||
|
}
|
||||||
|
exp_sum = block_sum<NUM_WARPS>(&red_smem[NUM_WARPS], exp_sum);
|
||||||
|
|
||||||
|
// Compute softmax.
|
||||||
|
const float inv_sum = __fdividef(1.f, exp_sum + 1e-6f);
|
||||||
|
for (int i = thread_idx; i < context_len; i += NUM_THREADS) {
|
||||||
|
logits[i] *= inv_sum;
|
||||||
|
}
|
||||||
|
__syncthreads();
|
||||||
|
|
||||||
|
// Each thread will fetch 16 bytes from the value cache at a time.
|
||||||
|
constexpr int V_VEC_SIZE = 16 / sizeof(scalar_t);
|
||||||
|
using V_vec = typename Vec<scalar_t, V_VEC_SIZE>::Type;
|
||||||
|
using L_vec = typename FloatVec<V_vec>::Type;
|
||||||
|
|
||||||
|
constexpr int NUM_V_VECS_PER_ROW = BLOCK_SIZE / V_VEC_SIZE;
|
||||||
|
constexpr int NUM_ROWS_PER_ITER = WARP_SIZE / NUM_V_VECS_PER_ROW;
|
||||||
|
constexpr int NUM_ROWS_PER_THREAD = (HEAD_SIZE + NUM_ROWS_PER_ITER - 1) / NUM_ROWS_PER_ITER;
|
||||||
|
|
||||||
|
float accs[NUM_ROWS_PER_THREAD];
|
||||||
|
#pragma unroll
|
||||||
|
for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) {
|
||||||
|
accs[i] = 0.f;
|
||||||
|
}
|
||||||
|
|
||||||
|
for (int block_idx = warp_idx; block_idx < num_blocks; block_idx += NUM_WARPS) {
|
||||||
|
const int physical_block_number = block_table[block_idx];
|
||||||
|
const int physical_block_offset = (lane % NUM_V_VECS_PER_ROW) * V_VEC_SIZE;
|
||||||
|
const int token_idx = block_idx * BLOCK_SIZE + physical_block_offset;
|
||||||
|
L_vec logits_vec = *reinterpret_cast<L_vec*>(logits + token_idx);
|
||||||
|
|
||||||
|
const scalar_t* v_ptr = v_cache + physical_block_number * num_heads * HEAD_SIZE * BLOCK_SIZE
|
||||||
|
+ head_idx * HEAD_SIZE * BLOCK_SIZE;
|
||||||
|
#pragma unroll
|
||||||
|
for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) {
|
||||||
|
const int row_idx = lane / NUM_V_VECS_PER_ROW + i * NUM_ROWS_PER_ITER;
|
||||||
|
if (row_idx < HEAD_SIZE) {
|
||||||
|
const int offset = row_idx * BLOCK_SIZE + physical_block_offset;
|
||||||
|
V_vec v_vec = *reinterpret_cast<const V_vec*>(v_ptr + offset);
|
||||||
|
accs[i] += dot(logits_vec, cast_to_float(v_vec));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Perform reduction within each warp.
|
||||||
|
#pragma unroll
|
||||||
|
for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) {
|
||||||
|
float acc = accs[i];
|
||||||
|
#pragma unroll
|
||||||
|
for (int mask = NUM_V_VECS_PER_ROW / 2; mask >= 1; mask /= 2) {
|
||||||
|
acc += __shfl_xor_sync(uint32_t(-1), acc, mask);
|
||||||
|
}
|
||||||
|
accs[i] = acc;
|
||||||
|
}
|
||||||
|
|
||||||
|
// NOTE(woosuk): A barrier is required because the shared memory space for logits
|
||||||
|
// is reused for the output.
|
||||||
|
__syncthreads();
|
||||||
|
|
||||||
|
// Perform reduction across warps.
|
||||||
|
float* out_smem = reinterpret_cast<float*>(shared_mem);
|
||||||
|
#pragma unroll
|
||||||
|
for (int i = NUM_WARPS; i > 1; i /= 2) {
|
||||||
|
int mid = i / 2;
|
||||||
|
// Upper warps write to shared memory.
|
||||||
|
if (warp_idx >= mid && warp_idx < i) {
|
||||||
|
float* dst = &out_smem[(warp_idx - mid) * HEAD_SIZE];
|
||||||
|
#pragma unroll
|
||||||
|
for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) {
|
||||||
|
const int row_idx = lane / NUM_V_VECS_PER_ROW + i * NUM_ROWS_PER_ITER;
|
||||||
|
if (row_idx < HEAD_SIZE && lane % NUM_V_VECS_PER_ROW == 0) {
|
||||||
|
dst[row_idx] = accs[i];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
__syncthreads();
|
||||||
|
|
||||||
|
// Lower warps update the output.
|
||||||
|
if (warp_idx < mid) {
|
||||||
|
const float* src = &out_smem[warp_idx * HEAD_SIZE];
|
||||||
|
#pragma unroll
|
||||||
|
for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) {
|
||||||
|
const int row_idx = lane / NUM_V_VECS_PER_ROW + i * NUM_ROWS_PER_ITER;
|
||||||
|
if (row_idx < HEAD_SIZE && lane % NUM_V_VECS_PER_ROW == 0) {
|
||||||
|
accs[i] += src[row_idx];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
__syncthreads();
|
||||||
|
}
|
||||||
|
|
||||||
|
// Write the final output.
|
||||||
|
if (warp_idx == 0) {
|
||||||
|
scalar_t* out_ptr = out + seq_idx * num_heads * HEAD_SIZE + head_idx * HEAD_SIZE;
|
||||||
|
#pragma unroll
|
||||||
|
for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) {
|
||||||
|
const int row_idx = lane / NUM_V_VECS_PER_ROW + i * NUM_ROWS_PER_ITER;
|
||||||
|
if (row_idx < HEAD_SIZE && lane % NUM_V_VECS_PER_ROW == 0) {
|
||||||
|
convert_from_float(*(out_ptr + row_idx), accs[i]);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace cacheflow
|
||||||
|
|
||||||
|
#define LAUNCH_ATTENTION_KERNEL(T, HEAD_SIZE, BLOCK_SIZE, NUM_THREADS) \
|
||||||
|
cacheflow::single_query_cached_kv_attention_kernel<T, HEAD_SIZE, BLOCK_SIZE, NUM_THREADS> \
|
||||||
|
<<<grid, block, shared_mem_size, stream>>>( \
|
||||||
|
out_ptr, \
|
||||||
|
query_ptr, \
|
||||||
|
key_cache_ptr, \
|
||||||
|
value_cache_ptr, \
|
||||||
|
scale, \
|
||||||
|
block_tables_ptr, \
|
||||||
|
context_lens_ptr, \
|
||||||
|
max_num_blocks_per_seq);
|
||||||
|
|
||||||
|
// TODO(woosuk): Tune NUM_THREADS.
|
||||||
|
template<
|
||||||
|
typename T,
|
||||||
|
int BLOCK_SIZE,
|
||||||
|
int NUM_THREADS = 128>
|
||||||
|
void single_query_cached_kv_attention_launcher(
|
||||||
|
torch::Tensor& out,
|
||||||
|
torch::Tensor& query,
|
||||||
|
torch::Tensor& key_cache,
|
||||||
|
torch::Tensor& value_cache,
|
||||||
|
float scale,
|
||||||
|
torch::Tensor& block_tables,
|
||||||
|
torch::Tensor& context_lens,
|
||||||
|
int max_context_len) {
|
||||||
|
int num_seqs = query.size(0);
|
||||||
|
int num_heads = query.size(1);
|
||||||
|
int head_size = query.size(2);
|
||||||
|
int max_num_blocks_per_seq = block_tables.size(1);
|
||||||
|
|
||||||
|
T* out_ptr = reinterpret_cast<T*>(out.data_ptr());
|
||||||
|
T* query_ptr = reinterpret_cast<T*>(query.data_ptr());
|
||||||
|
T* key_cache_ptr = reinterpret_cast<T*>(key_cache.data_ptr());
|
||||||
|
T* value_cache_ptr = reinterpret_cast<T*>(value_cache.data_ptr());
|
||||||
|
int* block_tables_ptr = block_tables.data_ptr<int>();
|
||||||
|
int* context_lens_ptr = context_lens.data_ptr<int>();
|
||||||
|
|
||||||
|
constexpr int NUM_WARPS = NUM_THREADS / WARP_SIZE;
|
||||||
|
int padded_max_context_len = ((max_context_len + BLOCK_SIZE - 1) / BLOCK_SIZE) * BLOCK_SIZE;
|
||||||
|
int logits_size = padded_max_context_len * sizeof(float);
|
||||||
|
int outputs_size = (NUM_WARPS / 2) * head_size * sizeof(float);
|
||||||
|
int shared_mem_size = std::max(logits_size, outputs_size);
|
||||||
|
|
||||||
|
dim3 grid(num_heads, num_seqs);
|
||||||
|
dim3 block(NUM_THREADS);
|
||||||
|
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
||||||
|
switch (head_size) {
|
||||||
|
case 32:
|
||||||
|
LAUNCH_ATTENTION_KERNEL(T, 32, BLOCK_SIZE, NUM_THREADS);
|
||||||
|
break;
|
||||||
|
case 64:
|
||||||
|
LAUNCH_ATTENTION_KERNEL(T, 64, BLOCK_SIZE, NUM_THREADS);
|
||||||
|
break;
|
||||||
|
case 80:
|
||||||
|
LAUNCH_ATTENTION_KERNEL(T, 80, BLOCK_SIZE, NUM_THREADS);
|
||||||
|
break;
|
||||||
|
case 96:
|
||||||
|
LAUNCH_ATTENTION_KERNEL(T, 96, BLOCK_SIZE, NUM_THREADS);
|
||||||
|
break;
|
||||||
|
case 128:
|
||||||
|
LAUNCH_ATTENTION_KERNEL(T, 128, BLOCK_SIZE, NUM_THREADS);
|
||||||
|
break;
|
||||||
|
case 160:
|
||||||
|
LAUNCH_ATTENTION_KERNEL(T, 160, BLOCK_SIZE, NUM_THREADS);
|
||||||
|
break;
|
||||||
|
case 192:
|
||||||
|
LAUNCH_ATTENTION_KERNEL(T, 192, BLOCK_SIZE, NUM_THREADS);
|
||||||
|
break;
|
||||||
|
case 256:
|
||||||
|
LAUNCH_ATTENTION_KERNEL(T, 256, BLOCK_SIZE, NUM_THREADS);
|
||||||
|
break;
|
||||||
|
default:
|
||||||
|
assert(false);
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
void single_query_cached_kv_attention(
|
||||||
|
torch::Tensor& out,
|
||||||
|
torch::Tensor& query,
|
||||||
|
torch::Tensor& key_cache,
|
||||||
|
torch::Tensor& value_cache,
|
||||||
|
float scale,
|
||||||
|
torch::Tensor& block_tables,
|
||||||
|
torch::Tensor& context_lens,
|
||||||
|
int block_size,
|
||||||
|
int max_context_len) {
|
||||||
|
// TODO(woosuk): Support BF16.
|
||||||
|
if (query.element_size() == 2) {
|
||||||
|
// Half.
|
||||||
|
if (block_size == 8) {
|
||||||
|
single_query_cached_kv_attention_launcher<uint16_t, 8>(
|
||||||
|
out,
|
||||||
|
query,
|
||||||
|
key_cache,
|
||||||
|
value_cache,
|
||||||
|
scale,
|
||||||
|
block_tables,
|
||||||
|
context_lens,
|
||||||
|
max_context_len);
|
||||||
|
} else if (block_size == 16) {
|
||||||
|
single_query_cached_kv_attention_launcher<uint16_t, 16>(
|
||||||
|
out,
|
||||||
|
query,
|
||||||
|
key_cache,
|
||||||
|
value_cache,
|
||||||
|
scale,
|
||||||
|
block_tables,
|
||||||
|
context_lens,
|
||||||
|
max_context_len);
|
||||||
|
} else {
|
||||||
|
assert(false);
|
||||||
|
}
|
||||||
|
} else if (query.element_size() == 4) {
|
||||||
|
// Float.
|
||||||
|
if (block_size == 8) {
|
||||||
|
single_query_cached_kv_attention_launcher<float, 8>(
|
||||||
|
out,
|
||||||
|
query,
|
||||||
|
key_cache,
|
||||||
|
value_cache,
|
||||||
|
scale,
|
||||||
|
block_tables,
|
||||||
|
context_lens,
|
||||||
|
max_context_len);
|
||||||
|
} else if (block_size == 16) {
|
||||||
|
single_query_cached_kv_attention_launcher<float, 16>(
|
||||||
|
out,
|
||||||
|
query,
|
||||||
|
key_cache,
|
||||||
|
value_cache,
|
||||||
|
scale,
|
||||||
|
block_tables,
|
||||||
|
context_lens,
|
||||||
|
max_context_len);
|
||||||
|
} else {
|
||||||
|
assert(false);
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
assert(false);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#undef WARP_SIZE
|
204
csrc/attention_utils.h
Normal file
204
csrc/attention_utils.h
Normal file
@ -0,0 +1,204 @@
|
|||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include "cuda_primitives.h"
|
||||||
|
|
||||||
|
#include <float.h>
|
||||||
|
#include <type_traits>
|
||||||
|
|
||||||
|
#define MMHA_USE_FP32_ACUM_FOR_FMA
|
||||||
|
#define MMHA_USE_FP32_ACUM_FOR_OUT
|
||||||
|
|
||||||
|
namespace cacheflow {
|
||||||
|
|
||||||
|
// A vector type to store Q, K, V elements.
|
||||||
|
template<typename T, int VEC_SIZE>
|
||||||
|
struct Vec {};
|
||||||
|
template<>
|
||||||
|
struct Vec<float, 1> {
|
||||||
|
using Type = float;
|
||||||
|
};
|
||||||
|
template<>
|
||||||
|
struct Vec<float, 2> {
|
||||||
|
using Type = float2;
|
||||||
|
};
|
||||||
|
template<>
|
||||||
|
struct Vec<float, 4> {
|
||||||
|
using Type = float4;
|
||||||
|
};
|
||||||
|
template<>
|
||||||
|
struct Vec<uint16_t, 1> {
|
||||||
|
using Type = uint16_t;
|
||||||
|
};
|
||||||
|
template<>
|
||||||
|
struct Vec<uint16_t, 2> {
|
||||||
|
using Type = uint32_t;
|
||||||
|
};
|
||||||
|
template<>
|
||||||
|
struct Vec<uint16_t, 4> {
|
||||||
|
using Type = uint2;
|
||||||
|
};
|
||||||
|
template<>
|
||||||
|
struct Vec<uint16_t, 8> {
|
||||||
|
using Type = uint4;
|
||||||
|
};
|
||||||
|
|
||||||
|
template<typename T>
|
||||||
|
struct FloatVec {};
|
||||||
|
template<>
|
||||||
|
struct FloatVec<float> {
|
||||||
|
using Type = float;
|
||||||
|
};
|
||||||
|
template<>
|
||||||
|
struct FloatVec<float2> {
|
||||||
|
using Type = float2;
|
||||||
|
};
|
||||||
|
template<>
|
||||||
|
struct FloatVec<float4> {
|
||||||
|
using Type = float4;
|
||||||
|
};
|
||||||
|
template<>
|
||||||
|
struct FloatVec<uint16_t> {
|
||||||
|
using Type = float;
|
||||||
|
};
|
||||||
|
template<>
|
||||||
|
struct FloatVec<uint32_t> {
|
||||||
|
using Type = float2;
|
||||||
|
};
|
||||||
|
template<>
|
||||||
|
struct FloatVec<uint2> {
|
||||||
|
using Type = Float4_;
|
||||||
|
};
|
||||||
|
template<>
|
||||||
|
struct FloatVec<uint4> {
|
||||||
|
using Type = Float8_;
|
||||||
|
};
|
||||||
|
|
||||||
|
template<int THREADS_PER_KEY, typename K_vec, int N>
|
||||||
|
inline __device__ float qk_dot_(const K_vec (&q)[N], const K_vec (&k)[N])
|
||||||
|
{
|
||||||
|
using K_vec_acum = typename FloatVec<K_vec>::Type;
|
||||||
|
// Compute the parallel products for Q*K^T (treat vector lanes separately).
|
||||||
|
K_vec_acum qk_vec = mul<K_vec_acum, K_vec, K_vec>(q[0], k[0]);
|
||||||
|
#pragma unroll
|
||||||
|
for (int ii = 1; ii < N; ++ii) {
|
||||||
|
qk_vec = fma(q[ii], k[ii], qk_vec);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Finalize the reduction across lanes.
|
||||||
|
float qk = sum(qk_vec);
|
||||||
|
#pragma unroll
|
||||||
|
for (int mask = THREADS_PER_KEY / 2; mask >= 1; mask /= 2) {
|
||||||
|
qk += __shfl_xor_sync(uint32_t(-1), qk, mask);
|
||||||
|
}
|
||||||
|
return qk;
|
||||||
|
}
|
||||||
|
|
||||||
|
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||||
|
|
||||||
|
template<typename T, int THREADS_PER_KEY>
|
||||||
|
struct Qk_dot {
|
||||||
|
template<typename K_vec, int N>
|
||||||
|
static inline __device__ float dot(const K_vec (&q)[N], const K_vec (&k)[N])
|
||||||
|
{
|
||||||
|
return qk_dot_<THREADS_PER_KEY>(q, k);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||||
|
|
||||||
|
inline __device__ float4 hmma_fp32(const uint2& a, uint32_t b)
|
||||||
|
{
|
||||||
|
float4 c;
|
||||||
|
float zero = 0.f;
|
||||||
|
asm volatile("mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32 \n"
|
||||||
|
" {%0, %1, %2, %3}, \n"
|
||||||
|
" {%4, %5}, \n"
|
||||||
|
" {%6}, \n"
|
||||||
|
" {%7, %7, %7, %7}; \n"
|
||||||
|
|
||||||
|
: "=f"(c.x), "=f"(c.y), "=f"(c.z), "=f"(c.w)
|
||||||
|
: "r"(a.x) "r"(a.y), "r"(b), "f"(zero));
|
||||||
|
return c;
|
||||||
|
}
|
||||||
|
|
||||||
|
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||||
|
|
||||||
|
template<int N>
|
||||||
|
inline __device__ float qk_hmma_dot_(const uint32_t (&q)[N], const uint32_t (&k)[N])
|
||||||
|
{
|
||||||
|
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 750
|
||||||
|
using K_vec_acum = typename FloatVec<uint32_t>::Type;
|
||||||
|
K_vec_acum qk_vec = mul<K_vec_acum, uint32_t, uint32_t>(q[0], k[0]);
|
||||||
|
#pragma unroll
|
||||||
|
for (int ii = 1; ii < N; ++ii) {
|
||||||
|
qk_vec = fma(q[ii], k[ii], qk_vec);
|
||||||
|
}
|
||||||
|
#ifdef MMHA_USE_FP32_ACUM_FOR_FMA
|
||||||
|
uint32_t qk_vec_ = float2_to_half2(qk_vec);
|
||||||
|
return hmma_fp32(make_uint2(qk_vec_, 0u), 0x3c003c00u).x;
|
||||||
|
#else
|
||||||
|
return hmma_fp32(make_uint2(qk_vec, 0u), 0x3c003c00u).x;
|
||||||
|
#endif
|
||||||
|
#else
|
||||||
|
return 0.f;
|
||||||
|
#endif
|
||||||
|
}
|
||||||
|
|
||||||
|
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||||
|
|
||||||
|
template<>
|
||||||
|
struct Qk_dot<uint16_t, 4> {
|
||||||
|
template<int N>
|
||||||
|
static inline __device__ float dot(const uint32_t (&q)[N], const uint32_t (&k)[N])
|
||||||
|
{
|
||||||
|
#if __CUDA_ARCH__ >= 750 && defined(MMHA_USE_HMMA_FOR_REDUCTION)
|
||||||
|
return qk_hmma_dot_(q, k);
|
||||||
|
#else
|
||||||
|
return qk_dot_<4>(q, k);
|
||||||
|
#endif // defined MMHA_USE_HMMA_FOR_REDUCTION
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||||
|
|
||||||
|
template<int WARPS_PER_BLOCK, int WARP_SIZE = 32>
|
||||||
|
inline __device__ float block_sum(float* red_smem, float sum)
|
||||||
|
{
|
||||||
|
|
||||||
|
// Decompose the thread index into warp / lane.
|
||||||
|
int warp = threadIdx.x / WARP_SIZE;
|
||||||
|
int lane = threadIdx.x % WARP_SIZE;
|
||||||
|
|
||||||
|
// Compute the sum per warp.
|
||||||
|
#pragma unroll
|
||||||
|
for (int mask = WARP_SIZE / 2; mask >= 1; mask /= 2) {
|
||||||
|
sum += __shfl_xor_sync(uint32_t(-1), sum, mask);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Warp leaders store the data to shared memory.
|
||||||
|
if (lane == 0) {
|
||||||
|
red_smem[warp] = sum;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Make sure the data is in shared memory.
|
||||||
|
__syncthreads();
|
||||||
|
|
||||||
|
// The warps compute the final sums.
|
||||||
|
if (lane < WARPS_PER_BLOCK) {
|
||||||
|
sum = red_smem[lane];
|
||||||
|
}
|
||||||
|
|
||||||
|
// Parallel reduction inside the warp.
|
||||||
|
#pragma unroll
|
||||||
|
for (int mask = WARPS_PER_BLOCK / 2; mask >= 1; mask /= 2) {
|
||||||
|
sum += __shfl_xor_sync(uint32_t(-1), sum, mask);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Broadcast to other threads.
|
||||||
|
return __shfl_sync(uint32_t(-1), sum, 0);
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace cacheflow
|
||||||
|
|
||||||
|
#undef MMHA_USE_FP32_ACUM_FOR_FMA
|
||||||
|
#undef MMHA_USE_FP32_ACUM_FOR_OUT
|
@ -48,7 +48,7 @@ __global__ void reshape_and_cache_kernel(
|
|||||||
const scalar_t* __restrict__ key, // [num_tokens, num_heads, head_size]
|
const scalar_t* __restrict__ key, // [num_tokens, num_heads, head_size]
|
||||||
const scalar_t* __restrict__ value, // [num_tokens, num_heads, head_size]
|
const scalar_t* __restrict__ value, // [num_tokens, num_heads, head_size]
|
||||||
scalar_t* __restrict__ key_cache, // [num_blocks, num_heads, head_size/x, block_size, x]
|
scalar_t* __restrict__ key_cache, // [num_blocks, num_heads, head_size/x, block_size, x]
|
||||||
scalar_t* __restrict__ value_cache, // [num_blocks, num_heads, block_size, head_size]
|
scalar_t* __restrict__ value_cache, // [num_blocks, num_heads, head_size, block_size]
|
||||||
const int* __restrict__ slot_mapping, // [num_tokens]
|
const int* __restrict__ slot_mapping, // [num_tokens]
|
||||||
const int num_heads,
|
const int num_heads,
|
||||||
const int head_size,
|
const int head_size,
|
||||||
@ -73,10 +73,10 @@ __global__ void reshape_and_cache_kernel(
|
|||||||
+ x_idx * block_size * x
|
+ x_idx * block_size * x
|
||||||
+ block_offset * x
|
+ block_offset * x
|
||||||
+ x_offset;
|
+ x_offset;
|
||||||
const int tgt_value_idx = block_idx * num_heads * block_size * head_size
|
const int tgt_value_idx = block_idx * num_heads * head_size * block_size
|
||||||
+ head_idx * block_size * head_size
|
+ head_idx * head_size * block_size
|
||||||
+ block_offset * head_size
|
+ head_offset * block_size
|
||||||
+ head_offset;
|
+ block_offset;
|
||||||
key_cache[tgt_key_idx] = __ldg(&key[src_idx]);
|
key_cache[tgt_key_idx] = __ldg(&key[src_idx]);
|
||||||
value_cache[tgt_value_idx] = __ldg(&value[src_idx]);
|
value_cache[tgt_value_idx] = __ldg(&value[src_idx]);
|
||||||
}
|
}
|
||||||
|
1318
csrc/cuda_primitives.h
Normal file
1318
csrc/cuda_primitives.h
Normal file
File diff suppressed because it is too large
Load Diff
11
setup.py
11
setup.py
@ -9,15 +9,22 @@ ext_modules = []
|
|||||||
|
|
||||||
# Cache operations.
|
# Cache operations.
|
||||||
cache_extension = cpp_extension.CUDAExtension(
|
cache_extension = cpp_extension.CUDAExtension(
|
||||||
name='cacheflow.ops',
|
name='cacheflow.cache_ops',
|
||||||
sources=['csrc/cache.cpp', 'csrc/cache_kernels.cu'],
|
sources=['csrc/cache.cpp', 'csrc/cache_kernels.cu'],
|
||||||
extra_compile_args={'cxx': CXX_FLAGS, 'nvcc': NVCC_FLAGS},
|
extra_compile_args={'cxx': CXX_FLAGS, 'nvcc': NVCC_FLAGS},
|
||||||
)
|
)
|
||||||
ext_modules.append(cache_extension)
|
ext_modules.append(cache_extension)
|
||||||
|
|
||||||
|
# Attention kernels.
|
||||||
|
attention_extension = cpp_extension.CUDAExtension(
|
||||||
|
name='cacheflow.attention_ops',
|
||||||
|
sources=['csrc/attention.cpp', 'csrc/attention_kernels.cu'],
|
||||||
|
extra_compile_args={'cxx': CXX_FLAGS, 'nvcc': NVCC_FLAGS},
|
||||||
|
)
|
||||||
|
ext_modules.append(attention_extension)
|
||||||
|
|
||||||
setuptools.setup(
|
setuptools.setup(
|
||||||
name='cacheflow',
|
name='cacheflow',
|
||||||
requires_python='>=3.9',
|
|
||||||
ext_modules=ext_modules,
|
ext_modules=ext_modules,
|
||||||
cmdclass={'build_ext': cpp_extension.BuildExtension},
|
cmdclass={'build_ext': cpp_extension.BuildExtension},
|
||||||
)
|
)
|
||||||
|
142
tests/kernels/attention.py
Normal file
142
tests/kernels/attention.py
Normal file
@ -0,0 +1,142 @@
|
|||||||
|
import random
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from cacheflow import attention_ops
|
||||||
|
|
||||||
|
|
||||||
|
def ref_masked_attention(
|
||||||
|
query: torch.Tensor,
|
||||||
|
key: torch.Tensor,
|
||||||
|
value: torch.Tensor,
|
||||||
|
scale: float,
|
||||||
|
attn_mask: Optional[torch.Tensor] = None,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
query = query * scale
|
||||||
|
attn = torch.einsum('qhd,khd->hqk', query, key)
|
||||||
|
if attn_mask is not None:
|
||||||
|
attn = attn + attn_mask
|
||||||
|
attn = torch.softmax(attn, dim=-1)
|
||||||
|
out = torch.einsum('hqk,khd->qhd', attn, value)
|
||||||
|
return out
|
||||||
|
|
||||||
|
|
||||||
|
def ref_single_query_cached_kv_attention(
|
||||||
|
output: torch.Tensor,
|
||||||
|
query: torch.Tensor,
|
||||||
|
key_cache: torch.Tensor,
|
||||||
|
value_cache: torch.Tensor,
|
||||||
|
block_tables: torch.Tensor,
|
||||||
|
context_lens: torch.Tensor,
|
||||||
|
) -> None:
|
||||||
|
num_heads = value_cache.shape[1]
|
||||||
|
head_size = value_cache.shape[2]
|
||||||
|
block_size = value_cache.shape[3]
|
||||||
|
|
||||||
|
num_input_tokens = query.shape[0]
|
||||||
|
for i in range(num_input_tokens):
|
||||||
|
q = query[i].unsqueeze(0)
|
||||||
|
block_table = block_tables[i]
|
||||||
|
context_len = int(context_lens[i])
|
||||||
|
|
||||||
|
keys = []
|
||||||
|
values = []
|
||||||
|
for j in range(context_len):
|
||||||
|
block_number = int(block_table[j // block_size])
|
||||||
|
block_offset = j % block_size
|
||||||
|
|
||||||
|
k = key_cache[block_number, :, :, block_offset, :]
|
||||||
|
k = k.reshape(num_heads, head_size)
|
||||||
|
keys.append(k)
|
||||||
|
|
||||||
|
v = value_cache[block_number, :, :, block_offset]
|
||||||
|
values.append(v)
|
||||||
|
keys = torch.stack(keys, dim=0)
|
||||||
|
values = torch.stack(values, dim=0)
|
||||||
|
|
||||||
|
scale = 1.0 / (head_size ** 0.5)
|
||||||
|
out = ref_masked_attention(q, keys, values, scale)
|
||||||
|
out = out.view(num_heads, head_size)
|
||||||
|
output[i].copy_(out, non_blocking=True)
|
||||||
|
|
||||||
|
|
||||||
|
def test_single_query_cached_kv_attention(
|
||||||
|
num_tokens: int,
|
||||||
|
num_heads: int,
|
||||||
|
head_size: int,
|
||||||
|
block_size: int,
|
||||||
|
num_blocks: int,
|
||||||
|
dtype: torch.dtype,
|
||||||
|
) -> None:
|
||||||
|
query = torch.randn(
|
||||||
|
num_tokens, num_heads, head_size, dtype=dtype, device='cuda')
|
||||||
|
x = 16 // torch.tensor([], dtype=dtype).element_size()
|
||||||
|
key_block_shape = (num_heads, head_size // x, block_size, x)
|
||||||
|
key_cache = torch.randn(
|
||||||
|
size=(num_blocks, *key_block_shape), dtype=dtype, device='cuda')
|
||||||
|
value_block_shape = (num_heads, head_size, block_size)
|
||||||
|
value_cache = torch.randn(
|
||||||
|
size=(num_blocks, *value_block_shape), dtype=dtype, device='cuda')
|
||||||
|
|
||||||
|
context_lens = [random.randint(1, 4096) for _ in range(num_tokens)]
|
||||||
|
max_context_len = max(context_lens)
|
||||||
|
context_lens = torch.tensor(context_lens, dtype=torch.int, device='cuda')
|
||||||
|
|
||||||
|
max_num_blocks_per_seq = (max_context_len + block_size - 1) // block_size
|
||||||
|
block_tables = []
|
||||||
|
for _ in range(num_tokens):
|
||||||
|
block_table = [
|
||||||
|
random.randint(0, num_blocks - 1)
|
||||||
|
for _ in range(max_num_blocks_per_seq)
|
||||||
|
]
|
||||||
|
block_tables.append(block_table)
|
||||||
|
block_tables = torch.tensor(block_tables, dtype=torch.int, device='cuda')
|
||||||
|
|
||||||
|
scale = float(1.0 / (head_size ** 0.5))
|
||||||
|
output = torch.empty_like(query)
|
||||||
|
attention_ops.single_query_cached_kv_attention(
|
||||||
|
output,
|
||||||
|
query,
|
||||||
|
key_cache,
|
||||||
|
value_cache,
|
||||||
|
scale,
|
||||||
|
block_tables,
|
||||||
|
context_lens,
|
||||||
|
block_size,
|
||||||
|
max_context_len,
|
||||||
|
)
|
||||||
|
|
||||||
|
ref_output = torch.empty_like(query)
|
||||||
|
ref_single_query_cached_kv_attention(
|
||||||
|
ref_output,
|
||||||
|
query,
|
||||||
|
key_cache,
|
||||||
|
value_cache,
|
||||||
|
block_tables,
|
||||||
|
context_lens,
|
||||||
|
)
|
||||||
|
# NOTE(woosuk): Due to the difference in the data types the two
|
||||||
|
# implementations use for attention softmax logits and accumulation,
|
||||||
|
# there is a small difference in the final outputs.
|
||||||
|
# We should use a relaxed tolerance for the test.
|
||||||
|
assert torch.allclose(output, ref_output, atol=1e-3, rtol=1e-5)
|
||||||
|
|
||||||
|
|
||||||
|
@torch.inference_mode()
|
||||||
|
def test_attention() -> None:
|
||||||
|
for dtype in [torch.half, torch.float]:
|
||||||
|
for block_size in [8, 16]:
|
||||||
|
for head_size in [64, 80, 96, 128, 256]:
|
||||||
|
test_single_query_cached_kv_attention(
|
||||||
|
num_tokens=37,
|
||||||
|
num_heads=3,
|
||||||
|
head_size=head_size,
|
||||||
|
block_size=block_size,
|
||||||
|
num_blocks=1024,
|
||||||
|
dtype=dtype,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
test_attention()
|
@ -2,7 +2,7 @@ import random
|
|||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from cacheflow.ops import reshape_and_cache
|
from cacheflow import cache_ops
|
||||||
|
|
||||||
|
|
||||||
def test_reshape_and_cache(
|
def test_reshape_and_cache(
|
||||||
@ -26,30 +26,30 @@ def test_reshape_and_cache(
|
|||||||
key_cache = torch.randn(size=key_cache_shape, dtype=dtype, device='cuda')
|
key_cache = torch.randn(size=key_cache_shape, dtype=dtype, device='cuda')
|
||||||
cloned_key_cache = key_cache.clone()
|
cloned_key_cache = key_cache.clone()
|
||||||
|
|
||||||
value_cache_shape = (num_blocks, num_heads, block_size, head_size)
|
value_cache_shape = (num_blocks, num_heads, head_size, block_size)
|
||||||
value_cache = torch.randn(
|
value_cache = torch.randn(
|
||||||
size=value_cache_shape, dtype=dtype, device='cuda')
|
size=value_cache_shape, dtype=dtype, device='cuda')
|
||||||
cloned_value_cache = value_cache.clone()
|
cloned_value_cache = value_cache.clone()
|
||||||
|
|
||||||
reshape_and_cache(key, value, key_cache, value_cache, slot_mapping)
|
cache_ops.reshape_and_cache(key, value, key_cache, value_cache, slot_mapping)
|
||||||
|
|
||||||
for i in range(num_tokens):
|
for i in range(num_tokens):
|
||||||
reshaped_key = key.reshape(num_tokens, num_heads, head_size // x, x)
|
reshaped_key = key.reshape(num_tokens, num_heads, head_size // x, x)
|
||||||
block_idx = slot_mapping[i] // block_size
|
block_idx = slot_mapping[i] // block_size
|
||||||
block_offset = slot_mapping[i] % block_size
|
block_offset = slot_mapping[i] % block_size
|
||||||
cloned_key_cache[block_idx, :, :, block_offset, :] = reshaped_key[i]
|
cloned_key_cache[block_idx, :, :, block_offset, :] = reshaped_key[i]
|
||||||
cloned_value_cache[block_idx, :, block_offset, :] = value[i]
|
cloned_value_cache[block_idx, :, :, block_offset] = value[i]
|
||||||
|
|
||||||
assert torch.allclose(key_cache, cloned_key_cache)
|
assert torch.allclose(key_cache, cloned_key_cache)
|
||||||
assert torch.allclose(value_cache, cloned_value_cache)
|
assert torch.allclose(value_cache, cloned_value_cache)
|
||||||
|
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.inference_mode()
|
||||||
def test_kernels():
|
def test_cache() -> None:
|
||||||
test_reshape_and_cache(
|
test_reshape_and_cache(
|
||||||
num_tokens=3, num_heads=2, head_size=16, block_size=2, num_blocks=2,
|
num_tokens=3, num_heads=2, head_size=16, block_size=8, num_blocks=2,
|
||||||
dtype=torch.half)
|
dtype=torch.half)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
test_kernels()
|
test_cache()
|
Loading…
x
Reference in New Issue
Block a user