[Kernel][Backend][Model] Blocksparse flash attention kernel and Phi-3-Small model (#4799)

Co-authored-by: beagleski <yunanzhang@microsoft.com>
Co-authored-by: bapatra <bapatra@microsoft.com>
Co-authored-by: Barun Patra <codedecde@users.noreply.github.com>
Co-authored-by: Michael Goin <michael@neuralmagic.com>
This commit is contained in:
Eric Xihui Lin 2024-05-25 01:00:52 -04:00 committed by GitHub
parent e64fde4b01
commit 8e192ff967
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
23 changed files with 2445 additions and 87 deletions

View File

@ -85,6 +85,7 @@ inline __device__ float block_sum(float* red_smem, float sum) {
// Grid: (num_heads, num_seqs, max_num_partitions).
template <typename scalar_t, typename cache_t, int HEAD_SIZE, int BLOCK_SIZE,
int NUM_THREADS, vllm::Fp8KVCacheDataType KV_DTYPE,
bool IS_BLOCK_SPARSE,
int PARTITION_SIZE = 0> // Zero means no partitioning.
__device__ void paged_attention_kernel(
float* __restrict__ exp_sums, // [num_seqs, num_heads, max_num_partitions]
@ -104,7 +105,9 @@ __device__ void paged_attention_kernel(
const int max_num_blocks_per_seq,
const float* __restrict__ alibi_slopes, // [num_heads]
const int q_stride, const int kv_block_stride, const int kv_head_stride,
const float kv_scale) {
const float kv_scale, const int tp_rank, const int blocksparse_local_blocks,
const int blocksparse_vert_stride, const int blocksparse_block_size,
const int blocksparse_head_sliding_step) {
const int seq_idx = blockIdx.y;
const int partition_idx = blockIdx.z;
const int max_num_partitions = gridDim.z;
@ -202,11 +205,55 @@ __device__ void paged_attention_kernel(
// Each thread group in a warp fetches a key from the block, and computes
// dot product with the query.
const int* block_table = block_tables + seq_idx * max_num_blocks_per_seq;
// blocksparse specific vars
int bs_block_offset;
int q_bs_block_id;
if constexpr (IS_BLOCK_SPARSE) {
// const int num_blocksparse_blocks = DIVIDE_ROUND_UP(seq_len,
// blocksparse_block_size);
q_bs_block_id = (seq_len - 1) / blocksparse_block_size;
if (blocksparse_head_sliding_step >= 0)
// sliding on q heads
bs_block_offset =
(tp_rank * num_heads + head_idx) * blocksparse_head_sliding_step + 1;
else
// sliding on kv heads
bs_block_offset = (tp_rank * num_kv_heads + kv_head_idx) *
(-blocksparse_head_sliding_step) +
1;
}
for (int block_idx = start_block_idx + warp_idx; block_idx < end_block_idx;
block_idx += NUM_WARPS) {
// NOTE(woosuk): The block number is stored in int32. However, we cast it to
// int64 because int32 can lead to overflow when this variable is multiplied
// by large numbers (e.g., kv_block_stride).
// For blocksparse attention: skip computation on blocks that are not
// attended
if constexpr (IS_BLOCK_SPARSE) {
const int k_bs_block_id = block_idx * BLOCK_SIZE / blocksparse_block_size;
const bool is_remote =
((k_bs_block_id + bs_block_offset) % blocksparse_vert_stride == 0);
const bool is_local =
(k_bs_block_id > q_bs_block_id - blocksparse_local_blocks);
if (!is_remote && !is_local) {
for (int i = 0; i < NUM_TOKENS_PER_THREAD_GROUP; i++) {
const int physical_block_offset =
(thread_group_idx + i * WARP_SIZE) % BLOCK_SIZE;
const int token_idx = block_idx * BLOCK_SIZE + physical_block_offset;
if (thread_group_offset == 0) {
// NOTE(linxihui): assign very large number to skipped tokens to
// avoid contribution to the sumexp softmax normalizer. This will
// not be used at computing sum(softmax*v) as the blocks will be
// skipped.
logits[token_idx - start_token_idx] = -FLT_MAX;
}
}
continue;
}
}
const int64_t physical_block_number =
static_cast<int64_t>(block_table[block_idx]);
@ -335,6 +382,15 @@ __device__ void paged_attention_kernel(
// NOTE(woosuk): The block number is stored in int32. However, we cast it to
// int64 because int32 can lead to overflow when this variable is multiplied
// by large numbers (e.g., kv_block_stride).
// For blocksparse attention: skip computation on blocks that are not
// attended
if constexpr (IS_BLOCK_SPARSE) {
int v_bs_block_id = block_idx * BLOCK_SIZE / blocksparse_block_size;
if (!((v_bs_block_id + bs_block_offset) % blocksparse_vert_stride == 0) &&
!((v_bs_block_id > q_bs_block_id - blocksparse_local_blocks))) {
continue;
}
}
const int64_t physical_block_number =
static_cast<int64_t>(block_table[block_idx]);
const int physical_block_offset = (lane % NUM_V_VECS_PER_ROW) * V_VEC_SIZE;
@ -441,8 +497,8 @@ __device__ void paged_attention_kernel(
// Grid: (num_heads, num_seqs, 1).
template <typename scalar_t, typename cache_t, int HEAD_SIZE, int BLOCK_SIZE,
int NUM_THREADS,
vllm::Fp8KVCacheDataType KV_DTYPE>
int NUM_THREADS, vllm::Fp8KVCacheDataType KV_DTYPE,
bool IS_BLOCK_SPARSE>
__global__ void paged_attention_v1_kernel(
scalar_t* __restrict__ out, // [num_seqs, num_heads, head_size]
const scalar_t* __restrict__ q, // [num_seqs, num_heads, head_size]
@ -457,18 +513,23 @@ __global__ void paged_attention_v1_kernel(
const int max_num_blocks_per_seq,
const float* __restrict__ alibi_slopes, // [num_heads]
const int q_stride, const int kv_block_stride, const int kv_head_stride,
const float kv_scale) {
const float kv_scale, const int tp_rank, const int blocksparse_local_blocks,
const int blocksparse_vert_stride, const int blocksparse_block_size,
const int blocksparse_head_sliding_step) {
paged_attention_kernel<scalar_t, cache_t, HEAD_SIZE, BLOCK_SIZE, NUM_THREADS,
KV_DTYPE>(
KV_DTYPE, IS_BLOCK_SPARSE>(
/* exp_sums */ nullptr, /* max_logits */ nullptr, out, q, k_cache,
v_cache, num_kv_heads, scale, block_tables, seq_lens,
max_num_blocks_per_seq, alibi_slopes, q_stride, kv_block_stride,
kv_head_stride, kv_scale);
kv_head_stride, kv_scale, tp_rank, blocksparse_local_blocks,
blocksparse_vert_stride, blocksparse_block_size,
blocksparse_head_sliding_step);
}
// Grid: (num_heads, num_seqs, max_num_partitions).
template <typename scalar_t, typename cache_t, int HEAD_SIZE, int BLOCK_SIZE,
int NUM_THREADS, vllm::Fp8KVCacheDataType KV_DTYPE,
bool IS_BLOCK_SPARSE,
int PARTITION_SIZE>
__global__ void paged_attention_v2_kernel(
float* __restrict__ exp_sums, // [num_seqs, num_heads, max_num_partitions]
@ -488,12 +549,16 @@ __global__ void paged_attention_v2_kernel(
const int max_num_blocks_per_seq,
const float* __restrict__ alibi_slopes, // [num_heads]
const int q_stride, const int kv_block_stride, const int kv_head_stride,
const float kv_scale) {
const float kv_scale, const int tp_rank, const int blocksparse_local_blocks,
const int blocksparse_vert_stride, const int blocksparse_block_size,
const int blocksparse_head_sliding_step) {
paged_attention_kernel<scalar_t, cache_t, HEAD_SIZE, BLOCK_SIZE, NUM_THREADS,
KV_DTYPE, PARTITION_SIZE>(
KV_DTYPE, IS_BLOCK_SPARSE, PARTITION_SIZE>(
exp_sums, max_logits, tmp_out, q, k_cache, v_cache, num_kv_heads, scale,
block_tables, seq_lens, max_num_blocks_per_seq, alibi_slopes, q_stride,
kv_block_stride, kv_head_stride, kv_scale);
kv_block_stride, kv_head_stride, kv_scale, tp_rank,
blocksparse_local_blocks, blocksparse_vert_stride, blocksparse_block_size,
blocksparse_head_sliding_step);
}
// Grid: (num_heads, num_seqs).
@ -607,25 +672,32 @@ __global__ void paged_attention_v2_reduce_kernel(
#define LAUNCH_PAGED_ATTENTION_V1(HEAD_SIZE) \
VLLM_DevFuncAttribute_SET_MaxDynamicSharedMemorySize( \
((void*)vllm::paged_attention_v1_kernel< \
T, CACHE_T, HEAD_SIZE, BLOCK_SIZE, NUM_THREADS, KV_DTYPE>), \
((void*)vllm::paged_attention_v1_kernel<T, CACHE_T, HEAD_SIZE, \
BLOCK_SIZE, NUM_THREADS, \
KV_DTYPE, IS_BLOCK_SPARSE>), \
shared_mem_size); \
vllm::paged_attention_v1_kernel<T, CACHE_T, HEAD_SIZE, BLOCK_SIZE, \
NUM_THREADS, KV_DTYPE> \
NUM_THREADS, KV_DTYPE, IS_BLOCK_SPARSE> \
<<<grid, block, shared_mem_size, stream>>>( \
out_ptr, query_ptr, key_cache_ptr, value_cache_ptr, num_kv_heads, \
scale, block_tables_ptr, seq_lens_ptr, max_num_blocks_per_seq, \
alibi_slopes_ptr, q_stride, kv_block_stride, kv_head_stride, \
kv_scale);
kv_scale, tp_rank, blocksparse_local_blocks, \
blocksparse_vert_stride, blocksparse_block_size, \
blocksparse_head_sliding_step);
// TODO(woosuk): Tune NUM_THREADS.
template <typename T, typename CACHE_T, int BLOCK_SIZE,
vllm::Fp8KVCacheDataType KV_DTYPE, int NUM_THREADS = 128>
vllm::Fp8KVCacheDataType KV_DTYPE, bool IS_BLOCK_SPARSE,
int NUM_THREADS = 128>
void paged_attention_v1_launcher(
torch::Tensor& out, torch::Tensor& query, torch::Tensor& key_cache,
torch::Tensor& value_cache, int num_kv_heads, float scale,
torch::Tensor& block_tables, torch::Tensor& seq_lens, int max_seq_len,
const c10::optional<torch::Tensor>& alibi_slopes, float kv_scale) {
const c10::optional<torch::Tensor>& alibi_slopes, float kv_scale,
const int tp_rank, const int blocksparse_local_blocks,
const int blocksparse_vert_stride, const int blocksparse_block_size,
const int blocksparse_head_sliding_step) {
int num_seqs = query.size(0);
int num_heads = query.size(1);
int head_size = query.size(2);
@ -691,23 +763,36 @@ void paged_attention_v1_launcher(
}
}
#define CALL_V1_LAUNCHER(T, CACHE_T, BLOCK_SIZE, KV_DTYPE) \
paged_attention_v1_launcher<T, CACHE_T, BLOCK_SIZE, KV_DTYPE>( \
#define CALL_V1_LAUNCHER(T, CACHE_T, BLOCK_SIZE, KV_DTYPE, IS_BLOCK_SPARSE) \
paged_attention_v1_launcher<T, CACHE_T, BLOCK_SIZE, KV_DTYPE, \
IS_BLOCK_SPARSE>( \
out, query, key_cache, value_cache, num_kv_heads, scale, block_tables, \
seq_lens, max_seq_len, alibi_slopes, kv_scale);
seq_lens, max_seq_len, alibi_slopes, kv_scale, tp_rank, \
blocksparse_local_blocks, blocksparse_vert_stride, \
blocksparse_block_size, blocksparse_head_sliding_step);
#define CALL_V1_LAUNCHER_SPARSITY(T, CACHE_T, BLOCK_SIZE, IS_FP8_KV_CACHE) \
switch (is_block_sparse) { \
case true: \
CALL_V1_LAUNCHER(T, CACHE_T, BLOCK_SIZE, IS_FP8_KV_CACHE, true); \
break; \
case false: \
CALL_V1_LAUNCHER(T, CACHE_T, BLOCK_SIZE, IS_FP8_KV_CACHE, false); \
break; \
}
// NOTE(woosuk): To reduce the compilation time, we omitted block sizes
// 1, 2, 4, 64, 128, 256.
#define CALL_V1_LAUNCHER_BLOCK_SIZE(T, CACHE_T, KV_DTYPE) \
switch (block_size) { \
case 8: \
CALL_V1_LAUNCHER(T, CACHE_T, 8, KV_DTYPE); \
CALL_V1_LAUNCHER_SPARSITY(T, CACHE_T, 8, KV_DTYPE); \
break; \
case 16: \
CALL_V1_LAUNCHER(T, CACHE_T, 16, KV_DTYPE); \
CALL_V1_LAUNCHER_SPARSITY(T, CACHE_T, 16, KV_DTYPE); \
break; \
case 32: \
CALL_V1_LAUNCHER(T, CACHE_T, 32, KV_DTYPE); \
CALL_V1_LAUNCHER_SPARSITY(T, CACHE_T, 32, KV_DTYPE); \
break; \
default: \
TORCH_CHECK(false, "Unsupported block size: ", block_size); \
@ -727,18 +812,26 @@ void paged_attention_v1(
torch::Tensor& seq_lens, // [num_seqs]
int block_size, int max_seq_len,
const c10::optional<torch::Tensor>& alibi_slopes,
const std::string& kv_cache_dtype, float kv_scale){
const std::string& kv_cache_dtype, float kv_scale, const int tp_rank,
const int blocksparse_local_blocks, const int blocksparse_vert_stride,
const int blocksparse_block_size, const int blocksparse_head_sliding_step) {
const bool is_block_sparse = (blocksparse_vert_stride > 1);
DISPATCH_BY_KV_CACHE_DTYPE(query.dtype(), kv_cache_dtype,
CALL_V1_LAUNCHER_BLOCK_SIZE)
}
DISPATCH_BY_KV_CACHE_DTYPE(query.dtype(), kv_cache_dtype,
CALL_V1_LAUNCHER_BLOCK_SIZE)}
#define LAUNCH_PAGED_ATTENTION_V2(HEAD_SIZE) \
vllm::paged_attention_v2_kernel<T, CACHE_T, HEAD_SIZE, BLOCK_SIZE, \
NUM_THREADS, KV_DTYPE, PARTITION_SIZE> \
NUM_THREADS, KV_DTYPE, IS_BLOCK_SPARSE, \
PARTITION_SIZE> \
<<<grid, block, shared_mem_size, stream>>>( \
exp_sums_ptr, max_logits_ptr, tmp_out_ptr, query_ptr, key_cache_ptr, \
value_cache_ptr, num_kv_heads, scale, block_tables_ptr, \
seq_lens_ptr, max_num_blocks_per_seq, alibi_slopes_ptr, q_stride, \
kv_block_stride, kv_head_stride, kv_scale); \
kv_block_stride, kv_head_stride, kv_scale, tp_rank, \
blocksparse_local_blocks, blocksparse_vert_stride, \
blocksparse_block_size, blocksparse_head_sliding_step); \
vllm::paged_attention_v2_reduce_kernel<T, HEAD_SIZE, NUM_THREADS, \
PARTITION_SIZE> \
<<<reduce_grid, block, reduce_shared_mem_size, stream>>>( \
@ -746,14 +839,17 @@ void paged_attention_v1(
max_num_partitions);
template <typename T, typename CACHE_T, int BLOCK_SIZE,
vllm::Fp8KVCacheDataType KV_DTYPE, int NUM_THREADS = 128,
int PARTITION_SIZE = 512>
vllm::Fp8KVCacheDataType KV_DTYPE, bool IS_BLOCK_SPARSE,
int NUM_THREADS = 128, int PARTITION_SIZE = 512>
void paged_attention_v2_launcher(
torch::Tensor& out, torch::Tensor& exp_sums, torch::Tensor& max_logits,
torch::Tensor& tmp_out, torch::Tensor& query, torch::Tensor& key_cache,
torch::Tensor& value_cache, int num_kv_heads, float scale,
torch::Tensor& block_tables, torch::Tensor& seq_lens, int max_seq_len,
const c10::optional<torch::Tensor>& alibi_slopes, float kv_scale) {
const c10::optional<torch::Tensor>& alibi_slopes, float kv_scale,
const int tp_rank, const int blocksparse_local_blocks,
const int blocksparse_vert_stride, const int blocksparse_block_size,
const int blocksparse_head_sliding_step) {
int num_seqs = query.size(0);
int num_heads = query.size(1);
int head_size = query.size(2);
@ -824,24 +920,36 @@ void paged_attention_v2_launcher(
}
}
#define CALL_V2_LAUNCHER(T, CACHE_T, BLOCK_SIZE, KV_DTYPE) \
paged_attention_v2_launcher<T, CACHE_T, BLOCK_SIZE, KV_DTYPE>( \
#define CALL_V2_LAUNCHER(T, CACHE_T, BLOCK_SIZE, KV_DTYPE, IS_BLOCK_SPARSE) \
paged_attention_v2_launcher<T, CACHE_T, BLOCK_SIZE, KV_DTYPE, \
IS_BLOCK_SPARSE>( \
out, exp_sums, max_logits, tmp_out, query, key_cache, value_cache, \
num_kv_heads, scale, block_tables, seq_lens, max_seq_len, alibi_slopes, \
kv_scale);
kv_scale, tp_rank, blocksparse_local_blocks, blocksparse_vert_stride, \
blocksparse_block_size, blocksparse_head_sliding_step);
#define CALL_V2_LAUNCHER_SPARSITY(T, CACHE_T, BLOCK_SIZE, IS_FP8_KV_CACHE) \
switch (is_block_sparse) { \
case true: \
CALL_V2_LAUNCHER(T, CACHE_T, BLOCK_SIZE, IS_FP8_KV_CACHE, true); \
break; \
case false: \
CALL_V2_LAUNCHER(T, CACHE_T, BLOCK_SIZE, IS_FP8_KV_CACHE, false); \
break; \
}
// NOTE(woosuk): To reduce the compilation time, we omitted block sizes
// 1, 2, 4, 64, 128, 256.
#define CALL_V2_LAUNCHER_BLOCK_SIZE(T, CACHE_T, KV_DTYPE) \
switch (block_size) { \
case 8: \
CALL_V2_LAUNCHER(T, CACHE_T, 8, KV_DTYPE); \
CALL_V2_LAUNCHER_SPARSITY(T, CACHE_T, 8, KV_DTYPE); \
break; \
case 16: \
CALL_V2_LAUNCHER(T, CACHE_T, 16, KV_DTYPE); \
CALL_V2_LAUNCHER_SPARSITY(T, CACHE_T, 16, KV_DTYPE); \
break; \
case 32: \
CALL_V2_LAUNCHER(T, CACHE_T, 32, KV_DTYPE); \
CALL_V2_LAUNCHER_SPARSITY(T, CACHE_T, 32, KV_DTYPE); \
break; \
default: \
TORCH_CHECK(false, "Unsupported block size: ", block_size); \
@ -865,7 +973,10 @@ void paged_attention_v2(
torch::Tensor& seq_lens, // [num_seqs]
int block_size, int max_seq_len,
const c10::optional<torch::Tensor>& alibi_slopes,
const std::string& kv_cache_dtype, float kv_scale) {
const std::string& kv_cache_dtype, float kv_scale, const int tp_rank,
const int blocksparse_local_blocks, const int blocksparse_vert_stride,
const int blocksparse_block_size, const int blocksparse_head_sliding_step) {
const bool is_block_sparse = (blocksparse_vert_stride > 1);
DISPATCH_BY_KV_CACHE_DTYPE(query.dtype(), kv_cache_dtype,
CALL_V2_LAUNCHER_BLOCK_SIZE)
}
@ -873,4 +984,4 @@ void paged_attention_v2(
#undef WARP_SIZE
#undef MAX
#undef MIN
#undef DIVIDE_ROUND_UP
#undef DIVIDE_ROUND_UP

View File

@ -415,14 +415,17 @@ void paged_attention_v1_impl_launcher(
}
} // namespace
void paged_attention_v1(torch::Tensor& out, torch::Tensor& query,
torch::Tensor& key_cache, torch::Tensor& value_cache,
int num_kv_heads, float scale,
torch::Tensor& block_tables, torch::Tensor& seq_lens,
int block_size, int max_seq_len,
const c10::optional<torch::Tensor>& alibi_slopes,
const std::string& kv_cache_dtype, float kv_scale) {
void paged_attention_v1(
torch::Tensor& out, torch::Tensor& query, torch::Tensor& key_cache,
torch::Tensor& value_cache, int num_kv_heads, float scale,
torch::Tensor& block_tables, torch::Tensor& seq_lens, int block_size,
int max_seq_len, const c10::optional<torch::Tensor>& alibi_slopes,
const std::string& kv_cache_dtype, float kv_scale, const int tp_rank,
const int blocksparse_local_blocks, const int blocksparse_vert_stride,
const int blocksparse_block_size, const int blocksparse_head_sliding_step) {
TORCH_CHECK(kv_scale == 1.0f);
TORCH_CHECK(blocksparse_vert_stride <= 1,
"CPU backend does not support blocksparse attention yet.");
VLLM_DISPATCH_FLOATING_TYPES(query.scalar_type(), "paged_attention_v1_impl",
[&] {
CPU_KERNEL_GUARD_IN(paged_attention_v1_impl)
@ -726,16 +729,18 @@ void paged_attention_v2_impl_launcher(
}
} // namespace
void paged_attention_v2(torch::Tensor& out, torch::Tensor& exp_sums,
torch::Tensor& max_logits, torch::Tensor& tmp_out,
torch::Tensor& query, torch::Tensor& key_cache,
torch::Tensor& value_cache, int num_kv_heads,
float scale, torch::Tensor& block_tables,
torch::Tensor& seq_lens, int block_size,
int max_seq_len,
const c10::optional<torch::Tensor>& alibi_slopes,
const std::string& kv_cache_dtype, float kv_scale) {
void paged_attention_v2(
torch::Tensor& out, torch::Tensor& exp_sums, torch::Tensor& max_logits,
torch::Tensor& tmp_out, torch::Tensor& query, torch::Tensor& key_cache,
torch::Tensor& value_cache, int num_kv_heads, float scale,
torch::Tensor& block_tables, torch::Tensor& seq_lens, int block_size,
int max_seq_len, const c10::optional<torch::Tensor>& alibi_slopes,
const std::string& kv_cache_dtype, float kv_scale, const int tp_rank,
const int blocksparse_local_blocks, const int blocksparse_vert_stride,
const int blocksparse_block_size, const int blocksparse_head_sliding_step) {
TORCH_CHECK(kv_scale == 1.0f);
TORCH_CHECK(blocksparse_vert_stride <= 1,
"CPU backend does not support blocksparse attention yet.");
VLLM_DISPATCH_FLOATING_TYPES(query.scalar_type(), "paged_attention_v2_impl",
[&] {
CPU_KERNEL_GUARD_IN(paged_attention_v2_impl)

View File

@ -2,23 +2,24 @@
#include <torch/extension.h>
void paged_attention_v1(torch::Tensor& out, torch::Tensor& query,
torch::Tensor& key_cache, torch::Tensor& value_cache,
int num_kv_heads, float scale,
torch::Tensor& block_tables, torch::Tensor& seq_lens,
int block_size, int max_seq_len,
const c10::optional<torch::Tensor>& alibi_slopes,
const std::string& kv_cache_dtype, float kv_scale);
void paged_attention_v1(
torch::Tensor& out, torch::Tensor& query, torch::Tensor& key_cache,
torch::Tensor& value_cache, int num_kv_heads, float scale,
torch::Tensor& block_tables, torch::Tensor& seq_lens, int block_size,
int max_seq_len, const c10::optional<torch::Tensor>& alibi_slopes,
const std::string& kv_cache_dtype, float kv_scale, const int tp_rank,
const int blocksparse_local_blocks, const int blocksparse_vert_stride,
const int blocksparse_block_size, const int blocksparse_head_sliding_step);
void paged_attention_v2(torch::Tensor& out, torch::Tensor& exp_sums,
torch::Tensor& max_logits, torch::Tensor& tmp_out,
torch::Tensor& query, torch::Tensor& key_cache,
torch::Tensor& value_cache, int num_kv_heads,
float scale, torch::Tensor& block_tables,
torch::Tensor& seq_lens, int block_size,
int max_seq_len,
const c10::optional<torch::Tensor>& alibi_slopes,
const std::string& kv_cache_dtype, float kv_scale);
void paged_attention_v2(
torch::Tensor& out, torch::Tensor& exp_sums, torch::Tensor& max_logits,
torch::Tensor& tmp_out, torch::Tensor& query, torch::Tensor& key_cache,
torch::Tensor& value_cache, int num_kv_heads, float scale,
torch::Tensor& block_tables, torch::Tensor& seq_lens, int block_size,
int max_seq_len, const c10::optional<torch::Tensor>& alibi_slopes,
const std::string& kv_cache_dtype, float kv_scale, const int tp_rank,
const int blocksparse_local_blocks, const int blocksparse_vert_stride,
const int blocksparse_block_size, const int blocksparse_head_sliding_step);
void rms_norm(torch::Tensor& out, torch::Tensor& input, torch::Tensor& weight,
float epsilon);

View File

@ -123,6 +123,10 @@ Alongside each architecture, we include some popular models that use it.
- Phi-3
- :code:`microsoft/Phi-3-mini-4k-instruct`, :code:`microsoft/Phi-3-mini-128k-instruct`, etc.
-
* - :code:`Phi3SmallForCausalLM`
- Phi-3-Small
- :code:`microsoft/Phi-3-small-8k-instruct`, :code:`microsoft/Phi-3-small-128k-instruct`, etc.
-
* - :code:`QWenLMHeadModel`
- Qwen
- :code:`Qwen/Qwen-7B`, :code:`Qwen/Qwen-7B-Chat`, etc.

View File

@ -0,0 +1,442 @@
import random
from typing import List, Optional, Tuple
import pytest
import torch
from vllm import _custom_ops as ops
from vllm.attention.ops.blocksparse_attention.interface import (
LocalStridedBlockSparseAttn)
from vllm.utils import get_max_shared_memory_bytes, is_hip
from .allclose_default import get_default_atol, get_default_rtol
FLOAT32_BYTES = torch.finfo(torch.float).bits // 8
# This will change depending on the compute capability.
# - 512 as a buffer
MAX_SEQ_LEN = get_max_shared_memory_bytes() // FLOAT32_BYTES - 512
# MAX_SEQ_LEN = 2771
# There may not be enough gpu memory due to large NUM_BLOCKS.
# Reduce NUM_BLOCKS when it happens.
NUM_BLOCKS = 4321 # Arbitrary values for testing
PARTITION_SIZE = 512
DTYPES = [torch.half, torch.bfloat16]
NUM_GEN_SEQS = [3] # Arbitrary values for testing
NUM_PREFILL_SEQS = [3] # Arbitrary values for testing
NUM_HEADS = [(40, 40), (64, 8)] # Arbitrary values for testing
HEAD_SIZES = [64, 112]
BLOCK_SIZES = [16, 32]
USE_ALIBI = [False, True]
KV_CACHE_DTYPE = ["auto", "fp8"]
SEEDS = [0]
CUDA_DEVICES = ['cuda:0']
BLOCKSPARSE_LOCAL_BLOCKS = [16]
BLOCKSPARSE_VERT_STRIDES = [8]
BLOCKSPARSE_BLOCK_SIZES = [64]
BLOCKSPARSE_HEADS_SLIDINGS = [0, 2, -1]
BLOCKSPARSE_HOMO_HEADS = [True, False]
def ref_masked_attention(
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
scale: float,
attn_mask: Optional[torch.Tensor] = None,
) -> torch.Tensor:
attn_weights = scale * torch.einsum("qhd,khd->hqk", query, key).float()
if attn_mask is not None:
attn_weights = attn_weights + attn_mask.float()
attn_weights = torch.softmax(attn_weights, dim=-1).to(value.dtype)
out = torch.einsum("hqk,khd->qhd", attn_weights, value)
return out
def ref_single_query_cached_kv_attention(
output: torch.Tensor,
query: torch.Tensor,
num_queries_per_kv: int,
key_cache: torch.Tensor,
value_cache: torch.Tensor,
block_tables: torch.Tensor,
seq_lens: torch.Tensor,
scale: float,
alibi_slopes: Optional[torch.Tensor],
tp_rank: int = 0,
blocksparse_local_blocks: int = 0,
blocksparse_vert_stride: int = 1,
blocksparse_block_size: int = 64,
blocksparse_head_sliding_step: int = 0,
) -> None:
num_query_heads = query.shape[1]
num_kv_heads = value_cache.shape[1]
head_size = value_cache.shape[2]
block_size = value_cache.shape[3]
num_seqs = query.shape[0]
block_tables = block_tables.cpu().tolist()
seq_lens = seq_lens.cpu().tolist()
for i in range(num_seqs):
q = query[i].unsqueeze(0)
block_table = block_tables[i]
seq_len = int(seq_lens[i])
keys = []
values = []
for j in range(seq_len):
block_number = int(block_table[j // block_size])
block_offset = j % block_size
k = key_cache[block_number, :, :, block_offset, :]
k = k.reshape(num_kv_heads, head_size)
keys.append(k)
v = value_cache[block_number, :, :, block_offset]
values.append(v)
keys = torch.stack(keys, dim=0)
values = torch.stack(values, dim=0)
if num_queries_per_kv > 1:
# Handle MQA and GQA
keys = torch.repeat_interleave(keys, num_queries_per_kv, dim=1)
values = torch.repeat_interleave(values, num_queries_per_kv, dim=1)
alibi_bias = None
if alibi_slopes is not None:
# Create the ALiBi bias used in the paged attention kernel.
position_ids = torch.arange(seq_len).int()
alibi_bias = (position_ids - seq_len + 1).float()
alibi_bias = alibi_slopes.view(-1, 1, 1) * alibi_bias.view(
1, 1, -1)
if blocksparse_vert_stride >= 1:
bsize = blocksparse_block_size
hsliding = blocksparse_head_sliding_step
vert = blocksparse_vert_stride
locals = blocksparse_local_blocks
qb = (seq_len - 1) // bsize
attn_mask = q.new_zeros(
(num_query_heads, 1, seq_len)).float() - torch.inf
for h in range(num_query_heads):
if hsliding >= 0: # slide with q heads
bs_offset = (tp_rank * num_query_heads + h) * hsliding + 1
else: # slide with kv heads
bs_offset = (tp_rank * num_kv_heads +
h // num_queries_per_kv) * (-hsliding) + 1
for kb in range(qb + 1):
kj = kb * bsize
if (qb - kb) < locals or \
(kb + bs_offset) % vert == 0:
attn_mask[h, 0, kj:min(kj + bsize, seq_len)] = 0
if alibi_bias is not None:
attn_mask += alibi_bias
else:
attn_mask = alibi_bias
out = ref_masked_attention(q, keys, values, scale, attn_mask=attn_mask)
out = out.view(num_query_heads, head_size)
output[i].copy_(out, non_blocking=True)
@pytest.mark.parametrize("version", ["v1", "v2"])
@pytest.mark.parametrize("num_seqs", NUM_GEN_SEQS)
@pytest.mark.parametrize("num_heads", NUM_HEADS)
@pytest.mark.parametrize("head_size", HEAD_SIZES)
@pytest.mark.parametrize("use_alibi", USE_ALIBI)
@pytest.mark.parametrize("block_size", BLOCK_SIZES)
@pytest.mark.parametrize("dtype", DTYPES)
@pytest.mark.parametrize("kv_cache_dtype", KV_CACHE_DTYPE)
@pytest.mark.parametrize("seed", SEEDS)
@pytest.mark.parametrize("device", CUDA_DEVICES)
@pytest.mark.parametrize("blocksparse_local_blocks", BLOCKSPARSE_LOCAL_BLOCKS)
@pytest.mark.parametrize("blocksparse_vert_stride", BLOCKSPARSE_VERT_STRIDES)
@pytest.mark.parametrize("blocksparse_block_size", BLOCKSPARSE_BLOCK_SIZES)
@pytest.mark.parametrize("blocksparse_head_sliding_step",
BLOCKSPARSE_HEADS_SLIDINGS)
def test_paged_attention(
kv_cache_factory,
version: str,
num_seqs: int,
num_heads: Tuple[int, int],
head_size: int,
use_alibi: bool,
block_size: int,
dtype: torch.dtype,
kv_cache_dtype: str,
seed: int,
device: str,
blocksparse_local_blocks: int,
blocksparse_vert_stride: int,
blocksparse_block_size: int,
blocksparse_head_sliding_step: int,
) -> None:
random.seed(seed)
torch.random.manual_seed(seed)
if torch.cuda.is_available():
torch.cuda.manual_seed(seed)
torch.set_default_device(device)
scale = float(1.0 / (head_size**0.5))
num_query_heads, num_kv_heads = num_heads
query = torch.empty(num_seqs, num_query_heads, head_size, dtype=dtype)
query.uniform_(-scale, scale)
assert num_query_heads % num_kv_heads == 0
num_queries_per_kv = num_query_heads // num_kv_heads
alibi_slopes = None
if use_alibi:
alibi_slopes = torch.rand(num_query_heads, dtype=torch.float)
seq_lens = [random.randint(1, MAX_SEQ_LEN) for _ in range(num_seqs)]
seq_lens[-1] = MAX_SEQ_LEN
max_seq_len = max(seq_lens)
seq_lens = torch.tensor(seq_lens, dtype=torch.int)
# Create the block tables.
max_num_blocks_per_seq = (max_seq_len + block_size - 1) // block_size
block_tables = []
for _ in range(num_seqs):
block_table = [
random.randint(0, NUM_BLOCKS - 1)
for _ in range(max_num_blocks_per_seq)
]
block_tables.append(block_table)
block_tables = torch.tensor(block_tables, dtype=torch.int)
# Create the KV caches.
key_caches, value_caches = kv_cache_factory(NUM_BLOCKS, block_size, 1,
num_kv_heads, head_size,
kv_cache_dtype, dtype, seed,
device)
key_cache, value_cache = key_caches[0], value_caches[0]
# Using default kv_scale
kv_scale = 1.0
tp_rank = 0
# Call the paged attention kernel.
output = torch.empty_like(query)
if version == "v1":
ops.paged_attention_v1(
output,
query,
key_cache,
value_cache,
num_kv_heads,
scale,
block_tables,
seq_lens,
block_size,
max_seq_len,
alibi_slopes,
kv_cache_dtype,
kv_scale,
tp_rank=tp_rank,
blocksparse_local_blocks=blocksparse_local_blocks,
blocksparse_vert_stride=blocksparse_vert_stride,
blocksparse_block_size=blocksparse_block_size,
blocksparse_head_sliding_step=blocksparse_head_sliding_step,
)
elif version == "v2":
num_partitions = ((max_seq_len + PARTITION_SIZE - 1) // PARTITION_SIZE)
assert PARTITION_SIZE % block_size == 0
num_seqs, num_heads, head_size = output.shape
tmp_output = torch.empty(
size=(num_seqs, num_heads, num_partitions, head_size),
dtype=output.dtype,
)
exp_sums = torch.empty(
size=(num_seqs, num_heads, num_partitions),
dtype=torch.float32,
)
max_logits = torch.empty_like(exp_sums)
ops.paged_attention_v2(
output,
exp_sums,
max_logits,
tmp_output,
query,
key_cache,
value_cache,
num_kv_heads,
scale,
block_tables,
seq_lens,
block_size,
max_seq_len,
alibi_slopes,
kv_cache_dtype,
kv_scale,
tp_rank=tp_rank,
blocksparse_local_blocks=blocksparse_local_blocks,
blocksparse_vert_stride=blocksparse_vert_stride,
blocksparse_block_size=blocksparse_block_size,
blocksparse_head_sliding_step=blocksparse_head_sliding_step,
)
else:
raise AssertionError(f"Unknown version: {version}")
# Run the reference implementation.
if kv_cache_dtype == "fp8":
# Convert cache data back to dtype.
x = 16 // torch.tensor([], dtype=dtype).element_size()
key_cache_shape = (NUM_BLOCKS, num_kv_heads, head_size // x,
block_size, x)
dequantized_key_cache = torch.empty(size=key_cache_shape,
dtype=dtype,
device=device)
ops.convert_fp8(dequantized_key_cache, key_cache)
key_cache = dequantized_key_cache
value_cache_shape = value_cache.shape
dequantized_value_cache = torch.empty(size=value_cache_shape,
dtype=dtype,
device=device)
ops.convert_fp8(dequantized_value_cache, value_cache)
value_cache = dequantized_value_cache
ref_output = torch.empty_like(query)
ref_single_query_cached_kv_attention(
ref_output,
query,
num_queries_per_kv,
key_cache,
value_cache,
block_tables,
seq_lens,
scale,
alibi_slopes,
tp_rank,
blocksparse_local_blocks,
blocksparse_vert_stride,
blocksparse_block_size,
blocksparse_head_sliding_step,
)
# NOTE(woosuk): Due to the kernel-level differences in the two
# implementations, there is a small numerical difference in the two
# outputs. Thus, we use a relaxed tolerance for the test.
atol = get_default_atol(output) if is_hip() else 1e-3
rtol = get_default_rtol(output) if is_hip() else 1e-5
# NOTE(zhaoyang): FP8 KV Cache will introduce quantization error,
# so we use a relaxed tolerance for the test.
atol, rtol = 1e-3, 1e-5
if kv_cache_dtype == "fp8":
atol, rtol = 1e-2, 1e-5
assert torch.allclose(output, ref_output, atol=atol, rtol=rtol)
def ref_multi_query_kv_attention(
cu_seq_lens: List[int],
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
scale: float,
dtype: torch.dtype,
) -> torch.Tensor:
num_seqs = len(cu_seq_lens) - 1
ref_outputs = []
for i in range(num_seqs):
start_idx = cu_seq_lens[i]
end_idx = cu_seq_lens[i + 1]
seq_len = end_idx - start_idx
# Create attention mask.
attn_mask = torch.triu(torch.ones(seq_len, seq_len, dtype=dtype),
diagonal=1)
attn_mask = attn_mask * torch.finfo(dtype).min
attn_mask = attn_mask.to(dtype=dtype)
ref_output = ref_masked_attention(
query[start_idx:end_idx],
key[start_idx:end_idx],
value[start_idx:end_idx],
scale,
attn_mask=attn_mask,
)
ref_outputs.append(ref_output)
ref_output = torch.cat(ref_outputs, dim=0)
return ref_output
@pytest.mark.parametrize("num_seqs", NUM_PREFILL_SEQS)
@pytest.mark.parametrize("num_heads", NUM_HEADS)
@pytest.mark.parametrize("head_size", HEAD_SIZES)
@pytest.mark.parametrize("blocksparse_local_blocks", BLOCKSPARSE_LOCAL_BLOCKS)
@pytest.mark.parametrize("blocksparse_vert_stride", BLOCKSPARSE_VERT_STRIDES)
@pytest.mark.parametrize("blocksparse_block_size", BLOCKSPARSE_BLOCK_SIZES)
@pytest.mark.parametrize("blocksparse_homo_heads", BLOCKSPARSE_HOMO_HEADS)
@pytest.mark.parametrize("dtype", DTYPES)
@pytest.mark.parametrize("seed", SEEDS)
@pytest.mark.parametrize("device", CUDA_DEVICES)
@torch.inference_mode()
def test_varlen_blocksparse_attention_prefill(
num_seqs: int,
num_heads: Tuple[int, int],
head_size: int,
blocksparse_local_blocks: int,
blocksparse_vert_stride: int,
blocksparse_block_size: int,
blocksparse_homo_heads: bool,
dtype: torch.dtype,
seed: int,
device: str,
) -> None:
random.seed(seed)
torch.random.manual_seed(seed)
if torch.cuda.is_available():
torch.cuda.manual_seed(seed)
torch.set_default_device(device)
# MAX_SEQ_LEN sometimes causes OOM in the reference implementation.
# As the xformers library is already tested with its own tests, we can use
# a smaller MAX_SEQ_LEN here.
max_len = min(MAX_SEQ_LEN, 4096)
seq_lens = random.sample(range(1, max_len), num_seqs)
cu_seq_lens = torch.cumsum(torch.tensor([0] + seq_lens), dim=0)
num_tokens = sum(seq_lens)
scale = float(1.0 / (head_size**0.5))
num_query_heads, num_kv_heads = num_heads
assert num_query_heads % num_kv_heads == 0
num_queries_per_kv = num_query_heads // num_kv_heads
qkv = torch.empty(num_tokens,
num_query_heads + 2 * num_kv_heads,
head_size,
dtype=dtype)
qkv.uniform_(-scale, scale)
query, key, value = qkv.split(
[num_query_heads, num_kv_heads, num_kv_heads], dim=1)
bs_attn_op = LocalStridedBlockSparseAttn(
num_query_heads,
max_len,
local_blocks=blocksparse_local_blocks,
vert_stride=blocksparse_vert_stride,
block_size=blocksparse_block_size,
device=device,
dtype=dtype,
homo_head=blocksparse_homo_heads)
output = bs_attn_op(query,
key,
value,
cu_seq_lens.to(device),
sm_scale=scale)
if num_queries_per_kv > 1:
# Handle MQA and GQA
key = torch.repeat_interleave(key, num_queries_per_kv, dim=1)
value = torch.repeat_interleave(value, num_queries_per_kv, dim=1)
ref_output = ref_multi_query_kv_attention(
cu_seq_lens,
query,
key,
value,
scale,
dtype,
)
assert torch.allclose(output, ref_output, atol=1e-2, rtol=1e-2)

View File

@ -45,11 +45,17 @@ def paged_attention_v1(
alibi_slopes: Optional[torch.Tensor],
kv_cache_dtype: str,
kv_scale: float,
tp_rank: int = 0,
blocksparse_local_blocks: int = 0,
blocksparse_vert_stride: int = 0,
blocksparse_block_size: int = 64,
blocksparse_head_sliding_step: int = 0,
) -> None:
vllm_ops.paged_attention_v1(out, query, key_cache, value_cache,
num_kv_heads, scale, block_tables, seq_lens,
block_size, max_seq_len, alibi_slopes,
kv_cache_dtype, kv_scale)
vllm_ops.paged_attention_v1(
out, query, key_cache, value_cache, num_kv_heads, scale, block_tables,
seq_lens, block_size, max_seq_len, alibi_slopes, kv_cache_dtype,
kv_scale, tp_rank, blocksparse_local_blocks, blocksparse_vert_stride,
blocksparse_block_size, blocksparse_head_sliding_step)
def paged_attention_v2(
@ -69,12 +75,18 @@ def paged_attention_v2(
alibi_slopes: Optional[torch.Tensor],
kv_cache_dtype: str,
kv_scale: float,
tp_rank: int = 0,
blocksparse_local_blocks: int = 0,
blocksparse_vert_stride: int = 0,
blocksparse_block_size: int = 64,
blocksparse_head_sliding_step: int = 0,
) -> None:
vllm_ops.paged_attention_v2(out, exp_sum, max_logits, tmp_out, query,
key_cache, value_cache, num_kv_heads, scale,
block_tables, seq_lens, block_size,
max_seq_len, alibi_slopes, kv_cache_dtype,
kv_scale)
vllm_ops.paged_attention_v2(
out, exp_sum, max_logits, tmp_out, query, key_cache, value_cache,
num_kv_heads, scale, block_tables, seq_lens, block_size, max_seq_len,
alibi_slopes, kv_cache_dtype, kv_scale, tp_rank,
blocksparse_local_blocks, blocksparse_vert_stride,
blocksparse_block_size, blocksparse_head_sliding_step)
# pos encoding ops

View File

@ -111,6 +111,7 @@ class AttentionImpl(ABC, Generic[T]):
alibi_slopes: Optional[List[float]] = None,
sliding_window: Optional[int] = None,
kv_cache_dtype: str = "auto",
blocksparse_params: Optional[Dict[str, Any]] = None,
) -> None:
raise NotImplementedError

View File

@ -0,0 +1,410 @@
from dataclasses import dataclass, field
from typing import Any, Dict, List, Optional, Tuple, Type
import torch
from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
AttentionMetadata)
from vllm.attention.ops.blocksparse_attention.interface import (
LocalStridedBlockSparseAttn, get_head_sliding_step)
from vllm.attention.ops.paged_attn import PagedAttention
from vllm.distributed import (get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size)
@dataclass
class BlocksparseParams:
max_seqlen: int
# Num q heads per tensor-parallel rank/partition
num_heads: int # per TP partition
# Num kv heads per tensor-parallel rank/partition
num_kv_heads: int
# block size used for blocksparse attention.
# This is the block_size used in `local_blocks`, `vert_stride`.
block_size: int
# Number of blocks for local attention, i.e., number of
# local attended tokens / `sparse_block_size`
local_blocks: int
# Attend to one block per every `vert_stride` blocks.
# Controlling the sparsity
vert_stride: int
"""
If to use the same vertical stride offset for all heads,
i.e., attend to the same block of tokens on all heads.
By default, it is False, i.e., attention on the non-local
blocks depends on the `head_idx`, that is on
blocks satisfying
`(block_idx + head_idx * head_sliding_step + 1) % vert_stride == 0`
where `head_sliding_step=max(1, int(vert_stride / num_total_heads))`,
`block_idx = position_id // sparse_block_size`.
See `..ops.blocksparse_attention.utils:get_sparse_attn_mask`
for more detail.
"""
homo_head: bool = False
# If within a group, the kv offsets that each q attends is the same or no.
homo_head_group: bool = False
# Decided by homo_head and homo_head group
head_sliding_step: int = field(init=False)
# range of q heads to for a TP rank
active_head_range: Tuple = field(init=False)
def __post_init__(self):
assert self.block_size > 0
assert self.local_blocks >= 0
assert self.vert_stride >= 1
assert self.num_heads % self.num_kv_heads == 0
tp_size = get_tensor_model_parallel_world_size()
tp_rank = get_tensor_model_parallel_rank()
total_heads = tp_size * self.num_heads
total_kv_heads = tp_size * self.num_kv_heads
if self.homo_head:
self.head_sliding_step = 0
elif self.homo_head_group:
head_sliding_step = get_head_sliding_step(total_kv_heads,
self.vert_stride)
# negative indicates sliding along kv heads, i.e., homo q group
self.head_sliding_step = -head_sliding_step
else:
self.head_sliding_step = get_head_sliding_step(
total_heads, self.vert_stride)
self.active_head_range = (
tp_rank * self.num_heads,
(tp_rank + 1) * self.num_heads,
)
class BlocksparseFlashAttentionBackend(AttentionBackend):
@staticmethod
def get_impl_cls() -> Type["BlocksparseFlashAttentionImpl"]:
return BlocksparseFlashAttentionImpl
@staticmethod
def make_metadata(*args, **kwargs) -> "BlocksparseFlashAttentionMetadata":
return BlocksparseFlashAttentionMetadata(*args, **kwargs)
@staticmethod
def get_kv_cache_shape(
num_blocks: int,
block_size: int,
num_kv_heads: int,
head_size: int,
) -> Tuple[int, ...]:
return PagedAttention.get_kv_cache_shape(num_blocks, block_size,
num_kv_heads, head_size)
@staticmethod
def swap_blocks(
src_kv_cache: torch.Tensor,
dst_kv_cache: torch.Tensor,
src_to_dst: Dict[int, int],
) -> None:
PagedAttention.swap_blocks(src_kv_cache, dst_kv_cache, src_to_dst)
@staticmethod
def copy_blocks(
kv_caches: List[torch.Tensor],
src_to_dists: Dict[int, List[int]],
) -> None:
PagedAttention.copy_blocks(kv_caches, src_to_dists)
@dataclass
class BlocksparseFlashAttentionMetadata(AttentionMetadata):
"""A copy of Metadata for FlashAttentionBackend,
to avoid having to install flash_attn.
NOTE: Any python object stored here is not updated when it is
cuda-graph replayed. If you have values that need to be changed
dynamically, it should be stored in tensor. The tensor has to be
updated from `CUDAGraphRunner.forward` API.
"""
# (batch_size,). The sequence length per sequence. Sequence length means
# the computed tokens + new tokens None if it is a decoding.
seq_lens: Optional[List[int]]
# seq_lens stored as a tensor.
seq_lens_tensor: Optional[torch.Tensor]
# NOTE(sang): Definition of context_len, query_len, and seq_len.
# |---------- N-1 iteration --------|
# |---------------- N iteration ---------------------|
# |- tokenA -|......................|-- newTokens ---|
# |---------- context_len ----------|
# |-------------------- seq_len ----------------------|
# |-- query_len ---|
# Maximum query length in the batch. None for decoding.
max_query_len: Optional[int]
# Maximum sequence length among prefill batch. 0 if there are decoding
# requests only.
max_prefill_seq_len: int
# Maximum sequence length among decode batch. 0 if there are prefill
# requests only.
max_decode_seq_len: int
# (batch_size + 1,). The cumulative subquery lengths of the sequences in
# the batch, used to index into subquery. E.g., if the subquery length
# is [4, 6], it is [0, 4, 10].
query_start_loc: Optional[torch.Tensor]
# (batch_size + 1,). The cumulative sequence lengths of the sequences in
# the batch, used to index into sequence. E.g., if the sequence length is
# [4, 6], it is [0, 4, 10].
seq_start_loc: Optional[torch.Tensor]
# (batch_size,) A tensor of context lengths (tokens that are computed
# so far).
context_lens_tensor: Optional[torch.Tensor]
# (batch_size, max_blocks_per_seq).
# Block addresses per sequence. (Seq id -> list of physical block)
# E.g., [0, 1, 2] means tokens are stored in 0th, 1st, and 2nd blocks
# in the kv cache. Each block can contain up to block_size tokens.
# 2nd dimensions are padded up to max_blocks_per_seq if it is cuda-graph
# captured.
block_tables: Optional[torch.Tensor]
# Whether or not if cuda graph is enabled.
# Cuda-graph is currently enabled for decoding only.
# TODO(woosuk): Move `use_cuda_graph` out since it's unrelated to attention.
use_cuda_graph: bool
_cached_prefill_metadata: Optional[
"BlocksparseFlashAttentionMetadata"] = None
_cached_decode_metadata: Optional[
"BlocksparseFlashAttentionMetadata"] = None
@property
def prefill_metadata(
self) -> Optional["BlocksparseFlashAttentionMetadata"]:
if self.num_prefills == 0:
return None
if self._cached_prefill_metadata is not None:
return self._cached_prefill_metadata
assert self.seq_lens is not None
assert self.seq_lens_tensor is not None
assert self.query_start_loc is not None
assert self.context_lens_tensor is not None
assert self.block_tables is not None
assert self.seq_start_loc is not None
self._cached_prefill_metadata = BlocksparseFlashAttentionMetadata(
num_prefills=self.num_prefills,
num_prefill_tokens=self.num_prefill_tokens,
num_decode_tokens=0,
slot_mapping=self.slot_mapping[:self.num_prefill_tokens],
seq_lens=self.seq_lens[:self.num_prefills],
seq_lens_tensor=self.seq_lens_tensor[:self.num_prefills],
max_query_len=self.max_query_len,
max_prefill_seq_len=self.max_prefill_seq_len,
max_decode_seq_len=0,
query_start_loc=self.query_start_loc[:self.num_prefills + 1],
seq_start_loc=self.seq_start_loc[:self.num_prefills + 1],
context_lens_tensor=self.context_lens_tensor[:self.num_prefills],
block_tables=self.block_tables[:self.num_prefills],
use_cuda_graph=False,
)
return self._cached_prefill_metadata
@property
def decode_metadata(self) -> Optional["BlocksparseFlashAttentionMetadata"]:
if self.num_decode_tokens == 0:
return None
if self._cached_decode_metadata is not None:
return self._cached_decode_metadata
assert self.block_tables is not None
assert self.seq_lens_tensor is not None
self._cached_decode_metadata = BlocksparseFlashAttentionMetadata(
num_prefills=0,
num_prefill_tokens=0,
num_decode_tokens=self.num_decode_tokens,
slot_mapping=self.slot_mapping[self.num_prefill_tokens:],
seq_lens=None,
seq_lens_tensor=self.seq_lens_tensor[self.num_prefills:],
max_query_len=None,
max_prefill_seq_len=0,
max_decode_seq_len=self.max_decode_seq_len,
query_start_loc=None,
seq_start_loc=None,
context_lens_tensor=None,
block_tables=self.block_tables[self.num_prefills:],
use_cuda_graph=self.use_cuda_graph,
)
return self._cached_decode_metadata
class BlocksparseFlashAttentionImpl(AttentionImpl):
"""
If the input tensors contain prompt tokens, the layout is as follows:
|<--------------- num_prompt_tokens -------------->|
|<--prompt_0-->|<--prompt_1-->|...|<--prompt_N-1-->|
Otherwise, the layout is as follows:
|<------------------ num_generation_tokens (M) ----------------->|
|<--generation_0-->|..........|<--generation_M-1-->|<--padding-->|
Generation tokens can contain padding when cuda-graph is used.
Currently, prompt tokens don't contain any padding.
The prompts might have different lengths, while the generation tokens
always have length 1.
"""
def __init__(
self,
num_heads: int,
head_size: int,
scale: float,
num_kv_heads: int,
alibi_slopes: Optional[List[float]],
sliding_window: Optional[int],
kv_cache_dtype: str,
blocksparse_params: Optional[Dict[str, Any]] = None,
) -> None:
assert blocksparse_params is not None
assert alibi_slopes is None, ValueError(
"Alibi not support for blocksparse flash attention.")
assert sliding_window is None, ValueError(
"sliding_window is invalid for blocksparse attention.")
if "num_heads" not in blocksparse_params:
blocksparse_params["num_heads"] = num_heads
if "num_kv_heads" not in blocksparse_params:
blocksparse_params["num_kv_heads"] = num_kv_heads or num_heads
self.blocksparse_params = BlocksparseParams(**blocksparse_params)
self.kv_cache_dtype = kv_cache_dtype
self.num_heads = num_heads
self.head_size = head_size
self.scale = float(scale)
self.alibi_slopes = alibi_slopes
self.num_kv_heads = num_heads if num_kv_heads is None else num_kv_heads
assert self.num_heads % self.num_kv_heads == 0
self.num_queries_per_kv = self.num_heads // self.num_kv_heads
self.local_blocks = self.blocksparse_params.local_blocks
self.vert_stride = self.blocksparse_params.vert_stride
self.sparse_block_size = self.blocksparse_params.block_size
self.head_sliding_step = self.blocksparse_params.head_sliding_step
suppored_head_sizes = PagedAttention.get_supported_head_sizes()
if head_size not in suppored_head_sizes:
raise ValueError(
f"Head size {head_size} is not supported by PagedAttention. "
f"Supported head sizes are: {suppored_head_sizes}.")
self.tp_size = get_tensor_model_parallel_world_size()
self.tp_rank = get_tensor_model_parallel_rank()
total_num_heads = num_heads * self.tp_size
self.bs_attn = LocalStridedBlockSparseAttn(
total_num_heads,
self.blocksparse_params.max_seqlen,
self.blocksparse_params.local_blocks,
self.blocksparse_params.vert_stride,
self.blocksparse_params.block_size,
homo_head=self.blocksparse_params.homo_head,
active_head_range=self.blocksparse_params.active_head_range,
)
def forward(
self,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
kv_cache: torch.Tensor,
attn_metadata: BlocksparseFlashAttentionMetadata,
kv_scale: float = 1.0,
) -> torch.Tensor:
"""Forward pass with FlashAttention and PagedAttention.
Args:
query: shape = [num_tokens, num_heads * head_size]
key: shape = [num_tokens, num_kv_heads * head_size]
value: shape = [num_tokens, num_kv_heads * head_size]
kv_cache = [2, num_blocks, block_size * num_kv_heads * head_size]
attn_metadata: Metadata for attention.
Returns:
shape = [num_tokens, num_heads * head_size]
"""
num_tokens, hidden_size = query.shape
# Reshape the query, key, and value tensors.
query = query.view(-1, self.num_heads, self.head_size)
key = key.view(-1, self.num_kv_heads, self.head_size)
value = value.view(-1, self.num_kv_heads, self.head_size)
if kv_cache is not None:
key_cache, value_cache = PagedAttention.split_kv_cache(
kv_cache, self.num_kv_heads, self.head_size)
# Reshape the input keys and values and store them in the cache.
# If kv_cache is not provided, the new key and value tensors are
# not cached. This happens during the initial memory profiling run.
PagedAttention.write_to_paged_cache(
key,
value,
key_cache,
value_cache,
attn_metadata.slot_mapping,
self.kv_cache_dtype,
kv_scale,
)
if prefill_meta := attn_metadata.prefill_metadata:
# Prompt run.
# normal attention
# When block_tables are not filled, it means q and k are the
# prompt, and they have the same length.
assert kv_cache is None \
or prefill_meta.block_tables is None \
or prefill_meta.block_tables.numel() == 0, \
"Does not support prefix-enabled attention."
output = self.bs_attn(
q=query,
k=key,
v=value,
cu_seqlens_q=prefill_meta.seq_start_loc,
cu_seqlens_k=prefill_meta.seq_start_loc,
sm_scale=self.scale,
)
if decode_meta := attn_metadata.decode_metadata:
# Decoding run.
output = PagedAttention.forward_decode(
query,
key_cache,
value_cache,
decode_meta.block_tables,
decode_meta.seq_lens_tensor,
self.blocksparse_params.max_seqlen,
self.kv_cache_dtype,
self.num_kv_heads,
self.scale,
self.alibi_slopes,
kv_scale,
tp_rank=self.tp_rank,
blocksparse_local_blocks=self.local_blocks,
blocksparse_vert_stride=self.vert_stride,
blocksparse_block_size=self.sparse_block_size,
blocksparse_head_sliding_step=self.head_sliding_step,
)
# Reshape the output tensor.
return output.view(num_tokens, hidden_size)

View File

@ -1,6 +1,6 @@
"""Attention layer with FlashAttention."""
from dataclasses import dataclass
from typing import List, Optional, Tuple, Type
from typing import Any, Dict, List, Optional, Tuple, Type
import torch
from vllm_flash_attn import flash_attn_varlen_func, flash_attn_with_kvcache
@ -219,7 +219,10 @@ class FlashAttentionImpl(AttentionImpl):
alibi_slopes: Optional[List[float]],
sliding_window: Optional[int],
kv_cache_dtype: str,
blocksparse_params: Optional[Dict[str, Any]] = None,
) -> None:
assert blocksparse_params is None, ValueError(
"FlashAttention does not support block-sparse attention.")
self.num_heads = num_heads
self.head_size = head_size
self.scale = float(scale)

View File

@ -1,6 +1,6 @@
"""Attention layer ROCm GPUs."""
from dataclasses import dataclass
from typing import List, Optional, Tuple, Type
from typing import Any, Dict, List, Optional, Tuple, Type
import torch
@ -201,7 +201,10 @@ class ROCmFlashAttentionImpl(AttentionImpl):
alibi_slopes: Optional[List[float]],
sliding_window: Optional[int],
kv_cache_dtype: str,
blocksparse_params: Optional[Dict[str, Any]] = None,
) -> None:
assert blocksparse_params is None, ValueError(
"ROCFlashAttention does not support blocksparse attention.")
self.num_heads = num_heads
self.head_size = head_size
self.scale = float(scale)

View File

@ -1,7 +1,7 @@
""" Attention layer with torch scaled_dot_product_attention
and PagedAttention."""
from dataclasses import dataclass
from typing import List, Optional, Tuple, Type
from typing import Any, Dict, List, Optional, Tuple, Type
import torch
from torch.nn.functional import scaled_dot_product_attention
@ -100,7 +100,10 @@ class TorchSDPABackendImpl(AttentionImpl[TorchSDPAMetadata]):
alibi_slopes: Optional[List[float]],
sliding_window: Optional[int],
kv_cache_dtype: str,
blocksparse_params: Optional[Dict[str, Any]] = None,
) -> None:
assert blocksparse_params is None, ValueError(
"Torch SPDA does not support block-sparse attention.")
self.num_heads = num_heads
self.head_size = head_size
self.scale = float(scale)

View File

@ -1,6 +1,6 @@
"""Attention layer with xFormers and PagedAttention."""
from dataclasses import dataclass
from typing import Dict, List, Optional, Tuple, Type
from typing import Any, Dict, List, Optional, Tuple, Type
import torch
from xformers import ops as xops
@ -212,7 +212,10 @@ class XFormersImpl(AttentionImpl[XFormersMetadata]):
alibi_slopes: Optional[List[float]],
sliding_window: Optional[int],
kv_cache_dtype: str,
blocksparse_params: Optional[Dict[str, Any]] = None,
) -> None:
assert blocksparse_params is None, ValueError(
"XFormer does not support block-sparse attention.")
self.num_heads = num_heads
self.head_size = head_size
self.scale = float(scale)

View File

@ -1,5 +1,5 @@
"""Attention layer."""
from typing import List, Optional
from typing import Any, Dict, List, Optional
import torch
import torch.nn as nn
@ -33,6 +33,7 @@ class Attention(nn.Module):
sliding_window: Optional[int] = None,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
blocksparse_params: Optional[Dict[str, Any]] = None,
) -> None:
super().__init__()
if cache_config is not None:
@ -69,10 +70,12 @@ class Attention(nn.Module):
dtype = torch.get_default_dtype()
attn_backend = get_attn_backend(num_heads, head_size, num_kv_heads,
sliding_window, dtype, kv_cache_dtype,
block_size)
block_size, blocksparse_params
is not None)
impl_cls = attn_backend.get_impl_cls()
self.impl = impl_cls(num_heads, head_size, scale, num_kv_heads,
alibi_slopes, sliding_window, kv_cache_dtype)
alibi_slopes, sliding_window, kv_cache_dtype,
blocksparse_params)
def forward(
self,
@ -90,4 +93,5 @@ class Attention(nn.Module):
s += f", num_heads={self.impl.num_heads}" # type: ignore
s += f", num_kv_heads={self.impl.num_kv_heads}" # type: ignore
s += f", scale={self.impl.scale}" # type: ignore
s += f", backend={self.impl.__class__.__name__}"
return s

View File

@ -0,0 +1,423 @@
import torch
import triton
import triton.language as tl
def blocksparse_flash_attn_varlen_fwd(
q,
k,
v, # (#tokens, n_heads, head_size)
cu_seqlens_k,
cu_seqlens_q,
sm_scale,
sparse_layout,
*,
block_size=64,
q_block_size=None,
max_seqlen=None):
# split q to blocks
assert isinstance(sparse_layout, (list, tuple))
_, n_heads, head_size = q.shape
batch_size = cu_seqlens_k.size(0) - 1
q_block_size = q_block_size or block_size
assert q.dim() == k.dim() == v.dim() == 3
assert q.size(1) % k.size(1) == 0
assert q.size(2) == k.size(2)
# TODO(linxihui): allow k, v to have different head_size
assert k.shape == v.shape
assert cu_seqlens_k.dim() == 1
q_k_ratio = q.size(1) // k.size(1)
if cu_seqlens_q is None:
if q.size(0) == batch_size: # decoding only
cu_seqlens_q = torch.arange(
0,
batch_size + 1,
dtype=cu_seqlens_k.dtype,
device=cu_seqlens_k.device,
)
elif q.size(0) == k.size(0):
cu_seqlens_q = cu_seqlens_k
else:
raise ValueError("cu_seqlens_q must be specified\
if it mix of prefilling and decoding.")
else:
assert cu_seqlens_k.size(0) == cu_seqlens_q.size(0)
# switch to use cpu to avoid too many kernel launches when iterated over
q_lens = (cu_seqlens_q[1:] - cu_seqlens_q[:-1]).cpu()
k_lens = (cu_seqlens_k[1:] - cu_seqlens_k[:-1]).cpu()
assert torch.logical_or(q_lens == 1, k_lens == q_lens).all(), (
"length of q should either be 1 (decoding) or same as k (prefilling).")
if max_seqlen:
assert k_lens.max() <= max_seqlen
n_blocks = (q_lens + q_block_size - 1) // q_block_size
q_batch_ids = torch.tensor(
[i for i, n in enumerate(n_blocks) for _ in range(n)],
dtype=cu_seqlens_q.dtype,
device=cu_seqlens_q.device,
)
q_start_sids = torch.tensor(
[i * q_block_size for n in n_blocks for i in range(n)],
dtype=cu_seqlens_q.dtype,
device=cu_seqlens_q.device,
)
out = q.new_empty(q.shape)
cu_seqlens_q = cu_seqlens_q.contiguous()
cu_seqlens_k = cu_seqlens_k.contiguous()
layout_crow_indices, layout_col_indices = sparse_layout
block_d = triton.next_power_of_2(head_size)
decoding_only = (q_lens == 1).all().item()
grid = (len(q_start_sids), n_heads, 1)
_fwd_kernel_batch_inference[grid](
q,
k,
v,
out,
sm_scale,
cu_seqlens_q[:-1],
cu_seqlens_q[1:],
cu_seqlens_k[:-1],
cu_seqlens_k[1:],
q_batch_ids,
q_start_sids,
0,
*q.stride(),
0,
*k.stride(),
0,
*v.stride(),
0,
*out.stride(),
layout_crow_indices,
layout_col_indices,
*layout_crow_indices.stride(),
*layout_col_indices.stride(),
q_k_ratio,
HAS_BATCH_DIM=False,
D_HEAD=head_size,
BLOCK_M=q_block_size,
BLOCK_N=block_size,
BLOCK_D=block_d,
BLOCK_M_LOADING=(16 if decoding_only else
q_block_size), # smaller for decoding
EVEN_D=block_d == head_size,
num_warps=1 if decoding_only else 4,
num_stages=3)
return out
@triton.jit
def _fwd_kernel_inner(
acc,
l_i,
m_i,
q,
Q,
k_block_col_idx,
layout_col_ptr,
layout_col_stride_h,
layout_col_stride_m,
k_ptrs,
v_ptrs,
off_h,
offs_m,
offs_n,
offs_d,
stride_kt,
stride_vt,
sm_scale,
k_seqlen,
past_len,
LAST_K_BLOCK: tl.constexpr,
BLOCK_M_LOADING: tl.constexpr,
BLOCK_N: tl.constexpr,
D_HEAD: tl.constexpr,
EVEN_D: tl.constexpr,
M_LT_N: tl.constexpr,
):
k_block_id = tl.load(layout_col_ptr + off_h * layout_col_stride_h +
k_block_col_idx * layout_col_stride_m).to(tl.int32)
start_n = k_block_id * BLOCK_N
if LAST_K_BLOCK:
if EVEN_D:
k = tl.load(
k_ptrs + start_n * stride_kt,
mask=offs_n[None, :] + start_n < k_seqlen,
)
else:
k = tl.load(
k_ptrs + start_n * stride_kt,
mask=(offs_n[None, :] + start_n < k_seqlen) &
(offs_d[:, None] < D_HEAD),
)
else:
if EVEN_D:
k = tl.load(k_ptrs + start_n * stride_kt)
else:
k = tl.load(k_ptrs + start_n * stride_kt,
mask=offs_d[:, None] < D_HEAD)
qk = tl.zeros([BLOCK_M_LOADING, BLOCK_N], dtype=tl.float32)
qk += tl.dot(q, k)
qk *= sm_scale
# the following is needed only when LAST_K_BLOCK or BLOCK_M < BLOCK_N
if LAST_K_BLOCK | M_LT_N:
qk += tl.where(
offs_m[:, None] + past_len >= (start_n + offs_n[None, :]),
0,
float("-inf"),
)
# flash-attn2
m_ij = tl.maximum(m_i, tl.max(qk, 1))
p = tl.math.exp2(qk - m_ij[:, None])
l_ij = tl.sum(p, 1)
alpha = tl.math.exp2(m_i - m_ij)
acc = acc * alpha[:, None]
# update m_i
m_i = m_ij
l_i = l_i * alpha + l_ij
p = p.to(Q.dtype.element_ty)
# update acc
if LAST_K_BLOCK:
if EVEN_D:
v = tl.load(
v_ptrs + start_n * stride_vt,
mask=offs_n[:, None] + start_n < k_seqlen,
)
else:
v = tl.load(
v_ptrs + start_n * stride_vt,
mask=(offs_n[:, None] + start_n < k_seqlen) &
(offs_d[None, :] < D_HEAD),
)
else:
if EVEN_D:
v = tl.load(v_ptrs + start_n * stride_vt)
else:
v = tl.load(v_ptrs + start_n * stride_vt,
mask=offs_d[None, :] < D_HEAD)
acc += tl.dot(p, v)
return acc, l_i, m_i
@triton.heuristics({
"M_LT_N":
lambda kwargs: kwargs["BLOCK_M"] < kwargs["BLOCK_N"],
})
@triton.jit
def _fwd_kernel_batch_inference(
Q,
K,
V,
Out,
sm_scale,
q_batch_starts,
q_batch_ends,
k_batch_starts,
k_batch_ends,
q_batch_ids,
q_start_sids,
stride_qb,
stride_qt,
stride_qh,
stride_qd,
stride_kb,
stride_kt,
stride_kh,
stride_kd,
stride_vb,
stride_vt,
stride_vh,
stride_vd,
stride_ob,
stride_ot,
stride_oh,
stride_od,
layout_crow_ptr,
layout_col_ptr,
layout_crow_stride_h,
layout_crow_stride_m,
layout_col_stride_h,
layout_col_stride_m,
q_k_ratio,
HAS_BATCH_DIM: tl.constexpr,
D_HEAD: tl.constexpr,
BLOCK_M: tl.constexpr,
BLOCK_N: tl.constexpr,
BLOCK_D: tl.constexpr,
BLOCK_M_LOADING: tl.constexpr,
EVEN_D: tl.constexpr,
M_LT_N: tl.constexpr,
):
"""
NOTATION:
pid: position id
sid: storage id
sbid: storage block id
pbid: position block id
offs_m, offs_n: storage offsets of m-dim(q, row) and n-dim(k, col)
TODO(linxihui):
Optimize grouped-attn
"""
off_zm = tl.program_id(0)
off_h = tl.program_id(1)
off_h_for_kv = off_h // q_k_ratio
if HAS_BATCH_DIM:
off_z = tl.program_id(2)
Q += off_z * stride_qb
K += off_z * stride_kb
V += off_z * stride_vb
Out += off_z * stride_ob
start_m = off_zm
q_start_sid = start_m * BLOCK_M # always 0 for decoding
else:
off_z = tl.load(q_batch_ids + off_zm).to(tl.int32) # [0, 0, 0, 1]
q_start_sid = tl.load(q_start_sids + off_zm)
start_m = q_start_sid // BLOCK_M # q_sbid
offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M_LOADING)
offs_n = tl.arange(0, BLOCK_N)
offs_d = tl.arange(0, BLOCK_D)
q_cu_start = tl.load(q_batch_starts + off_z).to(tl.int32)
q_seqlen = tl.load(q_batch_ends + off_z).to(tl.int32) - q_cu_start
k_cu_start = tl.load(k_batch_starts + off_z).to(tl.int32)
k_seqlen = tl.load(k_batch_ends + off_z).to(tl.int32) - k_cu_start
past_len = k_seqlen - q_seqlen
Q += q_cu_start * stride_qt + off_h * stride_qh
K += k_cu_start * stride_kt + off_h_for_kv * stride_kh
V += k_cu_start * stride_vt + off_h_for_kv * stride_vh
Out += q_cu_start * stride_ot + off_h * stride_oh
q_pbid = (past_len + q_start_sid) // BLOCK_M
if EVEN_D:
q = tl.load(
Q + offs_m[:, None] * stride_qt + offs_d[None, :] * stride_qd,
mask=offs_m[:, None] < q_seqlen,
)
else:
q = tl.load(
Q + offs_m[:, None] * stride_qt + offs_d[None, :] * stride_qd,
mask=(offs_m[:, None] < q_seqlen) & (offs_d[None, :] < D_HEAD),
other=0,
)
sparse_crow_ptr = (layout_crow_ptr + off_h * layout_crow_stride_h +
q_pbid * layout_crow_stride_m)
# TODO(linxihui): load at once, with any Triton version
# that supports `tl.split`, e.g., Triton 3.0
k_block_start = tl.load(sparse_crow_ptr).to(tl.int32)
k_block_end = tl.load(sparse_crow_ptr + 1).to(tl.int32)
m_i = tl.zeros([BLOCK_M_LOADING], dtype=tl.float32) - float("inf")
l_i = tl.zeros([BLOCK_M_LOADING], dtype=tl.float32)
acc = tl.zeros([BLOCK_M_LOADING, BLOCK_D], dtype=tl.float32)
k_ptrs = K + offs_n[None, :] * stride_kt + offs_d[:, None] * stride_kd
v_ptrs = V + offs_n[:, None] * stride_vt + offs_d[None, :] * stride_vd
sm_scale *= (
1.44269504 # 1/log2 as we use base2 for exponential and logarithm
)
for k_block_col_idx in range(k_block_start, k_block_end - 1):
acc, l_i, m_i = _fwd_kernel_inner(
acc,
l_i,
m_i,
q,
Q,
k_block_col_idx,
layout_col_ptr,
layout_col_stride_h,
layout_col_stride_m,
k_ptrs,
v_ptrs,
off_h,
offs_m,
offs_n,
offs_d,
stride_kt,
stride_vt,
sm_scale,
k_seqlen,
past_len,
False,
BLOCK_M_LOADING,
BLOCK_N,
D_HEAD,
EVEN_D,
M_LT_N,
)
acc, l_i, m_i = _fwd_kernel_inner(
acc,
l_i,
m_i,
q,
Q,
k_block_end - 1,
layout_col_ptr,
layout_col_stride_h,
layout_col_stride_m,
k_ptrs,
v_ptrs,
off_h,
offs_m,
offs_n,
offs_d,
stride_kt,
stride_vt,
sm_scale,
k_seqlen,
past_len,
True,
BLOCK_M_LOADING,
BLOCK_N,
D_HEAD,
EVEN_D,
M_LT_N,
)
# flash-attn 2
m_i += tl.math.log2(l_i)
acc = acc / l_i[:, None]
# write output
if EVEN_D:
tl.store(
Out + offs_m[:, None] * stride_ot + offs_d[None, :] * stride_od,
acc,
mask=offs_m[:, None] < q_seqlen,
)
else:
tl.store(
Out + offs_m[:, None] * stride_ot + offs_d[None, :] * stride_od,
acc,
mask=(offs_m[:, None] < q_seqlen) & (offs_d[None, :] < D_HEAD),
)

View File

@ -0,0 +1,238 @@
import math
import torch
from vllm.utils import is_cpu, is_hip
from .utils import (dense_to_crow_col, get_head_sliding_step,
get_sparse_attn_mask)
IS_COMPUTE_8_OR_ABOVE = (torch.cuda.is_available()
and torch.cuda.get_device_capability()[0] >= 8)
if IS_COMPUTE_8_OR_ABOVE:
from .blocksparse_attention_kernel import blocksparse_flash_attn_varlen_fwd
class LocalStridedBlockSparseAttn(torch.nn.Module):
def __init__(
self,
n_heads,
max_seqlen,
local_blocks,
vert_stride,
block_size,
device=None,
dtype=None,
homo_head=False,
active_head_range=None,
q_block_size=None,
use_spda=None,
):
super().__init__()
if use_spda is None:
use_spda = is_hip() or is_cpu() or not \
IS_COMPUTE_8_OR_ABOVE
device = device or (torch.cuda.current_device()
if torch.cuda.is_available() else "cpu")
device = torch.device(device)
# NOTE: vllm CPU backend support BF16 instead of FP16.
dtype = dtype or (torch.bfloat16 if IS_COMPUTE_8_OR_ABOVE
or device.type == "cpu" else torch.half)
self.n_heads = n_heads
self.max_seqlen = max_seqlen
self.local_blocks = local_blocks
self.vert_stride = vert_stride
self.use_spda = use_spda
self.dtype = dtype
self.device = device
self.block_size = block_size
self.q_block_size = q_block_size
self.homo_head = homo_head
self.active_head_range = active_head_range
self.head_sliding_step = get_head_sliding_step(n_heads, vert_stride,
homo_head)
sparse_layout, sparse_pattern, self.dense_attn_mask = (
self.get_attn_pattern(dtype, device))
if q_block_size is not None and q_block_size != block_size:
if q_block_size > block_size:
assert q_block_size % block_size == 0
blocks_to_merge = q_block_size // block_size
shape = sparse_pattern.shape
sparse_pattern = sparse_pattern.view(shape[0], -1,
blocks_to_merge,
shape[-1])
sparse_pattern = sparse_pattern.sum(2)
sparse_layout = dense_to_crow_col(sparse_pattern)
else:
raise ValueError(
"Does not support smaller q_block_size. It will be slower."
)
self.sparse_layout = sparse_layout
def get_attn_pattern(self, dtype, device):
sparse_layout, sparse_pattern, dense_attn_mask = get_sparse_attn_mask(
self.n_heads,
self.max_seqlen,
self.max_seqlen,
dtype,
device,
block_size=self.block_size,
local_blocks=self.local_blocks,
vert_stride=self.vert_stride,
homo_head=self.homo_head,
return_dense=self.use_spda,
dense_mask_type="bias",
)
if (not self.homo_head) and (self.active_head_range is not None):
assert isinstance(self.active_head_range, tuple)
assert (len(self.active_head_range) == 2)
h_start, h_end = self.active_head_range
sparse_layout = tuple(x[h_start:h_end] for x in sparse_layout)
if self.use_spda:
dense_attn_mask = dense_attn_mask[h_start:h_end]
return sparse_layout, sparse_pattern, dense_attn_mask
def varlen_attn(self,
q,
k,
v,
cu_seqlens_k,
cu_seqlens_q=None,
sm_scale=None):
"""
q, k, v: shape = (num_tokens, num_heads_q/kv, head_size).
Support grouped attention, with `q[:, i*r:(i*r + r)]`
is correspondent to `k[:, i]`, where `r` is the q/k ratio.
cu_seqlens_k: shape=(batch_size + 1,),
indicating segment of samples,
e.g., `k[cu_seqlen[i]:cu_seqlne[i+1]]` is q of sample i
cu_seqlens_q: shape=(batch_size + 1, ).
Default None: same as cu_seqlens_k for prefilling or
[0, 1, .., batch_size] for decoding.
The only case you need to specify is when q is a mix of
prefilling and decoding.
sm_scale: softmax scale, default to 1/sqrt(head_size).
return: tensor of shape as q.
"""
assert (
IS_COMPUTE_8_OR_ABOVE
), "Requires compute capability of 8 or above (Ampere or newer) to use \
Triton kernel."
sm_scale = sm_scale or 1.0 / math.sqrt(q.size(-1))
return blocksparse_flash_attn_varlen_fwd(
q,
k,
v,
cu_seqlens_k,
cu_seqlens_q,
sm_scale,
self.sparse_layout,
block_size=self.block_size,
q_block_size=self.q_block_size,
max_seqlen=self.max_seqlen,
)
@staticmethod
def transpose_and_pad(x, cu_seqlens, maxlen, head_repeats=1):
"""
:param x: (total_tokens, n_heads, head_size)
:return: (batch, n_heads, length, head_size)
"""
x_padded = x.new_empty(
len(cu_seqlens) - 1, x.size(1), head_repeats, maxlen, x.size(2))
cu_seqlens = cu_seqlens.cpu()
for i, (s, e) in enumerate(zip(cu_seqlens[:-1], cu_seqlens[1:])):
x_padded[i, :, :, :e - s].copy_(x[s:e].transpose(0,
1).unsqueeze(1))
return x_padded.flatten(1, 2)
@staticmethod
def transpose_and_unpad(x_padded, cu_seqlens):
"""
:param x_padded: (batch, n_heads, length, head_size)
:return: (total_tokens, n_heads, head_size)
"""
cu_seqlens = cu_seqlens.cpu()
total_n_tokens = cu_seqlens[-1]
x = x_padded.new_empty(total_n_tokens, x_padded.size(1),
x_padded.size(3))
for i, (s, e) in enumerate(zip(cu_seqlens[:-1], cu_seqlens[1:])):
x[s:e].copy_(x_padded[i, :, :e - s].transpose(0, 1))
return x
def spda(self, q, k, v, cu_seqlens_k, cu_seqlens_q=None, sm_scale=None):
"""For CPU, V100 or other older GPUs.
NOTE: torch SPDA supports nested tensor,
but seems extremely slow. Choose to pad instead.
"""
assert (cu_seqlens_q is None or
(cu_seqlens_q
== cu_seqlens_k).all()), "Can only handle prompt with SPDA."
assert q.size(0) == k.size(0), "can only handle prompt with SPDA."
assert q.size(1) % k.size(1) == 0
q_k_ratio = q.size(1) // k.size(1)
sm_scale = sm_scale or 1.0 / math.sqrt(q.size(-1))
cu_seqlens = cu_seqlens_k.cpu()
maxlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max()
if (self.dense_attn_mask.dtype != q.dtype
or self.dense_attn_mask.device != q.device):
_, _, self.dense_attn_mask = self.get_attn_pattern(
q.dtype, q.device)
attn_mask = self.dense_attn_mask[None, :, :maxlen, :maxlen]
q2 = self.transpose_and_pad(q, cu_seqlens, maxlen, 1)
k2, v2 = [
self.transpose_and_pad(x, cu_seqlens, maxlen, q_k_ratio)
for x in [k, v]
]
spda_output = torch.nn.functional.scaled_dot_product_attention(
q2, k2, v2, attn_mask=attn_mask, scale=sm_scale)
return self.transpose_and_unpad(spda_output, cu_seqlens)
def forward(self, q, k, v, cu_seqlens_k, cu_seqlens_q=None, sm_scale=None):
"""Dispatch to `varlen_attn` (Ampere or newer) or
`self.spda`(cpu, Volta, Turing or older)based on
the type of device used and cuda compute capability.
q, k, v: shape = (num_tokens, num_heads_q/kv, head_size).
Support grouped attention, with `q[:, i*r:(i*r + r)]`
is correspondent to `k[:, i]`, where `r` is the q/k ratio.
cu_seqlens_k: shape=(batch_size + 1,), indicating segment of samples,
e.g., `k[cu_seqlen[i]:cu_seqlne[i+1]]` is q of sample i
cu_seqlens_q: shape=(batch_size + 1, ).
Default None: same as cu_seqlens_k for prefilling or
[0, 1, .., batch_size] for decoding.
The only case you need to specify
is when q is a mix of prefilling
and decoding.
sm_scale: softmax scale, default to 1/sqrt(head_size).
return: tensor of shape as q.
"""
assert k.dim() == 3
if self.use_spda:
return self.spda(
q,
k,
v,
cu_seqlens_k,
cu_seqlens_q=cu_seqlens_q,
sm_scale=sm_scale,
)
return self.varlen_attn(q,
k,
v,
cu_seqlens_k,
cu_seqlens_q=cu_seqlens_q,
sm_scale=sm_scale)

View File

@ -0,0 +1,216 @@
# Helper functions for 3D sparse pattern
# These function are not optimized and very inefficient.
# Avoid calling them too frequent or use a cache mechanism.
from functools import lru_cache
import torch
import triton
from scipy import sparse
def dense_to_crow_col(x: torch.Tensor):
"""Turning a 2D/3D torch tensor (x) to CSR rows/cols indexing.
NOTE: col_indices padded -1
"""
device = x.device
pad = -1
dim = x.dim()
assert x.dim() in (2, 3)
if x.dim() == 2:
x = x[None]
x = [sparse.csr_matrix(xi.bool().cpu().numpy()) for xi in x]
crows = torch.vstack([torch.from_numpy(xi.indptr) for xi in x])
cols = [torch.from_numpy(xi.indices) for xi in x]
max_cols = max(len(xi) for xi in cols)
cols = [
torch.cat([xi, pad + xi.new_zeros(max_cols - xi.shape[0])])
for xi in cols
]
cols = torch.vstack(cols)
if dim == 2:
crows = crows[0]
cols = cols[0]
return crows.to(device), cols.to(device)
def crow_col_to_dense(crows: torch.Tensor,
cols: torch.Tensor,
dtype: torch.dtype = torch.float16):
dim = crows.dim()
if dim == 1:
crows = crows[None]
cols = cols[None]
device = crows.device
crows, cols = crows.cpu(), cols.cpu() # faster in cpu
shape = (crows.shape[0], crows.shape[1] - 1, cols.max() + 1)
x = torch.zeros(shape, dtype=dtype)
for i in range(shape[0]):
for j in range(shape[1]):
x[i, j, cols[i, crows[i, j]:crows[i, j + 1]]] = 1
if dim == 1:
x = x[0]
return x.to(device)
def dense_to_ccol_row(x: torch.Tensor):
"""Similar, but to CSC format"""
x = x.transpose(-2, -1)
return dense_to_crow_col(x)
def ccol_row_to_dense(ccol: torch.Tensor,
rows: torch.Tensor,
dtype: torch.dtype = torch.float16):
return crow_col_to_dense(ccol, rows, dtype).permute(0, 2, 1).contiguous()
def _get_sparse_attn_mask_homo_head(
q_len: int,
max_seqlen: int,
dtype: torch.dtype,
device: torch.device,
block_size: int = 128,
local_blocks: int = 4,
vert_stride: int = 4,
return_dense: bool = False,
):
"""
:return: a tuple of 3:
- tuple of crow_indices, col_indices representation
of CSR format.
- block dense mask
- all token dense mask (be aware that it can be
OOM if it is too big) if `return_dense==True`,
otherwise, None
"""
with torch.no_grad():
num_blocks = triton.cdiv(max_seqlen, block_size)
q_pos = torch.arange(num_blocks)[:, None]
k_pos = torch.arange(num_blocks)[None]
mask_vert_strided = (torch.arange(num_blocks) + 1) % vert_stride == 0
block_mask_dense = (((q_pos >= k_pos)
& ((q_pos - k_pos < local_blocks)
| mask_vert_strided)).to(device).to(dtype))
num_blocks_q = triton.cdiv(q_len, block_size)
block_mask_dense_output = (dense_to_crow_col(
block_mask_dense[-num_blocks_q:].contiguous()))
if return_dense:
mask_dense = torch.kron(
block_mask_dense,
block_mask_dense.new_ones((block_size, block_size)),
)
causal_mask = torch.tril(torch.ones(
max_seqlen, max_seqlen)).type_as(mask_dense)[-q_len:]
mask_dense = mask_dense[-q_len:, :max_seqlen] * causal_mask
return (
block_mask_dense_output,
block_mask_dense,
mask_dense,
)
else:
return (
block_mask_dense_output,
block_mask_dense,
None,
)
def binary_mask_to_bias(mask_dense: torch.Tensor):
mask_dense = 1 - mask_dense
mask_dense.masked_fill_(mask_dense.bool(), -torch.inf)
return mask_dense
def get_head_sliding_step(n_heads: int,
vert_stride: int,
homo_head: bool = False):
if homo_head:
return 0
return max(1, int(vert_stride / n_heads))
@lru_cache
def get_sparse_attn_mask(
n_heads: int,
q_len: int,
max_seqlen: int,
dtype: torch.dtype,
device: torch.device,
block_size: int = 64,
local_blocks: int = 4,
vert_stride: int = 4,
homo_head: bool = True,
return_dense: bool = False,
dense_mask_type: str = "binary",
):
"""
:param dense_mask_type: "binary" (0 for skip token, 1 for others)
or "bias" (-inf for skip token, 0 or others)
:return: a tuple of 3:
- tuple of crow_indices, col_indices representation
of CSR format.
- block dense mask
- all token dense mask (be aware that it can be OOM if it
is too big) if `return_dense==True`, otherwise, None
"""
assert dense_mask_type in ("binary", "bias")
if homo_head:
with torch.no_grad():
(crow, col), block_mask_dense, mask_dense = (
_get_sparse_attn_mask_homo_head(
q_len,
max_seqlen,
dtype,
device,
block_size,
local_blocks,
vert_stride,
return_dense,
))
crow = crow[None].expand(n_heads, crow.shape[0])
col = col[None].expand(n_heads, col.shape[0])
if return_dense:
mask_dense = mask_dense[None].expand(n_heads,
*mask_dense.shape)
if dense_mask_type == "bias":
mask_dense = binary_mask_to_bias(mask_dense)
return (crow, col), block_mask_dense, mask_dense
with torch.no_grad():
num_blocks = triton.cdiv(max_seqlen, block_size)
q_pos = torch.arange(num_blocks)[None, :, None]
k_pos = torch.arange(num_blocks)[None, None]
head_sliding_step = get_head_sliding_step(n_heads, vert_stride)
mask_vert_strided = [
(torch.arange(num_blocks) + h * head_sliding_step + 1) %
vert_stride == 0 for h in range(n_heads)
]
mask_vert_strided = torch.vstack(mask_vert_strided).unsqueeze(1)
block_mask_dense = (((q_pos >= k_pos)
& ((q_pos - k_pos < local_blocks)
| mask_vert_strided)).to(device).to(dtype))
num_blocks_q = triton.cdiv(q_len, block_size)
block_mask_dense_output = block_mask_dense[:, -num_blocks_q:]
if return_dense:
mask_dense = torch.kron(
block_mask_dense,
block_mask_dense.new_ones((block_size, block_size)),
)
causal_mask = torch.tril(torch.ones(
max_seqlen, max_seqlen)).type_as(mask_dense)[-q_len:]
mask_dense = mask_dense[..., -q_len:, :max_seqlen] * causal_mask[None]
if dense_mask_type == "bias":
mask_dense = binary_mask_to_bias(mask_dense)
return (
dense_to_crow_col(block_mask_dense_output),
block_mask_dense,
mask_dense,
)
else:
return (
dense_to_crow_col(block_mask_dense_output),
block_mask_dense,
None,
)

View File

@ -91,9 +91,21 @@ class PagedAttention:
scale: float,
alibi_slopes: Optional[torch.Tensor],
kv_scale: float,
tp_rank: int = 0,
blocksparse_local_blocks: int = 0,
blocksparse_vert_stride: int = 0,
blocksparse_block_size: int = 64,
blocksparse_head_sliding_step: int = 0,
) -> torch.Tensor:
output = torch.empty_like(query)
if blocksparse_vert_stride is not None and blocksparse_vert_stride > 1:
# use blocksparse paged attention
block_size = value_cache.size(-1)
assert (blocksparse_block_size > 0 and
blocksparse_block_size % block_size == 0), \
(f"{blocksparse_block_size=} needs to be a multiple of"
f"{block_size=} used in block_tables.")
output = torch.empty_like(query)
block_size = value_cache.shape[3]
num_seqs, num_heads, head_size = query.shape
max_num_partitions = ((max_seq_len + _PARTITION_SIZE - 1) //
@ -107,6 +119,7 @@ class PagedAttention:
# For context len > 8192, use V2 kernel to avoid shared memory shortage.
use_v1 = (max_seq_len <= 8192
and (max_num_partitions == 1 or num_seqs * num_heads > 512))
if use_v1:
# Run PagedAttention V1.
ops.paged_attention_v1(
@ -123,6 +136,11 @@ class PagedAttention:
alibi_slopes,
kv_cache_dtype,
kv_scale,
tp_rank,
blocksparse_local_blocks,
blocksparse_vert_stride,
blocksparse_block_size,
blocksparse_head_sliding_step,
)
else:
# Run PagedAttention V2.
@ -155,6 +173,11 @@ class PagedAttention:
alibi_slopes,
kv_cache_dtype,
kv_scale,
tp_rank,
blocksparse_local_blocks,
blocksparse_vert_stride,
blocksparse_block_size,
blocksparse_head_sliding_step,
)
return output

View File

@ -29,7 +29,14 @@ def get_attn_backend(
dtype: torch.dtype,
kv_cache_dtype: Optional[str],
block_size: int,
is_blocksparse: bool = False,
) -> Type[AttentionBackend]:
if is_blocksparse:
logger.info("Using BlocksparseFlashAttention backend.")
from vllm.attention.backends.blocksparse_attn import (
BlocksparseFlashAttentionBackend)
return BlocksparseFlashAttentionBackend
"""Determine which attention backend to use and only import
the selected backend module.
"""

View File

@ -100,6 +100,7 @@ class OpenAIServing:
token_logprob = step_top_logprobs[token_id].logprob
token = step_top_logprobs[token_id].decoded_token
logprobs.tokens.append(token)
token_logprob = max(token_logprob, -9999.0)
logprobs.token_logprobs.append(token_logprob)
if num_output_top_logprobs:

View File

@ -56,6 +56,7 @@ _GENERATION_MODELS = {
"Starcoder2ForCausalLM": ("starcoder2", "Starcoder2ForCausalLM"),
"ArcticForCausalLM": ("arctic", "ArcticForCausalLM"),
"XverseForCausalLM": ("xverse", "XverseForCausalLM"),
"Phi3SmallForCausalLM": ("phi3_small", "Phi3SmallForCausalLM"),
}
_EMBEDDING_MODELS = {

View File

@ -0,0 +1,447 @@
import math
from typing import Iterable, List, Optional, Tuple
import torch
from torch import nn
from transformers.configuration_utils import PretrainedConfig
from vllm.attention import Attention, AttentionMetadata
from vllm.config import CacheConfig, LoRAConfig
from vllm.distributed import (get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size)
from vllm.model_executor.layers.linear import (MergedColumnParallelLinear,
QKVParallelLinear,
RowParallelLinear)
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig)
from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.layers.sampler import Sampler
from vllm.model_executor.layers.vocab_parallel_embedding import (
DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding)
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.sequence import SamplerOutput
def load_column_parallel_weight(param: torch.nn.Parameter,
loaded_weight: torch.Tensor):
tp = get_tensor_model_parallel_world_size()
rk = get_tensor_model_parallel_rank()
assert param.size(0) * tp == loaded_weight.size(0)
s = rk * param.size(0)
e = (rk + 1) * param.size(0)
loaded_weight = loaded_weight[s:e]
assert param.shape == loaded_weight.shape
param.data.copy_(loaded_weight)
class HeadMajorQKVParallelLinear(QKVParallelLinear):
def weight_loader(self, param: torch.nn.Parameter,
loaded_weight: torch.Tensor):
return load_column_parallel_weight(param, loaded_weight)
class HeadMajorColumnParallelLinear(MergedColumnParallelLinear):
def weight_loader(self, param: torch.nn.Parameter,
loaded_weight: torch.Tensor):
return load_column_parallel_weight(param, loaded_weight)
@torch.jit.script
def quick_gelu(x):
return x * torch.sigmoid(1.702 * x)
@torch.jit.script
def gegelu(input, limit: Optional[float] = None):
a_gelu, a_linear = input[..., ::2], input[..., 1::2]
if limit is not None:
a_gelu = torch.where(torch.isinf(a_gelu), a_gelu,
a_gelu.clamp(min=None, max=limit))
a_linear = torch.where(
torch.isinf(a_linear),
a_linear,
a_linear.clamp(min=-limit, max=limit),
)
out_gelu = quick_gelu(a_gelu)
return out_gelu * (a_linear + 1)
class Phi3SmallMLP(nn.Module):
def __init__(
self,
config: PretrainedConfig,
quant_config: Optional[QuantizationConfig] = None,
) -> None:
super().__init__()
self.config = config
assert (self.config.hidden_act == "gegelu"
), "Only `gegelu` is supported for the 4.7 series of models .."
self.hidden_size = config.hidden_size
self.gegelu_limit = config.gegelu_limit
self.intermediate_size = config.intermediate_size
self.up_proj = HeadMajorColumnParallelLinear(
self.hidden_size,
2 * [self.intermediate_size],
bias=True,
quant_config=quant_config,
)
self.down_proj = RowParallelLinear(
self.intermediate_size,
self.hidden_size,
bias=True,
quant_config=quant_config,
)
def forward(self, x):
gate_up, _ = self.up_proj(x)
x = gegelu(gate_up)
x, _ = self.down_proj(x)
return x
class Phi3SmallSelfAttention(nn.Module):
def __init__(
self,
config: PretrainedConfig,
layer_idx: int,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
) -> None:
super().__init__()
self.layer_idx = layer_idx
self.config = config
self.sparse_block_size = config.blocksparse_block_size
self.homo_heads = config.blocksparse_homo_head_pattern
self.local_blocks = config.blocksparse_num_local_blocks
self.vert_stride = config.blocksparse_vert_stride
assert (config.blocksparse_block_size ==
config.blocksparse_triton_kernel_block_size)
self.hidden_size = config.hidden_size
# Number of Query Heads
self.num_heads = config.num_attention_heads
self.head_dim = self.hidden_size // self.num_heads
self.tp_size = get_tensor_model_parallel_world_size()
# Number of total Key Value Heads before tensor parallel
self.num_key_value_heads = config.num_key_value_heads
self.num_q_per_kv = self.num_heads // self.num_key_value_heads
if self.tp_size > 1:
assert self.num_key_value_heads % self.tp_size == 0
self.num_kv_heads_per_partion = max(
1, self.num_key_value_heads // self.tp_size)
self.num_heads_per_partition = self.num_heads // self.tp_size
self.max_position_embeddings = config.max_position_embeddings
self.rope_embedding_base = config.rope_embedding_base
self.rope_position_scale = config.rope_position_scale
self.is_causal = True
norm_factor = None
if config.mup_use_scaling:
norm_factor = self.head_dim / config.mup_attn_multiplier
else:
norm_factor = math.sqrt(self.head_dim)
self.scale = 1 / norm_factor
self.query_key_value = HeadMajorQKVParallelLinear(
self.hidden_size,
self.head_dim,
self.num_heads,
self.num_key_value_heads,
bias=True,
quant_config=quant_config,
)
self.dense = RowParallelLinear(self.hidden_size,
self.hidden_size,
bias=True,
quant_config=quant_config)
if getattr(self.config, "rope_scaling", None) is not None:
rope_scaling = self.config.rope_scaling
for key in rope_scaling:
if isinstance(rope_scaling[key], list):
rope_scaling[key] = tuple(rope_scaling[key])
if "factor" not in rope_scaling:
rope_scaling["factor"] = self.rope_position_scale
else:
rope_scaling = {
"type": "linear",
"factor": self.rope_position_scale,
}
self.rotary_emb = get_rope(
self.head_dim,
rotary_dim=self.head_dim,
max_position=self.max_position_embeddings,
base=self.rope_embedding_base,
rope_scaling=rope_scaling,
)
# blocksparse params
self.blocksparse_block_size = config.blocksparse_block_size
self.blocksparse_num_local_blocks = config.blocksparse_num_local_blocks
self.blocksparse_vert_stride = config.blocksparse_vert_stride
use_dense_attn = (getattr(self.config,
"dense_attention_every_n_layers", None)
and (self.layer_idx + 1) %
self.config.dense_attention_every_n_layers == 0)
bs_params = None
if not use_dense_attn:
bs_params = {
'max_seqlen': self.max_position_embeddings,
'num_heads': self.num_heads_per_partition,
"num_kv_heads": self.num_kv_heads_per_partion,
"block_size": self.sparse_block_size,
"local_blocks": self.local_blocks,
"vert_stride": self.vert_stride,
"homo_head": self.homo_heads
}
self.attn = Attention(
self.num_heads_per_partition,
self.head_dim,
self.scale,
num_kv_heads=self.num_kv_heads_per_partion,
cache_config=cache_config,
quant_config=quant_config,
blocksparse_params=bs_params,
)
def forward(
self,
positions: torch.Tensor,
hidden_states: torch.Tensor,
kv_cache: torch.Tensor,
attn_metadata: AttentionMetadata,
) -> Tuple[torch.Tensor, Optional[torch.Tensor],
Optional[Tuple[torch.Tensor]]]:
qkv, _ = self.query_key_value(hidden_states)
qkv = qkv.view(qkv.shape[:-1] +
(-1, (self.num_q_per_kv + 2), self.head_dim))
q, k, v = qkv.split([self.num_q_per_kv, 1, 1], dim=-2)
# NOTE: this is required by RotaryEmbed, which indeed does not have to
# TODO: allow 3D QK for rotary forward
q = q.reshape(-1, self.head_dim * self.num_heads_per_partition)
k = k.reshape(-1, self.head_dim * self.num_kv_heads_per_partion)
v = v.reshape(-1, self.head_dim * self.num_kv_heads_per_partion)
q, k = self.rotary_emb(positions, q, k)
attn_output = self.attn(q, k, v, kv_cache, attn_metadata=attn_metadata)
output, _ = self.dense(attn_output)
return output
class Phi3SmallDecoderLayer(nn.Module):
def __init__(
self,
config: PretrainedConfig,
layer_idx: int,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
):
super().__init__()
self.hidden_size = config.hidden_size
self.self_attn = Phi3SmallSelfAttention(config,
layer_idx,
cache_config=cache_config,
quant_config=quant_config)
self.mlp = Phi3SmallMLP(config, quant_config)
self.input_layernorm = nn.LayerNorm(config.hidden_size,
eps=config.layer_norm_epsilon)
self.post_attention_layernorm = nn.LayerNorm(
config.hidden_size, eps=config.layer_norm_epsilon)
def forward(
self,
positions: torch.Tensor,
hidden_states: torch.Tensor,
kv_cache: torch.Tensor,
attn_metadata: AttentionMetadata,
) -> torch.Tensor:
residual = hidden_states
hidden_states = self.input_layernorm(hidden_states)
hidden_states = self.self_attn(
positions=positions,
hidden_states=hidden_states,
kv_cache=kv_cache,
attn_metadata=attn_metadata,
)
hidden_states = residual + hidden_states
residual = hidden_states
hidden_states = self.post_attention_layernorm(hidden_states)
hidden_states = self.mlp(hidden_states)
hidden_states = residual + hidden_states
return hidden_states
class Phi3SmallModel(nn.Module):
def __init__(
self,
config: PretrainedConfig,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
):
super().__init__()
self.config = config
self.embed_tokens = VocabParallelEmbedding(config.vocab_size,
config.hidden_size)
self.mup_embedding_multiplier = config.mup_embedding_multiplier
self.layers = nn.ModuleList([
Phi3SmallDecoderLayer(config, layer_idx, cache_config,
quant_config)
for layer_idx in range(config.num_hidden_layers)
])
self.final_layernorm = nn.LayerNorm(config.hidden_size,
eps=config.layer_norm_epsilon)
def get_input_embeddings(self):
return self.embed_tokens
def set_input_embeddings(self, value):
self.embed_tokens = value
def forward(
self,
input_ids: torch.LongTensor,
positions: Optional[torch.LongTensor],
kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata = None,
):
hidden_states = self.embed_tokens(input_ids)
if (self.mup_embedding_multiplier is not None
and self.mup_embedding_multiplier > 0.0):
hidden_states = hidden_states * self.mup_embedding_multiplier
for i in range(len(self.layers)):
layer = self.layers[i]
hidden_states = layer(
positions,
hidden_states,
kv_caches[i],
attn_metadata,
)
hidden_states = self.final_layernorm(hidden_states)
return hidden_states
class Phi3SmallForCausalLM(nn.Module):
_tied_weights_keys = ["lm_head.weight"]
def __init__(
self,
config,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
lora_config: Optional[LoRAConfig] = None,
):
super().__init__()
self.config = config
self.quant_config = quant_config
self.model = Phi3SmallModel(config, cache_config, quant_config)
self.vocab_size = config.vocab_size
self.mup_width_multiplier = config.mup_width_multiplier
self.lm_head = ParallelLMHead(
self.vocab_size,
config.hidden_size,
org_num_embeddings=config.vocab_size,
padding_size=DEFAULT_VOCAB_PADDING_SIZE,
)
self.logits_processor = LogitsProcessor(config.vocab_size)
self.sampler = Sampler()
# tokens in tiktoken but not used
if hasattr(config, 'dummy_token_indices'):
device = self.lm_head.weight.device
self.register_buffer('dummy_token_indices',
torch.LongTensor(
config.dummy_token_indices).to(device),
persistent=False)
else:
self.dummy_token_indices = None
def get_input_embeddings(self):
return self.model.embed_tokens
def set_input_embeddings(self, value):
self.model.embed_tokens = value
def get_output_embeddings(self):
return self.lm_head
def set_output_embeddings(self, value):
self.lm_head = value
def set_decoder(self, decoder):
self.model = decoder
def get_decoder(self):
return self.model
def compute_logits(self, hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata) -> torch.Tensor:
logits = self.logits_processor(self.lm_head.weight, hidden_states,
sampling_metadata)
if self.dummy_token_indices is not None and logits is not None:
logits.index_fill_(-1, self.dummy_token_indices, -torch.inf)
return logits
def forward(
self,
input_ids: torch.LongTensor,
positions: Optional[torch.LongTensor],
kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata,
) -> torch.Tensor:
output_hidden_states = self.model(
input_ids=input_ids,
positions=positions,
kv_caches=kv_caches,
attn_metadata=attn_metadata,
)
output_hidden_states = output_hidden_states
return output_hidden_states
def sample(
self,
logits: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> Optional[SamplerOutput]:
next_tokens = self.sampler(logits / self.mup_width_multiplier,
sampling_metadata)
return next_tokens
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
params_dict = dict(self.named_parameters())
for name, loaded_weight in weights:
if "rotary_emb.inv_freq" in name:
continue
if name.endswith(".bias") and name not in params_dict:
continue
param = params_dict[name]
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
weight_loader(param, loaded_weight)
self.lm_head.weight.data.copy_(self.model.embed_tokens.weight.data)

View File

@ -63,4 +63,4 @@ def get_hf_text_config(config: PretrainedConfig):
assert hasattr(config.text_config, "num_attention_heads")
return config.text_config
else:
return config
return config