[Attention] MLA support for V1 (#13789)

Signed-off-by: Yang Chen <yangche@fb.com>
This commit is contained in:
Yang Chen 2025-02-27 10:14:17 -08:00 committed by GitHub
parent f1579b229d
commit 58d1b2aa77
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
10 changed files with 1340 additions and 59 deletions

View File

@ -89,6 +89,7 @@ class Attention(nn.Module):
self._k_scale_float = 1.0
self._v_scale_float = 1.0
self.use_mla = use_mla
self.num_heads = num_heads
self.head_size = head_size
self.num_kv_heads = num_kv_heads
@ -158,6 +159,10 @@ class Attention(nn.Module):
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
# For some alternate attention backends like MLA the attention output
# shape does not match the query shape, so we optionally let the model
# definition specify the output tensor shape.
output_shape: Optional[torch.Size] = None,
) -> torch.Tensor:
"""
The KV cache is stored inside this class and is accessed via
@ -173,17 +178,25 @@ class Attention(nn.Module):
if attn_metadata.enable_kv_scales_calculation:
self.calc_kv_scales(key, value)
if self.use_output:
output = torch.empty_like(query)
hidden_size = query.size(-1)
# Reshape the query, key, and value tensors.
# NOTE(woosuk): We do this outside the custom op to minimize the
# CPU overheads from the non-CUDA-graph regions.
query = query.view(-1, self.num_heads, self.head_size)
output = output.view(-1, self.num_heads, self.head_size)
if key is not None:
key = key.view(-1, self.num_kv_heads, self.head_size)
if value is not None:
value = value.view(-1, self.num_kv_heads, self.head_size)
output_shape = (output_shape
if output_shape is not None else query.shape)
output = torch.empty(output_shape,
dtype=query.dtype,
device=query.device)
hidden_size = output_shape[-1]
# We skip reshaping query, key and value tensors for the MLA
# backend since these tensors have different semantics and are
# processed differently.
if not self.use_mla:
# Reshape the query, key, and value tensors.
# NOTE(woosuk): We do this outside the custom op to minimize the
# CPU overheads from the non-CUDA-graph regions.
query = query.view(-1, self.num_heads, self.head_size)
output = output.view(-1, self.num_heads, self.head_size)
if key is not None:
key = key.view(-1, self.num_kv_heads, self.head_size)
if value is not None:
value = value.view(-1, self.num_kv_heads, self.head_size)
if self.use_direct_call:
forward_context: ForwardContext = get_forward_context()
attn_metadata = forward_context.attn_metadata

View File

@ -420,9 +420,15 @@ class DeepseekV2MLAAttention(nn.Module):
mscale = yarn_get_mscale(scaling_factor, float(mscale_all_dim))
self.scaling = self.scaling * mscale * mscale
# In the MLA backend, kv_cache includes both k_c and
# pe (i.e. decoupled position embeddings). In particular,
# the concat_and_cache_mla op requires
# k_c.size(1) + k_pe.size(1) == kv_cache.size(2)
# i.e.
# kv_lora_rank + qk_rope_head_dim == head_size
self.mla_attn = Attention(
num_heads=self.num_local_heads,
head_size=self.kv_lora_rank,
head_size=self.kv_lora_rank + self.qk_rope_head_dim,
scale=self.scaling,
num_kv_heads=1,
cache_config=cache_config,
@ -458,7 +464,10 @@ class DeepseekV2MLAAttention(nn.Module):
kv_c, k_pe = self.kv_a_proj_with_mqa(hidden_states)[0].split(
[self.kv_lora_rank, self.qk_rope_head_dim], dim=-1)
kv_c_normed = self.kv_a_layernorm(kv_c.contiguous())
return self.mla_attn(hidden_states_or_q_c, kv_c_normed, k_pe)
return self.mla_attn(hidden_states_or_q_c,
kv_c_normed,
k_pe,
output_shape=hidden_states.shape)
class DeepseekV2DecoderLayer(nn.Module):

View File

@ -162,8 +162,13 @@ class CudaPlatformBase(Platform):
kv_cache_dtype, block_size, use_v1,
use_mla) -> str:
if use_v1:
logger.info("Using Flash Attention backend on V1 engine.")
return "vllm.v1.attention.backends.flash_attn.FlashAttentionBackend"
if use_mla:
logger.info("Using Triton MLA backend on V1 engine.")
return "vllm.v1.attention.backends.triton_mla.TritonMLABackend"
else:
logger.info("Using Flash Attention backend on V1 engine.")
return ("vllm.v1.attention.backends.flash_attn."
"FlashAttentionBackend")
if use_mla:
if selected_backend == _Backend.FLASHMLA:
from vllm.attention.backends.flashmla import (

View File

@ -35,6 +35,7 @@ class _Backend(enum.Enum):
OPENVINO = enum.auto()
FLASHINFER = enum.auto()
TRITON_MLA = enum.auto()
TRITON_MLA_VLLM_V1 = enum.auto()
FLASHMLA = enum.auto()
HPU_ATTN = enum.auto()
PALLAS = enum.auto()

View File

@ -1,7 +1,7 @@
# SPDX-License-Identifier: Apache-2.0
"""Attention layer with FlashAttention."""
from dataclasses import dataclass
from typing import Any, Dict, List, Optional, Tuple, Type
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Type
import numpy as np
import torch
@ -14,6 +14,11 @@ from vllm.logger import init_logger
from vllm.platforms import current_platform
from vllm.utils import cdiv
if TYPE_CHECKING:
from vllm.v1.core.scheduler_output import SchedulerOutput
from vllm.v1.worker.gpu_input_batch import InputBatch
from vllm.v1.worker.gpu_model_runner import GPUModelRunner
if current_platform.is_cuda():
from vllm.vllm_flash_attn import flash_attn_varlen_func
@ -40,6 +45,10 @@ class FlashAttentionBackend(AttentionBackend):
def get_metadata_cls() -> Type["AttentionMetadata"]:
return FlashAttentionMetadata
@staticmethod
def get_builder_cls() -> Type["FlashAttentionMetadataBuilder"]:
return FlashAttentionMetadataBuilder
@staticmethod
def get_kv_cache_shape(
num_blocks: int,
@ -85,6 +94,62 @@ class FlashAttentionMetadata:
num_input_tokens: int = 0 # Number of tokens including padding.
class FlashAttentionMetadataBuilder:
def __init__(self, runner: "GPUModelRunner"):
self.runner = runner
def reorder_batch(self, input_batch: "InputBatch",
scheduler_output: "SchedulerOutput"):
pass
def build(self, num_reqs: int, num_actual_tokens: int, max_query_len: int,
common_prefix_len: int):
max_seq_len = self.runner.seq_lens_np[:num_reqs].max()
query_start_loc = self.runner.query_start_loc_cpu[:num_reqs + 1].to(
self.runner.device, non_blocking=True)
seq_lens = self.runner.seq_lens_cpu[:num_reqs].to(self.runner.device,
non_blocking=True)
block_table = (
self.runner.input_batch.block_table.get_device_tensor()[:num_reqs])
slot_mapping = self.runner.slot_mapping_cpu[:num_actual_tokens].to(
self.runner.device, non_blocking=True).long()
use_cascade = common_prefix_len > 0
if use_cascade:
# TODO: Optimize.
cu_prefix_query_lens = torch.tensor([0, num_actual_tokens],
dtype=torch.int32,
device=self.runner.device)
prefix_kv_lens = torch.tensor([common_prefix_len],
dtype=torch.int32,
device=self.runner.device)
suffix_kv_lens = (self.runner.seq_lens_np[:num_reqs] -
common_prefix_len)
suffix_kv_lens = torch.from_numpy(suffix_kv_lens).to(
self.runner.device)
else:
cu_prefix_query_lens = None
prefix_kv_lens = None
suffix_kv_lens = None
attn_metadata = FlashAttentionMetadata(
num_actual_tokens=num_actual_tokens,
max_query_len=max_query_len,
query_start_loc=query_start_loc,
max_seq_len=max_seq_len,
seq_lens=seq_lens,
block_table=block_table,
slot_mapping=slot_mapping,
use_cascade=use_cascade,
common_prefix_len=common_prefix_len,
cu_prefix_query_lens=cu_prefix_query_lens,
prefix_kv_lens=prefix_kv_lens,
suffix_kv_lens=suffix_kv_lens,
)
return attn_metadata
class FlashAttentionImpl(AttentionImpl):
def __init__(
@ -371,4 +436,4 @@ def cascade_attention(
# Merge prefix and suffix outputs, and store the result in output.
merge_attn_states(output, prefix_output, prefix_lse, suffix_output,
suffix_lse)
suffix_lse)

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,110 @@
# SPDX-License-Identifier: Apache-2.0
from typing import Any, Dict, List, Optional, Type
import torch
from vllm.attention.backends.abstract import AttentionType
from vllm.attention.ops.triton_decode_attention import decode_attention_fwd
from vllm.logger import init_logger
from vllm.v1.attention.backends.mla.common import (MLACommonBackend,
MLACommonImpl,
MLACommonMetadata)
logger = init_logger(__name__)
class TritonMLABackend(MLACommonBackend):
@staticmethod
def get_name() -> str:
return "TRITON_MLA_VLLM_V1"
@staticmethod
def get_impl_cls() -> Type["TritonMLAImpl"]:
return TritonMLAImpl
class TritonMLAImpl(MLACommonImpl[MLACommonMetadata]):
def __init__(
self,
num_heads: int,
head_size: int,
scale: float,
num_kv_heads: int,
alibi_slopes: Optional[List[float]],
sliding_window: Optional[int],
kv_cache_dtype: str,
blocksparse_params: Optional[Dict[str, Any]],
logits_soft_cap: Optional[float],
attn_type: str,
# MLA Specific Arguments
**mla_args) -> None:
super().__init__(num_heads, head_size, scale, num_kv_heads,
alibi_slopes, sliding_window, kv_cache_dtype,
blocksparse_params, logits_soft_cap, attn_type,
**mla_args)
unsupported_features = [
alibi_slopes, sliding_window, blocksparse_params, logits_soft_cap
]
if any(unsupported_features):
raise NotImplementedError(
"TritonMLAImpl does not support one of the following: "
"alibi_slopes, sliding_window, blocksparse_params, "
"logits_soft_cap")
if attn_type != AttentionType.DECODER:
raise NotImplementedError("Encoder self-attention and "
"encoder/decoder cross-attention "
"are not implemented for "
"TritonMLAImpl")
def _forward_decode(
self,
q_nope: torch.Tensor,
q_pe: torch.Tensor,
kv_c_and_k_pe_cache: torch.Tensor,
attn_metadata: MLACommonMetadata,
) -> torch.Tensor:
assert kv_c_and_k_pe_cache.numel() > 0
if self.kv_cache_dtype.startswith("fp8"):
raise NotImplementedError("FP8 Triton MLA not yet supported")
B = q_nope.shape[0]
q = torch.cat([q_nope, q_pe], dim=-1)
o = torch.zeros(B,
self.num_heads,
self.kv_lora_rank,
dtype=q.dtype,
device=q.device)
num_kv_splits = 4 # TODO: heuristic
# TODO(lucas) Allocate ahead of time
attn_logits = torch.empty(
(
B,
self.num_heads,
num_kv_splits,
# NOTE(lucas) idk why the +1 is here but sglang has it so we
# just mirror that
self.kv_lora_rank + 1,
),
dtype=torch.float32,
device=q.device,
)
# Add a head dim of 1
kv_c_and_k_pe_cache = kv_c_and_k_pe_cache.unsqueeze(2)
kv_c_cache = kv_c_and_k_pe_cache[..., :self.kv_lora_rank]
PAGE_SIZE = kv_c_and_k_pe_cache.size(1)
# Run MQA
decode_attention_fwd(q, kv_c_and_k_pe_cache, kv_c_cache, o,
attn_metadata.block_table, attn_metadata.seq_lens,
attn_logits, num_kv_splits, self.scale, PAGE_SIZE)
return self._v_up_proj_and_o_proj(o)

View File

@ -80,7 +80,14 @@ class InputBatch:
self.num_tokens = np.zeros(max_num_reqs, dtype=np.int32)
self.num_tokens_no_spec = np.zeros(max_num_reqs, dtype=np.int32)
self.num_prompt_tokens = np.zeros(max_num_reqs, dtype=np.int32)
self.num_computed_tokens_cpu = np.empty(max_num_reqs, dtype=np.int32)
self.num_computed_tokens_cpu_tensor = torch.zeros(
(max_num_reqs, ),
device="cpu",
dtype=torch.int32,
pin_memory=pin_memory,
)
self.num_computed_tokens_cpu = \
self.num_computed_tokens_cpu_tensor.numpy()
# Block table.
self.block_table = BlockTable(
@ -356,6 +363,61 @@ class InputBatch:
self.allowed_token_ids_mask_cpu_tensor[req_index].fill_(False)
return req_index
def swap_states(self, i1: int, i2: int) -> None:
old_id_i1 = self._req_ids[i1]
old_id_i2 = self._req_ids[i2]
self._req_ids[i1], self._req_ids[i2] =\
self._req_ids[i2], self._req_ids[i1] # noqa
self.req_output_token_ids[i1], self.req_output_token_ids[i2] =\
self.req_output_token_ids[i2], self.req_output_token_ids[i1]
assert old_id_i1 is not None and old_id_i2 is not None
self.req_id_to_index[old_id_i1], self.req_id_to_index[old_id_i2] =\
self.req_id_to_index[old_id_i2], self.req_id_to_index[old_id_i1]
self.num_tokens[i1], self.num_tokens[i2] =\
self.num_tokens[i2], self.num_tokens[i1]
self.token_ids_cpu[i1, ...], self.token_ids_cpu[i2, ...], =\
self.token_ids_cpu[i2, ...], self.token_ids_cpu[i1, ...]
self.num_tokens_no_spec[i1], self.num_tokens_no_spec[i2] =\
self.num_tokens_no_spec[i2], self.num_tokens_no_spec[i1]
self.num_prompt_tokens[i1], self.num_prompt_tokens[i2] =\
self.num_prompt_tokens[i2], self.num_prompt_tokens[i1]
self.num_computed_tokens_cpu[i1], self.num_computed_tokens_cpu[i2] =\
self.num_computed_tokens_cpu[i2], self.num_computed_tokens_cpu[i1]
self.temperature_cpu[i1], self.temperature_cpu[i2] =\
self.temperature_cpu[i2], self.temperature_cpu[i1]
self.top_p_cpu[i1], self.top_p_cpu[i2] =\
self.top_p_cpu[i2], self.top_p_cpu[i1]
self.top_k_cpu[i1], self.top_k_cpu[i2] =\
self.top_k_cpu[i2], self.top_k_cpu[i1]
self.frequency_penalties_cpu[i1], self.frequency_penalties_cpu[i2] =\
self.frequency_penalties_cpu[i2], self.frequency_penalties_cpu[i1]
self.presence_penalties_cpu[i1], self.presence_penalties_cpu[i2] =\
self.presence_penalties_cpu[i2], self.presence_penalties_cpu[i1]
self.repetition_penalties_cpu[i1], self.repetition_penalties_cpu[i2] =\
self.repetition_penalties_cpu[i2], self.repetition_penalties_cpu[i1]
self.min_p_cpu[i1], self.min_p_cpu[i2] =\
self.min_p_cpu[i2], self.min_p_cpu[i1]
g1 = self.generators.get(i1)
g2 = self.generators.get(i2)
if g1 is not None:
self.generators[i2] = g1
if g2 is not None:
self.generators[i1] = g2
t1 = self.min_tokens.get(i1)
t2 = self.min_tokens.get(i2)
if t1 is not None:
self.min_tokens[i2] = t1
if t2 is not None:
self.min_tokens[i1] = t2
self.request_lora_mapping[i1], self.request_lora_mapping[i2] =\
self.request_lora_mapping[i2], self.request_lora_mapping[i1]
self.logit_bias[i1], self.logit_bias[i2] =\
self.logit_bias[i2], self.logit_bias[i1]
self.block_table.swap_row(i1, i2)
def condense(self, empty_req_indices: List[int]) -> None:
num_reqs = self.num_reqs
if num_reqs == 0:

View File

@ -2,6 +2,7 @@
import gc
import time
import weakref
from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union
import numpy as np
@ -9,7 +10,7 @@ import torch
import torch.distributed
import torch.nn as nn
from vllm.attention.backends.abstract import AttentionType
from vllm.attention import AttentionType, get_attn_backend
from vllm.attention.layer import Attention
from vllm.config import CompilationLevel, VllmConfig
from vllm.distributed.parallel_state import get_pp_group, graph_capture
@ -24,8 +25,7 @@ from vllm.sampling_params import SamplingType
from vllm.sequence import IntermediateTensors
from vllm.utils import (STR_DTYPE_TO_TORCH_DTYPE, DeviceMemoryProfiler,
LayerBlockType, cdiv, is_pin_memory_available)
from vllm.v1.attention.backends.flash_attn import (FlashAttentionBackend,
FlashAttentionMetadata)
from vllm.v1.attention.backends.flash_attn import FlashAttentionMetadata
from vllm.v1.core.encoder_cache_manager import compute_encoder_budget
from vllm.v1.engine.mm_input_cache import MMInputCacheClient
from vllm.v1.kv_cache_interface import (FullAttentionSpec, KVCacheConfig,
@ -92,6 +92,27 @@ class GPUModelRunner(LoRAModelRunnerMixin):
self.head_size = model_config.get_head_size()
self.hidden_size = model_config.get_hidden_size()
self.attn_backend = get_attn_backend(
self.head_size,
self.dtype,
self.kv_cache_dtype,
self.block_size,
self.model_config.is_attention_free,
use_mla=self.model_config.use_mla,
)
if self.attn_backend is None:
error_msg = (
f"Error with get_att_backend: {self.head_size=}, "
f"{self.dtype=}, {self.kv_cache_dtype=}, {self.block_size=}, "
f"{self.model_config.is_attention_free=}, "
f"{self.model_config.use_mla=}")
logger.error(error_msg)
raise NotImplementedError(
"Non-Attention backend is not supported by V1 GPUModelRunner.")
self.attn_metadata_builder = self.attn_backend.get_builder_cls()(
weakref.proxy(self))
# Multi-modal data support
self.input_registry = INPUT_REGISTRY
self.mm_registry = MULTIMODAL_REGISTRY
@ -433,6 +454,12 @@ class GPUModelRunner(LoRAModelRunnerMixin):
num_reqs = self.input_batch.num_reqs
assert num_reqs > 0
# Some attention backends (namely MLA) may want to separate requests
# based on if the attention computation will be compute-bound or
# memory-bound. This gives them a hook to do that.
self.attn_metadata_builder.reorder_batch(self.input_batch,
scheduler_output)
# OPTIMIZATION: Start copying the block table first.
# This way, we can overlap the copy with the following CPU operations.
self.input_batch.block_table.commit(num_reqs)
@ -515,7 +542,6 @@ class GPUModelRunner(LoRAModelRunnerMixin):
self.seq_lens_np[:num_reqs] = (
self.input_batch.num_computed_tokens_cpu[:num_reqs] +
num_scheduled_tokens)
max_seq_len = self.seq_lens_np[:num_reqs].max()
# Copy the tensors to the GPU.
self.input_ids[:total_num_scheduled_tokens].copy_(
@ -530,49 +556,17 @@ class GPUModelRunner(LoRAModelRunnerMixin):
self.positions[:total_num_scheduled_tokens].copy_(
self.positions_cpu[:total_num_scheduled_tokens],
non_blocking=True)
query_start_loc = self.query_start_loc_cpu[:num_reqs + 1].to(
self.device, non_blocking=True)
seq_lens = self.seq_lens_cpu[:num_reqs].to(self.device,
non_blocking=True)
slot_mapping = self.slot_mapping_cpu[:total_num_scheduled_tokens].to(
self.device, non_blocking=True).long()
# Prepare for cascade attention if needed.
common_prefix_len = self._compute_cascade_attn_prefix_len(
num_scheduled_tokens,
scheduler_output.num_common_prefix_blocks,
)
use_cascade = common_prefix_len > 0
if use_cascade:
# TODO: Optimize.
cu_prefix_query_lens = torch.tensor(
[0, total_num_scheduled_tokens],
dtype=torch.int32,
device=self.device)
prefix_kv_lens = torch.tensor([common_prefix_len],
dtype=torch.int32,
device=self.device)
suffix_kv_lens = (self.seq_lens_np[:num_reqs] - common_prefix_len)
suffix_kv_lens = torch.from_numpy(suffix_kv_lens).to(self.device)
else:
cu_prefix_query_lens = None
prefix_kv_lens = None
suffix_kv_lens = None
attn_metadata = FlashAttentionMetadata(
attn_metadata = self.attn_metadata_builder.build(
num_reqs=num_reqs,
num_actual_tokens=total_num_scheduled_tokens,
max_query_len=max_num_scheduled_tokens,
query_start_loc=query_start_loc,
max_seq_len=max_seq_len,
seq_lens=seq_lens,
block_table=(
self.input_batch.block_table.get_device_tensor()[:num_reqs]),
slot_mapping=slot_mapping,
use_cascade=use_cascade,
common_prefix_len=common_prefix_len,
cu_prefix_query_lens=cu_prefix_query_lens,
prefix_kv_lens=prefix_kv_lens,
suffix_kv_lens=suffix_kv_lens,
)
use_spec_decode = len(
@ -586,7 +580,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
# from these partial requests, we do so for simplicity.
# We will ignore the sampled tokens from the partial requests.
# TODO: Support prompt logprobs.
logits_indices = query_start_loc[1:] - 1
logits_indices = attn_metadata.query_start_loc[1:] - 1
# Hot-Swap lora model
if self.lora_config:
@ -667,7 +661,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
# common_prefix_len should be a multiple of the block size.
common_prefix_len = (common_prefix_len // self.block_size *
self.block_size)
use_cascade = FlashAttentionBackend.use_cascade_attention(
use_cascade = self.attn_backend.use_cascade_attention(
common_prefix_len=common_prefix_len,
query_lens=num_scheduled_tokens,
num_query_heads=self.num_query_heads,
@ -1379,7 +1373,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
assert tensor_config.size % layer_spec.page_size_bytes == 0
num_blocks = tensor_config.size // layer_spec.page_size_bytes
if isinstance(layer_spec, FullAttentionSpec):
kv_cache_shape = FlashAttentionBackend.get_kv_cache_shape(
kv_cache_shape = self.attn_backend.get_kv_cache_shape(
num_blocks, layer_spec.block_size, layer_spec.num_kv_heads,
layer_spec.head_size)
dtype = layer_spec.dtype