[Attention] Update to lastest FA3 code (#13111)

Signed-off-by: Lucas Wilkinson <lwilkinson@neuralmagic.com>
This commit is contained in:
Lucas Wilkinson 2025-04-17 18:14:07 -04:00 committed by GitHub
parent 3408e47159
commit 183dad7a85
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 241 additions and 118 deletions

View File

@ -38,7 +38,7 @@ else()
FetchContent_Declare( FetchContent_Declare(
vllm-flash-attn vllm-flash-attn
GIT_REPOSITORY https://github.com/vllm-project/flash-attention.git GIT_REPOSITORY https://github.com/vllm-project/flash-attention.git
GIT_TAG dc9d410b3e2d6534a4c70724c2515f4def670a22 GIT_TAG 0a721daebe4fa7149f06ecf3d3eabeb6dcd0f1fa
GIT_PROGRESS TRUE GIT_PROGRESS TRUE
# Don't share the vllm-flash-attn build between build types # Don't share the vllm-flash-attn build between build types
BINARY_DIR ${CMAKE_BINARY_DIR}/vllm-flash-attn BINARY_DIR ${CMAKE_BINARY_DIR}/vllm-flash-attn

View File

@ -1043,8 +1043,8 @@ class MLACommonImpl(MLAAttentionImpl[T], Generic[T]):
self.q_proj = q_proj self.q_proj = q_proj
self.kv_b_proj = kv_b_proj self.kv_b_proj = kv_b_proj
self.o_proj = o_proj self.o_proj = o_proj
self.triton_fa_func = triton_attention
self.triton_fa_func = triton_attention
# Handle the differences between the flash_attn_varlen from flash_attn # Handle the differences between the flash_attn_varlen from flash_attn
# and the one from vllm_flash_attn. The former is used on RoCM and the # and the one from vllm_flash_attn. The former is used on RoCM and the
# latter has an additional parameter to control FA2 vs FA3 # latter has an additional parameter to control FA2 vs FA3
@ -1055,6 +1055,70 @@ class MLACommonImpl(MLAAttentionImpl[T], Generic[T]):
functools.partial(flash_attn_varlen_func, functools.partial(flash_attn_varlen_func,
fa_version=self.vllm_flash_attn_version) fa_version=self.vllm_flash_attn_version)
# For MLA the v head dim is smaller than qk head dim so we pad out
# v with 0s to match the qk head dim for attention backends that do
# not support different headdims
# We don't need to pad V if we are on a hopper system with FA3
self._pad_v = self.vllm_flash_attn_version is None or not (
self.vllm_flash_attn_version == 3
and current_platform.get_device_capability()[0] == 9)
def _flash_attn_varlen_diff_headdims(self, q, k, v, softmax_scale,
return_softmax_lse, **kwargs):
maybe_padded_v = v
if self._pad_v:
maybe_padded_v = torch.nn.functional.pad(
v, [0, q.shape[-1] - v.shape[-1]], value=0)
if is_hip and envs.VLLM_USE_TRITON_FLASH_ATTN \
and not return_softmax_lse:
attn_out = self.triton_fa_func(
q,
k,
maybe_padded_v,
**kwargs,
)
if is_vllm_fa:
attn_out = self.flash_attn_varlen_func(
q=q,
k=k,
v=maybe_padded_v,
return_softmax_lse=return_softmax_lse,
softmax_scale=softmax_scale,
**kwargs,
)
else:
# Use return_attn_probs instead of return_softmax_lse for RoCM
attn_out = self.flash_attn_varlen_func(
q=q,
k=k,
v=maybe_padded_v,
return_attn_probs=return_softmax_lse,
softmax_scale=softmax_scale,
**kwargs,
)
# Unpack the output if there is multiple results,
# triton always returns (output, softmax_lse),
# vllm_flash_attn returns (output, softmax_lse) when
# `return_softmax_lse = True`
# flash_attn (RoCM) returns (output, softmax_lse, ...) when
# `return_attn_probs = True`
rest = None
if isinstance(attn_out, tuple):
attn_out, *rest = attn_out
# unpad if necessary
if self._pad_v:
attn_out = attn_out[..., :v.shape[-1]]
# Remain consistent with old `flash_attn_varlen_func` where there
# is only one output tensor if `return_softmax_lse` is False.
if return_softmax_lse:
assert rest is not None
return attn_out, rest[0]
return attn_out
def _v_up_proj_and_o_proj(self, x): def _v_up_proj_and_o_proj(self, x):
# Convert from (B, N, L) to (N, B, L) # Convert from (B, N, L) to (N, B, L)
x = x.view(-1, self.num_heads, self.kv_lora_rank).transpose(0, 1) x = x.view(-1, self.num_heads, self.kv_lora_rank).transpose(0, 1)
@ -1176,40 +1240,19 @@ class MLACommonImpl(MLAAttentionImpl[T], Generic[T]):
k = torch.cat((k_nope, k_pe.expand((*k_nope.shape[:-1], -1))), k = torch.cat((k_nope, k_pe.expand((*k_nope.shape[:-1], -1))),
dim=-1) dim=-1)
# For MLA the v head dim is smaller than qk head dim so we pad attn_output, attn_softmax_lse = \
# out v with 0s to match the qk head dim self._flash_attn_varlen_diff_headdims(
v_padded = torch.nn.functional.pad(v, q=q,
[0, q.shape[-1] - v.shape[-1]], k=k,
value=0) v=v,
cu_seqlens_q=prefill_metadata.query_start_loc,
if is_vllm_fa: cu_seqlens_k=prefill_metadata.context_chunk_cu_seq_lens[i],
attn_output, attn_softmax_lse = self.flash_attn_varlen_func( max_seqlen_q=prefill_metadata.max_query_len,
q=q, max_seqlen_k=prefill_metadata.context_chunk_max_seq_lens[i],
k=k, softmax_scale=self.scale,
v=v_padded, causal=False, # Context is unmasked
cu_seqlens_q=prefill_metadata.query_start_loc, return_softmax_lse=True,
cu_seqlens_k=prefill_metadata.context_chunk_cu_seq_lens[i], )
max_seqlen_q=prefill_metadata.max_query_len,
max_seqlen_k=prefill_metadata.
context_chunk_max_seq_lens[i],
softmax_scale=self.scale,
causal=False, # Context is unmasked
return_softmax_lse=True,
)
else:
attn_output, attn_softmax_lse, _ = self.flash_attn_varlen_func(
q=q,
k=k,
v=v_padded,
cu_seqlens_q=prefill_metadata.query_start_loc,
cu_seqlens_k=prefill_metadata.context_chunk_cu_seq_lens[i],
max_seqlen_q=prefill_metadata.max_query_len,
max_seqlen_k=prefill_metadata.
context_chunk_max_seq_lens[i],
softmax_scale=self.scale,
causal=False, # Context is unmasked
return_attn_probs=True,
)
if output is None: if output is None:
output = attn_output output = attn_output
@ -1252,58 +1295,22 @@ class MLACommonImpl(MLAAttentionImpl[T], Generic[T]):
k = torch.cat((k_nope, k_pe.expand((*k_nope.shape[:-1], -1))), dim=-1) k = torch.cat((k_nope, k_pe.expand((*k_nope.shape[:-1], -1))), dim=-1)
# For MLA the v head dim is smaller than qk head dim so we pad out output = self._flash_attn_varlen_diff_headdims(
# v with 0s to match the qk head dim q=q,
v_padded = torch.nn.functional.pad(v, [0, q.shape[-1] - v.shape[-1]], k=k,
value=0) v=v,
cu_seqlens_q=prefill_metadata.query_start_loc,
if is_hip and envs.VLLM_USE_TRITON_FLASH_ATTN and not has_context: cu_seqlens_k=prefill_metadata.query_start_loc,
output = self.triton_fa_func( max_seqlen_q=prefill_metadata.max_prefill_seq_len,
q, max_seqlen_k=prefill_metadata.max_prefill_seq_len,
k, softmax_scale=self.scale,
v_padded, causal=True,
None, return_softmax_lse=has_context,
prefill_metadata.query_start_loc, )
prefill_metadata.query_start_loc,
prefill_metadata.max_prefill_seq_len,
prefill_metadata.max_prefill_seq_len,
True, # causal
self.scale,
None, # attn_mask is None unless applying ALiBi mask
)
## triton flash attention always return 2 objects
if not has_context:
output = output[0]
elif is_vllm_fa:
output = self.flash_attn_varlen_func(
q=q,
k=k,
v=v_padded,
cu_seqlens_q=prefill_metadata.query_start_loc,
cu_seqlens_k=prefill_metadata.query_start_loc,
max_seqlen_q=prefill_metadata.max_prefill_seq_len,
max_seqlen_k=prefill_metadata.max_prefill_seq_len,
softmax_scale=self.scale,
causal=True,
return_softmax_lse=has_context,
)
else:
output = self.flash_attn_varlen_func(
q=q,
k=k,
v=v_padded,
cu_seqlens_q=prefill_metadata.query_start_loc,
cu_seqlens_k=prefill_metadata.query_start_loc,
max_seqlen_q=prefill_metadata.max_prefill_seq_len,
max_seqlen_k=prefill_metadata.max_prefill_seq_len,
softmax_scale=self.scale,
causal=True,
return_attn_probs=has_context,
)
if has_context: if has_context:
# ROCm flash_attn_varlen_func will return 3 objects instead of 2 # ROCm flash_attn_varlen_func will return 3 objects instead of 2
suffix_output, suffix_lse, *rest = output suffix_output, suffix_lse = output
context_output, context_lse = self._compute_prefill_context( \ context_output, context_lse = self._compute_prefill_context( \
q, kv_c_and_k_pe_cache, attn_metadata) q, kv_c_and_k_pe_cache, attn_metadata)
@ -1316,12 +1323,7 @@ class MLACommonImpl(MLAAttentionImpl[T], Generic[T]):
suffix_lse=suffix_lse, suffix_lse=suffix_lse,
) )
# slice by `:v.shape[-1]` in order to remove v headdim padding return self.o_proj(output.flatten(start_dim=-2))[0]
output = output\
.view(-1, self.num_heads, q.shape[-1])[..., :v.shape[-1]]\
.reshape(-1, self.num_heads * v.shape[-1])
return self.o_proj(output)[0]
@abstractmethod @abstractmethod
def _forward_decode( def _forward_decode(

View File

@ -2,8 +2,10 @@
"""Attention backend utils""" """Attention backend utils"""
from collections import defaultdict from collections import defaultdict
from contextlib import contextmanager from contextlib import contextmanager
from dataclasses import dataclass
from itertools import accumulate from itertools import accumulate
from typing import TYPE_CHECKING, Any, Dict, List, Tuple, Type, TypeVar, Union from typing import (TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Type,
TypeVar, Union)
import numpy as np import numpy as np
import torch import torch
@ -11,6 +13,7 @@ import torch
from vllm.attention import (AttentionMetadata, AttentionMetadataBuilder, from vllm.attention import (AttentionMetadata, AttentionMetadataBuilder,
AttentionState) AttentionState)
from vllm.attention.backends.abstract import AttentionType from vllm.attention.backends.abstract import AttentionType
from vllm.config import ModelConfig
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.multimodal import MultiModalPlaceholderMap from vllm.multimodal import MultiModalPlaceholderMap
from vllm.utils import async_tensor_h2d, make_tensor_with_pad from vllm.utils import async_tensor_h2d, make_tensor_with_pad
@ -583,3 +586,24 @@ def get_num_prefill_decode_query_kv_tokens(
return (num_prefill_query_tokens, num_prefill_kv_tokens, return (num_prefill_query_tokens, num_prefill_kv_tokens,
num_decode_query_tokens) num_decode_query_tokens)
@dataclass
class MLADims:
q_lora_rank: Optional[int]
kv_lora_rank: int
qk_nope_head_dim: int
qk_rope_head_dim: int
v_head_dim: int
def get_mla_dims(model_config: ModelConfig) -> MLADims:
hf_text_config = model_config.hf_text_config
return MLADims(
q_lora_rank=getattr(hf_text_config, "q_lora_rank", None),
kv_lora_rank=hf_text_config.kv_lora_rank,
qk_nope_head_dim=hf_text_config.qk_nope_head_dim,
qk_rope_head_dim=hf_text_config.qk_rope_head_dim,
v_head_dim=hf_text_config.v_head_dim,
)

View File

@ -23,7 +23,8 @@ if TYPE_CHECKING:
from vllm.v1.worker.gpu_model_runner import GPUModelRunner from vllm.v1.worker.gpu_model_runner import GPUModelRunner
if current_platform.is_cuda(): if current_platform.is_cuda():
from vllm.vllm_flash_attn import flash_attn_varlen_func from vllm.vllm_flash_attn import (flash_attn_varlen_func,
get_scheduler_metadata)
logger = init_logger(__name__) logger = init_logger(__name__)
@ -93,6 +94,10 @@ class FlashAttentionMetadata:
prefix_kv_lens: Optional[torch.Tensor] prefix_kv_lens: Optional[torch.Tensor]
suffix_kv_lens: Optional[torch.Tensor] suffix_kv_lens: Optional[torch.Tensor]
# Optional aot scheduling
scheduler_metadata: Optional[torch.Tensor] = None
prefix_scheduler_metadata: Optional[torch.Tensor] = None
# For logging. # For logging.
num_input_tokens: int = 0 # Number of tokens including padding. num_input_tokens: int = 0 # Number of tokens including padding.
@ -277,7 +282,14 @@ def make_local_attention_virtual_batches(
class FlashAttentionMetadataBuilder: class FlashAttentionMetadataBuilder:
def __init__(self, runner: "GPUModelRunner"): def __init__(self, runner: "GPUModelRunner"):
model_config = runner.model_config
self.runner = runner self.runner = runner
self.aot_schedule = (get_flash_attn_version() == 3)
self.num_heads = model_config.get_num_attention_heads(
runner.parallel_config)
self.headdim = model_config.get_head_size()
self.page_size = self.runner.block_size
def reorder_batch(self, input_batch: "InputBatch", def reorder_batch(self, input_batch: "InputBatch",
scheduler_output: "SchedulerOutput") -> bool: scheduler_output: "SchedulerOutput") -> bool:
@ -319,6 +331,24 @@ class FlashAttentionMetadataBuilder:
) )
use_cascade = common_prefix_len > 0 use_cascade = common_prefix_len > 0
def schedule(cu_query_lens, max_query_len, seqlens, max_seq_len,
causal):
if self.aot_schedule:
return get_scheduler_metadata(
batch_size=num_reqs,
max_seqlen_q=max_query_len,
max_seqlen_k=max_seq_len,
cache_seqlens=seqlens,
num_heads_q=self.num_heads,
num_heads_kv=self.num_heads,
headdim=self.headdim,
page_size=self.page_size,
cu_seqlens_q=cu_query_lens,
causal=causal,
)
return None
if use_cascade: if use_cascade:
cu_prefix_query_lens = torch.tensor([0, num_actual_tokens], cu_prefix_query_lens = torch.tensor([0, num_actual_tokens],
dtype=torch.int32, dtype=torch.int32,
@ -330,10 +360,28 @@ class FlashAttentionMetadataBuilder:
common_prefix_len) common_prefix_len)
suffix_kv_lens = torch.from_numpy(suffix_kv_lens).to( suffix_kv_lens = torch.from_numpy(suffix_kv_lens).to(
self.runner.device) self.runner.device)
prefix_scheduler_metadata = schedule(
cu_query_lens=cu_prefix_query_lens,
max_query_len=num_actual_tokens,
seqlens=prefix_kv_lens,
max_seq_len=common_prefix_len,
causal=False)
scheduler_metadata = schedule(cu_query_lens=query_start_loc,
max_query_len=max_query_len,
seqlens=suffix_kv_lens,
max_seq_len=max_seq_len -
common_prefix_len,
causal=True)
else: else:
cu_prefix_query_lens = None cu_prefix_query_lens = None
prefix_kv_lens = None prefix_kv_lens = None
suffix_kv_lens = None suffix_kv_lens = None
prefix_scheduler_metadata = None
scheduler_metadata = schedule(cu_query_lens=query_start_loc,
max_query_len=max_query_len,
seqlens=seq_lens,
max_seq_len=max_seq_len,
causal=True)
attn_metadata = FlashAttentionMetadata( attn_metadata = FlashAttentionMetadata(
num_actual_tokens=num_actual_tokens, num_actual_tokens=num_actual_tokens,
@ -345,10 +393,12 @@ class FlashAttentionMetadataBuilder:
slot_mapping=slot_mapping, slot_mapping=slot_mapping,
use_cascade=use_cascade, use_cascade=use_cascade,
common_prefix_len=common_prefix_len, common_prefix_len=common_prefix_len,
scheduler_metadata=scheduler_metadata,
cu_prefix_query_lens=cu_prefix_query_lens, cu_prefix_query_lens=cu_prefix_query_lens,
prefix_kv_lens=prefix_kv_lens, prefix_kv_lens=prefix_kv_lens,
suffix_kv_lens=suffix_kv_lens, suffix_kv_lens=suffix_kv_lens,
local_attn_metadata=local_attn_metadata, local_attn_metadata=local_attn_metadata,
prefix_scheduler_metadata=prefix_scheduler_metadata,
) )
return attn_metadata return attn_metadata
@ -515,6 +565,7 @@ class FlashAttentionImpl(AttentionImpl):
window_size=self.sliding_window, window_size=self.sliding_window,
block_table=block_table, block_table=block_table,
softcap=self.logits_soft_cap, softcap=self.logits_soft_cap,
scheduler_metadata=attn_metadata.scheduler_metadata,
fa_version=self.vllm_flash_attn_version, fa_version=self.vllm_flash_attn_version,
q_descale=layer._q_scale.expand(descale_shape), q_descale=layer._q_scale.expand(descale_shape),
k_descale=layer._k_scale.expand(descale_shape), k_descale=layer._k_scale.expand(descale_shape),
@ -543,6 +594,8 @@ class FlashAttentionImpl(AttentionImpl):
block_table=attn_metadata.block_table, block_table=attn_metadata.block_table,
common_prefix_len=attn_metadata.common_prefix_len, common_prefix_len=attn_metadata.common_prefix_len,
fa_version=self.vllm_flash_attn_version, fa_version=self.vllm_flash_attn_version,
prefix_scheduler_metadata=attn_metadata.prefix_scheduler_metadata,
suffix_scheduler_metadata=attn_metadata.scheduler_metadata,
q_descale=layer._q_scale, q_descale=layer._q_scale,
k_descale=layer._k_scale, k_descale=layer._k_scale,
v_descale=layer._v_scale, v_descale=layer._v_scale,
@ -636,6 +689,8 @@ def cascade_attention(
block_table: torch.Tensor, block_table: torch.Tensor,
common_prefix_len: int, common_prefix_len: int,
fa_version: int, fa_version: int,
prefix_scheduler_metadata: Optional[torch.Tensor] = None,
suffix_scheduler_metadata: Optional[torch.Tensor] = None,
q_descale: Optional[torch.Tensor] = None, q_descale: Optional[torch.Tensor] = None,
k_descale: Optional[torch.Tensor] = None, k_descale: Optional[torch.Tensor] = None,
v_descale: Optional[torch.Tensor] = None, v_descale: Optional[torch.Tensor] = None,
@ -667,6 +722,7 @@ def cascade_attention(
block_table=block_table[:1], block_table=block_table[:1],
softcap=logits_soft_cap, softcap=logits_soft_cap,
return_softmax_lse=True, return_softmax_lse=True,
scheduler_metadata=prefix_scheduler_metadata,
fa_version=fa_version, fa_version=fa_version,
q_descale=q_descale.expand(descale_shape) q_descale=q_descale.expand(descale_shape)
if q_descale is not None else None, if q_descale is not None else None,
@ -693,6 +749,7 @@ def cascade_attention(
block_table=block_table[:, num_common_kv_blocks:], block_table=block_table[:, num_common_kv_blocks:],
softcap=logits_soft_cap, softcap=logits_soft_cap,
return_softmax_lse=True, return_softmax_lse=True,
scheduler_metadata=suffix_scheduler_metadata,
fa_version=fa_version, fa_version=fa_version,
q_descale=q_descale.expand(descale_shape) q_descale=q_descale.expand(descale_shape)
if q_descale is not None else None, if q_descale is not None else None,

View File

@ -195,6 +195,7 @@ from vllm import _custom_ops as ops
from vllm.attention.backends.abstract import (AttentionBackend, AttentionLayer, from vllm.attention.backends.abstract import (AttentionBackend, AttentionLayer,
AttentionMetadata, AttentionMetadata,
MLAAttentionImpl) MLAAttentionImpl)
from vllm.attention.backends.utils import get_mla_dims
from vllm.attention.ops.merge_attn_states import merge_attn_states from vllm.attention.ops.merge_attn_states import merge_attn_states
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.model_executor.layers.linear import (ColumnParallelLinear, from vllm.model_executor.layers.linear import (ColumnParallelLinear,
@ -207,9 +208,11 @@ from vllm.vllm_flash_attn.fa_utils import get_flash_attn_version
try: try:
from vllm.vllm_flash_attn import flash_attn_varlen_func from vllm.vllm_flash_attn import flash_attn_varlen_func
is_vllm_fa = True
except ImportError: except ImportError:
# For rocm use upstream flash attention # For rocm use upstream flash attention
from flash_attn import flash_attn_varlen_func from flash_attn import flash_attn_varlen_func
is_vllm_fa = False
if TYPE_CHECKING: if TYPE_CHECKING:
from vllm.v1.core.sched.output import SchedulerOutput from vllm.v1.core.sched.output import SchedulerOutput
@ -350,6 +353,14 @@ class MLACommonMetadataBuilder(Generic[M]):
model_config = runner.model_config model_config = runner.model_config
cache_config = runner.cache_config cache_config = runner.cache_config
self.chunked_prefill_enabled = scheduler_config.chunked_prefill_enabled self.chunked_prefill_enabled = scheduler_config.chunked_prefill_enabled
self.num_heads = model_config.get_num_attention_heads(
runner.parallel_config)
self.mla_dims = get_mla_dims(model_config)
self.aot_schedule = is_vllm_fa and (get_flash_attn_version() == 3)
# Dont try to access the runner on AMD
if self.aot_schedule:
self.page_size = self.runner.block_size
if self.chunked_prefill_enabled: if self.chunked_prefill_enabled:
self.chunked_prefill_workspace_size = min( self.chunked_prefill_workspace_size = min(
@ -375,7 +386,6 @@ class MLACommonMetadataBuilder(Generic[M]):
dtype=model_config.dtype, dtype=model_config.dtype,
device=runner.device, device=runner.device,
) )
self.page_size = self.runner.block_size
def reorder_batch(self, input_batch: "InputBatch", def reorder_batch(self, input_batch: "InputBatch",
scheduler_output: "SchedulerOutput") -> bool: scheduler_output: "SchedulerOutput") -> bool:
@ -464,7 +474,6 @@ class MLACommonMetadataBuilder(Generic[M]):
seq_lens_cpu = self.runner.seq_lens_cpu[:num_reqs] seq_lens_cpu = self.runner.seq_lens_cpu[:num_reqs]
seq_lens = seq_lens_cpu.to(device, non_blocking=True) seq_lens = seq_lens_cpu.to(device, non_blocking=True)
max_query_len = seq_lens_cpu.max().item()
prefill_metadata = None prefill_metadata = None
if self._num_prefills > 0: if self._num_prefills > 0:
@ -475,6 +484,8 @@ class MLACommonMetadataBuilder(Generic[M]):
num_computed_tokens_cpu_tensor[reqs_start:num_reqs] num_computed_tokens_cpu_tensor[reqs_start:num_reqs]
max_context_len_cpu = context_lens_cpu.max().item() max_context_len_cpu = context_lens_cpu.max().item()
num_prefills_with_context_cpu = (context_lens_cpu > 0).sum().item() num_prefills_with_context_cpu = (context_lens_cpu > 0).sum().item()
prefill_query_start_loc = query_start_loc[
reqs_start:] - query_start_loc[reqs_start]
chunked_context_metadata = None chunked_context_metadata = None
if self.chunked_prefill_enabled and self._num_prefills > 0 \ if self.chunked_prefill_enabled and self._num_prefills > 0 \
@ -537,8 +548,7 @@ class MLACommonMetadataBuilder(Generic[M]):
prefill_metadata = MLACommonPrefillMetadata( prefill_metadata = MLACommonPrefillMetadata(
input_positions=input_positions[tokens_start:], input_positions=input_positions[tokens_start:],
block_table=block_table[reqs_start:, ...], block_table=block_table[reqs_start:, ...],
query_start_loc=query_start_loc[reqs_start:] - query_start_loc=prefill_query_start_loc,
query_start_loc[reqs_start],
max_query_len=max_query_len, max_query_len=max_query_len,
chunked_context=chunked_context_metadata, chunked_context=chunked_context_metadata,
) )
@ -628,11 +638,56 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]):
# and the one from vllm_flash_attn. The former is used on RoCM and the # and the one from vllm_flash_attn. The former is used on RoCM and the
# latter has an additional parameter to control FA2 vs FA3 # latter has an additional parameter to control FA2 vs FA3
self.flash_attn_varlen_func = flash_attn_varlen_func self.flash_attn_varlen_func = flash_attn_varlen_func
self.vllm_flash_attn_version = get_flash_attn_version()
if self.vllm_flash_attn_version is not None: if self.vllm_flash_attn_version is not None:
self.flash_attn_varlen_func = \ self.flash_attn_varlen_func = \
functools.partial(flash_attn_varlen_func, functools.partial(flash_attn_varlen_func,
fa_version=self.vllm_flash_attn_version) fa_version=self.vllm_flash_attn_version)
# For MLA the v head dim is smaller than qk head dim so we pad out
# v with 0s to match the qk head dim for attention backends that do
# not support different headdims
# We don't need to pad V if we are on a hopper system with FA3
self._pad_v = self.vllm_flash_attn_version is None or not (
self.vllm_flash_attn_version == 3
and current_platform.get_device_capability()[0] == 9)
def _flash_attn_varlen_diff_headdims(self,
q,
k,
v,
return_softmax_lse=False,
softmax_scale=None,
**kwargs):
maybe_padded_v = v
if self._pad_v:
maybe_padded_v = torch.nn.functional.pad(
v, [0, q.shape[-1] - v.shape[-1]], value=0)
attn_out = self.flash_attn_varlen_func(
q=q,
k=k,
v=maybe_padded_v,
return_softmax_lse=return_softmax_lse,
softmax_scale=softmax_scale,
**kwargs,
)
# Unpack the output if there is multiple results
lse = None
if isinstance(attn_out, tuple):
attn_out, lse = attn_out[0], attn_out[1]
# unpad if necessary
if self._pad_v:
attn_out = attn_out[..., :v.shape[-1]]
# Remain consistent with old `flash_attn_varlen_func` where there
# is only one output tensor if `return_softmax_lse` is False.
if return_softmax_lse:
return attn_out, lse
return attn_out
def _v_up_proj_and_o_proj(self, x): def _v_up_proj_and_o_proj(self, x):
# Convert from (B, N, L) to (N, B, L) # Convert from (B, N, L) to (N, B, L)
x = x.view(-1, self.num_heads, self.kv_lora_rank).transpose(0, 1) x = x.view(-1, self.num_heads, self.kv_lora_rank).transpose(0, 1)
@ -745,16 +800,11 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]):
k = torch.cat((k_nope, k_pe.expand((*k_nope.shape[:-1], -1))), k = torch.cat((k_nope, k_pe.expand((*k_nope.shape[:-1], -1))),
dim=-1) dim=-1)
# For MLA the v head dim is smaller than qk head dim so we pad attn_output, attn_softmax_lse = \
# out v with 0s to match the qk head dim self._flash_attn_varlen_diff_headdims(
v_padded = torch.nn.functional.pad(v,
[0, q.shape[-1] - v.shape[-1]],
value=0)
attn_output, attn_softmax_lse = self.flash_attn_varlen_func(
q=q, q=q,
k=k, k=k,
v=v_padded, v=v,
cu_seqlens_q=prefill_metadata.query_start_loc, cu_seqlens_q=prefill_metadata.query_start_loc,
cu_seqlens_k=prefill_metadata.chunked_context.cu_seq_lens[i], cu_seqlens_k=prefill_metadata.chunked_context.cu_seq_lens[i],
max_seqlen_q=prefill_metadata.max_query_len, max_seqlen_q=prefill_metadata.max_query_len,
@ -801,15 +851,10 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]):
k = torch.cat((k_nope, k_pe.expand((*k_nope.shape[:-1], -1))), dim=-1) k = torch.cat((k_nope, k_pe.expand((*k_nope.shape[:-1], -1))), dim=-1)
# For MLA the v head dim is smaller than qk head dim so we pad out output = self._flash_attn_varlen_diff_headdims(
# v with 0s to match the qk head dim
v_padded = torch.nn.functional.pad(v, [0, q.shape[-1] - v.shape[-1]],
value=0)
output = self.flash_attn_varlen_func(
q=q, q=q,
k=k, k=k,
v=v_padded, v=v,
cu_seqlens_q=attn_metadata.prefill.query_start_loc, cu_seqlens_q=attn_metadata.prefill.query_start_loc,
cu_seqlens_k=attn_metadata.prefill.query_start_loc, cu_seqlens_k=attn_metadata.prefill.query_start_loc,
max_seqlen_q=attn_metadata.prefill.max_query_len, max_seqlen_q=attn_metadata.prefill.max_query_len,
@ -833,12 +878,7 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]):
suffix_lse=suffix_lse, suffix_lse=suffix_lse,
) )
# slice by `:v.shape[-1]` in order to remove v headdim padding return self.o_proj(output.flatten(start_dim=-2))[0]
output = output\
.view(-1, self.num_heads, q.shape[-1])[..., :v.shape[-1]]\
.reshape(-1, self.num_heads * v.shape[-1])
return self.o_proj(output)[0]
@abstractmethod @abstractmethod
def _forward_decode( def _forward_decode(