[Kernel][Backend][Model] Blocksparse flash attention kernel and Phi-3-Small model (#4799)
Co-authored-by: beagleski <yunanzhang@microsoft.com> Co-authored-by: bapatra <bapatra@microsoft.com> Co-authored-by: Barun Patra <codedecde@users.noreply.github.com> Co-authored-by: Michael Goin <michael@neuralmagic.com>
This commit is contained in:
parent
e64fde4b01
commit
8e192ff967
@ -85,6 +85,7 @@ inline __device__ float block_sum(float* red_smem, float sum) {
|
|||||||
// Grid: (num_heads, num_seqs, max_num_partitions).
|
// Grid: (num_heads, num_seqs, max_num_partitions).
|
||||||
template <typename scalar_t, typename cache_t, int HEAD_SIZE, int BLOCK_SIZE,
|
template <typename scalar_t, typename cache_t, int HEAD_SIZE, int BLOCK_SIZE,
|
||||||
int NUM_THREADS, vllm::Fp8KVCacheDataType KV_DTYPE,
|
int NUM_THREADS, vllm::Fp8KVCacheDataType KV_DTYPE,
|
||||||
|
bool IS_BLOCK_SPARSE,
|
||||||
int PARTITION_SIZE = 0> // Zero means no partitioning.
|
int PARTITION_SIZE = 0> // Zero means no partitioning.
|
||||||
__device__ void paged_attention_kernel(
|
__device__ void paged_attention_kernel(
|
||||||
float* __restrict__ exp_sums, // [num_seqs, num_heads, max_num_partitions]
|
float* __restrict__ exp_sums, // [num_seqs, num_heads, max_num_partitions]
|
||||||
@ -104,7 +105,9 @@ __device__ void paged_attention_kernel(
|
|||||||
const int max_num_blocks_per_seq,
|
const int max_num_blocks_per_seq,
|
||||||
const float* __restrict__ alibi_slopes, // [num_heads]
|
const float* __restrict__ alibi_slopes, // [num_heads]
|
||||||
const int q_stride, const int kv_block_stride, const int kv_head_stride,
|
const int q_stride, const int kv_block_stride, const int kv_head_stride,
|
||||||
const float kv_scale) {
|
const float kv_scale, const int tp_rank, const int blocksparse_local_blocks,
|
||||||
|
const int blocksparse_vert_stride, const int blocksparse_block_size,
|
||||||
|
const int blocksparse_head_sliding_step) {
|
||||||
const int seq_idx = blockIdx.y;
|
const int seq_idx = blockIdx.y;
|
||||||
const int partition_idx = blockIdx.z;
|
const int partition_idx = blockIdx.z;
|
||||||
const int max_num_partitions = gridDim.z;
|
const int max_num_partitions = gridDim.z;
|
||||||
@ -202,11 +205,55 @@ __device__ void paged_attention_kernel(
|
|||||||
// Each thread group in a warp fetches a key from the block, and computes
|
// Each thread group in a warp fetches a key from the block, and computes
|
||||||
// dot product with the query.
|
// dot product with the query.
|
||||||
const int* block_table = block_tables + seq_idx * max_num_blocks_per_seq;
|
const int* block_table = block_tables + seq_idx * max_num_blocks_per_seq;
|
||||||
|
|
||||||
|
// blocksparse specific vars
|
||||||
|
int bs_block_offset;
|
||||||
|
int q_bs_block_id;
|
||||||
|
if constexpr (IS_BLOCK_SPARSE) {
|
||||||
|
// const int num_blocksparse_blocks = DIVIDE_ROUND_UP(seq_len,
|
||||||
|
// blocksparse_block_size);
|
||||||
|
q_bs_block_id = (seq_len - 1) / blocksparse_block_size;
|
||||||
|
if (blocksparse_head_sliding_step >= 0)
|
||||||
|
// sliding on q heads
|
||||||
|
bs_block_offset =
|
||||||
|
(tp_rank * num_heads + head_idx) * blocksparse_head_sliding_step + 1;
|
||||||
|
else
|
||||||
|
// sliding on kv heads
|
||||||
|
bs_block_offset = (tp_rank * num_kv_heads + kv_head_idx) *
|
||||||
|
(-blocksparse_head_sliding_step) +
|
||||||
|
1;
|
||||||
|
}
|
||||||
|
|
||||||
for (int block_idx = start_block_idx + warp_idx; block_idx < end_block_idx;
|
for (int block_idx = start_block_idx + warp_idx; block_idx < end_block_idx;
|
||||||
block_idx += NUM_WARPS) {
|
block_idx += NUM_WARPS) {
|
||||||
// NOTE(woosuk): The block number is stored in int32. However, we cast it to
|
// NOTE(woosuk): The block number is stored in int32. However, we cast it to
|
||||||
// int64 because int32 can lead to overflow when this variable is multiplied
|
// int64 because int32 can lead to overflow when this variable is multiplied
|
||||||
// by large numbers (e.g., kv_block_stride).
|
// by large numbers (e.g., kv_block_stride).
|
||||||
|
// For blocksparse attention: skip computation on blocks that are not
|
||||||
|
// attended
|
||||||
|
if constexpr (IS_BLOCK_SPARSE) {
|
||||||
|
const int k_bs_block_id = block_idx * BLOCK_SIZE / blocksparse_block_size;
|
||||||
|
const bool is_remote =
|
||||||
|
((k_bs_block_id + bs_block_offset) % blocksparse_vert_stride == 0);
|
||||||
|
const bool is_local =
|
||||||
|
(k_bs_block_id > q_bs_block_id - blocksparse_local_blocks);
|
||||||
|
if (!is_remote && !is_local) {
|
||||||
|
for (int i = 0; i < NUM_TOKENS_PER_THREAD_GROUP; i++) {
|
||||||
|
const int physical_block_offset =
|
||||||
|
(thread_group_idx + i * WARP_SIZE) % BLOCK_SIZE;
|
||||||
|
const int token_idx = block_idx * BLOCK_SIZE + physical_block_offset;
|
||||||
|
|
||||||
|
if (thread_group_offset == 0) {
|
||||||
|
// NOTE(linxihui): assign very large number to skipped tokens to
|
||||||
|
// avoid contribution to the sumexp softmax normalizer. This will
|
||||||
|
// not be used at computing sum(softmax*v) as the blocks will be
|
||||||
|
// skipped.
|
||||||
|
logits[token_idx - start_token_idx] = -FLT_MAX;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
}
|
||||||
const int64_t physical_block_number =
|
const int64_t physical_block_number =
|
||||||
static_cast<int64_t>(block_table[block_idx]);
|
static_cast<int64_t>(block_table[block_idx]);
|
||||||
|
|
||||||
@ -335,6 +382,15 @@ __device__ void paged_attention_kernel(
|
|||||||
// NOTE(woosuk): The block number is stored in int32. However, we cast it to
|
// NOTE(woosuk): The block number is stored in int32. However, we cast it to
|
||||||
// int64 because int32 can lead to overflow when this variable is multiplied
|
// int64 because int32 can lead to overflow when this variable is multiplied
|
||||||
// by large numbers (e.g., kv_block_stride).
|
// by large numbers (e.g., kv_block_stride).
|
||||||
|
// For blocksparse attention: skip computation on blocks that are not
|
||||||
|
// attended
|
||||||
|
if constexpr (IS_BLOCK_SPARSE) {
|
||||||
|
int v_bs_block_id = block_idx * BLOCK_SIZE / blocksparse_block_size;
|
||||||
|
if (!((v_bs_block_id + bs_block_offset) % blocksparse_vert_stride == 0) &&
|
||||||
|
!((v_bs_block_id > q_bs_block_id - blocksparse_local_blocks))) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
}
|
||||||
const int64_t physical_block_number =
|
const int64_t physical_block_number =
|
||||||
static_cast<int64_t>(block_table[block_idx]);
|
static_cast<int64_t>(block_table[block_idx]);
|
||||||
const int physical_block_offset = (lane % NUM_V_VECS_PER_ROW) * V_VEC_SIZE;
|
const int physical_block_offset = (lane % NUM_V_VECS_PER_ROW) * V_VEC_SIZE;
|
||||||
@ -441,8 +497,8 @@ __device__ void paged_attention_kernel(
|
|||||||
|
|
||||||
// Grid: (num_heads, num_seqs, 1).
|
// Grid: (num_heads, num_seqs, 1).
|
||||||
template <typename scalar_t, typename cache_t, int HEAD_SIZE, int BLOCK_SIZE,
|
template <typename scalar_t, typename cache_t, int HEAD_SIZE, int BLOCK_SIZE,
|
||||||
int NUM_THREADS,
|
int NUM_THREADS, vllm::Fp8KVCacheDataType KV_DTYPE,
|
||||||
vllm::Fp8KVCacheDataType KV_DTYPE>
|
bool IS_BLOCK_SPARSE>
|
||||||
__global__ void paged_attention_v1_kernel(
|
__global__ void paged_attention_v1_kernel(
|
||||||
scalar_t* __restrict__ out, // [num_seqs, num_heads, head_size]
|
scalar_t* __restrict__ out, // [num_seqs, num_heads, head_size]
|
||||||
const scalar_t* __restrict__ q, // [num_seqs, num_heads, head_size]
|
const scalar_t* __restrict__ q, // [num_seqs, num_heads, head_size]
|
||||||
@ -457,18 +513,23 @@ __global__ void paged_attention_v1_kernel(
|
|||||||
const int max_num_blocks_per_seq,
|
const int max_num_blocks_per_seq,
|
||||||
const float* __restrict__ alibi_slopes, // [num_heads]
|
const float* __restrict__ alibi_slopes, // [num_heads]
|
||||||
const int q_stride, const int kv_block_stride, const int kv_head_stride,
|
const int q_stride, const int kv_block_stride, const int kv_head_stride,
|
||||||
const float kv_scale) {
|
const float kv_scale, const int tp_rank, const int blocksparse_local_blocks,
|
||||||
|
const int blocksparse_vert_stride, const int blocksparse_block_size,
|
||||||
|
const int blocksparse_head_sliding_step) {
|
||||||
paged_attention_kernel<scalar_t, cache_t, HEAD_SIZE, BLOCK_SIZE, NUM_THREADS,
|
paged_attention_kernel<scalar_t, cache_t, HEAD_SIZE, BLOCK_SIZE, NUM_THREADS,
|
||||||
KV_DTYPE>(
|
KV_DTYPE, IS_BLOCK_SPARSE>(
|
||||||
/* exp_sums */ nullptr, /* max_logits */ nullptr, out, q, k_cache,
|
/* exp_sums */ nullptr, /* max_logits */ nullptr, out, q, k_cache,
|
||||||
v_cache, num_kv_heads, scale, block_tables, seq_lens,
|
v_cache, num_kv_heads, scale, block_tables, seq_lens,
|
||||||
max_num_blocks_per_seq, alibi_slopes, q_stride, kv_block_stride,
|
max_num_blocks_per_seq, alibi_slopes, q_stride, kv_block_stride,
|
||||||
kv_head_stride, kv_scale);
|
kv_head_stride, kv_scale, tp_rank, blocksparse_local_blocks,
|
||||||
|
blocksparse_vert_stride, blocksparse_block_size,
|
||||||
|
blocksparse_head_sliding_step);
|
||||||
}
|
}
|
||||||
|
|
||||||
// Grid: (num_heads, num_seqs, max_num_partitions).
|
// Grid: (num_heads, num_seqs, max_num_partitions).
|
||||||
template <typename scalar_t, typename cache_t, int HEAD_SIZE, int BLOCK_SIZE,
|
template <typename scalar_t, typename cache_t, int HEAD_SIZE, int BLOCK_SIZE,
|
||||||
int NUM_THREADS, vllm::Fp8KVCacheDataType KV_DTYPE,
|
int NUM_THREADS, vllm::Fp8KVCacheDataType KV_DTYPE,
|
||||||
|
bool IS_BLOCK_SPARSE,
|
||||||
int PARTITION_SIZE>
|
int PARTITION_SIZE>
|
||||||
__global__ void paged_attention_v2_kernel(
|
__global__ void paged_attention_v2_kernel(
|
||||||
float* __restrict__ exp_sums, // [num_seqs, num_heads, max_num_partitions]
|
float* __restrict__ exp_sums, // [num_seqs, num_heads, max_num_partitions]
|
||||||
@ -488,12 +549,16 @@ __global__ void paged_attention_v2_kernel(
|
|||||||
const int max_num_blocks_per_seq,
|
const int max_num_blocks_per_seq,
|
||||||
const float* __restrict__ alibi_slopes, // [num_heads]
|
const float* __restrict__ alibi_slopes, // [num_heads]
|
||||||
const int q_stride, const int kv_block_stride, const int kv_head_stride,
|
const int q_stride, const int kv_block_stride, const int kv_head_stride,
|
||||||
const float kv_scale) {
|
const float kv_scale, const int tp_rank, const int blocksparse_local_blocks,
|
||||||
|
const int blocksparse_vert_stride, const int blocksparse_block_size,
|
||||||
|
const int blocksparse_head_sliding_step) {
|
||||||
paged_attention_kernel<scalar_t, cache_t, HEAD_SIZE, BLOCK_SIZE, NUM_THREADS,
|
paged_attention_kernel<scalar_t, cache_t, HEAD_SIZE, BLOCK_SIZE, NUM_THREADS,
|
||||||
KV_DTYPE, PARTITION_SIZE>(
|
KV_DTYPE, IS_BLOCK_SPARSE, PARTITION_SIZE>(
|
||||||
exp_sums, max_logits, tmp_out, q, k_cache, v_cache, num_kv_heads, scale,
|
exp_sums, max_logits, tmp_out, q, k_cache, v_cache, num_kv_heads, scale,
|
||||||
block_tables, seq_lens, max_num_blocks_per_seq, alibi_slopes, q_stride,
|
block_tables, seq_lens, max_num_blocks_per_seq, alibi_slopes, q_stride,
|
||||||
kv_block_stride, kv_head_stride, kv_scale);
|
kv_block_stride, kv_head_stride, kv_scale, tp_rank,
|
||||||
|
blocksparse_local_blocks, blocksparse_vert_stride, blocksparse_block_size,
|
||||||
|
blocksparse_head_sliding_step);
|
||||||
}
|
}
|
||||||
|
|
||||||
// Grid: (num_heads, num_seqs).
|
// Grid: (num_heads, num_seqs).
|
||||||
@ -607,25 +672,32 @@ __global__ void paged_attention_v2_reduce_kernel(
|
|||||||
|
|
||||||
#define LAUNCH_PAGED_ATTENTION_V1(HEAD_SIZE) \
|
#define LAUNCH_PAGED_ATTENTION_V1(HEAD_SIZE) \
|
||||||
VLLM_DevFuncAttribute_SET_MaxDynamicSharedMemorySize( \
|
VLLM_DevFuncAttribute_SET_MaxDynamicSharedMemorySize( \
|
||||||
((void*)vllm::paged_attention_v1_kernel< \
|
((void*)vllm::paged_attention_v1_kernel<T, CACHE_T, HEAD_SIZE, \
|
||||||
T, CACHE_T, HEAD_SIZE, BLOCK_SIZE, NUM_THREADS, KV_DTYPE>), \
|
BLOCK_SIZE, NUM_THREADS, \
|
||||||
|
KV_DTYPE, IS_BLOCK_SPARSE>), \
|
||||||
shared_mem_size); \
|
shared_mem_size); \
|
||||||
vllm::paged_attention_v1_kernel<T, CACHE_T, HEAD_SIZE, BLOCK_SIZE, \
|
vllm::paged_attention_v1_kernel<T, CACHE_T, HEAD_SIZE, BLOCK_SIZE, \
|
||||||
NUM_THREADS, KV_DTYPE> \
|
NUM_THREADS, KV_DTYPE, IS_BLOCK_SPARSE> \
|
||||||
<<<grid, block, shared_mem_size, stream>>>( \
|
<<<grid, block, shared_mem_size, stream>>>( \
|
||||||
out_ptr, query_ptr, key_cache_ptr, value_cache_ptr, num_kv_heads, \
|
out_ptr, query_ptr, key_cache_ptr, value_cache_ptr, num_kv_heads, \
|
||||||
scale, block_tables_ptr, seq_lens_ptr, max_num_blocks_per_seq, \
|
scale, block_tables_ptr, seq_lens_ptr, max_num_blocks_per_seq, \
|
||||||
alibi_slopes_ptr, q_stride, kv_block_stride, kv_head_stride, \
|
alibi_slopes_ptr, q_stride, kv_block_stride, kv_head_stride, \
|
||||||
kv_scale);
|
kv_scale, tp_rank, blocksparse_local_blocks, \
|
||||||
|
blocksparse_vert_stride, blocksparse_block_size, \
|
||||||
|
blocksparse_head_sliding_step);
|
||||||
|
|
||||||
// TODO(woosuk): Tune NUM_THREADS.
|
// TODO(woosuk): Tune NUM_THREADS.
|
||||||
template <typename T, typename CACHE_T, int BLOCK_SIZE,
|
template <typename T, typename CACHE_T, int BLOCK_SIZE,
|
||||||
vllm::Fp8KVCacheDataType KV_DTYPE, int NUM_THREADS = 128>
|
vllm::Fp8KVCacheDataType KV_DTYPE, bool IS_BLOCK_SPARSE,
|
||||||
|
int NUM_THREADS = 128>
|
||||||
void paged_attention_v1_launcher(
|
void paged_attention_v1_launcher(
|
||||||
torch::Tensor& out, torch::Tensor& query, torch::Tensor& key_cache,
|
torch::Tensor& out, torch::Tensor& query, torch::Tensor& key_cache,
|
||||||
torch::Tensor& value_cache, int num_kv_heads, float scale,
|
torch::Tensor& value_cache, int num_kv_heads, float scale,
|
||||||
torch::Tensor& block_tables, torch::Tensor& seq_lens, int max_seq_len,
|
torch::Tensor& block_tables, torch::Tensor& seq_lens, int max_seq_len,
|
||||||
const c10::optional<torch::Tensor>& alibi_slopes, float kv_scale) {
|
const c10::optional<torch::Tensor>& alibi_slopes, float kv_scale,
|
||||||
|
const int tp_rank, const int blocksparse_local_blocks,
|
||||||
|
const int blocksparse_vert_stride, const int blocksparse_block_size,
|
||||||
|
const int blocksparse_head_sliding_step) {
|
||||||
int num_seqs = query.size(0);
|
int num_seqs = query.size(0);
|
||||||
int num_heads = query.size(1);
|
int num_heads = query.size(1);
|
||||||
int head_size = query.size(2);
|
int head_size = query.size(2);
|
||||||
@ -691,23 +763,36 @@ void paged_attention_v1_launcher(
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#define CALL_V1_LAUNCHER(T, CACHE_T, BLOCK_SIZE, KV_DTYPE) \
|
#define CALL_V1_LAUNCHER(T, CACHE_T, BLOCK_SIZE, KV_DTYPE, IS_BLOCK_SPARSE) \
|
||||||
paged_attention_v1_launcher<T, CACHE_T, BLOCK_SIZE, KV_DTYPE>( \
|
paged_attention_v1_launcher<T, CACHE_T, BLOCK_SIZE, KV_DTYPE, \
|
||||||
|
IS_BLOCK_SPARSE>( \
|
||||||
out, query, key_cache, value_cache, num_kv_heads, scale, block_tables, \
|
out, query, key_cache, value_cache, num_kv_heads, scale, block_tables, \
|
||||||
seq_lens, max_seq_len, alibi_slopes, kv_scale);
|
seq_lens, max_seq_len, alibi_slopes, kv_scale, tp_rank, \
|
||||||
|
blocksparse_local_blocks, blocksparse_vert_stride, \
|
||||||
|
blocksparse_block_size, blocksparse_head_sliding_step);
|
||||||
|
|
||||||
|
#define CALL_V1_LAUNCHER_SPARSITY(T, CACHE_T, BLOCK_SIZE, IS_FP8_KV_CACHE) \
|
||||||
|
switch (is_block_sparse) { \
|
||||||
|
case true: \
|
||||||
|
CALL_V1_LAUNCHER(T, CACHE_T, BLOCK_SIZE, IS_FP8_KV_CACHE, true); \
|
||||||
|
break; \
|
||||||
|
case false: \
|
||||||
|
CALL_V1_LAUNCHER(T, CACHE_T, BLOCK_SIZE, IS_FP8_KV_CACHE, false); \
|
||||||
|
break; \
|
||||||
|
}
|
||||||
|
|
||||||
// NOTE(woosuk): To reduce the compilation time, we omitted block sizes
|
// NOTE(woosuk): To reduce the compilation time, we omitted block sizes
|
||||||
// 1, 2, 4, 64, 128, 256.
|
// 1, 2, 4, 64, 128, 256.
|
||||||
#define CALL_V1_LAUNCHER_BLOCK_SIZE(T, CACHE_T, KV_DTYPE) \
|
#define CALL_V1_LAUNCHER_BLOCK_SIZE(T, CACHE_T, KV_DTYPE) \
|
||||||
switch (block_size) { \
|
switch (block_size) { \
|
||||||
case 8: \
|
case 8: \
|
||||||
CALL_V1_LAUNCHER(T, CACHE_T, 8, KV_DTYPE); \
|
CALL_V1_LAUNCHER_SPARSITY(T, CACHE_T, 8, KV_DTYPE); \
|
||||||
break; \
|
break; \
|
||||||
case 16: \
|
case 16: \
|
||||||
CALL_V1_LAUNCHER(T, CACHE_T, 16, KV_DTYPE); \
|
CALL_V1_LAUNCHER_SPARSITY(T, CACHE_T, 16, KV_DTYPE); \
|
||||||
break; \
|
break; \
|
||||||
case 32: \
|
case 32: \
|
||||||
CALL_V1_LAUNCHER(T, CACHE_T, 32, KV_DTYPE); \
|
CALL_V1_LAUNCHER_SPARSITY(T, CACHE_T, 32, KV_DTYPE); \
|
||||||
break; \
|
break; \
|
||||||
default: \
|
default: \
|
||||||
TORCH_CHECK(false, "Unsupported block size: ", block_size); \
|
TORCH_CHECK(false, "Unsupported block size: ", block_size); \
|
||||||
@ -727,18 +812,26 @@ void paged_attention_v1(
|
|||||||
torch::Tensor& seq_lens, // [num_seqs]
|
torch::Tensor& seq_lens, // [num_seqs]
|
||||||
int block_size, int max_seq_len,
|
int block_size, int max_seq_len,
|
||||||
const c10::optional<torch::Tensor>& alibi_slopes,
|
const c10::optional<torch::Tensor>& alibi_slopes,
|
||||||
const std::string& kv_cache_dtype, float kv_scale){
|
const std::string& kv_cache_dtype, float kv_scale, const int tp_rank,
|
||||||
|
const int blocksparse_local_blocks, const int blocksparse_vert_stride,
|
||||||
|
const int blocksparse_block_size, const int blocksparse_head_sliding_step) {
|
||||||
|
const bool is_block_sparse = (blocksparse_vert_stride > 1);
|
||||||
|
|
||||||
DISPATCH_BY_KV_CACHE_DTYPE(query.dtype(), kv_cache_dtype,
|
DISPATCH_BY_KV_CACHE_DTYPE(query.dtype(), kv_cache_dtype,
|
||||||
CALL_V1_LAUNCHER_BLOCK_SIZE)}
|
CALL_V1_LAUNCHER_BLOCK_SIZE)
|
||||||
|
}
|
||||||
|
|
||||||
#define LAUNCH_PAGED_ATTENTION_V2(HEAD_SIZE) \
|
#define LAUNCH_PAGED_ATTENTION_V2(HEAD_SIZE) \
|
||||||
vllm::paged_attention_v2_kernel<T, CACHE_T, HEAD_SIZE, BLOCK_SIZE, \
|
vllm::paged_attention_v2_kernel<T, CACHE_T, HEAD_SIZE, BLOCK_SIZE, \
|
||||||
NUM_THREADS, KV_DTYPE, PARTITION_SIZE> \
|
NUM_THREADS, KV_DTYPE, IS_BLOCK_SPARSE, \
|
||||||
|
PARTITION_SIZE> \
|
||||||
<<<grid, block, shared_mem_size, stream>>>( \
|
<<<grid, block, shared_mem_size, stream>>>( \
|
||||||
exp_sums_ptr, max_logits_ptr, tmp_out_ptr, query_ptr, key_cache_ptr, \
|
exp_sums_ptr, max_logits_ptr, tmp_out_ptr, query_ptr, key_cache_ptr, \
|
||||||
value_cache_ptr, num_kv_heads, scale, block_tables_ptr, \
|
value_cache_ptr, num_kv_heads, scale, block_tables_ptr, \
|
||||||
seq_lens_ptr, max_num_blocks_per_seq, alibi_slopes_ptr, q_stride, \
|
seq_lens_ptr, max_num_blocks_per_seq, alibi_slopes_ptr, q_stride, \
|
||||||
kv_block_stride, kv_head_stride, kv_scale); \
|
kv_block_stride, kv_head_stride, kv_scale, tp_rank, \
|
||||||
|
blocksparse_local_blocks, blocksparse_vert_stride, \
|
||||||
|
blocksparse_block_size, blocksparse_head_sliding_step); \
|
||||||
vllm::paged_attention_v2_reduce_kernel<T, HEAD_SIZE, NUM_THREADS, \
|
vllm::paged_attention_v2_reduce_kernel<T, HEAD_SIZE, NUM_THREADS, \
|
||||||
PARTITION_SIZE> \
|
PARTITION_SIZE> \
|
||||||
<<<reduce_grid, block, reduce_shared_mem_size, stream>>>( \
|
<<<reduce_grid, block, reduce_shared_mem_size, stream>>>( \
|
||||||
@ -746,14 +839,17 @@ void paged_attention_v1(
|
|||||||
max_num_partitions);
|
max_num_partitions);
|
||||||
|
|
||||||
template <typename T, typename CACHE_T, int BLOCK_SIZE,
|
template <typename T, typename CACHE_T, int BLOCK_SIZE,
|
||||||
vllm::Fp8KVCacheDataType KV_DTYPE, int NUM_THREADS = 128,
|
vllm::Fp8KVCacheDataType KV_DTYPE, bool IS_BLOCK_SPARSE,
|
||||||
int PARTITION_SIZE = 512>
|
int NUM_THREADS = 128, int PARTITION_SIZE = 512>
|
||||||
void paged_attention_v2_launcher(
|
void paged_attention_v2_launcher(
|
||||||
torch::Tensor& out, torch::Tensor& exp_sums, torch::Tensor& max_logits,
|
torch::Tensor& out, torch::Tensor& exp_sums, torch::Tensor& max_logits,
|
||||||
torch::Tensor& tmp_out, torch::Tensor& query, torch::Tensor& key_cache,
|
torch::Tensor& tmp_out, torch::Tensor& query, torch::Tensor& key_cache,
|
||||||
torch::Tensor& value_cache, int num_kv_heads, float scale,
|
torch::Tensor& value_cache, int num_kv_heads, float scale,
|
||||||
torch::Tensor& block_tables, torch::Tensor& seq_lens, int max_seq_len,
|
torch::Tensor& block_tables, torch::Tensor& seq_lens, int max_seq_len,
|
||||||
const c10::optional<torch::Tensor>& alibi_slopes, float kv_scale) {
|
const c10::optional<torch::Tensor>& alibi_slopes, float kv_scale,
|
||||||
|
const int tp_rank, const int blocksparse_local_blocks,
|
||||||
|
const int blocksparse_vert_stride, const int blocksparse_block_size,
|
||||||
|
const int blocksparse_head_sliding_step) {
|
||||||
int num_seqs = query.size(0);
|
int num_seqs = query.size(0);
|
||||||
int num_heads = query.size(1);
|
int num_heads = query.size(1);
|
||||||
int head_size = query.size(2);
|
int head_size = query.size(2);
|
||||||
@ -824,24 +920,36 @@ void paged_attention_v2_launcher(
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#define CALL_V2_LAUNCHER(T, CACHE_T, BLOCK_SIZE, KV_DTYPE) \
|
#define CALL_V2_LAUNCHER(T, CACHE_T, BLOCK_SIZE, KV_DTYPE, IS_BLOCK_SPARSE) \
|
||||||
paged_attention_v2_launcher<T, CACHE_T, BLOCK_SIZE, KV_DTYPE>( \
|
paged_attention_v2_launcher<T, CACHE_T, BLOCK_SIZE, KV_DTYPE, \
|
||||||
|
IS_BLOCK_SPARSE>( \
|
||||||
out, exp_sums, max_logits, tmp_out, query, key_cache, value_cache, \
|
out, exp_sums, max_logits, tmp_out, query, key_cache, value_cache, \
|
||||||
num_kv_heads, scale, block_tables, seq_lens, max_seq_len, alibi_slopes, \
|
num_kv_heads, scale, block_tables, seq_lens, max_seq_len, alibi_slopes, \
|
||||||
kv_scale);
|
kv_scale, tp_rank, blocksparse_local_blocks, blocksparse_vert_stride, \
|
||||||
|
blocksparse_block_size, blocksparse_head_sliding_step);
|
||||||
|
|
||||||
|
#define CALL_V2_LAUNCHER_SPARSITY(T, CACHE_T, BLOCK_SIZE, IS_FP8_KV_CACHE) \
|
||||||
|
switch (is_block_sparse) { \
|
||||||
|
case true: \
|
||||||
|
CALL_V2_LAUNCHER(T, CACHE_T, BLOCK_SIZE, IS_FP8_KV_CACHE, true); \
|
||||||
|
break; \
|
||||||
|
case false: \
|
||||||
|
CALL_V2_LAUNCHER(T, CACHE_T, BLOCK_SIZE, IS_FP8_KV_CACHE, false); \
|
||||||
|
break; \
|
||||||
|
}
|
||||||
|
|
||||||
// NOTE(woosuk): To reduce the compilation time, we omitted block sizes
|
// NOTE(woosuk): To reduce the compilation time, we omitted block sizes
|
||||||
// 1, 2, 4, 64, 128, 256.
|
// 1, 2, 4, 64, 128, 256.
|
||||||
#define CALL_V2_LAUNCHER_BLOCK_SIZE(T, CACHE_T, KV_DTYPE) \
|
#define CALL_V2_LAUNCHER_BLOCK_SIZE(T, CACHE_T, KV_DTYPE) \
|
||||||
switch (block_size) { \
|
switch (block_size) { \
|
||||||
case 8: \
|
case 8: \
|
||||||
CALL_V2_LAUNCHER(T, CACHE_T, 8, KV_DTYPE); \
|
CALL_V2_LAUNCHER_SPARSITY(T, CACHE_T, 8, KV_DTYPE); \
|
||||||
break; \
|
break; \
|
||||||
case 16: \
|
case 16: \
|
||||||
CALL_V2_LAUNCHER(T, CACHE_T, 16, KV_DTYPE); \
|
CALL_V2_LAUNCHER_SPARSITY(T, CACHE_T, 16, KV_DTYPE); \
|
||||||
break; \
|
break; \
|
||||||
case 32: \
|
case 32: \
|
||||||
CALL_V2_LAUNCHER(T, CACHE_T, 32, KV_DTYPE); \
|
CALL_V2_LAUNCHER_SPARSITY(T, CACHE_T, 32, KV_DTYPE); \
|
||||||
break; \
|
break; \
|
||||||
default: \
|
default: \
|
||||||
TORCH_CHECK(false, "Unsupported block size: ", block_size); \
|
TORCH_CHECK(false, "Unsupported block size: ", block_size); \
|
||||||
@ -865,7 +973,10 @@ void paged_attention_v2(
|
|||||||
torch::Tensor& seq_lens, // [num_seqs]
|
torch::Tensor& seq_lens, // [num_seqs]
|
||||||
int block_size, int max_seq_len,
|
int block_size, int max_seq_len,
|
||||||
const c10::optional<torch::Tensor>& alibi_slopes,
|
const c10::optional<torch::Tensor>& alibi_slopes,
|
||||||
const std::string& kv_cache_dtype, float kv_scale) {
|
const std::string& kv_cache_dtype, float kv_scale, const int tp_rank,
|
||||||
|
const int blocksparse_local_blocks, const int blocksparse_vert_stride,
|
||||||
|
const int blocksparse_block_size, const int blocksparse_head_sliding_step) {
|
||||||
|
const bool is_block_sparse = (blocksparse_vert_stride > 1);
|
||||||
DISPATCH_BY_KV_CACHE_DTYPE(query.dtype(), kv_cache_dtype,
|
DISPATCH_BY_KV_CACHE_DTYPE(query.dtype(), kv_cache_dtype,
|
||||||
CALL_V2_LAUNCHER_BLOCK_SIZE)
|
CALL_V2_LAUNCHER_BLOCK_SIZE)
|
||||||
}
|
}
|
||||||
|
@ -415,14 +415,17 @@ void paged_attention_v1_impl_launcher(
|
|||||||
}
|
}
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
void paged_attention_v1(torch::Tensor& out, torch::Tensor& query,
|
void paged_attention_v1(
|
||||||
torch::Tensor& key_cache, torch::Tensor& value_cache,
|
torch::Tensor& out, torch::Tensor& query, torch::Tensor& key_cache,
|
||||||
int num_kv_heads, float scale,
|
torch::Tensor& value_cache, int num_kv_heads, float scale,
|
||||||
torch::Tensor& block_tables, torch::Tensor& seq_lens,
|
torch::Tensor& block_tables, torch::Tensor& seq_lens, int block_size,
|
||||||
int block_size, int max_seq_len,
|
int max_seq_len, const c10::optional<torch::Tensor>& alibi_slopes,
|
||||||
const c10::optional<torch::Tensor>& alibi_slopes,
|
const std::string& kv_cache_dtype, float kv_scale, const int tp_rank,
|
||||||
const std::string& kv_cache_dtype, float kv_scale) {
|
const int blocksparse_local_blocks, const int blocksparse_vert_stride,
|
||||||
|
const int blocksparse_block_size, const int blocksparse_head_sliding_step) {
|
||||||
TORCH_CHECK(kv_scale == 1.0f);
|
TORCH_CHECK(kv_scale == 1.0f);
|
||||||
|
TORCH_CHECK(blocksparse_vert_stride <= 1,
|
||||||
|
"CPU backend does not support blocksparse attention yet.");
|
||||||
VLLM_DISPATCH_FLOATING_TYPES(query.scalar_type(), "paged_attention_v1_impl",
|
VLLM_DISPATCH_FLOATING_TYPES(query.scalar_type(), "paged_attention_v1_impl",
|
||||||
[&] {
|
[&] {
|
||||||
CPU_KERNEL_GUARD_IN(paged_attention_v1_impl)
|
CPU_KERNEL_GUARD_IN(paged_attention_v1_impl)
|
||||||
@ -726,16 +729,18 @@ void paged_attention_v2_impl_launcher(
|
|||||||
}
|
}
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
void paged_attention_v2(torch::Tensor& out, torch::Tensor& exp_sums,
|
void paged_attention_v2(
|
||||||
torch::Tensor& max_logits, torch::Tensor& tmp_out,
|
torch::Tensor& out, torch::Tensor& exp_sums, torch::Tensor& max_logits,
|
||||||
torch::Tensor& query, torch::Tensor& key_cache,
|
torch::Tensor& tmp_out, torch::Tensor& query, torch::Tensor& key_cache,
|
||||||
torch::Tensor& value_cache, int num_kv_heads,
|
torch::Tensor& value_cache, int num_kv_heads, float scale,
|
||||||
float scale, torch::Tensor& block_tables,
|
torch::Tensor& block_tables, torch::Tensor& seq_lens, int block_size,
|
||||||
torch::Tensor& seq_lens, int block_size,
|
int max_seq_len, const c10::optional<torch::Tensor>& alibi_slopes,
|
||||||
int max_seq_len,
|
const std::string& kv_cache_dtype, float kv_scale, const int tp_rank,
|
||||||
const c10::optional<torch::Tensor>& alibi_slopes,
|
const int blocksparse_local_blocks, const int blocksparse_vert_stride,
|
||||||
const std::string& kv_cache_dtype, float kv_scale) {
|
const int blocksparse_block_size, const int blocksparse_head_sliding_step) {
|
||||||
TORCH_CHECK(kv_scale == 1.0f);
|
TORCH_CHECK(kv_scale == 1.0f);
|
||||||
|
TORCH_CHECK(blocksparse_vert_stride <= 1,
|
||||||
|
"CPU backend does not support blocksparse attention yet.");
|
||||||
VLLM_DISPATCH_FLOATING_TYPES(query.scalar_type(), "paged_attention_v2_impl",
|
VLLM_DISPATCH_FLOATING_TYPES(query.scalar_type(), "paged_attention_v2_impl",
|
||||||
[&] {
|
[&] {
|
||||||
CPU_KERNEL_GUARD_IN(paged_attention_v2_impl)
|
CPU_KERNEL_GUARD_IN(paged_attention_v2_impl)
|
||||||
|
33
csrc/ops.h
33
csrc/ops.h
@ -2,23 +2,24 @@
|
|||||||
|
|
||||||
#include <torch/extension.h>
|
#include <torch/extension.h>
|
||||||
|
|
||||||
void paged_attention_v1(torch::Tensor& out, torch::Tensor& query,
|
void paged_attention_v1(
|
||||||
torch::Tensor& key_cache, torch::Tensor& value_cache,
|
torch::Tensor& out, torch::Tensor& query, torch::Tensor& key_cache,
|
||||||
int num_kv_heads, float scale,
|
torch::Tensor& value_cache, int num_kv_heads, float scale,
|
||||||
torch::Tensor& block_tables, torch::Tensor& seq_lens,
|
torch::Tensor& block_tables, torch::Tensor& seq_lens, int block_size,
|
||||||
int block_size, int max_seq_len,
|
int max_seq_len, const c10::optional<torch::Tensor>& alibi_slopes,
|
||||||
const c10::optional<torch::Tensor>& alibi_slopes,
|
const std::string& kv_cache_dtype, float kv_scale, const int tp_rank,
|
||||||
const std::string& kv_cache_dtype, float kv_scale);
|
const int blocksparse_local_blocks, const int blocksparse_vert_stride,
|
||||||
|
const int blocksparse_block_size, const int blocksparse_head_sliding_step);
|
||||||
|
|
||||||
void paged_attention_v2(torch::Tensor& out, torch::Tensor& exp_sums,
|
void paged_attention_v2(
|
||||||
torch::Tensor& max_logits, torch::Tensor& tmp_out,
|
torch::Tensor& out, torch::Tensor& exp_sums, torch::Tensor& max_logits,
|
||||||
torch::Tensor& query, torch::Tensor& key_cache,
|
torch::Tensor& tmp_out, torch::Tensor& query, torch::Tensor& key_cache,
|
||||||
torch::Tensor& value_cache, int num_kv_heads,
|
torch::Tensor& value_cache, int num_kv_heads, float scale,
|
||||||
float scale, torch::Tensor& block_tables,
|
torch::Tensor& block_tables, torch::Tensor& seq_lens, int block_size,
|
||||||
torch::Tensor& seq_lens, int block_size,
|
int max_seq_len, const c10::optional<torch::Tensor>& alibi_slopes,
|
||||||
int max_seq_len,
|
const std::string& kv_cache_dtype, float kv_scale, const int tp_rank,
|
||||||
const c10::optional<torch::Tensor>& alibi_slopes,
|
const int blocksparse_local_blocks, const int blocksparse_vert_stride,
|
||||||
const std::string& kv_cache_dtype, float kv_scale);
|
const int blocksparse_block_size, const int blocksparse_head_sliding_step);
|
||||||
|
|
||||||
void rms_norm(torch::Tensor& out, torch::Tensor& input, torch::Tensor& weight,
|
void rms_norm(torch::Tensor& out, torch::Tensor& input, torch::Tensor& weight,
|
||||||
float epsilon);
|
float epsilon);
|
||||||
|
@ -123,6 +123,10 @@ Alongside each architecture, we include some popular models that use it.
|
|||||||
- Phi-3
|
- Phi-3
|
||||||
- :code:`microsoft/Phi-3-mini-4k-instruct`, :code:`microsoft/Phi-3-mini-128k-instruct`, etc.
|
- :code:`microsoft/Phi-3-mini-4k-instruct`, :code:`microsoft/Phi-3-mini-128k-instruct`, etc.
|
||||||
-
|
-
|
||||||
|
* - :code:`Phi3SmallForCausalLM`
|
||||||
|
- Phi-3-Small
|
||||||
|
- :code:`microsoft/Phi-3-small-8k-instruct`, :code:`microsoft/Phi-3-small-128k-instruct`, etc.
|
||||||
|
-
|
||||||
* - :code:`QWenLMHeadModel`
|
* - :code:`QWenLMHeadModel`
|
||||||
- Qwen
|
- Qwen
|
||||||
- :code:`Qwen/Qwen-7B`, :code:`Qwen/Qwen-7B-Chat`, etc.
|
- :code:`Qwen/Qwen-7B`, :code:`Qwen/Qwen-7B-Chat`, etc.
|
||||||
|
442
tests/kernels/test_blocksparse_attention.py
Normal file
442
tests/kernels/test_blocksparse_attention.py
Normal file
@ -0,0 +1,442 @@
|
|||||||
|
import random
|
||||||
|
from typing import List, Optional, Tuple
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from vllm import _custom_ops as ops
|
||||||
|
from vllm.attention.ops.blocksparse_attention.interface import (
|
||||||
|
LocalStridedBlockSparseAttn)
|
||||||
|
from vllm.utils import get_max_shared_memory_bytes, is_hip
|
||||||
|
|
||||||
|
from .allclose_default import get_default_atol, get_default_rtol
|
||||||
|
|
||||||
|
FLOAT32_BYTES = torch.finfo(torch.float).bits // 8
|
||||||
|
# This will change depending on the compute capability.
|
||||||
|
# - 512 as a buffer
|
||||||
|
MAX_SEQ_LEN = get_max_shared_memory_bytes() // FLOAT32_BYTES - 512
|
||||||
|
# MAX_SEQ_LEN = 2771
|
||||||
|
|
||||||
|
# There may not be enough gpu memory due to large NUM_BLOCKS.
|
||||||
|
# Reduce NUM_BLOCKS when it happens.
|
||||||
|
NUM_BLOCKS = 4321 # Arbitrary values for testing
|
||||||
|
PARTITION_SIZE = 512
|
||||||
|
DTYPES = [torch.half, torch.bfloat16]
|
||||||
|
NUM_GEN_SEQS = [3] # Arbitrary values for testing
|
||||||
|
NUM_PREFILL_SEQS = [3] # Arbitrary values for testing
|
||||||
|
NUM_HEADS = [(40, 40), (64, 8)] # Arbitrary values for testing
|
||||||
|
|
||||||
|
HEAD_SIZES = [64, 112]
|
||||||
|
BLOCK_SIZES = [16, 32]
|
||||||
|
USE_ALIBI = [False, True]
|
||||||
|
KV_CACHE_DTYPE = ["auto", "fp8"]
|
||||||
|
SEEDS = [0]
|
||||||
|
CUDA_DEVICES = ['cuda:0']
|
||||||
|
BLOCKSPARSE_LOCAL_BLOCKS = [16]
|
||||||
|
BLOCKSPARSE_VERT_STRIDES = [8]
|
||||||
|
|
||||||
|
BLOCKSPARSE_BLOCK_SIZES = [64]
|
||||||
|
BLOCKSPARSE_HEADS_SLIDINGS = [0, 2, -1]
|
||||||
|
BLOCKSPARSE_HOMO_HEADS = [True, False]
|
||||||
|
|
||||||
|
|
||||||
|
def ref_masked_attention(
|
||||||
|
query: torch.Tensor,
|
||||||
|
key: torch.Tensor,
|
||||||
|
value: torch.Tensor,
|
||||||
|
scale: float,
|
||||||
|
attn_mask: Optional[torch.Tensor] = None,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
attn_weights = scale * torch.einsum("qhd,khd->hqk", query, key).float()
|
||||||
|
if attn_mask is not None:
|
||||||
|
attn_weights = attn_weights + attn_mask.float()
|
||||||
|
attn_weights = torch.softmax(attn_weights, dim=-1).to(value.dtype)
|
||||||
|
out = torch.einsum("hqk,khd->qhd", attn_weights, value)
|
||||||
|
return out
|
||||||
|
|
||||||
|
|
||||||
|
def ref_single_query_cached_kv_attention(
|
||||||
|
output: torch.Tensor,
|
||||||
|
query: torch.Tensor,
|
||||||
|
num_queries_per_kv: int,
|
||||||
|
key_cache: torch.Tensor,
|
||||||
|
value_cache: torch.Tensor,
|
||||||
|
block_tables: torch.Tensor,
|
||||||
|
seq_lens: torch.Tensor,
|
||||||
|
scale: float,
|
||||||
|
alibi_slopes: Optional[torch.Tensor],
|
||||||
|
tp_rank: int = 0,
|
||||||
|
blocksparse_local_blocks: int = 0,
|
||||||
|
blocksparse_vert_stride: int = 1,
|
||||||
|
blocksparse_block_size: int = 64,
|
||||||
|
blocksparse_head_sliding_step: int = 0,
|
||||||
|
) -> None:
|
||||||
|
num_query_heads = query.shape[1]
|
||||||
|
num_kv_heads = value_cache.shape[1]
|
||||||
|
head_size = value_cache.shape[2]
|
||||||
|
block_size = value_cache.shape[3]
|
||||||
|
num_seqs = query.shape[0]
|
||||||
|
|
||||||
|
block_tables = block_tables.cpu().tolist()
|
||||||
|
seq_lens = seq_lens.cpu().tolist()
|
||||||
|
for i in range(num_seqs):
|
||||||
|
q = query[i].unsqueeze(0)
|
||||||
|
block_table = block_tables[i]
|
||||||
|
seq_len = int(seq_lens[i])
|
||||||
|
|
||||||
|
keys = []
|
||||||
|
values = []
|
||||||
|
for j in range(seq_len):
|
||||||
|
block_number = int(block_table[j // block_size])
|
||||||
|
block_offset = j % block_size
|
||||||
|
|
||||||
|
k = key_cache[block_number, :, :, block_offset, :]
|
||||||
|
k = k.reshape(num_kv_heads, head_size)
|
||||||
|
keys.append(k)
|
||||||
|
|
||||||
|
v = value_cache[block_number, :, :, block_offset]
|
||||||
|
values.append(v)
|
||||||
|
keys = torch.stack(keys, dim=0)
|
||||||
|
values = torch.stack(values, dim=0)
|
||||||
|
if num_queries_per_kv > 1:
|
||||||
|
# Handle MQA and GQA
|
||||||
|
keys = torch.repeat_interleave(keys, num_queries_per_kv, dim=1)
|
||||||
|
values = torch.repeat_interleave(values, num_queries_per_kv, dim=1)
|
||||||
|
|
||||||
|
alibi_bias = None
|
||||||
|
if alibi_slopes is not None:
|
||||||
|
# Create the ALiBi bias used in the paged attention kernel.
|
||||||
|
position_ids = torch.arange(seq_len).int()
|
||||||
|
alibi_bias = (position_ids - seq_len + 1).float()
|
||||||
|
alibi_bias = alibi_slopes.view(-1, 1, 1) * alibi_bias.view(
|
||||||
|
1, 1, -1)
|
||||||
|
|
||||||
|
if blocksparse_vert_stride >= 1:
|
||||||
|
bsize = blocksparse_block_size
|
||||||
|
hsliding = blocksparse_head_sliding_step
|
||||||
|
vert = blocksparse_vert_stride
|
||||||
|
locals = blocksparse_local_blocks
|
||||||
|
qb = (seq_len - 1) // bsize
|
||||||
|
attn_mask = q.new_zeros(
|
||||||
|
(num_query_heads, 1, seq_len)).float() - torch.inf
|
||||||
|
for h in range(num_query_heads):
|
||||||
|
if hsliding >= 0: # slide with q heads
|
||||||
|
bs_offset = (tp_rank * num_query_heads + h) * hsliding + 1
|
||||||
|
else: # slide with kv heads
|
||||||
|
bs_offset = (tp_rank * num_kv_heads +
|
||||||
|
h // num_queries_per_kv) * (-hsliding) + 1
|
||||||
|
for kb in range(qb + 1):
|
||||||
|
kj = kb * bsize
|
||||||
|
if (qb - kb) < locals or \
|
||||||
|
(kb + bs_offset) % vert == 0:
|
||||||
|
attn_mask[h, 0, kj:min(kj + bsize, seq_len)] = 0
|
||||||
|
if alibi_bias is not None:
|
||||||
|
attn_mask += alibi_bias
|
||||||
|
else:
|
||||||
|
attn_mask = alibi_bias
|
||||||
|
|
||||||
|
out = ref_masked_attention(q, keys, values, scale, attn_mask=attn_mask)
|
||||||
|
out = out.view(num_query_heads, head_size)
|
||||||
|
output[i].copy_(out, non_blocking=True)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("version", ["v1", "v2"])
|
||||||
|
@pytest.mark.parametrize("num_seqs", NUM_GEN_SEQS)
|
||||||
|
@pytest.mark.parametrize("num_heads", NUM_HEADS)
|
||||||
|
@pytest.mark.parametrize("head_size", HEAD_SIZES)
|
||||||
|
@pytest.mark.parametrize("use_alibi", USE_ALIBI)
|
||||||
|
@pytest.mark.parametrize("block_size", BLOCK_SIZES)
|
||||||
|
@pytest.mark.parametrize("dtype", DTYPES)
|
||||||
|
@pytest.mark.parametrize("kv_cache_dtype", KV_CACHE_DTYPE)
|
||||||
|
@pytest.mark.parametrize("seed", SEEDS)
|
||||||
|
@pytest.mark.parametrize("device", CUDA_DEVICES)
|
||||||
|
@pytest.mark.parametrize("blocksparse_local_blocks", BLOCKSPARSE_LOCAL_BLOCKS)
|
||||||
|
@pytest.mark.parametrize("blocksparse_vert_stride", BLOCKSPARSE_VERT_STRIDES)
|
||||||
|
@pytest.mark.parametrize("blocksparse_block_size", BLOCKSPARSE_BLOCK_SIZES)
|
||||||
|
@pytest.mark.parametrize("blocksparse_head_sliding_step",
|
||||||
|
BLOCKSPARSE_HEADS_SLIDINGS)
|
||||||
|
def test_paged_attention(
|
||||||
|
kv_cache_factory,
|
||||||
|
version: str,
|
||||||
|
num_seqs: int,
|
||||||
|
num_heads: Tuple[int, int],
|
||||||
|
head_size: int,
|
||||||
|
use_alibi: bool,
|
||||||
|
block_size: int,
|
||||||
|
dtype: torch.dtype,
|
||||||
|
kv_cache_dtype: str,
|
||||||
|
seed: int,
|
||||||
|
device: str,
|
||||||
|
blocksparse_local_blocks: int,
|
||||||
|
blocksparse_vert_stride: int,
|
||||||
|
blocksparse_block_size: int,
|
||||||
|
blocksparse_head_sliding_step: int,
|
||||||
|
) -> None:
|
||||||
|
random.seed(seed)
|
||||||
|
torch.random.manual_seed(seed)
|
||||||
|
if torch.cuda.is_available():
|
||||||
|
torch.cuda.manual_seed(seed)
|
||||||
|
torch.set_default_device(device)
|
||||||
|
scale = float(1.0 / (head_size**0.5))
|
||||||
|
num_query_heads, num_kv_heads = num_heads
|
||||||
|
query = torch.empty(num_seqs, num_query_heads, head_size, dtype=dtype)
|
||||||
|
query.uniform_(-scale, scale)
|
||||||
|
|
||||||
|
assert num_query_heads % num_kv_heads == 0
|
||||||
|
num_queries_per_kv = num_query_heads // num_kv_heads
|
||||||
|
alibi_slopes = None
|
||||||
|
if use_alibi:
|
||||||
|
alibi_slopes = torch.rand(num_query_heads, dtype=torch.float)
|
||||||
|
|
||||||
|
seq_lens = [random.randint(1, MAX_SEQ_LEN) for _ in range(num_seqs)]
|
||||||
|
seq_lens[-1] = MAX_SEQ_LEN
|
||||||
|
max_seq_len = max(seq_lens)
|
||||||
|
seq_lens = torch.tensor(seq_lens, dtype=torch.int)
|
||||||
|
|
||||||
|
# Create the block tables.
|
||||||
|
max_num_blocks_per_seq = (max_seq_len + block_size - 1) // block_size
|
||||||
|
block_tables = []
|
||||||
|
for _ in range(num_seqs):
|
||||||
|
block_table = [
|
||||||
|
random.randint(0, NUM_BLOCKS - 1)
|
||||||
|
for _ in range(max_num_blocks_per_seq)
|
||||||
|
]
|
||||||
|
block_tables.append(block_table)
|
||||||
|
block_tables = torch.tensor(block_tables, dtype=torch.int)
|
||||||
|
|
||||||
|
# Create the KV caches.
|
||||||
|
key_caches, value_caches = kv_cache_factory(NUM_BLOCKS, block_size, 1,
|
||||||
|
num_kv_heads, head_size,
|
||||||
|
kv_cache_dtype, dtype, seed,
|
||||||
|
device)
|
||||||
|
key_cache, value_cache = key_caches[0], value_caches[0]
|
||||||
|
|
||||||
|
# Using default kv_scale
|
||||||
|
kv_scale = 1.0
|
||||||
|
tp_rank = 0
|
||||||
|
|
||||||
|
# Call the paged attention kernel.
|
||||||
|
output = torch.empty_like(query)
|
||||||
|
if version == "v1":
|
||||||
|
ops.paged_attention_v1(
|
||||||
|
output,
|
||||||
|
query,
|
||||||
|
key_cache,
|
||||||
|
value_cache,
|
||||||
|
num_kv_heads,
|
||||||
|
scale,
|
||||||
|
block_tables,
|
||||||
|
seq_lens,
|
||||||
|
block_size,
|
||||||
|
max_seq_len,
|
||||||
|
alibi_slopes,
|
||||||
|
kv_cache_dtype,
|
||||||
|
kv_scale,
|
||||||
|
tp_rank=tp_rank,
|
||||||
|
blocksparse_local_blocks=blocksparse_local_blocks,
|
||||||
|
blocksparse_vert_stride=blocksparse_vert_stride,
|
||||||
|
blocksparse_block_size=blocksparse_block_size,
|
||||||
|
blocksparse_head_sliding_step=blocksparse_head_sliding_step,
|
||||||
|
)
|
||||||
|
elif version == "v2":
|
||||||
|
num_partitions = ((max_seq_len + PARTITION_SIZE - 1) // PARTITION_SIZE)
|
||||||
|
assert PARTITION_SIZE % block_size == 0
|
||||||
|
num_seqs, num_heads, head_size = output.shape
|
||||||
|
tmp_output = torch.empty(
|
||||||
|
size=(num_seqs, num_heads, num_partitions, head_size),
|
||||||
|
dtype=output.dtype,
|
||||||
|
)
|
||||||
|
exp_sums = torch.empty(
|
||||||
|
size=(num_seqs, num_heads, num_partitions),
|
||||||
|
dtype=torch.float32,
|
||||||
|
)
|
||||||
|
max_logits = torch.empty_like(exp_sums)
|
||||||
|
ops.paged_attention_v2(
|
||||||
|
output,
|
||||||
|
exp_sums,
|
||||||
|
max_logits,
|
||||||
|
tmp_output,
|
||||||
|
query,
|
||||||
|
key_cache,
|
||||||
|
value_cache,
|
||||||
|
num_kv_heads,
|
||||||
|
scale,
|
||||||
|
block_tables,
|
||||||
|
seq_lens,
|
||||||
|
block_size,
|
||||||
|
max_seq_len,
|
||||||
|
alibi_slopes,
|
||||||
|
kv_cache_dtype,
|
||||||
|
kv_scale,
|
||||||
|
tp_rank=tp_rank,
|
||||||
|
blocksparse_local_blocks=blocksparse_local_blocks,
|
||||||
|
blocksparse_vert_stride=blocksparse_vert_stride,
|
||||||
|
blocksparse_block_size=blocksparse_block_size,
|
||||||
|
blocksparse_head_sliding_step=blocksparse_head_sliding_step,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
raise AssertionError(f"Unknown version: {version}")
|
||||||
|
|
||||||
|
# Run the reference implementation.
|
||||||
|
if kv_cache_dtype == "fp8":
|
||||||
|
# Convert cache data back to dtype.
|
||||||
|
x = 16 // torch.tensor([], dtype=dtype).element_size()
|
||||||
|
key_cache_shape = (NUM_BLOCKS, num_kv_heads, head_size // x,
|
||||||
|
block_size, x)
|
||||||
|
dequantized_key_cache = torch.empty(size=key_cache_shape,
|
||||||
|
dtype=dtype,
|
||||||
|
device=device)
|
||||||
|
ops.convert_fp8(dequantized_key_cache, key_cache)
|
||||||
|
key_cache = dequantized_key_cache
|
||||||
|
|
||||||
|
value_cache_shape = value_cache.shape
|
||||||
|
dequantized_value_cache = torch.empty(size=value_cache_shape,
|
||||||
|
dtype=dtype,
|
||||||
|
device=device)
|
||||||
|
ops.convert_fp8(dequantized_value_cache, value_cache)
|
||||||
|
value_cache = dequantized_value_cache
|
||||||
|
|
||||||
|
ref_output = torch.empty_like(query)
|
||||||
|
ref_single_query_cached_kv_attention(
|
||||||
|
ref_output,
|
||||||
|
query,
|
||||||
|
num_queries_per_kv,
|
||||||
|
key_cache,
|
||||||
|
value_cache,
|
||||||
|
block_tables,
|
||||||
|
seq_lens,
|
||||||
|
scale,
|
||||||
|
alibi_slopes,
|
||||||
|
tp_rank,
|
||||||
|
blocksparse_local_blocks,
|
||||||
|
blocksparse_vert_stride,
|
||||||
|
blocksparse_block_size,
|
||||||
|
blocksparse_head_sliding_step,
|
||||||
|
)
|
||||||
|
|
||||||
|
# NOTE(woosuk): Due to the kernel-level differences in the two
|
||||||
|
# implementations, there is a small numerical difference in the two
|
||||||
|
# outputs. Thus, we use a relaxed tolerance for the test.
|
||||||
|
atol = get_default_atol(output) if is_hip() else 1e-3
|
||||||
|
rtol = get_default_rtol(output) if is_hip() else 1e-5
|
||||||
|
|
||||||
|
# NOTE(zhaoyang): FP8 KV Cache will introduce quantization error,
|
||||||
|
# so we use a relaxed tolerance for the test.
|
||||||
|
atol, rtol = 1e-3, 1e-5
|
||||||
|
if kv_cache_dtype == "fp8":
|
||||||
|
atol, rtol = 1e-2, 1e-5
|
||||||
|
assert torch.allclose(output, ref_output, atol=atol, rtol=rtol)
|
||||||
|
|
||||||
|
|
||||||
|
def ref_multi_query_kv_attention(
|
||||||
|
cu_seq_lens: List[int],
|
||||||
|
query: torch.Tensor,
|
||||||
|
key: torch.Tensor,
|
||||||
|
value: torch.Tensor,
|
||||||
|
scale: float,
|
||||||
|
dtype: torch.dtype,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
num_seqs = len(cu_seq_lens) - 1
|
||||||
|
ref_outputs = []
|
||||||
|
for i in range(num_seqs):
|
||||||
|
start_idx = cu_seq_lens[i]
|
||||||
|
end_idx = cu_seq_lens[i + 1]
|
||||||
|
seq_len = end_idx - start_idx
|
||||||
|
|
||||||
|
# Create attention mask.
|
||||||
|
attn_mask = torch.triu(torch.ones(seq_len, seq_len, dtype=dtype),
|
||||||
|
diagonal=1)
|
||||||
|
attn_mask = attn_mask * torch.finfo(dtype).min
|
||||||
|
attn_mask = attn_mask.to(dtype=dtype)
|
||||||
|
|
||||||
|
ref_output = ref_masked_attention(
|
||||||
|
query[start_idx:end_idx],
|
||||||
|
key[start_idx:end_idx],
|
||||||
|
value[start_idx:end_idx],
|
||||||
|
scale,
|
||||||
|
attn_mask=attn_mask,
|
||||||
|
)
|
||||||
|
ref_outputs.append(ref_output)
|
||||||
|
ref_output = torch.cat(ref_outputs, dim=0)
|
||||||
|
return ref_output
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("num_seqs", NUM_PREFILL_SEQS)
|
||||||
|
@pytest.mark.parametrize("num_heads", NUM_HEADS)
|
||||||
|
@pytest.mark.parametrize("head_size", HEAD_SIZES)
|
||||||
|
@pytest.mark.parametrize("blocksparse_local_blocks", BLOCKSPARSE_LOCAL_BLOCKS)
|
||||||
|
@pytest.mark.parametrize("blocksparse_vert_stride", BLOCKSPARSE_VERT_STRIDES)
|
||||||
|
@pytest.mark.parametrize("blocksparse_block_size", BLOCKSPARSE_BLOCK_SIZES)
|
||||||
|
@pytest.mark.parametrize("blocksparse_homo_heads", BLOCKSPARSE_HOMO_HEADS)
|
||||||
|
@pytest.mark.parametrize("dtype", DTYPES)
|
||||||
|
@pytest.mark.parametrize("seed", SEEDS)
|
||||||
|
@pytest.mark.parametrize("device", CUDA_DEVICES)
|
||||||
|
@torch.inference_mode()
|
||||||
|
def test_varlen_blocksparse_attention_prefill(
|
||||||
|
num_seqs: int,
|
||||||
|
num_heads: Tuple[int, int],
|
||||||
|
head_size: int,
|
||||||
|
blocksparse_local_blocks: int,
|
||||||
|
blocksparse_vert_stride: int,
|
||||||
|
blocksparse_block_size: int,
|
||||||
|
blocksparse_homo_heads: bool,
|
||||||
|
dtype: torch.dtype,
|
||||||
|
seed: int,
|
||||||
|
device: str,
|
||||||
|
) -> None:
|
||||||
|
random.seed(seed)
|
||||||
|
torch.random.manual_seed(seed)
|
||||||
|
if torch.cuda.is_available():
|
||||||
|
torch.cuda.manual_seed(seed)
|
||||||
|
torch.set_default_device(device)
|
||||||
|
# MAX_SEQ_LEN sometimes causes OOM in the reference implementation.
|
||||||
|
# As the xformers library is already tested with its own tests, we can use
|
||||||
|
# a smaller MAX_SEQ_LEN here.
|
||||||
|
max_len = min(MAX_SEQ_LEN, 4096)
|
||||||
|
seq_lens = random.sample(range(1, max_len), num_seqs)
|
||||||
|
cu_seq_lens = torch.cumsum(torch.tensor([0] + seq_lens), dim=0)
|
||||||
|
num_tokens = sum(seq_lens)
|
||||||
|
|
||||||
|
scale = float(1.0 / (head_size**0.5))
|
||||||
|
num_query_heads, num_kv_heads = num_heads
|
||||||
|
assert num_query_heads % num_kv_heads == 0
|
||||||
|
num_queries_per_kv = num_query_heads // num_kv_heads
|
||||||
|
|
||||||
|
qkv = torch.empty(num_tokens,
|
||||||
|
num_query_heads + 2 * num_kv_heads,
|
||||||
|
head_size,
|
||||||
|
dtype=dtype)
|
||||||
|
qkv.uniform_(-scale, scale)
|
||||||
|
query, key, value = qkv.split(
|
||||||
|
[num_query_heads, num_kv_heads, num_kv_heads], dim=1)
|
||||||
|
|
||||||
|
bs_attn_op = LocalStridedBlockSparseAttn(
|
||||||
|
num_query_heads,
|
||||||
|
max_len,
|
||||||
|
local_blocks=blocksparse_local_blocks,
|
||||||
|
vert_stride=blocksparse_vert_stride,
|
||||||
|
block_size=blocksparse_block_size,
|
||||||
|
device=device,
|
||||||
|
dtype=dtype,
|
||||||
|
homo_head=blocksparse_homo_heads)
|
||||||
|
|
||||||
|
output = bs_attn_op(query,
|
||||||
|
key,
|
||||||
|
value,
|
||||||
|
cu_seq_lens.to(device),
|
||||||
|
sm_scale=scale)
|
||||||
|
|
||||||
|
if num_queries_per_kv > 1:
|
||||||
|
# Handle MQA and GQA
|
||||||
|
key = torch.repeat_interleave(key, num_queries_per_kv, dim=1)
|
||||||
|
value = torch.repeat_interleave(value, num_queries_per_kv, dim=1)
|
||||||
|
|
||||||
|
ref_output = ref_multi_query_kv_attention(
|
||||||
|
cu_seq_lens,
|
||||||
|
query,
|
||||||
|
key,
|
||||||
|
value,
|
||||||
|
scale,
|
||||||
|
dtype,
|
||||||
|
)
|
||||||
|
assert torch.allclose(output, ref_output, atol=1e-2, rtol=1e-2)
|
@ -45,11 +45,17 @@ def paged_attention_v1(
|
|||||||
alibi_slopes: Optional[torch.Tensor],
|
alibi_slopes: Optional[torch.Tensor],
|
||||||
kv_cache_dtype: str,
|
kv_cache_dtype: str,
|
||||||
kv_scale: float,
|
kv_scale: float,
|
||||||
|
tp_rank: int = 0,
|
||||||
|
blocksparse_local_blocks: int = 0,
|
||||||
|
blocksparse_vert_stride: int = 0,
|
||||||
|
blocksparse_block_size: int = 64,
|
||||||
|
blocksparse_head_sliding_step: int = 0,
|
||||||
) -> None:
|
) -> None:
|
||||||
vllm_ops.paged_attention_v1(out, query, key_cache, value_cache,
|
vllm_ops.paged_attention_v1(
|
||||||
num_kv_heads, scale, block_tables, seq_lens,
|
out, query, key_cache, value_cache, num_kv_heads, scale, block_tables,
|
||||||
block_size, max_seq_len, alibi_slopes,
|
seq_lens, block_size, max_seq_len, alibi_slopes, kv_cache_dtype,
|
||||||
kv_cache_dtype, kv_scale)
|
kv_scale, tp_rank, blocksparse_local_blocks, blocksparse_vert_stride,
|
||||||
|
blocksparse_block_size, blocksparse_head_sliding_step)
|
||||||
|
|
||||||
|
|
||||||
def paged_attention_v2(
|
def paged_attention_v2(
|
||||||
@ -69,12 +75,18 @@ def paged_attention_v2(
|
|||||||
alibi_slopes: Optional[torch.Tensor],
|
alibi_slopes: Optional[torch.Tensor],
|
||||||
kv_cache_dtype: str,
|
kv_cache_dtype: str,
|
||||||
kv_scale: float,
|
kv_scale: float,
|
||||||
|
tp_rank: int = 0,
|
||||||
|
blocksparse_local_blocks: int = 0,
|
||||||
|
blocksparse_vert_stride: int = 0,
|
||||||
|
blocksparse_block_size: int = 64,
|
||||||
|
blocksparse_head_sliding_step: int = 0,
|
||||||
) -> None:
|
) -> None:
|
||||||
vllm_ops.paged_attention_v2(out, exp_sum, max_logits, tmp_out, query,
|
vllm_ops.paged_attention_v2(
|
||||||
key_cache, value_cache, num_kv_heads, scale,
|
out, exp_sum, max_logits, tmp_out, query, key_cache, value_cache,
|
||||||
block_tables, seq_lens, block_size,
|
num_kv_heads, scale, block_tables, seq_lens, block_size, max_seq_len,
|
||||||
max_seq_len, alibi_slopes, kv_cache_dtype,
|
alibi_slopes, kv_cache_dtype, kv_scale, tp_rank,
|
||||||
kv_scale)
|
blocksparse_local_blocks, blocksparse_vert_stride,
|
||||||
|
blocksparse_block_size, blocksparse_head_sliding_step)
|
||||||
|
|
||||||
|
|
||||||
# pos encoding ops
|
# pos encoding ops
|
||||||
|
@ -111,6 +111,7 @@ class AttentionImpl(ABC, Generic[T]):
|
|||||||
alibi_slopes: Optional[List[float]] = None,
|
alibi_slopes: Optional[List[float]] = None,
|
||||||
sliding_window: Optional[int] = None,
|
sliding_window: Optional[int] = None,
|
||||||
kv_cache_dtype: str = "auto",
|
kv_cache_dtype: str = "auto",
|
||||||
|
blocksparse_params: Optional[Dict[str, Any]] = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
|
410
vllm/attention/backends/blocksparse_attn.py
Normal file
410
vllm/attention/backends/blocksparse_attn.py
Normal file
@ -0,0 +1,410 @@
|
|||||||
|
from dataclasses import dataclass, field
|
||||||
|
from typing import Any, Dict, List, Optional, Tuple, Type
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
|
||||||
|
AttentionMetadata)
|
||||||
|
from vllm.attention.ops.blocksparse_attention.interface import (
|
||||||
|
LocalStridedBlockSparseAttn, get_head_sliding_step)
|
||||||
|
from vllm.attention.ops.paged_attn import PagedAttention
|
||||||
|
from vllm.distributed import (get_tensor_model_parallel_rank,
|
||||||
|
get_tensor_model_parallel_world_size)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class BlocksparseParams:
|
||||||
|
max_seqlen: int
|
||||||
|
|
||||||
|
# Num q heads per tensor-parallel rank/partition
|
||||||
|
num_heads: int # per TP partition
|
||||||
|
# Num kv heads per tensor-parallel rank/partition
|
||||||
|
num_kv_heads: int
|
||||||
|
|
||||||
|
# block size used for blocksparse attention.
|
||||||
|
# This is the block_size used in `local_blocks`, `vert_stride`.
|
||||||
|
block_size: int
|
||||||
|
|
||||||
|
# Number of blocks for local attention, i.e., number of
|
||||||
|
# local attended tokens / `sparse_block_size`
|
||||||
|
local_blocks: int
|
||||||
|
|
||||||
|
# Attend to one block per every `vert_stride` blocks.
|
||||||
|
# Controlling the sparsity
|
||||||
|
vert_stride: int
|
||||||
|
"""
|
||||||
|
If to use the same vertical stride offset for all heads,
|
||||||
|
i.e., attend to the same block of tokens on all heads.
|
||||||
|
By default, it is False, i.e., attention on the non-local
|
||||||
|
blocks depends on the `head_idx`, that is on
|
||||||
|
blocks satisfying
|
||||||
|
`(block_idx + head_idx * head_sliding_step + 1) % vert_stride == 0`
|
||||||
|
where `head_sliding_step=max(1, int(vert_stride / num_total_heads))`,
|
||||||
|
`block_idx = position_id // sparse_block_size`.
|
||||||
|
See `..ops.blocksparse_attention.utils:get_sparse_attn_mask`
|
||||||
|
for more detail.
|
||||||
|
"""
|
||||||
|
homo_head: bool = False
|
||||||
|
|
||||||
|
# If within a group, the kv offsets that each q attends is the same or no.
|
||||||
|
homo_head_group: bool = False
|
||||||
|
|
||||||
|
# Decided by homo_head and homo_head group
|
||||||
|
head_sliding_step: int = field(init=False)
|
||||||
|
|
||||||
|
# range of q heads to for a TP rank
|
||||||
|
active_head_range: Tuple = field(init=False)
|
||||||
|
|
||||||
|
def __post_init__(self):
|
||||||
|
assert self.block_size > 0
|
||||||
|
assert self.local_blocks >= 0
|
||||||
|
assert self.vert_stride >= 1
|
||||||
|
assert self.num_heads % self.num_kv_heads == 0
|
||||||
|
|
||||||
|
tp_size = get_tensor_model_parallel_world_size()
|
||||||
|
tp_rank = get_tensor_model_parallel_rank()
|
||||||
|
total_heads = tp_size * self.num_heads
|
||||||
|
total_kv_heads = tp_size * self.num_kv_heads
|
||||||
|
|
||||||
|
if self.homo_head:
|
||||||
|
self.head_sliding_step = 0
|
||||||
|
elif self.homo_head_group:
|
||||||
|
head_sliding_step = get_head_sliding_step(total_kv_heads,
|
||||||
|
self.vert_stride)
|
||||||
|
# negative indicates sliding along kv heads, i.e., homo q group
|
||||||
|
self.head_sliding_step = -head_sliding_step
|
||||||
|
else:
|
||||||
|
self.head_sliding_step = get_head_sliding_step(
|
||||||
|
total_heads, self.vert_stride)
|
||||||
|
|
||||||
|
self.active_head_range = (
|
||||||
|
tp_rank * self.num_heads,
|
||||||
|
(tp_rank + 1) * self.num_heads,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class BlocksparseFlashAttentionBackend(AttentionBackend):
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def get_impl_cls() -> Type["BlocksparseFlashAttentionImpl"]:
|
||||||
|
return BlocksparseFlashAttentionImpl
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def make_metadata(*args, **kwargs) -> "BlocksparseFlashAttentionMetadata":
|
||||||
|
return BlocksparseFlashAttentionMetadata(*args, **kwargs)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def get_kv_cache_shape(
|
||||||
|
num_blocks: int,
|
||||||
|
block_size: int,
|
||||||
|
num_kv_heads: int,
|
||||||
|
head_size: int,
|
||||||
|
) -> Tuple[int, ...]:
|
||||||
|
return PagedAttention.get_kv_cache_shape(num_blocks, block_size,
|
||||||
|
num_kv_heads, head_size)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def swap_blocks(
|
||||||
|
src_kv_cache: torch.Tensor,
|
||||||
|
dst_kv_cache: torch.Tensor,
|
||||||
|
src_to_dst: Dict[int, int],
|
||||||
|
) -> None:
|
||||||
|
PagedAttention.swap_blocks(src_kv_cache, dst_kv_cache, src_to_dst)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def copy_blocks(
|
||||||
|
kv_caches: List[torch.Tensor],
|
||||||
|
src_to_dists: Dict[int, List[int]],
|
||||||
|
) -> None:
|
||||||
|
PagedAttention.copy_blocks(kv_caches, src_to_dists)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class BlocksparseFlashAttentionMetadata(AttentionMetadata):
|
||||||
|
"""A copy of Metadata for FlashAttentionBackend,
|
||||||
|
to avoid having to install flash_attn.
|
||||||
|
|
||||||
|
NOTE: Any python object stored here is not updated when it is
|
||||||
|
cuda-graph replayed. If you have values that need to be changed
|
||||||
|
dynamically, it should be stored in tensor. The tensor has to be
|
||||||
|
updated from `CUDAGraphRunner.forward` API.
|
||||||
|
"""
|
||||||
|
# (batch_size,). The sequence length per sequence. Sequence length means
|
||||||
|
# the computed tokens + new tokens None if it is a decoding.
|
||||||
|
seq_lens: Optional[List[int]]
|
||||||
|
# seq_lens stored as a tensor.
|
||||||
|
seq_lens_tensor: Optional[torch.Tensor]
|
||||||
|
|
||||||
|
# NOTE(sang): Definition of context_len, query_len, and seq_len.
|
||||||
|
# |---------- N-1 iteration --------|
|
||||||
|
# |---------------- N iteration ---------------------|
|
||||||
|
# |- tokenA -|......................|-- newTokens ---|
|
||||||
|
# |---------- context_len ----------|
|
||||||
|
# |-------------------- seq_len ----------------------|
|
||||||
|
# |-- query_len ---|
|
||||||
|
|
||||||
|
# Maximum query length in the batch. None for decoding.
|
||||||
|
max_query_len: Optional[int]
|
||||||
|
# Maximum sequence length among prefill batch. 0 if there are decoding
|
||||||
|
# requests only.
|
||||||
|
max_prefill_seq_len: int
|
||||||
|
# Maximum sequence length among decode batch. 0 if there are prefill
|
||||||
|
# requests only.
|
||||||
|
max_decode_seq_len: int
|
||||||
|
# (batch_size + 1,). The cumulative subquery lengths of the sequences in
|
||||||
|
# the batch, used to index into subquery. E.g., if the subquery length
|
||||||
|
# is [4, 6], it is [0, 4, 10].
|
||||||
|
query_start_loc: Optional[torch.Tensor]
|
||||||
|
# (batch_size + 1,). The cumulative sequence lengths of the sequences in
|
||||||
|
# the batch, used to index into sequence. E.g., if the sequence length is
|
||||||
|
# [4, 6], it is [0, 4, 10].
|
||||||
|
seq_start_loc: Optional[torch.Tensor]
|
||||||
|
# (batch_size,) A tensor of context lengths (tokens that are computed
|
||||||
|
# so far).
|
||||||
|
context_lens_tensor: Optional[torch.Tensor]
|
||||||
|
|
||||||
|
# (batch_size, max_blocks_per_seq).
|
||||||
|
# Block addresses per sequence. (Seq id -> list of physical block)
|
||||||
|
# E.g., [0, 1, 2] means tokens are stored in 0th, 1st, and 2nd blocks
|
||||||
|
# in the kv cache. Each block can contain up to block_size tokens.
|
||||||
|
# 2nd dimensions are padded up to max_blocks_per_seq if it is cuda-graph
|
||||||
|
# captured.
|
||||||
|
block_tables: Optional[torch.Tensor]
|
||||||
|
|
||||||
|
# Whether or not if cuda graph is enabled.
|
||||||
|
# Cuda-graph is currently enabled for decoding only.
|
||||||
|
# TODO(woosuk): Move `use_cuda_graph` out since it's unrelated to attention.
|
||||||
|
use_cuda_graph: bool
|
||||||
|
|
||||||
|
_cached_prefill_metadata: Optional[
|
||||||
|
"BlocksparseFlashAttentionMetadata"] = None
|
||||||
|
_cached_decode_metadata: Optional[
|
||||||
|
"BlocksparseFlashAttentionMetadata"] = None
|
||||||
|
|
||||||
|
@property
|
||||||
|
def prefill_metadata(
|
||||||
|
self) -> Optional["BlocksparseFlashAttentionMetadata"]:
|
||||||
|
if self.num_prefills == 0:
|
||||||
|
return None
|
||||||
|
|
||||||
|
if self._cached_prefill_metadata is not None:
|
||||||
|
return self._cached_prefill_metadata
|
||||||
|
|
||||||
|
assert self.seq_lens is not None
|
||||||
|
assert self.seq_lens_tensor is not None
|
||||||
|
assert self.query_start_loc is not None
|
||||||
|
assert self.context_lens_tensor is not None
|
||||||
|
assert self.block_tables is not None
|
||||||
|
assert self.seq_start_loc is not None
|
||||||
|
|
||||||
|
self._cached_prefill_metadata = BlocksparseFlashAttentionMetadata(
|
||||||
|
num_prefills=self.num_prefills,
|
||||||
|
num_prefill_tokens=self.num_prefill_tokens,
|
||||||
|
num_decode_tokens=0,
|
||||||
|
slot_mapping=self.slot_mapping[:self.num_prefill_tokens],
|
||||||
|
seq_lens=self.seq_lens[:self.num_prefills],
|
||||||
|
seq_lens_tensor=self.seq_lens_tensor[:self.num_prefills],
|
||||||
|
max_query_len=self.max_query_len,
|
||||||
|
max_prefill_seq_len=self.max_prefill_seq_len,
|
||||||
|
max_decode_seq_len=0,
|
||||||
|
query_start_loc=self.query_start_loc[:self.num_prefills + 1],
|
||||||
|
seq_start_loc=self.seq_start_loc[:self.num_prefills + 1],
|
||||||
|
context_lens_tensor=self.context_lens_tensor[:self.num_prefills],
|
||||||
|
block_tables=self.block_tables[:self.num_prefills],
|
||||||
|
use_cuda_graph=False,
|
||||||
|
)
|
||||||
|
return self._cached_prefill_metadata
|
||||||
|
|
||||||
|
@property
|
||||||
|
def decode_metadata(self) -> Optional["BlocksparseFlashAttentionMetadata"]:
|
||||||
|
if self.num_decode_tokens == 0:
|
||||||
|
return None
|
||||||
|
|
||||||
|
if self._cached_decode_metadata is not None:
|
||||||
|
return self._cached_decode_metadata
|
||||||
|
assert self.block_tables is not None
|
||||||
|
assert self.seq_lens_tensor is not None
|
||||||
|
|
||||||
|
self._cached_decode_metadata = BlocksparseFlashAttentionMetadata(
|
||||||
|
num_prefills=0,
|
||||||
|
num_prefill_tokens=0,
|
||||||
|
num_decode_tokens=self.num_decode_tokens,
|
||||||
|
slot_mapping=self.slot_mapping[self.num_prefill_tokens:],
|
||||||
|
seq_lens=None,
|
||||||
|
seq_lens_tensor=self.seq_lens_tensor[self.num_prefills:],
|
||||||
|
max_query_len=None,
|
||||||
|
max_prefill_seq_len=0,
|
||||||
|
max_decode_seq_len=self.max_decode_seq_len,
|
||||||
|
query_start_loc=None,
|
||||||
|
seq_start_loc=None,
|
||||||
|
context_lens_tensor=None,
|
||||||
|
block_tables=self.block_tables[self.num_prefills:],
|
||||||
|
use_cuda_graph=self.use_cuda_graph,
|
||||||
|
)
|
||||||
|
return self._cached_decode_metadata
|
||||||
|
|
||||||
|
|
||||||
|
class BlocksparseFlashAttentionImpl(AttentionImpl):
|
||||||
|
"""
|
||||||
|
If the input tensors contain prompt tokens, the layout is as follows:
|
||||||
|
|<--------------- num_prompt_tokens -------------->|
|
||||||
|
|<--prompt_0-->|<--prompt_1-->|...|<--prompt_N-1-->|
|
||||||
|
|
||||||
|
Otherwise, the layout is as follows:
|
||||||
|
|<------------------ num_generation_tokens (M) ----------------->|
|
||||||
|
|<--generation_0-->|..........|<--generation_M-1-->|<--padding-->|
|
||||||
|
|
||||||
|
Generation tokens can contain padding when cuda-graph is used.
|
||||||
|
Currently, prompt tokens don't contain any padding.
|
||||||
|
|
||||||
|
The prompts might have different lengths, while the generation tokens
|
||||||
|
always have length 1.
|
||||||
|
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
num_heads: int,
|
||||||
|
head_size: int,
|
||||||
|
scale: float,
|
||||||
|
num_kv_heads: int,
|
||||||
|
alibi_slopes: Optional[List[float]],
|
||||||
|
sliding_window: Optional[int],
|
||||||
|
kv_cache_dtype: str,
|
||||||
|
blocksparse_params: Optional[Dict[str, Any]] = None,
|
||||||
|
) -> None:
|
||||||
|
assert blocksparse_params is not None
|
||||||
|
assert alibi_slopes is None, ValueError(
|
||||||
|
"Alibi not support for blocksparse flash attention.")
|
||||||
|
assert sliding_window is None, ValueError(
|
||||||
|
"sliding_window is invalid for blocksparse attention.")
|
||||||
|
|
||||||
|
if "num_heads" not in blocksparse_params:
|
||||||
|
blocksparse_params["num_heads"] = num_heads
|
||||||
|
if "num_kv_heads" not in blocksparse_params:
|
||||||
|
blocksparse_params["num_kv_heads"] = num_kv_heads or num_heads
|
||||||
|
self.blocksparse_params = BlocksparseParams(**blocksparse_params)
|
||||||
|
self.kv_cache_dtype = kv_cache_dtype
|
||||||
|
|
||||||
|
self.num_heads = num_heads
|
||||||
|
self.head_size = head_size
|
||||||
|
self.scale = float(scale)
|
||||||
|
self.alibi_slopes = alibi_slopes
|
||||||
|
self.num_kv_heads = num_heads if num_kv_heads is None else num_kv_heads
|
||||||
|
|
||||||
|
assert self.num_heads % self.num_kv_heads == 0
|
||||||
|
self.num_queries_per_kv = self.num_heads // self.num_kv_heads
|
||||||
|
|
||||||
|
self.local_blocks = self.blocksparse_params.local_blocks
|
||||||
|
self.vert_stride = self.blocksparse_params.vert_stride
|
||||||
|
self.sparse_block_size = self.blocksparse_params.block_size
|
||||||
|
self.head_sliding_step = self.blocksparse_params.head_sliding_step
|
||||||
|
|
||||||
|
suppored_head_sizes = PagedAttention.get_supported_head_sizes()
|
||||||
|
if head_size not in suppored_head_sizes:
|
||||||
|
raise ValueError(
|
||||||
|
f"Head size {head_size} is not supported by PagedAttention. "
|
||||||
|
f"Supported head sizes are: {suppored_head_sizes}.")
|
||||||
|
|
||||||
|
self.tp_size = get_tensor_model_parallel_world_size()
|
||||||
|
self.tp_rank = get_tensor_model_parallel_rank()
|
||||||
|
|
||||||
|
total_num_heads = num_heads * self.tp_size
|
||||||
|
self.bs_attn = LocalStridedBlockSparseAttn(
|
||||||
|
total_num_heads,
|
||||||
|
self.blocksparse_params.max_seqlen,
|
||||||
|
self.blocksparse_params.local_blocks,
|
||||||
|
self.blocksparse_params.vert_stride,
|
||||||
|
self.blocksparse_params.block_size,
|
||||||
|
homo_head=self.blocksparse_params.homo_head,
|
||||||
|
active_head_range=self.blocksparse_params.active_head_range,
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
query: torch.Tensor,
|
||||||
|
key: torch.Tensor,
|
||||||
|
value: torch.Tensor,
|
||||||
|
kv_cache: torch.Tensor,
|
||||||
|
attn_metadata: BlocksparseFlashAttentionMetadata,
|
||||||
|
kv_scale: float = 1.0,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
"""Forward pass with FlashAttention and PagedAttention.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
query: shape = [num_tokens, num_heads * head_size]
|
||||||
|
key: shape = [num_tokens, num_kv_heads * head_size]
|
||||||
|
value: shape = [num_tokens, num_kv_heads * head_size]
|
||||||
|
kv_cache = [2, num_blocks, block_size * num_kv_heads * head_size]
|
||||||
|
attn_metadata: Metadata for attention.
|
||||||
|
Returns:
|
||||||
|
shape = [num_tokens, num_heads * head_size]
|
||||||
|
"""
|
||||||
|
num_tokens, hidden_size = query.shape
|
||||||
|
# Reshape the query, key, and value tensors.
|
||||||
|
query = query.view(-1, self.num_heads, self.head_size)
|
||||||
|
key = key.view(-1, self.num_kv_heads, self.head_size)
|
||||||
|
value = value.view(-1, self.num_kv_heads, self.head_size)
|
||||||
|
|
||||||
|
if kv_cache is not None:
|
||||||
|
key_cache, value_cache = PagedAttention.split_kv_cache(
|
||||||
|
kv_cache, self.num_kv_heads, self.head_size)
|
||||||
|
|
||||||
|
# Reshape the input keys and values and store them in the cache.
|
||||||
|
# If kv_cache is not provided, the new key and value tensors are
|
||||||
|
# not cached. This happens during the initial memory profiling run.
|
||||||
|
|
||||||
|
PagedAttention.write_to_paged_cache(
|
||||||
|
key,
|
||||||
|
value,
|
||||||
|
key_cache,
|
||||||
|
value_cache,
|
||||||
|
attn_metadata.slot_mapping,
|
||||||
|
self.kv_cache_dtype,
|
||||||
|
kv_scale,
|
||||||
|
)
|
||||||
|
|
||||||
|
if prefill_meta := attn_metadata.prefill_metadata:
|
||||||
|
|
||||||
|
# Prompt run.
|
||||||
|
# normal attention
|
||||||
|
# When block_tables are not filled, it means q and k are the
|
||||||
|
# prompt, and they have the same length.
|
||||||
|
|
||||||
|
assert kv_cache is None \
|
||||||
|
or prefill_meta.block_tables is None \
|
||||||
|
or prefill_meta.block_tables.numel() == 0, \
|
||||||
|
"Does not support prefix-enabled attention."
|
||||||
|
|
||||||
|
output = self.bs_attn(
|
||||||
|
q=query,
|
||||||
|
k=key,
|
||||||
|
v=value,
|
||||||
|
cu_seqlens_q=prefill_meta.seq_start_loc,
|
||||||
|
cu_seqlens_k=prefill_meta.seq_start_loc,
|
||||||
|
sm_scale=self.scale,
|
||||||
|
)
|
||||||
|
|
||||||
|
if decode_meta := attn_metadata.decode_metadata:
|
||||||
|
# Decoding run.
|
||||||
|
output = PagedAttention.forward_decode(
|
||||||
|
query,
|
||||||
|
key_cache,
|
||||||
|
value_cache,
|
||||||
|
decode_meta.block_tables,
|
||||||
|
decode_meta.seq_lens_tensor,
|
||||||
|
self.blocksparse_params.max_seqlen,
|
||||||
|
self.kv_cache_dtype,
|
||||||
|
self.num_kv_heads,
|
||||||
|
self.scale,
|
||||||
|
self.alibi_slopes,
|
||||||
|
kv_scale,
|
||||||
|
tp_rank=self.tp_rank,
|
||||||
|
blocksparse_local_blocks=self.local_blocks,
|
||||||
|
blocksparse_vert_stride=self.vert_stride,
|
||||||
|
blocksparse_block_size=self.sparse_block_size,
|
||||||
|
blocksparse_head_sliding_step=self.head_sliding_step,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Reshape the output tensor.
|
||||||
|
return output.view(num_tokens, hidden_size)
|
@ -1,6 +1,6 @@
|
|||||||
"""Attention layer with FlashAttention."""
|
"""Attention layer with FlashAttention."""
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import List, Optional, Tuple, Type
|
from typing import Any, Dict, List, Optional, Tuple, Type
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from vllm_flash_attn import flash_attn_varlen_func, flash_attn_with_kvcache
|
from vllm_flash_attn import flash_attn_varlen_func, flash_attn_with_kvcache
|
||||||
@ -219,7 +219,10 @@ class FlashAttentionImpl(AttentionImpl):
|
|||||||
alibi_slopes: Optional[List[float]],
|
alibi_slopes: Optional[List[float]],
|
||||||
sliding_window: Optional[int],
|
sliding_window: Optional[int],
|
||||||
kv_cache_dtype: str,
|
kv_cache_dtype: str,
|
||||||
|
blocksparse_params: Optional[Dict[str, Any]] = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
|
assert blocksparse_params is None, ValueError(
|
||||||
|
"FlashAttention does not support block-sparse attention.")
|
||||||
self.num_heads = num_heads
|
self.num_heads = num_heads
|
||||||
self.head_size = head_size
|
self.head_size = head_size
|
||||||
self.scale = float(scale)
|
self.scale = float(scale)
|
||||||
|
@ -1,6 +1,6 @@
|
|||||||
"""Attention layer ROCm GPUs."""
|
"""Attention layer ROCm GPUs."""
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import List, Optional, Tuple, Type
|
from typing import Any, Dict, List, Optional, Tuple, Type
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
@ -201,7 +201,10 @@ class ROCmFlashAttentionImpl(AttentionImpl):
|
|||||||
alibi_slopes: Optional[List[float]],
|
alibi_slopes: Optional[List[float]],
|
||||||
sliding_window: Optional[int],
|
sliding_window: Optional[int],
|
||||||
kv_cache_dtype: str,
|
kv_cache_dtype: str,
|
||||||
|
blocksparse_params: Optional[Dict[str, Any]] = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
|
assert blocksparse_params is None, ValueError(
|
||||||
|
"ROCFlashAttention does not support blocksparse attention.")
|
||||||
self.num_heads = num_heads
|
self.num_heads = num_heads
|
||||||
self.head_size = head_size
|
self.head_size = head_size
|
||||||
self.scale = float(scale)
|
self.scale = float(scale)
|
||||||
|
@ -1,7 +1,7 @@
|
|||||||
""" Attention layer with torch scaled_dot_product_attention
|
""" Attention layer with torch scaled_dot_product_attention
|
||||||
and PagedAttention."""
|
and PagedAttention."""
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import List, Optional, Tuple, Type
|
from typing import Any, Dict, List, Optional, Tuple, Type
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from torch.nn.functional import scaled_dot_product_attention
|
from torch.nn.functional import scaled_dot_product_attention
|
||||||
@ -100,7 +100,10 @@ class TorchSDPABackendImpl(AttentionImpl[TorchSDPAMetadata]):
|
|||||||
alibi_slopes: Optional[List[float]],
|
alibi_slopes: Optional[List[float]],
|
||||||
sliding_window: Optional[int],
|
sliding_window: Optional[int],
|
||||||
kv_cache_dtype: str,
|
kv_cache_dtype: str,
|
||||||
|
blocksparse_params: Optional[Dict[str, Any]] = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
|
assert blocksparse_params is None, ValueError(
|
||||||
|
"Torch SPDA does not support block-sparse attention.")
|
||||||
self.num_heads = num_heads
|
self.num_heads = num_heads
|
||||||
self.head_size = head_size
|
self.head_size = head_size
|
||||||
self.scale = float(scale)
|
self.scale = float(scale)
|
||||||
|
@ -1,6 +1,6 @@
|
|||||||
"""Attention layer with xFormers and PagedAttention."""
|
"""Attention layer with xFormers and PagedAttention."""
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import Dict, List, Optional, Tuple, Type
|
from typing import Any, Dict, List, Optional, Tuple, Type
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from xformers import ops as xops
|
from xformers import ops as xops
|
||||||
@ -212,7 +212,10 @@ class XFormersImpl(AttentionImpl[XFormersMetadata]):
|
|||||||
alibi_slopes: Optional[List[float]],
|
alibi_slopes: Optional[List[float]],
|
||||||
sliding_window: Optional[int],
|
sliding_window: Optional[int],
|
||||||
kv_cache_dtype: str,
|
kv_cache_dtype: str,
|
||||||
|
blocksparse_params: Optional[Dict[str, Any]] = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
|
assert blocksparse_params is None, ValueError(
|
||||||
|
"XFormer does not support block-sparse attention.")
|
||||||
self.num_heads = num_heads
|
self.num_heads = num_heads
|
||||||
self.head_size = head_size
|
self.head_size = head_size
|
||||||
self.scale = float(scale)
|
self.scale = float(scale)
|
||||||
|
@ -1,5 +1,5 @@
|
|||||||
"""Attention layer."""
|
"""Attention layer."""
|
||||||
from typing import List, Optional
|
from typing import Any, Dict, List, Optional
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
@ -33,6 +33,7 @@ class Attention(nn.Module):
|
|||||||
sliding_window: Optional[int] = None,
|
sliding_window: Optional[int] = None,
|
||||||
cache_config: Optional[CacheConfig] = None,
|
cache_config: Optional[CacheConfig] = None,
|
||||||
quant_config: Optional[QuantizationConfig] = None,
|
quant_config: Optional[QuantizationConfig] = None,
|
||||||
|
blocksparse_params: Optional[Dict[str, Any]] = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
if cache_config is not None:
|
if cache_config is not None:
|
||||||
@ -69,10 +70,12 @@ class Attention(nn.Module):
|
|||||||
dtype = torch.get_default_dtype()
|
dtype = torch.get_default_dtype()
|
||||||
attn_backend = get_attn_backend(num_heads, head_size, num_kv_heads,
|
attn_backend = get_attn_backend(num_heads, head_size, num_kv_heads,
|
||||||
sliding_window, dtype, kv_cache_dtype,
|
sliding_window, dtype, kv_cache_dtype,
|
||||||
block_size)
|
block_size, blocksparse_params
|
||||||
|
is not None)
|
||||||
impl_cls = attn_backend.get_impl_cls()
|
impl_cls = attn_backend.get_impl_cls()
|
||||||
self.impl = impl_cls(num_heads, head_size, scale, num_kv_heads,
|
self.impl = impl_cls(num_heads, head_size, scale, num_kv_heads,
|
||||||
alibi_slopes, sliding_window, kv_cache_dtype)
|
alibi_slopes, sliding_window, kv_cache_dtype,
|
||||||
|
blocksparse_params)
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
@ -90,4 +93,5 @@ class Attention(nn.Module):
|
|||||||
s += f", num_heads={self.impl.num_heads}" # type: ignore
|
s += f", num_heads={self.impl.num_heads}" # type: ignore
|
||||||
s += f", num_kv_heads={self.impl.num_kv_heads}" # type: ignore
|
s += f", num_kv_heads={self.impl.num_kv_heads}" # type: ignore
|
||||||
s += f", scale={self.impl.scale}" # type: ignore
|
s += f", scale={self.impl.scale}" # type: ignore
|
||||||
|
s += f", backend={self.impl.__class__.__name__}"
|
||||||
return s
|
return s
|
||||||
|
@ -0,0 +1,423 @@
|
|||||||
|
import torch
|
||||||
|
import triton
|
||||||
|
import triton.language as tl
|
||||||
|
|
||||||
|
|
||||||
|
def blocksparse_flash_attn_varlen_fwd(
|
||||||
|
q,
|
||||||
|
k,
|
||||||
|
v, # (#tokens, n_heads, head_size)
|
||||||
|
cu_seqlens_k,
|
||||||
|
cu_seqlens_q,
|
||||||
|
sm_scale,
|
||||||
|
sparse_layout,
|
||||||
|
*,
|
||||||
|
block_size=64,
|
||||||
|
q_block_size=None,
|
||||||
|
max_seqlen=None):
|
||||||
|
# split q to blocks
|
||||||
|
|
||||||
|
assert isinstance(sparse_layout, (list, tuple))
|
||||||
|
|
||||||
|
_, n_heads, head_size = q.shape
|
||||||
|
batch_size = cu_seqlens_k.size(0) - 1
|
||||||
|
q_block_size = q_block_size or block_size
|
||||||
|
|
||||||
|
assert q.dim() == k.dim() == v.dim() == 3
|
||||||
|
assert q.size(1) % k.size(1) == 0
|
||||||
|
assert q.size(2) == k.size(2)
|
||||||
|
# TODO(linxihui): allow k, v to have different head_size
|
||||||
|
assert k.shape == v.shape
|
||||||
|
assert cu_seqlens_k.dim() == 1
|
||||||
|
|
||||||
|
q_k_ratio = q.size(1) // k.size(1)
|
||||||
|
|
||||||
|
if cu_seqlens_q is None:
|
||||||
|
if q.size(0) == batch_size: # decoding only
|
||||||
|
cu_seqlens_q = torch.arange(
|
||||||
|
0,
|
||||||
|
batch_size + 1,
|
||||||
|
dtype=cu_seqlens_k.dtype,
|
||||||
|
device=cu_seqlens_k.device,
|
||||||
|
)
|
||||||
|
elif q.size(0) == k.size(0):
|
||||||
|
cu_seqlens_q = cu_seqlens_k
|
||||||
|
else:
|
||||||
|
raise ValueError("cu_seqlens_q must be specified\
|
||||||
|
if it mix of prefilling and decoding.")
|
||||||
|
else:
|
||||||
|
assert cu_seqlens_k.size(0) == cu_seqlens_q.size(0)
|
||||||
|
|
||||||
|
# switch to use cpu to avoid too many kernel launches when iterated over
|
||||||
|
q_lens = (cu_seqlens_q[1:] - cu_seqlens_q[:-1]).cpu()
|
||||||
|
k_lens = (cu_seqlens_k[1:] - cu_seqlens_k[:-1]).cpu()
|
||||||
|
|
||||||
|
assert torch.logical_or(q_lens == 1, k_lens == q_lens).all(), (
|
||||||
|
"length of q should either be 1 (decoding) or same as k (prefilling).")
|
||||||
|
|
||||||
|
if max_seqlen:
|
||||||
|
assert k_lens.max() <= max_seqlen
|
||||||
|
|
||||||
|
n_blocks = (q_lens + q_block_size - 1) // q_block_size
|
||||||
|
|
||||||
|
q_batch_ids = torch.tensor(
|
||||||
|
[i for i, n in enumerate(n_blocks) for _ in range(n)],
|
||||||
|
dtype=cu_seqlens_q.dtype,
|
||||||
|
device=cu_seqlens_q.device,
|
||||||
|
)
|
||||||
|
q_start_sids = torch.tensor(
|
||||||
|
[i * q_block_size for n in n_blocks for i in range(n)],
|
||||||
|
dtype=cu_seqlens_q.dtype,
|
||||||
|
device=cu_seqlens_q.device,
|
||||||
|
)
|
||||||
|
|
||||||
|
out = q.new_empty(q.shape)
|
||||||
|
cu_seqlens_q = cu_seqlens_q.contiguous()
|
||||||
|
cu_seqlens_k = cu_seqlens_k.contiguous()
|
||||||
|
|
||||||
|
layout_crow_indices, layout_col_indices = sparse_layout
|
||||||
|
block_d = triton.next_power_of_2(head_size)
|
||||||
|
|
||||||
|
decoding_only = (q_lens == 1).all().item()
|
||||||
|
grid = (len(q_start_sids), n_heads, 1)
|
||||||
|
|
||||||
|
_fwd_kernel_batch_inference[grid](
|
||||||
|
q,
|
||||||
|
k,
|
||||||
|
v,
|
||||||
|
out,
|
||||||
|
sm_scale,
|
||||||
|
cu_seqlens_q[:-1],
|
||||||
|
cu_seqlens_q[1:],
|
||||||
|
cu_seqlens_k[:-1],
|
||||||
|
cu_seqlens_k[1:],
|
||||||
|
q_batch_ids,
|
||||||
|
q_start_sids,
|
||||||
|
0,
|
||||||
|
*q.stride(),
|
||||||
|
0,
|
||||||
|
*k.stride(),
|
||||||
|
0,
|
||||||
|
*v.stride(),
|
||||||
|
0,
|
||||||
|
*out.stride(),
|
||||||
|
layout_crow_indices,
|
||||||
|
layout_col_indices,
|
||||||
|
*layout_crow_indices.stride(),
|
||||||
|
*layout_col_indices.stride(),
|
||||||
|
q_k_ratio,
|
||||||
|
HAS_BATCH_DIM=False,
|
||||||
|
D_HEAD=head_size,
|
||||||
|
BLOCK_M=q_block_size,
|
||||||
|
BLOCK_N=block_size,
|
||||||
|
BLOCK_D=block_d,
|
||||||
|
BLOCK_M_LOADING=(16 if decoding_only else
|
||||||
|
q_block_size), # smaller for decoding
|
||||||
|
EVEN_D=block_d == head_size,
|
||||||
|
num_warps=1 if decoding_only else 4,
|
||||||
|
num_stages=3)
|
||||||
|
|
||||||
|
return out
|
||||||
|
|
||||||
|
|
||||||
|
@triton.jit
|
||||||
|
def _fwd_kernel_inner(
|
||||||
|
acc,
|
||||||
|
l_i,
|
||||||
|
m_i,
|
||||||
|
q,
|
||||||
|
Q,
|
||||||
|
k_block_col_idx,
|
||||||
|
layout_col_ptr,
|
||||||
|
layout_col_stride_h,
|
||||||
|
layout_col_stride_m,
|
||||||
|
k_ptrs,
|
||||||
|
v_ptrs,
|
||||||
|
off_h,
|
||||||
|
offs_m,
|
||||||
|
offs_n,
|
||||||
|
offs_d,
|
||||||
|
stride_kt,
|
||||||
|
stride_vt,
|
||||||
|
sm_scale,
|
||||||
|
k_seqlen,
|
||||||
|
past_len,
|
||||||
|
LAST_K_BLOCK: tl.constexpr,
|
||||||
|
BLOCK_M_LOADING: tl.constexpr,
|
||||||
|
BLOCK_N: tl.constexpr,
|
||||||
|
D_HEAD: tl.constexpr,
|
||||||
|
EVEN_D: tl.constexpr,
|
||||||
|
M_LT_N: tl.constexpr,
|
||||||
|
):
|
||||||
|
k_block_id = tl.load(layout_col_ptr + off_h * layout_col_stride_h +
|
||||||
|
k_block_col_idx * layout_col_stride_m).to(tl.int32)
|
||||||
|
start_n = k_block_id * BLOCK_N
|
||||||
|
if LAST_K_BLOCK:
|
||||||
|
if EVEN_D:
|
||||||
|
k = tl.load(
|
||||||
|
k_ptrs + start_n * stride_kt,
|
||||||
|
mask=offs_n[None, :] + start_n < k_seqlen,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
k = tl.load(
|
||||||
|
k_ptrs + start_n * stride_kt,
|
||||||
|
mask=(offs_n[None, :] + start_n < k_seqlen) &
|
||||||
|
(offs_d[:, None] < D_HEAD),
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
if EVEN_D:
|
||||||
|
k = tl.load(k_ptrs + start_n * stride_kt)
|
||||||
|
else:
|
||||||
|
k = tl.load(k_ptrs + start_n * stride_kt,
|
||||||
|
mask=offs_d[:, None] < D_HEAD)
|
||||||
|
|
||||||
|
qk = tl.zeros([BLOCK_M_LOADING, BLOCK_N], dtype=tl.float32)
|
||||||
|
qk += tl.dot(q, k)
|
||||||
|
qk *= sm_scale
|
||||||
|
|
||||||
|
# the following is needed only when LAST_K_BLOCK or BLOCK_M < BLOCK_N
|
||||||
|
if LAST_K_BLOCK | M_LT_N:
|
||||||
|
qk += tl.where(
|
||||||
|
offs_m[:, None] + past_len >= (start_n + offs_n[None, :]),
|
||||||
|
0,
|
||||||
|
float("-inf"),
|
||||||
|
)
|
||||||
|
|
||||||
|
# flash-attn2
|
||||||
|
m_ij = tl.maximum(m_i, tl.max(qk, 1))
|
||||||
|
p = tl.math.exp2(qk - m_ij[:, None])
|
||||||
|
l_ij = tl.sum(p, 1)
|
||||||
|
alpha = tl.math.exp2(m_i - m_ij)
|
||||||
|
acc = acc * alpha[:, None]
|
||||||
|
# update m_i
|
||||||
|
m_i = m_ij
|
||||||
|
l_i = l_i * alpha + l_ij
|
||||||
|
|
||||||
|
p = p.to(Q.dtype.element_ty)
|
||||||
|
# update acc
|
||||||
|
if LAST_K_BLOCK:
|
||||||
|
if EVEN_D:
|
||||||
|
v = tl.load(
|
||||||
|
v_ptrs + start_n * stride_vt,
|
||||||
|
mask=offs_n[:, None] + start_n < k_seqlen,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
v = tl.load(
|
||||||
|
v_ptrs + start_n * stride_vt,
|
||||||
|
mask=(offs_n[:, None] + start_n < k_seqlen) &
|
||||||
|
(offs_d[None, :] < D_HEAD),
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
if EVEN_D:
|
||||||
|
v = tl.load(v_ptrs + start_n * stride_vt)
|
||||||
|
else:
|
||||||
|
v = tl.load(v_ptrs + start_n * stride_vt,
|
||||||
|
mask=offs_d[None, :] < D_HEAD)
|
||||||
|
|
||||||
|
acc += tl.dot(p, v)
|
||||||
|
|
||||||
|
return acc, l_i, m_i
|
||||||
|
|
||||||
|
|
||||||
|
@triton.heuristics({
|
||||||
|
"M_LT_N":
|
||||||
|
lambda kwargs: kwargs["BLOCK_M"] < kwargs["BLOCK_N"],
|
||||||
|
})
|
||||||
|
@triton.jit
|
||||||
|
def _fwd_kernel_batch_inference(
|
||||||
|
Q,
|
||||||
|
K,
|
||||||
|
V,
|
||||||
|
Out,
|
||||||
|
sm_scale,
|
||||||
|
q_batch_starts,
|
||||||
|
q_batch_ends,
|
||||||
|
k_batch_starts,
|
||||||
|
k_batch_ends,
|
||||||
|
q_batch_ids,
|
||||||
|
q_start_sids,
|
||||||
|
stride_qb,
|
||||||
|
stride_qt,
|
||||||
|
stride_qh,
|
||||||
|
stride_qd,
|
||||||
|
stride_kb,
|
||||||
|
stride_kt,
|
||||||
|
stride_kh,
|
||||||
|
stride_kd,
|
||||||
|
stride_vb,
|
||||||
|
stride_vt,
|
||||||
|
stride_vh,
|
||||||
|
stride_vd,
|
||||||
|
stride_ob,
|
||||||
|
stride_ot,
|
||||||
|
stride_oh,
|
||||||
|
stride_od,
|
||||||
|
layout_crow_ptr,
|
||||||
|
layout_col_ptr,
|
||||||
|
layout_crow_stride_h,
|
||||||
|
layout_crow_stride_m,
|
||||||
|
layout_col_stride_h,
|
||||||
|
layout_col_stride_m,
|
||||||
|
q_k_ratio,
|
||||||
|
HAS_BATCH_DIM: tl.constexpr,
|
||||||
|
D_HEAD: tl.constexpr,
|
||||||
|
BLOCK_M: tl.constexpr,
|
||||||
|
BLOCK_N: tl.constexpr,
|
||||||
|
BLOCK_D: tl.constexpr,
|
||||||
|
BLOCK_M_LOADING: tl.constexpr,
|
||||||
|
EVEN_D: tl.constexpr,
|
||||||
|
M_LT_N: tl.constexpr,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
NOTATION:
|
||||||
|
pid: position id
|
||||||
|
sid: storage id
|
||||||
|
sbid: storage block id
|
||||||
|
pbid: position block id
|
||||||
|
offs_m, offs_n: storage offsets of m-dim(q, row) and n-dim(k, col)
|
||||||
|
|
||||||
|
TODO(linxihui):
|
||||||
|
Optimize grouped-attn
|
||||||
|
"""
|
||||||
|
off_zm = tl.program_id(0)
|
||||||
|
off_h = tl.program_id(1)
|
||||||
|
|
||||||
|
off_h_for_kv = off_h // q_k_ratio
|
||||||
|
|
||||||
|
if HAS_BATCH_DIM:
|
||||||
|
off_z = tl.program_id(2)
|
||||||
|
Q += off_z * stride_qb
|
||||||
|
K += off_z * stride_kb
|
||||||
|
V += off_z * stride_vb
|
||||||
|
Out += off_z * stride_ob
|
||||||
|
start_m = off_zm
|
||||||
|
q_start_sid = start_m * BLOCK_M # always 0 for decoding
|
||||||
|
else:
|
||||||
|
off_z = tl.load(q_batch_ids + off_zm).to(tl.int32) # [0, 0, 0, 1]
|
||||||
|
q_start_sid = tl.load(q_start_sids + off_zm)
|
||||||
|
start_m = q_start_sid // BLOCK_M # q_sbid
|
||||||
|
|
||||||
|
offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M_LOADING)
|
||||||
|
offs_n = tl.arange(0, BLOCK_N)
|
||||||
|
offs_d = tl.arange(0, BLOCK_D)
|
||||||
|
|
||||||
|
q_cu_start = tl.load(q_batch_starts + off_z).to(tl.int32)
|
||||||
|
q_seqlen = tl.load(q_batch_ends + off_z).to(tl.int32) - q_cu_start
|
||||||
|
k_cu_start = tl.load(k_batch_starts + off_z).to(tl.int32)
|
||||||
|
k_seqlen = tl.load(k_batch_ends + off_z).to(tl.int32) - k_cu_start
|
||||||
|
past_len = k_seqlen - q_seqlen
|
||||||
|
|
||||||
|
Q += q_cu_start * stride_qt + off_h * stride_qh
|
||||||
|
K += k_cu_start * stride_kt + off_h_for_kv * stride_kh
|
||||||
|
V += k_cu_start * stride_vt + off_h_for_kv * stride_vh
|
||||||
|
Out += q_cu_start * stride_ot + off_h * stride_oh
|
||||||
|
|
||||||
|
q_pbid = (past_len + q_start_sid) // BLOCK_M
|
||||||
|
|
||||||
|
if EVEN_D:
|
||||||
|
q = tl.load(
|
||||||
|
Q + offs_m[:, None] * stride_qt + offs_d[None, :] * stride_qd,
|
||||||
|
mask=offs_m[:, None] < q_seqlen,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
q = tl.load(
|
||||||
|
Q + offs_m[:, None] * stride_qt + offs_d[None, :] * stride_qd,
|
||||||
|
mask=(offs_m[:, None] < q_seqlen) & (offs_d[None, :] < D_HEAD),
|
||||||
|
other=0,
|
||||||
|
)
|
||||||
|
|
||||||
|
sparse_crow_ptr = (layout_crow_ptr + off_h * layout_crow_stride_h +
|
||||||
|
q_pbid * layout_crow_stride_m)
|
||||||
|
|
||||||
|
# TODO(linxihui): load at once, with any Triton version
|
||||||
|
# that supports `tl.split`, e.g., Triton 3.0
|
||||||
|
k_block_start = tl.load(sparse_crow_ptr).to(tl.int32)
|
||||||
|
k_block_end = tl.load(sparse_crow_ptr + 1).to(tl.int32)
|
||||||
|
|
||||||
|
m_i = tl.zeros([BLOCK_M_LOADING], dtype=tl.float32) - float("inf")
|
||||||
|
l_i = tl.zeros([BLOCK_M_LOADING], dtype=tl.float32)
|
||||||
|
acc = tl.zeros([BLOCK_M_LOADING, BLOCK_D], dtype=tl.float32)
|
||||||
|
|
||||||
|
k_ptrs = K + offs_n[None, :] * stride_kt + offs_d[:, None] * stride_kd
|
||||||
|
v_ptrs = V + offs_n[:, None] * stride_vt + offs_d[None, :] * stride_vd
|
||||||
|
|
||||||
|
sm_scale *= (
|
||||||
|
1.44269504 # 1/log2 as we use base2 for exponential and logarithm
|
||||||
|
)
|
||||||
|
|
||||||
|
for k_block_col_idx in range(k_block_start, k_block_end - 1):
|
||||||
|
acc, l_i, m_i = _fwd_kernel_inner(
|
||||||
|
acc,
|
||||||
|
l_i,
|
||||||
|
m_i,
|
||||||
|
q,
|
||||||
|
Q,
|
||||||
|
k_block_col_idx,
|
||||||
|
layout_col_ptr,
|
||||||
|
layout_col_stride_h,
|
||||||
|
layout_col_stride_m,
|
||||||
|
k_ptrs,
|
||||||
|
v_ptrs,
|
||||||
|
off_h,
|
||||||
|
offs_m,
|
||||||
|
offs_n,
|
||||||
|
offs_d,
|
||||||
|
stride_kt,
|
||||||
|
stride_vt,
|
||||||
|
sm_scale,
|
||||||
|
k_seqlen,
|
||||||
|
past_len,
|
||||||
|
False,
|
||||||
|
BLOCK_M_LOADING,
|
||||||
|
BLOCK_N,
|
||||||
|
D_HEAD,
|
||||||
|
EVEN_D,
|
||||||
|
M_LT_N,
|
||||||
|
)
|
||||||
|
|
||||||
|
acc, l_i, m_i = _fwd_kernel_inner(
|
||||||
|
acc,
|
||||||
|
l_i,
|
||||||
|
m_i,
|
||||||
|
q,
|
||||||
|
Q,
|
||||||
|
k_block_end - 1,
|
||||||
|
layout_col_ptr,
|
||||||
|
layout_col_stride_h,
|
||||||
|
layout_col_stride_m,
|
||||||
|
k_ptrs,
|
||||||
|
v_ptrs,
|
||||||
|
off_h,
|
||||||
|
offs_m,
|
||||||
|
offs_n,
|
||||||
|
offs_d,
|
||||||
|
stride_kt,
|
||||||
|
stride_vt,
|
||||||
|
sm_scale,
|
||||||
|
k_seqlen,
|
||||||
|
past_len,
|
||||||
|
True,
|
||||||
|
BLOCK_M_LOADING,
|
||||||
|
BLOCK_N,
|
||||||
|
D_HEAD,
|
||||||
|
EVEN_D,
|
||||||
|
M_LT_N,
|
||||||
|
)
|
||||||
|
|
||||||
|
# flash-attn 2
|
||||||
|
m_i += tl.math.log2(l_i)
|
||||||
|
acc = acc / l_i[:, None]
|
||||||
|
|
||||||
|
# write output
|
||||||
|
if EVEN_D:
|
||||||
|
tl.store(
|
||||||
|
Out + offs_m[:, None] * stride_ot + offs_d[None, :] * stride_od,
|
||||||
|
acc,
|
||||||
|
mask=offs_m[:, None] < q_seqlen,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
tl.store(
|
||||||
|
Out + offs_m[:, None] * stride_ot + offs_d[None, :] * stride_od,
|
||||||
|
acc,
|
||||||
|
mask=(offs_m[:, None] < q_seqlen) & (offs_d[None, :] < D_HEAD),
|
||||||
|
)
|
238
vllm/attention/ops/blocksparse_attention/interface.py
Normal file
238
vllm/attention/ops/blocksparse_attention/interface.py
Normal file
@ -0,0 +1,238 @@
|
|||||||
|
import math
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from vllm.utils import is_cpu, is_hip
|
||||||
|
|
||||||
|
from .utils import (dense_to_crow_col, get_head_sliding_step,
|
||||||
|
get_sparse_attn_mask)
|
||||||
|
|
||||||
|
IS_COMPUTE_8_OR_ABOVE = (torch.cuda.is_available()
|
||||||
|
and torch.cuda.get_device_capability()[0] >= 8)
|
||||||
|
|
||||||
|
if IS_COMPUTE_8_OR_ABOVE:
|
||||||
|
from .blocksparse_attention_kernel import blocksparse_flash_attn_varlen_fwd
|
||||||
|
|
||||||
|
|
||||||
|
class LocalStridedBlockSparseAttn(torch.nn.Module):
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
n_heads,
|
||||||
|
max_seqlen,
|
||||||
|
local_blocks,
|
||||||
|
vert_stride,
|
||||||
|
block_size,
|
||||||
|
device=None,
|
||||||
|
dtype=None,
|
||||||
|
homo_head=False,
|
||||||
|
active_head_range=None,
|
||||||
|
q_block_size=None,
|
||||||
|
use_spda=None,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
if use_spda is None:
|
||||||
|
use_spda = is_hip() or is_cpu() or not \
|
||||||
|
IS_COMPUTE_8_OR_ABOVE
|
||||||
|
device = device or (torch.cuda.current_device()
|
||||||
|
if torch.cuda.is_available() else "cpu")
|
||||||
|
device = torch.device(device)
|
||||||
|
# NOTE: vllm CPU backend support BF16 instead of FP16.
|
||||||
|
dtype = dtype or (torch.bfloat16 if IS_COMPUTE_8_OR_ABOVE
|
||||||
|
or device.type == "cpu" else torch.half)
|
||||||
|
|
||||||
|
self.n_heads = n_heads
|
||||||
|
self.max_seqlen = max_seqlen
|
||||||
|
self.local_blocks = local_blocks
|
||||||
|
self.vert_stride = vert_stride
|
||||||
|
self.use_spda = use_spda
|
||||||
|
self.dtype = dtype
|
||||||
|
self.device = device
|
||||||
|
self.block_size = block_size
|
||||||
|
self.q_block_size = q_block_size
|
||||||
|
self.homo_head = homo_head
|
||||||
|
self.active_head_range = active_head_range
|
||||||
|
self.head_sliding_step = get_head_sliding_step(n_heads, vert_stride,
|
||||||
|
homo_head)
|
||||||
|
|
||||||
|
sparse_layout, sparse_pattern, self.dense_attn_mask = (
|
||||||
|
self.get_attn_pattern(dtype, device))
|
||||||
|
|
||||||
|
if q_block_size is not None and q_block_size != block_size:
|
||||||
|
if q_block_size > block_size:
|
||||||
|
assert q_block_size % block_size == 0
|
||||||
|
blocks_to_merge = q_block_size // block_size
|
||||||
|
shape = sparse_pattern.shape
|
||||||
|
sparse_pattern = sparse_pattern.view(shape[0], -1,
|
||||||
|
blocks_to_merge,
|
||||||
|
shape[-1])
|
||||||
|
sparse_pattern = sparse_pattern.sum(2)
|
||||||
|
sparse_layout = dense_to_crow_col(sparse_pattern)
|
||||||
|
else:
|
||||||
|
raise ValueError(
|
||||||
|
"Does not support smaller q_block_size. It will be slower."
|
||||||
|
)
|
||||||
|
|
||||||
|
self.sparse_layout = sparse_layout
|
||||||
|
|
||||||
|
def get_attn_pattern(self, dtype, device):
|
||||||
|
sparse_layout, sparse_pattern, dense_attn_mask = get_sparse_attn_mask(
|
||||||
|
self.n_heads,
|
||||||
|
self.max_seqlen,
|
||||||
|
self.max_seqlen,
|
||||||
|
dtype,
|
||||||
|
device,
|
||||||
|
block_size=self.block_size,
|
||||||
|
local_blocks=self.local_blocks,
|
||||||
|
vert_stride=self.vert_stride,
|
||||||
|
homo_head=self.homo_head,
|
||||||
|
return_dense=self.use_spda,
|
||||||
|
dense_mask_type="bias",
|
||||||
|
)
|
||||||
|
if (not self.homo_head) and (self.active_head_range is not None):
|
||||||
|
assert isinstance(self.active_head_range, tuple)
|
||||||
|
assert (len(self.active_head_range) == 2)
|
||||||
|
h_start, h_end = self.active_head_range
|
||||||
|
sparse_layout = tuple(x[h_start:h_end] for x in sparse_layout)
|
||||||
|
if self.use_spda:
|
||||||
|
dense_attn_mask = dense_attn_mask[h_start:h_end]
|
||||||
|
return sparse_layout, sparse_pattern, dense_attn_mask
|
||||||
|
|
||||||
|
def varlen_attn(self,
|
||||||
|
q,
|
||||||
|
k,
|
||||||
|
v,
|
||||||
|
cu_seqlens_k,
|
||||||
|
cu_seqlens_q=None,
|
||||||
|
sm_scale=None):
|
||||||
|
"""
|
||||||
|
q, k, v: shape = (num_tokens, num_heads_q/kv, head_size).
|
||||||
|
Support grouped attention, with `q[:, i*r:(i*r + r)]`
|
||||||
|
is correspondent to `k[:, i]`, where `r` is the q/k ratio.
|
||||||
|
cu_seqlens_k: shape=(batch_size + 1,),
|
||||||
|
indicating segment of samples,
|
||||||
|
e.g., `k[cu_seqlen[i]:cu_seqlne[i+1]]` is q of sample i
|
||||||
|
cu_seqlens_q: shape=(batch_size + 1, ).
|
||||||
|
Default None: same as cu_seqlens_k for prefilling or
|
||||||
|
[0, 1, .., batch_size] for decoding.
|
||||||
|
The only case you need to specify is when q is a mix of
|
||||||
|
prefilling and decoding.
|
||||||
|
sm_scale: softmax scale, default to 1/sqrt(head_size).
|
||||||
|
|
||||||
|
return: tensor of shape as q.
|
||||||
|
"""
|
||||||
|
assert (
|
||||||
|
IS_COMPUTE_8_OR_ABOVE
|
||||||
|
), "Requires compute capability of 8 or above (Ampere or newer) to use \
|
||||||
|
Triton kernel."
|
||||||
|
|
||||||
|
sm_scale = sm_scale or 1.0 / math.sqrt(q.size(-1))
|
||||||
|
|
||||||
|
return blocksparse_flash_attn_varlen_fwd(
|
||||||
|
q,
|
||||||
|
k,
|
||||||
|
v,
|
||||||
|
cu_seqlens_k,
|
||||||
|
cu_seqlens_q,
|
||||||
|
sm_scale,
|
||||||
|
self.sparse_layout,
|
||||||
|
block_size=self.block_size,
|
||||||
|
q_block_size=self.q_block_size,
|
||||||
|
max_seqlen=self.max_seqlen,
|
||||||
|
)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def transpose_and_pad(x, cu_seqlens, maxlen, head_repeats=1):
|
||||||
|
"""
|
||||||
|
:param x: (total_tokens, n_heads, head_size)
|
||||||
|
:return: (batch, n_heads, length, head_size)
|
||||||
|
"""
|
||||||
|
x_padded = x.new_empty(
|
||||||
|
len(cu_seqlens) - 1, x.size(1), head_repeats, maxlen, x.size(2))
|
||||||
|
cu_seqlens = cu_seqlens.cpu()
|
||||||
|
for i, (s, e) in enumerate(zip(cu_seqlens[:-1], cu_seqlens[1:])):
|
||||||
|
x_padded[i, :, :, :e - s].copy_(x[s:e].transpose(0,
|
||||||
|
1).unsqueeze(1))
|
||||||
|
return x_padded.flatten(1, 2)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def transpose_and_unpad(x_padded, cu_seqlens):
|
||||||
|
"""
|
||||||
|
:param x_padded: (batch, n_heads, length, head_size)
|
||||||
|
:return: (total_tokens, n_heads, head_size)
|
||||||
|
"""
|
||||||
|
cu_seqlens = cu_seqlens.cpu()
|
||||||
|
total_n_tokens = cu_seqlens[-1]
|
||||||
|
x = x_padded.new_empty(total_n_tokens, x_padded.size(1),
|
||||||
|
x_padded.size(3))
|
||||||
|
for i, (s, e) in enumerate(zip(cu_seqlens[:-1], cu_seqlens[1:])):
|
||||||
|
x[s:e].copy_(x_padded[i, :, :e - s].transpose(0, 1))
|
||||||
|
return x
|
||||||
|
|
||||||
|
def spda(self, q, k, v, cu_seqlens_k, cu_seqlens_q=None, sm_scale=None):
|
||||||
|
"""For CPU, V100 or other older GPUs.
|
||||||
|
NOTE: torch SPDA supports nested tensor,
|
||||||
|
but seems extremely slow. Choose to pad instead.
|
||||||
|
"""
|
||||||
|
assert (cu_seqlens_q is None or
|
||||||
|
(cu_seqlens_q
|
||||||
|
== cu_seqlens_k).all()), "Can only handle prompt with SPDA."
|
||||||
|
assert q.size(0) == k.size(0), "can only handle prompt with SPDA."
|
||||||
|
|
||||||
|
assert q.size(1) % k.size(1) == 0
|
||||||
|
q_k_ratio = q.size(1) // k.size(1)
|
||||||
|
sm_scale = sm_scale or 1.0 / math.sqrt(q.size(-1))
|
||||||
|
cu_seqlens = cu_seqlens_k.cpu()
|
||||||
|
maxlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max()
|
||||||
|
|
||||||
|
if (self.dense_attn_mask.dtype != q.dtype
|
||||||
|
or self.dense_attn_mask.device != q.device):
|
||||||
|
_, _, self.dense_attn_mask = self.get_attn_pattern(
|
||||||
|
q.dtype, q.device)
|
||||||
|
attn_mask = self.dense_attn_mask[None, :, :maxlen, :maxlen]
|
||||||
|
|
||||||
|
q2 = self.transpose_and_pad(q, cu_seqlens, maxlen, 1)
|
||||||
|
k2, v2 = [
|
||||||
|
self.transpose_and_pad(x, cu_seqlens, maxlen, q_k_ratio)
|
||||||
|
for x in [k, v]
|
||||||
|
]
|
||||||
|
spda_output = torch.nn.functional.scaled_dot_product_attention(
|
||||||
|
q2, k2, v2, attn_mask=attn_mask, scale=sm_scale)
|
||||||
|
return self.transpose_and_unpad(spda_output, cu_seqlens)
|
||||||
|
|
||||||
|
def forward(self, q, k, v, cu_seqlens_k, cu_seqlens_q=None, sm_scale=None):
|
||||||
|
"""Dispatch to `varlen_attn` (Ampere or newer) or
|
||||||
|
`self.spda`(cpu, Volta, Turing or older)based on
|
||||||
|
the type of device used and cuda compute capability.
|
||||||
|
|
||||||
|
q, k, v: shape = (num_tokens, num_heads_q/kv, head_size).
|
||||||
|
Support grouped attention, with `q[:, i*r:(i*r + r)]`
|
||||||
|
is correspondent to `k[:, i]`, where `r` is the q/k ratio.
|
||||||
|
cu_seqlens_k: shape=(batch_size + 1,), indicating segment of samples,
|
||||||
|
e.g., `k[cu_seqlen[i]:cu_seqlne[i+1]]` is q of sample i
|
||||||
|
cu_seqlens_q: shape=(batch_size + 1, ).
|
||||||
|
Default None: same as cu_seqlens_k for prefilling or
|
||||||
|
[0, 1, .., batch_size] for decoding.
|
||||||
|
The only case you need to specify
|
||||||
|
is when q is a mix of prefilling
|
||||||
|
and decoding.
|
||||||
|
sm_scale: softmax scale, default to 1/sqrt(head_size).
|
||||||
|
|
||||||
|
return: tensor of shape as q.
|
||||||
|
"""
|
||||||
|
assert k.dim() == 3
|
||||||
|
if self.use_spda:
|
||||||
|
return self.spda(
|
||||||
|
q,
|
||||||
|
k,
|
||||||
|
v,
|
||||||
|
cu_seqlens_k,
|
||||||
|
cu_seqlens_q=cu_seqlens_q,
|
||||||
|
sm_scale=sm_scale,
|
||||||
|
)
|
||||||
|
return self.varlen_attn(q,
|
||||||
|
k,
|
||||||
|
v,
|
||||||
|
cu_seqlens_k,
|
||||||
|
cu_seqlens_q=cu_seqlens_q,
|
||||||
|
sm_scale=sm_scale)
|
216
vllm/attention/ops/blocksparse_attention/utils.py
Normal file
216
vllm/attention/ops/blocksparse_attention/utils.py
Normal file
@ -0,0 +1,216 @@
|
|||||||
|
# Helper functions for 3D sparse pattern
|
||||||
|
# These function are not optimized and very inefficient.
|
||||||
|
# Avoid calling them too frequent or use a cache mechanism.
|
||||||
|
|
||||||
|
from functools import lru_cache
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import triton
|
||||||
|
from scipy import sparse
|
||||||
|
|
||||||
|
|
||||||
|
def dense_to_crow_col(x: torch.Tensor):
|
||||||
|
"""Turning a 2D/3D torch tensor (x) to CSR rows/cols indexing.
|
||||||
|
NOTE: col_indices padded -1
|
||||||
|
"""
|
||||||
|
device = x.device
|
||||||
|
pad = -1
|
||||||
|
dim = x.dim()
|
||||||
|
assert x.dim() in (2, 3)
|
||||||
|
if x.dim() == 2:
|
||||||
|
x = x[None]
|
||||||
|
x = [sparse.csr_matrix(xi.bool().cpu().numpy()) for xi in x]
|
||||||
|
crows = torch.vstack([torch.from_numpy(xi.indptr) for xi in x])
|
||||||
|
cols = [torch.from_numpy(xi.indices) for xi in x]
|
||||||
|
max_cols = max(len(xi) for xi in cols)
|
||||||
|
cols = [
|
||||||
|
torch.cat([xi, pad + xi.new_zeros(max_cols - xi.shape[0])])
|
||||||
|
for xi in cols
|
||||||
|
]
|
||||||
|
cols = torch.vstack(cols)
|
||||||
|
if dim == 2:
|
||||||
|
crows = crows[0]
|
||||||
|
cols = cols[0]
|
||||||
|
return crows.to(device), cols.to(device)
|
||||||
|
|
||||||
|
|
||||||
|
def crow_col_to_dense(crows: torch.Tensor,
|
||||||
|
cols: torch.Tensor,
|
||||||
|
dtype: torch.dtype = torch.float16):
|
||||||
|
dim = crows.dim()
|
||||||
|
if dim == 1:
|
||||||
|
crows = crows[None]
|
||||||
|
cols = cols[None]
|
||||||
|
device = crows.device
|
||||||
|
crows, cols = crows.cpu(), cols.cpu() # faster in cpu
|
||||||
|
shape = (crows.shape[0], crows.shape[1] - 1, cols.max() + 1)
|
||||||
|
x = torch.zeros(shape, dtype=dtype)
|
||||||
|
for i in range(shape[0]):
|
||||||
|
for j in range(shape[1]):
|
||||||
|
x[i, j, cols[i, crows[i, j]:crows[i, j + 1]]] = 1
|
||||||
|
if dim == 1:
|
||||||
|
x = x[0]
|
||||||
|
return x.to(device)
|
||||||
|
|
||||||
|
|
||||||
|
def dense_to_ccol_row(x: torch.Tensor):
|
||||||
|
"""Similar, but to CSC format"""
|
||||||
|
x = x.transpose(-2, -1)
|
||||||
|
return dense_to_crow_col(x)
|
||||||
|
|
||||||
|
|
||||||
|
def ccol_row_to_dense(ccol: torch.Tensor,
|
||||||
|
rows: torch.Tensor,
|
||||||
|
dtype: torch.dtype = torch.float16):
|
||||||
|
return crow_col_to_dense(ccol, rows, dtype).permute(0, 2, 1).contiguous()
|
||||||
|
|
||||||
|
|
||||||
|
def _get_sparse_attn_mask_homo_head(
|
||||||
|
q_len: int,
|
||||||
|
max_seqlen: int,
|
||||||
|
dtype: torch.dtype,
|
||||||
|
device: torch.device,
|
||||||
|
block_size: int = 128,
|
||||||
|
local_blocks: int = 4,
|
||||||
|
vert_stride: int = 4,
|
||||||
|
return_dense: bool = False,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
:return: a tuple of 3:
|
||||||
|
- tuple of crow_indices, col_indices representation
|
||||||
|
of CSR format.
|
||||||
|
- block dense mask
|
||||||
|
- all token dense mask (be aware that it can be
|
||||||
|
OOM if it is too big) if `return_dense==True`,
|
||||||
|
otherwise, None
|
||||||
|
"""
|
||||||
|
with torch.no_grad():
|
||||||
|
num_blocks = triton.cdiv(max_seqlen, block_size)
|
||||||
|
q_pos = torch.arange(num_blocks)[:, None]
|
||||||
|
k_pos = torch.arange(num_blocks)[None]
|
||||||
|
mask_vert_strided = (torch.arange(num_blocks) + 1) % vert_stride == 0
|
||||||
|
block_mask_dense = (((q_pos >= k_pos)
|
||||||
|
& ((q_pos - k_pos < local_blocks)
|
||||||
|
| mask_vert_strided)).to(device).to(dtype))
|
||||||
|
num_blocks_q = triton.cdiv(q_len, block_size)
|
||||||
|
block_mask_dense_output = (dense_to_crow_col(
|
||||||
|
block_mask_dense[-num_blocks_q:].contiguous()))
|
||||||
|
if return_dense:
|
||||||
|
mask_dense = torch.kron(
|
||||||
|
block_mask_dense,
|
||||||
|
block_mask_dense.new_ones((block_size, block_size)),
|
||||||
|
)
|
||||||
|
causal_mask = torch.tril(torch.ones(
|
||||||
|
max_seqlen, max_seqlen)).type_as(mask_dense)[-q_len:]
|
||||||
|
mask_dense = mask_dense[-q_len:, :max_seqlen] * causal_mask
|
||||||
|
return (
|
||||||
|
block_mask_dense_output,
|
||||||
|
block_mask_dense,
|
||||||
|
mask_dense,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
return (
|
||||||
|
block_mask_dense_output,
|
||||||
|
block_mask_dense,
|
||||||
|
None,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def binary_mask_to_bias(mask_dense: torch.Tensor):
|
||||||
|
mask_dense = 1 - mask_dense
|
||||||
|
mask_dense.masked_fill_(mask_dense.bool(), -torch.inf)
|
||||||
|
return mask_dense
|
||||||
|
|
||||||
|
|
||||||
|
def get_head_sliding_step(n_heads: int,
|
||||||
|
vert_stride: int,
|
||||||
|
homo_head: bool = False):
|
||||||
|
if homo_head:
|
||||||
|
return 0
|
||||||
|
return max(1, int(vert_stride / n_heads))
|
||||||
|
|
||||||
|
|
||||||
|
@lru_cache
|
||||||
|
def get_sparse_attn_mask(
|
||||||
|
n_heads: int,
|
||||||
|
q_len: int,
|
||||||
|
max_seqlen: int,
|
||||||
|
dtype: torch.dtype,
|
||||||
|
device: torch.device,
|
||||||
|
block_size: int = 64,
|
||||||
|
local_blocks: int = 4,
|
||||||
|
vert_stride: int = 4,
|
||||||
|
homo_head: bool = True,
|
||||||
|
return_dense: bool = False,
|
||||||
|
dense_mask_type: str = "binary",
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
:param dense_mask_type: "binary" (0 for skip token, 1 for others)
|
||||||
|
or "bias" (-inf for skip token, 0 or others)
|
||||||
|
:return: a tuple of 3:
|
||||||
|
- tuple of crow_indices, col_indices representation
|
||||||
|
of CSR format.
|
||||||
|
- block dense mask
|
||||||
|
- all token dense mask (be aware that it can be OOM if it
|
||||||
|
is too big) if `return_dense==True`, otherwise, None
|
||||||
|
"""
|
||||||
|
assert dense_mask_type in ("binary", "bias")
|
||||||
|
if homo_head:
|
||||||
|
with torch.no_grad():
|
||||||
|
(crow, col), block_mask_dense, mask_dense = (
|
||||||
|
_get_sparse_attn_mask_homo_head(
|
||||||
|
q_len,
|
||||||
|
max_seqlen,
|
||||||
|
dtype,
|
||||||
|
device,
|
||||||
|
block_size,
|
||||||
|
local_blocks,
|
||||||
|
vert_stride,
|
||||||
|
return_dense,
|
||||||
|
))
|
||||||
|
crow = crow[None].expand(n_heads, crow.shape[0])
|
||||||
|
col = col[None].expand(n_heads, col.shape[0])
|
||||||
|
if return_dense:
|
||||||
|
mask_dense = mask_dense[None].expand(n_heads,
|
||||||
|
*mask_dense.shape)
|
||||||
|
if dense_mask_type == "bias":
|
||||||
|
mask_dense = binary_mask_to_bias(mask_dense)
|
||||||
|
return (crow, col), block_mask_dense, mask_dense
|
||||||
|
|
||||||
|
with torch.no_grad():
|
||||||
|
num_blocks = triton.cdiv(max_seqlen, block_size)
|
||||||
|
q_pos = torch.arange(num_blocks)[None, :, None]
|
||||||
|
k_pos = torch.arange(num_blocks)[None, None]
|
||||||
|
head_sliding_step = get_head_sliding_step(n_heads, vert_stride)
|
||||||
|
mask_vert_strided = [
|
||||||
|
(torch.arange(num_blocks) + h * head_sliding_step + 1) %
|
||||||
|
vert_stride == 0 for h in range(n_heads)
|
||||||
|
]
|
||||||
|
mask_vert_strided = torch.vstack(mask_vert_strided).unsqueeze(1)
|
||||||
|
block_mask_dense = (((q_pos >= k_pos)
|
||||||
|
& ((q_pos - k_pos < local_blocks)
|
||||||
|
| mask_vert_strided)).to(device).to(dtype))
|
||||||
|
num_blocks_q = triton.cdiv(q_len, block_size)
|
||||||
|
block_mask_dense_output = block_mask_dense[:, -num_blocks_q:]
|
||||||
|
if return_dense:
|
||||||
|
mask_dense = torch.kron(
|
||||||
|
block_mask_dense,
|
||||||
|
block_mask_dense.new_ones((block_size, block_size)),
|
||||||
|
)
|
||||||
|
causal_mask = torch.tril(torch.ones(
|
||||||
|
max_seqlen, max_seqlen)).type_as(mask_dense)[-q_len:]
|
||||||
|
mask_dense = mask_dense[..., -q_len:, :max_seqlen] * causal_mask[None]
|
||||||
|
if dense_mask_type == "bias":
|
||||||
|
mask_dense = binary_mask_to_bias(mask_dense)
|
||||||
|
|
||||||
|
return (
|
||||||
|
dense_to_crow_col(block_mask_dense_output),
|
||||||
|
block_mask_dense,
|
||||||
|
mask_dense,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
return (
|
||||||
|
dense_to_crow_col(block_mask_dense_output),
|
||||||
|
block_mask_dense,
|
||||||
|
None,
|
||||||
|
)
|
@ -91,9 +91,21 @@ class PagedAttention:
|
|||||||
scale: float,
|
scale: float,
|
||||||
alibi_slopes: Optional[torch.Tensor],
|
alibi_slopes: Optional[torch.Tensor],
|
||||||
kv_scale: float,
|
kv_scale: float,
|
||||||
|
tp_rank: int = 0,
|
||||||
|
blocksparse_local_blocks: int = 0,
|
||||||
|
blocksparse_vert_stride: int = 0,
|
||||||
|
blocksparse_block_size: int = 64,
|
||||||
|
blocksparse_head_sliding_step: int = 0,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
output = torch.empty_like(query)
|
if blocksparse_vert_stride is not None and blocksparse_vert_stride > 1:
|
||||||
|
# use blocksparse paged attention
|
||||||
|
block_size = value_cache.size(-1)
|
||||||
|
assert (blocksparse_block_size > 0 and
|
||||||
|
blocksparse_block_size % block_size == 0), \
|
||||||
|
(f"{blocksparse_block_size=} needs to be a multiple of"
|
||||||
|
f"{block_size=} used in block_tables.")
|
||||||
|
|
||||||
|
output = torch.empty_like(query)
|
||||||
block_size = value_cache.shape[3]
|
block_size = value_cache.shape[3]
|
||||||
num_seqs, num_heads, head_size = query.shape
|
num_seqs, num_heads, head_size = query.shape
|
||||||
max_num_partitions = ((max_seq_len + _PARTITION_SIZE - 1) //
|
max_num_partitions = ((max_seq_len + _PARTITION_SIZE - 1) //
|
||||||
@ -107,6 +119,7 @@ class PagedAttention:
|
|||||||
# For context len > 8192, use V2 kernel to avoid shared memory shortage.
|
# For context len > 8192, use V2 kernel to avoid shared memory shortage.
|
||||||
use_v1 = (max_seq_len <= 8192
|
use_v1 = (max_seq_len <= 8192
|
||||||
and (max_num_partitions == 1 or num_seqs * num_heads > 512))
|
and (max_num_partitions == 1 or num_seqs * num_heads > 512))
|
||||||
|
|
||||||
if use_v1:
|
if use_v1:
|
||||||
# Run PagedAttention V1.
|
# Run PagedAttention V1.
|
||||||
ops.paged_attention_v1(
|
ops.paged_attention_v1(
|
||||||
@ -123,6 +136,11 @@ class PagedAttention:
|
|||||||
alibi_slopes,
|
alibi_slopes,
|
||||||
kv_cache_dtype,
|
kv_cache_dtype,
|
||||||
kv_scale,
|
kv_scale,
|
||||||
|
tp_rank,
|
||||||
|
blocksparse_local_blocks,
|
||||||
|
blocksparse_vert_stride,
|
||||||
|
blocksparse_block_size,
|
||||||
|
blocksparse_head_sliding_step,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
# Run PagedAttention V2.
|
# Run PagedAttention V2.
|
||||||
@ -155,6 +173,11 @@ class PagedAttention:
|
|||||||
alibi_slopes,
|
alibi_slopes,
|
||||||
kv_cache_dtype,
|
kv_cache_dtype,
|
||||||
kv_scale,
|
kv_scale,
|
||||||
|
tp_rank,
|
||||||
|
blocksparse_local_blocks,
|
||||||
|
blocksparse_vert_stride,
|
||||||
|
blocksparse_block_size,
|
||||||
|
blocksparse_head_sliding_step,
|
||||||
)
|
)
|
||||||
return output
|
return output
|
||||||
|
|
||||||
|
@ -29,7 +29,14 @@ def get_attn_backend(
|
|||||||
dtype: torch.dtype,
|
dtype: torch.dtype,
|
||||||
kv_cache_dtype: Optional[str],
|
kv_cache_dtype: Optional[str],
|
||||||
block_size: int,
|
block_size: int,
|
||||||
|
is_blocksparse: bool = False,
|
||||||
) -> Type[AttentionBackend]:
|
) -> Type[AttentionBackend]:
|
||||||
|
|
||||||
|
if is_blocksparse:
|
||||||
|
logger.info("Using BlocksparseFlashAttention backend.")
|
||||||
|
from vllm.attention.backends.blocksparse_attn import (
|
||||||
|
BlocksparseFlashAttentionBackend)
|
||||||
|
return BlocksparseFlashAttentionBackend
|
||||||
"""Determine which attention backend to use and only import
|
"""Determine which attention backend to use and only import
|
||||||
the selected backend module.
|
the selected backend module.
|
||||||
"""
|
"""
|
||||||
|
@ -100,6 +100,7 @@ class OpenAIServing:
|
|||||||
token_logprob = step_top_logprobs[token_id].logprob
|
token_logprob = step_top_logprobs[token_id].logprob
|
||||||
token = step_top_logprobs[token_id].decoded_token
|
token = step_top_logprobs[token_id].decoded_token
|
||||||
logprobs.tokens.append(token)
|
logprobs.tokens.append(token)
|
||||||
|
token_logprob = max(token_logprob, -9999.0)
|
||||||
logprobs.token_logprobs.append(token_logprob)
|
logprobs.token_logprobs.append(token_logprob)
|
||||||
|
|
||||||
if num_output_top_logprobs:
|
if num_output_top_logprobs:
|
||||||
|
@ -56,6 +56,7 @@ _GENERATION_MODELS = {
|
|||||||
"Starcoder2ForCausalLM": ("starcoder2", "Starcoder2ForCausalLM"),
|
"Starcoder2ForCausalLM": ("starcoder2", "Starcoder2ForCausalLM"),
|
||||||
"ArcticForCausalLM": ("arctic", "ArcticForCausalLM"),
|
"ArcticForCausalLM": ("arctic", "ArcticForCausalLM"),
|
||||||
"XverseForCausalLM": ("xverse", "XverseForCausalLM"),
|
"XverseForCausalLM": ("xverse", "XverseForCausalLM"),
|
||||||
|
"Phi3SmallForCausalLM": ("phi3_small", "Phi3SmallForCausalLM"),
|
||||||
}
|
}
|
||||||
|
|
||||||
_EMBEDDING_MODELS = {
|
_EMBEDDING_MODELS = {
|
||||||
|
447
vllm/model_executor/models/phi3_small.py
Normal file
447
vllm/model_executor/models/phi3_small.py
Normal file
@ -0,0 +1,447 @@
|
|||||||
|
import math
|
||||||
|
from typing import Iterable, List, Optional, Tuple
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from torch import nn
|
||||||
|
from transformers.configuration_utils import PretrainedConfig
|
||||||
|
|
||||||
|
from vllm.attention import Attention, AttentionMetadata
|
||||||
|
from vllm.config import CacheConfig, LoRAConfig
|
||||||
|
from vllm.distributed import (get_tensor_model_parallel_rank,
|
||||||
|
get_tensor_model_parallel_world_size)
|
||||||
|
from vllm.model_executor.layers.linear import (MergedColumnParallelLinear,
|
||||||
|
QKVParallelLinear,
|
||||||
|
RowParallelLinear)
|
||||||
|
from vllm.model_executor.layers.logits_processor import LogitsProcessor
|
||||||
|
from vllm.model_executor.layers.quantization.base_config import (
|
||||||
|
QuantizationConfig)
|
||||||
|
from vllm.model_executor.layers.rotary_embedding import get_rope
|
||||||
|
from vllm.model_executor.layers.sampler import Sampler
|
||||||
|
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
||||||
|
DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding)
|
||||||
|
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
||||||
|
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
||||||
|
from vllm.sequence import SamplerOutput
|
||||||
|
|
||||||
|
|
||||||
|
def load_column_parallel_weight(param: torch.nn.Parameter,
|
||||||
|
loaded_weight: torch.Tensor):
|
||||||
|
tp = get_tensor_model_parallel_world_size()
|
||||||
|
rk = get_tensor_model_parallel_rank()
|
||||||
|
assert param.size(0) * tp == loaded_weight.size(0)
|
||||||
|
s = rk * param.size(0)
|
||||||
|
e = (rk + 1) * param.size(0)
|
||||||
|
loaded_weight = loaded_weight[s:e]
|
||||||
|
assert param.shape == loaded_weight.shape
|
||||||
|
param.data.copy_(loaded_weight)
|
||||||
|
|
||||||
|
|
||||||
|
class HeadMajorQKVParallelLinear(QKVParallelLinear):
|
||||||
|
|
||||||
|
def weight_loader(self, param: torch.nn.Parameter,
|
||||||
|
loaded_weight: torch.Tensor):
|
||||||
|
return load_column_parallel_weight(param, loaded_weight)
|
||||||
|
|
||||||
|
|
||||||
|
class HeadMajorColumnParallelLinear(MergedColumnParallelLinear):
|
||||||
|
|
||||||
|
def weight_loader(self, param: torch.nn.Parameter,
|
||||||
|
loaded_weight: torch.Tensor):
|
||||||
|
return load_column_parallel_weight(param, loaded_weight)
|
||||||
|
|
||||||
|
|
||||||
|
@torch.jit.script
|
||||||
|
def quick_gelu(x):
|
||||||
|
return x * torch.sigmoid(1.702 * x)
|
||||||
|
|
||||||
|
|
||||||
|
@torch.jit.script
|
||||||
|
def gegelu(input, limit: Optional[float] = None):
|
||||||
|
a_gelu, a_linear = input[..., ::2], input[..., 1::2]
|
||||||
|
if limit is not None:
|
||||||
|
a_gelu = torch.where(torch.isinf(a_gelu), a_gelu,
|
||||||
|
a_gelu.clamp(min=None, max=limit))
|
||||||
|
a_linear = torch.where(
|
||||||
|
torch.isinf(a_linear),
|
||||||
|
a_linear,
|
||||||
|
a_linear.clamp(min=-limit, max=limit),
|
||||||
|
)
|
||||||
|
out_gelu = quick_gelu(a_gelu)
|
||||||
|
return out_gelu * (a_linear + 1)
|
||||||
|
|
||||||
|
|
||||||
|
class Phi3SmallMLP(nn.Module):
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
config: PretrainedConfig,
|
||||||
|
quant_config: Optional[QuantizationConfig] = None,
|
||||||
|
) -> None:
|
||||||
|
super().__init__()
|
||||||
|
self.config = config
|
||||||
|
assert (self.config.hidden_act == "gegelu"
|
||||||
|
), "Only `gegelu` is supported for the 4.7 series of models .."
|
||||||
|
self.hidden_size = config.hidden_size
|
||||||
|
self.gegelu_limit = config.gegelu_limit
|
||||||
|
self.intermediate_size = config.intermediate_size
|
||||||
|
|
||||||
|
self.up_proj = HeadMajorColumnParallelLinear(
|
||||||
|
self.hidden_size,
|
||||||
|
2 * [self.intermediate_size],
|
||||||
|
bias=True,
|
||||||
|
quant_config=quant_config,
|
||||||
|
)
|
||||||
|
self.down_proj = RowParallelLinear(
|
||||||
|
self.intermediate_size,
|
||||||
|
self.hidden_size,
|
||||||
|
bias=True,
|
||||||
|
quant_config=quant_config,
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
gate_up, _ = self.up_proj(x)
|
||||||
|
x = gegelu(gate_up)
|
||||||
|
x, _ = self.down_proj(x)
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class Phi3SmallSelfAttention(nn.Module):
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
config: PretrainedConfig,
|
||||||
|
layer_idx: int,
|
||||||
|
cache_config: Optional[CacheConfig] = None,
|
||||||
|
quant_config: Optional[QuantizationConfig] = None,
|
||||||
|
) -> None:
|
||||||
|
super().__init__()
|
||||||
|
self.layer_idx = layer_idx
|
||||||
|
self.config = config
|
||||||
|
self.sparse_block_size = config.blocksparse_block_size
|
||||||
|
self.homo_heads = config.blocksparse_homo_head_pattern
|
||||||
|
self.local_blocks = config.blocksparse_num_local_blocks
|
||||||
|
self.vert_stride = config.blocksparse_vert_stride
|
||||||
|
|
||||||
|
assert (config.blocksparse_block_size ==
|
||||||
|
config.blocksparse_triton_kernel_block_size)
|
||||||
|
|
||||||
|
self.hidden_size = config.hidden_size
|
||||||
|
# Number of Query Heads
|
||||||
|
self.num_heads = config.num_attention_heads
|
||||||
|
|
||||||
|
self.head_dim = self.hidden_size // self.num_heads
|
||||||
|
self.tp_size = get_tensor_model_parallel_world_size()
|
||||||
|
# Number of total Key Value Heads before tensor parallel
|
||||||
|
self.num_key_value_heads = config.num_key_value_heads
|
||||||
|
self.num_q_per_kv = self.num_heads // self.num_key_value_heads
|
||||||
|
if self.tp_size > 1:
|
||||||
|
assert self.num_key_value_heads % self.tp_size == 0
|
||||||
|
self.num_kv_heads_per_partion = max(
|
||||||
|
1, self.num_key_value_heads // self.tp_size)
|
||||||
|
self.num_heads_per_partition = self.num_heads // self.tp_size
|
||||||
|
|
||||||
|
self.max_position_embeddings = config.max_position_embeddings
|
||||||
|
self.rope_embedding_base = config.rope_embedding_base
|
||||||
|
self.rope_position_scale = config.rope_position_scale
|
||||||
|
self.is_causal = True
|
||||||
|
|
||||||
|
norm_factor = None
|
||||||
|
if config.mup_use_scaling:
|
||||||
|
norm_factor = self.head_dim / config.mup_attn_multiplier
|
||||||
|
else:
|
||||||
|
norm_factor = math.sqrt(self.head_dim)
|
||||||
|
self.scale = 1 / norm_factor
|
||||||
|
|
||||||
|
self.query_key_value = HeadMajorQKVParallelLinear(
|
||||||
|
self.hidden_size,
|
||||||
|
self.head_dim,
|
||||||
|
self.num_heads,
|
||||||
|
self.num_key_value_heads,
|
||||||
|
bias=True,
|
||||||
|
quant_config=quant_config,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.dense = RowParallelLinear(self.hidden_size,
|
||||||
|
self.hidden_size,
|
||||||
|
bias=True,
|
||||||
|
quant_config=quant_config)
|
||||||
|
|
||||||
|
if getattr(self.config, "rope_scaling", None) is not None:
|
||||||
|
rope_scaling = self.config.rope_scaling
|
||||||
|
for key in rope_scaling:
|
||||||
|
if isinstance(rope_scaling[key], list):
|
||||||
|
rope_scaling[key] = tuple(rope_scaling[key])
|
||||||
|
|
||||||
|
if "factor" not in rope_scaling:
|
||||||
|
rope_scaling["factor"] = self.rope_position_scale
|
||||||
|
else:
|
||||||
|
rope_scaling = {
|
||||||
|
"type": "linear",
|
||||||
|
"factor": self.rope_position_scale,
|
||||||
|
}
|
||||||
|
|
||||||
|
self.rotary_emb = get_rope(
|
||||||
|
self.head_dim,
|
||||||
|
rotary_dim=self.head_dim,
|
||||||
|
max_position=self.max_position_embeddings,
|
||||||
|
base=self.rope_embedding_base,
|
||||||
|
rope_scaling=rope_scaling,
|
||||||
|
)
|
||||||
|
|
||||||
|
# blocksparse params
|
||||||
|
self.blocksparse_block_size = config.blocksparse_block_size
|
||||||
|
self.blocksparse_num_local_blocks = config.blocksparse_num_local_blocks
|
||||||
|
self.blocksparse_vert_stride = config.blocksparse_vert_stride
|
||||||
|
|
||||||
|
use_dense_attn = (getattr(self.config,
|
||||||
|
"dense_attention_every_n_layers", None)
|
||||||
|
and (self.layer_idx + 1) %
|
||||||
|
self.config.dense_attention_every_n_layers == 0)
|
||||||
|
|
||||||
|
bs_params = None
|
||||||
|
if not use_dense_attn:
|
||||||
|
bs_params = {
|
||||||
|
'max_seqlen': self.max_position_embeddings,
|
||||||
|
'num_heads': self.num_heads_per_partition,
|
||||||
|
"num_kv_heads": self.num_kv_heads_per_partion,
|
||||||
|
"block_size": self.sparse_block_size,
|
||||||
|
"local_blocks": self.local_blocks,
|
||||||
|
"vert_stride": self.vert_stride,
|
||||||
|
"homo_head": self.homo_heads
|
||||||
|
}
|
||||||
|
|
||||||
|
self.attn = Attention(
|
||||||
|
self.num_heads_per_partition,
|
||||||
|
self.head_dim,
|
||||||
|
self.scale,
|
||||||
|
num_kv_heads=self.num_kv_heads_per_partion,
|
||||||
|
cache_config=cache_config,
|
||||||
|
quant_config=quant_config,
|
||||||
|
blocksparse_params=bs_params,
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
positions: torch.Tensor,
|
||||||
|
hidden_states: torch.Tensor,
|
||||||
|
kv_cache: torch.Tensor,
|
||||||
|
attn_metadata: AttentionMetadata,
|
||||||
|
) -> Tuple[torch.Tensor, Optional[torch.Tensor],
|
||||||
|
Optional[Tuple[torch.Tensor]]]:
|
||||||
|
qkv, _ = self.query_key_value(hidden_states)
|
||||||
|
|
||||||
|
qkv = qkv.view(qkv.shape[:-1] +
|
||||||
|
(-1, (self.num_q_per_kv + 2), self.head_dim))
|
||||||
|
q, k, v = qkv.split([self.num_q_per_kv, 1, 1], dim=-2)
|
||||||
|
|
||||||
|
# NOTE: this is required by RotaryEmbed, which indeed does not have to
|
||||||
|
# TODO: allow 3D QK for rotary forward
|
||||||
|
q = q.reshape(-1, self.head_dim * self.num_heads_per_partition)
|
||||||
|
k = k.reshape(-1, self.head_dim * self.num_kv_heads_per_partion)
|
||||||
|
v = v.reshape(-1, self.head_dim * self.num_kv_heads_per_partion)
|
||||||
|
|
||||||
|
q, k = self.rotary_emb(positions, q, k)
|
||||||
|
attn_output = self.attn(q, k, v, kv_cache, attn_metadata=attn_metadata)
|
||||||
|
output, _ = self.dense(attn_output)
|
||||||
|
|
||||||
|
return output
|
||||||
|
|
||||||
|
|
||||||
|
class Phi3SmallDecoderLayer(nn.Module):
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
config: PretrainedConfig,
|
||||||
|
layer_idx: int,
|
||||||
|
cache_config: Optional[CacheConfig] = None,
|
||||||
|
quant_config: Optional[QuantizationConfig] = None,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
self.hidden_size = config.hidden_size
|
||||||
|
self.self_attn = Phi3SmallSelfAttention(config,
|
||||||
|
layer_idx,
|
||||||
|
cache_config=cache_config,
|
||||||
|
quant_config=quant_config)
|
||||||
|
self.mlp = Phi3SmallMLP(config, quant_config)
|
||||||
|
|
||||||
|
self.input_layernorm = nn.LayerNorm(config.hidden_size,
|
||||||
|
eps=config.layer_norm_epsilon)
|
||||||
|
self.post_attention_layernorm = nn.LayerNorm(
|
||||||
|
config.hidden_size, eps=config.layer_norm_epsilon)
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
positions: torch.Tensor,
|
||||||
|
hidden_states: torch.Tensor,
|
||||||
|
kv_cache: torch.Tensor,
|
||||||
|
attn_metadata: AttentionMetadata,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
residual = hidden_states
|
||||||
|
hidden_states = self.input_layernorm(hidden_states)
|
||||||
|
|
||||||
|
hidden_states = self.self_attn(
|
||||||
|
positions=positions,
|
||||||
|
hidden_states=hidden_states,
|
||||||
|
kv_cache=kv_cache,
|
||||||
|
attn_metadata=attn_metadata,
|
||||||
|
)
|
||||||
|
hidden_states = residual + hidden_states
|
||||||
|
|
||||||
|
residual = hidden_states
|
||||||
|
hidden_states = self.post_attention_layernorm(hidden_states)
|
||||||
|
hidden_states = self.mlp(hidden_states)
|
||||||
|
hidden_states = residual + hidden_states
|
||||||
|
return hidden_states
|
||||||
|
|
||||||
|
|
||||||
|
class Phi3SmallModel(nn.Module):
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
config: PretrainedConfig,
|
||||||
|
cache_config: Optional[CacheConfig] = None,
|
||||||
|
quant_config: Optional[QuantizationConfig] = None,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
self.config = config
|
||||||
|
self.embed_tokens = VocabParallelEmbedding(config.vocab_size,
|
||||||
|
config.hidden_size)
|
||||||
|
self.mup_embedding_multiplier = config.mup_embedding_multiplier
|
||||||
|
self.layers = nn.ModuleList([
|
||||||
|
Phi3SmallDecoderLayer(config, layer_idx, cache_config,
|
||||||
|
quant_config)
|
||||||
|
for layer_idx in range(config.num_hidden_layers)
|
||||||
|
])
|
||||||
|
|
||||||
|
self.final_layernorm = nn.LayerNorm(config.hidden_size,
|
||||||
|
eps=config.layer_norm_epsilon)
|
||||||
|
|
||||||
|
def get_input_embeddings(self):
|
||||||
|
return self.embed_tokens
|
||||||
|
|
||||||
|
def set_input_embeddings(self, value):
|
||||||
|
self.embed_tokens = value
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
input_ids: torch.LongTensor,
|
||||||
|
positions: Optional[torch.LongTensor],
|
||||||
|
kv_caches: List[torch.Tensor],
|
||||||
|
attn_metadata: AttentionMetadata = None,
|
||||||
|
):
|
||||||
|
hidden_states = self.embed_tokens(input_ids)
|
||||||
|
if (self.mup_embedding_multiplier is not None
|
||||||
|
and self.mup_embedding_multiplier > 0.0):
|
||||||
|
hidden_states = hidden_states * self.mup_embedding_multiplier
|
||||||
|
for i in range(len(self.layers)):
|
||||||
|
layer = self.layers[i]
|
||||||
|
hidden_states = layer(
|
||||||
|
positions,
|
||||||
|
hidden_states,
|
||||||
|
kv_caches[i],
|
||||||
|
attn_metadata,
|
||||||
|
)
|
||||||
|
hidden_states = self.final_layernorm(hidden_states)
|
||||||
|
return hidden_states
|
||||||
|
|
||||||
|
|
||||||
|
class Phi3SmallForCausalLM(nn.Module):
|
||||||
|
_tied_weights_keys = ["lm_head.weight"]
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
config,
|
||||||
|
cache_config: Optional[CacheConfig] = None,
|
||||||
|
quant_config: Optional[QuantizationConfig] = None,
|
||||||
|
lora_config: Optional[LoRAConfig] = None,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
self.config = config
|
||||||
|
self.quant_config = quant_config
|
||||||
|
self.model = Phi3SmallModel(config, cache_config, quant_config)
|
||||||
|
self.vocab_size = config.vocab_size
|
||||||
|
self.mup_width_multiplier = config.mup_width_multiplier
|
||||||
|
self.lm_head = ParallelLMHead(
|
||||||
|
self.vocab_size,
|
||||||
|
config.hidden_size,
|
||||||
|
org_num_embeddings=config.vocab_size,
|
||||||
|
padding_size=DEFAULT_VOCAB_PADDING_SIZE,
|
||||||
|
)
|
||||||
|
self.logits_processor = LogitsProcessor(config.vocab_size)
|
||||||
|
self.sampler = Sampler()
|
||||||
|
|
||||||
|
# tokens in tiktoken but not used
|
||||||
|
if hasattr(config, 'dummy_token_indices'):
|
||||||
|
device = self.lm_head.weight.device
|
||||||
|
self.register_buffer('dummy_token_indices',
|
||||||
|
torch.LongTensor(
|
||||||
|
config.dummy_token_indices).to(device),
|
||||||
|
persistent=False)
|
||||||
|
else:
|
||||||
|
self.dummy_token_indices = None
|
||||||
|
|
||||||
|
def get_input_embeddings(self):
|
||||||
|
return self.model.embed_tokens
|
||||||
|
|
||||||
|
def set_input_embeddings(self, value):
|
||||||
|
self.model.embed_tokens = value
|
||||||
|
|
||||||
|
def get_output_embeddings(self):
|
||||||
|
return self.lm_head
|
||||||
|
|
||||||
|
def set_output_embeddings(self, value):
|
||||||
|
self.lm_head = value
|
||||||
|
|
||||||
|
def set_decoder(self, decoder):
|
||||||
|
self.model = decoder
|
||||||
|
|
||||||
|
def get_decoder(self):
|
||||||
|
return self.model
|
||||||
|
|
||||||
|
def compute_logits(self, hidden_states: torch.Tensor,
|
||||||
|
sampling_metadata: SamplingMetadata) -> torch.Tensor:
|
||||||
|
logits = self.logits_processor(self.lm_head.weight, hidden_states,
|
||||||
|
sampling_metadata)
|
||||||
|
if self.dummy_token_indices is not None and logits is not None:
|
||||||
|
logits.index_fill_(-1, self.dummy_token_indices, -torch.inf)
|
||||||
|
return logits
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
input_ids: torch.LongTensor,
|
||||||
|
positions: Optional[torch.LongTensor],
|
||||||
|
kv_caches: List[torch.Tensor],
|
||||||
|
attn_metadata: AttentionMetadata,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
output_hidden_states = self.model(
|
||||||
|
input_ids=input_ids,
|
||||||
|
positions=positions,
|
||||||
|
kv_caches=kv_caches,
|
||||||
|
attn_metadata=attn_metadata,
|
||||||
|
)
|
||||||
|
output_hidden_states = output_hidden_states
|
||||||
|
return output_hidden_states
|
||||||
|
|
||||||
|
def sample(
|
||||||
|
self,
|
||||||
|
logits: torch.Tensor,
|
||||||
|
sampling_metadata: SamplingMetadata,
|
||||||
|
) -> Optional[SamplerOutput]:
|
||||||
|
|
||||||
|
next_tokens = self.sampler(logits / self.mup_width_multiplier,
|
||||||
|
sampling_metadata)
|
||||||
|
return next_tokens
|
||||||
|
|
||||||
|
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
||||||
|
|
||||||
|
params_dict = dict(self.named_parameters())
|
||||||
|
for name, loaded_weight in weights:
|
||||||
|
if "rotary_emb.inv_freq" in name:
|
||||||
|
continue
|
||||||
|
if name.endswith(".bias") and name not in params_dict:
|
||||||
|
continue
|
||||||
|
param = params_dict[name]
|
||||||
|
weight_loader = getattr(param, "weight_loader",
|
||||||
|
default_weight_loader)
|
||||||
|
weight_loader(param, loaded_weight)
|
||||||
|
self.lm_head.weight.data.copy_(self.model.embed_tokens.weight.data)
|
Loading…
x
Reference in New Issue
Block a user