[ROCM][KERNEL] Paged attention for V1 (#15720)
Signed-off-by: Aleksandr Malyshev <maleksan@amd.com> Signed-off-by: root <root@banff-cyxtera-s65-4.amd.com> Co-authored-by: Aleksandr Malyshev <maleksan@amd.com> Co-authored-by: root <root@banff-cyxtera-s65-4.amd.com>
This commit is contained in:
parent
bd7599d34a
commit
e73ff24e31
@ -272,6 +272,7 @@ __launch_bounds__(NUM_THREADS, 5) void paged_attention_ll4mi_QKV_mfma16_kernel(
|
||||
const float scale,
|
||||
const int* __restrict__ block_tables, // [num_seqs, max_num_blocks_per_seq]
|
||||
const int* __restrict__ context_lens, // [num_seqs]
|
||||
const int* __restrict__ query_start_loc_ptr, // [num_seqs]
|
||||
const int max_num_blocks_per_seq,
|
||||
const float* __restrict__ alibi_slopes, // [num_heads]
|
||||
const int q_stride,
|
||||
@ -291,6 +292,13 @@ __launch_bounds__(NUM_THREADS, 5) void paged_attention_ll4mi_QKV_mfma16_kernel(
|
||||
const int rowid = laneid / 16;
|
||||
|
||||
const auto seq_idx = blockIdx.x;
|
||||
// NOTE queries with sequence len > 1 are prefills and taken care by another
|
||||
// kernel.
|
||||
if (query_start_loc_ptr != nullptr &&
|
||||
(query_start_loc_ptr[seq_idx + 1] - query_start_loc_ptr[seq_idx]) != 1) {
|
||||
return;
|
||||
}
|
||||
|
||||
const auto partition_idx = blockIdx.y;
|
||||
|
||||
constexpr int T_PAR_SIZE = 256; // token partition size set to 256
|
||||
@ -377,9 +385,10 @@ __launch_bounds__(NUM_THREADS, 5) void paged_attention_ll4mi_QKV_mfma16_kernel(
|
||||
// fetch Q in shared across warps and then write to registers
|
||||
const int local_qhead_idx = 4 * warpid + rowid;
|
||||
const int global_qhead_idx = wg_start_head_idx + local_qhead_idx;
|
||||
const int64_t seq_idx64 = static_cast<int64_t>(seq_idx);
|
||||
const int64_t query_start_off = static_cast<int64_t>(
|
||||
query_start_loc_ptr ? query_start_loc_ptr[seq_idx] : seq_idx);
|
||||
const scalar_t* q_ptr =
|
||||
q + seq_idx64 * q_stride + global_qhead_idx * HEAD_SIZE;
|
||||
q + query_start_off * q_stride + global_qhead_idx * HEAD_SIZE;
|
||||
|
||||
const int qhead_element = lane16id * CONTIGUOUS_SCALAR_ELEMS_16B;
|
||||
if ((local_qhead_idx < GQA_RATIO) && (qhead_element < HEAD_SIZE)) {
|
||||
@ -777,6 +786,7 @@ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_QKV_mfma4_kernel(
|
||||
const float scale,
|
||||
const int* __restrict__ block_tables, // [num_seqs, max_num_blocks_per_seq]
|
||||
const int* __restrict__ context_lens, // [num_seqs]
|
||||
const int* __restrict__ query_start_loc_ptr, // [num_seqs]
|
||||
const int max_num_blocks_per_seq,
|
||||
const float* __restrict__ alibi_slopes, // [num_heads]
|
||||
const int q_stride,
|
||||
@ -794,6 +804,12 @@ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_QKV_mfma4_kernel(
|
||||
const int lane4id = laneid % 4;
|
||||
|
||||
const auto seq_idx = blockIdx.x;
|
||||
// NOTE queries with sequence len > 1 are prefills and taken care by another
|
||||
// kernel.
|
||||
if (query_start_loc_ptr != nullptr &&
|
||||
(query_start_loc_ptr[seq_idx + 1] - query_start_loc_ptr[seq_idx] != 1)) {
|
||||
return;
|
||||
}
|
||||
const auto partition_idx = blockIdx.y;
|
||||
const auto partition_size = blockDim.x;
|
||||
const auto max_num_partitions = gridDim.y;
|
||||
@ -882,9 +898,11 @@ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_QKV_mfma4_kernel(
|
||||
}
|
||||
|
||||
// fetch q elements
|
||||
// every 4 lanes fetch 8 elems, so warp fetches 8*16 = 128 elems
|
||||
// every 4 lanes fetch 8 elems, so warp fetches 8*16 = 128 elemsc
|
||||
const int64_t query_start_off = static_cast<int64_t>(
|
||||
query_start_loc_ptr ? query_start_loc_ptr[seq_idx] : seq_idx);
|
||||
const scalar_t* q_ptr =
|
||||
q + seq_idx * q_stride + wg_start_head_idx * HEAD_SIZE;
|
||||
q + query_start_off * q_stride + wg_start_head_idx * HEAD_SIZE;
|
||||
const _B16x8* q_ptrh8 = reinterpret_cast<const _B16x8*>(q_ptr);
|
||||
const int qhead_elemh8 = laneid / 4;
|
||||
|
||||
@ -1267,10 +1285,19 @@ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_reduce_kernel(
|
||||
const scalar_t* __restrict__ tmp_out, // [num_seqs, num_heads,
|
||||
// max_num_partitions, head_size]
|
||||
const int* __restrict__ context_lens, // [num_seqs]
|
||||
const int* __restrict__ query_start_loc_ptr, // [num_seqs]
|
||||
const int max_num_partitions) {
|
||||
const auto num_heads = gridDim.x;
|
||||
const auto head_idx = blockIdx.x;
|
||||
const auto seq_idx = blockIdx.y;
|
||||
|
||||
// NOTE queries with sequence len > 1 are prefills and taken care by another
|
||||
// kernel.
|
||||
if (query_start_loc_ptr != nullptr &&
|
||||
(query_start_loc_ptr[seq_idx + 1] - query_start_loc_ptr[seq_idx] != 1)) {
|
||||
return;
|
||||
}
|
||||
|
||||
const int context_len = context_lens[seq_idx];
|
||||
const int num_partitions = DIVIDE_ROUND_UP(context_len, PARTITION_SIZE);
|
||||
[[maybe_unused]] constexpr int NUM_WARPS = NUM_THREADS / WARP_SIZE;
|
||||
@ -1439,7 +1466,9 @@ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_reduce_kernel(
|
||||
__fdividef(1.0f, shared_global_exp_sum + 1e-6f);
|
||||
acc *= inv_global_exp_sum;
|
||||
|
||||
OUTT* out_ptr = out + static_cast<int64_t>(seq_idx) * num_heads * HEAD_SIZE +
|
||||
const int64_t query_start_off = static_cast<int64_t>(
|
||||
query_start_loc_ptr ? query_start_loc_ptr[seq_idx] : seq_idx);
|
||||
OUTT* out_ptr = out + query_start_off * num_heads * HEAD_SIZE +
|
||||
static_cast<int64_t>(head_idx) * HEAD_SIZE;
|
||||
if constexpr (std::is_same<OUTT, bit8_t>::value) {
|
||||
out_ptr[threadIdx.x] =
|
||||
@ -1466,6 +1495,7 @@ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_QKV_mfma16_kernel(
|
||||
const float scale,
|
||||
const int* __restrict__ block_tables, // [num_seqs, max_num_blocks_per_seq]
|
||||
const int* __restrict__ context_lens, // [num_seqs]
|
||||
const int* __restrict__ query_start_loc_ptr, // [num_seqs]
|
||||
const int max_num_blocks_per_seq,
|
||||
const float* __restrict__ alibi_slopes, // [num_heads]
|
||||
const int q_stride,
|
||||
@ -1492,6 +1522,7 @@ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_QKV_mfma4_kernel(
|
||||
const float scale,
|
||||
const int* __restrict__ block_tables, // [num_seqs, max_num_blocks_per_seq]
|
||||
const int* __restrict__ context_lens, // [num_seqs]
|
||||
const int* __restrict__ query_start_loc_ptr, // [num_seqs]
|
||||
const int max_num_blocks_per_seq,
|
||||
const float* __restrict__ alibi_slopes, // [num_heads]
|
||||
const int q_stride,
|
||||
@ -1515,6 +1546,7 @@ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_reduce_kernel(
|
||||
const float* __restrict__ max_logits, // [num_seqs, num_heads, max_num_partitions]
|
||||
const scalar_t* __restrict__ tmp_out, // [num_seqs, num_heads, max_num_partitions, head_size]
|
||||
const int* __restrict__ context_lens, // [num_seqs]
|
||||
const int* __restrict__ query_start_loc_ptr, // [num_seqs]
|
||||
const int max_num_partitions) {
|
||||
UNREACHABLE_CODE
|
||||
}
|
||||
@ -1522,34 +1554,34 @@ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_reduce_kernel(
|
||||
|
||||
#endif // defined(__HIP__MI300_MI250__) TODO: Add NAVI support
|
||||
|
||||
#define LAUNCH_CUSTOM_ATTENTION_MFMA16(GQA_RATIO) \
|
||||
paged_attention_ll4mi_QKV_mfma16_kernel<T, KVT, KV_DTYPE, OUTT, BLOCK_SIZE, \
|
||||
HEAD_SIZE, NTHR, ALIBI_ENABLED, \
|
||||
GQA_RATIO> \
|
||||
<<<grid, block, 0, stream>>>( \
|
||||
query_ptr, key_cache_ptr, value_cache_ptr, num_kv_heads, scale, \
|
||||
block_tables_ptr, context_lens_ptr, max_num_blocks_per_seq, \
|
||||
alibi_slopes_ptr, q_stride, kv_block_stride, kv_head_stride, \
|
||||
exp_sums_ptr, max_logits_ptr, tmp_out_ptr, out_ptr, max_ctx_blocks, \
|
||||
k_scale_ptr, v_scale_ptr);
|
||||
#define LAUNCH_CUSTOM_ATTENTION_MFMA16(GQA_RATIO) \
|
||||
paged_attention_ll4mi_QKV_mfma16_kernel<T, KVT, KV_DTYPE, OUTT, BLOCK_SIZE, \
|
||||
HEAD_SIZE, NTHR, ALIBI_ENABLED, \
|
||||
GQA_RATIO> \
|
||||
<<<grid, block, 0, stream>>>( \
|
||||
query_ptr, key_cache_ptr, value_cache_ptr, num_kv_heads, scale, \
|
||||
block_tables_ptr, context_lens_ptr, query_start_loc_ptr, \
|
||||
max_num_blocks_per_seq, alibi_slopes_ptr, q_stride, kv_block_stride, \
|
||||
kv_head_stride, exp_sums_ptr, max_logits_ptr, tmp_out_ptr, out_ptr, \
|
||||
max_ctx_blocks, k_scale_ptr, v_scale_ptr);
|
||||
|
||||
#define LAUNCH_CUSTOM_ATTENTION_MFMA4(GQA_RATIO) \
|
||||
paged_attention_ll4mi_QKV_mfma4_kernel<T, KVT, KV_DTYPE, OUTT, BLOCK_SIZE, \
|
||||
HEAD_SIZE, NTHR, ALIBI_ENABLED, \
|
||||
GQA_RATIO> \
|
||||
<<<grid, block, 0, stream>>>( \
|
||||
query_ptr, key_cache_ptr, value_cache_ptr, num_kv_heads, scale, \
|
||||
block_tables_ptr, context_lens_ptr, max_num_blocks_per_seq, \
|
||||
alibi_slopes_ptr, q_stride, kv_block_stride, kv_head_stride, \
|
||||
exp_sums_ptr, max_logits_ptr, tmp_out_ptr, out_ptr, max_ctx_blocks, \
|
||||
k_scale_ptr, v_scale_ptr);
|
||||
#define LAUNCH_CUSTOM_ATTENTION_MFMA4(GQA_RATIO) \
|
||||
paged_attention_ll4mi_QKV_mfma4_kernel<T, KVT, KV_DTYPE, OUTT, BLOCK_SIZE, \
|
||||
HEAD_SIZE, NTHR, ALIBI_ENABLED, \
|
||||
GQA_RATIO> \
|
||||
<<<grid, block, 0, stream>>>( \
|
||||
query_ptr, key_cache_ptr, value_cache_ptr, num_kv_heads, scale, \
|
||||
block_tables_ptr, context_lens_ptr, query_start_loc_ptr, \
|
||||
max_num_blocks_per_seq, alibi_slopes_ptr, q_stride, kv_block_stride, \
|
||||
kv_head_stride, exp_sums_ptr, max_logits_ptr, tmp_out_ptr, out_ptr, \
|
||||
max_ctx_blocks, k_scale_ptr, v_scale_ptr);
|
||||
|
||||
#define LAUNCH_CUSTOM_REDUCTION(NPAR_LOOPS) \
|
||||
paged_attention_ll4mi_reduce_kernel<T, OUTT, HEAD_SIZE, HEAD_SIZE, \
|
||||
PARTITION_SIZE, NPAR_LOOPS> \
|
||||
<<<reduce_grid, reduce_block, 0, stream>>>( \
|
||||
out_ptr, exp_sums_ptr, max_logits_ptr, tmp_out_ptr, \
|
||||
context_lens_ptr, max_num_partitions);
|
||||
context_lens_ptr, query_start_loc_ptr, max_num_partitions);
|
||||
|
||||
template <typename T, typename KVT, vllm::Fp8KVCacheDataType KV_DTYPE,
|
||||
int BLOCK_SIZE, int HEAD_SIZE, typename OUTT, int PARTITION_SIZE_OLD,
|
||||
@ -1559,9 +1591,10 @@ void paged_attention_custom_launcher(
|
||||
torch::Tensor& tmp_out, torch::Tensor& query, torch::Tensor& key_cache,
|
||||
torch::Tensor& value_cache, const int num_kv_heads, float scale,
|
||||
torch::Tensor& block_tables, torch::Tensor& context_lens,
|
||||
int max_context_len, const std::optional<torch::Tensor>& alibi_slopes,
|
||||
torch::Tensor& k_scale, torch::Tensor& v_scale) {
|
||||
int num_seqs = query.size(0);
|
||||
const std::optional<torch::Tensor>& query_start_loc, int max_context_len,
|
||||
const std::optional<torch::Tensor>& alibi_slopes, torch::Tensor& k_scale,
|
||||
torch::Tensor& v_scale) {
|
||||
int num_seqs = block_tables.size(0);
|
||||
int num_heads = query.size(1);
|
||||
int head_size = query.size(2);
|
||||
int max_num_blocks_per_seq = block_tables.size(1);
|
||||
@ -1569,6 +1602,13 @@ void paged_attention_custom_launcher(
|
||||
int kv_block_stride = key_cache.stride(0);
|
||||
int kv_head_stride = key_cache.stride(1);
|
||||
|
||||
// NOTE: query start location is optional for V0 decode should not be used.
|
||||
// If batch contains mix of prefills and decode, prefills should be skipped.
|
||||
const int* query_start_loc_ptr =
|
||||
query_start_loc
|
||||
? reinterpret_cast<const int*>(query_start_loc.value().data_ptr())
|
||||
: nullptr;
|
||||
|
||||
// NOTE: alibi_slopes is optional.
|
||||
const float* alibi_slopes_ptr =
|
||||
alibi_slopes
|
||||
@ -1700,8 +1740,8 @@ void paged_attention_custom_launcher(
|
||||
paged_attention_custom_launcher<T, KVT, KV_DTYPE, BLK_SIZE, HEAD_SIZE, T, \
|
||||
PSIZE, ALIBI_ENABLED>( \
|
||||
out, exp_sums, max_logits, tmp_out, query, key_cache, value_cache, \
|
||||
num_kv_heads, scale, block_tables, context_lens, max_context_len, \
|
||||
alibi_slopes, k_scale, v_scale);
|
||||
num_kv_heads, scale, block_tables, context_lens, query_start_loc, \
|
||||
max_context_len, alibi_slopes, k_scale, v_scale);
|
||||
|
||||
#define CALL_CUSTOM_LAUNCHER_ALIBI(T, KVT, KV_DTYPE, BLK_SIZE, HEAD_SIZE, \
|
||||
PSIZE) \
|
||||
@ -1750,6 +1790,7 @@ void paged_attention(
|
||||
double scale,
|
||||
torch::Tensor& block_tables, // [num_seqs, max_num_blocks_per_seq]
|
||||
torch::Tensor& context_lens, // [num_seqs]
|
||||
const std::optional<torch::Tensor>& query_start_loc, // [num_seqs]
|
||||
int64_t block_size, int64_t max_context_len,
|
||||
const std::optional<torch::Tensor>& alibi_slopes,
|
||||
const std::string& kv_cache_dtype, torch::Tensor& k_scale,
|
||||
|
@ -7,8 +7,9 @@ void paged_attention(torch::Tensor& out, torch::Tensor& exp_sums,
|
||||
torch::Tensor& query, torch::Tensor& key_cache,
|
||||
torch::Tensor& value_cache, int64_t num_kv_heads,
|
||||
double scale, torch::Tensor& block_tables,
|
||||
torch::Tensor& context_lens, int64_t block_size,
|
||||
int64_t max_context_len,
|
||||
torch::Tensor& context_lens,
|
||||
const std::optional<torch::Tensor>& query_start_loc,
|
||||
int64_t block_size, int64_t max_context_len,
|
||||
const std::optional<torch::Tensor>& alibi_slopes,
|
||||
const std::string& kv_cache_dtype, torch::Tensor& k_scale,
|
||||
torch::Tensor& v_scale);
|
||||
|
@ -23,7 +23,9 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, rocm_ops) {
|
||||
" Tensor query, Tensor key_cache,"
|
||||
" Tensor value_cache, int num_kv_heads,"
|
||||
" float scale, Tensor block_tables,"
|
||||
" Tensor context_lens, int block_size,"
|
||||
" Tensor context_lens,"
|
||||
" Tensor? query_start_loc,"
|
||||
" int block_size,"
|
||||
" int max_context_len,"
|
||||
" Tensor? alibi_slopes,"
|
||||
" str kv_cache_dtype,"
|
||||
|
@ -164,6 +164,7 @@ def test_contexted_kv_attention(
|
||||
block_table,
|
||||
b_start_loc,
|
||||
b_seq_len,
|
||||
MAX_CTX_LEN,
|
||||
max_input_len,
|
||||
k_scale,
|
||||
v_scale,
|
||||
@ -180,6 +181,7 @@ def test_contexted_kv_attention(
|
||||
block_table,
|
||||
b_start_loc,
|
||||
b_seq_len,
|
||||
MAX_CTX_LEN,
|
||||
max_input_len,
|
||||
k_scale,
|
||||
v_scale,
|
||||
@ -397,6 +399,7 @@ def test_contexted_kv_attention_alibi(
|
||||
block_table,
|
||||
b_start_loc,
|
||||
b_seq_len,
|
||||
MAX_CTX_LEN,
|
||||
max_input_len,
|
||||
k_scale,
|
||||
v_scale,
|
||||
@ -413,6 +416,7 @@ def test_contexted_kv_attention_alibi(
|
||||
block_table,
|
||||
b_start_loc,
|
||||
b_seq_len,
|
||||
MAX_CTX_LEN,
|
||||
max_input_len,
|
||||
k_scale,
|
||||
v_scale,
|
||||
|
@ -110,6 +110,7 @@ def paged_attention_rocm(
|
||||
scale: float,
|
||||
block_tables: torch.Tensor,
|
||||
seq_lens: torch.Tensor,
|
||||
query_start_loc: Optional[torch.Tensor],
|
||||
block_size: int,
|
||||
max_seq_len: int,
|
||||
alibi_slopes: Optional[torch.Tensor],
|
||||
@ -120,8 +121,9 @@ def paged_attention_rocm(
|
||||
torch.ops._rocm_C.paged_attention(out, exp_sum, max_logits, tmp_out, query,
|
||||
key_cache, value_cache, num_kv_heads,
|
||||
scale, block_tables, seq_lens,
|
||||
block_size, max_seq_len, alibi_slopes,
|
||||
kv_cache_dtype, k_scale, v_scale)
|
||||
query_start_loc, block_size, max_seq_len,
|
||||
alibi_slopes, kv_cache_dtype, k_scale,
|
||||
v_scale)
|
||||
|
||||
|
||||
def mla_decode_kvcache_cpu(
|
||||
|
@ -17,16 +17,13 @@ from vllm.attention.ops.paged_attn import (PagedAttention,
|
||||
PagedAttentionMetadata)
|
||||
from vllm.logger import init_logger
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.platforms.rocm import use_rocm_custom_paged_attention
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vllm.worker.model_runner import ModelInputForGPUWithSamplingMetadata
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
_PARTITION_SIZE_ROCM = 256
|
||||
_GPU_ARCH = torch.cuda.get_device_properties("cuda").gcnArchName
|
||||
_ON_NAVI = "gfx1" in _GPU_ARCH
|
||||
_ON_MI250_MI300 = any(arch in _GPU_ARCH for arch in ["gfx90a", "gfx942"])
|
||||
|
||||
|
||||
class ROCmFlashAttentionBackend(AttentionBackend):
|
||||
@ -790,9 +787,9 @@ class ROCmFlashAttentionImpl(AttentionImpl):
|
||||
num_seqs, num_heads, head_size = decode_query.shape
|
||||
block_size = value_cache.shape[3]
|
||||
gqa_ratio = num_heads // self.num_kv_heads
|
||||
use_custom = _use_rocm_custom_paged_attention(
|
||||
use_custom = use_rocm_custom_paged_attention(
|
||||
decode_query.dtype, head_size, block_size, gqa_ratio,
|
||||
decode_meta.max_decode_seq_len)
|
||||
decode_meta.max_decode_seq_len, self.sliding_window)
|
||||
if use_custom:
|
||||
max_seq_len = (decode_meta.max_decode_seq_len if self.attn_type
|
||||
!= AttentionType.ENCODER_DECODER else
|
||||
@ -817,6 +814,8 @@ class ROCmFlashAttentionImpl(AttentionImpl):
|
||||
out = output[num_prefill_tokens:]
|
||||
else:
|
||||
out = output
|
||||
|
||||
query_start_loc = None
|
||||
ops.paged_attention_rocm(
|
||||
out,
|
||||
exp_sums,
|
||||
@ -833,6 +832,7 @@ class ROCmFlashAttentionImpl(AttentionImpl):
|
||||
decode_meta.seq_lens_tensor
|
||||
if self.attn_type != AttentionType.ENCODER_DECODER else
|
||||
decode_meta.encoder_seq_lens_tensor,
|
||||
query_start_loc,
|
||||
block_size,
|
||||
max_seq_len,
|
||||
self.alibi_slopes,
|
||||
@ -898,15 +898,3 @@ def _sdpa_attention(
|
||||
start = end
|
||||
|
||||
return output
|
||||
|
||||
|
||||
def _use_rocm_custom_paged_attention(qtype: torch.dtype, head_size: int,
|
||||
block_size: int, gqa_ratio: int,
|
||||
max_seq_len: int) -> bool:
|
||||
# rocm custom page attention not support on navi (gfx1*)
|
||||
return (_ON_MI250_MI300 and not _ON_NAVI
|
||||
and (qtype == torch.half or qtype == torch.bfloat16)
|
||||
and (head_size == 64 or head_size == 128)
|
||||
and (block_size == 16 or block_size == 32)
|
||||
and (gqa_ratio >= 1 and gqa_ratio <= 16) and max_seq_len <= 32768
|
||||
and envs.VLLM_ROCM_CUSTOM_PAGED_ATTN)
|
||||
|
@ -10,6 +10,9 @@ import torch
|
||||
import triton
|
||||
import triton.language as tl
|
||||
|
||||
from vllm import _custom_ops as ops
|
||||
from vllm.platforms.rocm import use_rocm_custom_paged_attention
|
||||
|
||||
from .prefix_prefill import context_attention_fwd
|
||||
|
||||
|
||||
@ -33,26 +36,26 @@ def kernel_paged_attention_2d(
|
||||
num_query_heads: tl.constexpr, # int
|
||||
num_queries_per_kv: tl.constexpr, # int
|
||||
num_queries_per_kv_padded: tl.constexpr, # int
|
||||
block_table_stride: tl.constexpr, # int
|
||||
query_stride_0: tl.constexpr, # int
|
||||
query_stride_1: tl.constexpr, # int, should be equal to head_size
|
||||
output_stride_0: tl.constexpr, # int
|
||||
output_stride_1: tl.constexpr, # int, should be equal to head_size
|
||||
block_table_stride: tl.int64, # int
|
||||
query_stride_0: tl.int64, # int
|
||||
query_stride_1: tl.int64, # int, should be equal to head_size
|
||||
output_stride_0: tl.int64, # int
|
||||
output_stride_1: tl.int64, # int, should be equal to head_size
|
||||
BLOCK_SIZE: tl.constexpr, # int
|
||||
HEAD_SIZE: tl.constexpr, # int
|
||||
HEAD_SIZE_PADDED: tl.constexpr, # int, must be power of 2
|
||||
USE_ALIBI_SLOPES: tl.constexpr, # bool
|
||||
SLIDING_WINDOW: tl.constexpr, # int
|
||||
x: tl.constexpr, # int
|
||||
stride_k_cache_0: tl.constexpr, # int
|
||||
stride_k_cache_1: tl.constexpr, # int
|
||||
stride_k_cache_2: tl.constexpr, # int
|
||||
stride_k_cache_3: tl.constexpr, # int
|
||||
stride_k_cache_4: tl.constexpr, # int
|
||||
stride_v_cache_0: tl.constexpr, # int
|
||||
stride_v_cache_1: tl.constexpr, # int
|
||||
stride_v_cache_2: tl.constexpr, # int
|
||||
stride_v_cache_3: tl.constexpr, # int
|
||||
stride_k_cache_0: tl.int64, # int
|
||||
stride_k_cache_1: tl.int64, # int
|
||||
stride_k_cache_2: tl.int64, # int
|
||||
stride_k_cache_3: tl.int64, # int
|
||||
stride_k_cache_4: tl.int64, # int
|
||||
stride_v_cache_0: tl.int64, # int
|
||||
stride_v_cache_1: tl.int64, # int
|
||||
stride_v_cache_2: tl.int64, # int
|
||||
stride_v_cache_3: tl.int64, # int
|
||||
filter_by_query_len: tl.constexpr, # bool
|
||||
query_start_len_ptr, # [num_seqs+1]
|
||||
):
|
||||
@ -212,6 +215,7 @@ def chunked_prefill_paged_decode(
|
||||
block_table,
|
||||
query_start_loc,
|
||||
seq_lens,
|
||||
max_seq_len,
|
||||
max_query_len,
|
||||
k_scale,
|
||||
v_scale,
|
||||
@ -240,6 +244,7 @@ def chunked_prefill_paged_decode(
|
||||
b_loc=block_table,
|
||||
b_start_loc=query_start_loc,
|
||||
b_seq_len=seq_lens,
|
||||
max_seq_len=max_seq_len,
|
||||
max_input_len=max_query_len,
|
||||
k_scale=k_scale,
|
||||
v_scale=v_scale,
|
||||
@ -275,43 +280,87 @@ def chunked_prefill_paged_decode(
|
||||
num_queries_per_kv_padded = max(triton.next_power_of_2(num_queries_per_kv),
|
||||
16)
|
||||
|
||||
kernel_paged_attention_2d[(
|
||||
num_seqs,
|
||||
num_kv_heads,
|
||||
)](
|
||||
output_ptr=output,
|
||||
query_ptr=query,
|
||||
key_cache_ptr=key_cache,
|
||||
value_cache_ptr=value_cache,
|
||||
block_tables_ptr=block_table,
|
||||
seq_lens_ptr=seq_lens,
|
||||
alibi_slopes_ptr=alibi_slopes,
|
||||
scale=sm_scale,
|
||||
k_scale=k_scale,
|
||||
v_scale=v_scale,
|
||||
num_query_heads=num_query_heads,
|
||||
num_queries_per_kv=num_queries_per_kv,
|
||||
num_queries_per_kv_padded=num_queries_per_kv_padded,
|
||||
block_table_stride=block_table.stride(0),
|
||||
query_stride_0=query.stride(0),
|
||||
query_stride_1=query.stride(1),
|
||||
output_stride_0=output.stride(0),
|
||||
output_stride_1=output.stride(1),
|
||||
BLOCK_SIZE=block_size,
|
||||
HEAD_SIZE=head_size,
|
||||
HEAD_SIZE_PADDED=triton.next_power_of_2(head_size),
|
||||
USE_ALIBI_SLOPES=use_alibi_slopes,
|
||||
SLIDING_WINDOW=sliding_window,
|
||||
x=key_cache.shape[4],
|
||||
stride_k_cache_0=key_cache.stride(0),
|
||||
stride_k_cache_1=key_cache.stride(1),
|
||||
stride_k_cache_2=key_cache.stride(2),
|
||||
stride_k_cache_3=key_cache.stride(3),
|
||||
stride_k_cache_4=key_cache.stride(4),
|
||||
stride_v_cache_0=value_cache.stride(0),
|
||||
stride_v_cache_1=value_cache.stride(1),
|
||||
stride_v_cache_2=value_cache.stride(2),
|
||||
stride_v_cache_3=value_cache.stride(3),
|
||||
filter_by_query_len=True,
|
||||
query_start_len_ptr=query_start_loc,
|
||||
)
|
||||
use_custom = use_rocm_custom_paged_attention(query.dtype, head_size,
|
||||
block_size,
|
||||
num_queries_per_kv,
|
||||
max_seq_len, sliding_window)
|
||||
if use_custom:
|
||||
_PARTITION_SIZE_ROCM = 256
|
||||
max_num_partitions = ((max_seq_len + _PARTITION_SIZE_ROCM - 1) //
|
||||
_PARTITION_SIZE_ROCM)
|
||||
assert _PARTITION_SIZE_ROCM % block_size == 0
|
||||
total_num_seq = query.shape[0]
|
||||
tmp_output = torch.empty(
|
||||
size=(total_num_seq, num_query_heads, max_num_partitions,
|
||||
head_size),
|
||||
dtype=output.dtype,
|
||||
device=output.device,
|
||||
)
|
||||
exp_sums = torch.empty(
|
||||
size=(total_num_seq, num_query_heads, max_num_partitions),
|
||||
dtype=torch.float32,
|
||||
device=output.device,
|
||||
)
|
||||
max_logits = torch.empty_like(exp_sums)
|
||||
|
||||
ops.paged_attention_rocm(
|
||||
output,
|
||||
exp_sums,
|
||||
max_logits,
|
||||
tmp_output,
|
||||
query,
|
||||
key_cache,
|
||||
value_cache,
|
||||
num_kv_heads,
|
||||
scale=sm_scale,
|
||||
block_tables=block_table,
|
||||
seq_lens=seq_lens,
|
||||
query_start_loc=query_start_loc,
|
||||
block_size=block_size,
|
||||
max_seq_len=max_seq_len,
|
||||
alibi_slopes=alibi_slopes,
|
||||
kv_cache_dtype=kv_cache_dtype,
|
||||
k_scale=k_scale,
|
||||
v_scale=v_scale,
|
||||
)
|
||||
else:
|
||||
kernel_paged_attention_2d[(
|
||||
num_seqs,
|
||||
num_kv_heads,
|
||||
)](
|
||||
output_ptr=output,
|
||||
query_ptr=query,
|
||||
key_cache_ptr=key_cache,
|
||||
value_cache_ptr=value_cache,
|
||||
block_tables_ptr=block_table,
|
||||
seq_lens_ptr=seq_lens,
|
||||
alibi_slopes_ptr=alibi_slopes,
|
||||
scale=sm_scale,
|
||||
k_scale=k_scale,
|
||||
v_scale=v_scale,
|
||||
num_query_heads=num_query_heads,
|
||||
num_queries_per_kv=num_queries_per_kv,
|
||||
num_queries_per_kv_padded=num_queries_per_kv_padded,
|
||||
block_table_stride=block_table.stride(0),
|
||||
query_stride_0=query.stride(0),
|
||||
query_stride_1=query.stride(1),
|
||||
output_stride_0=output.stride(0),
|
||||
output_stride_1=output.stride(1),
|
||||
BLOCK_SIZE=block_size,
|
||||
HEAD_SIZE=head_size,
|
||||
HEAD_SIZE_PADDED=triton.next_power_of_2(head_size),
|
||||
USE_ALIBI_SLOPES=use_alibi_slopes,
|
||||
SLIDING_WINDOW=sliding_window,
|
||||
x=key_cache.shape[4],
|
||||
stride_k_cache_0=key_cache.stride(0),
|
||||
stride_k_cache_1=key_cache.stride(1),
|
||||
stride_k_cache_2=key_cache.stride(2),
|
||||
stride_k_cache_3=key_cache.stride(3),
|
||||
stride_k_cache_4=key_cache.stride(4),
|
||||
stride_v_cache_0=value_cache.stride(0),
|
||||
stride_v_cache_1=value_cache.stride(1),
|
||||
stride_v_cache_2=value_cache.stride(2),
|
||||
stride_v_cache_3=value_cache.stride(3),
|
||||
filter_by_query_len=True,
|
||||
query_start_len_ptr=query_start_loc,
|
||||
)
|
||||
|
@ -209,6 +209,7 @@ class PagedAttention:
|
||||
v_scale: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
output = torch.empty_like(query)
|
||||
max_seq_len = None
|
||||
context_attention_fwd(
|
||||
query,
|
||||
key,
|
||||
@ -221,6 +222,7 @@ class PagedAttention:
|
||||
# query_start_loc is (batch_size + 1,)
|
||||
query_start_loc,
|
||||
seq_lens_tensor,
|
||||
max_seq_len,
|
||||
max_query_len,
|
||||
k_scale,
|
||||
v_scale,
|
||||
|
@ -725,6 +725,7 @@ if triton.__version__ >= "2.1.0":
|
||||
b_loc,
|
||||
b_start_loc,
|
||||
b_seq_len,
|
||||
max_seq_len,
|
||||
max_input_len,
|
||||
k_scale: torch.Tensor,
|
||||
v_scale: torch.Tensor,
|
||||
|
@ -1,7 +1,7 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
import os
|
||||
from functools import lru_cache, wraps
|
||||
from functools import cache, lru_cache, wraps
|
||||
from typing import TYPE_CHECKING, Dict, List, Optional
|
||||
|
||||
import torch
|
||||
@ -98,6 +98,25 @@ def device_id_to_physical_device_id(device_id: int) -> int:
|
||||
return device_id
|
||||
|
||||
|
||||
@cache
|
||||
def use_rocm_custom_paged_attention(qtype: torch.dtype, head_size: int,
|
||||
block_size: int, gqa_ratio: int,
|
||||
max_seq_len: int,
|
||||
sliding_window: int) -> bool:
|
||||
|
||||
GPU_ARCH = torch.cuda.get_device_properties("cuda").gcnArchName
|
||||
ON_NAVI = "gfx1" in GPU_ARCH
|
||||
ON_MI250_MI300 = any(arch in GPU_ARCH for arch in ["gfx90a", "gfx942"])
|
||||
|
||||
# rocm custom page attention not support on navi (gfx1*)
|
||||
return (ON_MI250_MI300 and not ON_NAVI and (sliding_window == 0)
|
||||
and (qtype == torch.half or qtype == torch.bfloat16)
|
||||
and (head_size == 64 or head_size == 128)
|
||||
and (block_size == 16 or block_size == 32)
|
||||
and (gqa_ratio >= 1 and gqa_ratio <= 16) and max_seq_len <= 32768
|
||||
and envs.VLLM_ROCM_CUSTOM_PAGED_ATTN)
|
||||
|
||||
|
||||
class RocmPlatform(Platform):
|
||||
_enum = PlatformEnum.ROCM
|
||||
device_name: str = "rocm"
|
||||
|
@ -168,6 +168,7 @@ class TritonAttentionImpl(AttentionImpl):
|
||||
block_table=attn_metadata.block_table,
|
||||
query_start_loc=attn_metadata.query_start_loc,
|
||||
seq_lens=attn_metadata.seq_lens,
|
||||
max_seq_len=attn_metadata.max_seq_len,
|
||||
max_query_len=attn_metadata.max_query_len,
|
||||
k_scale=layer._k_scale,
|
||||
v_scale=layer._v_scale,
|
||||
|
Loading…
x
Reference in New Issue
Block a user