[Attention] Use FA3 for MLA on Hopper (#12807)

Signed-off-by: Lucas Wilkinson <lwilkinson@neuralmagic.com>
This commit is contained in:
Lucas Wilkinson 2025-02-06 06:43:12 -05:00 committed by GitHub
parent cefd56ee35
commit c786e757fa
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 51 additions and 59 deletions

View File

@ -14,19 +14,16 @@ from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
AttentionMetadataBuilder, AttentionMetadataBuilder,
AttentionType) AttentionType)
from vllm.attention.backends.utils import ( from vllm.attention.backends.utils import (
PAD_SLOT_ID, CommonAttentionState, compute_slot_mapping, PAD_SLOT_ID, VLLM_FLASH_ATTN_VERSION, CommonAttentionState,
compute_slot_mapping_start_idx, get_num_prefill_decode_query_kv_tokens, compute_slot_mapping, compute_slot_mapping_start_idx,
get_seq_len_block_table_args, is_all_cross_attn_metadata_set, get_num_prefill_decode_query_kv_tokens, get_seq_len_block_table_args,
is_all_encoder_attn_metadata_set, is_block_tables_empty) is_all_cross_attn_metadata_set, is_all_encoder_attn_metadata_set,
from vllm.envs import VLLM_FLASH_ATTN_VERSION is_block_tables_empty)
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.multimodal import MultiModalPlaceholderMap from vllm.multimodal import MultiModalPlaceholderMap
from vllm.platforms import current_platform
from vllm.utils import async_tensor_h2d, make_tensor_with_pad from vllm.utils import async_tensor_h2d, make_tensor_with_pad
from vllm.vllm_flash_attn import (fa_version_unsupported_reason, from vllm.vllm_flash_attn import (flash_attn_varlen_func,
flash_attn_varlen_func, flash_attn_with_kvcache)
flash_attn_with_kvcache,
is_fa_version_supported)
if TYPE_CHECKING: if TYPE_CHECKING:
from vllm.worker.model_runner import (ModelInputForGPUBuilder, from vllm.worker.model_runner import (ModelInputForGPUBuilder,
@ -644,25 +641,6 @@ class FlashAttentionImpl(AttentionImpl):
f"Supported head sizes are: {support_head_sizes}.") f"Supported head sizes are: {support_head_sizes}.")
self.attn_type = attn_type self.attn_type = attn_type
# if hopper default to FA3, otherwise stick to FA2 for now
# TODO(lucas): profile FA3 on ampere to see if it makes sense to
# use FA3 as default for both
if current_platform.get_device_capability()[0] >= 9:
self.fa_version = 3 if is_fa_version_supported(3) else 2
else:
self.fa_version = 2
if VLLM_FLASH_ATTN_VERSION is not None:
assert VLLM_FLASH_ATTN_VERSION in [2, 3]
self.fa_version = VLLM_FLASH_ATTN_VERSION
if not is_fa_version_supported(self.fa_version):
logger.error("Cannot use FA version %d is not supported due to %s",
self.fa_version,
fa_version_unsupported_reason(self.fa_version))
assert is_fa_version_supported(self.fa_version)
def forward( def forward(
self, self,
layer: AttentionLayer, layer: AttentionLayer,
@ -781,7 +759,7 @@ class FlashAttentionImpl(AttentionImpl):
alibi_slopes=alibi_slopes, alibi_slopes=alibi_slopes,
softcap=logits_soft_cap, softcap=logits_soft_cap,
out=prefill_output, out=prefill_output,
fa_version=self.fa_version, fa_version=VLLM_FLASH_ATTN_VERSION,
) )
else: else:
# prefix-enabled attention # prefix-enabled attention
@ -804,7 +782,7 @@ class FlashAttentionImpl(AttentionImpl):
block_table=prefill_meta.block_tables, block_table=prefill_meta.block_tables,
softcap=logits_soft_cap, softcap=logits_soft_cap,
out=prefill_output, out=prefill_output,
fa_version=self.fa_version, fa_version=VLLM_FLASH_ATTN_VERSION,
) )
if decode_meta := attn_metadata.decode_metadata: if decode_meta := attn_metadata.decode_metadata:
@ -833,7 +811,7 @@ class FlashAttentionImpl(AttentionImpl):
softcap=logits_soft_cap, softcap=logits_soft_cap,
block_table=decode_meta.block_tables, block_table=decode_meta.block_tables,
out=decode_output, out=decode_output,
fa_version=self.fa_version, fa_version=VLLM_FLASH_ATTN_VERSION,
) )
else: else:
# Use flash_attn_with_kvcache for normal decoding. # Use flash_attn_with_kvcache for normal decoding.
@ -854,7 +832,7 @@ class FlashAttentionImpl(AttentionImpl):
alibi_slopes=alibi_slopes, alibi_slopes=alibi_slopes,
softcap=logits_soft_cap, softcap=logits_soft_cap,
out=decode_output.unsqueeze(1), out=decode_output.unsqueeze(1),
fa_version=self.fa_version, fa_version=VLLM_FLASH_ATTN_VERSION,
) )
return output return output

View File

@ -12,6 +12,7 @@ from vllm import envs
from vllm.attention.backends.abstract import (AttentionLayer, from vllm.attention.backends.abstract import (AttentionLayer,
AttentionMetadata, AttentionMetadata,
MLAAttentionImpl, T) MLAAttentionImpl, T)
from vllm.attention.backends.utils import VLLM_FLASH_ATTN_VERSION
from vllm.distributed import (get_tensor_model_parallel_world_size, from vllm.distributed import (get_tensor_model_parallel_world_size,
tensor_model_parallel_all_reduce) tensor_model_parallel_all_reduce)
from vllm.model_executor.layers.linear import (ColumnParallelLinear, from vllm.model_executor.layers.linear import (ColumnParallelLinear,
@ -533,6 +534,7 @@ class MLACommonImpl(MLAAttentionImpl[T], Generic[T]):
max_seqlen_k=max_prefill_seq_len, max_seqlen_k=max_prefill_seq_len,
softmax_scale=self.scale, softmax_scale=self.scale,
causal=True, causal=True,
fa_version=VLLM_FLASH_ATTN_VERSION,
) )
attn_output = attn_output\ attn_output = attn_output\
.view(-1, self.num_heads, q.shape[-1])[..., :v.shape[-1]]\ .view(-1, self.num_heads, q.shape[-1])[..., :v.shape[-1]]\

View File

@ -8,12 +8,17 @@ from typing import TYPE_CHECKING, Any, Dict, List, Tuple, Type, TypeVar, Union
import numpy as np import numpy as np
import torch import torch
from vllm import envs
from vllm.attention import (AttentionMetadata, AttentionMetadataBuilder, from vllm.attention import (AttentionMetadata, AttentionMetadataBuilder,
AttentionState) AttentionState)
from vllm.attention.backends.abstract import AttentionType from vllm.attention.backends.abstract import AttentionType
from vllm.logger import logging
from vllm.multimodal import MultiModalPlaceholderMap from vllm.multimodal import MultiModalPlaceholderMap
from vllm.platforms import current_platform
from vllm.utils import async_tensor_h2d, make_tensor_with_pad from vllm.utils import async_tensor_h2d, make_tensor_with_pad
logger = logging.getLogger(__name__)
if TYPE_CHECKING: if TYPE_CHECKING:
from vllm.worker.model_runner_base import ModelRunnerBase from vllm.worker.model_runner_base import ModelRunnerBase
@ -580,3 +585,32 @@ def get_num_prefill_decode_query_kv_tokens(
return (num_prefill_query_tokens, num_prefill_kv_tokens, return (num_prefill_query_tokens, num_prefill_kv_tokens,
num_decode_query_tokens) num_decode_query_tokens)
try:
from vllm.vllm_flash_attn.flash_attn_interface import (
fa_version_unsupported_reason, is_fa_version_supported)
def flash_attn_version():
# if hopper default to FA3, otherwise stick to FA2 for now
# TODO(lucas): profile FA3 on ampere to see if it makes sense to
# use FA3 as default for both
if current_platform.get_device_capability()[0] >= 9:
fa_version = 3 if is_fa_version_supported(3) else 2
else:
fa_version = 2
if envs.VLLM_FLASH_ATTN_VERSION is not None:
assert envs.VLLM_FLASH_ATTN_VERSION in [2, 3]
fa_version = envs.VLLM_FLASH_ATTN_VERSION
if not is_fa_version_supported(fa_version):
logger.error("Cannot use FA version %d is not supported due to %s",
fa_version, fa_version_unsupported_reason(fa_version))
assert is_fa_version_supported(fa_version)
return fa_version
VLLM_FLASH_ATTN_VERSION = flash_attn_version()
except ImportError:
VLLM_FLASH_ATTN_VERSION = None

View File

@ -10,13 +10,10 @@ import triton.language as tl
from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl, from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
AttentionMetadata, AttentionType) AttentionMetadata, AttentionType)
from vllm.envs import VLLM_FLASH_ATTN_VERSION from vllm.attention.backends.utils import VLLM_FLASH_ATTN_VERSION
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.platforms import current_platform
from vllm.utils import cdiv from vllm.utils import cdiv
from vllm.vllm_flash_attn import (fa_version_unsupported_reason, from vllm.vllm_flash_attn import flash_attn_varlen_func
flash_attn_varlen_func,
is_fa_version_supported)
logger = init_logger(__name__) logger = init_logger(__name__)
@ -136,25 +133,6 @@ class FlashAttentionImpl(AttentionImpl):
"are not implemented for " "are not implemented for "
"FlashAttentionImpl") "FlashAttentionImpl")
# if hopper default to FA3, otherwise stick to FA2 for now
# TODO(lucas): profile FA3 on ampere to see if it makes sense to
# use FA3 as default for both
if current_platform.get_device_capability()[0] >= 9:
self.fa_version = 3 if is_fa_version_supported(3) else 2
else:
self.fa_version = 2
if VLLM_FLASH_ATTN_VERSION is not None:
assert VLLM_FLASH_ATTN_VERSION in [2, 3]
self.fa_version = VLLM_FLASH_ATTN_VERSION
if not is_fa_version_supported(self.fa_version):
logger.error("Cannot use FA version %d is not supported due to %s",
self.fa_version,
fa_version_unsupported_reason(self.fa_version))
assert is_fa_version_supported(self.fa_version)
def forward( def forward(
self, self,
layer: torch.nn.Module, layer: torch.nn.Module,
@ -227,7 +205,7 @@ class FlashAttentionImpl(AttentionImpl):
window_size=self.sliding_window, window_size=self.sliding_window,
block_table=attn_metadata.block_table, block_table=attn_metadata.block_table,
softcap=self.logits_soft_cap, softcap=self.logits_soft_cap,
fa_version=self.fa_version, fa_version=VLLM_FLASH_ATTN_VERSION,
) )
return output return output
@ -249,7 +227,7 @@ class FlashAttentionImpl(AttentionImpl):
logits_soft_cap=self.logits_soft_cap, logits_soft_cap=self.logits_soft_cap,
block_table=attn_metadata.block_table, block_table=attn_metadata.block_table,
common_prefix_len=attn_metadata.common_prefix_len, common_prefix_len=attn_metadata.common_prefix_len,
fa_version=self.fa_version, fa_version=VLLM_FLASH_ATTN_VERSION,
) )
return output return output