diff --git a/csrc/cache.h b/csrc/cache.h index cf4a65c2..0970b704 100644 --- a/csrc/cache.h +++ b/csrc/cache.h @@ -39,3 +39,10 @@ void concat_and_cache_mla(torch::Tensor& kv_c, torch::Tensor& k_pe, // Just for unittest void convert_fp8(torch::Tensor& dst_cache, torch::Tensor& src_cache, const double scale, const std::string& kv_cache_dtype); + +void gather_cache( + torch::Tensor const& src_cache, // [NUM_BLOCKS, BLOCK_SIZE, ENTRIES...] + torch::Tensor const& dst, // [TOT_TOKENS, ENTRIES...] + torch::Tensor const& block_table, // [BATCH, BLOCK_INDICES] + torch::Tensor const& cu_seq_lens, // [BATCH+1] + int64_t batch_size, std::optional seq_starts = std::nullopt); \ No newline at end of file diff --git a/csrc/cache_kernels.cu b/csrc/cache_kernels.cu index 0960888d..a6f8602a 100644 --- a/csrc/cache_kernels.cu +++ b/csrc/cache_kernels.cu @@ -2,6 +2,7 @@ #include #include +#include "cuda_utils.h" #include "cuda_compat.h" #include "dispatch_utils.h" @@ -570,3 +571,161 @@ void convert_fp8(torch::Tensor& dst_cache, torch::Tensor& src_cache, TORCH_CHECK(false, "Unsupported data type: ", kv_cache_dtype); } } + +namespace vllm { + +// grid is launched with dimensions (batch, num_splits) +template +__global__ void gather_cache( + const scalar_t* __restrict__ src_cache, // [NUM_BLOCKS, BLOCK_SIZE, + // ENTRIES...] + scalar_t* __restrict__ dst, // [TOT_TOKENS, ENTRIES...] + const int32_t* __restrict__ block_table, // [BATCH, BLOCK_INDICES] + const int32_t* __restrict__ cu_seq_lens, // [BATCH+1] + const int32_t block_size, const int32_t entry_size, + const int64_t block_table_stride, const int64_t cache_block_stride, + const int64_t cache_entry_stride, const int64_t dst_entry_stride, + const int32_t* __restrict__ seq_starts) { // Optional: starting offsets per + // batch + + const int64_t bid = blockIdx.x; // Batch ID + const int32_t num_splits = gridDim.y; + const int32_t split = blockIdx.y; + const int32_t seq_start = cu_seq_lens[bid]; + const int32_t seq_end = cu_seq_lens[bid + 1]; + const int32_t seq_len = seq_end - seq_start; + const int32_t tot_blocks = cuda_utils::ceil_div(seq_len, block_size); + const int32_t split_blocks = cuda_utils::ceil_div(tot_blocks, num_splits); + + const int32_t split_start = split * split_blocks; + const int32_t split_end = min((split + 1) * split_blocks, tot_blocks); + + const bool is_active_split = (split_start < tot_blocks); + const bool is_last_split = (split_end == tot_blocks); + + if (!is_active_split) return; + + int32_t full_blocks_end = split_end; + int32_t partial_block_size = 0; + + // Adjust the pointer for the block_table for this batch. + // If seq_starts is provided, compute an offset based on (seq_starts[bid] / + // page_size) + const int32_t batch_offset = bid * block_table_stride; + int32_t offset = 0; + if (seq_starts != nullptr) { + offset = seq_starts[bid] / block_size; + } + const int32_t* batch_block_table = block_table + batch_offset + offset; + + // Adjust dst pointer based on the cumulative sequence lengths. + dst += seq_start * dst_entry_stride; + + if (is_last_split) { + partial_block_size = seq_len % block_size; + if (partial_block_size) full_blocks_end -= 1; + } + + auto copy_entry = [&](const scalar_t* __restrict__ _src, + scalar_t* __restrict__ _dst) { + for (int i = threadIdx.x; i < entry_size; i += blockDim.x) + _dst[i] = _src[i]; + }; + + for (int pid = split_start; pid < full_blocks_end; ++pid) { + auto block_id = batch_block_table[pid]; + auto block_start_ptr = src_cache + block_id * cache_block_stride; + auto block_dst_ptr = dst + pid * block_size * dst_entry_stride; + for (int eid = 0; eid < block_size; ++eid) { + copy_entry(block_start_ptr + eid * cache_entry_stride, + block_dst_ptr + eid * dst_entry_stride); + } + } + + if (partial_block_size) { + auto block_id = batch_block_table[full_blocks_end]; + auto block_start_ptr = src_cache + block_id * cache_block_stride; + auto block_dst_ptr = dst + full_blocks_end * block_size * dst_entry_stride; + for (int eid = 0; eid < partial_block_size; ++eid) { + copy_entry(block_start_ptr + eid * cache_entry_stride, + block_dst_ptr + eid * dst_entry_stride); + } + } +} + +} // namespace vllm + +// Macro to dispatch the kernel based on the data type. +#define CALL_GATHER_CACHE(CPY_DTYPE) \ + vllm::gather_cache<<>>( \ + reinterpret_cast(src_cache.data_ptr()), \ + reinterpret_cast(dst.data_ptr()), \ + block_table.data_ptr(), cu_seq_lens.data_ptr(), \ + block_size, entry_size, block_table_stride, cache_block_stride, \ + cache_entry_stride, dst_entry_stride, seq_starts_ptr); + +// Gather sequences from the cache into the destination tensor. +// - cu_seq_lens contains the cumulative sequence lengths for each batch +// - block_table contains the cache block indices for each sequence +// - Optionally, seq_starts (if provided) offsets the starting block index by +// (seq_starts[bid] / page_size) +void gather_cache( + torch::Tensor const& src_cache, // [NUM_BLOCKS, BLOCK_SIZE, ENTRIES...] + torch::Tensor const& dst, // [TOT_TOKENS, ENTRIES...] + torch::Tensor const& block_table, // [BATCH, BLOCK_INDICES] + torch::Tensor const& cu_seq_lens, // [BATCH+1] + int64_t batch_size, + std::optional seq_starts = std::nullopt) { + at::cuda::OptionalCUDAGuard device_guard(src_cache.device()); + const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + + int32_t block_size = src_cache.size(1); + int32_t entry_size = src_cache.flatten(2, -1).size(2); + + TORCH_CHECK(block_table.dtype() == torch::kInt32, + "block_table must be int32"); + TORCH_CHECK(cu_seq_lens.dtype() == torch::kInt32, + "cu_seq_lens must be int32"); + if (seq_starts.has_value()) { + TORCH_CHECK(seq_starts.value().dtype() == torch::kInt32, + "seq_starts must be int32"); + } + + TORCH_CHECK(src_cache.device() == dst.device(), + "src_cache and dst must be on the same device"); + TORCH_CHECK(src_cache.device() == block_table.device(), + "src_cache and block_table must be on the same device"); + TORCH_CHECK(src_cache.device() == cu_seq_lens.device(), + "src_cache and cu_seq_lens must be on the same device"); + if (seq_starts.has_value()) { + TORCH_CHECK(src_cache.device() == seq_starts.value().device(), + "src_cache and seq_starts must be on the same device"); + } + + int64_t block_table_stride = block_table.stride(0); + int64_t cache_block_stride = src_cache.stride(0); + int64_t cache_entry_stride = src_cache.stride(1); + int64_t dst_entry_stride = dst.stride(0); + + // Decide on the number of splits based on the batch size. + int num_splits = batch_size > 128 ? 2 : batch_size > 64 ? 4 : 16; + dim3 grid(batch_size, num_splits); + dim3 block(1024); + + TORCH_CHECK(src_cache.dtype() == dst.dtype(), + "src_cache and dst must have the same dtype"); + + const int dtype_bits = src_cache.element_size() * 8; + const int32_t* seq_starts_ptr = + seq_starts.has_value() ? seq_starts.value().data_ptr() : nullptr; + + if (dtype_bits == 32) { + CALL_GATHER_CACHE(uint32_t); + } else if (dtype_bits == 16) { + CALL_GATHER_CACHE(uint16_t); + } else if (dtype_bits == 8) { + CALL_GATHER_CACHE(uint8_t); + } else { + TORCH_CHECK(false, "Unsupported data type width: ", dtype_bits); + } +} diff --git a/csrc/core/math.hpp b/csrc/core/math.hpp index ddfaca27..b8171133 100644 --- a/csrc/core/math.hpp +++ b/csrc/core/math.hpp @@ -7,8 +7,3 @@ inline constexpr uint32_t next_pow_2(uint32_t const num) { if (num <= 1) return num; return 1 << (CHAR_BIT * sizeof(num) - __builtin_clz(num - 1)); } - -template -inline constexpr std::enable_if_t, T> ceil_div(T a, T b) { - return (a + b - 1) / b; -} \ No newline at end of file diff --git a/csrc/cuda_utils.h b/csrc/cuda_utils.h index 6f79d2b7..6e62ea20 100644 --- a/csrc/cuda_utils.h +++ b/csrc/cuda_utils.h @@ -2,10 +2,14 @@ #include -#if defined(__CUDACC__) || defined(_NVHPC_CUDA) - #define HOST_DEVICE_INLINE __forceinline__ __host__ __device__ - #define DEVICE_INLINE __forceinline__ __device__ - #define HOST_INLINE __forceinline__ __host__ +#if defined(__HIPCC__) + #define HOST_DEVICE_INLINE __host__ __device__ + #define DEVICE_INLINE __device__ + #define HOST_INLINE __host__ +#elif defined(__CUDACC__) || defined(_NVHPC_CUDA) + #define HOST_DEVICE_INLINE __host__ __device__ __forceinline__ + #define DEVICE_INLINE __device__ __forceinline__ + #define HOST_INLINE __host__ __forceinline__ #else #define HOST_DEVICE_INLINE inline #define DEVICE_INLINE inline @@ -25,3 +29,13 @@ int64_t get_device_attribute(int64_t attribute, int64_t device_id); int64_t get_max_shared_memory_per_block_device_attribute(int64_t device_id); + +namespace cuda_utils { + +template +HOST_DEVICE_INLINE constexpr std::enable_if_t, T> +ceil_div(T a, T b) { + return (a + b - 1) / b; +} + +}; // namespace cuda_utils \ No newline at end of file diff --git a/csrc/quantization/cutlass_w8a8/scaled_mm_c3x.cu b/csrc/quantization/cutlass_w8a8/scaled_mm_c3x.cu index e40f2822..53921abc 100644 --- a/csrc/quantization/cutlass_w8a8/scaled_mm_c3x.cu +++ b/csrc/quantization/cutlass_w8a8/scaled_mm_c3x.cu @@ -1,7 +1,7 @@ #include #include "c3x/scaled_mm_kernels.hpp" -#include "core/math.hpp" +#include "cuda_utils.h" /* This file defines quantized GEMM operations using the CUTLASS 3.x API, for @@ -33,7 +33,8 @@ void cutlass_scaled_mm_sm90(torch::Tensor& c, torch::Tensor const& a, auto make_group_shape = [](torch::Tensor const& x, torch::Tensor const& s) -> GroupShape { TORCH_CHECK(s.dim() == 2, "cutlass_scaled_mm group scales must be 2D"); - return {ceil_div(x.size(0), s.size(0)), ceil_div(x.size(1), s.size(1))}; + return {cuda_utils::ceil_div(x.size(0), s.size(0)), + cuda_utils::ceil_div(x.size(1), s.size(1))}; }; GroupShape a_scale_group_shape = make_group_shape(a, a_scales); diff --git a/csrc/torch_bindings.cpp b/csrc/torch_bindings.cpp index ef81db14..d2aecba4 100644 --- a/csrc/torch_bindings.cpp +++ b/csrc/torch_bindings.cpp @@ -493,6 +493,12 @@ TORCH_LIBRARY_EXPAND(CONCAT(TORCH_EXTENSION_NAME, _cache_ops), cache_ops) { "convert_fp8(Tensor! dst_cache, Tensor src_cache, float scale, " "str kv_cache_dtype) -> ()"); cache_ops.impl("convert_fp8", torch::kCUDA, &convert_fp8); + + // Gather cache blocks from src_cache to dst. + cache_ops.def( + "gather_cache(Tensor src_cache, Tensor! dst, Tensor block_table, " + "Tensor cu_seq_lens, int batch_size, Tensor? seq_starts) -> ()"); + cache_ops.impl("gather_cache", torch::kCUDA, &gather_cache); } TORCH_LIBRARY_EXPAND(CONCAT(TORCH_EXTENSION_NAME, _cuda_utils), cuda_utils) { diff --git a/tests/kernels/test_cache.py b/tests/kernels/test_cache.py index 21c02c5d..b8b5e204 100644 --- a/tests/kernels/test_cache.py +++ b/tests/kernels/test_cache.py @@ -682,8 +682,6 @@ def test_swap_blocks_mla( torch.ops._C_cache_ops.swap_blocks, (src_cache, dst_cache, block_mapping_tensor), test_utils=DEFAULT_OPCHECK_TEST_UTILS, - cond=(kv_lora_rank == KV_LORA_RANKS[0] - and qk_rope_head_dim == QK_ROPE_HEAD_DIMS[0]), ) ops.swap_blocks(src_cache, dst_cache, block_mapping_tensor) @@ -694,3 +692,76 @@ def test_swap_blocks_mla( dst_cache[dst].cpu(), msg=f"Block {src} from src should have been swapped to block " f"{dst} in dst_cache.") + + +@pytest.mark.parametrize("kv_lora_rank", [512]) +@pytest.mark.parametrize("qk_rope_head_dim", [64]) +@pytest.mark.parametrize("block_size", [16]) +@pytest.mark.parametrize("num_blocks", [1024]) +@pytest.mark.parametrize("max_seq_len", [512]) +@pytest.mark.parametrize("batch_size", [8]) +@pytest.mark.parametrize("dtype", [torch.float32]) +@pytest.mark.parametrize("kv_cache_dtype", + ["auto"]) # You can also test "fp8" if needed. +@pytest.mark.parametrize("align_cache", [True, False]) +@pytest.mark.parametrize("device", CUDA_DEVICES) +@torch.inference_mode() +def test_gather_cache_mla(kv_lora_rank, qk_rope_head_dim, block_size, + num_blocks, max_seq_len, batch_size, dtype, + kv_cache_dtype, align_cache, device): + entry_size = kv_lora_rank + qk_rope_head_dim + src_cache = _create_mla_cache(num_blocks, block_size, entry_size, dtype, + kv_cache_dtype, device, align_cache) + _fill_mla_cache(src_cache, kv_cache_dtype=kv_cache_dtype) + + seq_len_tensor = torch.randint(0, + max_seq_len + 1, (batch_size, ), + device=device) + + total_tokens = seq_len_tensor.sum() + cu_seq_lens = torch.empty((batch_size + 1), + dtype=torch.int32, + device=device) + cu_seq_lens[0] = 0 + cu_seq_lens[1:] = seq_len_tensor.cumsum(dim=0).to(dtype=torch.int32) + print("seq_len_tensor", seq_len_tensor) + + tot_blocks_tensor = (seq_len_tensor + block_size - 1) // block_size + block_table = torch.empty((batch_size, num_blocks), + dtype=torch.int32, + device=device) + + for b in range(batch_size): + perm = torch.randperm(num_blocks, device=device) + block_table[b, :] = perm + + dst = torch.zeros((total_tokens, entry_size), + dtype=src_cache.dtype, + device=device) + + expected_batches = [] + for b in range(batch_size): + s = seq_len_tensor[b] + if s == 0: + continue + tot = tot_blocks_tensor[b] + blocks = block_table[b, :tot].tolist() + + gathered_rows = [] + for i in range(tot - 1): + gathered_rows.append(src_cache[blocks[i]]) + remaining = s - (tot - 1) * block_size + gathered_rows.append(src_cache[blocks[-1], :remaining, :]) + + batch_expected = torch.cat(gathered_rows, dim=0) + expected_batches.append(batch_expected) + expected = torch.cat(expected_batches, dim=0) + + opcheck( + torch.ops._C_cache_ops.gather_cache, + (src_cache, dst, block_table, cu_seq_lens, batch_size, None), + test_utils=DEFAULT_OPCHECK_TEST_UTILS, + ) + + ops.gather_cache(src_cache, dst, block_table, cu_seq_lens, batch_size) + torch.testing.assert_close(dst, expected) diff --git a/vllm/_custom_ops.py b/vllm/_custom_ops.py index e3e3c644..2112af12 100644 --- a/vllm/_custom_ops.py +++ b/vllm/_custom_ops.py @@ -1099,6 +1099,16 @@ def convert_fp8(output: torch.Tensor, torch.ops._C_cache_ops.convert_fp8(output, input, scale, kv_dtype) +def gather_cache(src_cache: torch.Tensor, + dst: torch.Tensor, + block_table: torch.Tensor, + cu_seq_lens: torch.Tensor, + batch_size: int, + seq_starts: Optional[torch.Tensor] = None) -> None: + torch.ops._C_cache_ops.gather_cache(src_cache, dst, block_table, + cu_seq_lens, batch_size, seq_starts) + + def get_device_attribute(attribute: int, device: int) -> int: return torch.ops._C_cuda_utils.get_device_attribute(attribute, device) diff --git a/vllm/attention/__init__.py b/vllm/attention/__init__.py index 85c5715f..89229e7b 100644 --- a/vllm/attention/__init__.py +++ b/vllm/attention/__init__.py @@ -4,16 +4,12 @@ from vllm.attention.backends.abstract import (AttentionBackend, AttentionMetadata, AttentionMetadataBuilder, AttentionState, AttentionType) +from vllm.attention.backends.utils import get_flash_attn_version from vllm.attention.layer import Attention from vllm.attention.selector import get_attn_backend __all__ = [ - "Attention", - "AttentionBackend", - "AttentionMetadata", - "AttentionType", - "AttentionMetadataBuilder", - "Attention", - "AttentionState", - "get_attn_backend", + "Attention", "AttentionBackend", "AttentionMetadata", "AttentionType", + "AttentionMetadataBuilder", "Attention", "AttentionState", + "get_attn_backend", "get_flash_attn_version" ] diff --git a/vllm/attention/backends/mla/common.py b/vllm/attention/backends/mla/common.py new file mode 100644 index 00000000..c3dbbdb8 --- /dev/null +++ b/vllm/attention/backends/mla/common.py @@ -0,0 +1,1503 @@ +# SPDX-License-Identifier: Apache-2.0 +""" +This file implements common components for MLA implementations. + +First we define: + +Sq as Q sequence length +Skv as KV sequence length + +MLA has two possible ways of computing, a data-movement friendly approach and a +compute friendly approach, we generally want to use the compute friendly +approach for "prefill" (i.e. the ratio Sq / Skv is "small", is near 1) +and the data-movement friendly approach for "decode" (i.e. the ratio +Sq / Skv is "large"). + +NOTE what we deem small and large is currently determined by if its labelled +prefill or decode by the scheduler, but this is something we should probably +tune. + +Main reference: DeepseekV2 paper, and FlashInfer Implementation +(https://arxiv.org/abs/2405.04434 and https://github.com/flashinfer-ai/flashinfer/pull/551). + +Deepseek's MLA attention works the following way: +* Use a single latent vector to represent the per-token entry of the KV cache. +* For decode (i.e. the memory friendly approach) the attention "simulates" a +multi-head attention, while the compute is similar to multi-query attention. + +Below is example of both paths assuming batchsize = 1 + +## More Extent Definitions: + +C Context length, `Skv - Sq` +H hidden size +N number of attention heads +Lq latent dimension for Q 1536 in DSV3 +Lkv latent dimension for K/V 512 in DSV3 +P nope dimension, no rope. 128 in DSV3 +R rope dimension, goes through rope. 64 in DSV3 +V V head dim. 128 in DSV3 + +## Vector/Matrix Definitions + +h_t hidden states (input to attention) shape [Sq, H] +q_c latent/compressed Q shape [Sq, Lq] +q_nope uncompressed Q (no-rope) shape [Sq, N, P] +q_pe uncompressed Q (rope) shape [Sq, N, R] +kv_c latent/compressed KV shape [Skv, Lkv] +k_pe decoupled k position embeddings shape [Skv, R] +new_kv_c new kv_c from current iter shape [Sq, Lkv] +new_k_pe new k_pe from current iter shape [Sq, R] +cache_kv_c cached k_c from previous iters shape [C, Lkv] +cache_k_pe cached k_pe from previous iters shape [C, R] +W_DQ project h_t to q_c shape [H, Lq] +W_UQ project q_c to q_nope shape [Lq, N * P] +W_QR project q_c to q_pe shape [Lq, N * R] +W_DKV project h_t to kv_c shape [H, Lkv] +W_UK project kv_c to k_nope shape [Lkv, N * P] +W_KR project h_t to k_pe shape [H, N * R] +W_UV project kv_c to v shape [Lkv, N * V] +W_O project v to h_t shape [N * V, H] + + +## Compute Friendly Approach (i.e. "_forward_prefill"): + +q_c = h_t @ W_DQ +q_nope = (q_c @ W_UQ).view(Sq, N, P) +q_pe = RoPE(q_c @ W_QR).view(Sq, N, R) +new_kv_c = h_t @ W_DKV +new_k_pe = RoPE(h_t @ W_KR) +kv_c = torch.cat([new_kv_c, cache_kv_c], dim=0) +k_pe = torch.cat([new_k_pe, cache_k_pe], dim=0) +k_nope = (kv_c @ W_UK).view(Skv, N, P) +v = (kv_c @ W_UV).view(Skv, N, V) + +// MHA with QK headdim = P + R +// V headdim = V +// spda_o shape [Sq, N, V] +spda_o = scaled_dot_product_attention( + torch.cat([q_nope, q_pe], dim=-1), + torch.cat([k_nope, k_pe.unsqueeze(1).expand(-1, N, -1)], dim=-1), + v +) +return spda_o @ W_O + +NOTE: in the actual code, + `kv_b_proj` is [W_UK; W_UV] concatnated per head + `q_b_proj` is [W_UQ; W_QR] concatnated per head + `out_proj` is W_O + + +## Data-Movement Friendly Approach (i.e. "_forward_decode"): + +Ahead of time, compute: + +% this projects from q_c to [Sq, N * Lkv] +W_UQ_UK = einsum("qnp,knp -> qnk" + W_UQ.view(Lq, N, P), W_UK.view(Lkv, N, P) + ).view(Lkv, N * Lkv) +% this projects from attn output [Sq, N * Lkv] to [Sq, H] +W_UV_O = einsum("knv,nvh -> nkh" + W_UV.view(Lkv, N, V), W_O.view(N, V, H) + ).view(N * Lkv, H) + +Runtime +q_c = h_t @ W_DQ +q_latent = q_c @ W_UQ_UK.view(Sq, N, Lkv) +q_pe = RoPE(q_c @ W_QR).view(Sq, N, R) +new_kv_c = h_t @ W_DKV +new_k_pe = RoPE(h_t @ W_KR) +kv_c = torch.cat([new_kv_c, cache_kv_c], dim=0) +k_pe = torch.cat([new_k_pe, cache_k_pe], dim=0) + +// MQA with QK headdim = Lkv + R +// V headdim = Lkv +// spda_o shape [Sq, N, Lkv] +// NOTE: this is less compute-friendly since Lkv > P +// but is more data-movement friendly since its MQA vs MHA +spda_o = scaled_dot_product_attention( + torch.cat([q_latent, q_pe], dim=-1), + torch.cat([kv_c, k_pe], dim=-1), + kv_c +) +return spda_o.reshape(-1, N * Lkv) @ W_UV_O + + +## Chunked Prefill + +For chunked prefill we want to use the compute friendly algorithm. We are +assuming sufficiently large Sq / Skv ratio, in the future may want to switch to +the data-movement friendly approach if the chunk (i.e. `Sq`) is small. + +However, the compute-friendly approach can potentially run out of memory if Skv +is large due to: `k_nope = (kv_c @ W_UK).view(Skv, N, P)` + +To mitigate this, we chunk the computation of attention with respect to the +current context (i.e. `cache_kv_c` and `cache_k_pe`) so that we can used a +fixed workspace size. + +The chunked prefill approach is as follows: + +MCC Max chunk of context to process per iter, computed dynamically, + used to bound the memory usage + +q_c = h_t @ W_DQ +q_nope = (q_c @ W_UQ).view(Sq, N, P) +q_pe = RoPE(q_c @ W_QR).view(Sq, N, R) +new_kv_c = h_t @ W_DKV +new_k_pe = RoPE(h_t @ W_KR) +new_k_nope = (new_kv_c @ W_UK).view(Sq, N, P) +new_v = (new_kv_c @ W_UV).view(Sq, N, V) + +// MHA between queries and new KV +// with QK headdim = P + R +// V headdim = V +// curr_o shape [Sq, N, V] +// curr_lse shape [N, Sq], this is just order FA returns +curr_o, curr_lse = scaled_dot_product_attention( + torch.cat([q_nope, q_pe], dim=-1), + torch.cat([new_k_nope, new_k_pe.unsqueeze(1).expand(-1, N, -1)], dim=-1), + new_v, + casual=True, + return_softmax_lse=True +) + +// Compute attention with the already existing context +for chunk_idx in range(cdiv(C, MCC)): + chunk_start = chunk_idx * MCC + chunk_end = min(chunk_start + MCC, C) + Sc = chunk_end - chunk_start + cache_kv_c_chunk = cache_kv_c[chunk_start:chunk_end] + cache_k_pe_chunk = cache_k_pe[chunk_start:chunk_end] + cache_k_nope_chunk = (cache_kv_c_chunk @ W_UK).view(-1, N, P) + cache_v_chunk = (cache_kv_c_chunk @ W_UV).view(-1, N, V) + + chunk_o, chunk_lse = scaled_dot_product_attention( + torch.cat([q_nope, q_pe], dim=-1), + torch.cat([cache_k_nope_chunk, + cache_k_pe_chunk.unsqueeze(1).expand(-1, N, -1)], + dim=-1), + cache_v_chunk, + casual=False, + return_softmax_lse=True + ) + + curr_o, curr_lse = merge_attn_states( + suffix_output=curr_o, + suffix_lse=curr_lse, + prefix_output=chunk_o, + prefix_lse=chunk_lse, + ) + +return curr_o @ W_O +""" + +import functools +from abc import abstractmethod +from collections import defaultdict +from contextlib import contextmanager +from dataclasses import dataclass +from itertools import accumulate +from typing import (TYPE_CHECKING, Any, Dict, Generic, List, Optional, Tuple, + Type, TypeVar) + +import torch +from compressed_tensors.quantization import QuantizationStrategy + +from vllm import _custom_ops as ops +from vllm import envs +from vllm.attention.backends.abstract import (AttentionBackend, AttentionLayer, + AttentionMetadata, + AttentionMetadataBuilder, + AttentionState, MLAAttentionImpl) +from vllm.attention.backends.utils import (PAD_SLOT_ID, compute_slot_mapping, + compute_slot_mapping_start_idx, + get_flash_attn_version, + is_block_tables_empty) +from vllm.attention.ops.triton_merge_attn_states import merge_attn_states +from vllm.distributed import (get_tensor_model_parallel_world_size, + tensor_model_parallel_all_reduce) +from vllm.model_executor.layers.linear import (ColumnParallelLinear, + LinearBase, RowParallelLinear, + UnquantizedLinearMethod) +from vllm.model_executor.layers.quantization.compressed_tensors.compressed_tensors import ( # noqa: E501 + CompressedTensorsLinearMethod) +from vllm.model_executor.layers.quantization.compressed_tensors.schemes import ( + CompressedTensorsW8A8Fp8) +from vllm.model_executor.layers.quantization.fp8 import Fp8LinearMethod +from vllm.model_executor.layers.quantization.utils.fp8_utils import ( + apply_fp8_linear_generic, current_platform_fp8_dtype, is_fp8) +from vllm.model_executor.layers.quantization.utils.quant_utils import ( + scaled_quantize) +from vllm.model_executor.layers.rotary_embedding import ( + DeepseekScalingRotaryEmbedding, RotaryEmbedding) +from vllm.multimodal import MultiModalPlaceholderMap +from vllm.utils import async_tensor_h2d, cdiv, make_tensor_with_pad, round_down + +try: + from vllm.vllm_flash_attn import flash_attn_varlen_func +except ImportError: + # For rocm use upstream flash attention + from flash_attn import flash_attn_varlen_func + +if TYPE_CHECKING: + from vllm.worker.model_runner import (ModelInputForGPUBuilder, + ModelInputForGPUWithSamplingMetadata) + + +class MLACommonBackend(AttentionBackend): + + @staticmethod + def get_name() -> str: + return "TRITON_MLA" + + @staticmethod + def get_metadata_cls() -> Type["AttentionMetadata"]: + return MLACommonMetadata + + @staticmethod + def get_builder_cls() -> Type["MLACommonMetadataBuilder"]: + return MLACommonMetadataBuilder + + @staticmethod + def get_state_cls() -> Type["MLACommonState"]: + return MLACommonState + + @staticmethod + def get_kv_cache_shape( + num_blocks: int, + block_size: int, + num_kv_heads: int, # assumed to be 1 for MLA + head_size: int, + ) -> Tuple[int, ...]: + return (num_blocks, block_size, head_size) + + @staticmethod + def swap_blocks( + src_kv_cache: torch.Tensor, + dst_kv_cache: torch.Tensor, + src_to_dst: torch.Tensor, + ) -> None: + ops.swap_blocks(src_kv_cache, dst_kv_cache, src_to_dst) + + @staticmethod + def copy_blocks( + kv_caches: List[torch.Tensor], + src_to_dists: torch.Tensor, + ) -> None: + ops.copy_blocks_mla(kv_caches, src_to_dists) + + @staticmethod + def get_supported_head_sizes() -> List[int]: + return [576] + + +class MLACommonState(AttentionState): + + def __init__(self, runner): + self.runner = runner + self._is_graph_capturing = False + + scheduler_config = runner.scheduler_config + self.model_config = runner.model_config + cache_config = runner.cache_config + + self.chunked_prefill_enabled = scheduler_config.chunked_prefill_enabled + + if self.chunked_prefill_enabled: + self.chunked_prefill_workspace_size = min( + # Max sure there is enough for 8 full length request or at least + # 4 pages of cache per request + max( + 8 * self.model_config.max_model_len, 4 * + scheduler_config.max_num_seqs * cache_config.block_size), + # For long-context models try not to over-allocate limiting + # kv-cache space, limiting it to 64k tokens, + # which would result in the workspace being: + # 2*(576)*(64*1024) = 144mb + # (assuming 576 MLA head dim, and fp16) + # which would result in up-projected context being + # 2*(192*128)*(64*1024) = 3gb + # (assuming 192 QK head dim, 128 heads, and fp16) + 128 * 1024) + assert self.chunked_prefill_workspace_size >= \ + scheduler_config.max_num_seqs * cache_config.block_size + + @contextmanager + def graph_capture(self, max_batch_size: int): + self._is_graph_capturing = True + + self._graph_slot_mapping = torch.full((max_batch_size, ), + PAD_SLOT_ID, + dtype=torch.long, + device=self.runner.device) + self._graph_seq_lens = torch.ones(max_batch_size, + dtype=torch.int32, + device=self.runner.device) + self._graph_block_tables = torch.from_numpy( + self.runner.graph_block_tables).to(device=self.runner.device) + + self._positions = torch.zeros((max_batch_size, ), + dtype=torch.long, + device=self.runner.device) + + yield + + self._is_graph_capturing = False + del self._graph_slot_mapping + del self._graph_seq_lens + del self._graph_block_tables + del self._positions + + def graph_clone(self, batch_size: int): + assert self._is_graph_capturing + return self.__class__(self.runner) + + def graph_capture_get_metadata_for_batch( + self, batch_size: int, is_encoder_decoder_model: bool = False): + assert self._is_graph_capturing + + attn_metadata = self.runner.attn_backend.make_metadata( + multi_modal_placeholder_index_maps=None, + enable_kv_scales_calculation=False, + use_cuda_graph=True, + num_prefills=0, + num_prefill_tokens=0, + num_decode_tokens=batch_size, + slot_mapping=self._graph_slot_mapping[:batch_size], + seq_lens=None, + seq_lens_tensor=self._graph_seq_lens[:batch_size], + max_query_len=1, + max_decode_query_len=1, + max_prefill_seq_len=0, + max_decode_seq_len=self.runner.max_seq_len_to_capture, + query_start_loc=None, + seq_start_loc=None, + context_lens_tensor=None, + block_tables=self._graph_block_tables[:batch_size], + input_positions=self._positions[:batch_size], + head_dim=self.runner.model_config.get_head_size()) + + if is_encoder_decoder_model: + raise NotImplementedError( + "MLACommonState does not support encoder/decoder yet") + + return attn_metadata + + def get_graph_input_buffers(self, + attn_metadata, + is_encoder_decoder_model: bool = False): + input_buffers = { + "slot_mapping": attn_metadata.slot_mapping, + "seq_lens_tensor": attn_metadata.decode_metadata.seq_lens_tensor, + "block_tables": attn_metadata.decode_metadata.block_tables, + "input_positions": attn_metadata.decode_metadata.input_positions, + } + if is_encoder_decoder_model: + raise NotImplementedError( + "MLACommonState does not support encoder/decoder yet") + + return input_buffers + + def prepare_graph_input_buffers(self, + input_buffers, + attn_metadata, + is_encoder_decoder_model: bool = False): + input_positions = attn_metadata.input_positions + num_positions = input_positions.shape[0] + input_buffers["seq_lens_tensor"].copy_( + attn_metadata.decode_metadata.seq_lens_tensor, non_blocking=True) + input_buffers["block_tables"].copy_( + attn_metadata.decode_metadata.block_tables, non_blocking=True) + # CUDA graph buffer is padded so only perform a partial copy based on + # num_positions + input_buffers["input_positions"][:num_positions].copy_( + input_positions, non_blocking=True) + if is_encoder_decoder_model: + raise NotImplementedError( + "TritonMLAState does not support encoder/decoder yet") + + def begin_forward(self, model_input): + if self.chunked_prefill_enabled: + if not hasattr(self, "chunked_prefill_workspace"): + # not self.runner.device does not return the correct device + # for this process, (init_device sets the correct device but + # only on the Worker). The only way Ive figured out to get the + # correct device is to allocate the workspace on the first call + # to begin_forward and use the device of the input tokens + assert model_input.input_tokens is not None + self.chunked_prefill_workspace = torch.empty( + (self.chunked_prefill_workspace_size, + self.model_config.get_head_size()), + dtype=self.model_config.dtype, + device=model_input.input_tokens.device, + ) + + model_input.attn_metadata.chunked_prefill_workspace = \ + self.chunked_prefill_workspace + + +@dataclass +class MLACommonMetadata(AttentionMetadata): + """Metadata for MLACommon. + + NOTE: Please read the comment at the top of the file before trying to + understand this class + + NOTE: Any python object stored here is not updated when it is + cuda-graph replayed. If you have values that need to be changed + dynamically, it should be stored in tensor. The tensor has to be + updated from `CUDAGraphRunner.forward` API. + """ + # Whether or not if cuda graph is enabled. + # Cuda-graph is currently enabled for decoding only. + # TODO(woosuk): Move `use_cuda_graph` out since it's unrelated to attention. + use_cuda_graph: bool + + # New for MLA (compared to FlashAttention) + # Input positions for rotrary embeddings since for MLA the rotary + # position embeddings are applied inside the attention backend + input_positions: torch.Tensor + + # NOTE(sang): Definition of context_len, query_len, and seq_len. + # |---------- N-1 iteration --------| + # |---------------- N iteration ---------------------| + # |- tokenA -|......................|-- newTokens ---| + # |---------- context_len ----------| + # |-------------------- seq_len ---------------------| + # |-- query_len ---| + + # (batch_size,). The sequence length per sequence. Sequence length means + # the computed tokens + new tokens None if it is a decoding. + seq_lens: Optional[List[int]] + # seq_lens stored as a tensor. + seq_lens_tensor: Optional[torch.Tensor] + + # Maximum sequence length among prefill batch. 0 if there are decoding + # requests only. + max_prefill_seq_len: int + # Maximum sequence length among decode batch. 0 if there are prefill + # requests only. + max_decode_seq_len: int + # (batch_size,) A tensor of context lengths (tokens that are computed + # so far). + context_lens_tensor: Optional[torch.Tensor] + + # (batch_size, max_blocks_per_seq). + # Block addresses per sequence. (Seq id -> list of physical block) + # E.g., [0, 1, 2] means tokens are stored in 0th, 1st, and 2nd blocks + # in the kv cache. Each block can contain up to block_size tokens. + # 2nd dimensions are padded up to max_blocks_per_seq if it is cuda-graph + # captured. + block_tables: Optional[torch.Tensor] + + # Maximum query length in the batch. + max_query_len: Optional[int] = None + + # Max number of query tokens among request in the batch. + max_decode_query_len: Optional[int] = None + + # (batch_size + 1,). The cumulative subquery lengths of the sequences in + # the batch, used to index into subquery. E.g., if the subquery length + # is [4, 6], it is [0, 4, 10]. + query_start_loc: Optional[torch.Tensor] = None + # (batch_size + 1,). The cumulative sequence lengths of the sequences in + # the batch, used to index into sequence. E.g., if the sequence length is + # [4, 6], it is [0, 4, 10]. + seq_start_loc: Optional[torch.Tensor] = None + + _cached_prefill_metadata: Optional["MLACommonMetadata"] = None + _cached_decode_metadata: Optional["MLACommonMetadata"] = None + + num_prefill_tokens: int + + # The dimension of the attention heads + head_dim: Optional[int] = None + + # Used when chunked prefill is enabled to simulate worst case workspace + # allocations, hopefully to avoid going OOM + is_profile_run: bool = False + + # New for MLA (compared to FlashAttention) + # For chunked prefill + context_chunk_cu_seq_lens: Optional[torch.Tensor] = None + context_chunk_starts: Optional[torch.Tensor] = None + context_chunk_seq_tot: Optional[List[int]] = None + context_chunk_max_seq_lens: Optional[List[int]] = None + # Set by MLAAttentionState in `begin_forward` so it doesn't get broadcasted + chunked_prefill_workspace: Optional[torch.Tensor] = None + + def __post_init__(self): + supported_head_sizes = MLACommonBackend.get_supported_head_sizes() + if self.head_dim is not None and self.head_dim \ + not in supported_head_sizes: + raise ValueError( + f"Only {supported_head_sizes} are supported for head_dim,", + f"received {self.head_dim}.") + + @property + def prefill_metadata(self) -> Optional["MLACommonMetadata"]: + if self.num_prefills == 0: + return None + + if self._cached_prefill_metadata is not None: + return self._cached_prefill_metadata + + assert self.seq_lens is not None + assert self.seq_lens_tensor is not None + + # Compute some attn_metadata fields which default to None + query_start_loc = (None if self.query_start_loc is None else + self.query_start_loc[:self.num_prefills + 1]) + slot_mapping = (None if self.slot_mapping is None else + self.slot_mapping[:self.num_prefill_tokens]) + seq_lens = (None if self.seq_lens is None else + self.seq_lens[:self.num_prefills]) + seq_lens_tensor = (None if self.seq_lens_tensor is None else + self.seq_lens_tensor[:self.num_prefills]) + seq_start_loc = (None if self.seq_start_loc is None else + self.seq_start_loc[:self.num_prefills + 1]) + context_lens_tensor = (None if self.context_lens_tensor is None else + self.context_lens_tensor[:self.num_prefills]) + block_tables = (None if self.block_tables is None else + self.block_tables[:self.num_prefills]) + input_positions = (None if self.input_positions is None else + self.input_positions[:self.num_prefill_tokens]) + + self._cached_prefill_metadata = MLACommonMetadata( + # Required by ModelRunner + use_cuda_graph=False, # Not Attention Related + # Required by Attention Metadata + num_prefills=self.num_prefills, + num_prefill_tokens=self.num_prefill_tokens, + num_decode_tokens=0, + slot_mapping=slot_mapping, + # Required by Attention Metadata (not used) + multi_modal_placeholder_index_maps=None, + enable_kv_scales_calculation=False, + # MLACommonMetadata + input_positions=input_positions, + seq_lens=seq_lens, + seq_lens_tensor=seq_lens_tensor, + max_query_len=self.max_query_len, + max_prefill_seq_len=self.max_prefill_seq_len, + max_decode_query_len=0, + max_decode_seq_len=0, + query_start_loc=query_start_loc, + seq_start_loc=seq_start_loc, + context_lens_tensor=context_lens_tensor, + block_tables=block_tables, + head_dim=self.head_dim, + is_profile_run=self.is_profile_run, + # MLACommonMetadata Chunk prefill specific + context_chunk_cu_seq_lens=self.context_chunk_cu_seq_lens, + context_chunk_starts=self.context_chunk_starts, + context_chunk_seq_tot=self.context_chunk_seq_tot, + context_chunk_max_seq_lens=self.context_chunk_max_seq_lens, + ) + return self._cached_prefill_metadata + + @property + def decode_metadata(self) -> Optional["MLACommonMetadata"]: + if self.num_decode_tokens == 0: + return None + + if self._cached_decode_metadata is not None: + return self._cached_decode_metadata + assert self.seq_lens_tensor is not None + + # Compute some attn_metadata fields which default to None + slot_mapping = (None if self.slot_mapping is None else + self.slot_mapping[self.num_prefill_tokens:]) + seq_lens_tensor = (None if self.seq_lens_tensor is None else + self.seq_lens_tensor[self.num_prefills:]) + block_tables = (None if self.block_tables is None else + self.block_tables[self.num_prefills:]) + input_positions = (None if self.input_positions is None else + self.input_positions[self.num_prefill_tokens:]) + + self._cached_decode_metadata = MLACommonMetadata( + # Required by ModelRunner + use_cuda_graph=self.use_cuda_graph, # Not Attention Related + # Required by Attention Metadata + num_prefills=0, + num_prefill_tokens=0, + num_decode_tokens=self.num_decode_tokens, + slot_mapping=slot_mapping, + # Required by Attention Metadata (not used) + multi_modal_placeholder_index_maps=None, + enable_kv_scales_calculation=False, + # MLACommonMetadata + seq_lens=None, + seq_lens_tensor=seq_lens_tensor, + max_decode_query_len=self.max_decode_query_len, + max_query_len=self.max_query_len, + max_prefill_seq_len=0, + max_decode_seq_len=self.max_decode_seq_len, + # Batch may be composed of prefill|decodes, adjust query start + # indices to refer to the start of decodes. E.g. + # in tokens:[3 prefills|6 decodes], query_start_loc=[3,9] => [0,6]. + query_start_loc=(self.query_start_loc[self.num_prefills:] - + self.query_start_loc[self.num_prefills]) + if self.query_start_loc is not None else None, + seq_start_loc=self.seq_start_loc[self.num_prefills:] + if self.seq_start_loc is not None else None, + context_lens_tensor=None, + block_tables=block_tables, + input_positions=input_positions, + head_dim=self.head_dim, + is_profile_run=self.is_profile_run) + return self._cached_decode_metadata + + def advance_step(self, + model_input: "ModelInputForGPUWithSamplingMetadata", + sampled_token_ids: Optional[torch.Tensor], + block_size: int, + num_seqs: int, + num_queries: int, + turn_prefills_into_decodes: bool = False): + """ + Update metadata in-place to advance one decode step. + """ + # When using cudagraph, the num_seqs is padded to the next captured + # batch sized, but num_queries tracks the actual number of requests in + # the batch. For --enforce-eager mode, num_seqs == num_queries + if num_seqs != num_queries: + assert num_seqs > num_queries + + if turn_prefills_into_decodes: + # When Mutli-Step is enabled with Chunked-Prefill, prefills and + # decodes are scheduled together. In the first step, all the + # prefills turn into decodes. This update reflects that + # conversion. + assert self.num_decode_tokens + self.num_prefills == num_seqs + self.num_decode_tokens += self.num_prefills + self.num_prefills = 0 + self.num_prefill_tokens = 0 + self.max_prefill_seq_len = 0 + self.max_query_len = 1 + + self.slot_mapping = self.slot_mapping[:num_seqs] + else: + assert self.seq_lens is not None + assert self.max_decode_seq_len == max(self.seq_lens) + + assert self.num_prefills == 0 + assert self.num_prefill_tokens == 0 + assert self.num_decode_tokens == num_seqs + assert self.slot_mapping.shape == (num_seqs, ) + + assert self.seq_lens is not None + assert len(self.seq_lens) == num_seqs + assert self.seq_lens_tensor is not None + assert self.seq_lens_tensor.shape == (num_seqs, ) + assert self.max_query_len == 1 + assert self.max_prefill_seq_len == 0 + + assert self.query_start_loc is not None + assert self.query_start_loc.shape == (num_queries + 1, ) + assert self.seq_start_loc is not None + assert self.seq_start_loc.shape == (num_seqs + 1, ) + + assert self.context_lens_tensor is not None + assert self.context_lens_tensor.shape == (num_queries, ) + + assert self.block_tables is not None + assert self.block_tables.shape[0] == num_seqs + + # Update query lengths. Note that we update only queries and not seqs, + # since tensors may be padded due to captured cuda graph batch size + for i in range(num_queries): + self.seq_lens[i] += 1 + self.max_decode_seq_len = max(self.seq_lens) + + ops.advance_step_flashattn(num_seqs=num_seqs, + num_queries=num_queries, + block_size=block_size, + input_tokens=model_input.input_tokens, + sampled_token_ids=sampled_token_ids, + input_positions=model_input.input_positions, + seq_lens=self.seq_lens_tensor, + slot_mapping=self.slot_mapping, + block_tables=self.block_tables) + + +T = TypeVar("T", bound=MLACommonMetadata) + + +class MLACommonMetadataBuilder(AttentionMetadataBuilder[MLACommonMetadata]): + """ + NOTE: Please read the comment at the top of the file before trying to + understand this class + """ + + def __init__(self, input_builder: "ModelInputForGPUBuilder"): + self.input_builder = input_builder + self.runner = input_builder.runner + self.sliding_window = input_builder.sliding_window + self.block_size = input_builder.block_size + self.chunked_prefill_enabled = \ + self.runner.scheduler_config.chunked_prefill_enabled + + if self.chunked_prefill_enabled: + attn_state = self.input_builder.runner.attn_state + self.chunked_prefill_workspace_size = \ + attn_state.chunked_prefill_workspace_size + self.page_size = self.runner.block_size + + def prepare(self): + self.slot_mapping: List[int] = [] + self.prefill_seq_lens: List[int] = [] + self.context_lens: List[int] = [] + self.block_tables: List[List[int]] = [] + self.curr_seq_lens: List[int] = [] + self.input_positions: List[int] = [] + self.multimodal_placeholder_maps: Dict[ + str, + MultiModalPlaceholderMap] = defaultdict(MultiModalPlaceholderMap) + self.num_prefills = 0 + self.num_prefill_tokens = 0 + self.num_decode_tokens = 0 + self.has_prefix_cache_hit = False + + def _add_seq_group( + self, inter_data: "ModelInputForGPUBuilder.InterDataForSeqGroup", + chunked_prefill_enabled: bool, prefix_cache_hit: bool): + """Add a sequence group to the metadata. Specifically update/append + 1. context length. + 2. block table. + 3. slot mapping. + """ + is_prompt = inter_data.is_prompt + block_tables = inter_data.block_tables + + for (seq_id, token_len, seq_len, curr_seq_len, query_len, context_len, + curr_sliding_window_block, input_positions) in zip( + inter_data.seq_ids, [len(t) for t in inter_data.input_tokens], + inter_data.orig_seq_lens, inter_data.seq_lens, + inter_data.query_lens, inter_data.context_lens, + inter_data.curr_sliding_window_blocks, + inter_data.input_positions): + self.input_positions.extend(input_positions) + self.context_lens.append(context_len) + if is_prompt: + self.num_prefills += 1 + self.num_prefill_tokens += token_len + self.prefill_seq_lens.append(seq_len) + else: + self.num_decode_tokens += query_len + self.curr_seq_lens.append(curr_seq_len) + + # Compute block table. + # TODO(sang): Combine chunked prefill and prefix caching by + # only allowing multiple of block_size chunk size. + # NOTE: This only works for oooooooxxx style attention. + block_table = [] + if prefix_cache_hit: + # NOTE(woosuk): For flash-attn, the block table should + # include the entries for the incoming prefill tokens. + block_table = block_tables[seq_id] + elif ((chunked_prefill_enabled or not is_prompt) + and block_tables is not None): + if curr_sliding_window_block == 0: + block_table = block_tables[seq_id] + else: + block_table = block_tables[seq_id][ + -curr_sliding_window_block:] + self.block_tables.append(block_table) + + # Compute slot mapping. + is_profile_run = is_block_tables_empty(block_tables) + start_idx = compute_slot_mapping_start_idx(is_prompt, query_len, + context_len, + self.sliding_window) + compute_slot_mapping(is_profile_run, self.slot_mapping, seq_id, + seq_len, context_len, start_idx, + self.block_size, inter_data.block_tables) + + def _get_graph_runner_block_tables( + self, num_seqs: int, + block_tables: List[List[int]]) -> torch.Tensor: + # The shape of graph_block_tables is + # [max batch size, max context len // block size]. + max_batch_size, max_blocks = self.runner.graph_block_tables.shape + assert max_batch_size >= num_seqs + + graph_block_tables = self.runner.graph_block_tables[:num_seqs] + for i, block_table in enumerate(block_tables): + if block_table: + num_blocks = len(block_table) + if num_blocks <= max_blocks: + graph_block_tables[i, :num_blocks] = block_table + else: + # It may be possible to have more blocks allocated due + # to lookahead slots of multi-step, however, they are + # not used anyway, so can be safely ignored. + graph_block_tables[ + i, :max_blocks] = block_table[:max_blocks] + + return torch.from_numpy(graph_block_tables).to( + device=self.runner.device, non_blocking=True) + + def build(self, seq_lens: List[int], query_lens: List[int], + cuda_graph_pad_size: int, batch_size: int): + """Build attention metadata with on-device tensors. + + Args: + seq_lens: The maybe padded sequence lengths of the input sequences. + query_lens: The query lengths of the input sequences. + cuda_graph_pad_size: The padding size for cuda graph. + -1 if cuda graph is not used. + batch_size: The maybe padded batch size. + """ + prefix_cache_hit = any([ + inter_data.prefix_cache_hit + for inter_data in self.input_builder.inter_data_list + ]) + + for inter_data in self.input_builder.inter_data_list: + self._add_seq_group(inter_data, + self.input_builder.chunked_prefill_enabled, + prefix_cache_hit) + + device = self.runner.device + use_captured_graph = cuda_graph_pad_size != -1 + + max_query_len = max(query_lens) + decode_query_lens = query_lens[self.num_prefills:] + if len(decode_query_lens) > 0: + max_decode_query_len = max(decode_query_lens) + else: + max_decode_query_len = 1 + max_prefill_seq_len = max(self.prefill_seq_lens, default=0) + max_decode_seq_len = max(self.curr_seq_lens, default=0) + num_decode_tokens = self.num_decode_tokens + query_start_loc = list(accumulate(query_lens, initial=0)) + seq_start_loc = list(accumulate(seq_lens, initial=0)) + + num_seqs = len(seq_lens) + if use_captured_graph: + self.slot_mapping.extend([PAD_SLOT_ID] * cuda_graph_pad_size) + self.block_tables.extend([] * cuda_graph_pad_size) + num_decode_tokens = batch_size - self.num_prefill_tokens + block_tables = self._get_graph_runner_block_tables( + num_seqs, self.block_tables) + else: + block_tables = make_tensor_with_pad( + self.block_tables, + pad=0, + dtype=torch.int, + device=device, + ) + assert max_query_len > 0, ("query_lens: {}".format(query_lens)) + + assert device is not None + context_lens_tensor = async_tensor_h2d(self.context_lens, torch.int, + device, self.runner.pin_memory) + seq_lens_tensor = async_tensor_h2d(seq_lens, torch.int, device, + self.runner.pin_memory) + input_positions = async_tensor_h2d(self.input_positions, torch.long, + device, self.runner.pin_memory) + slot_mapping_tensor = async_tensor_h2d(self.slot_mapping, torch.long, + device, self.runner.pin_memory) + query_start_loc_tensor = async_tensor_h2d(query_start_loc, torch.int32, + device, + self.runner.pin_memory) + seq_start_loc_tensor = async_tensor_h2d(seq_start_loc, torch.int32, + device, self.runner.pin_memory) + + context_chunk_cu_seq_lens = None + context_chunk_starts = None + context_chunk_seq_tot = None + context_chunk_max_seq_lens = None + + if self.chunked_prefill_enabled and self.num_prefills > 0 \ + and context_lens_tensor is not None \ + and context_lens_tensor[:self.num_prefills].max() > 0: + + # NOTE: it is recommend you read the `Chunked Prefill` section in + # the comment at the top of the file before trying to understand + # the following code + + num_prefills_with_context = \ + (context_lens_tensor[:self.num_prefills] > 0).sum().item() + + # currently we allocate an equal amount of workspace for each + # prefill in the batch, we could probably use a more advanced + # algorithm here and allocate more workspace to prefills with + # longer context lengths + max_context_chunk = \ + self.chunked_prefill_workspace_size // num_prefills_with_context + + # align max_context_chunk to page_size by rounding down, + # currently the `gather_cache` kernel cannot handle + # `context_chunk_starts` that are not aligned to page_size + max_context_chunk = round_down(max_context_chunk, self.page_size) + assert max_context_chunk > 0 + num_chunks = cdiv(context_lens_tensor.max(), max_context_chunk) + + # if `max_context_chunk = 256`, `num_chunks = 3`, and + # `num_prefills_with_context = 4`, create a tensor that looks like + # [[0, 0, 0, 0], [256, 256, 256, 256], [512, 512, 512, 512]] + context_chunk_starts = \ + torch.arange(num_chunks, device=device, dtype=torch.int32)\ + .unsqueeze(1).expand(-1, self.num_prefills)\ + * max_context_chunk + chunk_ends = torch.min(context_lens_tensor[:self.num_prefills]\ + .unsqueeze(0), context_chunk_starts + max_context_chunk) + chunk_seq_lens = (chunk_ends - context_chunk_starts).clamp(min=0) + _context_chunk_cu_seq_lens = chunk_seq_lens.cumsum(dim=1).to( + torch.int32) + zero = torch.zeros(num_chunks, dtype=torch.int32, device=device)\ + .unsqueeze(-1) + context_chunk_cu_seq_lens = \ + torch.cat([zero, _context_chunk_cu_seq_lens], dim=1) + context_chunk_max_seq_lens = \ + chunk_seq_lens.max(dim=1).values.tolist() + context_chunk_seq_tot = chunk_seq_lens.sum(dim=1).tolist() + assert max(context_chunk_seq_tot) <= \ + self.chunked_prefill_workspace_size + + return MLACommonMetadata( + # Required by ModelRunner + use_cuda_graph=use_captured_graph, # Not Attention Related + # Required by Attention Metadata + num_prefills=self.num_prefills, + slot_mapping=slot_mapping_tensor, + num_prefill_tokens=self.num_prefill_tokens, + num_decode_tokens=num_decode_tokens, + # Required by Attention Metadata (not used) + multi_modal_placeholder_index_maps=None, # Not Attention Related + enable_kv_scales_calculation=False, + # MLACommonMetadata + input_positions=input_positions, + seq_lens=seq_lens, + seq_lens_tensor=seq_lens_tensor, + max_query_len=max_query_len, + max_decode_query_len=max_decode_query_len, + max_prefill_seq_len=max_prefill_seq_len, + max_decode_seq_len=max_decode_seq_len, + query_start_loc=query_start_loc_tensor, + seq_start_loc=seq_start_loc_tensor, + context_lens_tensor=context_lens_tensor, + block_tables=block_tables, + head_dim=self.runner.model_config.get_head_size(), + is_profile_run=self.runner.in_profile_run, + # MLACommonMetadata Chunk prefill specific + context_chunk_cu_seq_lens=context_chunk_cu_seq_lens, + context_chunk_starts=context_chunk_starts, + context_chunk_seq_tot=context_chunk_seq_tot, + context_chunk_max_seq_lens=context_chunk_max_seq_lens, + ) + + +class MLACommonImpl(MLAAttentionImpl[T], Generic[T]): + """ + NOTE: Please read the comment at the top of the file before trying to + understand this class + """ + + def __init__( + self, + num_heads: int, + head_size: int, + scale: float, + num_kv_heads: int, + alibi_slopes: Optional[List[float]], + sliding_window: Optional[int], + kv_cache_dtype: str, + blocksparse_params: Optional[Dict[str, Any]], + logits_soft_cap: Optional[float], + attn_type: str, + # MLA Specific Arguments + q_lora_rank: Optional[int], + kv_lora_rank: int, + qk_nope_head_dim: int, + qk_rope_head_dim: int, + qk_head_dim: int, + v_head_dim: int, + rotary_emb: RotaryEmbedding, + # q_proj should be q_b_proj if q_lora_rank is not None, but from an + # attention backend perspective we rely on the layer to pass in the + # correct matrix + q_proj: ColumnParallelLinear, + kv_b_proj: ColumnParallelLinear, + o_proj: RowParallelLinear, + ) -> None: + self.num_heads = num_heads + self.head_size = head_size + self.scale = float(scale) + self.num_kv_heads = num_kv_heads + self.kv_cache_dtype = kv_cache_dtype + + self.q_lora_rank = q_lora_rank + self.kv_lora_rank = kv_lora_rank + self.qk_nope_head_dim = qk_nope_head_dim + self.qk_rope_head_dim = qk_rope_head_dim + self.qk_head_dim = qk_head_dim + self.v_head_dim = v_head_dim + + self.rotary_emb = rotary_emb + self.use_yarn_rope = isinstance(rotary_emb, + DeepseekScalingRotaryEmbedding) + self.q_proj = q_proj + self.kv_b_proj = kv_b_proj + self.o_proj = o_proj + self.vllm_flash_attn_version = get_flash_attn_version() + + # Handle the differences between the flash_attn_varlen from flash_attn + # and the one from vllm_flash_attn. The former is used on RoCM and the + # latter has an additional parameter to control FA2 vs FA3 + self.flash_attn_varlen_func = flash_attn_varlen_func + if self.vllm_flash_attn_version is not None: + self.flash_attn_varlen_func = \ + functools.partial(flash_attn_varlen_func, + fa_version=self.vllm_flash_attn_version) + + def _v_up_proj_and_o_proj(self, x): + if envs.VLLM_MLA_PERFORM_MATRIX_ABSORPTION: + if is_fp8(self.W_UV_O): + output_parallel = apply_fp8_linear_generic( + x.flatten(start_dim=1), self.W_UV_O, self.W_UV_O_scales, + self.reqaunt_input_group_shape, + self.reqaunt_weight_group_shape) + else: + output_parallel = torch.matmul(x.flatten(start_dim=1), + self.W_UV_O) + if self.tp_size > 1: + output = tensor_model_parallel_all_reduce(output_parallel) + else: + output = output_parallel + return output + else: + x = torch.einsum("bnl,lnv->bnv", x, self.W_UV) + return self.o_proj(x.reshape(-1, + self.num_heads * self.v_head_dim))[0] + + def _q_proj_and_k_up_proj(self, x): + if envs.VLLM_MLA_PERFORM_MATRIX_ABSORPTION: + if is_fp8(self.W_Q_UK): + return apply_fp8_linear_generic( + x, self.W_Q_UK, self.W_Q_UK_scales, + self.reqaunt_input_group_shape, + self.reqaunt_weight_group_shape).view( + -1, self.num_heads, self.kv_lora_rank) + return torch.matmul(x, self.W_Q_UK)\ + .view(-1, self.num_heads, self.kv_lora_rank) + else: + x = torch.matmul(x, self.W_Q)\ + .view(-1, self.num_heads, self.qk_nope_head_dim) + return torch.einsum("bnp,lnp->bnl", x, self.W_UK)\ + .view(-1, self.num_heads, self.kv_lora_rank) + + def process_weights_after_loading(self, act_dtype: torch.dtype): + + # TODO(lucas) This is very gross, we need a more wide scale refactor of + # all the FP8 code with a more standard way of + # defining schemes/group-shapes, we should also potentially force + # quant_methods to support a decompress function + # + # returns input_group_shape, weight_group_shape + def get_scale_group_shapes_for_fp8(layer: LinearBase) -> \ + Tuple[Tuple[int, int], Tuple[int, int]]: + if isinstance(layer.quant_method, Fp8LinearMethod): + if layer.quant_method.block_quant: + weight_block_size = \ + layer.quant_method.quant_config.weight_block_size + # per-token-group (1, X), block-quantized (X, Y) + return (1, weight_block_size[-1]), weight_block_size + else: + return (-1, -1), (-1, -1) # per-tensor, per-tensor + elif isinstance(layer.quant_method, CompressedTensorsLinearMethod)\ + and isinstance(layer.scheme, CompressedTensorsW8A8Fp8): + # this is hacky but we always assume the for + # CompressedTensorsW8A8Fp8 the input is dynamic per-token + # we ignore if it is static-per-tensor since we are going to + # requantize after later anyways + strategy = layer.scheme.strategy + if strategy == QuantizationStrategy.TENSOR: + return (1, -1), (-1, -1) # per-token, per-tensor + elif strategy == QuantizationStrategy.CHANNEL: + return (1, -1), (-1, 1) # per-token, per-channel + else: + raise NotImplementedError( + f"QuantizationStrategy.{strategy} is not supported for " + "fp8 MLA, please run with VLLM_MLA_DISABLE=1") + else: + raise NotImplementedError( + "Can't determine scale group shapes for " + f"{layer.quant_method}, please run with VLLM_MLA_DISABLE=1" + ) + + def get_layer_weight(layer): + if hasattr(layer, "weight"): + return layer.weight + elif hasattr(layer, "qweight"): + return layer.qweight + else: + raise AttributeError( + f"Layer '{layer}' has neither weight nor qweight") + + def get_and_maybe_dequant_weights(layer: LinearBase): + if not isinstance(layer.quant_method, UnquantizedLinearMethod): + # NOTE: This should only be used offline, since it's O(N^3) + eye = torch.eye(layer.input_size_per_partition, + dtype=act_dtype, + device=get_layer_weight(layer).device) + dequant_weights = layer.quant_method.apply(layer, + eye, + bias=None) + del eye + # standardize to (output, input) + return dequant_weights.T + return layer.weight + + weight_dtype = get_layer_weight(self.kv_b_proj).dtype + assert get_layer_weight(self.o_proj).dtype == weight_dtype + assert get_layer_weight(self.q_proj).dtype == weight_dtype + + kv_b_proj_weight = get_and_maybe_dequant_weights(self.kv_b_proj).T + assert kv_b_proj_weight.shape == ( + self.kv_lora_rank, + self.num_heads * (self.qk_nope_head_dim + self.v_head_dim)), ( + f"{kv_b_proj_weight.shape=}, " + f"{self.kv_lora_rank=}, " + f"{self.num_heads=}, " + f"{self.qk_nope_head_dim=}, " + f"{self.v_head_dim=}") + kv_b_proj_weight = kv_b_proj_weight.view( + self.kv_lora_rank, + self.num_heads, + self.qk_nope_head_dim + self.v_head_dim, + ) + + W_UK, W_UV = kv_b_proj_weight.split( + [self.qk_nope_head_dim, self.v_head_dim], dim=-1) + + q_proj_weight = get_and_maybe_dequant_weights(self.q_proj).T\ + .view(-1, self.num_heads, self.qk_head_dim) + + # can be W_Q or W_UQ depending q_lora_rank, the former if + # q_lora_rank is None, the latter otherwise. From the Attention backend + # perspective though we call these both W_Q and rely on the layer + # to pass in the correct matrix + W_Q = q_proj_weight[..., :self.qk_nope_head_dim] + self.W_QR = q_proj_weight[..., self.qk_nope_head_dim:]\ + .flatten(start_dim=1).contiguous() + + # W_QR is small so for simplicity we dont bother requantizing it + self.W_QR = self.W_QR.to(act_dtype) + + if envs.VLLM_MLA_PERFORM_MATRIX_ABSORPTION: + requantization_enabled = not envs.VLLM_MLA_DISABLE_REQUANTIZATION + if is_fp8(weight_dtype) and requantization_enabled: + # This assumes it wise to requantize using the same group shapes + # (i.e. strategy, per-tensor, per-channel, block etc.) that the + # weights were originally quantized + requant_input_group_shape, requant_weight_group_shape = \ + get_scale_group_shapes_for_fp8(self.q_proj) + assert (requant_input_group_shape, requant_weight_group_shape)\ + == get_scale_group_shapes_for_fp8(self.kv_b_proj) + assert (requant_input_group_shape, requant_weight_group_shape)\ + == get_scale_group_shapes_for_fp8(self.o_proj) + self.reqaunt_input_group_shape = requant_input_group_shape + self.reqaunt_weight_group_shape = requant_weight_group_shape + + # + # Perform matrix-absorption following + # https://github.com/flashinfer-ai/flashinfer/pull/551 + # for decode, as a result we end up with absorbed weights for decode + # and another copy of raw weights for prefill. + # + self.W_UK, self.W_UV = kv_b_proj_weight.split( + [self.qk_nope_head_dim, self.v_head_dim], dim=-1) + # We absorb `W_UK` into `W_Q` resulting in either W_Q_UK or W_UQ_UK + # depending q_lora_rank, the former if q_lora_rank is None, the + # latter otherwise + # basically if q_lora_rank is none we are absorbing into q_proj + # instead of UQ + W_Q_UK = torch.einsum("qnd,lnd -> qnl", W_Q, W_UK)\ + .flatten(start_dim=1).contiguous() + + if is_fp8(weight_dtype) and requantization_enabled: + W_Q_UK, W_Q_UK_scales = scaled_quantize( + W_Q_UK, + self.reqaunt_weight_group_shape, + quant_dtype=current_platform_fp8_dtype) + # For FP8 save the transpose so we can use + # `apply_w8a8_block_fp8_linear` directly + self.W_Q_UK = W_Q_UK.T.contiguous() + self.W_Q_UK_scales = W_Q_UK_scales.T.contiguous() + else: + self.W_Q_UK = W_Q_UK.to(act_dtype) + + W_O = get_and_maybe_dequant_weights(self.o_proj)\ + .view(-1, self.num_heads, self.v_head_dim) + W_UV_O = torch.einsum("lnd,hnd -> nlh", W_UV, W_O)\ + .flatten(start_dim=0, end_dim=1).contiguous() + + if is_fp8(weight_dtype) and requantization_enabled: + W_UV_O, W_UV_O_scales = scaled_quantize( + W_UV_O, + self.reqaunt_weight_group_shape, + quant_dtype=current_platform_fp8_dtype) + # For FP8 save the transpose so we can use + # `apply_w8a8_block_fp8_linear` directly + self.W_UV_O = W_UV_O.T.contiguous() + self.W_UV_O_scales = W_UV_O_scales.T.contiguous() + else: + self.W_UV_O = W_UV_O.to(act_dtype) + + self.tp_size = get_tensor_model_parallel_world_size() + else: + if is_fp8(weight_dtype): + raise NotImplementedError( + "Currently fp8 requires matrix absorption") + + self.W_UV = W_UV + self.W_UK = W_UK + self.W_Q = W_Q.flatten(start_dim=1) + + def _compute_prefill_context( + self, + q: torch.Tensor, + kv_c_and_k_pe_cache: torch.Tensor, + attn_metadata: MLACommonMetadata, + ): + prefill_metadata = attn_metadata.prefill_metadata + assert prefill_metadata is not None + assert prefill_metadata.context_chunk_seq_tot is not None + assert prefill_metadata.context_chunk_cu_seq_lens is not None + assert prefill_metadata.context_chunk_starts is not None + assert prefill_metadata.context_chunk_max_seq_lens is not None + assert prefill_metadata.context_lens_tensor is not None + + output = None + iters = len(prefill_metadata.context_chunk_seq_tot) + + # Fetch from attn_metadata directly, since it late bound by + # MLAAttentionState, grabbing it directly `attn_metadata` can avoid + # any weirdness around prefill_metadata caching + assert attn_metadata.chunked_prefill_workspace is not None + workspace = attn_metadata.chunked_prefill_workspace + + for i in range(iters): + toks = prefill_metadata.context_chunk_seq_tot[i] + + ops.gather_cache( + src_cache=kv_c_and_k_pe_cache, + dst=workspace, + block_table=prefill_metadata.block_tables, + cu_seq_lens=prefill_metadata.context_chunk_cu_seq_lens[i], + batch_size=prefill_metadata.num_prefills, + seq_starts=prefill_metadata.context_chunk_starts[i], + ) + + kv_c_normed = workspace[:toks]\ + [..., :self.kv_lora_rank].unsqueeze(1) + k_pe = workspace[:toks]\ + [..., self.kv_lora_rank:].unsqueeze(1) + + kv_nope = self.kv_b_proj(kv_c_normed)[0].view( \ + -1, self.num_heads, self.qk_nope_head_dim + self.v_head_dim) + k_nope, v = kv_nope\ + .split([self.qk_nope_head_dim, self.v_head_dim], dim=-1) + + k = torch.cat((k_nope, k_pe.expand((*k_nope.shape[:-1], -1))), + dim=-1) + + # For MLA the v head dim is smaller than qk head dim so we pad + # out v with 0s to match the qk head dim + v_padded = torch.nn.functional.pad(v, + [0, q.shape[-1] - v.shape[-1]], + value=0) + + attn_output, attn_softmax_lse = self.flash_attn_varlen_func( + q=q, + k=k, + v=v_padded, + cu_seqlens_q=prefill_metadata.query_start_loc, + cu_seqlens_k=prefill_metadata.context_chunk_cu_seq_lens[i], + max_seqlen_q=prefill_metadata.max_query_len, + max_seqlen_k=prefill_metadata.context_chunk_max_seq_lens[i], + softmax_scale=self.scale, + causal=False, # Context is unmasked + return_softmax_lse=True, + ) + + if output is None: + output = attn_output + output_lse = attn_softmax_lse + else: + output_tmp = torch.empty_like(output) + output_lse_tmp = torch.empty_like(output_lse) + merge_attn_states( + output=output_tmp, + output_lse=output_lse_tmp, + prefix_output=output, + prefix_lse=output_lse, + suffix_output=attn_output, + suffix_lse=attn_softmax_lse, + ) + output = output_tmp + output_lse = output_lse_tmp + + return output, output_lse + + def _forward_prefill( + self, + q: torch.Tensor, + kv_c_normed: torch.Tensor, + k_pe: torch.Tensor, + kv_c_and_k_pe_cache: torch.Tensor, + attn_metadata: MLACommonMetadata, + ) -> torch.Tensor: + + prefill_metadata = attn_metadata.prefill_metadata + assert prefill_metadata is not None + + has_context = prefill_metadata.context_lens_tensor is not None \ + and prefill_metadata.context_lens_tensor.max() > 0 + + kv_nope = self.kv_b_proj(kv_c_normed)[0].view(\ + -1, self.num_heads, self.qk_nope_head_dim + self.v_head_dim) + k_nope, v = kv_nope\ + .split([self.qk_nope_head_dim, self.v_head_dim], dim=-1) + + k = torch.cat((k_nope, k_pe.expand((*k_nope.shape[:-1], -1))), dim=-1) + + # For MLA the v head dim is smaller than qk head dim so we pad out + # v with 0s to match the qk head dim + v_padded = torch.nn.functional.pad(v, [0, q.shape[-1] - v.shape[-1]], + value=0) + + output = self.flash_attn_varlen_func( + q=q, + k=k, + v=v_padded, + cu_seqlens_q=prefill_metadata.query_start_loc, + cu_seqlens_k=prefill_metadata.query_start_loc, + max_seqlen_q=prefill_metadata.max_prefill_seq_len, + max_seqlen_k=prefill_metadata.max_prefill_seq_len, + softmax_scale=self.scale, + causal=True, + return_softmax_lse=has_context, + ) + + if has_context: + suffix_output, suffix_lse = output + context_output, context_lse = self._compute_prefill_context( \ + q, kv_c_and_k_pe_cache, attn_metadata) + + output = torch.empty_like(suffix_output) + merge_attn_states( + output=output, + prefix_output=context_output, + prefix_lse=context_lse, + suffix_output=suffix_output, + suffix_lse=suffix_lse, + ) + + # slice by `:v.shape[-1]` in order to remove v headdim padding + output = output\ + .view(-1, self.num_heads, q.shape[-1])[..., :v.shape[-1]]\ + .reshape(-1, self.num_heads * v.shape[-1]) + + return self.o_proj(output)[0] + + @abstractmethod + def _forward_decode( + self, + q_nope: torch.Tensor, + q_pe: torch.Tensor, + kv_c_and_k_pe_cache: torch.Tensor, + attn_metadata: T, + ) -> torch.Tensor: + raise NotImplementedError + + def forward( + self, + layer: AttentionLayer, + hidden_states_or_q_c: torch.Tensor, # query in unified attn + k_c_normed: torch.Tensor, # key in unified attn + k_pe: torch.Tensor, # value in unified attn + kv_cache: torch.Tensor, + attn_metadata: T, + output: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + if output is not None: + raise NotImplementedError( + "output is not yet supported for MLAImplBase") + + if attn_metadata.is_profile_run and \ + attn_metadata.chunked_prefill_workspace is not None: + # During the profile run try to simulate to worse case output size + # for `self.kv_b_proj(kv_c_normed)` in `_compute_prefill_context` + # since this can be large + _ = torch.empty( + (attn_metadata.chunked_prefill_workspace.shape[0], + self.num_heads, self.qk_nope_head_dim + self.v_head_dim), + device=k_c_normed.device, + dtype=k_c_normed.dtype, + ) + + has_decode = attn_metadata.decode_metadata is not None + has_prefill = attn_metadata.prefill_metadata is not None + + # Restore head dim (for rotary embedding) + k_pe = k_pe.unsqueeze(1) + assert hasattr(attn_metadata, "input_positions") + + num_prefill_tokens: int = attn_metadata.num_prefill_tokens + + decode_hs_or_q_c = hidden_states_or_q_c[num_prefill_tokens:] + decode_k_pe = k_pe[num_prefill_tokens:] + decode_input_positions = \ + attn_metadata.input_positions[num_prefill_tokens:] + + prefill_hs_or_q_c = hidden_states_or_q_c[:num_prefill_tokens] + prefill_k_pe = k_pe[:num_prefill_tokens] + prefill_input_positions = \ + attn_metadata.input_positions[:num_prefill_tokens] + prefill_k_c_normed = k_c_normed[:num_prefill_tokens] + + if has_decode: + decode_q_nope = self._q_proj_and_k_up_proj(decode_hs_or_q_c) + decode_q_pe = torch.matmul(decode_hs_or_q_c, self.W_QR)\ + .view(-1, self.num_heads, self.qk_rope_head_dim) + decode_q_pe[...], decode_k_pe[...] = self.rotary_emb( + decode_input_positions, decode_q_pe, decode_k_pe) + + if has_prefill: + prefill_q = self.q_proj(prefill_hs_or_q_c)[0]\ + .view(-1, self.num_heads, self.qk_head_dim) + prefill_q_pe = prefill_q[..., self.qk_nope_head_dim:] + prefill_q_pe[...], prefill_k_pe[...] = self.rotary_emb( + prefill_input_positions, prefill_q_pe, prefill_k_pe) + + # write the latent and rope to kv cache + if kv_cache.numel() > 0: + ops.concat_and_cache_mla( + k_c_normed, + k_pe.squeeze(1), + kv_cache, + attn_metadata.slot_mapping.flatten(), + kv_cache_dtype=self.kv_cache_dtype, + scale=layer._k_scale, + ) + + output = torch.empty(attn_metadata.num_prefill_tokens + + attn_metadata.num_decode_tokens, + self.o_proj.output_size, + device=hidden_states_or_q_c.device, + dtype=hidden_states_or_q_c.dtype) + if has_prefill: + output[:num_prefill_tokens] = self._forward_prefill( + prefill_q, prefill_k_c_normed, prefill_k_pe, kv_cache, + attn_metadata) + + if has_decode: + output[num_prefill_tokens:] = self._forward_decode( + decode_q_nope, decode_q_pe, kv_cache, attn_metadata) + + return output diff --git a/vllm/attention/backends/mla/utils.py b/vllm/attention/backends/mla/utils.py deleted file mode 100644 index df3fb2ae..00000000 --- a/vllm/attention/backends/mla/utils.py +++ /dev/null @@ -1,515 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 - -import functools -from abc import abstractmethod -from dataclasses import dataclass -from typing import Any, Dict, Generic, List, Optional, Tuple - -import torch -from compressed_tensors.quantization import QuantizationStrategy - -from vllm import _custom_ops as ops -from vllm import envs -from vllm.attention.backends.abstract import (AttentionLayer, - AttentionMetadata, - MLAAttentionImpl, T) -from vllm.attention.backends.utils import get_flash_attn_version -from vllm.distributed import (get_tensor_model_parallel_world_size, - tensor_model_parallel_all_reduce) -from vllm.model_executor.layers.linear import (ColumnParallelLinear, - LinearBase, RowParallelLinear, - UnquantizedLinearMethod) -from vllm.model_executor.layers.quantization.compressed_tensors.compressed_tensors import ( # noqa: E501 - CompressedTensorsLinearMethod) -from vllm.model_executor.layers.quantization.compressed_tensors.schemes import ( - CompressedTensorsW8A8Fp8) -from vllm.model_executor.layers.quantization.fp8 import Fp8LinearMethod -from vllm.model_executor.layers.quantization.utils.fp8_utils import ( - apply_fp8_linear_generic, current_platform_fp8_dtype, is_fp8) -from vllm.model_executor.layers.quantization.utils.quant_utils import ( - scaled_quantize) -from vllm.model_executor.layers.rotary_embedding import ( - DeepseekScalingRotaryEmbedding, RotaryEmbedding) - -try: - from vllm.vllm_flash_attn import flash_attn_varlen_func -except ImportError: - from flash_attn import flash_attn_varlen_func - - -@dataclass -class MLACommonMetadata(AttentionMetadata): - # Input positions for rotrary embeddings since for MLA the rotary - # position embeddings are applied inside the attention backend - input_positions: torch.Tensor - - -class MLACommonImpl(MLAAttentionImpl[T], Generic[T]): - """ - Common class for implementing repeated parts - - Main reference: DeepseekV2 paper, and FlashInfer Implementation - (https://arxiv.org/abs/2405.04434 and https://github.com/flashinfer-ai/flashinfer/pull/551). - - Deepseek's MLA attention works the following way: - * Use a single latent vector to represent the entire KV cache. - * The attention "simulates" a multi-head attention, while the compute is - similar to multi-query attention. - * The dataflow is as follows, - - * B: batch/sequence length - * H: hidden size - * N: number of attention heads - * Lq: latent dimension for Q - * Lkv: latent dimension for K/V - * P: nope dimension, P+R is the actual head_dim in common attention. - * R: rope dimension, this slide of the head_dim goes through rope. - * V: V head dim. - * kv_c: latent/compressed KV - * q_c: latent/compressed Q - - # - # Outside the MLA attention backend - # - - 1. The hidden states (B, H) are projected down into cq (B, Lq) and - kv_c_k_pe (B, Lkv+R). - 2. The kv_c_k_pe is split into kv_c (B, Lkv) and k_pe (B, R). cq - and kv_c are normalized. - - # - # Inside the MLA attention backend - # - - * if prefill: - - 3. The q_c is then projected up into the multi-head version. - * q_c goes from (B, Lq) to (B, N, (P+R)), which is split into q_nope - (B, N, P) and q_pe (B, N, R). - 4. q_pe, k_pe are then passed through rotary embeddings. - 5. kv_c and k_pe are concatenated and inserted into the cache - 6. The kv_c is then projected up into the multi-head version. - * kv_c goes from (B, Lkv) to (B, N, (P+V)) which has the nope - dimensions for K and V, which is split into k_nope (B, N, P) - and v (B, N, V). - 7. q (B, N, (P+R)) and k (B, N, (P+R)) matrices are assembled from - q_nope, q_pe, k_nope, k_pe. - 8. Attention is computued with q, k, v. - 9. The attention computation returns (B, N, V), which is projected back - to (B, H) using out projection. - - * if decode: - - 3. Here's the change, we do not perform up the full up projection for - q_c, and there is no up projection at all for kv_c. This is - achieved by the technique of "weight absorption". The paper says - "Fortunately, due to the associative law of matrix multiplication, - we can absorb WUK into WUQ, and WUV into WO" - * The q up projection turns (B, Lq) into (B, N, (P+R)), we split it - into W_UQ (Lq, N, P) and W_QR (Lq, N, R). - * The kv_c up projection turns (B, Lkv) into (B, N, (P+V)), we split - it into W_UK (Lkv, N, P) and W_UV (Lkv, N, V). - * The out projection shape W_O (N*V, H) turns (B, N, V) into (B, H). - * We can precompute the product of W_UQ and W_UK into - W_UQ_UK (Lq, N, Lkv), which is possible due to QK^T operation in - attention. - * We can precompute the product of W_UV and W_O into - W_UV_O (N, Lkv, H), which is possible due to V@O as the - "epilogue" of attention - 4. We still need to compute q_pe (B, N, R) by applying W_QR to q_latent. - 5. q_pe, k_pe are then passed through rotary embeddings. - 6. kv_c and k_pe are concatenated and inserted into the cache - 7. By applying W_UQ_UK to q_latent, we have the new q_nope of shape - (B, N, Lkv). - 8. q (B, N, (Lkv+R)), k (B, (Lkv+R)) are assembled from q_nope, q_pe, - kv_a, k_pe. v (B, Lkv) is exactly the same vector as kv_a. - 9. The attention is computed with q, k, v. Note that we just performed - a MQA attention with (LKv+R) as our head dim. - 10. The KV cache is updated using the new entries k (B, N, (Lkv+R)), - which included the v and rope values. - 11. The attention computation returns (B, N, Lkv), which is projected - back to (B, H) using W_UV_O. - - From @tsu-bin's calculation, we only want to use the absorption technique - for decode. The prefill algorithm should still use the up-projected MHA - for less flops and memory usage. - - """ - - def __init__( - self, - num_heads: int, - head_size: int, - scale: float, - num_kv_heads: int, - alibi_slopes: Optional[List[float]], - sliding_window: Optional[int], - kv_cache_dtype: str, - blocksparse_params: Optional[Dict[str, Any]], - logits_soft_cap: Optional[float], - attn_type: str, - # MLA Specific Arguments - q_lora_rank: Optional[int], - kv_lora_rank: int, - qk_nope_head_dim: int, - qk_rope_head_dim: int, - qk_head_dim: int, - v_head_dim: int, - rotary_emb: RotaryEmbedding, - # q_proj should be q_b_proj if q_lora_rank is not None, but from an - # attention backend perspective we rely on the layer to pass in the - # correct matrix - q_proj: ColumnParallelLinear, - kv_b_proj: ColumnParallelLinear, - o_proj: RowParallelLinear, - ) -> None: - self.num_heads = num_heads - self.head_size = head_size - self.scale = float(scale) - self.num_kv_heads = num_kv_heads - self.kv_cache_dtype = kv_cache_dtype - - self.q_lora_rank = q_lora_rank - self.kv_lora_rank = kv_lora_rank - self.qk_nope_head_dim = qk_nope_head_dim - self.qk_rope_head_dim = qk_rope_head_dim - self.qk_head_dim = qk_head_dim - self.v_head_dim = v_head_dim - - self.rotary_emb = rotary_emb - self.use_yarn_rope = isinstance(rotary_emb, - DeepseekScalingRotaryEmbedding) - self.q_proj = q_proj - self.kv_b_proj = kv_b_proj - self.o_proj = o_proj - self.vllm_flash_attn_version = get_flash_attn_version() - - # Handle the differences between the flash_attn_varlen from flash_attn - # and the one from vllm_flash_attn. The former is used on RoCM and the - # latter has an additional parameter to control FA2 vs FA3 - self.flash_attn_varlen_func = flash_attn_varlen_func - if self.vllm_flash_attn_version is not None: - self.flash_attn_varlen_func = \ - functools.partial(flash_attn_varlen_func, - fa_version=self.vllm_flash_attn_version) - - def _v_up_proj_and_o_proj(self, x): - if envs.VLLM_MLA_PERFORM_MATRIX_ABSORPTION: - if is_fp8(self.W_UV_O): - output_parallel = apply_fp8_linear_generic( - x.flatten(start_dim=1), self.W_UV_O, self.W_UV_O_scales, - self.reqaunt_input_group_shape, - self.reqaunt_weight_group_shape) - else: - output_parallel = torch.matmul(x.flatten(start_dim=1), - self.W_UV_O) - if self.tp_size > 1: - output = tensor_model_parallel_all_reduce(output_parallel) - else: - output = output_parallel - return output - else: - x = torch.einsum("bnl,lnv->bnv", x, self.W_UV) - return self.o_proj(x.reshape(-1, - self.num_heads * self.v_head_dim))[0] - - def _q_proj_and_k_up_proj(self, x): - if envs.VLLM_MLA_PERFORM_MATRIX_ABSORPTION: - if is_fp8(self.W_Q_UK): - return apply_fp8_linear_generic( - x, self.W_Q_UK, self.W_Q_UK_scales, - self.reqaunt_input_group_shape, - self.reqaunt_weight_group_shape).view( - -1, self.num_heads, self.kv_lora_rank) - return torch.matmul(x, self.W_Q_UK)\ - .view(-1, self.num_heads, self.kv_lora_rank) - else: - x = torch.matmul(x, self.W_Q)\ - .view(-1, self.num_heads, self.qk_nope_head_dim) - return torch.einsum("bnp,lnp->bnl", x, self.W_UK)\ - .view(-1, self.num_heads, self.kv_lora_rank) - - def process_weights_after_loading(self, act_dtype: torch.dtype): - # TODO(lucas) This is very gross, we need a more wide scale refactor of - # all the FP8 code with a more standard way of - # defining schemes/group-shapes, we should also potentially force - # quant_methods to support a decompress function - # - # returns input_group_shape, weight_group_shape - def get_scale_group_shapes_for_fp8(layer: LinearBase) -> \ - Tuple[Tuple[int, int], Tuple[int, int]]: - if isinstance(layer.quant_method, Fp8LinearMethod): - if layer.quant_method.block_quant: - weight_block_size = \ - layer.quant_method.quant_config.weight_block_size - # per-token-group (1, X), block-quantized (X, Y) - return (1, weight_block_size[-1]), weight_block_size - else: - return (-1, -1), (-1, -1) # per-tensor, per-tensor - elif isinstance(layer.quant_method, CompressedTensorsLinearMethod)\ - and isinstance(layer.scheme, CompressedTensorsW8A8Fp8): - # this is hacky but we always assume the for - # CompressedTensorsW8A8Fp8 the input is dynamic per-token - # we ignore if it is static-per-tensor since we are going to - # requantize after later anyways - strategy = layer.scheme.strategy - if strategy == QuantizationStrategy.TENSOR: - return (1, -1), (-1, -1) # per-token, per-tensor - elif strategy == QuantizationStrategy.CHANNEL: - return (1, -1), (-1, 1) # per-token, per-channel - else: - raise NotImplementedError( - f"QuantizationStrategy.{strategy} is not supported for " - "fp8 MLA, please run with VLLM_MLA_DISABLE=1") - else: - raise NotImplementedError( - "Can't determine scale group shapes for " - f"{layer.quant_method}, please run with VLLM_MLA_DISABLE=1" - ) - - def get_layer_weight(layer): - if hasattr(layer, "weight"): - return layer.weight - elif hasattr(layer, "qweight"): - return layer.qweight - else: - raise AttributeError( - f"Layer '{layer}' has neither weight nor qweight") - - def get_and_maybe_dequant_weights(layer: LinearBase): - if not isinstance(layer.quant_method, UnquantizedLinearMethod): - # NOTE: This should only be used offline, since it's O(N^3) - eye = torch.eye(layer.input_size_per_partition, - dtype=act_dtype, - device=get_layer_weight(layer).device) - dequant_weights = layer.quant_method.apply(layer, - eye, - bias=None) - del eye - # standardize to (output, input) - return dequant_weights.T - return layer.weight - - weight_dtype = get_layer_weight(self.kv_b_proj).dtype - assert get_layer_weight(self.o_proj).dtype == weight_dtype - assert get_layer_weight(self.q_proj).dtype == weight_dtype - - kv_b_proj_weight = get_and_maybe_dequant_weights(self.kv_b_proj).T - assert kv_b_proj_weight.shape == ( - self.kv_lora_rank, - self.num_heads * (self.qk_nope_head_dim + self.v_head_dim)), ( - f"{kv_b_proj_weight.shape=}, " - f"{self.kv_lora_rank=}, " - f"{self.num_heads=}, " - f"{self.qk_nope_head_dim=}, " - f"{self.v_head_dim=}") - kv_b_proj_weight = kv_b_proj_weight.view( - self.kv_lora_rank, - self.num_heads, - self.qk_nope_head_dim + self.v_head_dim, - ) - - W_UK, W_UV = kv_b_proj_weight.split( - [self.qk_nope_head_dim, self.v_head_dim], dim=-1) - - q_proj_weight = get_and_maybe_dequant_weights(self.q_proj).T\ - .view(-1, self.num_heads, self.qk_head_dim) - - # can be W_Q or W_UQ depending q_lora_rank, the former if - # q_lora_rank is None, the latter otherwise. From the Attention backend - # perspective though we call these both W_Q and rely on the layer - # to pass in the correct matrix - W_Q = q_proj_weight[..., :self.qk_nope_head_dim] - self.W_QR = q_proj_weight[..., self.qk_nope_head_dim:]\ - .flatten(start_dim=1).contiguous() - - # W_QR is small so for simplicity we dont bother requantizing it - self.W_QR = self.W_QR.to(act_dtype) - - if envs.VLLM_MLA_PERFORM_MATRIX_ABSORPTION: - requantization_enabled = not envs.VLLM_MLA_DISABLE_REQUANTIZATION - if is_fp8(weight_dtype) and requantization_enabled: - # This assumes it wise to requantize using the same group shapes - # (i.e. strategy, per-tensor, per-channel, block etc.) that the - # weights were originally quantized - requant_input_group_shape, requant_weight_group_shape = \ - get_scale_group_shapes_for_fp8(self.q_proj) - assert (requant_input_group_shape, requant_weight_group_shape)\ - == get_scale_group_shapes_for_fp8(self.kv_b_proj) - assert (requant_input_group_shape, requant_weight_group_shape)\ - == get_scale_group_shapes_for_fp8(self.o_proj) - self.reqaunt_input_group_shape = requant_input_group_shape - self.reqaunt_weight_group_shape = requant_weight_group_shape - - # - # Perform matrix-absorption following - # https://github.com/flashinfer-ai/flashinfer/pull/551 - # for decode, as a result we end up with absorbed weights for decode - # and another copy of raw weights for prefill. - # - self.W_UK, self.W_UV = kv_b_proj_weight.split( - [self.qk_nope_head_dim, self.v_head_dim], dim=-1) - # We absorb `W_UK` into `W_Q` resulting in either W_Q_UK or W_UQ_UK - # depending q_lora_rank, the former if q_lora_rank is None, the - # latter otherwise - # basically if q_lora_rank is none we are absorbing into q_proj - # instead of UQ - W_Q_UK = torch.einsum("qnd,lnd -> qnl", W_Q, W_UK)\ - .flatten(start_dim=1).contiguous() - - if is_fp8(weight_dtype) and requantization_enabled: - W_Q_UK, W_Q_UK_scales = scaled_quantize( - W_Q_UK, - self.reqaunt_weight_group_shape, - quant_dtype=current_platform_fp8_dtype) - # For FP8 save the transpose so we can use - # `apply_w8a8_block_fp8_linear` directly - self.W_Q_UK = W_Q_UK.T.contiguous() - self.W_Q_UK_scales = W_Q_UK_scales.T.contiguous() - else: - self.W_Q_UK = W_Q_UK.to(act_dtype) - - W_O = get_and_maybe_dequant_weights(self.o_proj)\ - .view(-1, self.num_heads, self.v_head_dim) - W_UV_O = torch.einsum("lnd,hnd -> nlh", W_UV, W_O)\ - .flatten(start_dim=0, end_dim=1).contiguous() - - if is_fp8(weight_dtype) and requantization_enabled: - W_UV_O, W_UV_O_scales = scaled_quantize( - W_UV_O, - self.reqaunt_weight_group_shape, - quant_dtype=current_platform_fp8_dtype) - # For FP8 save the transpose so we can use - # `apply_w8a8_block_fp8_linear` directly - self.W_UV_O = W_UV_O.T.contiguous() - self.W_UV_O_scales = W_UV_O_scales.T.contiguous() - else: - self.W_UV_O = W_UV_O.to(act_dtype) - - self.tp_size = get_tensor_model_parallel_world_size() - else: - if is_fp8(weight_dtype): - raise NotImplementedError( - "Currently fp8 requires matrix absorption") - - self.W_UV = W_UV - self.W_UK = W_UK - self.W_Q = W_Q.flatten(start_dim=1) - - @abstractmethod - def _forward_prefill( - self, - q: torch.Tensor, - kv_c_normed: torch.Tensor, - k_pe: torch.Tensor, - attn_metadata: T, - ) -> torch.Tensor: - raise NotImplementedError - - @abstractmethod - def _forward_decode( - self, - q_nope: torch.Tensor, - q_pe: torch.Tensor, - kv_cache: torch.Tensor, - attn_metadata: T, - ) -> torch.Tensor: - raise NotImplementedError - - def forward( - self, - layer: AttentionLayer, - hidden_states_or_q_c: torch.Tensor, # query in unified attn - k_c_normed: torch.Tensor, # key in unified attn - k_pe: torch.Tensor, # value in unified attn - kv_cache: torch.Tensor, - attn_metadata: T, - output: Optional[torch.Tensor] = None, - ) -> torch.Tensor: - if output is not None: - raise NotImplementedError( - "output is not yet supported for MLAImplBase") - - is_decode = attn_metadata.decode_metadata is not None - is_prefill = attn_metadata.prefill_metadata is not None - - if (is_decode and is_prefill): - raise NotImplementedError( - "chunked prefill is not supported for MLAImplBase") - - # Restore head dim (for rotary embedding) - k_pe = k_pe.unsqueeze(1) - assert hasattr(attn_metadata, "input_positions") - - if is_decode: - q_nope = self._q_proj_and_k_up_proj(hidden_states_or_q_c) - q_pe = torch.matmul(hidden_states_or_q_c, self.W_QR)\ - .view(-1, self.num_heads, self.qk_rope_head_dim) - q_pe, k_pe = self.rotary_emb(attn_metadata.input_positions, q_pe, - k_pe) - else: - assert is_prefill - q = self.q_proj(hidden_states_or_q_c)[0]\ - .view(-1, self.num_heads, self.qk_head_dim) - - # TODO(lucas): there must be a nicer way to write this line - q[..., self.qk_nope_head_dim:], k_pe = \ - self.rotary_emb( - attn_metadata.input_positions, - q[..., self.qk_nope_head_dim:], k_pe) - - # write the latent and rope to kv cache - if kv_cache.numel() > 0: - ops.concat_and_cache_mla( - k_c_normed, - k_pe.squeeze(1), - kv_cache, - attn_metadata.slot_mapping.flatten(), - kv_cache_dtype=self.kv_cache_dtype, - scale=layer._k_scale, - ) - - if attn_metadata.prefill_metadata is not None: - return self._forward_prefill(q, k_c_normed, k_pe, attn_metadata) - - if attn_metadata.decode_metadata is not None: - return self._forward_decode(q_nope, q_pe, kv_cache, attn_metadata) - - # Optional common flash-attn based prefill - def _forward_prefill_flash( - self, - q: torch.Tensor, - k_c_normed: torch.Tensor, - k_pe: torch.Tensor, - seq_start_loc: torch.Tensor, - max_prefill_seq_len: int, - ) -> torch.Tensor: - - kv_nope = self.kv_b_proj(k_c_normed)[0]\ - .view(-1, self.num_heads, self.qk_nope_head_dim + self.v_head_dim) - k_nope, v = kv_nope\ - .split([self.qk_nope_head_dim, self.v_head_dim], dim=-1) - - k = torch.cat((k_nope, k_pe.expand((*k_nope.shape[:-1], -1))), dim=-1) - - # For MLA the v head dim is smaller than qk head dim so we pad out - # v with 0s to match the qk head dim - v_padded = torch.nn.functional.pad(v, [0, q.shape[-1] - v.shape[-1]], - value=0) - - attn_output = self.flash_attn_varlen_func( - q=q, - k=k, - v=v_padded, - cu_seqlens_q=seq_start_loc, - cu_seqlens_k=seq_start_loc, - max_seqlen_q=max_prefill_seq_len, - max_seqlen_k=max_prefill_seq_len, - softmax_scale=self.scale, - causal=True, - ) - attn_output = attn_output\ - .view(-1, self.num_heads, q.shape[-1])[..., :v.shape[-1]]\ - .reshape(-1, self.num_heads * v.shape[-1]) - - return self.o_proj(attn_output)[0] diff --git a/vllm/attention/backends/triton_mla.py b/vllm/attention/backends/triton_mla.py index 9a1984a9..08e8226a 100644 --- a/vllm/attention/backends/triton_mla.py +++ b/vllm/attention/backends/triton_mla.py @@ -1,40 +1,17 @@ # SPDX-License-Identifier: Apache-2.0 -from collections import defaultdict -from contextlib import contextmanager -from dataclasses import dataclass -from itertools import accumulate -from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Type - -from vllm.multimodal import MultiModalPlaceholderMap - -try: - from flashinfer import BatchDecodeMlaWithPagedKVCacheWrapper - FLASHINFER_WORKSPACE_BUFFER_SIZE = 256 * 1024 * 1024 -except ImportError: - BatchDecodeMlaWithPagedKVCacheWrapper = None - FLASHINFER_WORKSPACE_BUFFER_SIZE = 0 +from typing import Any, Dict, List, Optional, Type import torch -from vllm import _custom_ops as ops -from vllm.attention.backends.abstract import (AttentionBackend, - AttentionMetadata, - AttentionMetadataBuilder, - AttentionState, AttentionType) -from vllm.attention.backends.mla.utils import MLACommonImpl, MLACommonMetadata -from vllm.attention.backends.utils import (PAD_SLOT_ID, compute_slot_mapping, - compute_slot_mapping_start_idx, - is_block_tables_empty) +from vllm.attention.backends.abstract import AttentionType +from vllm.attention.backends.mla.common import (MLACommonBackend, + MLACommonImpl, + MLACommonMetadata) from vllm.attention.ops.triton_decode_attention import decode_attention_fwd -from vllm.utils import async_tensor_h2d, make_tensor_with_pad - -if TYPE_CHECKING: - from vllm.worker.model_runner import (ModelInputForGPUBuilder, - ModelInputForGPUWithSamplingMetadata) -class TritonMLABackend(AttentionBackend): +class TritonMLABackend(MLACommonBackend): @staticmethod def get_name() -> str: @@ -44,610 +21,8 @@ class TritonMLABackend(AttentionBackend): def get_impl_cls() -> Type["TritonMLAImpl"]: return TritonMLAImpl - @staticmethod - def get_metadata_cls() -> Type["AttentionMetadata"]: - return TritonMLAMetadata - @staticmethod - def get_builder_cls() -> Type["TritonMLAMetadataBuilder"]: - return TritonMLAMetadataBuilder - - @staticmethod - def get_state_cls() -> Type["TritonMLAState"]: - return TritonMLAState - - @staticmethod - def get_kv_cache_shape( - num_blocks: int, - block_size: int, - num_kv_heads: int, # assumed to be 1 for MLA - head_size: int, - ) -> Tuple[int, ...]: - return (num_blocks, block_size, head_size) - - @staticmethod - def swap_blocks( - src_kv_cache: torch.Tensor, - dst_kv_cache: torch.Tensor, - src_to_dst: torch.Tensor, - ) -> None: - ops.swap_blocks(src_kv_cache, dst_kv_cache, src_to_dst) - - @staticmethod - def copy_blocks( - kv_caches: List[torch.Tensor], - src_to_dists: torch.Tensor, - ) -> None: - ops.copy_blocks_mla(kv_caches, src_to_dists) - - @staticmethod - def get_supported_head_sizes() -> List[int]: - return [576] - - -class TritonMLAState(AttentionState): - - def __init__(self, runner): - self.runner = runner - self._is_graph_capturing = False - - @contextmanager - def graph_capture(self, max_batch_size: int): - self._is_graph_capturing = True - - self._graph_slot_mapping = torch.full((max_batch_size, ), - PAD_SLOT_ID, - dtype=torch.long, - device=self.runner.device) - self._graph_seq_lens = torch.ones(max_batch_size, - dtype=torch.int32, - device=self.runner.device) - self._graph_block_tables = torch.from_numpy( - self.runner.graph_block_tables).to(device=self.runner.device) - - self._positions = torch.zeros((max_batch_size, ), - dtype=torch.long, - device=self.runner.device) - - yield - - self._is_graph_capturing = False - del self._graph_slot_mapping - del self._graph_seq_lens - del self._graph_block_tables - del self._positions - - def graph_clone(self, batch_size: int): - assert self._is_graph_capturing - return self.__class__(self.runner) - - def graph_capture_get_metadata_for_batch( - self, batch_size: int, is_encoder_decoder_model: bool = False): - assert self._is_graph_capturing - - attn_metadata = self.runner.attn_backend.make_metadata( - num_prefills=0, - num_prefill_tokens=0, - num_decode_tokens=batch_size, - slot_mapping=self._graph_slot_mapping[:batch_size], - multi_modal_placeholder_index_maps=None, - enable_kv_scales_calculation=True, - seq_lens=None, - seq_lens_tensor=self._graph_seq_lens[:batch_size], - max_query_len=1, - max_decode_query_len=1, - max_prefill_seq_len=0, - max_decode_seq_len=self.runner.max_seq_len_to_capture, - query_start_loc=None, - seq_start_loc=None, - context_lens_tensor=None, - block_tables=self._graph_block_tables[:batch_size], - use_cuda_graph=True, - input_positions=self._positions[:batch_size], - head_dim=self.runner.model_config.get_head_size()) - - if is_encoder_decoder_model: - raise NotImplementedError( - "TritonMLAState does not support encoder/decoder yet") - - return attn_metadata - - def get_graph_input_buffers(self, - attn_metadata, - is_encoder_decoder_model: bool = False): - input_buffers = { - "slot_mapping": attn_metadata.slot_mapping, - "seq_lens_tensor": attn_metadata.decode_metadata.seq_lens_tensor, - "block_tables": attn_metadata.decode_metadata.block_tables, - "input_positions": attn_metadata.decode_metadata.input_positions, - } - if is_encoder_decoder_model: - raise NotImplementedError( - "TritonMLAState does not support encoder/decoder yet") - - return input_buffers - - def prepare_graph_input_buffers(self, - input_buffers, - attn_metadata, - is_encoder_decoder_model: bool = False): - input_positions = attn_metadata.input_positions - num_positions = input_positions.shape[0] - input_buffers["seq_lens_tensor"].copy_( - attn_metadata.decode_metadata.seq_lens_tensor, non_blocking=True) - input_buffers["block_tables"].copy_( - attn_metadata.decode_metadata.block_tables, non_blocking=True) - # CUDA graph buffer is padded so only perform a partial copy based on - # num_positions - input_buffers["input_positions"][:num_positions].copy_( - input_positions, non_blocking=True) - if is_encoder_decoder_model: - raise NotImplementedError( - "TritonMLAState does not support encoder/decoder yet") - - def begin_forward(self, model_input): - return - - -@dataclass -class TritonMLAMetadata(MLACommonMetadata): - """Metadata for TritonMLAMetadata. - - NOTE: Any python object stored here is not updated when it is - cuda-graph replayed. If you have values that need to be changed - dynamically, it should be stored in tensor. The tensor has to be - updated from `CUDAGraphRunner.forward` API. - """ - # (batch_size,). The sequence length per sequence. Sequence length means - # the computed tokens + new tokens None if it is a decoding. - seq_lens: Optional[List[int]] - # seq_lens stored as a tensor. - seq_lens_tensor: Optional[torch.Tensor] - - # NOTE(sang): Definition of context_len, query_len, and seq_len. - # |---------- N-1 iteration --------| - # |---------------- N iteration ---------------------| - # |- tokenA -|......................|-- newTokens ---| - # |---------- context_len ----------| - # |-------------------- seq_len ---------------------| - # |-- query_len ---| - - # Maximum sequence length among prefill batch. 0 if there are decoding - # requests only. - max_prefill_seq_len: int - # Maximum sequence length among decode batch. 0 if there are prefill - # requests only. - max_decode_seq_len: int - # (batch_size,) A tensor of context lengths (tokens that are computed - # so far). - context_lens_tensor: Optional[torch.Tensor] - - # (batch_size, max_blocks_per_seq). - # Block addresses per sequence. (Seq id -> list of physical block) - # E.g., [0, 1, 2] means tokens are stored in 0th, 1st, and 2nd blocks - # in the kv cache. Each block can contain up to block_size tokens. - # 2nd dimensions are padded up to max_blocks_per_seq if it is cuda-graph - # captured. - block_tables: Optional[torch.Tensor] - - # Whether or not if cuda graph is enabled. - # Cuda-graph is currently enabled for decoding only. - # TODO(woosuk): Move `use_cuda_graph` out since it's unrelated to attention. - - use_cuda_graph: bool - - # Maximum query length in the batch. - max_query_len: Optional[int] = None - - # Max number of query tokens among request in the batch. - max_decode_query_len: Optional[int] = None - - # (batch_size + 1,). The cumulative subquery lengths of the sequences in - # the batch, used to index into subquery. E.g., if the subquery length - # is [4, 6], it is [0, 4, 10]. - query_start_loc: Optional[torch.Tensor] = None - # (batch_size + 1,). The cumulative sequence lengths of the sequences in - # the batch, used to index into sequence. E.g., if the sequence length is - # [4, 6], it is [0, 4, 10]. - seq_start_loc: Optional[torch.Tensor] = None - - _cached_prefill_metadata: Optional["TritonMLAMetadata"] = None - _cached_decode_metadata: Optional["TritonMLAMetadata"] = None - - num_prefill_tokens: int - - num_kv_splits: int = 4 # TODO(lucas) add heuristic - attn_logits: Optional[torch.Tensor] = None - req_idx: Optional[torch.Tensor] = None - - # The dimension of the attention heads - head_dim: Optional[int] = None - - def __post_init__(self): - supported_head_sizes = TritonMLABackend.get_supported_head_sizes() - if self.head_dim is not None and self.head_dim \ - not in supported_head_sizes: - raise ValueError( - f"Only {supported_head_sizes} are supported for head_dim,", - f"received {self.head_dim}.") - - @property - def prefill_metadata(self) -> Optional["TritonMLAMetadata"]: - if self.num_prefills == 0: - return None - - if self._cached_prefill_metadata is not None: - return self._cached_prefill_metadata - - assert self.seq_lens is not None - assert self.seq_lens_tensor is not None - - # Compute some attn_metadata fields which default to None - query_start_loc = (None if self.query_start_loc is None else - self.query_start_loc[:self.num_prefills + 1]) - slot_mapping = (None if self.slot_mapping is None else - self.slot_mapping[:self.num_prefill_tokens]) - seq_lens = (None if self.seq_lens is None else - self.seq_lens[:self.num_prefills]) - seq_lens_tensor = (None if self.seq_lens_tensor is None else - self.seq_lens_tensor[:self.num_prefills]) - seq_start_loc = (None if self.seq_start_loc is None else - self.seq_start_loc[:self.num_prefills + 1]) - context_lens_tensor = (None if self.context_lens_tensor is None else - self.context_lens_tensor[:self.num_prefills]) - block_tables = (None if self.block_tables is None else - self.block_tables[:self.num_prefills]) - input_positions = (None if self.input_positions is None else - self.input_positions[:self.num_prefill_tokens]) - - self._cached_prefill_metadata = TritonMLAMetadata( - num_prefills=self.num_prefills, - num_prefill_tokens=self.num_prefill_tokens, - num_decode_tokens=0, - slot_mapping=slot_mapping, - multi_modal_placeholder_index_maps=self. - multi_modal_placeholder_index_maps, - enable_kv_scales_calculation=self.enable_kv_scales_calculation, - input_positions=input_positions, - seq_lens=seq_lens, - seq_lens_tensor=seq_lens_tensor, - max_query_len=self.max_query_len, - max_prefill_seq_len=self.max_prefill_seq_len, - max_decode_query_len=0, - max_decode_seq_len=0, - query_start_loc=query_start_loc, - seq_start_loc=seq_start_loc, - context_lens_tensor=context_lens_tensor, - block_tables=block_tables, - use_cuda_graph=False, - head_dim=self.head_dim) - return self._cached_prefill_metadata - - @property - def decode_metadata(self) -> Optional["TritonMLAMetadata"]: - if self.num_decode_tokens == 0: - return None - - if self._cached_decode_metadata is not None: - return self._cached_decode_metadata - assert self.seq_lens_tensor is not None - - # Compute some attn_metadata fields which default to None - slot_mapping = (None if self.slot_mapping is None else - self.slot_mapping[self.num_prefill_tokens:]) - seq_lens_tensor = (None if self.seq_lens_tensor is None else - self.seq_lens_tensor[self.num_prefills:]) - block_tables = (None if self.block_tables is None else - self.block_tables[self.num_prefills:]) - input_positions = (None if self.input_positions is None else - self.input_positions[self.num_prefill_tokens:]) - - self._cached_decode_metadata = TritonMLAMetadata( - num_prefills=0, - num_prefill_tokens=0, - num_decode_tokens=self.num_decode_tokens, - slot_mapping=slot_mapping, - multi_modal_placeholder_index_maps=None, - enable_kv_scales_calculation=True, - seq_lens=None, - seq_lens_tensor=seq_lens_tensor, - max_decode_query_len=self.max_decode_query_len, - max_query_len=self.max_query_len, - max_prefill_seq_len=0, - max_decode_seq_len=self.max_decode_seq_len, - # Batch may be composed of prefill|decodes, adjust query start - # indices to refer to the start of decodes. E.g. - # in tokens:[3 prefills|6 decodes], query_start_loc=[3,9] => [0,6]. - query_start_loc=(self.query_start_loc[self.num_prefills:] - - self.query_start_loc[self.num_prefills]) - if self.query_start_loc is not None else None, - seq_start_loc=self.seq_start_loc[self.num_prefills:] - if self.seq_start_loc is not None else None, - context_lens_tensor=None, - block_tables=block_tables, - use_cuda_graph=self.use_cuda_graph, - input_positions=input_positions, - head_dim=self.head_dim) - return self._cached_decode_metadata - - def advance_step(self, - model_input: "ModelInputForGPUWithSamplingMetadata", - sampled_token_ids: Optional[torch.Tensor], - block_size: int, - num_seqs: int, - num_queries: int, - turn_prefills_into_decodes: bool = False): - """ - Update metadata in-place to advance one decode step. - """ - # When using cudagraph, the num_seqs is padded to the next captured - # batch sized, but num_queries tracks the actual number of requests in - # the batch. For --enforce-eager mode, num_seqs == num_queries - if num_seqs != num_queries: - assert num_seqs > num_queries - assert self.use_cuda_graph - - if turn_prefills_into_decodes: - # When Mutli-Step is enabled with Chunked-Prefill, prefills and - # decodes are scheduled together. In the first step, all the - # prefills turn into decodes. This update reflects that - # conversion. - assert self.num_decode_tokens + self.num_prefills == num_seqs - self.num_decode_tokens += self.num_prefills - self.num_prefills = 0 - self.num_prefill_tokens = 0 - self.max_prefill_seq_len = 0 - self.max_query_len = 1 - - self.slot_mapping = self.slot_mapping[:num_seqs] - else: - assert self.seq_lens is not None - assert self.max_decode_seq_len == max(self.seq_lens) - - assert self.num_prefills == 0 - assert self.num_prefill_tokens == 0 - assert self.num_decode_tokens == num_seqs - assert self.slot_mapping.shape == (num_seqs, ) - - assert self.seq_lens is not None - assert len(self.seq_lens) == num_seqs - assert self.seq_lens_tensor is not None - assert self.seq_lens_tensor.shape == (num_seqs, ) - assert self.max_query_len == 1 - assert self.max_prefill_seq_len == 0 - - assert self.query_start_loc is not None - assert self.query_start_loc.shape == (num_queries + 1, ) - assert self.seq_start_loc is not None - assert self.seq_start_loc.shape == (num_seqs + 1, ) - - assert self.context_lens_tensor is not None - assert self.context_lens_tensor.shape == (num_queries, ) - - assert self.block_tables is not None - assert self.block_tables.shape[0] == num_seqs - - # Update query lengths. Note that we update only queries and not seqs, - # since tensors may be padded due to captured cuda graph batch size - for i in range(num_queries): - self.seq_lens[i] += 1 - self.max_decode_seq_len = max(self.seq_lens) - - ops.advance_step_flashattn(num_seqs=num_seqs, - num_queries=num_queries, - block_size=block_size, - input_tokens=model_input.input_tokens, - sampled_token_ids=sampled_token_ids, - input_positions=model_input.input_positions, - seq_lens=self.seq_lens_tensor, - slot_mapping=self.slot_mapping, - block_tables=self.block_tables) - - -class TritonMLAMetadataBuilder(AttentionMetadataBuilder[TritonMLAMetadata]): - - def __init__(self, input_builder: "ModelInputForGPUBuilder"): - self.input_builder = input_builder - self.runner = input_builder.runner - self.sliding_window = input_builder.sliding_window - self.block_size = input_builder.block_size - - def prepare(self): - self.slot_mapping: List[int] = [] - self.prefill_seq_lens: List[int] = [] - self.context_lens: List[int] = [] - self.block_tables: List[List[int]] = [] - self.curr_seq_lens: List[int] = [] - self.input_positions: List[int] = [] - self.multimodal_placeholder_maps: Dict[ - str, - MultiModalPlaceholderMap] = defaultdict(MultiModalPlaceholderMap) - self.num_prefills = 0 - self.num_prefill_tokens = 0 - self.num_decode_tokens = 0 - self.has_prefix_cache_hit = False - - def _add_seq_group( - self, inter_data: "ModelInputForGPUBuilder.InterDataForSeqGroup", - chunked_prefill_enabled: bool, prefix_cache_hit: bool): - """Add a sequence group to the metadata. Specifically update/append - 1. context length. - 2. block table. - 3. slot mapping. - """ - is_prompt = inter_data.is_prompt - block_tables = inter_data.block_tables - - for (seq_id, token_len, seq_len, curr_seq_len, query_len, context_len, - curr_sliding_window_block, input_positions) in zip( - inter_data.seq_ids, [len(t) for t in inter_data.input_tokens], - inter_data.orig_seq_lens, inter_data.seq_lens, - inter_data.query_lens, inter_data.context_lens, - inter_data.curr_sliding_window_blocks, - inter_data.input_positions): - self.input_positions.extend(input_positions) - self.context_lens.append(context_len) - if is_prompt: - mm_maps = inter_data.multi_modal_placeholder_maps - if mm_maps: - for modality, placeholders in mm_maps.items(): - self.multimodal_placeholder_maps[modality].extend( - placeholders) - - self.num_prefills += 1 - self.num_prefill_tokens += token_len - self.prefill_seq_lens.append(seq_len) - else: - self.num_decode_tokens += query_len - self.curr_seq_lens.append(curr_seq_len) - - # Compute block table. - # TODO(sang): Combine chunked prefill and prefix caching by - # only allowing multiple of block_size chunk size. - # NOTE: This only works for oooooooxxx style attention. - block_table = [] - if prefix_cache_hit: - # NOTE(woosuk): For flash-attn, the block table should - # include the entries for the incoming prefill tokens. - block_table = block_tables[seq_id] - elif ((chunked_prefill_enabled or not is_prompt) - and block_tables is not None): - if curr_sliding_window_block == 0: - block_table = block_tables[seq_id] - else: - block_table = block_tables[seq_id][ - -curr_sliding_window_block:] - self.block_tables.append(block_table) - - # Compute slot mapping. - is_profile_run = is_block_tables_empty(block_tables) - start_idx = compute_slot_mapping_start_idx(is_prompt, query_len, - context_len, - self.sliding_window) - compute_slot_mapping(is_profile_run, self.slot_mapping, seq_id, - seq_len, context_len, start_idx, - self.block_size, inter_data.block_tables) - - def _get_graph_runner_block_tables( - self, num_seqs: int, - block_tables: List[List[int]]) -> torch.Tensor: - # The shape of graph_block_tables is - # [max batch size, max context len // block size]. - max_batch_size, max_blocks = self.runner.graph_block_tables.shape - assert max_batch_size >= num_seqs - - graph_block_tables = self.runner.graph_block_tables[:num_seqs] - for i, block_table in enumerate(block_tables): - if block_table: - num_blocks = len(block_table) - if num_blocks <= max_blocks: - graph_block_tables[i, :num_blocks] = block_table - else: - # It may be possible to have more blocks allocated due - # to lookahead slots of multi-step, however, they are - # not used anyway, so can be safely ignored. - graph_block_tables[ - i, :max_blocks] = block_table[:max_blocks] - - return torch.from_numpy(graph_block_tables).to( - device=self.runner.device, non_blocking=True) - - def build(self, seq_lens: List[int], query_lens: List[int], - cuda_graph_pad_size: int, batch_size: int): - """Build attention metadata with on-device tensors. - - Args: - seq_lens: The maybe padded sequence lengths of the input sequences. - query_lens: The query lengths of the input sequences. - cuda_graph_pad_size: The padding size for cuda graph. - -1 if cuda graph is not used. - batch_size: The maybe padded batch size. - """ - prefix_cache_hit = any([ - inter_data.prefix_cache_hit - for inter_data in self.input_builder.inter_data_list - ]) - for inter_data in self.input_builder.inter_data_list: - self._add_seq_group(inter_data, - self.input_builder.chunked_prefill_enabled, - prefix_cache_hit) - - device = self.runner.device - use_captured_graph = cuda_graph_pad_size != -1 - - max_query_len = max(query_lens) - decode_query_lens = query_lens[self.num_prefills:] - if len(decode_query_lens) > 0: - max_decode_query_len = max(decode_query_lens) - else: - max_decode_query_len = 1 - max_prefill_seq_len = max(self.prefill_seq_lens, default=0) - max_decode_seq_len = max(self.curr_seq_lens, default=0) - num_decode_tokens = self.num_decode_tokens - query_start_loc = list(accumulate(query_lens, initial=0)) - seq_start_loc = list(accumulate(seq_lens, initial=0)) - - num_seqs = len(seq_lens) - if use_captured_graph: - self.slot_mapping.extend([PAD_SLOT_ID] * cuda_graph_pad_size) - self.block_tables.extend([] * cuda_graph_pad_size) - num_decode_tokens = batch_size - self.num_prefill_tokens - block_tables = self._get_graph_runner_block_tables( - num_seqs, self.block_tables) - else: - block_tables = make_tensor_with_pad( - self.block_tables, - pad=0, - dtype=torch.int, - device=device, - ) - assert max_query_len > 0, ("query_lens: {}".format(query_lens)) - - assert device is not None - context_lens_tensor = async_tensor_h2d(self.context_lens, torch.int, - device, self.runner.pin_memory) - seq_lens_tensor = async_tensor_h2d(seq_lens, torch.int, device, - self.runner.pin_memory) - input_positions = async_tensor_h2d(self.input_positions, torch.long, - device, self.runner.pin_memory) - slot_mapping_tensor = async_tensor_h2d(self.slot_mapping, torch.long, - device, self.runner.pin_memory) - query_start_loc_tensor = async_tensor_h2d(query_start_loc, torch.int32, - device, - self.runner.pin_memory) - seq_start_loc_tensor = async_tensor_h2d(seq_start_loc, torch.int32, - device, self.runner.pin_memory) - placeholder_index_maps = { - modality: placeholder_map.index_map() - for modality, placeholder_map in - self.multimodal_placeholder_maps.items() - } - - return TritonMLAMetadata( - num_prefills=self.num_prefills, - slot_mapping=slot_mapping_tensor, - num_prefill_tokens=self.num_prefill_tokens, - num_decode_tokens=num_decode_tokens, - seq_lens=seq_lens, - multi_modal_placeholder_index_maps=placeholder_index_maps, - enable_kv_scales_calculation=True, - input_positions=input_positions, - seq_lens_tensor=seq_lens_tensor, - max_query_len=max_query_len, - max_decode_query_len=max_decode_query_len, - max_prefill_seq_len=max_prefill_seq_len, - max_decode_seq_len=max_decode_seq_len, - query_start_loc=query_start_loc_tensor, - seq_start_loc=seq_start_loc_tensor, - context_lens_tensor=context_lens_tensor, - block_tables=block_tables, - use_cuda_graph=use_captured_graph, - num_kv_splits=4, # TODO(lucas) add heuristic - head_dim=self.runner.model_config.get_head_size(), - ) - - -class TritonMLAImpl(MLACommonImpl[TritonMLAMetadata]): +class TritonMLAImpl(MLACommonImpl[MLACommonMetadata]): def __init__( self, @@ -662,11 +37,11 @@ class TritonMLAImpl(MLACommonImpl[TritonMLAMetadata]): logits_soft_cap: Optional[float], attn_type: str, # MLA Specific Arguments - **kwargs) -> None: + **mla_args) -> None: super().__init__(num_heads, head_size, scale, num_kv_heads, alibi_slopes, sliding_window, kv_cache_dtype, blocksparse_params, logits_soft_cap, attn_type, - **kwargs) + **mla_args) unsupported_features = [ alibi_slopes, sliding_window, blocksparse_params, logits_soft_cap @@ -683,24 +58,12 @@ class TritonMLAImpl(MLACommonImpl[TritonMLAMetadata]): "are not implemented for " "TritonMLAImpl") - def _forward_prefill( - self, - q: torch.Tensor, - kv_c_normed: torch.Tensor, - k_pe: torch.Tensor, - attn_metadata: TritonMLAMetadata, - ) -> torch.Tensor: - assert isinstance(attn_metadata, TritonMLAMetadata) - return self._forward_prefill_flash(q, kv_c_normed, k_pe, - attn_metadata.seq_start_loc, - attn_metadata.max_prefill_seq_len) - def _forward_decode( self, q_nope: torch.Tensor, q_pe: torch.Tensor, kv_c_and_k_pe_cache: torch.Tensor, - attn_metadata: TritonMLAMetadata, + attn_metadata: MLACommonMetadata, ) -> torch.Tensor: assert kv_c_and_k_pe_cache.numel() > 0 if self.kv_cache_dtype.startswith("fp8"): @@ -717,12 +80,14 @@ class TritonMLAImpl(MLACommonImpl[TritonMLAMetadata]): dtype=q.dtype, device=q.device) + num_kv_splits = 4 # TODO: heuristic + # TODO(lucas) Allocate ahead of time attn_logits = torch.empty( ( B, self.num_heads, - attn_metadata.num_kv_splits, + num_kv_splits, # NOTE(lucas) idk why the +1 is here but sglang has it so we # just mirror that self.kv_lora_rank + 1, @@ -740,7 +105,6 @@ class TritonMLAImpl(MLACommonImpl[TritonMLAMetadata]): decode_attention_fwd(q, kv_c_and_k_pe_cache, kv_c_cache, o, decode_meta.block_tables, decode_meta.seq_lens_tensor, attn_logits, - attn_metadata.num_kv_splits, self.scale, - PAGE_SIZE) + num_kv_splits, self.scale, PAGE_SIZE) return self._v_up_proj_and_o_proj(o) diff --git a/vllm/attention/ops/triton_merge_attn_states.py b/vllm/attention/ops/triton_merge_attn_states.py new file mode 100644 index 00000000..31545b60 --- /dev/null +++ b/vllm/attention/ops/triton_merge_attn_states.py @@ -0,0 +1,84 @@ +# SPDX-License-Identifier: Apache-2.0 +from typing import Optional + +import torch +import triton +import triton.language as tl + + +# Implements section 2.2 of https://www.arxiv.org/pdf/2501.01005 +# can be used to combine partial attention results (in the split-KV case) +def merge_attn_states( + output: torch.Tensor, + prefix_output: torch.Tensor, + prefix_lse: torch.Tensor, + suffix_output: torch.Tensor, + suffix_lse: torch.Tensor, + output_lse: Optional[torch.Tensor] = None, +) -> None: + num_tokens = output.shape[0] + num_query_heads = output.shape[1] + head_size = output.shape[2] + padded_head_size = triton.next_power_of_2(head_size) + + # TODO(woosuk): Use CUDA kernel instead of Triton to minimize CPU overhead. + merge_attn_states_kernel[(num_tokens, num_query_heads)]( + output, + output_lse, + prefix_output, + prefix_lse, + suffix_output, + suffix_lse, + head_size, + padded_head_size, + output_lse is not None, + ) + + +@triton.jit +def merge_attn_states_kernel( + output, # [NUM_TOKENS, NUM_HEADS, HEAD_SIZE] + output_lse, # [NUM_HEADS, NUM_TOKENS] + prefix_output, # [NUM_TOKENS, NUM_HEADS, HEAD_SIZE] + prefix_lse, # [NUM_HEADS, NUM_TOKENS] + suffix_output, # [NUM_TOKENS, NUM_HEADS, HEAD_SIZE] + suffix_lse, # [NUM_HEADS, NUM_TOKENS] + HEAD_SIZE: tl.constexpr, + PADDED_HEAD_SIZE: tl.constexpr, + OUTPUT_LSE: tl.constexpr, +): + token_idx = tl.program_id(0) + num_tokens = tl.num_programs(0) + head_idx = tl.program_id(1) + num_heads = tl.num_programs(1) + + p_lse = tl.load(prefix_lse + head_idx * num_tokens + token_idx) + s_lse = tl.load(suffix_lse + head_idx * num_tokens + token_idx) + max_lse = tl.maximum(p_lse, s_lse) + p_lse = p_lse - max_lse + s_lse = s_lse - max_lse + out_se = (tl.exp(p_lse) + tl.exp(s_lse)) + + if OUTPUT_LSE: + out_lse = tl.log(out_se) + max_lse + tl.store(output_lse + head_idx * num_tokens + token_idx, out_lse) + + head_arange = tl.arange(0, PADDED_HEAD_SIZE) + head_mask = head_arange < HEAD_SIZE + p_out = tl.load(prefix_output + token_idx * num_heads * HEAD_SIZE + + head_idx * HEAD_SIZE + head_arange, + mask=head_mask) + s_out = tl.load(suffix_output + token_idx * num_heads * HEAD_SIZE + + head_idx * HEAD_SIZE + head_arange, + mask=head_mask) + + # NOTE(woosuk): Be careful with the numerical stability. + # We should compute the scale first, and then multiply it with the output. + # Do not multiply the output with tl.exp(p_lse) or tl.exp(s_lse) directly. + p_scale = tl.exp(p_lse) / out_se + s_scale = tl.exp(s_lse) / out_se + out = p_out * p_scale + s_out * s_scale + tl.store(output + token_idx * num_heads * HEAD_SIZE + + head_idx * HEAD_SIZE + head_arange, + out, + mask=head_mask) diff --git a/vllm/config.py b/vllm/config.py index f118004b..d6e197fe 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -3332,19 +3332,6 @@ class VllmConfig: current_platform.check_and_update_config(self) - # If MLA is enabled, force disable chunked prefill and prefix caching - if self.model_config and self.model_config.use_mla: - logger.info("MLA is enabled; forcing chunked prefill and prefix " - "caching to be disabled.") - self.scheduler_config.enable_chunked_prefill = False - self.scheduler_config.chunked_prefill_enabled = False - self.scheduler_config.max_num_batched_tokens = max( - self.scheduler_config.max_model_len, - _DEFAULT_MAX_NUM_BATCHED_TOKENS) - - if self.cache_config is not None: - self.cache_config.enable_prefix_caching = False - if not self.instance_id: self.instance_id = random_uuid()[:5] diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index 5aa77a13..8b460b33 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -1170,9 +1170,9 @@ class EngineArgs: # long context (> 32K) models. This is to avoid OOM errors in the # initial memory profiling phase. - # For multimodal models, chunked prefill is disabled by default in - # V0, but enabled by design in V1 - if model_config.is_multimodal_model: + # For multimodal models and models with MLA, chunked prefill is + # disabled by default in V0, but enabled by design in V1 + if model_config.is_multimodal_model or model_config.use_mla: self.enable_chunked_prefill = bool(envs.VLLM_USE_V1) elif use_long_context: @@ -1207,7 +1207,6 @@ class EngineArgs: msg = "Chunked prefill is not supported for pooling models" raise ValueError(msg) - speculative_config = SpeculativeConfig.maybe_create_spec_config( target_model_config=model_config, target_parallel_config=parallel_config, diff --git a/vllm/model_executor/layers/quantization/utils/fp8_utils.py b/vllm/model_executor/layers/quantization/utils/fp8_utils.py index 9895537c..891edf23 100644 --- a/vllm/model_executor/layers/quantization/utils/fp8_utils.py +++ b/vllm/model_executor/layers/quantization/utils/fp8_utils.py @@ -162,6 +162,9 @@ def _per_token_group_quant_fp8( y_q_ptr, y_s_ptr, group_size, + # Num columns of y + y_num_columns, + y_row_stride, # Avoid to divide zero eps, # Information for float8 @@ -174,9 +177,14 @@ def _per_token_group_quant_fp8( quantization on a tensor. This function converts the tensor values into float8 values. """ + groups_per_row = y_num_columns // group_size + # Map the program id to the row of X and Y it should compute. g_id = tl.program_id(0) - y_ptr += g_id * group_size + row = g_id // groups_per_row + row_g_id = g_id % groups_per_row + + y_ptr += (row * y_row_stride) + (row_g_id * group_size) y_q_ptr += g_id * group_size y_s_ptr += g_id @@ -202,6 +210,7 @@ def _per_token_group_quant_fp8_colmajor( group_size, # Num columns of y y_num_columns, + y_row_stride, # Stride from one column to the next of y_s y_s_col_stride, # Avoid to divide zero @@ -216,9 +225,14 @@ def _per_token_group_quant_fp8_colmajor( quantization on a tensor. This function converts the tensor values into float8 values. """ + groups_per_row = y_num_columns // group_size + # Map the program id to the row of X and Y it should compute. g_id = tl.program_id(0) - y_ptr += g_id * group_size + row = g_id // groups_per_row + row_g_id = g_id % groups_per_row + + y_ptr += (row * y_row_stride) + (row_g_id * group_size) y_q_ptr += g_id * group_size # Convert g_id the flattened block coordinate to 2D so we can index @@ -267,7 +281,7 @@ def per_token_group_quant_fp8( assert (x.shape[-1] % group_size == 0), ( f"the last dimension of `x` {x.shape[-1]} must be divisible " f"by `group_size` {group_size}") - assert x.is_contiguous(), "`x` must be contiguous" + assert x.stride(-1) == 1, "`x` groups must be contiguous" finfo = torch.finfo(dtype) fp8_min = finfo.min @@ -295,6 +309,7 @@ def per_token_group_quant_fp8( x_s, group_size, x.shape[1], + x.stride(0), x_s.stride(1), eps, fp8_min=fp8_min, @@ -309,6 +324,8 @@ def per_token_group_quant_fp8( x_q, x_s, group_size, + x.shape[1], + x.stride(0), eps, fp8_min=fp8_min, fp8_max=fp8_max, diff --git a/vllm/utils.py b/vllm/utils.py index b1bac649..4d3f90c9 100644 --- a/vllm/utils.py +++ b/vllm/utils.py @@ -565,6 +565,10 @@ def round_up(x: int, y: int) -> int: return ((x + y - 1) // y) * y +def round_down(x: int, y: int) -> int: + return (x // y) * y + + def _generate_random_fp8( tensor: torch.Tensor, low: float, diff --git a/vllm/v1/attention/backends/flash_attn.py b/vllm/v1/attention/backends/flash_attn.py index b1b5cc35..1922a3bf 100755 --- a/vllm/v1/attention/backends/flash_attn.py +++ b/vllm/v1/attention/backends/flash_attn.py @@ -5,12 +5,11 @@ from typing import Any, Dict, List, Optional, Tuple, Type import numpy as np import torch -import triton -import triton.language as tl from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl, AttentionMetadata, AttentionType) from vllm.attention.backends.utils import get_flash_attn_version +from vllm.attention.ops.triton_merge_attn_states import merge_attn_states from vllm.logger import init_logger from vllm.platforms import current_platform from vllm.utils import cdiv @@ -372,70 +371,4 @@ def cascade_attention( # Merge prefix and suffix outputs, and store the result in output. merge_attn_states(output, prefix_output, prefix_lse, suffix_output, - suffix_lse) - - -def merge_attn_states( - output: torch.Tensor, - prefix_output: torch.Tensor, - prefix_lse: torch.Tensor, - suffix_output: torch.Tensor, - suffix_lse: torch.Tensor, -) -> None: - num_tokens = output.shape[0] - num_query_heads = output.shape[1] - head_size = output.shape[2] - padded_head_size = triton.next_power_of_2(head_size) - - # TODO(woosuk): Use CUDA kernel instead of Triton to minimize CPU overhead. - merge_attn_states_kernel[(num_tokens, num_query_heads)]( - output, - prefix_output, - prefix_lse, - suffix_output, - suffix_lse, - head_size, - padded_head_size, - ) - - -@triton.jit -def merge_attn_states_kernel( - output, # [NUM_TOKENS, NUM_HEADS, HEAD_SIZE] - prefix_output, # [NUM_TOKENS, NUM_HEADS, HEAD_SIZE] - prefix_lse, # [NUM_HEADS, NUM_TOKENS] - suffix_output, # [NUM_TOKENS, NUM_HEADS, HEAD_SIZE] - suffix_lse, # [NUM_HEADS, NUM_TOKENS] - HEAD_SIZE: tl.constexpr, - PADDED_HEAD_SIZE: tl.constexpr, -): - token_idx = tl.program_id(0) - num_tokens = tl.num_programs(0) - head_idx = tl.program_id(1) - num_heads = tl.num_programs(1) - - p_lse = tl.load(prefix_lse + head_idx * num_tokens + token_idx) - s_lse = tl.load(suffix_lse + head_idx * num_tokens + token_idx) - max_lse = tl.maximum(p_lse, s_lse) - p_lse = p_lse - max_lse - s_lse = s_lse - max_lse - - head_arange = tl.arange(0, PADDED_HEAD_SIZE) - head_mask = head_arange < HEAD_SIZE - p_out = tl.load(prefix_output + token_idx * num_heads * HEAD_SIZE + - head_idx * HEAD_SIZE + head_arange, - mask=head_mask) - s_out = tl.load(suffix_output + token_idx * num_heads * HEAD_SIZE + - head_idx * HEAD_SIZE + head_arange, - mask=head_mask) - - # NOTE(woosuk): Be careful with the numerical stability. - # We should compute the scale first, and then multiply it with the output. - # Do not multiply the output with tl.exp(p_lse) or tl.exp(s_lse) directly. - p_scale = tl.exp(p_lse) / (tl.exp(p_lse) + tl.exp(s_lse)) - s_scale = tl.exp(s_lse) / (tl.exp(p_lse) + tl.exp(s_lse)) - out = p_out * p_scale + s_out * s_scale - tl.store(output + token_idx * num_heads * HEAD_SIZE + - head_idx * HEAD_SIZE + head_arange, - out, - mask=head_mask) + suffix_lse) \ No newline at end of file