diff --git a/cmake/external_projects/vllm_flash_attn.cmake b/cmake/external_projects/vllm_flash_attn.cmake index afd7c47e..110ef266 100644 --- a/cmake/external_projects/vllm_flash_attn.cmake +++ b/cmake/external_projects/vllm_flash_attn.cmake @@ -38,7 +38,7 @@ else() FetchContent_Declare( vllm-flash-attn GIT_REPOSITORY https://github.com/vllm-project/flash-attention.git - GIT_TAG dc9d410b3e2d6534a4c70724c2515f4def670a22 + GIT_TAG 0a721daebe4fa7149f06ecf3d3eabeb6dcd0f1fa GIT_PROGRESS TRUE # Don't share the vllm-flash-attn build between build types BINARY_DIR ${CMAKE_BINARY_DIR}/vllm-flash-attn diff --git a/vllm/attention/backends/mla/common.py b/vllm/attention/backends/mla/common.py index 54278f5f..2ec771a6 100644 --- a/vllm/attention/backends/mla/common.py +++ b/vllm/attention/backends/mla/common.py @@ -1043,8 +1043,8 @@ class MLACommonImpl(MLAAttentionImpl[T], Generic[T]): self.q_proj = q_proj self.kv_b_proj = kv_b_proj self.o_proj = o_proj - self.triton_fa_func = triton_attention + self.triton_fa_func = triton_attention # Handle the differences between the flash_attn_varlen from flash_attn # and the one from vllm_flash_attn. The former is used on RoCM and the # latter has an additional parameter to control FA2 vs FA3 @@ -1055,6 +1055,70 @@ class MLACommonImpl(MLAAttentionImpl[T], Generic[T]): functools.partial(flash_attn_varlen_func, fa_version=self.vllm_flash_attn_version) + # For MLA the v head dim is smaller than qk head dim so we pad out + # v with 0s to match the qk head dim for attention backends that do + # not support different headdims + # We don't need to pad V if we are on a hopper system with FA3 + self._pad_v = self.vllm_flash_attn_version is None or not ( + self.vllm_flash_attn_version == 3 + and current_platform.get_device_capability()[0] == 9) + + def _flash_attn_varlen_diff_headdims(self, q, k, v, softmax_scale, + return_softmax_lse, **kwargs): + maybe_padded_v = v + if self._pad_v: + maybe_padded_v = torch.nn.functional.pad( + v, [0, q.shape[-1] - v.shape[-1]], value=0) + + if is_hip and envs.VLLM_USE_TRITON_FLASH_ATTN \ + and not return_softmax_lse: + attn_out = self.triton_fa_func( + q, + k, + maybe_padded_v, + **kwargs, + ) + if is_vllm_fa: + attn_out = self.flash_attn_varlen_func( + q=q, + k=k, + v=maybe_padded_v, + return_softmax_lse=return_softmax_lse, + softmax_scale=softmax_scale, + **kwargs, + ) + else: + # Use return_attn_probs instead of return_softmax_lse for RoCM + attn_out = self.flash_attn_varlen_func( + q=q, + k=k, + v=maybe_padded_v, + return_attn_probs=return_softmax_lse, + softmax_scale=softmax_scale, + **kwargs, + ) + + # Unpack the output if there is multiple results, + # triton always returns (output, softmax_lse), + # vllm_flash_attn returns (output, softmax_lse) when + # `return_softmax_lse = True` + # flash_attn (RoCM) returns (output, softmax_lse, ...) when + # `return_attn_probs = True` + rest = None + if isinstance(attn_out, tuple): + attn_out, *rest = attn_out + + # unpad if necessary + if self._pad_v: + attn_out = attn_out[..., :v.shape[-1]] + + # Remain consistent with old `flash_attn_varlen_func` where there + # is only one output tensor if `return_softmax_lse` is False. + if return_softmax_lse: + assert rest is not None + return attn_out, rest[0] + return attn_out + def _v_up_proj_and_o_proj(self, x): # Convert from (B, N, L) to (N, B, L) x = x.view(-1, self.num_heads, self.kv_lora_rank).transpose(0, 1) @@ -1176,40 +1240,19 @@ class MLACommonImpl(MLAAttentionImpl[T], Generic[T]): k = torch.cat((k_nope, k_pe.expand((*k_nope.shape[:-1], -1))), dim=-1) - # For MLA the v head dim is smaller than qk head dim so we pad - # out v with 0s to match the qk head dim - v_padded = torch.nn.functional.pad(v, - [0, q.shape[-1] - v.shape[-1]], - value=0) - - if is_vllm_fa: - attn_output, attn_softmax_lse = self.flash_attn_varlen_func( - q=q, - k=k, - v=v_padded, - cu_seqlens_q=prefill_metadata.query_start_loc, - cu_seqlens_k=prefill_metadata.context_chunk_cu_seq_lens[i], - max_seqlen_q=prefill_metadata.max_query_len, - max_seqlen_k=prefill_metadata. - context_chunk_max_seq_lens[i], - softmax_scale=self.scale, - causal=False, # Context is unmasked - return_softmax_lse=True, - ) - else: - attn_output, attn_softmax_lse, _ = self.flash_attn_varlen_func( - q=q, - k=k, - v=v_padded, - cu_seqlens_q=prefill_metadata.query_start_loc, - cu_seqlens_k=prefill_metadata.context_chunk_cu_seq_lens[i], - max_seqlen_q=prefill_metadata.max_query_len, - max_seqlen_k=prefill_metadata. - context_chunk_max_seq_lens[i], - softmax_scale=self.scale, - causal=False, # Context is unmasked - return_attn_probs=True, - ) + attn_output, attn_softmax_lse = \ + self._flash_attn_varlen_diff_headdims( + q=q, + k=k, + v=v, + cu_seqlens_q=prefill_metadata.query_start_loc, + cu_seqlens_k=prefill_metadata.context_chunk_cu_seq_lens[i], + max_seqlen_q=prefill_metadata.max_query_len, + max_seqlen_k=prefill_metadata.context_chunk_max_seq_lens[i], + softmax_scale=self.scale, + causal=False, # Context is unmasked + return_softmax_lse=True, + ) if output is None: output = attn_output @@ -1252,58 +1295,22 @@ class MLACommonImpl(MLAAttentionImpl[T], Generic[T]): k = torch.cat((k_nope, k_pe.expand((*k_nope.shape[:-1], -1))), dim=-1) - # For MLA the v head dim is smaller than qk head dim so we pad out - # v with 0s to match the qk head dim - v_padded = torch.nn.functional.pad(v, [0, q.shape[-1] - v.shape[-1]], - value=0) - - if is_hip and envs.VLLM_USE_TRITON_FLASH_ATTN and not has_context: - output = self.triton_fa_func( - q, - k, - v_padded, - None, - prefill_metadata.query_start_loc, - prefill_metadata.query_start_loc, - prefill_metadata.max_prefill_seq_len, - prefill_metadata.max_prefill_seq_len, - True, # causal - self.scale, - None, # attn_mask is None unless applying ALiBi mask - ) - ## triton flash attention always return 2 objects - if not has_context: - output = output[0] - elif is_vllm_fa: - output = self.flash_attn_varlen_func( - q=q, - k=k, - v=v_padded, - cu_seqlens_q=prefill_metadata.query_start_loc, - cu_seqlens_k=prefill_metadata.query_start_loc, - max_seqlen_q=prefill_metadata.max_prefill_seq_len, - max_seqlen_k=prefill_metadata.max_prefill_seq_len, - softmax_scale=self.scale, - causal=True, - return_softmax_lse=has_context, - ) - else: - output = self.flash_attn_varlen_func( - q=q, - k=k, - v=v_padded, - cu_seqlens_q=prefill_metadata.query_start_loc, - cu_seqlens_k=prefill_metadata.query_start_loc, - max_seqlen_q=prefill_metadata.max_prefill_seq_len, - max_seqlen_k=prefill_metadata.max_prefill_seq_len, - softmax_scale=self.scale, - causal=True, - return_attn_probs=has_context, - ) + output = self._flash_attn_varlen_diff_headdims( + q=q, + k=k, + v=v, + cu_seqlens_q=prefill_metadata.query_start_loc, + cu_seqlens_k=prefill_metadata.query_start_loc, + max_seqlen_q=prefill_metadata.max_prefill_seq_len, + max_seqlen_k=prefill_metadata.max_prefill_seq_len, + softmax_scale=self.scale, + causal=True, + return_softmax_lse=has_context, + ) if has_context: # ROCm flash_attn_varlen_func will return 3 objects instead of 2 - suffix_output, suffix_lse, *rest = output + suffix_output, suffix_lse = output context_output, context_lse = self._compute_prefill_context( \ q, kv_c_and_k_pe_cache, attn_metadata) @@ -1316,12 +1323,7 @@ class MLACommonImpl(MLAAttentionImpl[T], Generic[T]): suffix_lse=suffix_lse, ) - # slice by `:v.shape[-1]` in order to remove v headdim padding - output = output\ - .view(-1, self.num_heads, q.shape[-1])[..., :v.shape[-1]]\ - .reshape(-1, self.num_heads * v.shape[-1]) - - return self.o_proj(output)[0] + return self.o_proj(output.flatten(start_dim=-2))[0] @abstractmethod def _forward_decode( diff --git a/vllm/attention/backends/utils.py b/vllm/attention/backends/utils.py index b4413c36..89f1ea9b 100644 --- a/vllm/attention/backends/utils.py +++ b/vllm/attention/backends/utils.py @@ -2,8 +2,10 @@ """Attention backend utils""" from collections import defaultdict from contextlib import contextmanager +from dataclasses import dataclass from itertools import accumulate -from typing import TYPE_CHECKING, Any, Dict, List, Tuple, Type, TypeVar, Union +from typing import (TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Type, + TypeVar, Union) import numpy as np import torch @@ -11,6 +13,7 @@ import torch from vllm.attention import (AttentionMetadata, AttentionMetadataBuilder, AttentionState) from vllm.attention.backends.abstract import AttentionType +from vllm.config import ModelConfig from vllm.logger import init_logger from vllm.multimodal import MultiModalPlaceholderMap from vllm.utils import async_tensor_h2d, make_tensor_with_pad @@ -583,3 +586,24 @@ def get_num_prefill_decode_query_kv_tokens( return (num_prefill_query_tokens, num_prefill_kv_tokens, num_decode_query_tokens) + + +@dataclass +class MLADims: + q_lora_rank: Optional[int] + kv_lora_rank: int + qk_nope_head_dim: int + qk_rope_head_dim: int + v_head_dim: int + + +def get_mla_dims(model_config: ModelConfig) -> MLADims: + hf_text_config = model_config.hf_text_config + + return MLADims( + q_lora_rank=getattr(hf_text_config, "q_lora_rank", None), + kv_lora_rank=hf_text_config.kv_lora_rank, + qk_nope_head_dim=hf_text_config.qk_nope_head_dim, + qk_rope_head_dim=hf_text_config.qk_rope_head_dim, + v_head_dim=hf_text_config.v_head_dim, + ) diff --git a/vllm/v1/attention/backends/flash_attn.py b/vllm/v1/attention/backends/flash_attn.py index b4c7708d..c039cd80 100755 --- a/vllm/v1/attention/backends/flash_attn.py +++ b/vllm/v1/attention/backends/flash_attn.py @@ -23,7 +23,8 @@ if TYPE_CHECKING: from vllm.v1.worker.gpu_model_runner import GPUModelRunner if current_platform.is_cuda(): - from vllm.vllm_flash_attn import flash_attn_varlen_func + from vllm.vllm_flash_attn import (flash_attn_varlen_func, + get_scheduler_metadata) logger = init_logger(__name__) @@ -93,6 +94,10 @@ class FlashAttentionMetadata: prefix_kv_lens: Optional[torch.Tensor] suffix_kv_lens: Optional[torch.Tensor] + # Optional aot scheduling + scheduler_metadata: Optional[torch.Tensor] = None + prefix_scheduler_metadata: Optional[torch.Tensor] = None + # For logging. num_input_tokens: int = 0 # Number of tokens including padding. @@ -277,7 +282,14 @@ def make_local_attention_virtual_batches( class FlashAttentionMetadataBuilder: def __init__(self, runner: "GPUModelRunner"): + model_config = runner.model_config + self.runner = runner + self.aot_schedule = (get_flash_attn_version() == 3) + self.num_heads = model_config.get_num_attention_heads( + runner.parallel_config) + self.headdim = model_config.get_head_size() + self.page_size = self.runner.block_size def reorder_batch(self, input_batch: "InputBatch", scheduler_output: "SchedulerOutput") -> bool: @@ -319,6 +331,24 @@ class FlashAttentionMetadataBuilder: ) use_cascade = common_prefix_len > 0 + + def schedule(cu_query_lens, max_query_len, seqlens, max_seq_len, + causal): + if self.aot_schedule: + return get_scheduler_metadata( + batch_size=num_reqs, + max_seqlen_q=max_query_len, + max_seqlen_k=max_seq_len, + cache_seqlens=seqlens, + num_heads_q=self.num_heads, + num_heads_kv=self.num_heads, + headdim=self.headdim, + page_size=self.page_size, + cu_seqlens_q=cu_query_lens, + causal=causal, + ) + return None + if use_cascade: cu_prefix_query_lens = torch.tensor([0, num_actual_tokens], dtype=torch.int32, @@ -330,10 +360,28 @@ class FlashAttentionMetadataBuilder: common_prefix_len) suffix_kv_lens = torch.from_numpy(suffix_kv_lens).to( self.runner.device) + prefix_scheduler_metadata = schedule( + cu_query_lens=cu_prefix_query_lens, + max_query_len=num_actual_tokens, + seqlens=prefix_kv_lens, + max_seq_len=common_prefix_len, + causal=False) + scheduler_metadata = schedule(cu_query_lens=query_start_loc, + max_query_len=max_query_len, + seqlens=suffix_kv_lens, + max_seq_len=max_seq_len - + common_prefix_len, + causal=True) else: cu_prefix_query_lens = None prefix_kv_lens = None suffix_kv_lens = None + prefix_scheduler_metadata = None + scheduler_metadata = schedule(cu_query_lens=query_start_loc, + max_query_len=max_query_len, + seqlens=seq_lens, + max_seq_len=max_seq_len, + causal=True) attn_metadata = FlashAttentionMetadata( num_actual_tokens=num_actual_tokens, @@ -345,10 +393,12 @@ class FlashAttentionMetadataBuilder: slot_mapping=slot_mapping, use_cascade=use_cascade, common_prefix_len=common_prefix_len, + scheduler_metadata=scheduler_metadata, cu_prefix_query_lens=cu_prefix_query_lens, prefix_kv_lens=prefix_kv_lens, suffix_kv_lens=suffix_kv_lens, local_attn_metadata=local_attn_metadata, + prefix_scheduler_metadata=prefix_scheduler_metadata, ) return attn_metadata @@ -515,6 +565,7 @@ class FlashAttentionImpl(AttentionImpl): window_size=self.sliding_window, block_table=block_table, softcap=self.logits_soft_cap, + scheduler_metadata=attn_metadata.scheduler_metadata, fa_version=self.vllm_flash_attn_version, q_descale=layer._q_scale.expand(descale_shape), k_descale=layer._k_scale.expand(descale_shape), @@ -543,6 +594,8 @@ class FlashAttentionImpl(AttentionImpl): block_table=attn_metadata.block_table, common_prefix_len=attn_metadata.common_prefix_len, fa_version=self.vllm_flash_attn_version, + prefix_scheduler_metadata=attn_metadata.prefix_scheduler_metadata, + suffix_scheduler_metadata=attn_metadata.scheduler_metadata, q_descale=layer._q_scale, k_descale=layer._k_scale, v_descale=layer._v_scale, @@ -636,6 +689,8 @@ def cascade_attention( block_table: torch.Tensor, common_prefix_len: int, fa_version: int, + prefix_scheduler_metadata: Optional[torch.Tensor] = None, + suffix_scheduler_metadata: Optional[torch.Tensor] = None, q_descale: Optional[torch.Tensor] = None, k_descale: Optional[torch.Tensor] = None, v_descale: Optional[torch.Tensor] = None, @@ -667,6 +722,7 @@ def cascade_attention( block_table=block_table[:1], softcap=logits_soft_cap, return_softmax_lse=True, + scheduler_metadata=prefix_scheduler_metadata, fa_version=fa_version, q_descale=q_descale.expand(descale_shape) if q_descale is not None else None, @@ -693,6 +749,7 @@ def cascade_attention( block_table=block_table[:, num_common_kv_blocks:], softcap=logits_soft_cap, return_softmax_lse=True, + scheduler_metadata=suffix_scheduler_metadata, fa_version=fa_version, q_descale=q_descale.expand(descale_shape) if q_descale is not None else None, diff --git a/vllm/v1/attention/backends/mla/common.py b/vllm/v1/attention/backends/mla/common.py index b77e9525..c0a6bd29 100644 --- a/vllm/v1/attention/backends/mla/common.py +++ b/vllm/v1/attention/backends/mla/common.py @@ -195,6 +195,7 @@ from vllm import _custom_ops as ops from vllm.attention.backends.abstract import (AttentionBackend, AttentionLayer, AttentionMetadata, MLAAttentionImpl) +from vllm.attention.backends.utils import get_mla_dims from vllm.attention.ops.merge_attn_states import merge_attn_states from vllm.logger import init_logger from vllm.model_executor.layers.linear import (ColumnParallelLinear, @@ -207,9 +208,11 @@ from vllm.vllm_flash_attn.fa_utils import get_flash_attn_version try: from vllm.vllm_flash_attn import flash_attn_varlen_func + is_vllm_fa = True except ImportError: # For rocm use upstream flash attention from flash_attn import flash_attn_varlen_func + is_vllm_fa = False if TYPE_CHECKING: from vllm.v1.core.sched.output import SchedulerOutput @@ -350,6 +353,14 @@ class MLACommonMetadataBuilder(Generic[M]): model_config = runner.model_config cache_config = runner.cache_config self.chunked_prefill_enabled = scheduler_config.chunked_prefill_enabled + self.num_heads = model_config.get_num_attention_heads( + runner.parallel_config) + self.mla_dims = get_mla_dims(model_config) + self.aot_schedule = is_vllm_fa and (get_flash_attn_version() == 3) + + # Dont try to access the runner on AMD + if self.aot_schedule: + self.page_size = self.runner.block_size if self.chunked_prefill_enabled: self.chunked_prefill_workspace_size = min( @@ -375,7 +386,6 @@ class MLACommonMetadataBuilder(Generic[M]): dtype=model_config.dtype, device=runner.device, ) - self.page_size = self.runner.block_size def reorder_batch(self, input_batch: "InputBatch", scheduler_output: "SchedulerOutput") -> bool: @@ -464,7 +474,6 @@ class MLACommonMetadataBuilder(Generic[M]): seq_lens_cpu = self.runner.seq_lens_cpu[:num_reqs] seq_lens = seq_lens_cpu.to(device, non_blocking=True) - max_query_len = seq_lens_cpu.max().item() prefill_metadata = None if self._num_prefills > 0: @@ -475,6 +484,8 @@ class MLACommonMetadataBuilder(Generic[M]): num_computed_tokens_cpu_tensor[reqs_start:num_reqs] max_context_len_cpu = context_lens_cpu.max().item() num_prefills_with_context_cpu = (context_lens_cpu > 0).sum().item() + prefill_query_start_loc = query_start_loc[ + reqs_start:] - query_start_loc[reqs_start] chunked_context_metadata = None if self.chunked_prefill_enabled and self._num_prefills > 0 \ @@ -537,8 +548,7 @@ class MLACommonMetadataBuilder(Generic[M]): prefill_metadata = MLACommonPrefillMetadata( input_positions=input_positions[tokens_start:], block_table=block_table[reqs_start:, ...], - query_start_loc=query_start_loc[reqs_start:] - - query_start_loc[reqs_start], + query_start_loc=prefill_query_start_loc, max_query_len=max_query_len, chunked_context=chunked_context_metadata, ) @@ -628,11 +638,56 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]): # and the one from vllm_flash_attn. The former is used on RoCM and the # latter has an additional parameter to control FA2 vs FA3 self.flash_attn_varlen_func = flash_attn_varlen_func + self.vllm_flash_attn_version = get_flash_attn_version() if self.vllm_flash_attn_version is not None: self.flash_attn_varlen_func = \ functools.partial(flash_attn_varlen_func, fa_version=self.vllm_flash_attn_version) + # For MLA the v head dim is smaller than qk head dim so we pad out + # v with 0s to match the qk head dim for attention backends that do + # not support different headdims + # We don't need to pad V if we are on a hopper system with FA3 + self._pad_v = self.vllm_flash_attn_version is None or not ( + self.vllm_flash_attn_version == 3 + and current_platform.get_device_capability()[0] == 9) + + def _flash_attn_varlen_diff_headdims(self, + q, + k, + v, + return_softmax_lse=False, + softmax_scale=None, + **kwargs): + maybe_padded_v = v + if self._pad_v: + maybe_padded_v = torch.nn.functional.pad( + v, [0, q.shape[-1] - v.shape[-1]], value=0) + + attn_out = self.flash_attn_varlen_func( + q=q, + k=k, + v=maybe_padded_v, + return_softmax_lse=return_softmax_lse, + softmax_scale=softmax_scale, + **kwargs, + ) + + # Unpack the output if there is multiple results + lse = None + if isinstance(attn_out, tuple): + attn_out, lse = attn_out[0], attn_out[1] + + # unpad if necessary + if self._pad_v: + attn_out = attn_out[..., :v.shape[-1]] + + # Remain consistent with old `flash_attn_varlen_func` where there + # is only one output tensor if `return_softmax_lse` is False. + if return_softmax_lse: + return attn_out, lse + return attn_out + def _v_up_proj_and_o_proj(self, x): # Convert from (B, N, L) to (N, B, L) x = x.view(-1, self.num_heads, self.kv_lora_rank).transpose(0, 1) @@ -745,16 +800,11 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]): k = torch.cat((k_nope, k_pe.expand((*k_nope.shape[:-1], -1))), dim=-1) - # For MLA the v head dim is smaller than qk head dim so we pad - # out v with 0s to match the qk head dim - v_padded = torch.nn.functional.pad(v, - [0, q.shape[-1] - v.shape[-1]], - value=0) - - attn_output, attn_softmax_lse = self.flash_attn_varlen_func( + attn_output, attn_softmax_lse = \ + self._flash_attn_varlen_diff_headdims( q=q, k=k, - v=v_padded, + v=v, cu_seqlens_q=prefill_metadata.query_start_loc, cu_seqlens_k=prefill_metadata.chunked_context.cu_seq_lens[i], max_seqlen_q=prefill_metadata.max_query_len, @@ -801,15 +851,10 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]): k = torch.cat((k_nope, k_pe.expand((*k_nope.shape[:-1], -1))), dim=-1) - # For MLA the v head dim is smaller than qk head dim so we pad out - # v with 0s to match the qk head dim - v_padded = torch.nn.functional.pad(v, [0, q.shape[-1] - v.shape[-1]], - value=0) - - output = self.flash_attn_varlen_func( + output = self._flash_attn_varlen_diff_headdims( q=q, k=k, - v=v_padded, + v=v, cu_seqlens_q=attn_metadata.prefill.query_start_loc, cu_seqlens_k=attn_metadata.prefill.query_start_loc, max_seqlen_q=attn_metadata.prefill.max_query_len, @@ -833,12 +878,7 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]): suffix_lse=suffix_lse, ) - # slice by `:v.shape[-1]` in order to remove v headdim padding - output = output\ - .view(-1, self.num_heads, q.shape[-1])[..., :v.shape[-1]]\ - .reshape(-1, self.num_heads * v.shape[-1]) - - return self.o_proj(output)[0] + return self.o_proj(output.flatten(start_dim=-2))[0] @abstractmethod def _forward_decode(