[Attention] Update to lastest FA3 code (#13111)
Signed-off-by: Lucas Wilkinson <lwilkinson@neuralmagic.com>
This commit is contained in:
parent
3408e47159
commit
183dad7a85
@ -38,7 +38,7 @@ else()
|
||||
FetchContent_Declare(
|
||||
vllm-flash-attn
|
||||
GIT_REPOSITORY https://github.com/vllm-project/flash-attention.git
|
||||
GIT_TAG dc9d410b3e2d6534a4c70724c2515f4def670a22
|
||||
GIT_TAG 0a721daebe4fa7149f06ecf3d3eabeb6dcd0f1fa
|
||||
GIT_PROGRESS TRUE
|
||||
# Don't share the vllm-flash-attn build between build types
|
||||
BINARY_DIR ${CMAKE_BINARY_DIR}/vllm-flash-attn
|
||||
|
@ -1043,8 +1043,8 @@ class MLACommonImpl(MLAAttentionImpl[T], Generic[T]):
|
||||
self.q_proj = q_proj
|
||||
self.kv_b_proj = kv_b_proj
|
||||
self.o_proj = o_proj
|
||||
self.triton_fa_func = triton_attention
|
||||
|
||||
self.triton_fa_func = triton_attention
|
||||
# Handle the differences between the flash_attn_varlen from flash_attn
|
||||
# and the one from vllm_flash_attn. The former is used on RoCM and the
|
||||
# latter has an additional parameter to control FA2 vs FA3
|
||||
@ -1055,6 +1055,70 @@ class MLACommonImpl(MLAAttentionImpl[T], Generic[T]):
|
||||
functools.partial(flash_attn_varlen_func,
|
||||
fa_version=self.vllm_flash_attn_version)
|
||||
|
||||
# For MLA the v head dim is smaller than qk head dim so we pad out
|
||||
# v with 0s to match the qk head dim for attention backends that do
|
||||
# not support different headdims
|
||||
# We don't need to pad V if we are on a hopper system with FA3
|
||||
self._pad_v = self.vllm_flash_attn_version is None or not (
|
||||
self.vllm_flash_attn_version == 3
|
||||
and current_platform.get_device_capability()[0] == 9)
|
||||
|
||||
def _flash_attn_varlen_diff_headdims(self, q, k, v, softmax_scale,
|
||||
return_softmax_lse, **kwargs):
|
||||
maybe_padded_v = v
|
||||
if self._pad_v:
|
||||
maybe_padded_v = torch.nn.functional.pad(
|
||||
v, [0, q.shape[-1] - v.shape[-1]], value=0)
|
||||
|
||||
if is_hip and envs.VLLM_USE_TRITON_FLASH_ATTN \
|
||||
and not return_softmax_lse:
|
||||
attn_out = self.triton_fa_func(
|
||||
q,
|
||||
k,
|
||||
maybe_padded_v,
|
||||
**kwargs,
|
||||
)
|
||||
if is_vllm_fa:
|
||||
attn_out = self.flash_attn_varlen_func(
|
||||
q=q,
|
||||
k=k,
|
||||
v=maybe_padded_v,
|
||||
return_softmax_lse=return_softmax_lse,
|
||||
softmax_scale=softmax_scale,
|
||||
**kwargs,
|
||||
)
|
||||
else:
|
||||
# Use return_attn_probs instead of return_softmax_lse for RoCM
|
||||
attn_out = self.flash_attn_varlen_func(
|
||||
q=q,
|
||||
k=k,
|
||||
v=maybe_padded_v,
|
||||
return_attn_probs=return_softmax_lse,
|
||||
softmax_scale=softmax_scale,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
# Unpack the output if there is multiple results,
|
||||
# triton always returns (output, softmax_lse),
|
||||
# vllm_flash_attn returns (output, softmax_lse) when
|
||||
# `return_softmax_lse = True`
|
||||
# flash_attn (RoCM) returns (output, softmax_lse, ...) when
|
||||
# `return_attn_probs = True`
|
||||
rest = None
|
||||
if isinstance(attn_out, tuple):
|
||||
attn_out, *rest = attn_out
|
||||
|
||||
# unpad if necessary
|
||||
if self._pad_v:
|
||||
attn_out = attn_out[..., :v.shape[-1]]
|
||||
|
||||
# Remain consistent with old `flash_attn_varlen_func` where there
|
||||
# is only one output tensor if `return_softmax_lse` is False.
|
||||
if return_softmax_lse:
|
||||
assert rest is not None
|
||||
return attn_out, rest[0]
|
||||
return attn_out
|
||||
|
||||
def _v_up_proj_and_o_proj(self, x):
|
||||
# Convert from (B, N, L) to (N, B, L)
|
||||
x = x.view(-1, self.num_heads, self.kv_lora_rank).transpose(0, 1)
|
||||
@ -1176,40 +1240,19 @@ class MLACommonImpl(MLAAttentionImpl[T], Generic[T]):
|
||||
k = torch.cat((k_nope, k_pe.expand((*k_nope.shape[:-1], -1))),
|
||||
dim=-1)
|
||||
|
||||
# For MLA the v head dim is smaller than qk head dim so we pad
|
||||
# out v with 0s to match the qk head dim
|
||||
v_padded = torch.nn.functional.pad(v,
|
||||
[0, q.shape[-1] - v.shape[-1]],
|
||||
value=0)
|
||||
|
||||
if is_vllm_fa:
|
||||
attn_output, attn_softmax_lse = self.flash_attn_varlen_func(
|
||||
attn_output, attn_softmax_lse = \
|
||||
self._flash_attn_varlen_diff_headdims(
|
||||
q=q,
|
||||
k=k,
|
||||
v=v_padded,
|
||||
v=v,
|
||||
cu_seqlens_q=prefill_metadata.query_start_loc,
|
||||
cu_seqlens_k=prefill_metadata.context_chunk_cu_seq_lens[i],
|
||||
max_seqlen_q=prefill_metadata.max_query_len,
|
||||
max_seqlen_k=prefill_metadata.
|
||||
context_chunk_max_seq_lens[i],
|
||||
max_seqlen_k=prefill_metadata.context_chunk_max_seq_lens[i],
|
||||
softmax_scale=self.scale,
|
||||
causal=False, # Context is unmasked
|
||||
return_softmax_lse=True,
|
||||
)
|
||||
else:
|
||||
attn_output, attn_softmax_lse, _ = self.flash_attn_varlen_func(
|
||||
q=q,
|
||||
k=k,
|
||||
v=v_padded,
|
||||
cu_seqlens_q=prefill_metadata.query_start_loc,
|
||||
cu_seqlens_k=prefill_metadata.context_chunk_cu_seq_lens[i],
|
||||
max_seqlen_q=prefill_metadata.max_query_len,
|
||||
max_seqlen_k=prefill_metadata.
|
||||
context_chunk_max_seq_lens[i],
|
||||
softmax_scale=self.scale,
|
||||
causal=False, # Context is unmasked
|
||||
return_attn_probs=True,
|
||||
)
|
||||
|
||||
if output is None:
|
||||
output = attn_output
|
||||
@ -1252,33 +1295,10 @@ class MLACommonImpl(MLAAttentionImpl[T], Generic[T]):
|
||||
|
||||
k = torch.cat((k_nope, k_pe.expand((*k_nope.shape[:-1], -1))), dim=-1)
|
||||
|
||||
# For MLA the v head dim is smaller than qk head dim so we pad out
|
||||
# v with 0s to match the qk head dim
|
||||
v_padded = torch.nn.functional.pad(v, [0, q.shape[-1] - v.shape[-1]],
|
||||
value=0)
|
||||
|
||||
if is_hip and envs.VLLM_USE_TRITON_FLASH_ATTN and not has_context:
|
||||
output = self.triton_fa_func(
|
||||
q,
|
||||
k,
|
||||
v_padded,
|
||||
None,
|
||||
prefill_metadata.query_start_loc,
|
||||
prefill_metadata.query_start_loc,
|
||||
prefill_metadata.max_prefill_seq_len,
|
||||
prefill_metadata.max_prefill_seq_len,
|
||||
True, # causal
|
||||
self.scale,
|
||||
None, # attn_mask is None unless applying ALiBi mask
|
||||
)
|
||||
## triton flash attention always return 2 objects
|
||||
if not has_context:
|
||||
output = output[0]
|
||||
elif is_vllm_fa:
|
||||
output = self.flash_attn_varlen_func(
|
||||
output = self._flash_attn_varlen_diff_headdims(
|
||||
q=q,
|
||||
k=k,
|
||||
v=v_padded,
|
||||
v=v,
|
||||
cu_seqlens_q=prefill_metadata.query_start_loc,
|
||||
cu_seqlens_k=prefill_metadata.query_start_loc,
|
||||
max_seqlen_q=prefill_metadata.max_prefill_seq_len,
|
||||
@ -1287,23 +1307,10 @@ class MLACommonImpl(MLAAttentionImpl[T], Generic[T]):
|
||||
causal=True,
|
||||
return_softmax_lse=has_context,
|
||||
)
|
||||
else:
|
||||
output = self.flash_attn_varlen_func(
|
||||
q=q,
|
||||
k=k,
|
||||
v=v_padded,
|
||||
cu_seqlens_q=prefill_metadata.query_start_loc,
|
||||
cu_seqlens_k=prefill_metadata.query_start_loc,
|
||||
max_seqlen_q=prefill_metadata.max_prefill_seq_len,
|
||||
max_seqlen_k=prefill_metadata.max_prefill_seq_len,
|
||||
softmax_scale=self.scale,
|
||||
causal=True,
|
||||
return_attn_probs=has_context,
|
||||
)
|
||||
|
||||
if has_context:
|
||||
# ROCm flash_attn_varlen_func will return 3 objects instead of 2
|
||||
suffix_output, suffix_lse, *rest = output
|
||||
suffix_output, suffix_lse = output
|
||||
context_output, context_lse = self._compute_prefill_context( \
|
||||
q, kv_c_and_k_pe_cache, attn_metadata)
|
||||
|
||||
@ -1316,12 +1323,7 @@ class MLACommonImpl(MLAAttentionImpl[T], Generic[T]):
|
||||
suffix_lse=suffix_lse,
|
||||
)
|
||||
|
||||
# slice by `:v.shape[-1]` in order to remove v headdim padding
|
||||
output = output\
|
||||
.view(-1, self.num_heads, q.shape[-1])[..., :v.shape[-1]]\
|
||||
.reshape(-1, self.num_heads * v.shape[-1])
|
||||
|
||||
return self.o_proj(output)[0]
|
||||
return self.o_proj(output.flatten(start_dim=-2))[0]
|
||||
|
||||
@abstractmethod
|
||||
def _forward_decode(
|
||||
|
@ -2,8 +2,10 @@
|
||||
"""Attention backend utils"""
|
||||
from collections import defaultdict
|
||||
from contextlib import contextmanager
|
||||
from dataclasses import dataclass
|
||||
from itertools import accumulate
|
||||
from typing import TYPE_CHECKING, Any, Dict, List, Tuple, Type, TypeVar, Union
|
||||
from typing import (TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Type,
|
||||
TypeVar, Union)
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
@ -11,6 +13,7 @@ import torch
|
||||
from vllm.attention import (AttentionMetadata, AttentionMetadataBuilder,
|
||||
AttentionState)
|
||||
from vllm.attention.backends.abstract import AttentionType
|
||||
from vllm.config import ModelConfig
|
||||
from vllm.logger import init_logger
|
||||
from vllm.multimodal import MultiModalPlaceholderMap
|
||||
from vllm.utils import async_tensor_h2d, make_tensor_with_pad
|
||||
@ -583,3 +586,24 @@ def get_num_prefill_decode_query_kv_tokens(
|
||||
|
||||
return (num_prefill_query_tokens, num_prefill_kv_tokens,
|
||||
num_decode_query_tokens)
|
||||
|
||||
|
||||
@dataclass
|
||||
class MLADims:
|
||||
q_lora_rank: Optional[int]
|
||||
kv_lora_rank: int
|
||||
qk_nope_head_dim: int
|
||||
qk_rope_head_dim: int
|
||||
v_head_dim: int
|
||||
|
||||
|
||||
def get_mla_dims(model_config: ModelConfig) -> MLADims:
|
||||
hf_text_config = model_config.hf_text_config
|
||||
|
||||
return MLADims(
|
||||
q_lora_rank=getattr(hf_text_config, "q_lora_rank", None),
|
||||
kv_lora_rank=hf_text_config.kv_lora_rank,
|
||||
qk_nope_head_dim=hf_text_config.qk_nope_head_dim,
|
||||
qk_rope_head_dim=hf_text_config.qk_rope_head_dim,
|
||||
v_head_dim=hf_text_config.v_head_dim,
|
||||
)
|
||||
|
@ -23,7 +23,8 @@ if TYPE_CHECKING:
|
||||
from vllm.v1.worker.gpu_model_runner import GPUModelRunner
|
||||
|
||||
if current_platform.is_cuda():
|
||||
from vllm.vllm_flash_attn import flash_attn_varlen_func
|
||||
from vllm.vllm_flash_attn import (flash_attn_varlen_func,
|
||||
get_scheduler_metadata)
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
@ -93,6 +94,10 @@ class FlashAttentionMetadata:
|
||||
prefix_kv_lens: Optional[torch.Tensor]
|
||||
suffix_kv_lens: Optional[torch.Tensor]
|
||||
|
||||
# Optional aot scheduling
|
||||
scheduler_metadata: Optional[torch.Tensor] = None
|
||||
prefix_scheduler_metadata: Optional[torch.Tensor] = None
|
||||
|
||||
# For logging.
|
||||
num_input_tokens: int = 0 # Number of tokens including padding.
|
||||
|
||||
@ -277,7 +282,14 @@ def make_local_attention_virtual_batches(
|
||||
class FlashAttentionMetadataBuilder:
|
||||
|
||||
def __init__(self, runner: "GPUModelRunner"):
|
||||
model_config = runner.model_config
|
||||
|
||||
self.runner = runner
|
||||
self.aot_schedule = (get_flash_attn_version() == 3)
|
||||
self.num_heads = model_config.get_num_attention_heads(
|
||||
runner.parallel_config)
|
||||
self.headdim = model_config.get_head_size()
|
||||
self.page_size = self.runner.block_size
|
||||
|
||||
def reorder_batch(self, input_batch: "InputBatch",
|
||||
scheduler_output: "SchedulerOutput") -> bool:
|
||||
@ -319,6 +331,24 @@ class FlashAttentionMetadataBuilder:
|
||||
)
|
||||
|
||||
use_cascade = common_prefix_len > 0
|
||||
|
||||
def schedule(cu_query_lens, max_query_len, seqlens, max_seq_len,
|
||||
causal):
|
||||
if self.aot_schedule:
|
||||
return get_scheduler_metadata(
|
||||
batch_size=num_reqs,
|
||||
max_seqlen_q=max_query_len,
|
||||
max_seqlen_k=max_seq_len,
|
||||
cache_seqlens=seqlens,
|
||||
num_heads_q=self.num_heads,
|
||||
num_heads_kv=self.num_heads,
|
||||
headdim=self.headdim,
|
||||
page_size=self.page_size,
|
||||
cu_seqlens_q=cu_query_lens,
|
||||
causal=causal,
|
||||
)
|
||||
return None
|
||||
|
||||
if use_cascade:
|
||||
cu_prefix_query_lens = torch.tensor([0, num_actual_tokens],
|
||||
dtype=torch.int32,
|
||||
@ -330,10 +360,28 @@ class FlashAttentionMetadataBuilder:
|
||||
common_prefix_len)
|
||||
suffix_kv_lens = torch.from_numpy(suffix_kv_lens).to(
|
||||
self.runner.device)
|
||||
prefix_scheduler_metadata = schedule(
|
||||
cu_query_lens=cu_prefix_query_lens,
|
||||
max_query_len=num_actual_tokens,
|
||||
seqlens=prefix_kv_lens,
|
||||
max_seq_len=common_prefix_len,
|
||||
causal=False)
|
||||
scheduler_metadata = schedule(cu_query_lens=query_start_loc,
|
||||
max_query_len=max_query_len,
|
||||
seqlens=suffix_kv_lens,
|
||||
max_seq_len=max_seq_len -
|
||||
common_prefix_len,
|
||||
causal=True)
|
||||
else:
|
||||
cu_prefix_query_lens = None
|
||||
prefix_kv_lens = None
|
||||
suffix_kv_lens = None
|
||||
prefix_scheduler_metadata = None
|
||||
scheduler_metadata = schedule(cu_query_lens=query_start_loc,
|
||||
max_query_len=max_query_len,
|
||||
seqlens=seq_lens,
|
||||
max_seq_len=max_seq_len,
|
||||
causal=True)
|
||||
|
||||
attn_metadata = FlashAttentionMetadata(
|
||||
num_actual_tokens=num_actual_tokens,
|
||||
@ -345,10 +393,12 @@ class FlashAttentionMetadataBuilder:
|
||||
slot_mapping=slot_mapping,
|
||||
use_cascade=use_cascade,
|
||||
common_prefix_len=common_prefix_len,
|
||||
scheduler_metadata=scheduler_metadata,
|
||||
cu_prefix_query_lens=cu_prefix_query_lens,
|
||||
prefix_kv_lens=prefix_kv_lens,
|
||||
suffix_kv_lens=suffix_kv_lens,
|
||||
local_attn_metadata=local_attn_metadata,
|
||||
prefix_scheduler_metadata=prefix_scheduler_metadata,
|
||||
)
|
||||
return attn_metadata
|
||||
|
||||
@ -515,6 +565,7 @@ class FlashAttentionImpl(AttentionImpl):
|
||||
window_size=self.sliding_window,
|
||||
block_table=block_table,
|
||||
softcap=self.logits_soft_cap,
|
||||
scheduler_metadata=attn_metadata.scheduler_metadata,
|
||||
fa_version=self.vllm_flash_attn_version,
|
||||
q_descale=layer._q_scale.expand(descale_shape),
|
||||
k_descale=layer._k_scale.expand(descale_shape),
|
||||
@ -543,6 +594,8 @@ class FlashAttentionImpl(AttentionImpl):
|
||||
block_table=attn_metadata.block_table,
|
||||
common_prefix_len=attn_metadata.common_prefix_len,
|
||||
fa_version=self.vllm_flash_attn_version,
|
||||
prefix_scheduler_metadata=attn_metadata.prefix_scheduler_metadata,
|
||||
suffix_scheduler_metadata=attn_metadata.scheduler_metadata,
|
||||
q_descale=layer._q_scale,
|
||||
k_descale=layer._k_scale,
|
||||
v_descale=layer._v_scale,
|
||||
@ -636,6 +689,8 @@ def cascade_attention(
|
||||
block_table: torch.Tensor,
|
||||
common_prefix_len: int,
|
||||
fa_version: int,
|
||||
prefix_scheduler_metadata: Optional[torch.Tensor] = None,
|
||||
suffix_scheduler_metadata: Optional[torch.Tensor] = None,
|
||||
q_descale: Optional[torch.Tensor] = None,
|
||||
k_descale: Optional[torch.Tensor] = None,
|
||||
v_descale: Optional[torch.Tensor] = None,
|
||||
@ -667,6 +722,7 @@ def cascade_attention(
|
||||
block_table=block_table[:1],
|
||||
softcap=logits_soft_cap,
|
||||
return_softmax_lse=True,
|
||||
scheduler_metadata=prefix_scheduler_metadata,
|
||||
fa_version=fa_version,
|
||||
q_descale=q_descale.expand(descale_shape)
|
||||
if q_descale is not None else None,
|
||||
@ -693,6 +749,7 @@ def cascade_attention(
|
||||
block_table=block_table[:, num_common_kv_blocks:],
|
||||
softcap=logits_soft_cap,
|
||||
return_softmax_lse=True,
|
||||
scheduler_metadata=suffix_scheduler_metadata,
|
||||
fa_version=fa_version,
|
||||
q_descale=q_descale.expand(descale_shape)
|
||||
if q_descale is not None else None,
|
||||
|
@ -195,6 +195,7 @@ from vllm import _custom_ops as ops
|
||||
from vllm.attention.backends.abstract import (AttentionBackend, AttentionLayer,
|
||||
AttentionMetadata,
|
||||
MLAAttentionImpl)
|
||||
from vllm.attention.backends.utils import get_mla_dims
|
||||
from vllm.attention.ops.merge_attn_states import merge_attn_states
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
|
||||
@ -207,9 +208,11 @@ from vllm.vllm_flash_attn.fa_utils import get_flash_attn_version
|
||||
|
||||
try:
|
||||
from vllm.vllm_flash_attn import flash_attn_varlen_func
|
||||
is_vllm_fa = True
|
||||
except ImportError:
|
||||
# For rocm use upstream flash attention
|
||||
from flash_attn import flash_attn_varlen_func
|
||||
is_vllm_fa = False
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vllm.v1.core.sched.output import SchedulerOutput
|
||||
@ -350,6 +353,14 @@ class MLACommonMetadataBuilder(Generic[M]):
|
||||
model_config = runner.model_config
|
||||
cache_config = runner.cache_config
|
||||
self.chunked_prefill_enabled = scheduler_config.chunked_prefill_enabled
|
||||
self.num_heads = model_config.get_num_attention_heads(
|
||||
runner.parallel_config)
|
||||
self.mla_dims = get_mla_dims(model_config)
|
||||
self.aot_schedule = is_vllm_fa and (get_flash_attn_version() == 3)
|
||||
|
||||
# Dont try to access the runner on AMD
|
||||
if self.aot_schedule:
|
||||
self.page_size = self.runner.block_size
|
||||
|
||||
if self.chunked_prefill_enabled:
|
||||
self.chunked_prefill_workspace_size = min(
|
||||
@ -375,7 +386,6 @@ class MLACommonMetadataBuilder(Generic[M]):
|
||||
dtype=model_config.dtype,
|
||||
device=runner.device,
|
||||
)
|
||||
self.page_size = self.runner.block_size
|
||||
|
||||
def reorder_batch(self, input_batch: "InputBatch",
|
||||
scheduler_output: "SchedulerOutput") -> bool:
|
||||
@ -464,7 +474,6 @@ class MLACommonMetadataBuilder(Generic[M]):
|
||||
|
||||
seq_lens_cpu = self.runner.seq_lens_cpu[:num_reqs]
|
||||
seq_lens = seq_lens_cpu.to(device, non_blocking=True)
|
||||
max_query_len = seq_lens_cpu.max().item()
|
||||
|
||||
prefill_metadata = None
|
||||
if self._num_prefills > 0:
|
||||
@ -475,6 +484,8 @@ class MLACommonMetadataBuilder(Generic[M]):
|
||||
num_computed_tokens_cpu_tensor[reqs_start:num_reqs]
|
||||
max_context_len_cpu = context_lens_cpu.max().item()
|
||||
num_prefills_with_context_cpu = (context_lens_cpu > 0).sum().item()
|
||||
prefill_query_start_loc = query_start_loc[
|
||||
reqs_start:] - query_start_loc[reqs_start]
|
||||
|
||||
chunked_context_metadata = None
|
||||
if self.chunked_prefill_enabled and self._num_prefills > 0 \
|
||||
@ -537,8 +548,7 @@ class MLACommonMetadataBuilder(Generic[M]):
|
||||
prefill_metadata = MLACommonPrefillMetadata(
|
||||
input_positions=input_positions[tokens_start:],
|
||||
block_table=block_table[reqs_start:, ...],
|
||||
query_start_loc=query_start_loc[reqs_start:] -
|
||||
query_start_loc[reqs_start],
|
||||
query_start_loc=prefill_query_start_loc,
|
||||
max_query_len=max_query_len,
|
||||
chunked_context=chunked_context_metadata,
|
||||
)
|
||||
@ -628,11 +638,56 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]):
|
||||
# and the one from vllm_flash_attn. The former is used on RoCM and the
|
||||
# latter has an additional parameter to control FA2 vs FA3
|
||||
self.flash_attn_varlen_func = flash_attn_varlen_func
|
||||
self.vllm_flash_attn_version = get_flash_attn_version()
|
||||
if self.vllm_flash_attn_version is not None:
|
||||
self.flash_attn_varlen_func = \
|
||||
functools.partial(flash_attn_varlen_func,
|
||||
fa_version=self.vllm_flash_attn_version)
|
||||
|
||||
# For MLA the v head dim is smaller than qk head dim so we pad out
|
||||
# v with 0s to match the qk head dim for attention backends that do
|
||||
# not support different headdims
|
||||
# We don't need to pad V if we are on a hopper system with FA3
|
||||
self._pad_v = self.vllm_flash_attn_version is None or not (
|
||||
self.vllm_flash_attn_version == 3
|
||||
and current_platform.get_device_capability()[0] == 9)
|
||||
|
||||
def _flash_attn_varlen_diff_headdims(self,
|
||||
q,
|
||||
k,
|
||||
v,
|
||||
return_softmax_lse=False,
|
||||
softmax_scale=None,
|
||||
**kwargs):
|
||||
maybe_padded_v = v
|
||||
if self._pad_v:
|
||||
maybe_padded_v = torch.nn.functional.pad(
|
||||
v, [0, q.shape[-1] - v.shape[-1]], value=0)
|
||||
|
||||
attn_out = self.flash_attn_varlen_func(
|
||||
q=q,
|
||||
k=k,
|
||||
v=maybe_padded_v,
|
||||
return_softmax_lse=return_softmax_lse,
|
||||
softmax_scale=softmax_scale,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
# Unpack the output if there is multiple results
|
||||
lse = None
|
||||
if isinstance(attn_out, tuple):
|
||||
attn_out, lse = attn_out[0], attn_out[1]
|
||||
|
||||
# unpad if necessary
|
||||
if self._pad_v:
|
||||
attn_out = attn_out[..., :v.shape[-1]]
|
||||
|
||||
# Remain consistent with old `flash_attn_varlen_func` where there
|
||||
# is only one output tensor if `return_softmax_lse` is False.
|
||||
if return_softmax_lse:
|
||||
return attn_out, lse
|
||||
return attn_out
|
||||
|
||||
def _v_up_proj_and_o_proj(self, x):
|
||||
# Convert from (B, N, L) to (N, B, L)
|
||||
x = x.view(-1, self.num_heads, self.kv_lora_rank).transpose(0, 1)
|
||||
@ -745,16 +800,11 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]):
|
||||
k = torch.cat((k_nope, k_pe.expand((*k_nope.shape[:-1], -1))),
|
||||
dim=-1)
|
||||
|
||||
# For MLA the v head dim is smaller than qk head dim so we pad
|
||||
# out v with 0s to match the qk head dim
|
||||
v_padded = torch.nn.functional.pad(v,
|
||||
[0, q.shape[-1] - v.shape[-1]],
|
||||
value=0)
|
||||
|
||||
attn_output, attn_softmax_lse = self.flash_attn_varlen_func(
|
||||
attn_output, attn_softmax_lse = \
|
||||
self._flash_attn_varlen_diff_headdims(
|
||||
q=q,
|
||||
k=k,
|
||||
v=v_padded,
|
||||
v=v,
|
||||
cu_seqlens_q=prefill_metadata.query_start_loc,
|
||||
cu_seqlens_k=prefill_metadata.chunked_context.cu_seq_lens[i],
|
||||
max_seqlen_q=prefill_metadata.max_query_len,
|
||||
@ -801,15 +851,10 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]):
|
||||
|
||||
k = torch.cat((k_nope, k_pe.expand((*k_nope.shape[:-1], -1))), dim=-1)
|
||||
|
||||
# For MLA the v head dim is smaller than qk head dim so we pad out
|
||||
# v with 0s to match the qk head dim
|
||||
v_padded = torch.nn.functional.pad(v, [0, q.shape[-1] - v.shape[-1]],
|
||||
value=0)
|
||||
|
||||
output = self.flash_attn_varlen_func(
|
||||
output = self._flash_attn_varlen_diff_headdims(
|
||||
q=q,
|
||||
k=k,
|
||||
v=v_padded,
|
||||
v=v,
|
||||
cu_seqlens_q=attn_metadata.prefill.query_start_loc,
|
||||
cu_seqlens_k=attn_metadata.prefill.query_start_loc,
|
||||
max_seqlen_q=attn_metadata.prefill.max_query_len,
|
||||
@ -833,12 +878,7 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]):
|
||||
suffix_lse=suffix_lse,
|
||||
)
|
||||
|
||||
# slice by `:v.shape[-1]` in order to remove v headdim padding
|
||||
output = output\
|
||||
.view(-1, self.num_heads, q.shape[-1])[..., :v.shape[-1]]\
|
||||
.reshape(-1, self.num_heads * v.shape[-1])
|
||||
|
||||
return self.o_proj(output)[0]
|
||||
return self.o_proj(output.flatten(start_dim=-2))[0]
|
||||
|
||||
@abstractmethod
|
||||
def _forward_decode(
|
||||
|
Loading…
x
Reference in New Issue
Block a user