[Attention] Update to lastest FA3 code (#13111)
Signed-off-by: Lucas Wilkinson <lwilkinson@neuralmagic.com>
This commit is contained in:
parent
3408e47159
commit
183dad7a85
@ -38,7 +38,7 @@ else()
|
|||||||
FetchContent_Declare(
|
FetchContent_Declare(
|
||||||
vllm-flash-attn
|
vllm-flash-attn
|
||||||
GIT_REPOSITORY https://github.com/vllm-project/flash-attention.git
|
GIT_REPOSITORY https://github.com/vllm-project/flash-attention.git
|
||||||
GIT_TAG dc9d410b3e2d6534a4c70724c2515f4def670a22
|
GIT_TAG 0a721daebe4fa7149f06ecf3d3eabeb6dcd0f1fa
|
||||||
GIT_PROGRESS TRUE
|
GIT_PROGRESS TRUE
|
||||||
# Don't share the vllm-flash-attn build between build types
|
# Don't share the vllm-flash-attn build between build types
|
||||||
BINARY_DIR ${CMAKE_BINARY_DIR}/vllm-flash-attn
|
BINARY_DIR ${CMAKE_BINARY_DIR}/vllm-flash-attn
|
||||||
|
@ -1043,8 +1043,8 @@ class MLACommonImpl(MLAAttentionImpl[T], Generic[T]):
|
|||||||
self.q_proj = q_proj
|
self.q_proj = q_proj
|
||||||
self.kv_b_proj = kv_b_proj
|
self.kv_b_proj = kv_b_proj
|
||||||
self.o_proj = o_proj
|
self.o_proj = o_proj
|
||||||
self.triton_fa_func = triton_attention
|
|
||||||
|
|
||||||
|
self.triton_fa_func = triton_attention
|
||||||
# Handle the differences between the flash_attn_varlen from flash_attn
|
# Handle the differences between the flash_attn_varlen from flash_attn
|
||||||
# and the one from vllm_flash_attn. The former is used on RoCM and the
|
# and the one from vllm_flash_attn. The former is used on RoCM and the
|
||||||
# latter has an additional parameter to control FA2 vs FA3
|
# latter has an additional parameter to control FA2 vs FA3
|
||||||
@ -1055,6 +1055,70 @@ class MLACommonImpl(MLAAttentionImpl[T], Generic[T]):
|
|||||||
functools.partial(flash_attn_varlen_func,
|
functools.partial(flash_attn_varlen_func,
|
||||||
fa_version=self.vllm_flash_attn_version)
|
fa_version=self.vllm_flash_attn_version)
|
||||||
|
|
||||||
|
# For MLA the v head dim is smaller than qk head dim so we pad out
|
||||||
|
# v with 0s to match the qk head dim for attention backends that do
|
||||||
|
# not support different headdims
|
||||||
|
# We don't need to pad V if we are on a hopper system with FA3
|
||||||
|
self._pad_v = self.vllm_flash_attn_version is None or not (
|
||||||
|
self.vllm_flash_attn_version == 3
|
||||||
|
and current_platform.get_device_capability()[0] == 9)
|
||||||
|
|
||||||
|
def _flash_attn_varlen_diff_headdims(self, q, k, v, softmax_scale,
|
||||||
|
return_softmax_lse, **kwargs):
|
||||||
|
maybe_padded_v = v
|
||||||
|
if self._pad_v:
|
||||||
|
maybe_padded_v = torch.nn.functional.pad(
|
||||||
|
v, [0, q.shape[-1] - v.shape[-1]], value=0)
|
||||||
|
|
||||||
|
if is_hip and envs.VLLM_USE_TRITON_FLASH_ATTN \
|
||||||
|
and not return_softmax_lse:
|
||||||
|
attn_out = self.triton_fa_func(
|
||||||
|
q,
|
||||||
|
k,
|
||||||
|
maybe_padded_v,
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
|
if is_vllm_fa:
|
||||||
|
attn_out = self.flash_attn_varlen_func(
|
||||||
|
q=q,
|
||||||
|
k=k,
|
||||||
|
v=maybe_padded_v,
|
||||||
|
return_softmax_lse=return_softmax_lse,
|
||||||
|
softmax_scale=softmax_scale,
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
# Use return_attn_probs instead of return_softmax_lse for RoCM
|
||||||
|
attn_out = self.flash_attn_varlen_func(
|
||||||
|
q=q,
|
||||||
|
k=k,
|
||||||
|
v=maybe_padded_v,
|
||||||
|
return_attn_probs=return_softmax_lse,
|
||||||
|
softmax_scale=softmax_scale,
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Unpack the output if there is multiple results,
|
||||||
|
# triton always returns (output, softmax_lse),
|
||||||
|
# vllm_flash_attn returns (output, softmax_lse) when
|
||||||
|
# `return_softmax_lse = True`
|
||||||
|
# flash_attn (RoCM) returns (output, softmax_lse, ...) when
|
||||||
|
# `return_attn_probs = True`
|
||||||
|
rest = None
|
||||||
|
if isinstance(attn_out, tuple):
|
||||||
|
attn_out, *rest = attn_out
|
||||||
|
|
||||||
|
# unpad if necessary
|
||||||
|
if self._pad_v:
|
||||||
|
attn_out = attn_out[..., :v.shape[-1]]
|
||||||
|
|
||||||
|
# Remain consistent with old `flash_attn_varlen_func` where there
|
||||||
|
# is only one output tensor if `return_softmax_lse` is False.
|
||||||
|
if return_softmax_lse:
|
||||||
|
assert rest is not None
|
||||||
|
return attn_out, rest[0]
|
||||||
|
return attn_out
|
||||||
|
|
||||||
def _v_up_proj_and_o_proj(self, x):
|
def _v_up_proj_and_o_proj(self, x):
|
||||||
# Convert from (B, N, L) to (N, B, L)
|
# Convert from (B, N, L) to (N, B, L)
|
||||||
x = x.view(-1, self.num_heads, self.kv_lora_rank).transpose(0, 1)
|
x = x.view(-1, self.num_heads, self.kv_lora_rank).transpose(0, 1)
|
||||||
@ -1176,40 +1240,19 @@ class MLACommonImpl(MLAAttentionImpl[T], Generic[T]):
|
|||||||
k = torch.cat((k_nope, k_pe.expand((*k_nope.shape[:-1], -1))),
|
k = torch.cat((k_nope, k_pe.expand((*k_nope.shape[:-1], -1))),
|
||||||
dim=-1)
|
dim=-1)
|
||||||
|
|
||||||
# For MLA the v head dim is smaller than qk head dim so we pad
|
attn_output, attn_softmax_lse = \
|
||||||
# out v with 0s to match the qk head dim
|
self._flash_attn_varlen_diff_headdims(
|
||||||
v_padded = torch.nn.functional.pad(v,
|
q=q,
|
||||||
[0, q.shape[-1] - v.shape[-1]],
|
k=k,
|
||||||
value=0)
|
v=v,
|
||||||
|
cu_seqlens_q=prefill_metadata.query_start_loc,
|
||||||
if is_vllm_fa:
|
cu_seqlens_k=prefill_metadata.context_chunk_cu_seq_lens[i],
|
||||||
attn_output, attn_softmax_lse = self.flash_attn_varlen_func(
|
max_seqlen_q=prefill_metadata.max_query_len,
|
||||||
q=q,
|
max_seqlen_k=prefill_metadata.context_chunk_max_seq_lens[i],
|
||||||
k=k,
|
softmax_scale=self.scale,
|
||||||
v=v_padded,
|
causal=False, # Context is unmasked
|
||||||
cu_seqlens_q=prefill_metadata.query_start_loc,
|
return_softmax_lse=True,
|
||||||
cu_seqlens_k=prefill_metadata.context_chunk_cu_seq_lens[i],
|
)
|
||||||
max_seqlen_q=prefill_metadata.max_query_len,
|
|
||||||
max_seqlen_k=prefill_metadata.
|
|
||||||
context_chunk_max_seq_lens[i],
|
|
||||||
softmax_scale=self.scale,
|
|
||||||
causal=False, # Context is unmasked
|
|
||||||
return_softmax_lse=True,
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
attn_output, attn_softmax_lse, _ = self.flash_attn_varlen_func(
|
|
||||||
q=q,
|
|
||||||
k=k,
|
|
||||||
v=v_padded,
|
|
||||||
cu_seqlens_q=prefill_metadata.query_start_loc,
|
|
||||||
cu_seqlens_k=prefill_metadata.context_chunk_cu_seq_lens[i],
|
|
||||||
max_seqlen_q=prefill_metadata.max_query_len,
|
|
||||||
max_seqlen_k=prefill_metadata.
|
|
||||||
context_chunk_max_seq_lens[i],
|
|
||||||
softmax_scale=self.scale,
|
|
||||||
causal=False, # Context is unmasked
|
|
||||||
return_attn_probs=True,
|
|
||||||
)
|
|
||||||
|
|
||||||
if output is None:
|
if output is None:
|
||||||
output = attn_output
|
output = attn_output
|
||||||
@ -1252,58 +1295,22 @@ class MLACommonImpl(MLAAttentionImpl[T], Generic[T]):
|
|||||||
|
|
||||||
k = torch.cat((k_nope, k_pe.expand((*k_nope.shape[:-1], -1))), dim=-1)
|
k = torch.cat((k_nope, k_pe.expand((*k_nope.shape[:-1], -1))), dim=-1)
|
||||||
|
|
||||||
# For MLA the v head dim is smaller than qk head dim so we pad out
|
output = self._flash_attn_varlen_diff_headdims(
|
||||||
# v with 0s to match the qk head dim
|
q=q,
|
||||||
v_padded = torch.nn.functional.pad(v, [0, q.shape[-1] - v.shape[-1]],
|
k=k,
|
||||||
value=0)
|
v=v,
|
||||||
|
cu_seqlens_q=prefill_metadata.query_start_loc,
|
||||||
if is_hip and envs.VLLM_USE_TRITON_FLASH_ATTN and not has_context:
|
cu_seqlens_k=prefill_metadata.query_start_loc,
|
||||||
output = self.triton_fa_func(
|
max_seqlen_q=prefill_metadata.max_prefill_seq_len,
|
||||||
q,
|
max_seqlen_k=prefill_metadata.max_prefill_seq_len,
|
||||||
k,
|
softmax_scale=self.scale,
|
||||||
v_padded,
|
causal=True,
|
||||||
None,
|
return_softmax_lse=has_context,
|
||||||
prefill_metadata.query_start_loc,
|
)
|
||||||
prefill_metadata.query_start_loc,
|
|
||||||
prefill_metadata.max_prefill_seq_len,
|
|
||||||
prefill_metadata.max_prefill_seq_len,
|
|
||||||
True, # causal
|
|
||||||
self.scale,
|
|
||||||
None, # attn_mask is None unless applying ALiBi mask
|
|
||||||
)
|
|
||||||
## triton flash attention always return 2 objects
|
|
||||||
if not has_context:
|
|
||||||
output = output[0]
|
|
||||||
elif is_vllm_fa:
|
|
||||||
output = self.flash_attn_varlen_func(
|
|
||||||
q=q,
|
|
||||||
k=k,
|
|
||||||
v=v_padded,
|
|
||||||
cu_seqlens_q=prefill_metadata.query_start_loc,
|
|
||||||
cu_seqlens_k=prefill_metadata.query_start_loc,
|
|
||||||
max_seqlen_q=prefill_metadata.max_prefill_seq_len,
|
|
||||||
max_seqlen_k=prefill_metadata.max_prefill_seq_len,
|
|
||||||
softmax_scale=self.scale,
|
|
||||||
causal=True,
|
|
||||||
return_softmax_lse=has_context,
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
output = self.flash_attn_varlen_func(
|
|
||||||
q=q,
|
|
||||||
k=k,
|
|
||||||
v=v_padded,
|
|
||||||
cu_seqlens_q=prefill_metadata.query_start_loc,
|
|
||||||
cu_seqlens_k=prefill_metadata.query_start_loc,
|
|
||||||
max_seqlen_q=prefill_metadata.max_prefill_seq_len,
|
|
||||||
max_seqlen_k=prefill_metadata.max_prefill_seq_len,
|
|
||||||
softmax_scale=self.scale,
|
|
||||||
causal=True,
|
|
||||||
return_attn_probs=has_context,
|
|
||||||
)
|
|
||||||
|
|
||||||
if has_context:
|
if has_context:
|
||||||
# ROCm flash_attn_varlen_func will return 3 objects instead of 2
|
# ROCm flash_attn_varlen_func will return 3 objects instead of 2
|
||||||
suffix_output, suffix_lse, *rest = output
|
suffix_output, suffix_lse = output
|
||||||
context_output, context_lse = self._compute_prefill_context( \
|
context_output, context_lse = self._compute_prefill_context( \
|
||||||
q, kv_c_and_k_pe_cache, attn_metadata)
|
q, kv_c_and_k_pe_cache, attn_metadata)
|
||||||
|
|
||||||
@ -1316,12 +1323,7 @@ class MLACommonImpl(MLAAttentionImpl[T], Generic[T]):
|
|||||||
suffix_lse=suffix_lse,
|
suffix_lse=suffix_lse,
|
||||||
)
|
)
|
||||||
|
|
||||||
# slice by `:v.shape[-1]` in order to remove v headdim padding
|
return self.o_proj(output.flatten(start_dim=-2))[0]
|
||||||
output = output\
|
|
||||||
.view(-1, self.num_heads, q.shape[-1])[..., :v.shape[-1]]\
|
|
||||||
.reshape(-1, self.num_heads * v.shape[-1])
|
|
||||||
|
|
||||||
return self.o_proj(output)[0]
|
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def _forward_decode(
|
def _forward_decode(
|
||||||
|
@ -2,8 +2,10 @@
|
|||||||
"""Attention backend utils"""
|
"""Attention backend utils"""
|
||||||
from collections import defaultdict
|
from collections import defaultdict
|
||||||
from contextlib import contextmanager
|
from contextlib import contextmanager
|
||||||
|
from dataclasses import dataclass
|
||||||
from itertools import accumulate
|
from itertools import accumulate
|
||||||
from typing import TYPE_CHECKING, Any, Dict, List, Tuple, Type, TypeVar, Union
|
from typing import (TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Type,
|
||||||
|
TypeVar, Union)
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
@ -11,6 +13,7 @@ import torch
|
|||||||
from vllm.attention import (AttentionMetadata, AttentionMetadataBuilder,
|
from vllm.attention import (AttentionMetadata, AttentionMetadataBuilder,
|
||||||
AttentionState)
|
AttentionState)
|
||||||
from vllm.attention.backends.abstract import AttentionType
|
from vllm.attention.backends.abstract import AttentionType
|
||||||
|
from vllm.config import ModelConfig
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
from vllm.multimodal import MultiModalPlaceholderMap
|
from vllm.multimodal import MultiModalPlaceholderMap
|
||||||
from vllm.utils import async_tensor_h2d, make_tensor_with_pad
|
from vllm.utils import async_tensor_h2d, make_tensor_with_pad
|
||||||
@ -583,3 +586,24 @@ def get_num_prefill_decode_query_kv_tokens(
|
|||||||
|
|
||||||
return (num_prefill_query_tokens, num_prefill_kv_tokens,
|
return (num_prefill_query_tokens, num_prefill_kv_tokens,
|
||||||
num_decode_query_tokens)
|
num_decode_query_tokens)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class MLADims:
|
||||||
|
q_lora_rank: Optional[int]
|
||||||
|
kv_lora_rank: int
|
||||||
|
qk_nope_head_dim: int
|
||||||
|
qk_rope_head_dim: int
|
||||||
|
v_head_dim: int
|
||||||
|
|
||||||
|
|
||||||
|
def get_mla_dims(model_config: ModelConfig) -> MLADims:
|
||||||
|
hf_text_config = model_config.hf_text_config
|
||||||
|
|
||||||
|
return MLADims(
|
||||||
|
q_lora_rank=getattr(hf_text_config, "q_lora_rank", None),
|
||||||
|
kv_lora_rank=hf_text_config.kv_lora_rank,
|
||||||
|
qk_nope_head_dim=hf_text_config.qk_nope_head_dim,
|
||||||
|
qk_rope_head_dim=hf_text_config.qk_rope_head_dim,
|
||||||
|
v_head_dim=hf_text_config.v_head_dim,
|
||||||
|
)
|
||||||
|
@ -23,7 +23,8 @@ if TYPE_CHECKING:
|
|||||||
from vllm.v1.worker.gpu_model_runner import GPUModelRunner
|
from vllm.v1.worker.gpu_model_runner import GPUModelRunner
|
||||||
|
|
||||||
if current_platform.is_cuda():
|
if current_platform.is_cuda():
|
||||||
from vllm.vllm_flash_attn import flash_attn_varlen_func
|
from vllm.vllm_flash_attn import (flash_attn_varlen_func,
|
||||||
|
get_scheduler_metadata)
|
||||||
|
|
||||||
logger = init_logger(__name__)
|
logger = init_logger(__name__)
|
||||||
|
|
||||||
@ -93,6 +94,10 @@ class FlashAttentionMetadata:
|
|||||||
prefix_kv_lens: Optional[torch.Tensor]
|
prefix_kv_lens: Optional[torch.Tensor]
|
||||||
suffix_kv_lens: Optional[torch.Tensor]
|
suffix_kv_lens: Optional[torch.Tensor]
|
||||||
|
|
||||||
|
# Optional aot scheduling
|
||||||
|
scheduler_metadata: Optional[torch.Tensor] = None
|
||||||
|
prefix_scheduler_metadata: Optional[torch.Tensor] = None
|
||||||
|
|
||||||
# For logging.
|
# For logging.
|
||||||
num_input_tokens: int = 0 # Number of tokens including padding.
|
num_input_tokens: int = 0 # Number of tokens including padding.
|
||||||
|
|
||||||
@ -277,7 +282,14 @@ def make_local_attention_virtual_batches(
|
|||||||
class FlashAttentionMetadataBuilder:
|
class FlashAttentionMetadataBuilder:
|
||||||
|
|
||||||
def __init__(self, runner: "GPUModelRunner"):
|
def __init__(self, runner: "GPUModelRunner"):
|
||||||
|
model_config = runner.model_config
|
||||||
|
|
||||||
self.runner = runner
|
self.runner = runner
|
||||||
|
self.aot_schedule = (get_flash_attn_version() == 3)
|
||||||
|
self.num_heads = model_config.get_num_attention_heads(
|
||||||
|
runner.parallel_config)
|
||||||
|
self.headdim = model_config.get_head_size()
|
||||||
|
self.page_size = self.runner.block_size
|
||||||
|
|
||||||
def reorder_batch(self, input_batch: "InputBatch",
|
def reorder_batch(self, input_batch: "InputBatch",
|
||||||
scheduler_output: "SchedulerOutput") -> bool:
|
scheduler_output: "SchedulerOutput") -> bool:
|
||||||
@ -319,6 +331,24 @@ class FlashAttentionMetadataBuilder:
|
|||||||
)
|
)
|
||||||
|
|
||||||
use_cascade = common_prefix_len > 0
|
use_cascade = common_prefix_len > 0
|
||||||
|
|
||||||
|
def schedule(cu_query_lens, max_query_len, seqlens, max_seq_len,
|
||||||
|
causal):
|
||||||
|
if self.aot_schedule:
|
||||||
|
return get_scheduler_metadata(
|
||||||
|
batch_size=num_reqs,
|
||||||
|
max_seqlen_q=max_query_len,
|
||||||
|
max_seqlen_k=max_seq_len,
|
||||||
|
cache_seqlens=seqlens,
|
||||||
|
num_heads_q=self.num_heads,
|
||||||
|
num_heads_kv=self.num_heads,
|
||||||
|
headdim=self.headdim,
|
||||||
|
page_size=self.page_size,
|
||||||
|
cu_seqlens_q=cu_query_lens,
|
||||||
|
causal=causal,
|
||||||
|
)
|
||||||
|
return None
|
||||||
|
|
||||||
if use_cascade:
|
if use_cascade:
|
||||||
cu_prefix_query_lens = torch.tensor([0, num_actual_tokens],
|
cu_prefix_query_lens = torch.tensor([0, num_actual_tokens],
|
||||||
dtype=torch.int32,
|
dtype=torch.int32,
|
||||||
@ -330,10 +360,28 @@ class FlashAttentionMetadataBuilder:
|
|||||||
common_prefix_len)
|
common_prefix_len)
|
||||||
suffix_kv_lens = torch.from_numpy(suffix_kv_lens).to(
|
suffix_kv_lens = torch.from_numpy(suffix_kv_lens).to(
|
||||||
self.runner.device)
|
self.runner.device)
|
||||||
|
prefix_scheduler_metadata = schedule(
|
||||||
|
cu_query_lens=cu_prefix_query_lens,
|
||||||
|
max_query_len=num_actual_tokens,
|
||||||
|
seqlens=prefix_kv_lens,
|
||||||
|
max_seq_len=common_prefix_len,
|
||||||
|
causal=False)
|
||||||
|
scheduler_metadata = schedule(cu_query_lens=query_start_loc,
|
||||||
|
max_query_len=max_query_len,
|
||||||
|
seqlens=suffix_kv_lens,
|
||||||
|
max_seq_len=max_seq_len -
|
||||||
|
common_prefix_len,
|
||||||
|
causal=True)
|
||||||
else:
|
else:
|
||||||
cu_prefix_query_lens = None
|
cu_prefix_query_lens = None
|
||||||
prefix_kv_lens = None
|
prefix_kv_lens = None
|
||||||
suffix_kv_lens = None
|
suffix_kv_lens = None
|
||||||
|
prefix_scheduler_metadata = None
|
||||||
|
scheduler_metadata = schedule(cu_query_lens=query_start_loc,
|
||||||
|
max_query_len=max_query_len,
|
||||||
|
seqlens=seq_lens,
|
||||||
|
max_seq_len=max_seq_len,
|
||||||
|
causal=True)
|
||||||
|
|
||||||
attn_metadata = FlashAttentionMetadata(
|
attn_metadata = FlashAttentionMetadata(
|
||||||
num_actual_tokens=num_actual_tokens,
|
num_actual_tokens=num_actual_tokens,
|
||||||
@ -345,10 +393,12 @@ class FlashAttentionMetadataBuilder:
|
|||||||
slot_mapping=slot_mapping,
|
slot_mapping=slot_mapping,
|
||||||
use_cascade=use_cascade,
|
use_cascade=use_cascade,
|
||||||
common_prefix_len=common_prefix_len,
|
common_prefix_len=common_prefix_len,
|
||||||
|
scheduler_metadata=scheduler_metadata,
|
||||||
cu_prefix_query_lens=cu_prefix_query_lens,
|
cu_prefix_query_lens=cu_prefix_query_lens,
|
||||||
prefix_kv_lens=prefix_kv_lens,
|
prefix_kv_lens=prefix_kv_lens,
|
||||||
suffix_kv_lens=suffix_kv_lens,
|
suffix_kv_lens=suffix_kv_lens,
|
||||||
local_attn_metadata=local_attn_metadata,
|
local_attn_metadata=local_attn_metadata,
|
||||||
|
prefix_scheduler_metadata=prefix_scheduler_metadata,
|
||||||
)
|
)
|
||||||
return attn_metadata
|
return attn_metadata
|
||||||
|
|
||||||
@ -515,6 +565,7 @@ class FlashAttentionImpl(AttentionImpl):
|
|||||||
window_size=self.sliding_window,
|
window_size=self.sliding_window,
|
||||||
block_table=block_table,
|
block_table=block_table,
|
||||||
softcap=self.logits_soft_cap,
|
softcap=self.logits_soft_cap,
|
||||||
|
scheduler_metadata=attn_metadata.scheduler_metadata,
|
||||||
fa_version=self.vllm_flash_attn_version,
|
fa_version=self.vllm_flash_attn_version,
|
||||||
q_descale=layer._q_scale.expand(descale_shape),
|
q_descale=layer._q_scale.expand(descale_shape),
|
||||||
k_descale=layer._k_scale.expand(descale_shape),
|
k_descale=layer._k_scale.expand(descale_shape),
|
||||||
@ -543,6 +594,8 @@ class FlashAttentionImpl(AttentionImpl):
|
|||||||
block_table=attn_metadata.block_table,
|
block_table=attn_metadata.block_table,
|
||||||
common_prefix_len=attn_metadata.common_prefix_len,
|
common_prefix_len=attn_metadata.common_prefix_len,
|
||||||
fa_version=self.vllm_flash_attn_version,
|
fa_version=self.vllm_flash_attn_version,
|
||||||
|
prefix_scheduler_metadata=attn_metadata.prefix_scheduler_metadata,
|
||||||
|
suffix_scheduler_metadata=attn_metadata.scheduler_metadata,
|
||||||
q_descale=layer._q_scale,
|
q_descale=layer._q_scale,
|
||||||
k_descale=layer._k_scale,
|
k_descale=layer._k_scale,
|
||||||
v_descale=layer._v_scale,
|
v_descale=layer._v_scale,
|
||||||
@ -636,6 +689,8 @@ def cascade_attention(
|
|||||||
block_table: torch.Tensor,
|
block_table: torch.Tensor,
|
||||||
common_prefix_len: int,
|
common_prefix_len: int,
|
||||||
fa_version: int,
|
fa_version: int,
|
||||||
|
prefix_scheduler_metadata: Optional[torch.Tensor] = None,
|
||||||
|
suffix_scheduler_metadata: Optional[torch.Tensor] = None,
|
||||||
q_descale: Optional[torch.Tensor] = None,
|
q_descale: Optional[torch.Tensor] = None,
|
||||||
k_descale: Optional[torch.Tensor] = None,
|
k_descale: Optional[torch.Tensor] = None,
|
||||||
v_descale: Optional[torch.Tensor] = None,
|
v_descale: Optional[torch.Tensor] = None,
|
||||||
@ -667,6 +722,7 @@ def cascade_attention(
|
|||||||
block_table=block_table[:1],
|
block_table=block_table[:1],
|
||||||
softcap=logits_soft_cap,
|
softcap=logits_soft_cap,
|
||||||
return_softmax_lse=True,
|
return_softmax_lse=True,
|
||||||
|
scheduler_metadata=prefix_scheduler_metadata,
|
||||||
fa_version=fa_version,
|
fa_version=fa_version,
|
||||||
q_descale=q_descale.expand(descale_shape)
|
q_descale=q_descale.expand(descale_shape)
|
||||||
if q_descale is not None else None,
|
if q_descale is not None else None,
|
||||||
@ -693,6 +749,7 @@ def cascade_attention(
|
|||||||
block_table=block_table[:, num_common_kv_blocks:],
|
block_table=block_table[:, num_common_kv_blocks:],
|
||||||
softcap=logits_soft_cap,
|
softcap=logits_soft_cap,
|
||||||
return_softmax_lse=True,
|
return_softmax_lse=True,
|
||||||
|
scheduler_metadata=suffix_scheduler_metadata,
|
||||||
fa_version=fa_version,
|
fa_version=fa_version,
|
||||||
q_descale=q_descale.expand(descale_shape)
|
q_descale=q_descale.expand(descale_shape)
|
||||||
if q_descale is not None else None,
|
if q_descale is not None else None,
|
||||||
|
@ -195,6 +195,7 @@ from vllm import _custom_ops as ops
|
|||||||
from vllm.attention.backends.abstract import (AttentionBackend, AttentionLayer,
|
from vllm.attention.backends.abstract import (AttentionBackend, AttentionLayer,
|
||||||
AttentionMetadata,
|
AttentionMetadata,
|
||||||
MLAAttentionImpl)
|
MLAAttentionImpl)
|
||||||
|
from vllm.attention.backends.utils import get_mla_dims
|
||||||
from vllm.attention.ops.merge_attn_states import merge_attn_states
|
from vllm.attention.ops.merge_attn_states import merge_attn_states
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
|
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
|
||||||
@ -207,9 +208,11 @@ from vllm.vllm_flash_attn.fa_utils import get_flash_attn_version
|
|||||||
|
|
||||||
try:
|
try:
|
||||||
from vllm.vllm_flash_attn import flash_attn_varlen_func
|
from vllm.vllm_flash_attn import flash_attn_varlen_func
|
||||||
|
is_vllm_fa = True
|
||||||
except ImportError:
|
except ImportError:
|
||||||
# For rocm use upstream flash attention
|
# For rocm use upstream flash attention
|
||||||
from flash_attn import flash_attn_varlen_func
|
from flash_attn import flash_attn_varlen_func
|
||||||
|
is_vllm_fa = False
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from vllm.v1.core.sched.output import SchedulerOutput
|
from vllm.v1.core.sched.output import SchedulerOutput
|
||||||
@ -350,6 +353,14 @@ class MLACommonMetadataBuilder(Generic[M]):
|
|||||||
model_config = runner.model_config
|
model_config = runner.model_config
|
||||||
cache_config = runner.cache_config
|
cache_config = runner.cache_config
|
||||||
self.chunked_prefill_enabled = scheduler_config.chunked_prefill_enabled
|
self.chunked_prefill_enabled = scheduler_config.chunked_prefill_enabled
|
||||||
|
self.num_heads = model_config.get_num_attention_heads(
|
||||||
|
runner.parallel_config)
|
||||||
|
self.mla_dims = get_mla_dims(model_config)
|
||||||
|
self.aot_schedule = is_vllm_fa and (get_flash_attn_version() == 3)
|
||||||
|
|
||||||
|
# Dont try to access the runner on AMD
|
||||||
|
if self.aot_schedule:
|
||||||
|
self.page_size = self.runner.block_size
|
||||||
|
|
||||||
if self.chunked_prefill_enabled:
|
if self.chunked_prefill_enabled:
|
||||||
self.chunked_prefill_workspace_size = min(
|
self.chunked_prefill_workspace_size = min(
|
||||||
@ -375,7 +386,6 @@ class MLACommonMetadataBuilder(Generic[M]):
|
|||||||
dtype=model_config.dtype,
|
dtype=model_config.dtype,
|
||||||
device=runner.device,
|
device=runner.device,
|
||||||
)
|
)
|
||||||
self.page_size = self.runner.block_size
|
|
||||||
|
|
||||||
def reorder_batch(self, input_batch: "InputBatch",
|
def reorder_batch(self, input_batch: "InputBatch",
|
||||||
scheduler_output: "SchedulerOutput") -> bool:
|
scheduler_output: "SchedulerOutput") -> bool:
|
||||||
@ -464,7 +474,6 @@ class MLACommonMetadataBuilder(Generic[M]):
|
|||||||
|
|
||||||
seq_lens_cpu = self.runner.seq_lens_cpu[:num_reqs]
|
seq_lens_cpu = self.runner.seq_lens_cpu[:num_reqs]
|
||||||
seq_lens = seq_lens_cpu.to(device, non_blocking=True)
|
seq_lens = seq_lens_cpu.to(device, non_blocking=True)
|
||||||
max_query_len = seq_lens_cpu.max().item()
|
|
||||||
|
|
||||||
prefill_metadata = None
|
prefill_metadata = None
|
||||||
if self._num_prefills > 0:
|
if self._num_prefills > 0:
|
||||||
@ -475,6 +484,8 @@ class MLACommonMetadataBuilder(Generic[M]):
|
|||||||
num_computed_tokens_cpu_tensor[reqs_start:num_reqs]
|
num_computed_tokens_cpu_tensor[reqs_start:num_reqs]
|
||||||
max_context_len_cpu = context_lens_cpu.max().item()
|
max_context_len_cpu = context_lens_cpu.max().item()
|
||||||
num_prefills_with_context_cpu = (context_lens_cpu > 0).sum().item()
|
num_prefills_with_context_cpu = (context_lens_cpu > 0).sum().item()
|
||||||
|
prefill_query_start_loc = query_start_loc[
|
||||||
|
reqs_start:] - query_start_loc[reqs_start]
|
||||||
|
|
||||||
chunked_context_metadata = None
|
chunked_context_metadata = None
|
||||||
if self.chunked_prefill_enabled and self._num_prefills > 0 \
|
if self.chunked_prefill_enabled and self._num_prefills > 0 \
|
||||||
@ -537,8 +548,7 @@ class MLACommonMetadataBuilder(Generic[M]):
|
|||||||
prefill_metadata = MLACommonPrefillMetadata(
|
prefill_metadata = MLACommonPrefillMetadata(
|
||||||
input_positions=input_positions[tokens_start:],
|
input_positions=input_positions[tokens_start:],
|
||||||
block_table=block_table[reqs_start:, ...],
|
block_table=block_table[reqs_start:, ...],
|
||||||
query_start_loc=query_start_loc[reqs_start:] -
|
query_start_loc=prefill_query_start_loc,
|
||||||
query_start_loc[reqs_start],
|
|
||||||
max_query_len=max_query_len,
|
max_query_len=max_query_len,
|
||||||
chunked_context=chunked_context_metadata,
|
chunked_context=chunked_context_metadata,
|
||||||
)
|
)
|
||||||
@ -628,11 +638,56 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]):
|
|||||||
# and the one from vllm_flash_attn. The former is used on RoCM and the
|
# and the one from vllm_flash_attn. The former is used on RoCM and the
|
||||||
# latter has an additional parameter to control FA2 vs FA3
|
# latter has an additional parameter to control FA2 vs FA3
|
||||||
self.flash_attn_varlen_func = flash_attn_varlen_func
|
self.flash_attn_varlen_func = flash_attn_varlen_func
|
||||||
|
self.vllm_flash_attn_version = get_flash_attn_version()
|
||||||
if self.vllm_flash_attn_version is not None:
|
if self.vllm_flash_attn_version is not None:
|
||||||
self.flash_attn_varlen_func = \
|
self.flash_attn_varlen_func = \
|
||||||
functools.partial(flash_attn_varlen_func,
|
functools.partial(flash_attn_varlen_func,
|
||||||
fa_version=self.vllm_flash_attn_version)
|
fa_version=self.vllm_flash_attn_version)
|
||||||
|
|
||||||
|
# For MLA the v head dim is smaller than qk head dim so we pad out
|
||||||
|
# v with 0s to match the qk head dim for attention backends that do
|
||||||
|
# not support different headdims
|
||||||
|
# We don't need to pad V if we are on a hopper system with FA3
|
||||||
|
self._pad_v = self.vllm_flash_attn_version is None or not (
|
||||||
|
self.vllm_flash_attn_version == 3
|
||||||
|
and current_platform.get_device_capability()[0] == 9)
|
||||||
|
|
||||||
|
def _flash_attn_varlen_diff_headdims(self,
|
||||||
|
q,
|
||||||
|
k,
|
||||||
|
v,
|
||||||
|
return_softmax_lse=False,
|
||||||
|
softmax_scale=None,
|
||||||
|
**kwargs):
|
||||||
|
maybe_padded_v = v
|
||||||
|
if self._pad_v:
|
||||||
|
maybe_padded_v = torch.nn.functional.pad(
|
||||||
|
v, [0, q.shape[-1] - v.shape[-1]], value=0)
|
||||||
|
|
||||||
|
attn_out = self.flash_attn_varlen_func(
|
||||||
|
q=q,
|
||||||
|
k=k,
|
||||||
|
v=maybe_padded_v,
|
||||||
|
return_softmax_lse=return_softmax_lse,
|
||||||
|
softmax_scale=softmax_scale,
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Unpack the output if there is multiple results
|
||||||
|
lse = None
|
||||||
|
if isinstance(attn_out, tuple):
|
||||||
|
attn_out, lse = attn_out[0], attn_out[1]
|
||||||
|
|
||||||
|
# unpad if necessary
|
||||||
|
if self._pad_v:
|
||||||
|
attn_out = attn_out[..., :v.shape[-1]]
|
||||||
|
|
||||||
|
# Remain consistent with old `flash_attn_varlen_func` where there
|
||||||
|
# is only one output tensor if `return_softmax_lse` is False.
|
||||||
|
if return_softmax_lse:
|
||||||
|
return attn_out, lse
|
||||||
|
return attn_out
|
||||||
|
|
||||||
def _v_up_proj_and_o_proj(self, x):
|
def _v_up_proj_and_o_proj(self, x):
|
||||||
# Convert from (B, N, L) to (N, B, L)
|
# Convert from (B, N, L) to (N, B, L)
|
||||||
x = x.view(-1, self.num_heads, self.kv_lora_rank).transpose(0, 1)
|
x = x.view(-1, self.num_heads, self.kv_lora_rank).transpose(0, 1)
|
||||||
@ -745,16 +800,11 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]):
|
|||||||
k = torch.cat((k_nope, k_pe.expand((*k_nope.shape[:-1], -1))),
|
k = torch.cat((k_nope, k_pe.expand((*k_nope.shape[:-1], -1))),
|
||||||
dim=-1)
|
dim=-1)
|
||||||
|
|
||||||
# For MLA the v head dim is smaller than qk head dim so we pad
|
attn_output, attn_softmax_lse = \
|
||||||
# out v with 0s to match the qk head dim
|
self._flash_attn_varlen_diff_headdims(
|
||||||
v_padded = torch.nn.functional.pad(v,
|
|
||||||
[0, q.shape[-1] - v.shape[-1]],
|
|
||||||
value=0)
|
|
||||||
|
|
||||||
attn_output, attn_softmax_lse = self.flash_attn_varlen_func(
|
|
||||||
q=q,
|
q=q,
|
||||||
k=k,
|
k=k,
|
||||||
v=v_padded,
|
v=v,
|
||||||
cu_seqlens_q=prefill_metadata.query_start_loc,
|
cu_seqlens_q=prefill_metadata.query_start_loc,
|
||||||
cu_seqlens_k=prefill_metadata.chunked_context.cu_seq_lens[i],
|
cu_seqlens_k=prefill_metadata.chunked_context.cu_seq_lens[i],
|
||||||
max_seqlen_q=prefill_metadata.max_query_len,
|
max_seqlen_q=prefill_metadata.max_query_len,
|
||||||
@ -801,15 +851,10 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]):
|
|||||||
|
|
||||||
k = torch.cat((k_nope, k_pe.expand((*k_nope.shape[:-1], -1))), dim=-1)
|
k = torch.cat((k_nope, k_pe.expand((*k_nope.shape[:-1], -1))), dim=-1)
|
||||||
|
|
||||||
# For MLA the v head dim is smaller than qk head dim so we pad out
|
output = self._flash_attn_varlen_diff_headdims(
|
||||||
# v with 0s to match the qk head dim
|
|
||||||
v_padded = torch.nn.functional.pad(v, [0, q.shape[-1] - v.shape[-1]],
|
|
||||||
value=0)
|
|
||||||
|
|
||||||
output = self.flash_attn_varlen_func(
|
|
||||||
q=q,
|
q=q,
|
||||||
k=k,
|
k=k,
|
||||||
v=v_padded,
|
v=v,
|
||||||
cu_seqlens_q=attn_metadata.prefill.query_start_loc,
|
cu_seqlens_q=attn_metadata.prefill.query_start_loc,
|
||||||
cu_seqlens_k=attn_metadata.prefill.query_start_loc,
|
cu_seqlens_k=attn_metadata.prefill.query_start_loc,
|
||||||
max_seqlen_q=attn_metadata.prefill.max_query_len,
|
max_seqlen_q=attn_metadata.prefill.max_query_len,
|
||||||
@ -833,12 +878,7 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]):
|
|||||||
suffix_lse=suffix_lse,
|
suffix_lse=suffix_lse,
|
||||||
)
|
)
|
||||||
|
|
||||||
# slice by `:v.shape[-1]` in order to remove v headdim padding
|
return self.o_proj(output.flatten(start_dim=-2))[0]
|
||||||
output = output\
|
|
||||||
.view(-1, self.num_heads, q.shape[-1])[..., :v.shape[-1]]\
|
|
||||||
.reshape(-1, self.num_heads * v.shape[-1])
|
|
||||||
|
|
||||||
return self.o_proj(output)[0]
|
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def _forward_decode(
|
def _forward_decode(
|
||||||
|
Loading…
x
Reference in New Issue
Block a user