diff --git a/csrc/rocm/attention.cu b/csrc/rocm/attention.cu index 8ab2af22..2c3cae95 100644 --- a/csrc/rocm/attention.cu +++ b/csrc/rocm/attention.cu @@ -272,6 +272,7 @@ __launch_bounds__(NUM_THREADS, 5) void paged_attention_ll4mi_QKV_mfma16_kernel( const float scale, const int* __restrict__ block_tables, // [num_seqs, max_num_blocks_per_seq] const int* __restrict__ context_lens, // [num_seqs] + const int* __restrict__ query_start_loc_ptr, // [num_seqs] const int max_num_blocks_per_seq, const float* __restrict__ alibi_slopes, // [num_heads] const int q_stride, @@ -291,6 +292,13 @@ __launch_bounds__(NUM_THREADS, 5) void paged_attention_ll4mi_QKV_mfma16_kernel( const int rowid = laneid / 16; const auto seq_idx = blockIdx.x; + // NOTE queries with sequence len > 1 are prefills and taken care by another + // kernel. + if (query_start_loc_ptr != nullptr && + (query_start_loc_ptr[seq_idx + 1] - query_start_loc_ptr[seq_idx]) != 1) { + return; + } + const auto partition_idx = blockIdx.y; constexpr int T_PAR_SIZE = 256; // token partition size set to 256 @@ -377,9 +385,10 @@ __launch_bounds__(NUM_THREADS, 5) void paged_attention_ll4mi_QKV_mfma16_kernel( // fetch Q in shared across warps and then write to registers const int local_qhead_idx = 4 * warpid + rowid; const int global_qhead_idx = wg_start_head_idx + local_qhead_idx; - const int64_t seq_idx64 = static_cast(seq_idx); + const int64_t query_start_off = static_cast( + query_start_loc_ptr ? query_start_loc_ptr[seq_idx] : seq_idx); const scalar_t* q_ptr = - q + seq_idx64 * q_stride + global_qhead_idx * HEAD_SIZE; + q + query_start_off * q_stride + global_qhead_idx * HEAD_SIZE; const int qhead_element = lane16id * CONTIGUOUS_SCALAR_ELEMS_16B; if ((local_qhead_idx < GQA_RATIO) && (qhead_element < HEAD_SIZE)) { @@ -777,6 +786,7 @@ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_QKV_mfma4_kernel( const float scale, const int* __restrict__ block_tables, // [num_seqs, max_num_blocks_per_seq] const int* __restrict__ context_lens, // [num_seqs] + const int* __restrict__ query_start_loc_ptr, // [num_seqs] const int max_num_blocks_per_seq, const float* __restrict__ alibi_slopes, // [num_heads] const int q_stride, @@ -794,6 +804,12 @@ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_QKV_mfma4_kernel( const int lane4id = laneid % 4; const auto seq_idx = blockIdx.x; + // NOTE queries with sequence len > 1 are prefills and taken care by another + // kernel. + if (query_start_loc_ptr != nullptr && + (query_start_loc_ptr[seq_idx + 1] - query_start_loc_ptr[seq_idx] != 1)) { + return; + } const auto partition_idx = blockIdx.y; const auto partition_size = blockDim.x; const auto max_num_partitions = gridDim.y; @@ -882,9 +898,11 @@ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_QKV_mfma4_kernel( } // fetch q elements - // every 4 lanes fetch 8 elems, so warp fetches 8*16 = 128 elems + // every 4 lanes fetch 8 elems, so warp fetches 8*16 = 128 elemsc + const int64_t query_start_off = static_cast( + query_start_loc_ptr ? query_start_loc_ptr[seq_idx] : seq_idx); const scalar_t* q_ptr = - q + seq_idx * q_stride + wg_start_head_idx * HEAD_SIZE; + q + query_start_off * q_stride + wg_start_head_idx * HEAD_SIZE; const _B16x8* q_ptrh8 = reinterpret_cast(q_ptr); const int qhead_elemh8 = laneid / 4; @@ -1267,10 +1285,19 @@ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_reduce_kernel( const scalar_t* __restrict__ tmp_out, // [num_seqs, num_heads, // max_num_partitions, head_size] const int* __restrict__ context_lens, // [num_seqs] + const int* __restrict__ query_start_loc_ptr, // [num_seqs] const int max_num_partitions) { const auto num_heads = gridDim.x; const auto head_idx = blockIdx.x; const auto seq_idx = blockIdx.y; + + // NOTE queries with sequence len > 1 are prefills and taken care by another + // kernel. + if (query_start_loc_ptr != nullptr && + (query_start_loc_ptr[seq_idx + 1] - query_start_loc_ptr[seq_idx] != 1)) { + return; + } + const int context_len = context_lens[seq_idx]; const int num_partitions = DIVIDE_ROUND_UP(context_len, PARTITION_SIZE); [[maybe_unused]] constexpr int NUM_WARPS = NUM_THREADS / WARP_SIZE; @@ -1439,7 +1466,9 @@ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_reduce_kernel( __fdividef(1.0f, shared_global_exp_sum + 1e-6f); acc *= inv_global_exp_sum; - OUTT* out_ptr = out + static_cast(seq_idx) * num_heads * HEAD_SIZE + + const int64_t query_start_off = static_cast( + query_start_loc_ptr ? query_start_loc_ptr[seq_idx] : seq_idx); + OUTT* out_ptr = out + query_start_off * num_heads * HEAD_SIZE + static_cast(head_idx) * HEAD_SIZE; if constexpr (std::is_same::value) { out_ptr[threadIdx.x] = @@ -1466,6 +1495,7 @@ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_QKV_mfma16_kernel( const float scale, const int* __restrict__ block_tables, // [num_seqs, max_num_blocks_per_seq] const int* __restrict__ context_lens, // [num_seqs] + const int* __restrict__ query_start_loc_ptr, // [num_seqs] const int max_num_blocks_per_seq, const float* __restrict__ alibi_slopes, // [num_heads] const int q_stride, @@ -1492,6 +1522,7 @@ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_QKV_mfma4_kernel( const float scale, const int* __restrict__ block_tables, // [num_seqs, max_num_blocks_per_seq] const int* __restrict__ context_lens, // [num_seqs] + const int* __restrict__ query_start_loc_ptr, // [num_seqs] const int max_num_blocks_per_seq, const float* __restrict__ alibi_slopes, // [num_heads] const int q_stride, @@ -1515,6 +1546,7 @@ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_reduce_kernel( const float* __restrict__ max_logits, // [num_seqs, num_heads, max_num_partitions] const scalar_t* __restrict__ tmp_out, // [num_seqs, num_heads, max_num_partitions, head_size] const int* __restrict__ context_lens, // [num_seqs] + const int* __restrict__ query_start_loc_ptr, // [num_seqs] const int max_num_partitions) { UNREACHABLE_CODE } @@ -1522,34 +1554,34 @@ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_reduce_kernel( #endif // defined(__HIP__MI300_MI250__) TODO: Add NAVI support -#define LAUNCH_CUSTOM_ATTENTION_MFMA16(GQA_RATIO) \ - paged_attention_ll4mi_QKV_mfma16_kernel \ - <<>>( \ - query_ptr, key_cache_ptr, value_cache_ptr, num_kv_heads, scale, \ - block_tables_ptr, context_lens_ptr, max_num_blocks_per_seq, \ - alibi_slopes_ptr, q_stride, kv_block_stride, kv_head_stride, \ - exp_sums_ptr, max_logits_ptr, tmp_out_ptr, out_ptr, max_ctx_blocks, \ - k_scale_ptr, v_scale_ptr); +#define LAUNCH_CUSTOM_ATTENTION_MFMA16(GQA_RATIO) \ + paged_attention_ll4mi_QKV_mfma16_kernel \ + <<>>( \ + query_ptr, key_cache_ptr, value_cache_ptr, num_kv_heads, scale, \ + block_tables_ptr, context_lens_ptr, query_start_loc_ptr, \ + max_num_blocks_per_seq, alibi_slopes_ptr, q_stride, kv_block_stride, \ + kv_head_stride, exp_sums_ptr, max_logits_ptr, tmp_out_ptr, out_ptr, \ + max_ctx_blocks, k_scale_ptr, v_scale_ptr); -#define LAUNCH_CUSTOM_ATTENTION_MFMA4(GQA_RATIO) \ - paged_attention_ll4mi_QKV_mfma4_kernel \ - <<>>( \ - query_ptr, key_cache_ptr, value_cache_ptr, num_kv_heads, scale, \ - block_tables_ptr, context_lens_ptr, max_num_blocks_per_seq, \ - alibi_slopes_ptr, q_stride, kv_block_stride, kv_head_stride, \ - exp_sums_ptr, max_logits_ptr, tmp_out_ptr, out_ptr, max_ctx_blocks, \ - k_scale_ptr, v_scale_ptr); +#define LAUNCH_CUSTOM_ATTENTION_MFMA4(GQA_RATIO) \ + paged_attention_ll4mi_QKV_mfma4_kernel \ + <<>>( \ + query_ptr, key_cache_ptr, value_cache_ptr, num_kv_heads, scale, \ + block_tables_ptr, context_lens_ptr, query_start_loc_ptr, \ + max_num_blocks_per_seq, alibi_slopes_ptr, q_stride, kv_block_stride, \ + kv_head_stride, exp_sums_ptr, max_logits_ptr, tmp_out_ptr, out_ptr, \ + max_ctx_blocks, k_scale_ptr, v_scale_ptr); #define LAUNCH_CUSTOM_REDUCTION(NPAR_LOOPS) \ paged_attention_ll4mi_reduce_kernel \ <<>>( \ out_ptr, exp_sums_ptr, max_logits_ptr, tmp_out_ptr, \ - context_lens_ptr, max_num_partitions); + context_lens_ptr, query_start_loc_ptr, max_num_partitions); template & alibi_slopes, - torch::Tensor& k_scale, torch::Tensor& v_scale) { - int num_seqs = query.size(0); + const std::optional& query_start_loc, int max_context_len, + const std::optional& alibi_slopes, torch::Tensor& k_scale, + torch::Tensor& v_scale) { + int num_seqs = block_tables.size(0); int num_heads = query.size(1); int head_size = query.size(2); int max_num_blocks_per_seq = block_tables.size(1); @@ -1569,6 +1602,13 @@ void paged_attention_custom_launcher( int kv_block_stride = key_cache.stride(0); int kv_head_stride = key_cache.stride(1); + // NOTE: query start location is optional for V0 decode should not be used. + // If batch contains mix of prefills and decode, prefills should be skipped. + const int* query_start_loc_ptr = + query_start_loc + ? reinterpret_cast(query_start_loc.value().data_ptr()) + : nullptr; + // NOTE: alibi_slopes is optional. const float* alibi_slopes_ptr = alibi_slopes @@ -1700,8 +1740,8 @@ void paged_attention_custom_launcher( paged_attention_custom_launcher( \ out, exp_sums, max_logits, tmp_out, query, key_cache, value_cache, \ - num_kv_heads, scale, block_tables, context_lens, max_context_len, \ - alibi_slopes, k_scale, v_scale); + num_kv_heads, scale, block_tables, context_lens, query_start_loc, \ + max_context_len, alibi_slopes, k_scale, v_scale); #define CALL_CUSTOM_LAUNCHER_ALIBI(T, KVT, KV_DTYPE, BLK_SIZE, HEAD_SIZE, \ PSIZE) \ @@ -1750,6 +1790,7 @@ void paged_attention( double scale, torch::Tensor& block_tables, // [num_seqs, max_num_blocks_per_seq] torch::Tensor& context_lens, // [num_seqs] + const std::optional& query_start_loc, // [num_seqs] int64_t block_size, int64_t max_context_len, const std::optional& alibi_slopes, const std::string& kv_cache_dtype, torch::Tensor& k_scale, diff --git a/csrc/rocm/ops.h b/csrc/rocm/ops.h index ba161951..afb73545 100644 --- a/csrc/rocm/ops.h +++ b/csrc/rocm/ops.h @@ -7,8 +7,9 @@ void paged_attention(torch::Tensor& out, torch::Tensor& exp_sums, torch::Tensor& query, torch::Tensor& key_cache, torch::Tensor& value_cache, int64_t num_kv_heads, double scale, torch::Tensor& block_tables, - torch::Tensor& context_lens, int64_t block_size, - int64_t max_context_len, + torch::Tensor& context_lens, + const std::optional& query_start_loc, + int64_t block_size, int64_t max_context_len, const std::optional& alibi_slopes, const std::string& kv_cache_dtype, torch::Tensor& k_scale, torch::Tensor& v_scale); diff --git a/csrc/rocm/torch_bindings.cpp b/csrc/rocm/torch_bindings.cpp index a5d2e2f9..537e9357 100644 --- a/csrc/rocm/torch_bindings.cpp +++ b/csrc/rocm/torch_bindings.cpp @@ -23,7 +23,9 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, rocm_ops) { " Tensor query, Tensor key_cache," " Tensor value_cache, int num_kv_heads," " float scale, Tensor block_tables," - " Tensor context_lens, int block_size," + " Tensor context_lens," + " Tensor? query_start_loc," + " int block_size," " int max_context_len," " Tensor? alibi_slopes," " str kv_cache_dtype," diff --git a/tests/kernels/test_prefix_prefill.py b/tests/kernels/test_prefix_prefill.py index 50eaa92f..9333777d 100644 --- a/tests/kernels/test_prefix_prefill.py +++ b/tests/kernels/test_prefix_prefill.py @@ -164,6 +164,7 @@ def test_contexted_kv_attention( block_table, b_start_loc, b_seq_len, + MAX_CTX_LEN, max_input_len, k_scale, v_scale, @@ -180,6 +181,7 @@ def test_contexted_kv_attention( block_table, b_start_loc, b_seq_len, + MAX_CTX_LEN, max_input_len, k_scale, v_scale, @@ -397,6 +399,7 @@ def test_contexted_kv_attention_alibi( block_table, b_start_loc, b_seq_len, + MAX_CTX_LEN, max_input_len, k_scale, v_scale, @@ -413,6 +416,7 @@ def test_contexted_kv_attention_alibi( block_table, b_start_loc, b_seq_len, + MAX_CTX_LEN, max_input_len, k_scale, v_scale, diff --git a/vllm/_custom_ops.py b/vllm/_custom_ops.py index fe41a2d9..719e02ec 100644 --- a/vllm/_custom_ops.py +++ b/vllm/_custom_ops.py @@ -110,6 +110,7 @@ def paged_attention_rocm( scale: float, block_tables: torch.Tensor, seq_lens: torch.Tensor, + query_start_loc: Optional[torch.Tensor], block_size: int, max_seq_len: int, alibi_slopes: Optional[torch.Tensor], @@ -120,8 +121,9 @@ def paged_attention_rocm( torch.ops._rocm_C.paged_attention(out, exp_sum, max_logits, tmp_out, query, key_cache, value_cache, num_kv_heads, scale, block_tables, seq_lens, - block_size, max_seq_len, alibi_slopes, - kv_cache_dtype, k_scale, v_scale) + query_start_loc, block_size, max_seq_len, + alibi_slopes, kv_cache_dtype, k_scale, + v_scale) def mla_decode_kvcache_cpu( diff --git a/vllm/attention/backends/rocm_flash_attn.py b/vllm/attention/backends/rocm_flash_attn.py index f19773bb..9a4ee2ae 100644 --- a/vllm/attention/backends/rocm_flash_attn.py +++ b/vllm/attention/backends/rocm_flash_attn.py @@ -17,16 +17,13 @@ from vllm.attention.ops.paged_attn import (PagedAttention, PagedAttentionMetadata) from vllm.logger import init_logger from vllm.platforms import current_platform +from vllm.platforms.rocm import use_rocm_custom_paged_attention if TYPE_CHECKING: from vllm.worker.model_runner import ModelInputForGPUWithSamplingMetadata logger = init_logger(__name__) - _PARTITION_SIZE_ROCM = 256 -_GPU_ARCH = torch.cuda.get_device_properties("cuda").gcnArchName -_ON_NAVI = "gfx1" in _GPU_ARCH -_ON_MI250_MI300 = any(arch in _GPU_ARCH for arch in ["gfx90a", "gfx942"]) class ROCmFlashAttentionBackend(AttentionBackend): @@ -790,9 +787,9 @@ class ROCmFlashAttentionImpl(AttentionImpl): num_seqs, num_heads, head_size = decode_query.shape block_size = value_cache.shape[3] gqa_ratio = num_heads // self.num_kv_heads - use_custom = _use_rocm_custom_paged_attention( + use_custom = use_rocm_custom_paged_attention( decode_query.dtype, head_size, block_size, gqa_ratio, - decode_meta.max_decode_seq_len) + decode_meta.max_decode_seq_len, self.sliding_window) if use_custom: max_seq_len = (decode_meta.max_decode_seq_len if self.attn_type != AttentionType.ENCODER_DECODER else @@ -817,6 +814,8 @@ class ROCmFlashAttentionImpl(AttentionImpl): out = output[num_prefill_tokens:] else: out = output + + query_start_loc = None ops.paged_attention_rocm( out, exp_sums, @@ -833,6 +832,7 @@ class ROCmFlashAttentionImpl(AttentionImpl): decode_meta.seq_lens_tensor if self.attn_type != AttentionType.ENCODER_DECODER else decode_meta.encoder_seq_lens_tensor, + query_start_loc, block_size, max_seq_len, self.alibi_slopes, @@ -898,15 +898,3 @@ def _sdpa_attention( start = end return output - - -def _use_rocm_custom_paged_attention(qtype: torch.dtype, head_size: int, - block_size: int, gqa_ratio: int, - max_seq_len: int) -> bool: - # rocm custom page attention not support on navi (gfx1*) - return (_ON_MI250_MI300 and not _ON_NAVI - and (qtype == torch.half or qtype == torch.bfloat16) - and (head_size == 64 or head_size == 128) - and (block_size == 16 or block_size == 32) - and (gqa_ratio >= 1 and gqa_ratio <= 16) and max_seq_len <= 32768 - and envs.VLLM_ROCM_CUSTOM_PAGED_ATTN) diff --git a/vllm/attention/ops/chunked_prefill_paged_decode.py b/vllm/attention/ops/chunked_prefill_paged_decode.py index 48db3ebf..1b475816 100644 --- a/vllm/attention/ops/chunked_prefill_paged_decode.py +++ b/vllm/attention/ops/chunked_prefill_paged_decode.py @@ -10,6 +10,9 @@ import torch import triton import triton.language as tl +from vllm import _custom_ops as ops +from vllm.platforms.rocm import use_rocm_custom_paged_attention + from .prefix_prefill import context_attention_fwd @@ -33,26 +36,26 @@ def kernel_paged_attention_2d( num_query_heads: tl.constexpr, # int num_queries_per_kv: tl.constexpr, # int num_queries_per_kv_padded: tl.constexpr, # int - block_table_stride: tl.constexpr, # int - query_stride_0: tl.constexpr, # int - query_stride_1: tl.constexpr, # int, should be equal to head_size - output_stride_0: tl.constexpr, # int - output_stride_1: tl.constexpr, # int, should be equal to head_size + block_table_stride: tl.int64, # int + query_stride_0: tl.int64, # int + query_stride_1: tl.int64, # int, should be equal to head_size + output_stride_0: tl.int64, # int + output_stride_1: tl.int64, # int, should be equal to head_size BLOCK_SIZE: tl.constexpr, # int HEAD_SIZE: tl.constexpr, # int HEAD_SIZE_PADDED: tl.constexpr, # int, must be power of 2 USE_ALIBI_SLOPES: tl.constexpr, # bool SLIDING_WINDOW: tl.constexpr, # int x: tl.constexpr, # int - stride_k_cache_0: tl.constexpr, # int - stride_k_cache_1: tl.constexpr, # int - stride_k_cache_2: tl.constexpr, # int - stride_k_cache_3: tl.constexpr, # int - stride_k_cache_4: tl.constexpr, # int - stride_v_cache_0: tl.constexpr, # int - stride_v_cache_1: tl.constexpr, # int - stride_v_cache_2: tl.constexpr, # int - stride_v_cache_3: tl.constexpr, # int + stride_k_cache_0: tl.int64, # int + stride_k_cache_1: tl.int64, # int + stride_k_cache_2: tl.int64, # int + stride_k_cache_3: tl.int64, # int + stride_k_cache_4: tl.int64, # int + stride_v_cache_0: tl.int64, # int + stride_v_cache_1: tl.int64, # int + stride_v_cache_2: tl.int64, # int + stride_v_cache_3: tl.int64, # int filter_by_query_len: tl.constexpr, # bool query_start_len_ptr, # [num_seqs+1] ): @@ -212,6 +215,7 @@ def chunked_prefill_paged_decode( block_table, query_start_loc, seq_lens, + max_seq_len, max_query_len, k_scale, v_scale, @@ -240,6 +244,7 @@ def chunked_prefill_paged_decode( b_loc=block_table, b_start_loc=query_start_loc, b_seq_len=seq_lens, + max_seq_len=max_seq_len, max_input_len=max_query_len, k_scale=k_scale, v_scale=v_scale, @@ -275,43 +280,87 @@ def chunked_prefill_paged_decode( num_queries_per_kv_padded = max(triton.next_power_of_2(num_queries_per_kv), 16) - kernel_paged_attention_2d[( - num_seqs, - num_kv_heads, - )]( - output_ptr=output, - query_ptr=query, - key_cache_ptr=key_cache, - value_cache_ptr=value_cache, - block_tables_ptr=block_table, - seq_lens_ptr=seq_lens, - alibi_slopes_ptr=alibi_slopes, - scale=sm_scale, - k_scale=k_scale, - v_scale=v_scale, - num_query_heads=num_query_heads, - num_queries_per_kv=num_queries_per_kv, - num_queries_per_kv_padded=num_queries_per_kv_padded, - block_table_stride=block_table.stride(0), - query_stride_0=query.stride(0), - query_stride_1=query.stride(1), - output_stride_0=output.stride(0), - output_stride_1=output.stride(1), - BLOCK_SIZE=block_size, - HEAD_SIZE=head_size, - HEAD_SIZE_PADDED=triton.next_power_of_2(head_size), - USE_ALIBI_SLOPES=use_alibi_slopes, - SLIDING_WINDOW=sliding_window, - x=key_cache.shape[4], - stride_k_cache_0=key_cache.stride(0), - stride_k_cache_1=key_cache.stride(1), - stride_k_cache_2=key_cache.stride(2), - stride_k_cache_3=key_cache.stride(3), - stride_k_cache_4=key_cache.stride(4), - stride_v_cache_0=value_cache.stride(0), - stride_v_cache_1=value_cache.stride(1), - stride_v_cache_2=value_cache.stride(2), - stride_v_cache_3=value_cache.stride(3), - filter_by_query_len=True, - query_start_len_ptr=query_start_loc, - ) + use_custom = use_rocm_custom_paged_attention(query.dtype, head_size, + block_size, + num_queries_per_kv, + max_seq_len, sliding_window) + if use_custom: + _PARTITION_SIZE_ROCM = 256 + max_num_partitions = ((max_seq_len + _PARTITION_SIZE_ROCM - 1) // + _PARTITION_SIZE_ROCM) + assert _PARTITION_SIZE_ROCM % block_size == 0 + total_num_seq = query.shape[0] + tmp_output = torch.empty( + size=(total_num_seq, num_query_heads, max_num_partitions, + head_size), + dtype=output.dtype, + device=output.device, + ) + exp_sums = torch.empty( + size=(total_num_seq, num_query_heads, max_num_partitions), + dtype=torch.float32, + device=output.device, + ) + max_logits = torch.empty_like(exp_sums) + + ops.paged_attention_rocm( + output, + exp_sums, + max_logits, + tmp_output, + query, + key_cache, + value_cache, + num_kv_heads, + scale=sm_scale, + block_tables=block_table, + seq_lens=seq_lens, + query_start_loc=query_start_loc, + block_size=block_size, + max_seq_len=max_seq_len, + alibi_slopes=alibi_slopes, + kv_cache_dtype=kv_cache_dtype, + k_scale=k_scale, + v_scale=v_scale, + ) + else: + kernel_paged_attention_2d[( + num_seqs, + num_kv_heads, + )]( + output_ptr=output, + query_ptr=query, + key_cache_ptr=key_cache, + value_cache_ptr=value_cache, + block_tables_ptr=block_table, + seq_lens_ptr=seq_lens, + alibi_slopes_ptr=alibi_slopes, + scale=sm_scale, + k_scale=k_scale, + v_scale=v_scale, + num_query_heads=num_query_heads, + num_queries_per_kv=num_queries_per_kv, + num_queries_per_kv_padded=num_queries_per_kv_padded, + block_table_stride=block_table.stride(0), + query_stride_0=query.stride(0), + query_stride_1=query.stride(1), + output_stride_0=output.stride(0), + output_stride_1=output.stride(1), + BLOCK_SIZE=block_size, + HEAD_SIZE=head_size, + HEAD_SIZE_PADDED=triton.next_power_of_2(head_size), + USE_ALIBI_SLOPES=use_alibi_slopes, + SLIDING_WINDOW=sliding_window, + x=key_cache.shape[4], + stride_k_cache_0=key_cache.stride(0), + stride_k_cache_1=key_cache.stride(1), + stride_k_cache_2=key_cache.stride(2), + stride_k_cache_3=key_cache.stride(3), + stride_k_cache_4=key_cache.stride(4), + stride_v_cache_0=value_cache.stride(0), + stride_v_cache_1=value_cache.stride(1), + stride_v_cache_2=value_cache.stride(2), + stride_v_cache_3=value_cache.stride(3), + filter_by_query_len=True, + query_start_len_ptr=query_start_loc, + ) diff --git a/vllm/attention/ops/paged_attn.py b/vllm/attention/ops/paged_attn.py index fd703413..827c3041 100644 --- a/vllm/attention/ops/paged_attn.py +++ b/vllm/attention/ops/paged_attn.py @@ -209,6 +209,7 @@ class PagedAttention: v_scale: torch.Tensor, ) -> torch.Tensor: output = torch.empty_like(query) + max_seq_len = None context_attention_fwd( query, key, @@ -221,6 +222,7 @@ class PagedAttention: # query_start_loc is (batch_size + 1,) query_start_loc, seq_lens_tensor, + max_seq_len, max_query_len, k_scale, v_scale, diff --git a/vllm/attention/ops/prefix_prefill.py b/vllm/attention/ops/prefix_prefill.py index e85ec605..49ba476d 100644 --- a/vllm/attention/ops/prefix_prefill.py +++ b/vllm/attention/ops/prefix_prefill.py @@ -725,6 +725,7 @@ if triton.__version__ >= "2.1.0": b_loc, b_start_loc, b_seq_len, + max_seq_len, max_input_len, k_scale: torch.Tensor, v_scale: torch.Tensor, diff --git a/vllm/platforms/rocm.py b/vllm/platforms/rocm.py index 89b778c7..1d071430 100644 --- a/vllm/platforms/rocm.py +++ b/vllm/platforms/rocm.py @@ -1,7 +1,7 @@ # SPDX-License-Identifier: Apache-2.0 import os -from functools import lru_cache, wraps +from functools import cache, lru_cache, wraps from typing import TYPE_CHECKING, Dict, List, Optional import torch @@ -98,6 +98,25 @@ def device_id_to_physical_device_id(device_id: int) -> int: return device_id +@cache +def use_rocm_custom_paged_attention(qtype: torch.dtype, head_size: int, + block_size: int, gqa_ratio: int, + max_seq_len: int, + sliding_window: int) -> bool: + + GPU_ARCH = torch.cuda.get_device_properties("cuda").gcnArchName + ON_NAVI = "gfx1" in GPU_ARCH + ON_MI250_MI300 = any(arch in GPU_ARCH for arch in ["gfx90a", "gfx942"]) + + # rocm custom page attention not support on navi (gfx1*) + return (ON_MI250_MI300 and not ON_NAVI and (sliding_window == 0) + and (qtype == torch.half or qtype == torch.bfloat16) + and (head_size == 64 or head_size == 128) + and (block_size == 16 or block_size == 32) + and (gqa_ratio >= 1 and gqa_ratio <= 16) and max_seq_len <= 32768 + and envs.VLLM_ROCM_CUSTOM_PAGED_ATTN) + + class RocmPlatform(Platform): _enum = PlatformEnum.ROCM device_name: str = "rocm" diff --git a/vllm/v1/attention/backends/triton_attn.py b/vllm/v1/attention/backends/triton_attn.py index f11f2b62..15b49b14 100644 --- a/vllm/v1/attention/backends/triton_attn.py +++ b/vllm/v1/attention/backends/triton_attn.py @@ -168,6 +168,7 @@ class TritonAttentionImpl(AttentionImpl): block_table=attn_metadata.block_table, query_start_loc=attn_metadata.query_start_loc, seq_lens=attn_metadata.seq_lens, + max_seq_len=attn_metadata.max_seq_len, max_query_len=attn_metadata.max_query_len, k_scale=layer._k_scale, v_scale=layer._v_scale,