[ROCM][KERNEL] Paged attention for V1 (#15720)

Signed-off-by: Aleksandr Malyshev <maleksan@amd.com>
Signed-off-by: root <root@banff-cyxtera-s65-4.amd.com>
Co-authored-by: Aleksandr Malyshev <maleksan@amd.com>
Co-authored-by: root <root@banff-cyxtera-s65-4.amd.com>
This commit is contained in:
Aleksandr Malyshev 2025-04-02 19:48:00 -07:00 committed by GitHub
parent bd7599d34a
commit e73ff24e31
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
11 changed files with 219 additions and 109 deletions

View File

@ -272,6 +272,7 @@ __launch_bounds__(NUM_THREADS, 5) void paged_attention_ll4mi_QKV_mfma16_kernel(
const float scale,
const int* __restrict__ block_tables, // [num_seqs, max_num_blocks_per_seq]
const int* __restrict__ context_lens, // [num_seqs]
const int* __restrict__ query_start_loc_ptr, // [num_seqs]
const int max_num_blocks_per_seq,
const float* __restrict__ alibi_slopes, // [num_heads]
const int q_stride,
@ -291,6 +292,13 @@ __launch_bounds__(NUM_THREADS, 5) void paged_attention_ll4mi_QKV_mfma16_kernel(
const int rowid = laneid / 16;
const auto seq_idx = blockIdx.x;
// NOTE queries with sequence len > 1 are prefills and taken care by another
// kernel.
if (query_start_loc_ptr != nullptr &&
(query_start_loc_ptr[seq_idx + 1] - query_start_loc_ptr[seq_idx]) != 1) {
return;
}
const auto partition_idx = blockIdx.y;
constexpr int T_PAR_SIZE = 256; // token partition size set to 256
@ -377,9 +385,10 @@ __launch_bounds__(NUM_THREADS, 5) void paged_attention_ll4mi_QKV_mfma16_kernel(
// fetch Q in shared across warps and then write to registers
const int local_qhead_idx = 4 * warpid + rowid;
const int global_qhead_idx = wg_start_head_idx + local_qhead_idx;
const int64_t seq_idx64 = static_cast<int64_t>(seq_idx);
const int64_t query_start_off = static_cast<int64_t>(
query_start_loc_ptr ? query_start_loc_ptr[seq_idx] : seq_idx);
const scalar_t* q_ptr =
q + seq_idx64 * q_stride + global_qhead_idx * HEAD_SIZE;
q + query_start_off * q_stride + global_qhead_idx * HEAD_SIZE;
const int qhead_element = lane16id * CONTIGUOUS_SCALAR_ELEMS_16B;
if ((local_qhead_idx < GQA_RATIO) && (qhead_element < HEAD_SIZE)) {
@ -777,6 +786,7 @@ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_QKV_mfma4_kernel(
const float scale,
const int* __restrict__ block_tables, // [num_seqs, max_num_blocks_per_seq]
const int* __restrict__ context_lens, // [num_seqs]
const int* __restrict__ query_start_loc_ptr, // [num_seqs]
const int max_num_blocks_per_seq,
const float* __restrict__ alibi_slopes, // [num_heads]
const int q_stride,
@ -794,6 +804,12 @@ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_QKV_mfma4_kernel(
const int lane4id = laneid % 4;
const auto seq_idx = blockIdx.x;
// NOTE queries with sequence len > 1 are prefills and taken care by another
// kernel.
if (query_start_loc_ptr != nullptr &&
(query_start_loc_ptr[seq_idx + 1] - query_start_loc_ptr[seq_idx] != 1)) {
return;
}
const auto partition_idx = blockIdx.y;
const auto partition_size = blockDim.x;
const auto max_num_partitions = gridDim.y;
@ -882,9 +898,11 @@ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_QKV_mfma4_kernel(
}
// fetch q elements
// every 4 lanes fetch 8 elems, so warp fetches 8*16 = 128 elems
// every 4 lanes fetch 8 elems, so warp fetches 8*16 = 128 elemsc
const int64_t query_start_off = static_cast<int64_t>(
query_start_loc_ptr ? query_start_loc_ptr[seq_idx] : seq_idx);
const scalar_t* q_ptr =
q + seq_idx * q_stride + wg_start_head_idx * HEAD_SIZE;
q + query_start_off * q_stride + wg_start_head_idx * HEAD_SIZE;
const _B16x8* q_ptrh8 = reinterpret_cast<const _B16x8*>(q_ptr);
const int qhead_elemh8 = laneid / 4;
@ -1267,10 +1285,19 @@ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_reduce_kernel(
const scalar_t* __restrict__ tmp_out, // [num_seqs, num_heads,
// max_num_partitions, head_size]
const int* __restrict__ context_lens, // [num_seqs]
const int* __restrict__ query_start_loc_ptr, // [num_seqs]
const int max_num_partitions) {
const auto num_heads = gridDim.x;
const auto head_idx = blockIdx.x;
const auto seq_idx = blockIdx.y;
// NOTE queries with sequence len > 1 are prefills and taken care by another
// kernel.
if (query_start_loc_ptr != nullptr &&
(query_start_loc_ptr[seq_idx + 1] - query_start_loc_ptr[seq_idx] != 1)) {
return;
}
const int context_len = context_lens[seq_idx];
const int num_partitions = DIVIDE_ROUND_UP(context_len, PARTITION_SIZE);
[[maybe_unused]] constexpr int NUM_WARPS = NUM_THREADS / WARP_SIZE;
@ -1439,7 +1466,9 @@ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_reduce_kernel(
__fdividef(1.0f, shared_global_exp_sum + 1e-6f);
acc *= inv_global_exp_sum;
OUTT* out_ptr = out + static_cast<int64_t>(seq_idx) * num_heads * HEAD_SIZE +
const int64_t query_start_off = static_cast<int64_t>(
query_start_loc_ptr ? query_start_loc_ptr[seq_idx] : seq_idx);
OUTT* out_ptr = out + query_start_off * num_heads * HEAD_SIZE +
static_cast<int64_t>(head_idx) * HEAD_SIZE;
if constexpr (std::is_same<OUTT, bit8_t>::value) {
out_ptr[threadIdx.x] =
@ -1466,6 +1495,7 @@ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_QKV_mfma16_kernel(
const float scale,
const int* __restrict__ block_tables, // [num_seqs, max_num_blocks_per_seq]
const int* __restrict__ context_lens, // [num_seqs]
const int* __restrict__ query_start_loc_ptr, // [num_seqs]
const int max_num_blocks_per_seq,
const float* __restrict__ alibi_slopes, // [num_heads]
const int q_stride,
@ -1492,6 +1522,7 @@ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_QKV_mfma4_kernel(
const float scale,
const int* __restrict__ block_tables, // [num_seqs, max_num_blocks_per_seq]
const int* __restrict__ context_lens, // [num_seqs]
const int* __restrict__ query_start_loc_ptr, // [num_seqs]
const int max_num_blocks_per_seq,
const float* __restrict__ alibi_slopes, // [num_heads]
const int q_stride,
@ -1515,6 +1546,7 @@ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_reduce_kernel(
const float* __restrict__ max_logits, // [num_seqs, num_heads, max_num_partitions]
const scalar_t* __restrict__ tmp_out, // [num_seqs, num_heads, max_num_partitions, head_size]
const int* __restrict__ context_lens, // [num_seqs]
const int* __restrict__ query_start_loc_ptr, // [num_seqs]
const int max_num_partitions) {
UNREACHABLE_CODE
}
@ -1522,34 +1554,34 @@ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_reduce_kernel(
#endif // defined(__HIP__MI300_MI250__) TODO: Add NAVI support
#define LAUNCH_CUSTOM_ATTENTION_MFMA16(GQA_RATIO) \
paged_attention_ll4mi_QKV_mfma16_kernel<T, KVT, KV_DTYPE, OUTT, BLOCK_SIZE, \
HEAD_SIZE, NTHR, ALIBI_ENABLED, \
GQA_RATIO> \
<<<grid, block, 0, stream>>>( \
query_ptr, key_cache_ptr, value_cache_ptr, num_kv_heads, scale, \
block_tables_ptr, context_lens_ptr, max_num_blocks_per_seq, \
alibi_slopes_ptr, q_stride, kv_block_stride, kv_head_stride, \
exp_sums_ptr, max_logits_ptr, tmp_out_ptr, out_ptr, max_ctx_blocks, \
k_scale_ptr, v_scale_ptr);
#define LAUNCH_CUSTOM_ATTENTION_MFMA16(GQA_RATIO) \
paged_attention_ll4mi_QKV_mfma16_kernel<T, KVT, KV_DTYPE, OUTT, BLOCK_SIZE, \
HEAD_SIZE, NTHR, ALIBI_ENABLED, \
GQA_RATIO> \
<<<grid, block, 0, stream>>>( \
query_ptr, key_cache_ptr, value_cache_ptr, num_kv_heads, scale, \
block_tables_ptr, context_lens_ptr, query_start_loc_ptr, \
max_num_blocks_per_seq, alibi_slopes_ptr, q_stride, kv_block_stride, \
kv_head_stride, exp_sums_ptr, max_logits_ptr, tmp_out_ptr, out_ptr, \
max_ctx_blocks, k_scale_ptr, v_scale_ptr);
#define LAUNCH_CUSTOM_ATTENTION_MFMA4(GQA_RATIO) \
paged_attention_ll4mi_QKV_mfma4_kernel<T, KVT, KV_DTYPE, OUTT, BLOCK_SIZE, \
HEAD_SIZE, NTHR, ALIBI_ENABLED, \
GQA_RATIO> \
<<<grid, block, 0, stream>>>( \
query_ptr, key_cache_ptr, value_cache_ptr, num_kv_heads, scale, \
block_tables_ptr, context_lens_ptr, max_num_blocks_per_seq, \
alibi_slopes_ptr, q_stride, kv_block_stride, kv_head_stride, \
exp_sums_ptr, max_logits_ptr, tmp_out_ptr, out_ptr, max_ctx_blocks, \
k_scale_ptr, v_scale_ptr);
#define LAUNCH_CUSTOM_ATTENTION_MFMA4(GQA_RATIO) \
paged_attention_ll4mi_QKV_mfma4_kernel<T, KVT, KV_DTYPE, OUTT, BLOCK_SIZE, \
HEAD_SIZE, NTHR, ALIBI_ENABLED, \
GQA_RATIO> \
<<<grid, block, 0, stream>>>( \
query_ptr, key_cache_ptr, value_cache_ptr, num_kv_heads, scale, \
block_tables_ptr, context_lens_ptr, query_start_loc_ptr, \
max_num_blocks_per_seq, alibi_slopes_ptr, q_stride, kv_block_stride, \
kv_head_stride, exp_sums_ptr, max_logits_ptr, tmp_out_ptr, out_ptr, \
max_ctx_blocks, k_scale_ptr, v_scale_ptr);
#define LAUNCH_CUSTOM_REDUCTION(NPAR_LOOPS) \
paged_attention_ll4mi_reduce_kernel<T, OUTT, HEAD_SIZE, HEAD_SIZE, \
PARTITION_SIZE, NPAR_LOOPS> \
<<<reduce_grid, reduce_block, 0, stream>>>( \
out_ptr, exp_sums_ptr, max_logits_ptr, tmp_out_ptr, \
context_lens_ptr, max_num_partitions);
context_lens_ptr, query_start_loc_ptr, max_num_partitions);
template <typename T, typename KVT, vllm::Fp8KVCacheDataType KV_DTYPE,
int BLOCK_SIZE, int HEAD_SIZE, typename OUTT, int PARTITION_SIZE_OLD,
@ -1559,9 +1591,10 @@ void paged_attention_custom_launcher(
torch::Tensor& tmp_out, torch::Tensor& query, torch::Tensor& key_cache,
torch::Tensor& value_cache, const int num_kv_heads, float scale,
torch::Tensor& block_tables, torch::Tensor& context_lens,
int max_context_len, const std::optional<torch::Tensor>& alibi_slopes,
torch::Tensor& k_scale, torch::Tensor& v_scale) {
int num_seqs = query.size(0);
const std::optional<torch::Tensor>& query_start_loc, int max_context_len,
const std::optional<torch::Tensor>& alibi_slopes, torch::Tensor& k_scale,
torch::Tensor& v_scale) {
int num_seqs = block_tables.size(0);
int num_heads = query.size(1);
int head_size = query.size(2);
int max_num_blocks_per_seq = block_tables.size(1);
@ -1569,6 +1602,13 @@ void paged_attention_custom_launcher(
int kv_block_stride = key_cache.stride(0);
int kv_head_stride = key_cache.stride(1);
// NOTE: query start location is optional for V0 decode should not be used.
// If batch contains mix of prefills and decode, prefills should be skipped.
const int* query_start_loc_ptr =
query_start_loc
? reinterpret_cast<const int*>(query_start_loc.value().data_ptr())
: nullptr;
// NOTE: alibi_slopes is optional.
const float* alibi_slopes_ptr =
alibi_slopes
@ -1700,8 +1740,8 @@ void paged_attention_custom_launcher(
paged_attention_custom_launcher<T, KVT, KV_DTYPE, BLK_SIZE, HEAD_SIZE, T, \
PSIZE, ALIBI_ENABLED>( \
out, exp_sums, max_logits, tmp_out, query, key_cache, value_cache, \
num_kv_heads, scale, block_tables, context_lens, max_context_len, \
alibi_slopes, k_scale, v_scale);
num_kv_heads, scale, block_tables, context_lens, query_start_loc, \
max_context_len, alibi_slopes, k_scale, v_scale);
#define CALL_CUSTOM_LAUNCHER_ALIBI(T, KVT, KV_DTYPE, BLK_SIZE, HEAD_SIZE, \
PSIZE) \
@ -1750,6 +1790,7 @@ void paged_attention(
double scale,
torch::Tensor& block_tables, // [num_seqs, max_num_blocks_per_seq]
torch::Tensor& context_lens, // [num_seqs]
const std::optional<torch::Tensor>& query_start_loc, // [num_seqs]
int64_t block_size, int64_t max_context_len,
const std::optional<torch::Tensor>& alibi_slopes,
const std::string& kv_cache_dtype, torch::Tensor& k_scale,

View File

@ -7,8 +7,9 @@ void paged_attention(torch::Tensor& out, torch::Tensor& exp_sums,
torch::Tensor& query, torch::Tensor& key_cache,
torch::Tensor& value_cache, int64_t num_kv_heads,
double scale, torch::Tensor& block_tables,
torch::Tensor& context_lens, int64_t block_size,
int64_t max_context_len,
torch::Tensor& context_lens,
const std::optional<torch::Tensor>& query_start_loc,
int64_t block_size, int64_t max_context_len,
const std::optional<torch::Tensor>& alibi_slopes,
const std::string& kv_cache_dtype, torch::Tensor& k_scale,
torch::Tensor& v_scale);

View File

@ -23,7 +23,9 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, rocm_ops) {
" Tensor query, Tensor key_cache,"
" Tensor value_cache, int num_kv_heads,"
" float scale, Tensor block_tables,"
" Tensor context_lens, int block_size,"
" Tensor context_lens,"
" Tensor? query_start_loc,"
" int block_size,"
" int max_context_len,"
" Tensor? alibi_slopes,"
" str kv_cache_dtype,"

View File

@ -164,6 +164,7 @@ def test_contexted_kv_attention(
block_table,
b_start_loc,
b_seq_len,
MAX_CTX_LEN,
max_input_len,
k_scale,
v_scale,
@ -180,6 +181,7 @@ def test_contexted_kv_attention(
block_table,
b_start_loc,
b_seq_len,
MAX_CTX_LEN,
max_input_len,
k_scale,
v_scale,
@ -397,6 +399,7 @@ def test_contexted_kv_attention_alibi(
block_table,
b_start_loc,
b_seq_len,
MAX_CTX_LEN,
max_input_len,
k_scale,
v_scale,
@ -413,6 +416,7 @@ def test_contexted_kv_attention_alibi(
block_table,
b_start_loc,
b_seq_len,
MAX_CTX_LEN,
max_input_len,
k_scale,
v_scale,

View File

@ -110,6 +110,7 @@ def paged_attention_rocm(
scale: float,
block_tables: torch.Tensor,
seq_lens: torch.Tensor,
query_start_loc: Optional[torch.Tensor],
block_size: int,
max_seq_len: int,
alibi_slopes: Optional[torch.Tensor],
@ -120,8 +121,9 @@ def paged_attention_rocm(
torch.ops._rocm_C.paged_attention(out, exp_sum, max_logits, tmp_out, query,
key_cache, value_cache, num_kv_heads,
scale, block_tables, seq_lens,
block_size, max_seq_len, alibi_slopes,
kv_cache_dtype, k_scale, v_scale)
query_start_loc, block_size, max_seq_len,
alibi_slopes, kv_cache_dtype, k_scale,
v_scale)
def mla_decode_kvcache_cpu(

View File

@ -17,16 +17,13 @@ from vllm.attention.ops.paged_attn import (PagedAttention,
PagedAttentionMetadata)
from vllm.logger import init_logger
from vllm.platforms import current_platform
from vllm.platforms.rocm import use_rocm_custom_paged_attention
if TYPE_CHECKING:
from vllm.worker.model_runner import ModelInputForGPUWithSamplingMetadata
logger = init_logger(__name__)
_PARTITION_SIZE_ROCM = 256
_GPU_ARCH = torch.cuda.get_device_properties("cuda").gcnArchName
_ON_NAVI = "gfx1" in _GPU_ARCH
_ON_MI250_MI300 = any(arch in _GPU_ARCH for arch in ["gfx90a", "gfx942"])
class ROCmFlashAttentionBackend(AttentionBackend):
@ -790,9 +787,9 @@ class ROCmFlashAttentionImpl(AttentionImpl):
num_seqs, num_heads, head_size = decode_query.shape
block_size = value_cache.shape[3]
gqa_ratio = num_heads // self.num_kv_heads
use_custom = _use_rocm_custom_paged_attention(
use_custom = use_rocm_custom_paged_attention(
decode_query.dtype, head_size, block_size, gqa_ratio,
decode_meta.max_decode_seq_len)
decode_meta.max_decode_seq_len, self.sliding_window)
if use_custom:
max_seq_len = (decode_meta.max_decode_seq_len if self.attn_type
!= AttentionType.ENCODER_DECODER else
@ -817,6 +814,8 @@ class ROCmFlashAttentionImpl(AttentionImpl):
out = output[num_prefill_tokens:]
else:
out = output
query_start_loc = None
ops.paged_attention_rocm(
out,
exp_sums,
@ -833,6 +832,7 @@ class ROCmFlashAttentionImpl(AttentionImpl):
decode_meta.seq_lens_tensor
if self.attn_type != AttentionType.ENCODER_DECODER else
decode_meta.encoder_seq_lens_tensor,
query_start_loc,
block_size,
max_seq_len,
self.alibi_slopes,
@ -898,15 +898,3 @@ def _sdpa_attention(
start = end
return output
def _use_rocm_custom_paged_attention(qtype: torch.dtype, head_size: int,
block_size: int, gqa_ratio: int,
max_seq_len: int) -> bool:
# rocm custom page attention not support on navi (gfx1*)
return (_ON_MI250_MI300 and not _ON_NAVI
and (qtype == torch.half or qtype == torch.bfloat16)
and (head_size == 64 or head_size == 128)
and (block_size == 16 or block_size == 32)
and (gqa_ratio >= 1 and gqa_ratio <= 16) and max_seq_len <= 32768
and envs.VLLM_ROCM_CUSTOM_PAGED_ATTN)

View File

@ -10,6 +10,9 @@ import torch
import triton
import triton.language as tl
from vllm import _custom_ops as ops
from vllm.platforms.rocm import use_rocm_custom_paged_attention
from .prefix_prefill import context_attention_fwd
@ -33,26 +36,26 @@ def kernel_paged_attention_2d(
num_query_heads: tl.constexpr, # int
num_queries_per_kv: tl.constexpr, # int
num_queries_per_kv_padded: tl.constexpr, # int
block_table_stride: tl.constexpr, # int
query_stride_0: tl.constexpr, # int
query_stride_1: tl.constexpr, # int, should be equal to head_size
output_stride_0: tl.constexpr, # int
output_stride_1: tl.constexpr, # int, should be equal to head_size
block_table_stride: tl.int64, # int
query_stride_0: tl.int64, # int
query_stride_1: tl.int64, # int, should be equal to head_size
output_stride_0: tl.int64, # int
output_stride_1: tl.int64, # int, should be equal to head_size
BLOCK_SIZE: tl.constexpr, # int
HEAD_SIZE: tl.constexpr, # int
HEAD_SIZE_PADDED: tl.constexpr, # int, must be power of 2
USE_ALIBI_SLOPES: tl.constexpr, # bool
SLIDING_WINDOW: tl.constexpr, # int
x: tl.constexpr, # int
stride_k_cache_0: tl.constexpr, # int
stride_k_cache_1: tl.constexpr, # int
stride_k_cache_2: tl.constexpr, # int
stride_k_cache_3: tl.constexpr, # int
stride_k_cache_4: tl.constexpr, # int
stride_v_cache_0: tl.constexpr, # int
stride_v_cache_1: tl.constexpr, # int
stride_v_cache_2: tl.constexpr, # int
stride_v_cache_3: tl.constexpr, # int
stride_k_cache_0: tl.int64, # int
stride_k_cache_1: tl.int64, # int
stride_k_cache_2: tl.int64, # int
stride_k_cache_3: tl.int64, # int
stride_k_cache_4: tl.int64, # int
stride_v_cache_0: tl.int64, # int
stride_v_cache_1: tl.int64, # int
stride_v_cache_2: tl.int64, # int
stride_v_cache_3: tl.int64, # int
filter_by_query_len: tl.constexpr, # bool
query_start_len_ptr, # [num_seqs+1]
):
@ -212,6 +215,7 @@ def chunked_prefill_paged_decode(
block_table,
query_start_loc,
seq_lens,
max_seq_len,
max_query_len,
k_scale,
v_scale,
@ -240,6 +244,7 @@ def chunked_prefill_paged_decode(
b_loc=block_table,
b_start_loc=query_start_loc,
b_seq_len=seq_lens,
max_seq_len=max_seq_len,
max_input_len=max_query_len,
k_scale=k_scale,
v_scale=v_scale,
@ -275,43 +280,87 @@ def chunked_prefill_paged_decode(
num_queries_per_kv_padded = max(triton.next_power_of_2(num_queries_per_kv),
16)
kernel_paged_attention_2d[(
num_seqs,
num_kv_heads,
)](
output_ptr=output,
query_ptr=query,
key_cache_ptr=key_cache,
value_cache_ptr=value_cache,
block_tables_ptr=block_table,
seq_lens_ptr=seq_lens,
alibi_slopes_ptr=alibi_slopes,
scale=sm_scale,
k_scale=k_scale,
v_scale=v_scale,
num_query_heads=num_query_heads,
num_queries_per_kv=num_queries_per_kv,
num_queries_per_kv_padded=num_queries_per_kv_padded,
block_table_stride=block_table.stride(0),
query_stride_0=query.stride(0),
query_stride_1=query.stride(1),
output_stride_0=output.stride(0),
output_stride_1=output.stride(1),
BLOCK_SIZE=block_size,
HEAD_SIZE=head_size,
HEAD_SIZE_PADDED=triton.next_power_of_2(head_size),
USE_ALIBI_SLOPES=use_alibi_slopes,
SLIDING_WINDOW=sliding_window,
x=key_cache.shape[4],
stride_k_cache_0=key_cache.stride(0),
stride_k_cache_1=key_cache.stride(1),
stride_k_cache_2=key_cache.stride(2),
stride_k_cache_3=key_cache.stride(3),
stride_k_cache_4=key_cache.stride(4),
stride_v_cache_0=value_cache.stride(0),
stride_v_cache_1=value_cache.stride(1),
stride_v_cache_2=value_cache.stride(2),
stride_v_cache_3=value_cache.stride(3),
filter_by_query_len=True,
query_start_len_ptr=query_start_loc,
)
use_custom = use_rocm_custom_paged_attention(query.dtype, head_size,
block_size,
num_queries_per_kv,
max_seq_len, sliding_window)
if use_custom:
_PARTITION_SIZE_ROCM = 256
max_num_partitions = ((max_seq_len + _PARTITION_SIZE_ROCM - 1) //
_PARTITION_SIZE_ROCM)
assert _PARTITION_SIZE_ROCM % block_size == 0
total_num_seq = query.shape[0]
tmp_output = torch.empty(
size=(total_num_seq, num_query_heads, max_num_partitions,
head_size),
dtype=output.dtype,
device=output.device,
)
exp_sums = torch.empty(
size=(total_num_seq, num_query_heads, max_num_partitions),
dtype=torch.float32,
device=output.device,
)
max_logits = torch.empty_like(exp_sums)
ops.paged_attention_rocm(
output,
exp_sums,
max_logits,
tmp_output,
query,
key_cache,
value_cache,
num_kv_heads,
scale=sm_scale,
block_tables=block_table,
seq_lens=seq_lens,
query_start_loc=query_start_loc,
block_size=block_size,
max_seq_len=max_seq_len,
alibi_slopes=alibi_slopes,
kv_cache_dtype=kv_cache_dtype,
k_scale=k_scale,
v_scale=v_scale,
)
else:
kernel_paged_attention_2d[(
num_seqs,
num_kv_heads,
)](
output_ptr=output,
query_ptr=query,
key_cache_ptr=key_cache,
value_cache_ptr=value_cache,
block_tables_ptr=block_table,
seq_lens_ptr=seq_lens,
alibi_slopes_ptr=alibi_slopes,
scale=sm_scale,
k_scale=k_scale,
v_scale=v_scale,
num_query_heads=num_query_heads,
num_queries_per_kv=num_queries_per_kv,
num_queries_per_kv_padded=num_queries_per_kv_padded,
block_table_stride=block_table.stride(0),
query_stride_0=query.stride(0),
query_stride_1=query.stride(1),
output_stride_0=output.stride(0),
output_stride_1=output.stride(1),
BLOCK_SIZE=block_size,
HEAD_SIZE=head_size,
HEAD_SIZE_PADDED=triton.next_power_of_2(head_size),
USE_ALIBI_SLOPES=use_alibi_slopes,
SLIDING_WINDOW=sliding_window,
x=key_cache.shape[4],
stride_k_cache_0=key_cache.stride(0),
stride_k_cache_1=key_cache.stride(1),
stride_k_cache_2=key_cache.stride(2),
stride_k_cache_3=key_cache.stride(3),
stride_k_cache_4=key_cache.stride(4),
stride_v_cache_0=value_cache.stride(0),
stride_v_cache_1=value_cache.stride(1),
stride_v_cache_2=value_cache.stride(2),
stride_v_cache_3=value_cache.stride(3),
filter_by_query_len=True,
query_start_len_ptr=query_start_loc,
)

View File

@ -209,6 +209,7 @@ class PagedAttention:
v_scale: torch.Tensor,
) -> torch.Tensor:
output = torch.empty_like(query)
max_seq_len = None
context_attention_fwd(
query,
key,
@ -221,6 +222,7 @@ class PagedAttention:
# query_start_loc is (batch_size + 1,)
query_start_loc,
seq_lens_tensor,
max_seq_len,
max_query_len,
k_scale,
v_scale,

View File

@ -725,6 +725,7 @@ if triton.__version__ >= "2.1.0":
b_loc,
b_start_loc,
b_seq_len,
max_seq_len,
max_input_len,
k_scale: torch.Tensor,
v_scale: torch.Tensor,

View File

@ -1,7 +1,7 @@
# SPDX-License-Identifier: Apache-2.0
import os
from functools import lru_cache, wraps
from functools import cache, lru_cache, wraps
from typing import TYPE_CHECKING, Dict, List, Optional
import torch
@ -98,6 +98,25 @@ def device_id_to_physical_device_id(device_id: int) -> int:
return device_id
@cache
def use_rocm_custom_paged_attention(qtype: torch.dtype, head_size: int,
block_size: int, gqa_ratio: int,
max_seq_len: int,
sliding_window: int) -> bool:
GPU_ARCH = torch.cuda.get_device_properties("cuda").gcnArchName
ON_NAVI = "gfx1" in GPU_ARCH
ON_MI250_MI300 = any(arch in GPU_ARCH for arch in ["gfx90a", "gfx942"])
# rocm custom page attention not support on navi (gfx1*)
return (ON_MI250_MI300 and not ON_NAVI and (sliding_window == 0)
and (qtype == torch.half or qtype == torch.bfloat16)
and (head_size == 64 or head_size == 128)
and (block_size == 16 or block_size == 32)
and (gqa_ratio >= 1 and gqa_ratio <= 16) and max_seq_len <= 32768
and envs.VLLM_ROCM_CUSTOM_PAGED_ATTN)
class RocmPlatform(Platform):
_enum = PlatformEnum.ROCM
device_name: str = "rocm"

View File

@ -168,6 +168,7 @@ class TritonAttentionImpl(AttentionImpl):
block_table=attn_metadata.block_table,
query_start_loc=attn_metadata.query_start_loc,
seq_lens=attn_metadata.seq_lens,
max_seq_len=attn_metadata.max_seq_len,
max_query_len=attn_metadata.max_query_len,
k_scale=layer._k_scale,
v_scale=layer._v_scale,