[Attention] MLA with chunked prefill (#12639)
Signed-off-by: Lucas Wilkinson <lwilkinson@neuralmagic.com> Signed-off-by: Lucas Wilkinson <lwilkins@redhat.com> Co-authored-by: Patrick Horn <patrick.horn@gmail.com> Co-authored-by: simon-mo <xmo@berkeley.edu> Co-authored-by: Tyler Michael Smith <tyler@neuralmagic.com>
This commit is contained in:
parent
900edbfa48
commit
288cc6c234
@ -39,3 +39,10 @@ void concat_and_cache_mla(torch::Tensor& kv_c, torch::Tensor& k_pe,
|
||||
// Just for unittest
|
||||
void convert_fp8(torch::Tensor& dst_cache, torch::Tensor& src_cache,
|
||||
const double scale, const std::string& kv_cache_dtype);
|
||||
|
||||
void gather_cache(
|
||||
torch::Tensor const& src_cache, // [NUM_BLOCKS, BLOCK_SIZE, ENTRIES...]
|
||||
torch::Tensor const& dst, // [TOT_TOKENS, ENTRIES...]
|
||||
torch::Tensor const& block_table, // [BATCH, BLOCK_INDICES]
|
||||
torch::Tensor const& cu_seq_lens, // [BATCH+1]
|
||||
int64_t batch_size, std::optional<torch::Tensor> seq_starts = std::nullopt);
|
@ -2,6 +2,7 @@
|
||||
#include <ATen/cuda/CUDAContext.h>
|
||||
#include <c10/cuda/CUDAGuard.h>
|
||||
|
||||
#include "cuda_utils.h"
|
||||
#include "cuda_compat.h"
|
||||
#include "dispatch_utils.h"
|
||||
|
||||
@ -570,3 +571,161 @@ void convert_fp8(torch::Tensor& dst_cache, torch::Tensor& src_cache,
|
||||
TORCH_CHECK(false, "Unsupported data type: ", kv_cache_dtype);
|
||||
}
|
||||
}
|
||||
|
||||
namespace vllm {
|
||||
|
||||
// grid is launched with dimensions (batch, num_splits)
|
||||
template <typename scalar_t>
|
||||
__global__ void gather_cache(
|
||||
const scalar_t* __restrict__ src_cache, // [NUM_BLOCKS, BLOCK_SIZE,
|
||||
// ENTRIES...]
|
||||
scalar_t* __restrict__ dst, // [TOT_TOKENS, ENTRIES...]
|
||||
const int32_t* __restrict__ block_table, // [BATCH, BLOCK_INDICES]
|
||||
const int32_t* __restrict__ cu_seq_lens, // [BATCH+1]
|
||||
const int32_t block_size, const int32_t entry_size,
|
||||
const int64_t block_table_stride, const int64_t cache_block_stride,
|
||||
const int64_t cache_entry_stride, const int64_t dst_entry_stride,
|
||||
const int32_t* __restrict__ seq_starts) { // Optional: starting offsets per
|
||||
// batch
|
||||
|
||||
const int64_t bid = blockIdx.x; // Batch ID
|
||||
const int32_t num_splits = gridDim.y;
|
||||
const int32_t split = blockIdx.y;
|
||||
const int32_t seq_start = cu_seq_lens[bid];
|
||||
const int32_t seq_end = cu_seq_lens[bid + 1];
|
||||
const int32_t seq_len = seq_end - seq_start;
|
||||
const int32_t tot_blocks = cuda_utils::ceil_div(seq_len, block_size);
|
||||
const int32_t split_blocks = cuda_utils::ceil_div(tot_blocks, num_splits);
|
||||
|
||||
const int32_t split_start = split * split_blocks;
|
||||
const int32_t split_end = min((split + 1) * split_blocks, tot_blocks);
|
||||
|
||||
const bool is_active_split = (split_start < tot_blocks);
|
||||
const bool is_last_split = (split_end == tot_blocks);
|
||||
|
||||
if (!is_active_split) return;
|
||||
|
||||
int32_t full_blocks_end = split_end;
|
||||
int32_t partial_block_size = 0;
|
||||
|
||||
// Adjust the pointer for the block_table for this batch.
|
||||
// If seq_starts is provided, compute an offset based on (seq_starts[bid] /
|
||||
// page_size)
|
||||
const int32_t batch_offset = bid * block_table_stride;
|
||||
int32_t offset = 0;
|
||||
if (seq_starts != nullptr) {
|
||||
offset = seq_starts[bid] / block_size;
|
||||
}
|
||||
const int32_t* batch_block_table = block_table + batch_offset + offset;
|
||||
|
||||
// Adjust dst pointer based on the cumulative sequence lengths.
|
||||
dst += seq_start * dst_entry_stride;
|
||||
|
||||
if (is_last_split) {
|
||||
partial_block_size = seq_len % block_size;
|
||||
if (partial_block_size) full_blocks_end -= 1;
|
||||
}
|
||||
|
||||
auto copy_entry = [&](const scalar_t* __restrict__ _src,
|
||||
scalar_t* __restrict__ _dst) {
|
||||
for (int i = threadIdx.x; i < entry_size; i += blockDim.x)
|
||||
_dst[i] = _src[i];
|
||||
};
|
||||
|
||||
for (int pid = split_start; pid < full_blocks_end; ++pid) {
|
||||
auto block_id = batch_block_table[pid];
|
||||
auto block_start_ptr = src_cache + block_id * cache_block_stride;
|
||||
auto block_dst_ptr = dst + pid * block_size * dst_entry_stride;
|
||||
for (int eid = 0; eid < block_size; ++eid) {
|
||||
copy_entry(block_start_ptr + eid * cache_entry_stride,
|
||||
block_dst_ptr + eid * dst_entry_stride);
|
||||
}
|
||||
}
|
||||
|
||||
if (partial_block_size) {
|
||||
auto block_id = batch_block_table[full_blocks_end];
|
||||
auto block_start_ptr = src_cache + block_id * cache_block_stride;
|
||||
auto block_dst_ptr = dst + full_blocks_end * block_size * dst_entry_stride;
|
||||
for (int eid = 0; eid < partial_block_size; ++eid) {
|
||||
copy_entry(block_start_ptr + eid * cache_entry_stride,
|
||||
block_dst_ptr + eid * dst_entry_stride);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace vllm
|
||||
|
||||
// Macro to dispatch the kernel based on the data type.
|
||||
#define CALL_GATHER_CACHE(CPY_DTYPE) \
|
||||
vllm::gather_cache<CPY_DTYPE><<<grid, block, 0, stream>>>( \
|
||||
reinterpret_cast<CPY_DTYPE*>(src_cache.data_ptr()), \
|
||||
reinterpret_cast<CPY_DTYPE*>(dst.data_ptr()), \
|
||||
block_table.data_ptr<int32_t>(), cu_seq_lens.data_ptr<int32_t>(), \
|
||||
block_size, entry_size, block_table_stride, cache_block_stride, \
|
||||
cache_entry_stride, dst_entry_stride, seq_starts_ptr);
|
||||
|
||||
// Gather sequences from the cache into the destination tensor.
|
||||
// - cu_seq_lens contains the cumulative sequence lengths for each batch
|
||||
// - block_table contains the cache block indices for each sequence
|
||||
// - Optionally, seq_starts (if provided) offsets the starting block index by
|
||||
// (seq_starts[bid] / page_size)
|
||||
void gather_cache(
|
||||
torch::Tensor const& src_cache, // [NUM_BLOCKS, BLOCK_SIZE, ENTRIES...]
|
||||
torch::Tensor const& dst, // [TOT_TOKENS, ENTRIES...]
|
||||
torch::Tensor const& block_table, // [BATCH, BLOCK_INDICES]
|
||||
torch::Tensor const& cu_seq_lens, // [BATCH+1]
|
||||
int64_t batch_size,
|
||||
std::optional<torch::Tensor> seq_starts = std::nullopt) {
|
||||
at::cuda::OptionalCUDAGuard device_guard(src_cache.device());
|
||||
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
||||
|
||||
int32_t block_size = src_cache.size(1);
|
||||
int32_t entry_size = src_cache.flatten(2, -1).size(2);
|
||||
|
||||
TORCH_CHECK(block_table.dtype() == torch::kInt32,
|
||||
"block_table must be int32");
|
||||
TORCH_CHECK(cu_seq_lens.dtype() == torch::kInt32,
|
||||
"cu_seq_lens must be int32");
|
||||
if (seq_starts.has_value()) {
|
||||
TORCH_CHECK(seq_starts.value().dtype() == torch::kInt32,
|
||||
"seq_starts must be int32");
|
||||
}
|
||||
|
||||
TORCH_CHECK(src_cache.device() == dst.device(),
|
||||
"src_cache and dst must be on the same device");
|
||||
TORCH_CHECK(src_cache.device() == block_table.device(),
|
||||
"src_cache and block_table must be on the same device");
|
||||
TORCH_CHECK(src_cache.device() == cu_seq_lens.device(),
|
||||
"src_cache and cu_seq_lens must be on the same device");
|
||||
if (seq_starts.has_value()) {
|
||||
TORCH_CHECK(src_cache.device() == seq_starts.value().device(),
|
||||
"src_cache and seq_starts must be on the same device");
|
||||
}
|
||||
|
||||
int64_t block_table_stride = block_table.stride(0);
|
||||
int64_t cache_block_stride = src_cache.stride(0);
|
||||
int64_t cache_entry_stride = src_cache.stride(1);
|
||||
int64_t dst_entry_stride = dst.stride(0);
|
||||
|
||||
// Decide on the number of splits based on the batch size.
|
||||
int num_splits = batch_size > 128 ? 2 : batch_size > 64 ? 4 : 16;
|
||||
dim3 grid(batch_size, num_splits);
|
||||
dim3 block(1024);
|
||||
|
||||
TORCH_CHECK(src_cache.dtype() == dst.dtype(),
|
||||
"src_cache and dst must have the same dtype");
|
||||
|
||||
const int dtype_bits = src_cache.element_size() * 8;
|
||||
const int32_t* seq_starts_ptr =
|
||||
seq_starts.has_value() ? seq_starts.value().data_ptr<int32_t>() : nullptr;
|
||||
|
||||
if (dtype_bits == 32) {
|
||||
CALL_GATHER_CACHE(uint32_t);
|
||||
} else if (dtype_bits == 16) {
|
||||
CALL_GATHER_CACHE(uint16_t);
|
||||
} else if (dtype_bits == 8) {
|
||||
CALL_GATHER_CACHE(uint8_t);
|
||||
} else {
|
||||
TORCH_CHECK(false, "Unsupported data type width: ", dtype_bits);
|
||||
}
|
||||
}
|
||||
|
@ -7,8 +7,3 @@ inline constexpr uint32_t next_pow_2(uint32_t const num) {
|
||||
if (num <= 1) return num;
|
||||
return 1 << (CHAR_BIT * sizeof(num) - __builtin_clz(num - 1));
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
inline constexpr std::enable_if_t<std::is_integral_v<T>, T> ceil_div(T a, T b) {
|
||||
return (a + b - 1) / b;
|
||||
}
|
@ -2,10 +2,14 @@
|
||||
|
||||
#include <stdio.h>
|
||||
|
||||
#if defined(__CUDACC__) || defined(_NVHPC_CUDA)
|
||||
#define HOST_DEVICE_INLINE __forceinline__ __host__ __device__
|
||||
#define DEVICE_INLINE __forceinline__ __device__
|
||||
#define HOST_INLINE __forceinline__ __host__
|
||||
#if defined(__HIPCC__)
|
||||
#define HOST_DEVICE_INLINE __host__ __device__
|
||||
#define DEVICE_INLINE __device__
|
||||
#define HOST_INLINE __host__
|
||||
#elif defined(__CUDACC__) || defined(_NVHPC_CUDA)
|
||||
#define HOST_DEVICE_INLINE __host__ __device__ __forceinline__
|
||||
#define DEVICE_INLINE __device__ __forceinline__
|
||||
#define HOST_INLINE __host__ __forceinline__
|
||||
#else
|
||||
#define HOST_DEVICE_INLINE inline
|
||||
#define DEVICE_INLINE inline
|
||||
@ -25,3 +29,13 @@
|
||||
int64_t get_device_attribute(int64_t attribute, int64_t device_id);
|
||||
|
||||
int64_t get_max_shared_memory_per_block_device_attribute(int64_t device_id);
|
||||
|
||||
namespace cuda_utils {
|
||||
|
||||
template <typename T>
|
||||
HOST_DEVICE_INLINE constexpr std::enable_if_t<std::is_integral_v<T>, T>
|
||||
ceil_div(T a, T b) {
|
||||
return (a + b - 1) / b;
|
||||
}
|
||||
|
||||
}; // namespace cuda_utils
|
@ -1,7 +1,7 @@
|
||||
#include <cudaTypedefs.h>
|
||||
#include "c3x/scaled_mm_kernels.hpp"
|
||||
|
||||
#include "core/math.hpp"
|
||||
#include "cuda_utils.h"
|
||||
|
||||
/*
|
||||
This file defines quantized GEMM operations using the CUTLASS 3.x API, for
|
||||
@ -33,7 +33,8 @@ void cutlass_scaled_mm_sm90(torch::Tensor& c, torch::Tensor const& a,
|
||||
auto make_group_shape = [](torch::Tensor const& x,
|
||||
torch::Tensor const& s) -> GroupShape {
|
||||
TORCH_CHECK(s.dim() == 2, "cutlass_scaled_mm group scales must be 2D");
|
||||
return {ceil_div(x.size(0), s.size(0)), ceil_div(x.size(1), s.size(1))};
|
||||
return {cuda_utils::ceil_div(x.size(0), s.size(0)),
|
||||
cuda_utils::ceil_div(x.size(1), s.size(1))};
|
||||
};
|
||||
|
||||
GroupShape a_scale_group_shape = make_group_shape(a, a_scales);
|
||||
|
@ -493,6 +493,12 @@ TORCH_LIBRARY_EXPAND(CONCAT(TORCH_EXTENSION_NAME, _cache_ops), cache_ops) {
|
||||
"convert_fp8(Tensor! dst_cache, Tensor src_cache, float scale, "
|
||||
"str kv_cache_dtype) -> ()");
|
||||
cache_ops.impl("convert_fp8", torch::kCUDA, &convert_fp8);
|
||||
|
||||
// Gather cache blocks from src_cache to dst.
|
||||
cache_ops.def(
|
||||
"gather_cache(Tensor src_cache, Tensor! dst, Tensor block_table, "
|
||||
"Tensor cu_seq_lens, int batch_size, Tensor? seq_starts) -> ()");
|
||||
cache_ops.impl("gather_cache", torch::kCUDA, &gather_cache);
|
||||
}
|
||||
|
||||
TORCH_LIBRARY_EXPAND(CONCAT(TORCH_EXTENSION_NAME, _cuda_utils), cuda_utils) {
|
||||
|
@ -682,8 +682,6 @@ def test_swap_blocks_mla(
|
||||
torch.ops._C_cache_ops.swap_blocks,
|
||||
(src_cache, dst_cache, block_mapping_tensor),
|
||||
test_utils=DEFAULT_OPCHECK_TEST_UTILS,
|
||||
cond=(kv_lora_rank == KV_LORA_RANKS[0]
|
||||
and qk_rope_head_dim == QK_ROPE_HEAD_DIMS[0]),
|
||||
)
|
||||
|
||||
ops.swap_blocks(src_cache, dst_cache, block_mapping_tensor)
|
||||
@ -694,3 +692,76 @@ def test_swap_blocks_mla(
|
||||
dst_cache[dst].cpu(),
|
||||
msg=f"Block {src} from src should have been swapped to block "
|
||||
f"{dst} in dst_cache.")
|
||||
|
||||
|
||||
@pytest.mark.parametrize("kv_lora_rank", [512])
|
||||
@pytest.mark.parametrize("qk_rope_head_dim", [64])
|
||||
@pytest.mark.parametrize("block_size", [16])
|
||||
@pytest.mark.parametrize("num_blocks", [1024])
|
||||
@pytest.mark.parametrize("max_seq_len", [512])
|
||||
@pytest.mark.parametrize("batch_size", [8])
|
||||
@pytest.mark.parametrize("dtype", [torch.float32])
|
||||
@pytest.mark.parametrize("kv_cache_dtype",
|
||||
["auto"]) # You can also test "fp8" if needed.
|
||||
@pytest.mark.parametrize("align_cache", [True, False])
|
||||
@pytest.mark.parametrize("device", CUDA_DEVICES)
|
||||
@torch.inference_mode()
|
||||
def test_gather_cache_mla(kv_lora_rank, qk_rope_head_dim, block_size,
|
||||
num_blocks, max_seq_len, batch_size, dtype,
|
||||
kv_cache_dtype, align_cache, device):
|
||||
entry_size = kv_lora_rank + qk_rope_head_dim
|
||||
src_cache = _create_mla_cache(num_blocks, block_size, entry_size, dtype,
|
||||
kv_cache_dtype, device, align_cache)
|
||||
_fill_mla_cache(src_cache, kv_cache_dtype=kv_cache_dtype)
|
||||
|
||||
seq_len_tensor = torch.randint(0,
|
||||
max_seq_len + 1, (batch_size, ),
|
||||
device=device)
|
||||
|
||||
total_tokens = seq_len_tensor.sum()
|
||||
cu_seq_lens = torch.empty((batch_size + 1),
|
||||
dtype=torch.int32,
|
||||
device=device)
|
||||
cu_seq_lens[0] = 0
|
||||
cu_seq_lens[1:] = seq_len_tensor.cumsum(dim=0).to(dtype=torch.int32)
|
||||
print("seq_len_tensor", seq_len_tensor)
|
||||
|
||||
tot_blocks_tensor = (seq_len_tensor + block_size - 1) // block_size
|
||||
block_table = torch.empty((batch_size, num_blocks),
|
||||
dtype=torch.int32,
|
||||
device=device)
|
||||
|
||||
for b in range(batch_size):
|
||||
perm = torch.randperm(num_blocks, device=device)
|
||||
block_table[b, :] = perm
|
||||
|
||||
dst = torch.zeros((total_tokens, entry_size),
|
||||
dtype=src_cache.dtype,
|
||||
device=device)
|
||||
|
||||
expected_batches = []
|
||||
for b in range(batch_size):
|
||||
s = seq_len_tensor[b]
|
||||
if s == 0:
|
||||
continue
|
||||
tot = tot_blocks_tensor[b]
|
||||
blocks = block_table[b, :tot].tolist()
|
||||
|
||||
gathered_rows = []
|
||||
for i in range(tot - 1):
|
||||
gathered_rows.append(src_cache[blocks[i]])
|
||||
remaining = s - (tot - 1) * block_size
|
||||
gathered_rows.append(src_cache[blocks[-1], :remaining, :])
|
||||
|
||||
batch_expected = torch.cat(gathered_rows, dim=0)
|
||||
expected_batches.append(batch_expected)
|
||||
expected = torch.cat(expected_batches, dim=0)
|
||||
|
||||
opcheck(
|
||||
torch.ops._C_cache_ops.gather_cache,
|
||||
(src_cache, dst, block_table, cu_seq_lens, batch_size, None),
|
||||
test_utils=DEFAULT_OPCHECK_TEST_UTILS,
|
||||
)
|
||||
|
||||
ops.gather_cache(src_cache, dst, block_table, cu_seq_lens, batch_size)
|
||||
torch.testing.assert_close(dst, expected)
|
||||
|
@ -1099,6 +1099,16 @@ def convert_fp8(output: torch.Tensor,
|
||||
torch.ops._C_cache_ops.convert_fp8(output, input, scale, kv_dtype)
|
||||
|
||||
|
||||
def gather_cache(src_cache: torch.Tensor,
|
||||
dst: torch.Tensor,
|
||||
block_table: torch.Tensor,
|
||||
cu_seq_lens: torch.Tensor,
|
||||
batch_size: int,
|
||||
seq_starts: Optional[torch.Tensor] = None) -> None:
|
||||
torch.ops._C_cache_ops.gather_cache(src_cache, dst, block_table,
|
||||
cu_seq_lens, batch_size, seq_starts)
|
||||
|
||||
|
||||
def get_device_attribute(attribute: int, device: int) -> int:
|
||||
return torch.ops._C_cuda_utils.get_device_attribute(attribute, device)
|
||||
|
||||
|
@ -4,16 +4,12 @@ from vllm.attention.backends.abstract import (AttentionBackend,
|
||||
AttentionMetadata,
|
||||
AttentionMetadataBuilder,
|
||||
AttentionState, AttentionType)
|
||||
from vllm.attention.backends.utils import get_flash_attn_version
|
||||
from vllm.attention.layer import Attention
|
||||
from vllm.attention.selector import get_attn_backend
|
||||
|
||||
__all__ = [
|
||||
"Attention",
|
||||
"AttentionBackend",
|
||||
"AttentionMetadata",
|
||||
"AttentionType",
|
||||
"AttentionMetadataBuilder",
|
||||
"Attention",
|
||||
"AttentionState",
|
||||
"get_attn_backend",
|
||||
"Attention", "AttentionBackend", "AttentionMetadata", "AttentionType",
|
||||
"AttentionMetadataBuilder", "Attention", "AttentionState",
|
||||
"get_attn_backend", "get_flash_attn_version"
|
||||
]
|
||||
|
1503
vllm/attention/backends/mla/common.py
Normal file
1503
vllm/attention/backends/mla/common.py
Normal file
File diff suppressed because it is too large
Load Diff
@ -1,515 +0,0 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
import functools
|
||||
from abc import abstractmethod
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Dict, Generic, List, Optional, Tuple
|
||||
|
||||
import torch
|
||||
from compressed_tensors.quantization import QuantizationStrategy
|
||||
|
||||
from vllm import _custom_ops as ops
|
||||
from vllm import envs
|
||||
from vllm.attention.backends.abstract import (AttentionLayer,
|
||||
AttentionMetadata,
|
||||
MLAAttentionImpl, T)
|
||||
from vllm.attention.backends.utils import get_flash_attn_version
|
||||
from vllm.distributed import (get_tensor_model_parallel_world_size,
|
||||
tensor_model_parallel_all_reduce)
|
||||
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
|
||||
LinearBase, RowParallelLinear,
|
||||
UnquantizedLinearMethod)
|
||||
from vllm.model_executor.layers.quantization.compressed_tensors.compressed_tensors import ( # noqa: E501
|
||||
CompressedTensorsLinearMethod)
|
||||
from vllm.model_executor.layers.quantization.compressed_tensors.schemes import (
|
||||
CompressedTensorsW8A8Fp8)
|
||||
from vllm.model_executor.layers.quantization.fp8 import Fp8LinearMethod
|
||||
from vllm.model_executor.layers.quantization.utils.fp8_utils import (
|
||||
apply_fp8_linear_generic, current_platform_fp8_dtype, is_fp8)
|
||||
from vllm.model_executor.layers.quantization.utils.quant_utils import (
|
||||
scaled_quantize)
|
||||
from vllm.model_executor.layers.rotary_embedding import (
|
||||
DeepseekScalingRotaryEmbedding, RotaryEmbedding)
|
||||
|
||||
try:
|
||||
from vllm.vllm_flash_attn import flash_attn_varlen_func
|
||||
except ImportError:
|
||||
from flash_attn import flash_attn_varlen_func
|
||||
|
||||
|
||||
@dataclass
|
||||
class MLACommonMetadata(AttentionMetadata):
|
||||
# Input positions for rotrary embeddings since for MLA the rotary
|
||||
# position embeddings are applied inside the attention backend
|
||||
input_positions: torch.Tensor
|
||||
|
||||
|
||||
class MLACommonImpl(MLAAttentionImpl[T], Generic[T]):
|
||||
"""
|
||||
Common class for implementing repeated parts
|
||||
|
||||
Main reference: DeepseekV2 paper, and FlashInfer Implementation
|
||||
(https://arxiv.org/abs/2405.04434 and https://github.com/flashinfer-ai/flashinfer/pull/551).
|
||||
|
||||
Deepseek's MLA attention works the following way:
|
||||
* Use a single latent vector to represent the entire KV cache.
|
||||
* The attention "simulates" a multi-head attention, while the compute is
|
||||
similar to multi-query attention.
|
||||
* The dataflow is as follows,
|
||||
|
||||
* B: batch/sequence length
|
||||
* H: hidden size
|
||||
* N: number of attention heads
|
||||
* Lq: latent dimension for Q
|
||||
* Lkv: latent dimension for K/V
|
||||
* P: nope dimension, P+R is the actual head_dim in common attention.
|
||||
* R: rope dimension, this slide of the head_dim goes through rope.
|
||||
* V: V head dim.
|
||||
* kv_c: latent/compressed KV
|
||||
* q_c: latent/compressed Q
|
||||
|
||||
#
|
||||
# Outside the MLA attention backend
|
||||
#
|
||||
|
||||
1. The hidden states (B, H) are projected down into cq (B, Lq) and
|
||||
kv_c_k_pe (B, Lkv+R).
|
||||
2. The kv_c_k_pe is split into kv_c (B, Lkv) and k_pe (B, R). cq
|
||||
and kv_c are normalized.
|
||||
|
||||
#
|
||||
# Inside the MLA attention backend
|
||||
#
|
||||
|
||||
* if prefill:
|
||||
|
||||
3. The q_c is then projected up into the multi-head version.
|
||||
* q_c goes from (B, Lq) to (B, N, (P+R)), which is split into q_nope
|
||||
(B, N, P) and q_pe (B, N, R).
|
||||
4. q_pe, k_pe are then passed through rotary embeddings.
|
||||
5. kv_c and k_pe are concatenated and inserted into the cache
|
||||
6. The kv_c is then projected up into the multi-head version.
|
||||
* kv_c goes from (B, Lkv) to (B, N, (P+V)) which has the nope
|
||||
dimensions for K and V, which is split into k_nope (B, N, P)
|
||||
and v (B, N, V).
|
||||
7. q (B, N, (P+R)) and k (B, N, (P+R)) matrices are assembled from
|
||||
q_nope, q_pe, k_nope, k_pe.
|
||||
8. Attention is computued with q, k, v.
|
||||
9. The attention computation returns (B, N, V), which is projected back
|
||||
to (B, H) using out projection.
|
||||
|
||||
* if decode:
|
||||
|
||||
3. Here's the change, we do not perform up the full up projection for
|
||||
q_c, and there is no up projection at all for kv_c. This is
|
||||
achieved by the technique of "weight absorption". The paper says
|
||||
"Fortunately, due to the associative law of matrix multiplication,
|
||||
we can absorb WUK into WUQ, and WUV into WO"
|
||||
* The q up projection turns (B, Lq) into (B, N, (P+R)), we split it
|
||||
into W_UQ (Lq, N, P) and W_QR (Lq, N, R).
|
||||
* The kv_c up projection turns (B, Lkv) into (B, N, (P+V)), we split
|
||||
it into W_UK (Lkv, N, P) and W_UV (Lkv, N, V).
|
||||
* The out projection shape W_O (N*V, H) turns (B, N, V) into (B, H).
|
||||
* We can precompute the product of W_UQ and W_UK into
|
||||
W_UQ_UK (Lq, N, Lkv), which is possible due to QK^T operation in
|
||||
attention.
|
||||
* We can precompute the product of W_UV and W_O into
|
||||
W_UV_O (N, Lkv, H), which is possible due to V@O as the
|
||||
"epilogue" of attention
|
||||
4. We still need to compute q_pe (B, N, R) by applying W_QR to q_latent.
|
||||
5. q_pe, k_pe are then passed through rotary embeddings.
|
||||
6. kv_c and k_pe are concatenated and inserted into the cache
|
||||
7. By applying W_UQ_UK to q_latent, we have the new q_nope of shape
|
||||
(B, N, Lkv).
|
||||
8. q (B, N, (Lkv+R)), k (B, (Lkv+R)) are assembled from q_nope, q_pe,
|
||||
kv_a, k_pe. v (B, Lkv) is exactly the same vector as kv_a.
|
||||
9. The attention is computed with q, k, v. Note that we just performed
|
||||
a MQA attention with (LKv+R) as our head dim.
|
||||
10. The KV cache is updated using the new entries k (B, N, (Lkv+R)),
|
||||
which included the v and rope values.
|
||||
11. The attention computation returns (B, N, Lkv), which is projected
|
||||
back to (B, H) using W_UV_O.
|
||||
|
||||
From @tsu-bin's calculation, we only want to use the absorption technique
|
||||
for decode. The prefill algorithm should still use the up-projected MHA
|
||||
for less flops and memory usage.
|
||||
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
num_heads: int,
|
||||
head_size: int,
|
||||
scale: float,
|
||||
num_kv_heads: int,
|
||||
alibi_slopes: Optional[List[float]],
|
||||
sliding_window: Optional[int],
|
||||
kv_cache_dtype: str,
|
||||
blocksparse_params: Optional[Dict[str, Any]],
|
||||
logits_soft_cap: Optional[float],
|
||||
attn_type: str,
|
||||
# MLA Specific Arguments
|
||||
q_lora_rank: Optional[int],
|
||||
kv_lora_rank: int,
|
||||
qk_nope_head_dim: int,
|
||||
qk_rope_head_dim: int,
|
||||
qk_head_dim: int,
|
||||
v_head_dim: int,
|
||||
rotary_emb: RotaryEmbedding,
|
||||
# q_proj should be q_b_proj if q_lora_rank is not None, but from an
|
||||
# attention backend perspective we rely on the layer to pass in the
|
||||
# correct matrix
|
||||
q_proj: ColumnParallelLinear,
|
||||
kv_b_proj: ColumnParallelLinear,
|
||||
o_proj: RowParallelLinear,
|
||||
) -> None:
|
||||
self.num_heads = num_heads
|
||||
self.head_size = head_size
|
||||
self.scale = float(scale)
|
||||
self.num_kv_heads = num_kv_heads
|
||||
self.kv_cache_dtype = kv_cache_dtype
|
||||
|
||||
self.q_lora_rank = q_lora_rank
|
||||
self.kv_lora_rank = kv_lora_rank
|
||||
self.qk_nope_head_dim = qk_nope_head_dim
|
||||
self.qk_rope_head_dim = qk_rope_head_dim
|
||||
self.qk_head_dim = qk_head_dim
|
||||
self.v_head_dim = v_head_dim
|
||||
|
||||
self.rotary_emb = rotary_emb
|
||||
self.use_yarn_rope = isinstance(rotary_emb,
|
||||
DeepseekScalingRotaryEmbedding)
|
||||
self.q_proj = q_proj
|
||||
self.kv_b_proj = kv_b_proj
|
||||
self.o_proj = o_proj
|
||||
self.vllm_flash_attn_version = get_flash_attn_version()
|
||||
|
||||
# Handle the differences between the flash_attn_varlen from flash_attn
|
||||
# and the one from vllm_flash_attn. The former is used on RoCM and the
|
||||
# latter has an additional parameter to control FA2 vs FA3
|
||||
self.flash_attn_varlen_func = flash_attn_varlen_func
|
||||
if self.vllm_flash_attn_version is not None:
|
||||
self.flash_attn_varlen_func = \
|
||||
functools.partial(flash_attn_varlen_func,
|
||||
fa_version=self.vllm_flash_attn_version)
|
||||
|
||||
def _v_up_proj_and_o_proj(self, x):
|
||||
if envs.VLLM_MLA_PERFORM_MATRIX_ABSORPTION:
|
||||
if is_fp8(self.W_UV_O):
|
||||
output_parallel = apply_fp8_linear_generic(
|
||||
x.flatten(start_dim=1), self.W_UV_O, self.W_UV_O_scales,
|
||||
self.reqaunt_input_group_shape,
|
||||
self.reqaunt_weight_group_shape)
|
||||
else:
|
||||
output_parallel = torch.matmul(x.flatten(start_dim=1),
|
||||
self.W_UV_O)
|
||||
if self.tp_size > 1:
|
||||
output = tensor_model_parallel_all_reduce(output_parallel)
|
||||
else:
|
||||
output = output_parallel
|
||||
return output
|
||||
else:
|
||||
x = torch.einsum("bnl,lnv->bnv", x, self.W_UV)
|
||||
return self.o_proj(x.reshape(-1,
|
||||
self.num_heads * self.v_head_dim))[0]
|
||||
|
||||
def _q_proj_and_k_up_proj(self, x):
|
||||
if envs.VLLM_MLA_PERFORM_MATRIX_ABSORPTION:
|
||||
if is_fp8(self.W_Q_UK):
|
||||
return apply_fp8_linear_generic(
|
||||
x, self.W_Q_UK, self.W_Q_UK_scales,
|
||||
self.reqaunt_input_group_shape,
|
||||
self.reqaunt_weight_group_shape).view(
|
||||
-1, self.num_heads, self.kv_lora_rank)
|
||||
return torch.matmul(x, self.W_Q_UK)\
|
||||
.view(-1, self.num_heads, self.kv_lora_rank)
|
||||
else:
|
||||
x = torch.matmul(x, self.W_Q)\
|
||||
.view(-1, self.num_heads, self.qk_nope_head_dim)
|
||||
return torch.einsum("bnp,lnp->bnl", x, self.W_UK)\
|
||||
.view(-1, self.num_heads, self.kv_lora_rank)
|
||||
|
||||
def process_weights_after_loading(self, act_dtype: torch.dtype):
|
||||
# TODO(lucas) This is very gross, we need a more wide scale refactor of
|
||||
# all the FP8 code with a more standard way of
|
||||
# defining schemes/group-shapes, we should also potentially force
|
||||
# quant_methods to support a decompress function
|
||||
#
|
||||
# returns input_group_shape, weight_group_shape
|
||||
def get_scale_group_shapes_for_fp8(layer: LinearBase) -> \
|
||||
Tuple[Tuple[int, int], Tuple[int, int]]:
|
||||
if isinstance(layer.quant_method, Fp8LinearMethod):
|
||||
if layer.quant_method.block_quant:
|
||||
weight_block_size = \
|
||||
layer.quant_method.quant_config.weight_block_size
|
||||
# per-token-group (1, X), block-quantized (X, Y)
|
||||
return (1, weight_block_size[-1]), weight_block_size
|
||||
else:
|
||||
return (-1, -1), (-1, -1) # per-tensor, per-tensor
|
||||
elif isinstance(layer.quant_method, CompressedTensorsLinearMethod)\
|
||||
and isinstance(layer.scheme, CompressedTensorsW8A8Fp8):
|
||||
# this is hacky but we always assume the for
|
||||
# CompressedTensorsW8A8Fp8 the input is dynamic per-token
|
||||
# we ignore if it is static-per-tensor since we are going to
|
||||
# requantize after later anyways
|
||||
strategy = layer.scheme.strategy
|
||||
if strategy == QuantizationStrategy.TENSOR:
|
||||
return (1, -1), (-1, -1) # per-token, per-tensor
|
||||
elif strategy == QuantizationStrategy.CHANNEL:
|
||||
return (1, -1), (-1, 1) # per-token, per-channel
|
||||
else:
|
||||
raise NotImplementedError(
|
||||
f"QuantizationStrategy.{strategy} is not supported for "
|
||||
"fp8 MLA, please run with VLLM_MLA_DISABLE=1")
|
||||
else:
|
||||
raise NotImplementedError(
|
||||
"Can't determine scale group shapes for "
|
||||
f"{layer.quant_method}, please run with VLLM_MLA_DISABLE=1"
|
||||
)
|
||||
|
||||
def get_layer_weight(layer):
|
||||
if hasattr(layer, "weight"):
|
||||
return layer.weight
|
||||
elif hasattr(layer, "qweight"):
|
||||
return layer.qweight
|
||||
else:
|
||||
raise AttributeError(
|
||||
f"Layer '{layer}' has neither weight nor qweight")
|
||||
|
||||
def get_and_maybe_dequant_weights(layer: LinearBase):
|
||||
if not isinstance(layer.quant_method, UnquantizedLinearMethod):
|
||||
# NOTE: This should only be used offline, since it's O(N^3)
|
||||
eye = torch.eye(layer.input_size_per_partition,
|
||||
dtype=act_dtype,
|
||||
device=get_layer_weight(layer).device)
|
||||
dequant_weights = layer.quant_method.apply(layer,
|
||||
eye,
|
||||
bias=None)
|
||||
del eye
|
||||
# standardize to (output, input)
|
||||
return dequant_weights.T
|
||||
return layer.weight
|
||||
|
||||
weight_dtype = get_layer_weight(self.kv_b_proj).dtype
|
||||
assert get_layer_weight(self.o_proj).dtype == weight_dtype
|
||||
assert get_layer_weight(self.q_proj).dtype == weight_dtype
|
||||
|
||||
kv_b_proj_weight = get_and_maybe_dequant_weights(self.kv_b_proj).T
|
||||
assert kv_b_proj_weight.shape == (
|
||||
self.kv_lora_rank,
|
||||
self.num_heads * (self.qk_nope_head_dim + self.v_head_dim)), (
|
||||
f"{kv_b_proj_weight.shape=}, "
|
||||
f"{self.kv_lora_rank=}, "
|
||||
f"{self.num_heads=}, "
|
||||
f"{self.qk_nope_head_dim=}, "
|
||||
f"{self.v_head_dim=}")
|
||||
kv_b_proj_weight = kv_b_proj_weight.view(
|
||||
self.kv_lora_rank,
|
||||
self.num_heads,
|
||||
self.qk_nope_head_dim + self.v_head_dim,
|
||||
)
|
||||
|
||||
W_UK, W_UV = kv_b_proj_weight.split(
|
||||
[self.qk_nope_head_dim, self.v_head_dim], dim=-1)
|
||||
|
||||
q_proj_weight = get_and_maybe_dequant_weights(self.q_proj).T\
|
||||
.view(-1, self.num_heads, self.qk_head_dim)
|
||||
|
||||
# can be W_Q or W_UQ depending q_lora_rank, the former if
|
||||
# q_lora_rank is None, the latter otherwise. From the Attention backend
|
||||
# perspective though we call these both W_Q and rely on the layer
|
||||
# to pass in the correct matrix
|
||||
W_Q = q_proj_weight[..., :self.qk_nope_head_dim]
|
||||
self.W_QR = q_proj_weight[..., self.qk_nope_head_dim:]\
|
||||
.flatten(start_dim=1).contiguous()
|
||||
|
||||
# W_QR is small so for simplicity we dont bother requantizing it
|
||||
self.W_QR = self.W_QR.to(act_dtype)
|
||||
|
||||
if envs.VLLM_MLA_PERFORM_MATRIX_ABSORPTION:
|
||||
requantization_enabled = not envs.VLLM_MLA_DISABLE_REQUANTIZATION
|
||||
if is_fp8(weight_dtype) and requantization_enabled:
|
||||
# This assumes it wise to requantize using the same group shapes
|
||||
# (i.e. strategy, per-tensor, per-channel, block etc.) that the
|
||||
# weights were originally quantized
|
||||
requant_input_group_shape, requant_weight_group_shape = \
|
||||
get_scale_group_shapes_for_fp8(self.q_proj)
|
||||
assert (requant_input_group_shape, requant_weight_group_shape)\
|
||||
== get_scale_group_shapes_for_fp8(self.kv_b_proj)
|
||||
assert (requant_input_group_shape, requant_weight_group_shape)\
|
||||
== get_scale_group_shapes_for_fp8(self.o_proj)
|
||||
self.reqaunt_input_group_shape = requant_input_group_shape
|
||||
self.reqaunt_weight_group_shape = requant_weight_group_shape
|
||||
|
||||
#
|
||||
# Perform matrix-absorption following
|
||||
# https://github.com/flashinfer-ai/flashinfer/pull/551
|
||||
# for decode, as a result we end up with absorbed weights for decode
|
||||
# and another copy of raw weights for prefill.
|
||||
#
|
||||
self.W_UK, self.W_UV = kv_b_proj_weight.split(
|
||||
[self.qk_nope_head_dim, self.v_head_dim], dim=-1)
|
||||
# We absorb `W_UK` into `W_Q` resulting in either W_Q_UK or W_UQ_UK
|
||||
# depending q_lora_rank, the former if q_lora_rank is None, the
|
||||
# latter otherwise
|
||||
# basically if q_lora_rank is none we are absorbing into q_proj
|
||||
# instead of UQ
|
||||
W_Q_UK = torch.einsum("qnd,lnd -> qnl", W_Q, W_UK)\
|
||||
.flatten(start_dim=1).contiguous()
|
||||
|
||||
if is_fp8(weight_dtype) and requantization_enabled:
|
||||
W_Q_UK, W_Q_UK_scales = scaled_quantize(
|
||||
W_Q_UK,
|
||||
self.reqaunt_weight_group_shape,
|
||||
quant_dtype=current_platform_fp8_dtype)
|
||||
# For FP8 save the transpose so we can use
|
||||
# `apply_w8a8_block_fp8_linear` directly
|
||||
self.W_Q_UK = W_Q_UK.T.contiguous()
|
||||
self.W_Q_UK_scales = W_Q_UK_scales.T.contiguous()
|
||||
else:
|
||||
self.W_Q_UK = W_Q_UK.to(act_dtype)
|
||||
|
||||
W_O = get_and_maybe_dequant_weights(self.o_proj)\
|
||||
.view(-1, self.num_heads, self.v_head_dim)
|
||||
W_UV_O = torch.einsum("lnd,hnd -> nlh", W_UV, W_O)\
|
||||
.flatten(start_dim=0, end_dim=1).contiguous()
|
||||
|
||||
if is_fp8(weight_dtype) and requantization_enabled:
|
||||
W_UV_O, W_UV_O_scales = scaled_quantize(
|
||||
W_UV_O,
|
||||
self.reqaunt_weight_group_shape,
|
||||
quant_dtype=current_platform_fp8_dtype)
|
||||
# For FP8 save the transpose so we can use
|
||||
# `apply_w8a8_block_fp8_linear` directly
|
||||
self.W_UV_O = W_UV_O.T.contiguous()
|
||||
self.W_UV_O_scales = W_UV_O_scales.T.contiguous()
|
||||
else:
|
||||
self.W_UV_O = W_UV_O.to(act_dtype)
|
||||
|
||||
self.tp_size = get_tensor_model_parallel_world_size()
|
||||
else:
|
||||
if is_fp8(weight_dtype):
|
||||
raise NotImplementedError(
|
||||
"Currently fp8 requires matrix absorption")
|
||||
|
||||
self.W_UV = W_UV
|
||||
self.W_UK = W_UK
|
||||
self.W_Q = W_Q.flatten(start_dim=1)
|
||||
|
||||
@abstractmethod
|
||||
def _forward_prefill(
|
||||
self,
|
||||
q: torch.Tensor,
|
||||
kv_c_normed: torch.Tensor,
|
||||
k_pe: torch.Tensor,
|
||||
attn_metadata: T,
|
||||
) -> torch.Tensor:
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def _forward_decode(
|
||||
self,
|
||||
q_nope: torch.Tensor,
|
||||
q_pe: torch.Tensor,
|
||||
kv_cache: torch.Tensor,
|
||||
attn_metadata: T,
|
||||
) -> torch.Tensor:
|
||||
raise NotImplementedError
|
||||
|
||||
def forward(
|
||||
self,
|
||||
layer: AttentionLayer,
|
||||
hidden_states_or_q_c: torch.Tensor, # query in unified attn
|
||||
k_c_normed: torch.Tensor, # key in unified attn
|
||||
k_pe: torch.Tensor, # value in unified attn
|
||||
kv_cache: torch.Tensor,
|
||||
attn_metadata: T,
|
||||
output: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
if output is not None:
|
||||
raise NotImplementedError(
|
||||
"output is not yet supported for MLAImplBase")
|
||||
|
||||
is_decode = attn_metadata.decode_metadata is not None
|
||||
is_prefill = attn_metadata.prefill_metadata is not None
|
||||
|
||||
if (is_decode and is_prefill):
|
||||
raise NotImplementedError(
|
||||
"chunked prefill is not supported for MLAImplBase")
|
||||
|
||||
# Restore head dim (for rotary embedding)
|
||||
k_pe = k_pe.unsqueeze(1)
|
||||
assert hasattr(attn_metadata, "input_positions")
|
||||
|
||||
if is_decode:
|
||||
q_nope = self._q_proj_and_k_up_proj(hidden_states_or_q_c)
|
||||
q_pe = torch.matmul(hidden_states_or_q_c, self.W_QR)\
|
||||
.view(-1, self.num_heads, self.qk_rope_head_dim)
|
||||
q_pe, k_pe = self.rotary_emb(attn_metadata.input_positions, q_pe,
|
||||
k_pe)
|
||||
else:
|
||||
assert is_prefill
|
||||
q = self.q_proj(hidden_states_or_q_c)[0]\
|
||||
.view(-1, self.num_heads, self.qk_head_dim)
|
||||
|
||||
# TODO(lucas): there must be a nicer way to write this line
|
||||
q[..., self.qk_nope_head_dim:], k_pe = \
|
||||
self.rotary_emb(
|
||||
attn_metadata.input_positions,
|
||||
q[..., self.qk_nope_head_dim:], k_pe)
|
||||
|
||||
# write the latent and rope to kv cache
|
||||
if kv_cache.numel() > 0:
|
||||
ops.concat_and_cache_mla(
|
||||
k_c_normed,
|
||||
k_pe.squeeze(1),
|
||||
kv_cache,
|
||||
attn_metadata.slot_mapping.flatten(),
|
||||
kv_cache_dtype=self.kv_cache_dtype,
|
||||
scale=layer._k_scale,
|
||||
)
|
||||
|
||||
if attn_metadata.prefill_metadata is not None:
|
||||
return self._forward_prefill(q, k_c_normed, k_pe, attn_metadata)
|
||||
|
||||
if attn_metadata.decode_metadata is not None:
|
||||
return self._forward_decode(q_nope, q_pe, kv_cache, attn_metadata)
|
||||
|
||||
# Optional common flash-attn based prefill
|
||||
def _forward_prefill_flash(
|
||||
self,
|
||||
q: torch.Tensor,
|
||||
k_c_normed: torch.Tensor,
|
||||
k_pe: torch.Tensor,
|
||||
seq_start_loc: torch.Tensor,
|
||||
max_prefill_seq_len: int,
|
||||
) -> torch.Tensor:
|
||||
|
||||
kv_nope = self.kv_b_proj(k_c_normed)[0]\
|
||||
.view(-1, self.num_heads, self.qk_nope_head_dim + self.v_head_dim)
|
||||
k_nope, v = kv_nope\
|
||||
.split([self.qk_nope_head_dim, self.v_head_dim], dim=-1)
|
||||
|
||||
k = torch.cat((k_nope, k_pe.expand((*k_nope.shape[:-1], -1))), dim=-1)
|
||||
|
||||
# For MLA the v head dim is smaller than qk head dim so we pad out
|
||||
# v with 0s to match the qk head dim
|
||||
v_padded = torch.nn.functional.pad(v, [0, q.shape[-1] - v.shape[-1]],
|
||||
value=0)
|
||||
|
||||
attn_output = self.flash_attn_varlen_func(
|
||||
q=q,
|
||||
k=k,
|
||||
v=v_padded,
|
||||
cu_seqlens_q=seq_start_loc,
|
||||
cu_seqlens_k=seq_start_loc,
|
||||
max_seqlen_q=max_prefill_seq_len,
|
||||
max_seqlen_k=max_prefill_seq_len,
|
||||
softmax_scale=self.scale,
|
||||
causal=True,
|
||||
)
|
||||
attn_output = attn_output\
|
||||
.view(-1, self.num_heads, q.shape[-1])[..., :v.shape[-1]]\
|
||||
.reshape(-1, self.num_heads * v.shape[-1])
|
||||
|
||||
return self.o_proj(attn_output)[0]
|
@ -1,40 +1,17 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
from collections import defaultdict
|
||||
from contextlib import contextmanager
|
||||
from dataclasses import dataclass
|
||||
from itertools import accumulate
|
||||
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Type
|
||||
|
||||
from vllm.multimodal import MultiModalPlaceholderMap
|
||||
|
||||
try:
|
||||
from flashinfer import BatchDecodeMlaWithPagedKVCacheWrapper
|
||||
FLASHINFER_WORKSPACE_BUFFER_SIZE = 256 * 1024 * 1024
|
||||
except ImportError:
|
||||
BatchDecodeMlaWithPagedKVCacheWrapper = None
|
||||
FLASHINFER_WORKSPACE_BUFFER_SIZE = 0
|
||||
from typing import Any, Dict, List, Optional, Type
|
||||
|
||||
import torch
|
||||
|
||||
from vllm import _custom_ops as ops
|
||||
from vllm.attention.backends.abstract import (AttentionBackend,
|
||||
AttentionMetadata,
|
||||
AttentionMetadataBuilder,
|
||||
AttentionState, AttentionType)
|
||||
from vllm.attention.backends.mla.utils import MLACommonImpl, MLACommonMetadata
|
||||
from vllm.attention.backends.utils import (PAD_SLOT_ID, compute_slot_mapping,
|
||||
compute_slot_mapping_start_idx,
|
||||
is_block_tables_empty)
|
||||
from vllm.attention.backends.abstract import AttentionType
|
||||
from vllm.attention.backends.mla.common import (MLACommonBackend,
|
||||
MLACommonImpl,
|
||||
MLACommonMetadata)
|
||||
from vllm.attention.ops.triton_decode_attention import decode_attention_fwd
|
||||
from vllm.utils import async_tensor_h2d, make_tensor_with_pad
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vllm.worker.model_runner import (ModelInputForGPUBuilder,
|
||||
ModelInputForGPUWithSamplingMetadata)
|
||||
|
||||
|
||||
class TritonMLABackend(AttentionBackend):
|
||||
class TritonMLABackend(MLACommonBackend):
|
||||
|
||||
@staticmethod
|
||||
def get_name() -> str:
|
||||
@ -44,610 +21,8 @@ class TritonMLABackend(AttentionBackend):
|
||||
def get_impl_cls() -> Type["TritonMLAImpl"]:
|
||||
return TritonMLAImpl
|
||||
|
||||
@staticmethod
|
||||
def get_metadata_cls() -> Type["AttentionMetadata"]:
|
||||
return TritonMLAMetadata
|
||||
|
||||
@staticmethod
|
||||
def get_builder_cls() -> Type["TritonMLAMetadataBuilder"]:
|
||||
return TritonMLAMetadataBuilder
|
||||
|
||||
@staticmethod
|
||||
def get_state_cls() -> Type["TritonMLAState"]:
|
||||
return TritonMLAState
|
||||
|
||||
@staticmethod
|
||||
def get_kv_cache_shape(
|
||||
num_blocks: int,
|
||||
block_size: int,
|
||||
num_kv_heads: int, # assumed to be 1 for MLA
|
||||
head_size: int,
|
||||
) -> Tuple[int, ...]:
|
||||
return (num_blocks, block_size, head_size)
|
||||
|
||||
@staticmethod
|
||||
def swap_blocks(
|
||||
src_kv_cache: torch.Tensor,
|
||||
dst_kv_cache: torch.Tensor,
|
||||
src_to_dst: torch.Tensor,
|
||||
) -> None:
|
||||
ops.swap_blocks(src_kv_cache, dst_kv_cache, src_to_dst)
|
||||
|
||||
@staticmethod
|
||||
def copy_blocks(
|
||||
kv_caches: List[torch.Tensor],
|
||||
src_to_dists: torch.Tensor,
|
||||
) -> None:
|
||||
ops.copy_blocks_mla(kv_caches, src_to_dists)
|
||||
|
||||
@staticmethod
|
||||
def get_supported_head_sizes() -> List[int]:
|
||||
return [576]
|
||||
|
||||
|
||||
class TritonMLAState(AttentionState):
|
||||
|
||||
def __init__(self, runner):
|
||||
self.runner = runner
|
||||
self._is_graph_capturing = False
|
||||
|
||||
@contextmanager
|
||||
def graph_capture(self, max_batch_size: int):
|
||||
self._is_graph_capturing = True
|
||||
|
||||
self._graph_slot_mapping = torch.full((max_batch_size, ),
|
||||
PAD_SLOT_ID,
|
||||
dtype=torch.long,
|
||||
device=self.runner.device)
|
||||
self._graph_seq_lens = torch.ones(max_batch_size,
|
||||
dtype=torch.int32,
|
||||
device=self.runner.device)
|
||||
self._graph_block_tables = torch.from_numpy(
|
||||
self.runner.graph_block_tables).to(device=self.runner.device)
|
||||
|
||||
self._positions = torch.zeros((max_batch_size, ),
|
||||
dtype=torch.long,
|
||||
device=self.runner.device)
|
||||
|
||||
yield
|
||||
|
||||
self._is_graph_capturing = False
|
||||
del self._graph_slot_mapping
|
||||
del self._graph_seq_lens
|
||||
del self._graph_block_tables
|
||||
del self._positions
|
||||
|
||||
def graph_clone(self, batch_size: int):
|
||||
assert self._is_graph_capturing
|
||||
return self.__class__(self.runner)
|
||||
|
||||
def graph_capture_get_metadata_for_batch(
|
||||
self, batch_size: int, is_encoder_decoder_model: bool = False):
|
||||
assert self._is_graph_capturing
|
||||
|
||||
attn_metadata = self.runner.attn_backend.make_metadata(
|
||||
num_prefills=0,
|
||||
num_prefill_tokens=0,
|
||||
num_decode_tokens=batch_size,
|
||||
slot_mapping=self._graph_slot_mapping[:batch_size],
|
||||
multi_modal_placeholder_index_maps=None,
|
||||
enable_kv_scales_calculation=True,
|
||||
seq_lens=None,
|
||||
seq_lens_tensor=self._graph_seq_lens[:batch_size],
|
||||
max_query_len=1,
|
||||
max_decode_query_len=1,
|
||||
max_prefill_seq_len=0,
|
||||
max_decode_seq_len=self.runner.max_seq_len_to_capture,
|
||||
query_start_loc=None,
|
||||
seq_start_loc=None,
|
||||
context_lens_tensor=None,
|
||||
block_tables=self._graph_block_tables[:batch_size],
|
||||
use_cuda_graph=True,
|
||||
input_positions=self._positions[:batch_size],
|
||||
head_dim=self.runner.model_config.get_head_size())
|
||||
|
||||
if is_encoder_decoder_model:
|
||||
raise NotImplementedError(
|
||||
"TritonMLAState does not support encoder/decoder yet")
|
||||
|
||||
return attn_metadata
|
||||
|
||||
def get_graph_input_buffers(self,
|
||||
attn_metadata,
|
||||
is_encoder_decoder_model: bool = False):
|
||||
input_buffers = {
|
||||
"slot_mapping": attn_metadata.slot_mapping,
|
||||
"seq_lens_tensor": attn_metadata.decode_metadata.seq_lens_tensor,
|
||||
"block_tables": attn_metadata.decode_metadata.block_tables,
|
||||
"input_positions": attn_metadata.decode_metadata.input_positions,
|
||||
}
|
||||
if is_encoder_decoder_model:
|
||||
raise NotImplementedError(
|
||||
"TritonMLAState does not support encoder/decoder yet")
|
||||
|
||||
return input_buffers
|
||||
|
||||
def prepare_graph_input_buffers(self,
|
||||
input_buffers,
|
||||
attn_metadata,
|
||||
is_encoder_decoder_model: bool = False):
|
||||
input_positions = attn_metadata.input_positions
|
||||
num_positions = input_positions.shape[0]
|
||||
input_buffers["seq_lens_tensor"].copy_(
|
||||
attn_metadata.decode_metadata.seq_lens_tensor, non_blocking=True)
|
||||
input_buffers["block_tables"].copy_(
|
||||
attn_metadata.decode_metadata.block_tables, non_blocking=True)
|
||||
# CUDA graph buffer is padded so only perform a partial copy based on
|
||||
# num_positions
|
||||
input_buffers["input_positions"][:num_positions].copy_(
|
||||
input_positions, non_blocking=True)
|
||||
if is_encoder_decoder_model:
|
||||
raise NotImplementedError(
|
||||
"TritonMLAState does not support encoder/decoder yet")
|
||||
|
||||
def begin_forward(self, model_input):
|
||||
return
|
||||
|
||||
|
||||
@dataclass
|
||||
class TritonMLAMetadata(MLACommonMetadata):
|
||||
"""Metadata for TritonMLAMetadata.
|
||||
|
||||
NOTE: Any python object stored here is not updated when it is
|
||||
cuda-graph replayed. If you have values that need to be changed
|
||||
dynamically, it should be stored in tensor. The tensor has to be
|
||||
updated from `CUDAGraphRunner.forward` API.
|
||||
"""
|
||||
# (batch_size,). The sequence length per sequence. Sequence length means
|
||||
# the computed tokens + new tokens None if it is a decoding.
|
||||
seq_lens: Optional[List[int]]
|
||||
# seq_lens stored as a tensor.
|
||||
seq_lens_tensor: Optional[torch.Tensor]
|
||||
|
||||
# NOTE(sang): Definition of context_len, query_len, and seq_len.
|
||||
# |---------- N-1 iteration --------|
|
||||
# |---------------- N iteration ---------------------|
|
||||
# |- tokenA -|......................|-- newTokens ---|
|
||||
# |---------- context_len ----------|
|
||||
# |-------------------- seq_len ---------------------|
|
||||
# |-- query_len ---|
|
||||
|
||||
# Maximum sequence length among prefill batch. 0 if there are decoding
|
||||
# requests only.
|
||||
max_prefill_seq_len: int
|
||||
# Maximum sequence length among decode batch. 0 if there are prefill
|
||||
# requests only.
|
||||
max_decode_seq_len: int
|
||||
# (batch_size,) A tensor of context lengths (tokens that are computed
|
||||
# so far).
|
||||
context_lens_tensor: Optional[torch.Tensor]
|
||||
|
||||
# (batch_size, max_blocks_per_seq).
|
||||
# Block addresses per sequence. (Seq id -> list of physical block)
|
||||
# E.g., [0, 1, 2] means tokens are stored in 0th, 1st, and 2nd blocks
|
||||
# in the kv cache. Each block can contain up to block_size tokens.
|
||||
# 2nd dimensions are padded up to max_blocks_per_seq if it is cuda-graph
|
||||
# captured.
|
||||
block_tables: Optional[torch.Tensor]
|
||||
|
||||
# Whether or not if cuda graph is enabled.
|
||||
# Cuda-graph is currently enabled for decoding only.
|
||||
# TODO(woosuk): Move `use_cuda_graph` out since it's unrelated to attention.
|
||||
|
||||
use_cuda_graph: bool
|
||||
|
||||
# Maximum query length in the batch.
|
||||
max_query_len: Optional[int] = None
|
||||
|
||||
# Max number of query tokens among request in the batch.
|
||||
max_decode_query_len: Optional[int] = None
|
||||
|
||||
# (batch_size + 1,). The cumulative subquery lengths of the sequences in
|
||||
# the batch, used to index into subquery. E.g., if the subquery length
|
||||
# is [4, 6], it is [0, 4, 10].
|
||||
query_start_loc: Optional[torch.Tensor] = None
|
||||
# (batch_size + 1,). The cumulative sequence lengths of the sequences in
|
||||
# the batch, used to index into sequence. E.g., if the sequence length is
|
||||
# [4, 6], it is [0, 4, 10].
|
||||
seq_start_loc: Optional[torch.Tensor] = None
|
||||
|
||||
_cached_prefill_metadata: Optional["TritonMLAMetadata"] = None
|
||||
_cached_decode_metadata: Optional["TritonMLAMetadata"] = None
|
||||
|
||||
num_prefill_tokens: int
|
||||
|
||||
num_kv_splits: int = 4 # TODO(lucas) add heuristic
|
||||
attn_logits: Optional[torch.Tensor] = None
|
||||
req_idx: Optional[torch.Tensor] = None
|
||||
|
||||
# The dimension of the attention heads
|
||||
head_dim: Optional[int] = None
|
||||
|
||||
def __post_init__(self):
|
||||
supported_head_sizes = TritonMLABackend.get_supported_head_sizes()
|
||||
if self.head_dim is not None and self.head_dim \
|
||||
not in supported_head_sizes:
|
||||
raise ValueError(
|
||||
f"Only {supported_head_sizes} are supported for head_dim,",
|
||||
f"received {self.head_dim}.")
|
||||
|
||||
@property
|
||||
def prefill_metadata(self) -> Optional["TritonMLAMetadata"]:
|
||||
if self.num_prefills == 0:
|
||||
return None
|
||||
|
||||
if self._cached_prefill_metadata is not None:
|
||||
return self._cached_prefill_metadata
|
||||
|
||||
assert self.seq_lens is not None
|
||||
assert self.seq_lens_tensor is not None
|
||||
|
||||
# Compute some attn_metadata fields which default to None
|
||||
query_start_loc = (None if self.query_start_loc is None else
|
||||
self.query_start_loc[:self.num_prefills + 1])
|
||||
slot_mapping = (None if self.slot_mapping is None else
|
||||
self.slot_mapping[:self.num_prefill_tokens])
|
||||
seq_lens = (None if self.seq_lens is None else
|
||||
self.seq_lens[:self.num_prefills])
|
||||
seq_lens_tensor = (None if self.seq_lens_tensor is None else
|
||||
self.seq_lens_tensor[:self.num_prefills])
|
||||
seq_start_loc = (None if self.seq_start_loc is None else
|
||||
self.seq_start_loc[:self.num_prefills + 1])
|
||||
context_lens_tensor = (None if self.context_lens_tensor is None else
|
||||
self.context_lens_tensor[:self.num_prefills])
|
||||
block_tables = (None if self.block_tables is None else
|
||||
self.block_tables[:self.num_prefills])
|
||||
input_positions = (None if self.input_positions is None else
|
||||
self.input_positions[:self.num_prefill_tokens])
|
||||
|
||||
self._cached_prefill_metadata = TritonMLAMetadata(
|
||||
num_prefills=self.num_prefills,
|
||||
num_prefill_tokens=self.num_prefill_tokens,
|
||||
num_decode_tokens=0,
|
||||
slot_mapping=slot_mapping,
|
||||
multi_modal_placeholder_index_maps=self.
|
||||
multi_modal_placeholder_index_maps,
|
||||
enable_kv_scales_calculation=self.enable_kv_scales_calculation,
|
||||
input_positions=input_positions,
|
||||
seq_lens=seq_lens,
|
||||
seq_lens_tensor=seq_lens_tensor,
|
||||
max_query_len=self.max_query_len,
|
||||
max_prefill_seq_len=self.max_prefill_seq_len,
|
||||
max_decode_query_len=0,
|
||||
max_decode_seq_len=0,
|
||||
query_start_loc=query_start_loc,
|
||||
seq_start_loc=seq_start_loc,
|
||||
context_lens_tensor=context_lens_tensor,
|
||||
block_tables=block_tables,
|
||||
use_cuda_graph=False,
|
||||
head_dim=self.head_dim)
|
||||
return self._cached_prefill_metadata
|
||||
|
||||
@property
|
||||
def decode_metadata(self) -> Optional["TritonMLAMetadata"]:
|
||||
if self.num_decode_tokens == 0:
|
||||
return None
|
||||
|
||||
if self._cached_decode_metadata is not None:
|
||||
return self._cached_decode_metadata
|
||||
assert self.seq_lens_tensor is not None
|
||||
|
||||
# Compute some attn_metadata fields which default to None
|
||||
slot_mapping = (None if self.slot_mapping is None else
|
||||
self.slot_mapping[self.num_prefill_tokens:])
|
||||
seq_lens_tensor = (None if self.seq_lens_tensor is None else
|
||||
self.seq_lens_tensor[self.num_prefills:])
|
||||
block_tables = (None if self.block_tables is None else
|
||||
self.block_tables[self.num_prefills:])
|
||||
input_positions = (None if self.input_positions is None else
|
||||
self.input_positions[self.num_prefill_tokens:])
|
||||
|
||||
self._cached_decode_metadata = TritonMLAMetadata(
|
||||
num_prefills=0,
|
||||
num_prefill_tokens=0,
|
||||
num_decode_tokens=self.num_decode_tokens,
|
||||
slot_mapping=slot_mapping,
|
||||
multi_modal_placeholder_index_maps=None,
|
||||
enable_kv_scales_calculation=True,
|
||||
seq_lens=None,
|
||||
seq_lens_tensor=seq_lens_tensor,
|
||||
max_decode_query_len=self.max_decode_query_len,
|
||||
max_query_len=self.max_query_len,
|
||||
max_prefill_seq_len=0,
|
||||
max_decode_seq_len=self.max_decode_seq_len,
|
||||
# Batch may be composed of prefill|decodes, adjust query start
|
||||
# indices to refer to the start of decodes. E.g.
|
||||
# in tokens:[3 prefills|6 decodes], query_start_loc=[3,9] => [0,6].
|
||||
query_start_loc=(self.query_start_loc[self.num_prefills:] -
|
||||
self.query_start_loc[self.num_prefills])
|
||||
if self.query_start_loc is not None else None,
|
||||
seq_start_loc=self.seq_start_loc[self.num_prefills:]
|
||||
if self.seq_start_loc is not None else None,
|
||||
context_lens_tensor=None,
|
||||
block_tables=block_tables,
|
||||
use_cuda_graph=self.use_cuda_graph,
|
||||
input_positions=input_positions,
|
||||
head_dim=self.head_dim)
|
||||
return self._cached_decode_metadata
|
||||
|
||||
def advance_step(self,
|
||||
model_input: "ModelInputForGPUWithSamplingMetadata",
|
||||
sampled_token_ids: Optional[torch.Tensor],
|
||||
block_size: int,
|
||||
num_seqs: int,
|
||||
num_queries: int,
|
||||
turn_prefills_into_decodes: bool = False):
|
||||
"""
|
||||
Update metadata in-place to advance one decode step.
|
||||
"""
|
||||
# When using cudagraph, the num_seqs is padded to the next captured
|
||||
# batch sized, but num_queries tracks the actual number of requests in
|
||||
# the batch. For --enforce-eager mode, num_seqs == num_queries
|
||||
if num_seqs != num_queries:
|
||||
assert num_seqs > num_queries
|
||||
assert self.use_cuda_graph
|
||||
|
||||
if turn_prefills_into_decodes:
|
||||
# When Mutli-Step is enabled with Chunked-Prefill, prefills and
|
||||
# decodes are scheduled together. In the first step, all the
|
||||
# prefills turn into decodes. This update reflects that
|
||||
# conversion.
|
||||
assert self.num_decode_tokens + self.num_prefills == num_seqs
|
||||
self.num_decode_tokens += self.num_prefills
|
||||
self.num_prefills = 0
|
||||
self.num_prefill_tokens = 0
|
||||
self.max_prefill_seq_len = 0
|
||||
self.max_query_len = 1
|
||||
|
||||
self.slot_mapping = self.slot_mapping[:num_seqs]
|
||||
else:
|
||||
assert self.seq_lens is not None
|
||||
assert self.max_decode_seq_len == max(self.seq_lens)
|
||||
|
||||
assert self.num_prefills == 0
|
||||
assert self.num_prefill_tokens == 0
|
||||
assert self.num_decode_tokens == num_seqs
|
||||
assert self.slot_mapping.shape == (num_seqs, )
|
||||
|
||||
assert self.seq_lens is not None
|
||||
assert len(self.seq_lens) == num_seqs
|
||||
assert self.seq_lens_tensor is not None
|
||||
assert self.seq_lens_tensor.shape == (num_seqs, )
|
||||
assert self.max_query_len == 1
|
||||
assert self.max_prefill_seq_len == 0
|
||||
|
||||
assert self.query_start_loc is not None
|
||||
assert self.query_start_loc.shape == (num_queries + 1, )
|
||||
assert self.seq_start_loc is not None
|
||||
assert self.seq_start_loc.shape == (num_seqs + 1, )
|
||||
|
||||
assert self.context_lens_tensor is not None
|
||||
assert self.context_lens_tensor.shape == (num_queries, )
|
||||
|
||||
assert self.block_tables is not None
|
||||
assert self.block_tables.shape[0] == num_seqs
|
||||
|
||||
# Update query lengths. Note that we update only queries and not seqs,
|
||||
# since tensors may be padded due to captured cuda graph batch size
|
||||
for i in range(num_queries):
|
||||
self.seq_lens[i] += 1
|
||||
self.max_decode_seq_len = max(self.seq_lens)
|
||||
|
||||
ops.advance_step_flashattn(num_seqs=num_seqs,
|
||||
num_queries=num_queries,
|
||||
block_size=block_size,
|
||||
input_tokens=model_input.input_tokens,
|
||||
sampled_token_ids=sampled_token_ids,
|
||||
input_positions=model_input.input_positions,
|
||||
seq_lens=self.seq_lens_tensor,
|
||||
slot_mapping=self.slot_mapping,
|
||||
block_tables=self.block_tables)
|
||||
|
||||
|
||||
class TritonMLAMetadataBuilder(AttentionMetadataBuilder[TritonMLAMetadata]):
|
||||
|
||||
def __init__(self, input_builder: "ModelInputForGPUBuilder"):
|
||||
self.input_builder = input_builder
|
||||
self.runner = input_builder.runner
|
||||
self.sliding_window = input_builder.sliding_window
|
||||
self.block_size = input_builder.block_size
|
||||
|
||||
def prepare(self):
|
||||
self.slot_mapping: List[int] = []
|
||||
self.prefill_seq_lens: List[int] = []
|
||||
self.context_lens: List[int] = []
|
||||
self.block_tables: List[List[int]] = []
|
||||
self.curr_seq_lens: List[int] = []
|
||||
self.input_positions: List[int] = []
|
||||
self.multimodal_placeholder_maps: Dict[
|
||||
str,
|
||||
MultiModalPlaceholderMap] = defaultdict(MultiModalPlaceholderMap)
|
||||
self.num_prefills = 0
|
||||
self.num_prefill_tokens = 0
|
||||
self.num_decode_tokens = 0
|
||||
self.has_prefix_cache_hit = False
|
||||
|
||||
def _add_seq_group(
|
||||
self, inter_data: "ModelInputForGPUBuilder.InterDataForSeqGroup",
|
||||
chunked_prefill_enabled: bool, prefix_cache_hit: bool):
|
||||
"""Add a sequence group to the metadata. Specifically update/append
|
||||
1. context length.
|
||||
2. block table.
|
||||
3. slot mapping.
|
||||
"""
|
||||
is_prompt = inter_data.is_prompt
|
||||
block_tables = inter_data.block_tables
|
||||
|
||||
for (seq_id, token_len, seq_len, curr_seq_len, query_len, context_len,
|
||||
curr_sliding_window_block, input_positions) in zip(
|
||||
inter_data.seq_ids, [len(t) for t in inter_data.input_tokens],
|
||||
inter_data.orig_seq_lens, inter_data.seq_lens,
|
||||
inter_data.query_lens, inter_data.context_lens,
|
||||
inter_data.curr_sliding_window_blocks,
|
||||
inter_data.input_positions):
|
||||
self.input_positions.extend(input_positions)
|
||||
self.context_lens.append(context_len)
|
||||
if is_prompt:
|
||||
mm_maps = inter_data.multi_modal_placeholder_maps
|
||||
if mm_maps:
|
||||
for modality, placeholders in mm_maps.items():
|
||||
self.multimodal_placeholder_maps[modality].extend(
|
||||
placeholders)
|
||||
|
||||
self.num_prefills += 1
|
||||
self.num_prefill_tokens += token_len
|
||||
self.prefill_seq_lens.append(seq_len)
|
||||
else:
|
||||
self.num_decode_tokens += query_len
|
||||
self.curr_seq_lens.append(curr_seq_len)
|
||||
|
||||
# Compute block table.
|
||||
# TODO(sang): Combine chunked prefill and prefix caching by
|
||||
# only allowing multiple of block_size chunk size.
|
||||
# NOTE: This only works for oooooooxxx style attention.
|
||||
block_table = []
|
||||
if prefix_cache_hit:
|
||||
# NOTE(woosuk): For flash-attn, the block table should
|
||||
# include the entries for the incoming prefill tokens.
|
||||
block_table = block_tables[seq_id]
|
||||
elif ((chunked_prefill_enabled or not is_prompt)
|
||||
and block_tables is not None):
|
||||
if curr_sliding_window_block == 0:
|
||||
block_table = block_tables[seq_id]
|
||||
else:
|
||||
block_table = block_tables[seq_id][
|
||||
-curr_sliding_window_block:]
|
||||
self.block_tables.append(block_table)
|
||||
|
||||
# Compute slot mapping.
|
||||
is_profile_run = is_block_tables_empty(block_tables)
|
||||
start_idx = compute_slot_mapping_start_idx(is_prompt, query_len,
|
||||
context_len,
|
||||
self.sliding_window)
|
||||
compute_slot_mapping(is_profile_run, self.slot_mapping, seq_id,
|
||||
seq_len, context_len, start_idx,
|
||||
self.block_size, inter_data.block_tables)
|
||||
|
||||
def _get_graph_runner_block_tables(
|
||||
self, num_seqs: int,
|
||||
block_tables: List[List[int]]) -> torch.Tensor:
|
||||
# The shape of graph_block_tables is
|
||||
# [max batch size, max context len // block size].
|
||||
max_batch_size, max_blocks = self.runner.graph_block_tables.shape
|
||||
assert max_batch_size >= num_seqs
|
||||
|
||||
graph_block_tables = self.runner.graph_block_tables[:num_seqs]
|
||||
for i, block_table in enumerate(block_tables):
|
||||
if block_table:
|
||||
num_blocks = len(block_table)
|
||||
if num_blocks <= max_blocks:
|
||||
graph_block_tables[i, :num_blocks] = block_table
|
||||
else:
|
||||
# It may be possible to have more blocks allocated due
|
||||
# to lookahead slots of multi-step, however, they are
|
||||
# not used anyway, so can be safely ignored.
|
||||
graph_block_tables[
|
||||
i, :max_blocks] = block_table[:max_blocks]
|
||||
|
||||
return torch.from_numpy(graph_block_tables).to(
|
||||
device=self.runner.device, non_blocking=True)
|
||||
|
||||
def build(self, seq_lens: List[int], query_lens: List[int],
|
||||
cuda_graph_pad_size: int, batch_size: int):
|
||||
"""Build attention metadata with on-device tensors.
|
||||
|
||||
Args:
|
||||
seq_lens: The maybe padded sequence lengths of the input sequences.
|
||||
query_lens: The query lengths of the input sequences.
|
||||
cuda_graph_pad_size: The padding size for cuda graph.
|
||||
-1 if cuda graph is not used.
|
||||
batch_size: The maybe padded batch size.
|
||||
"""
|
||||
prefix_cache_hit = any([
|
||||
inter_data.prefix_cache_hit
|
||||
for inter_data in self.input_builder.inter_data_list
|
||||
])
|
||||
for inter_data in self.input_builder.inter_data_list:
|
||||
self._add_seq_group(inter_data,
|
||||
self.input_builder.chunked_prefill_enabled,
|
||||
prefix_cache_hit)
|
||||
|
||||
device = self.runner.device
|
||||
use_captured_graph = cuda_graph_pad_size != -1
|
||||
|
||||
max_query_len = max(query_lens)
|
||||
decode_query_lens = query_lens[self.num_prefills:]
|
||||
if len(decode_query_lens) > 0:
|
||||
max_decode_query_len = max(decode_query_lens)
|
||||
else:
|
||||
max_decode_query_len = 1
|
||||
max_prefill_seq_len = max(self.prefill_seq_lens, default=0)
|
||||
max_decode_seq_len = max(self.curr_seq_lens, default=0)
|
||||
num_decode_tokens = self.num_decode_tokens
|
||||
query_start_loc = list(accumulate(query_lens, initial=0))
|
||||
seq_start_loc = list(accumulate(seq_lens, initial=0))
|
||||
|
||||
num_seqs = len(seq_lens)
|
||||
if use_captured_graph:
|
||||
self.slot_mapping.extend([PAD_SLOT_ID] * cuda_graph_pad_size)
|
||||
self.block_tables.extend([] * cuda_graph_pad_size)
|
||||
num_decode_tokens = batch_size - self.num_prefill_tokens
|
||||
block_tables = self._get_graph_runner_block_tables(
|
||||
num_seqs, self.block_tables)
|
||||
else:
|
||||
block_tables = make_tensor_with_pad(
|
||||
self.block_tables,
|
||||
pad=0,
|
||||
dtype=torch.int,
|
||||
device=device,
|
||||
)
|
||||
assert max_query_len > 0, ("query_lens: {}".format(query_lens))
|
||||
|
||||
assert device is not None
|
||||
context_lens_tensor = async_tensor_h2d(self.context_lens, torch.int,
|
||||
device, self.runner.pin_memory)
|
||||
seq_lens_tensor = async_tensor_h2d(seq_lens, torch.int, device,
|
||||
self.runner.pin_memory)
|
||||
input_positions = async_tensor_h2d(self.input_positions, torch.long,
|
||||
device, self.runner.pin_memory)
|
||||
slot_mapping_tensor = async_tensor_h2d(self.slot_mapping, torch.long,
|
||||
device, self.runner.pin_memory)
|
||||
query_start_loc_tensor = async_tensor_h2d(query_start_loc, torch.int32,
|
||||
device,
|
||||
self.runner.pin_memory)
|
||||
seq_start_loc_tensor = async_tensor_h2d(seq_start_loc, torch.int32,
|
||||
device, self.runner.pin_memory)
|
||||
placeholder_index_maps = {
|
||||
modality: placeholder_map.index_map()
|
||||
for modality, placeholder_map in
|
||||
self.multimodal_placeholder_maps.items()
|
||||
}
|
||||
|
||||
return TritonMLAMetadata(
|
||||
num_prefills=self.num_prefills,
|
||||
slot_mapping=slot_mapping_tensor,
|
||||
num_prefill_tokens=self.num_prefill_tokens,
|
||||
num_decode_tokens=num_decode_tokens,
|
||||
seq_lens=seq_lens,
|
||||
multi_modal_placeholder_index_maps=placeholder_index_maps,
|
||||
enable_kv_scales_calculation=True,
|
||||
input_positions=input_positions,
|
||||
seq_lens_tensor=seq_lens_tensor,
|
||||
max_query_len=max_query_len,
|
||||
max_decode_query_len=max_decode_query_len,
|
||||
max_prefill_seq_len=max_prefill_seq_len,
|
||||
max_decode_seq_len=max_decode_seq_len,
|
||||
query_start_loc=query_start_loc_tensor,
|
||||
seq_start_loc=seq_start_loc_tensor,
|
||||
context_lens_tensor=context_lens_tensor,
|
||||
block_tables=block_tables,
|
||||
use_cuda_graph=use_captured_graph,
|
||||
num_kv_splits=4, # TODO(lucas) add heuristic
|
||||
head_dim=self.runner.model_config.get_head_size(),
|
||||
)
|
||||
|
||||
|
||||
class TritonMLAImpl(MLACommonImpl[TritonMLAMetadata]):
|
||||
class TritonMLAImpl(MLACommonImpl[MLACommonMetadata]):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
@ -662,11 +37,11 @@ class TritonMLAImpl(MLACommonImpl[TritonMLAMetadata]):
|
||||
logits_soft_cap: Optional[float],
|
||||
attn_type: str,
|
||||
# MLA Specific Arguments
|
||||
**kwargs) -> None:
|
||||
**mla_args) -> None:
|
||||
super().__init__(num_heads, head_size, scale, num_kv_heads,
|
||||
alibi_slopes, sliding_window, kv_cache_dtype,
|
||||
blocksparse_params, logits_soft_cap, attn_type,
|
||||
**kwargs)
|
||||
**mla_args)
|
||||
|
||||
unsupported_features = [
|
||||
alibi_slopes, sliding_window, blocksparse_params, logits_soft_cap
|
||||
@ -683,24 +58,12 @@ class TritonMLAImpl(MLACommonImpl[TritonMLAMetadata]):
|
||||
"are not implemented for "
|
||||
"TritonMLAImpl")
|
||||
|
||||
def _forward_prefill(
|
||||
self,
|
||||
q: torch.Tensor,
|
||||
kv_c_normed: torch.Tensor,
|
||||
k_pe: torch.Tensor,
|
||||
attn_metadata: TritonMLAMetadata,
|
||||
) -> torch.Tensor:
|
||||
assert isinstance(attn_metadata, TritonMLAMetadata)
|
||||
return self._forward_prefill_flash(q, kv_c_normed, k_pe,
|
||||
attn_metadata.seq_start_loc,
|
||||
attn_metadata.max_prefill_seq_len)
|
||||
|
||||
def _forward_decode(
|
||||
self,
|
||||
q_nope: torch.Tensor,
|
||||
q_pe: torch.Tensor,
|
||||
kv_c_and_k_pe_cache: torch.Tensor,
|
||||
attn_metadata: TritonMLAMetadata,
|
||||
attn_metadata: MLACommonMetadata,
|
||||
) -> torch.Tensor:
|
||||
assert kv_c_and_k_pe_cache.numel() > 0
|
||||
if self.kv_cache_dtype.startswith("fp8"):
|
||||
@ -717,12 +80,14 @@ class TritonMLAImpl(MLACommonImpl[TritonMLAMetadata]):
|
||||
dtype=q.dtype,
|
||||
device=q.device)
|
||||
|
||||
num_kv_splits = 4 # TODO: heuristic
|
||||
|
||||
# TODO(lucas) Allocate ahead of time
|
||||
attn_logits = torch.empty(
|
||||
(
|
||||
B,
|
||||
self.num_heads,
|
||||
attn_metadata.num_kv_splits,
|
||||
num_kv_splits,
|
||||
# NOTE(lucas) idk why the +1 is here but sglang has it so we
|
||||
# just mirror that
|
||||
self.kv_lora_rank + 1,
|
||||
@ -740,7 +105,6 @@ class TritonMLAImpl(MLACommonImpl[TritonMLAMetadata]):
|
||||
decode_attention_fwd(q, kv_c_and_k_pe_cache, kv_c_cache, o,
|
||||
decode_meta.block_tables,
|
||||
decode_meta.seq_lens_tensor, attn_logits,
|
||||
attn_metadata.num_kv_splits, self.scale,
|
||||
PAGE_SIZE)
|
||||
num_kv_splits, self.scale, PAGE_SIZE)
|
||||
|
||||
return self._v_up_proj_and_o_proj(o)
|
||||
|
84
vllm/attention/ops/triton_merge_attn_states.py
Normal file
84
vllm/attention/ops/triton_merge_attn_states.py
Normal file
@ -0,0 +1,84 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
import triton
|
||||
import triton.language as tl
|
||||
|
||||
|
||||
# Implements section 2.2 of https://www.arxiv.org/pdf/2501.01005
|
||||
# can be used to combine partial attention results (in the split-KV case)
|
||||
def merge_attn_states(
|
||||
output: torch.Tensor,
|
||||
prefix_output: torch.Tensor,
|
||||
prefix_lse: torch.Tensor,
|
||||
suffix_output: torch.Tensor,
|
||||
suffix_lse: torch.Tensor,
|
||||
output_lse: Optional[torch.Tensor] = None,
|
||||
) -> None:
|
||||
num_tokens = output.shape[0]
|
||||
num_query_heads = output.shape[1]
|
||||
head_size = output.shape[2]
|
||||
padded_head_size = triton.next_power_of_2(head_size)
|
||||
|
||||
# TODO(woosuk): Use CUDA kernel instead of Triton to minimize CPU overhead.
|
||||
merge_attn_states_kernel[(num_tokens, num_query_heads)](
|
||||
output,
|
||||
output_lse,
|
||||
prefix_output,
|
||||
prefix_lse,
|
||||
suffix_output,
|
||||
suffix_lse,
|
||||
head_size,
|
||||
padded_head_size,
|
||||
output_lse is not None,
|
||||
)
|
||||
|
||||
|
||||
@triton.jit
|
||||
def merge_attn_states_kernel(
|
||||
output, # [NUM_TOKENS, NUM_HEADS, HEAD_SIZE]
|
||||
output_lse, # [NUM_HEADS, NUM_TOKENS]
|
||||
prefix_output, # [NUM_TOKENS, NUM_HEADS, HEAD_SIZE]
|
||||
prefix_lse, # [NUM_HEADS, NUM_TOKENS]
|
||||
suffix_output, # [NUM_TOKENS, NUM_HEADS, HEAD_SIZE]
|
||||
suffix_lse, # [NUM_HEADS, NUM_TOKENS]
|
||||
HEAD_SIZE: tl.constexpr,
|
||||
PADDED_HEAD_SIZE: tl.constexpr,
|
||||
OUTPUT_LSE: tl.constexpr,
|
||||
):
|
||||
token_idx = tl.program_id(0)
|
||||
num_tokens = tl.num_programs(0)
|
||||
head_idx = tl.program_id(1)
|
||||
num_heads = tl.num_programs(1)
|
||||
|
||||
p_lse = tl.load(prefix_lse + head_idx * num_tokens + token_idx)
|
||||
s_lse = tl.load(suffix_lse + head_idx * num_tokens + token_idx)
|
||||
max_lse = tl.maximum(p_lse, s_lse)
|
||||
p_lse = p_lse - max_lse
|
||||
s_lse = s_lse - max_lse
|
||||
out_se = (tl.exp(p_lse) + tl.exp(s_lse))
|
||||
|
||||
if OUTPUT_LSE:
|
||||
out_lse = tl.log(out_se) + max_lse
|
||||
tl.store(output_lse + head_idx * num_tokens + token_idx, out_lse)
|
||||
|
||||
head_arange = tl.arange(0, PADDED_HEAD_SIZE)
|
||||
head_mask = head_arange < HEAD_SIZE
|
||||
p_out = tl.load(prefix_output + token_idx * num_heads * HEAD_SIZE +
|
||||
head_idx * HEAD_SIZE + head_arange,
|
||||
mask=head_mask)
|
||||
s_out = tl.load(suffix_output + token_idx * num_heads * HEAD_SIZE +
|
||||
head_idx * HEAD_SIZE + head_arange,
|
||||
mask=head_mask)
|
||||
|
||||
# NOTE(woosuk): Be careful with the numerical stability.
|
||||
# We should compute the scale first, and then multiply it with the output.
|
||||
# Do not multiply the output with tl.exp(p_lse) or tl.exp(s_lse) directly.
|
||||
p_scale = tl.exp(p_lse) / out_se
|
||||
s_scale = tl.exp(s_lse) / out_se
|
||||
out = p_out * p_scale + s_out * s_scale
|
||||
tl.store(output + token_idx * num_heads * HEAD_SIZE +
|
||||
head_idx * HEAD_SIZE + head_arange,
|
||||
out,
|
||||
mask=head_mask)
|
@ -3332,19 +3332,6 @@ class VllmConfig:
|
||||
|
||||
current_platform.check_and_update_config(self)
|
||||
|
||||
# If MLA is enabled, force disable chunked prefill and prefix caching
|
||||
if self.model_config and self.model_config.use_mla:
|
||||
logger.info("MLA is enabled; forcing chunked prefill and prefix "
|
||||
"caching to be disabled.")
|
||||
self.scheduler_config.enable_chunked_prefill = False
|
||||
self.scheduler_config.chunked_prefill_enabled = False
|
||||
self.scheduler_config.max_num_batched_tokens = max(
|
||||
self.scheduler_config.max_model_len,
|
||||
_DEFAULT_MAX_NUM_BATCHED_TOKENS)
|
||||
|
||||
if self.cache_config is not None:
|
||||
self.cache_config.enable_prefix_caching = False
|
||||
|
||||
if not self.instance_id:
|
||||
self.instance_id = random_uuid()[:5]
|
||||
|
||||
|
@ -1170,9 +1170,9 @@ class EngineArgs:
|
||||
# long context (> 32K) models. This is to avoid OOM errors in the
|
||||
# initial memory profiling phase.
|
||||
|
||||
# For multimodal models, chunked prefill is disabled by default in
|
||||
# V0, but enabled by design in V1
|
||||
if model_config.is_multimodal_model:
|
||||
# For multimodal models and models with MLA, chunked prefill is
|
||||
# disabled by default in V0, but enabled by design in V1
|
||||
if model_config.is_multimodal_model or model_config.use_mla:
|
||||
self.enable_chunked_prefill = bool(envs.VLLM_USE_V1)
|
||||
|
||||
elif use_long_context:
|
||||
@ -1207,7 +1207,6 @@ class EngineArgs:
|
||||
msg = "Chunked prefill is not supported for pooling models"
|
||||
raise ValueError(msg)
|
||||
|
||||
|
||||
speculative_config = SpeculativeConfig.maybe_create_spec_config(
|
||||
target_model_config=model_config,
|
||||
target_parallel_config=parallel_config,
|
||||
|
@ -162,6 +162,9 @@ def _per_token_group_quant_fp8(
|
||||
y_q_ptr,
|
||||
y_s_ptr,
|
||||
group_size,
|
||||
# Num columns of y
|
||||
y_num_columns,
|
||||
y_row_stride,
|
||||
# Avoid to divide zero
|
||||
eps,
|
||||
# Information for float8
|
||||
@ -174,9 +177,14 @@ def _per_token_group_quant_fp8(
|
||||
quantization on a tensor.
|
||||
This function converts the tensor values into float8 values.
|
||||
"""
|
||||
groups_per_row = y_num_columns // group_size
|
||||
|
||||
# Map the program id to the row of X and Y it should compute.
|
||||
g_id = tl.program_id(0)
|
||||
y_ptr += g_id * group_size
|
||||
row = g_id // groups_per_row
|
||||
row_g_id = g_id % groups_per_row
|
||||
|
||||
y_ptr += (row * y_row_stride) + (row_g_id * group_size)
|
||||
y_q_ptr += g_id * group_size
|
||||
y_s_ptr += g_id
|
||||
|
||||
@ -202,6 +210,7 @@ def _per_token_group_quant_fp8_colmajor(
|
||||
group_size,
|
||||
# Num columns of y
|
||||
y_num_columns,
|
||||
y_row_stride,
|
||||
# Stride from one column to the next of y_s
|
||||
y_s_col_stride,
|
||||
# Avoid to divide zero
|
||||
@ -216,9 +225,14 @@ def _per_token_group_quant_fp8_colmajor(
|
||||
quantization on a tensor.
|
||||
This function converts the tensor values into float8 values.
|
||||
"""
|
||||
groups_per_row = y_num_columns // group_size
|
||||
|
||||
# Map the program id to the row of X and Y it should compute.
|
||||
g_id = tl.program_id(0)
|
||||
y_ptr += g_id * group_size
|
||||
row = g_id // groups_per_row
|
||||
row_g_id = g_id % groups_per_row
|
||||
|
||||
y_ptr += (row * y_row_stride) + (row_g_id * group_size)
|
||||
y_q_ptr += g_id * group_size
|
||||
|
||||
# Convert g_id the flattened block coordinate to 2D so we can index
|
||||
@ -267,7 +281,7 @@ def per_token_group_quant_fp8(
|
||||
assert (x.shape[-1] % group_size == 0), (
|
||||
f"the last dimension of `x` {x.shape[-1]} must be divisible "
|
||||
f"by `group_size` {group_size}")
|
||||
assert x.is_contiguous(), "`x` must be contiguous"
|
||||
assert x.stride(-1) == 1, "`x` groups must be contiguous"
|
||||
|
||||
finfo = torch.finfo(dtype)
|
||||
fp8_min = finfo.min
|
||||
@ -295,6 +309,7 @@ def per_token_group_quant_fp8(
|
||||
x_s,
|
||||
group_size,
|
||||
x.shape[1],
|
||||
x.stride(0),
|
||||
x_s.stride(1),
|
||||
eps,
|
||||
fp8_min=fp8_min,
|
||||
@ -309,6 +324,8 @@ def per_token_group_quant_fp8(
|
||||
x_q,
|
||||
x_s,
|
||||
group_size,
|
||||
x.shape[1],
|
||||
x.stride(0),
|
||||
eps,
|
||||
fp8_min=fp8_min,
|
||||
fp8_max=fp8_max,
|
||||
|
@ -565,6 +565,10 @@ def round_up(x: int, y: int) -> int:
|
||||
return ((x + y - 1) // y) * y
|
||||
|
||||
|
||||
def round_down(x: int, y: int) -> int:
|
||||
return (x // y) * y
|
||||
|
||||
|
||||
def _generate_random_fp8(
|
||||
tensor: torch.Tensor,
|
||||
low: float,
|
||||
|
@ -5,12 +5,11 @@ from typing import Any, Dict, List, Optional, Tuple, Type
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import triton
|
||||
import triton.language as tl
|
||||
|
||||
from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
|
||||
AttentionMetadata, AttentionType)
|
||||
from vllm.attention.backends.utils import get_flash_attn_version
|
||||
from vllm.attention.ops.triton_merge_attn_states import merge_attn_states
|
||||
from vllm.logger import init_logger
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.utils import cdiv
|
||||
@ -372,70 +371,4 @@ def cascade_attention(
|
||||
|
||||
# Merge prefix and suffix outputs, and store the result in output.
|
||||
merge_attn_states(output, prefix_output, prefix_lse, suffix_output,
|
||||
suffix_lse)
|
||||
|
||||
|
||||
def merge_attn_states(
|
||||
output: torch.Tensor,
|
||||
prefix_output: torch.Tensor,
|
||||
prefix_lse: torch.Tensor,
|
||||
suffix_output: torch.Tensor,
|
||||
suffix_lse: torch.Tensor,
|
||||
) -> None:
|
||||
num_tokens = output.shape[0]
|
||||
num_query_heads = output.shape[1]
|
||||
head_size = output.shape[2]
|
||||
padded_head_size = triton.next_power_of_2(head_size)
|
||||
|
||||
# TODO(woosuk): Use CUDA kernel instead of Triton to minimize CPU overhead.
|
||||
merge_attn_states_kernel[(num_tokens, num_query_heads)](
|
||||
output,
|
||||
prefix_output,
|
||||
prefix_lse,
|
||||
suffix_output,
|
||||
suffix_lse,
|
||||
head_size,
|
||||
padded_head_size,
|
||||
)
|
||||
|
||||
|
||||
@triton.jit
|
||||
def merge_attn_states_kernel(
|
||||
output, # [NUM_TOKENS, NUM_HEADS, HEAD_SIZE]
|
||||
prefix_output, # [NUM_TOKENS, NUM_HEADS, HEAD_SIZE]
|
||||
prefix_lse, # [NUM_HEADS, NUM_TOKENS]
|
||||
suffix_output, # [NUM_TOKENS, NUM_HEADS, HEAD_SIZE]
|
||||
suffix_lse, # [NUM_HEADS, NUM_TOKENS]
|
||||
HEAD_SIZE: tl.constexpr,
|
||||
PADDED_HEAD_SIZE: tl.constexpr,
|
||||
):
|
||||
token_idx = tl.program_id(0)
|
||||
num_tokens = tl.num_programs(0)
|
||||
head_idx = tl.program_id(1)
|
||||
num_heads = tl.num_programs(1)
|
||||
|
||||
p_lse = tl.load(prefix_lse + head_idx * num_tokens + token_idx)
|
||||
s_lse = tl.load(suffix_lse + head_idx * num_tokens + token_idx)
|
||||
max_lse = tl.maximum(p_lse, s_lse)
|
||||
p_lse = p_lse - max_lse
|
||||
s_lse = s_lse - max_lse
|
||||
|
||||
head_arange = tl.arange(0, PADDED_HEAD_SIZE)
|
||||
head_mask = head_arange < HEAD_SIZE
|
||||
p_out = tl.load(prefix_output + token_idx * num_heads * HEAD_SIZE +
|
||||
head_idx * HEAD_SIZE + head_arange,
|
||||
mask=head_mask)
|
||||
s_out = tl.load(suffix_output + token_idx * num_heads * HEAD_SIZE +
|
||||
head_idx * HEAD_SIZE + head_arange,
|
||||
mask=head_mask)
|
||||
|
||||
# NOTE(woosuk): Be careful with the numerical stability.
|
||||
# We should compute the scale first, and then multiply it with the output.
|
||||
# Do not multiply the output with tl.exp(p_lse) or tl.exp(s_lse) directly.
|
||||
p_scale = tl.exp(p_lse) / (tl.exp(p_lse) + tl.exp(s_lse))
|
||||
s_scale = tl.exp(s_lse) / (tl.exp(p_lse) + tl.exp(s_lse))
|
||||
out = p_out * p_scale + s_out * s_scale
|
||||
tl.store(output + token_idx * num_heads * HEAD_SIZE +
|
||||
head_idx * HEAD_SIZE + head_arange,
|
||||
out,
|
||||
mask=head_mask)
|
||||
suffix_lse)
|
Loading…
x
Reference in New Issue
Block a user