[Attention] Use FA3 for MLA on Hopper (#12807)
Signed-off-by: Lucas Wilkinson <lwilkinson@neuralmagic.com>
This commit is contained in:
parent
cefd56ee35
commit
c786e757fa
@ -14,19 +14,16 @@ from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
|
|||||||
AttentionMetadataBuilder,
|
AttentionMetadataBuilder,
|
||||||
AttentionType)
|
AttentionType)
|
||||||
from vllm.attention.backends.utils import (
|
from vllm.attention.backends.utils import (
|
||||||
PAD_SLOT_ID, CommonAttentionState, compute_slot_mapping,
|
PAD_SLOT_ID, VLLM_FLASH_ATTN_VERSION, CommonAttentionState,
|
||||||
compute_slot_mapping_start_idx, get_num_prefill_decode_query_kv_tokens,
|
compute_slot_mapping, compute_slot_mapping_start_idx,
|
||||||
get_seq_len_block_table_args, is_all_cross_attn_metadata_set,
|
get_num_prefill_decode_query_kv_tokens, get_seq_len_block_table_args,
|
||||||
is_all_encoder_attn_metadata_set, is_block_tables_empty)
|
is_all_cross_attn_metadata_set, is_all_encoder_attn_metadata_set,
|
||||||
from vllm.envs import VLLM_FLASH_ATTN_VERSION
|
is_block_tables_empty)
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
from vllm.multimodal import MultiModalPlaceholderMap
|
from vllm.multimodal import MultiModalPlaceholderMap
|
||||||
from vllm.platforms import current_platform
|
|
||||||
from vllm.utils import async_tensor_h2d, make_tensor_with_pad
|
from vllm.utils import async_tensor_h2d, make_tensor_with_pad
|
||||||
from vllm.vllm_flash_attn import (fa_version_unsupported_reason,
|
from vllm.vllm_flash_attn import (flash_attn_varlen_func,
|
||||||
flash_attn_varlen_func,
|
flash_attn_with_kvcache)
|
||||||
flash_attn_with_kvcache,
|
|
||||||
is_fa_version_supported)
|
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from vllm.worker.model_runner import (ModelInputForGPUBuilder,
|
from vllm.worker.model_runner import (ModelInputForGPUBuilder,
|
||||||
@ -644,25 +641,6 @@ class FlashAttentionImpl(AttentionImpl):
|
|||||||
f"Supported head sizes are: {support_head_sizes}.")
|
f"Supported head sizes are: {support_head_sizes}.")
|
||||||
self.attn_type = attn_type
|
self.attn_type = attn_type
|
||||||
|
|
||||||
# if hopper default to FA3, otherwise stick to FA2 for now
|
|
||||||
# TODO(lucas): profile FA3 on ampere to see if it makes sense to
|
|
||||||
# use FA3 as default for both
|
|
||||||
if current_platform.get_device_capability()[0] >= 9:
|
|
||||||
self.fa_version = 3 if is_fa_version_supported(3) else 2
|
|
||||||
else:
|
|
||||||
self.fa_version = 2
|
|
||||||
|
|
||||||
if VLLM_FLASH_ATTN_VERSION is not None:
|
|
||||||
assert VLLM_FLASH_ATTN_VERSION in [2, 3]
|
|
||||||
self.fa_version = VLLM_FLASH_ATTN_VERSION
|
|
||||||
|
|
||||||
if not is_fa_version_supported(self.fa_version):
|
|
||||||
logger.error("Cannot use FA version %d is not supported due to %s",
|
|
||||||
self.fa_version,
|
|
||||||
fa_version_unsupported_reason(self.fa_version))
|
|
||||||
|
|
||||||
assert is_fa_version_supported(self.fa_version)
|
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
layer: AttentionLayer,
|
layer: AttentionLayer,
|
||||||
@ -781,7 +759,7 @@ class FlashAttentionImpl(AttentionImpl):
|
|||||||
alibi_slopes=alibi_slopes,
|
alibi_slopes=alibi_slopes,
|
||||||
softcap=logits_soft_cap,
|
softcap=logits_soft_cap,
|
||||||
out=prefill_output,
|
out=prefill_output,
|
||||||
fa_version=self.fa_version,
|
fa_version=VLLM_FLASH_ATTN_VERSION,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
# prefix-enabled attention
|
# prefix-enabled attention
|
||||||
@ -804,7 +782,7 @@ class FlashAttentionImpl(AttentionImpl):
|
|||||||
block_table=prefill_meta.block_tables,
|
block_table=prefill_meta.block_tables,
|
||||||
softcap=logits_soft_cap,
|
softcap=logits_soft_cap,
|
||||||
out=prefill_output,
|
out=prefill_output,
|
||||||
fa_version=self.fa_version,
|
fa_version=VLLM_FLASH_ATTN_VERSION,
|
||||||
)
|
)
|
||||||
|
|
||||||
if decode_meta := attn_metadata.decode_metadata:
|
if decode_meta := attn_metadata.decode_metadata:
|
||||||
@ -833,7 +811,7 @@ class FlashAttentionImpl(AttentionImpl):
|
|||||||
softcap=logits_soft_cap,
|
softcap=logits_soft_cap,
|
||||||
block_table=decode_meta.block_tables,
|
block_table=decode_meta.block_tables,
|
||||||
out=decode_output,
|
out=decode_output,
|
||||||
fa_version=self.fa_version,
|
fa_version=VLLM_FLASH_ATTN_VERSION,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
# Use flash_attn_with_kvcache for normal decoding.
|
# Use flash_attn_with_kvcache for normal decoding.
|
||||||
@ -854,7 +832,7 @@ class FlashAttentionImpl(AttentionImpl):
|
|||||||
alibi_slopes=alibi_slopes,
|
alibi_slopes=alibi_slopes,
|
||||||
softcap=logits_soft_cap,
|
softcap=logits_soft_cap,
|
||||||
out=decode_output.unsqueeze(1),
|
out=decode_output.unsqueeze(1),
|
||||||
fa_version=self.fa_version,
|
fa_version=VLLM_FLASH_ATTN_VERSION,
|
||||||
)
|
)
|
||||||
return output
|
return output
|
||||||
|
|
||||||
|
@ -12,6 +12,7 @@ from vllm import envs
|
|||||||
from vllm.attention.backends.abstract import (AttentionLayer,
|
from vllm.attention.backends.abstract import (AttentionLayer,
|
||||||
AttentionMetadata,
|
AttentionMetadata,
|
||||||
MLAAttentionImpl, T)
|
MLAAttentionImpl, T)
|
||||||
|
from vllm.attention.backends.utils import VLLM_FLASH_ATTN_VERSION
|
||||||
from vllm.distributed import (get_tensor_model_parallel_world_size,
|
from vllm.distributed import (get_tensor_model_parallel_world_size,
|
||||||
tensor_model_parallel_all_reduce)
|
tensor_model_parallel_all_reduce)
|
||||||
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
|
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
|
||||||
@ -533,6 +534,7 @@ class MLACommonImpl(MLAAttentionImpl[T], Generic[T]):
|
|||||||
max_seqlen_k=max_prefill_seq_len,
|
max_seqlen_k=max_prefill_seq_len,
|
||||||
softmax_scale=self.scale,
|
softmax_scale=self.scale,
|
||||||
causal=True,
|
causal=True,
|
||||||
|
fa_version=VLLM_FLASH_ATTN_VERSION,
|
||||||
)
|
)
|
||||||
attn_output = attn_output\
|
attn_output = attn_output\
|
||||||
.view(-1, self.num_heads, q.shape[-1])[..., :v.shape[-1]]\
|
.view(-1, self.num_heads, q.shape[-1])[..., :v.shape[-1]]\
|
||||||
|
@ -8,12 +8,17 @@ from typing import TYPE_CHECKING, Any, Dict, List, Tuple, Type, TypeVar, Union
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
|
from vllm import envs
|
||||||
from vllm.attention import (AttentionMetadata, AttentionMetadataBuilder,
|
from vllm.attention import (AttentionMetadata, AttentionMetadataBuilder,
|
||||||
AttentionState)
|
AttentionState)
|
||||||
from vllm.attention.backends.abstract import AttentionType
|
from vllm.attention.backends.abstract import AttentionType
|
||||||
|
from vllm.logger import logging
|
||||||
from vllm.multimodal import MultiModalPlaceholderMap
|
from vllm.multimodal import MultiModalPlaceholderMap
|
||||||
|
from vllm.platforms import current_platform
|
||||||
from vllm.utils import async_tensor_h2d, make_tensor_with_pad
|
from vllm.utils import async_tensor_h2d, make_tensor_with_pad
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from vllm.worker.model_runner_base import ModelRunnerBase
|
from vllm.worker.model_runner_base import ModelRunnerBase
|
||||||
|
|
||||||
@ -580,3 +585,32 @@ def get_num_prefill_decode_query_kv_tokens(
|
|||||||
|
|
||||||
return (num_prefill_query_tokens, num_prefill_kv_tokens,
|
return (num_prefill_query_tokens, num_prefill_kv_tokens,
|
||||||
num_decode_query_tokens)
|
num_decode_query_tokens)
|
||||||
|
|
||||||
|
|
||||||
|
try:
|
||||||
|
from vllm.vllm_flash_attn.flash_attn_interface import (
|
||||||
|
fa_version_unsupported_reason, is_fa_version_supported)
|
||||||
|
|
||||||
|
def flash_attn_version():
|
||||||
|
# if hopper default to FA3, otherwise stick to FA2 for now
|
||||||
|
# TODO(lucas): profile FA3 on ampere to see if it makes sense to
|
||||||
|
# use FA3 as default for both
|
||||||
|
if current_platform.get_device_capability()[0] >= 9:
|
||||||
|
fa_version = 3 if is_fa_version_supported(3) else 2
|
||||||
|
else:
|
||||||
|
fa_version = 2
|
||||||
|
|
||||||
|
if envs.VLLM_FLASH_ATTN_VERSION is not None:
|
||||||
|
assert envs.VLLM_FLASH_ATTN_VERSION in [2, 3]
|
||||||
|
fa_version = envs.VLLM_FLASH_ATTN_VERSION
|
||||||
|
|
||||||
|
if not is_fa_version_supported(fa_version):
|
||||||
|
logger.error("Cannot use FA version %d is not supported due to %s",
|
||||||
|
fa_version, fa_version_unsupported_reason(fa_version))
|
||||||
|
|
||||||
|
assert is_fa_version_supported(fa_version)
|
||||||
|
return fa_version
|
||||||
|
|
||||||
|
VLLM_FLASH_ATTN_VERSION = flash_attn_version()
|
||||||
|
except ImportError:
|
||||||
|
VLLM_FLASH_ATTN_VERSION = None
|
||||||
|
@ -10,13 +10,10 @@ import triton.language as tl
|
|||||||
|
|
||||||
from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
|
from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
|
||||||
AttentionMetadata, AttentionType)
|
AttentionMetadata, AttentionType)
|
||||||
from vllm.envs import VLLM_FLASH_ATTN_VERSION
|
from vllm.attention.backends.utils import VLLM_FLASH_ATTN_VERSION
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
from vllm.platforms import current_platform
|
|
||||||
from vllm.utils import cdiv
|
from vllm.utils import cdiv
|
||||||
from vllm.vllm_flash_attn import (fa_version_unsupported_reason,
|
from vllm.vllm_flash_attn import flash_attn_varlen_func
|
||||||
flash_attn_varlen_func,
|
|
||||||
is_fa_version_supported)
|
|
||||||
|
|
||||||
logger = init_logger(__name__)
|
logger = init_logger(__name__)
|
||||||
|
|
||||||
@ -136,25 +133,6 @@ class FlashAttentionImpl(AttentionImpl):
|
|||||||
"are not implemented for "
|
"are not implemented for "
|
||||||
"FlashAttentionImpl")
|
"FlashAttentionImpl")
|
||||||
|
|
||||||
# if hopper default to FA3, otherwise stick to FA2 for now
|
|
||||||
# TODO(lucas): profile FA3 on ampere to see if it makes sense to
|
|
||||||
# use FA3 as default for both
|
|
||||||
if current_platform.get_device_capability()[0] >= 9:
|
|
||||||
self.fa_version = 3 if is_fa_version_supported(3) else 2
|
|
||||||
else:
|
|
||||||
self.fa_version = 2
|
|
||||||
|
|
||||||
if VLLM_FLASH_ATTN_VERSION is not None:
|
|
||||||
assert VLLM_FLASH_ATTN_VERSION in [2, 3]
|
|
||||||
self.fa_version = VLLM_FLASH_ATTN_VERSION
|
|
||||||
|
|
||||||
if not is_fa_version_supported(self.fa_version):
|
|
||||||
logger.error("Cannot use FA version %d is not supported due to %s",
|
|
||||||
self.fa_version,
|
|
||||||
fa_version_unsupported_reason(self.fa_version))
|
|
||||||
|
|
||||||
assert is_fa_version_supported(self.fa_version)
|
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
layer: torch.nn.Module,
|
layer: torch.nn.Module,
|
||||||
@ -227,7 +205,7 @@ class FlashAttentionImpl(AttentionImpl):
|
|||||||
window_size=self.sliding_window,
|
window_size=self.sliding_window,
|
||||||
block_table=attn_metadata.block_table,
|
block_table=attn_metadata.block_table,
|
||||||
softcap=self.logits_soft_cap,
|
softcap=self.logits_soft_cap,
|
||||||
fa_version=self.fa_version,
|
fa_version=VLLM_FLASH_ATTN_VERSION,
|
||||||
)
|
)
|
||||||
return output
|
return output
|
||||||
|
|
||||||
@ -249,7 +227,7 @@ class FlashAttentionImpl(AttentionImpl):
|
|||||||
logits_soft_cap=self.logits_soft_cap,
|
logits_soft_cap=self.logits_soft_cap,
|
||||||
block_table=attn_metadata.block_table,
|
block_table=attn_metadata.block_table,
|
||||||
common_prefix_len=attn_metadata.common_prefix_len,
|
common_prefix_len=attn_metadata.common_prefix_len,
|
||||||
fa_version=self.fa_version,
|
fa_version=VLLM_FLASH_ATTN_VERSION,
|
||||||
)
|
)
|
||||||
return output
|
return output
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user