[Attention] Flash Attention 3 - fp8 (#14570)

Signed-off-by: Mickael Seznec <mickael@mistral.ai>
This commit is contained in:
Mickaël Seznec 2025-03-20 06:14:20 +01:00 committed by GitHub
parent ae65f3e237
commit a597a57595
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
15 changed files with 272 additions and 76 deletions

View File

@ -38,7 +38,7 @@ else()
FetchContent_Declare( FetchContent_Declare(
vllm-flash-attn vllm-flash-attn
GIT_REPOSITORY https://github.com/vllm-project/flash-attention.git GIT_REPOSITORY https://github.com/vllm-project/flash-attention.git
GIT_TAG 9bfa9869829d8c593527eb34c5271d0090f7ccc9 GIT_TAG dc9d410b3e2d6534a4c70724c2515f4def670a22
GIT_PROGRESS TRUE GIT_PROGRESS TRUE
# Don't share the vllm-flash-attn build between build types # Don't share the vllm-flash-attn build between build types
BINARY_DIR ${CMAKE_BINARY_DIR}/vllm-flash-attn BINARY_DIR ${CMAKE_BINARY_DIR}/vllm-flash-attn

View File

@ -15,6 +15,7 @@ NUM_HEADS = [(4, 4), (8, 2), (16, 2)]
HEAD_SIZES = [128, 256] HEAD_SIZES = [128, 256]
BLOCK_SIZES = [16, 32] BLOCK_SIZES = [16, 32]
DTYPES = [torch.float16, torch.bfloat16] DTYPES = [torch.float16, torch.bfloat16]
QDTYPES = [None, torch.float8_e4m3fn]
# one value large enough to test overflow in index calculation. # one value large enough to test overflow in index calculation.
# one value small enough to test the schema op check # one value small enough to test the schema op check
NUM_BLOCKS = [32768, 2048] NUM_BLOCKS = [32768, 2048]
@ -85,6 +86,7 @@ def ref_paged_attn(
@pytest.mark.parametrize("num_blocks", NUM_BLOCKS) @pytest.mark.parametrize("num_blocks", NUM_BLOCKS)
@pytest.mark.parametrize("sliding_window", [None, 256]) @pytest.mark.parametrize("sliding_window", [None, 256])
@pytest.mark.parametrize("fa_version", [2, 3]) @pytest.mark.parametrize("fa_version", [2, 3])
@pytest.mark.parametrize("q_dtype", QDTYPES)
@torch.inference_mode() @torch.inference_mode()
def test_flash_attn_with_paged_kv( def test_flash_attn_with_paged_kv(
use_out: bool, use_out: bool,
@ -97,11 +99,15 @@ def test_flash_attn_with_paged_kv(
num_blocks: int, num_blocks: int,
sliding_window: Optional[int], sliding_window: Optional[int],
fa_version: int, fa_version: int,
q_dtype: Optional[torch.dtype],
) -> None: ) -> None:
torch.set_default_device("cuda") torch.set_default_device("cuda")
if not is_fa_version_supported(fa_version): if not is_fa_version_supported(fa_version):
pytest.skip(f"Flash attention version {fa_version} not supported due " pytest.skip(f"Flash attention version {fa_version} not supported due "
f"to: \"{fa_version_unsupported_reason(fa_version)}\"") f"to: \"{fa_version_unsupported_reason(fa_version)}\"")
if q_dtype is not None and (dtype != torch.bfloat16 or fa_version == 2):
pytest.skip("Flash attention with quantized inputs is only "
"supported on version 3 with bfloat16 base type")
current_platform.seed_everything(0) current_platform.seed_everything(0)
num_seqs = len(kv_lens) num_seqs = len(kv_lens)
@ -130,10 +136,28 @@ def test_flash_attn_with_paged_kv(
q = query.unsqueeze(1) q = query.unsqueeze(1)
out = torch.empty_like(q) if use_out else None out = torch.empty_like(q) if use_out else None
maybe_quantized_query = q
maybe_quantized_key_cache = key_cache
maybe_quantized_value_cache = value_cache
q_descale = None
k_descale = None
v_descale = None
if q_dtype is not None:
# QKV are drawn from N(0, 1): no need for a fp8 scaling factor
maybe_quantized_query = query.to(q_dtype)
maybe_quantized_key_cache = key_cache.to(q_dtype)
maybe_quantized_value_cache = value_cache.to(q_dtype)
scale_shape = (num_seqs, num_kv_heads)
q_descale = torch.ones(scale_shape, dtype=torch.float32)
k_descale = torch.ones(scale_shape, dtype=torch.float32)
v_descale = torch.ones(scale_shape, dtype=torch.float32)
output = flash_attn_with_kvcache( output = flash_attn_with_kvcache(
q=q, q=maybe_quantized_query,
k_cache=key_cache, k_cache=maybe_quantized_key_cache,
v_cache=value_cache, v_cache=maybe_quantized_value_cache,
out=out, out=out,
softmax_scale=scale, softmax_scale=scale,
causal=True, causal=True,
@ -142,10 +166,17 @@ def test_flash_attn_with_paged_kv(
softcap=soft_cap if soft_cap is not None else 0, softcap=soft_cap if soft_cap is not None else 0,
window_size=window_size, window_size=window_size,
fa_version=fa_version, fa_version=fa_version,
q_descale=q_descale,
k_descale=k_descale,
v_descale=v_descale,
) )
output = output if not use_out else out output = output if not use_out else out
output = output.squeeze(1) output = output.squeeze(1)
atol, rtol = 1.5e-2, 1e-2
if q_dtype is not None:
atol, rtol = 1.5e-1, 1.5e-1
ref_output = ref_paged_attn(query=query, ref_output = ref_paged_attn(query=query,
key_cache=key_cache, key_cache=key_cache,
value_cache=value_cache, value_cache=value_cache,
@ -155,7 +186,7 @@ def test_flash_attn_with_paged_kv(
scale=scale, scale=scale,
soft_cap=soft_cap, soft_cap=soft_cap,
sliding_window=sliding_window) sliding_window=sliding_window)
torch.testing.assert_close(output, ref_output, atol=2e-2, rtol=1e-2), \ torch.testing.assert_close(output, ref_output, atol=atol, rtol=rtol), \
f"{torch.max(torch.abs(output - ref_output))}" f"{torch.max(torch.abs(output - ref_output))}"
@ -171,6 +202,7 @@ def test_flash_attn_with_paged_kv(
@pytest.mark.parametrize("soft_cap", [None, 10.0, 50.0]) @pytest.mark.parametrize("soft_cap", [None, 10.0, 50.0])
@pytest.mark.parametrize("num_blocks", NUM_BLOCKS) @pytest.mark.parametrize("num_blocks", NUM_BLOCKS)
@pytest.mark.parametrize("fa_version", [2, 3]) @pytest.mark.parametrize("fa_version", [2, 3])
@pytest.mark.parametrize("q_dtype", QDTYPES)
@torch.inference_mode() @torch.inference_mode()
def test_varlen_with_paged_kv( def test_varlen_with_paged_kv(
use_out: bool, use_out: bool,
@ -183,11 +215,15 @@ def test_varlen_with_paged_kv(
soft_cap: Optional[float], soft_cap: Optional[float],
num_blocks: int, num_blocks: int,
fa_version: int, fa_version: int,
q_dtype: Optional[torch.dtype],
) -> None: ) -> None:
torch.set_default_device("cuda") torch.set_default_device("cuda")
if not is_fa_version_supported(fa_version): if not is_fa_version_supported(fa_version):
pytest.skip(f"Flash attention version {fa_version} not supported due " pytest.skip(f"Flash attention version {fa_version} not supported due "
f"to: \"{fa_version_unsupported_reason(fa_version)}\"") f"to: \"{fa_version_unsupported_reason(fa_version)}\"")
if q_dtype is not None and (dtype != torch.bfloat16 or fa_version == 2):
pytest.skip("Flash attention with quantized inputs is only "
"supported on version 3 with bfloat16 base type")
current_platform.seed_everything(0) current_platform.seed_everything(0)
num_seqs = len(seq_lens) num_seqs = len(seq_lens)
query_lens = [x[0] for x in seq_lens] query_lens = [x[0] for x in seq_lens]
@ -223,10 +259,28 @@ def test_varlen_with_paged_kv(
dtype=torch.int32) dtype=torch.int32)
out = torch.empty_like(query) if use_out else None out = torch.empty_like(query) if use_out else None
maybe_quantized_query = query
maybe_quantized_key_cache = key_cache
maybe_quantized_value_cache = value_cache
q_descale = None
k_descale = None
v_descale = None
if q_dtype is not None:
# QKV are drawn from N(0, 1): no need for a fp8 scaling factor
maybe_quantized_query = query.to(q_dtype)
maybe_quantized_key_cache = key_cache.to(q_dtype)
maybe_quantized_value_cache = value_cache.to(q_dtype)
scale_shape = (num_seqs, num_kv_heads)
q_descale = torch.ones(scale_shape, dtype=torch.float32)
k_descale = torch.ones(scale_shape, dtype=torch.float32)
v_descale = torch.ones(scale_shape, dtype=torch.float32)
output = flash_attn_varlen_func( output = flash_attn_varlen_func(
q=query, q=maybe_quantized_query,
k=key_cache, k=maybe_quantized_key_cache,
v=value_cache, v=maybe_quantized_value_cache,
out=out, out=out,
cu_seqlens_q=cu_query_lens, cu_seqlens_q=cu_query_lens,
seqused_k=kv_lens, seqused_k=kv_lens,
@ -238,6 +292,9 @@ def test_varlen_with_paged_kv(
block_table=block_tables, block_table=block_tables,
softcap=soft_cap if soft_cap is not None else 0, softcap=soft_cap if soft_cap is not None else 0,
fa_version=fa_version, fa_version=fa_version,
q_descale=q_descale,
k_descale=k_descale,
v_descale=v_descale,
) )
output = output if not use_out else out output = output if not use_out else out
@ -252,5 +309,8 @@ def test_varlen_with_paged_kv(
sliding_window=sliding_window, sliding_window=sliding_window,
soft_cap=soft_cap, soft_cap=soft_cap,
) )
torch.testing.assert_close(output, ref_output, atol=2e-2, rtol=1e-2), \ atol, rtol = 1.5e-2, 1e-2
if q_dtype is not None:
atol, rtol = 1.5e-1, 1.5e-1
torch.testing.assert_close(output, ref_output, atol=atol, rtol=rtol), \
f"{torch.max(torch.abs(output - ref_output))}" f"{torch.max(torch.abs(output - ref_output))}"

View File

@ -4,12 +4,16 @@ from vllm.attention.backends.abstract import (AttentionBackend,
AttentionMetadata, AttentionMetadata,
AttentionMetadataBuilder, AttentionMetadataBuilder,
AttentionState, AttentionType) AttentionState, AttentionType)
from vllm.attention.backends.utils import get_flash_attn_version
from vllm.attention.layer import Attention from vllm.attention.layer import Attention
from vllm.attention.selector import get_attn_backend from vllm.attention.selector import get_attn_backend
__all__ = [ __all__ = [
"Attention", "AttentionBackend", "AttentionMetadata", "AttentionType", "Attention",
"AttentionMetadataBuilder", "Attention", "AttentionState", "AttentionBackend",
"get_attn_backend", "get_flash_attn_version" "AttentionMetadata",
"AttentionType",
"AttentionMetadataBuilder",
"Attention",
"AttentionState",
"get_attn_backend",
] ]

View File

@ -232,6 +232,7 @@ class AttentionMetadataBuilder(ABC, Generic[T]):
class AttentionLayer(Protocol): class AttentionLayer(Protocol):
_q_scale: torch.Tensor
_k_scale: torch.Tensor _k_scale: torch.Tensor
_v_scale: torch.Tensor _v_scale: torch.Tensor
_k_scale_float: float _k_scale_float: float

View File

@ -19,10 +19,10 @@ from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
# yapf: enable # yapf: enable
from vllm.attention.backends.utils import ( from vllm.attention.backends.utils import (
PAD_SLOT_ID, CommonAttentionState, compute_slot_mapping, PAD_SLOT_ID, CommonAttentionState, compute_slot_mapping,
compute_slot_mapping_start_idx, get_flash_attn_version, compute_slot_mapping_start_idx, get_num_prefill_decode_query_kv_tokens,
get_num_prefill_decode_query_kv_tokens, get_seq_len_block_table_args, get_seq_len_block_table_args, is_all_cross_attn_metadata_set,
is_all_cross_attn_metadata_set, is_all_encoder_attn_metadata_set, is_all_encoder_attn_metadata_set, is_block_tables_empty)
is_block_tables_empty) from vllm.fa_utils import get_flash_attn_version
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.multimodal import MultiModalPlaceholderMap from vllm.multimodal import MultiModalPlaceholderMap
from vllm.utils import async_tensor_h2d, make_tensor_with_pad from vllm.utils import async_tensor_h2d, make_tensor_with_pad
@ -630,9 +630,11 @@ class FlashAttentionImpl(AttentionImpl):
self.sliding_window = ((sliding_window - 1, self.sliding_window = ((sliding_window - 1,
0) if sliding_window is not None else (-1, -1)) 0) if sliding_window is not None else (-1, -1))
self.kv_cache_dtype = kv_cache_dtype self.kv_cache_dtype = kv_cache_dtype
if is_quantized_kv_cache(self.kv_cache_dtype): self.vllm_flash_attn_version = get_flash_attn_version()
if (is_quantized_kv_cache(self.kv_cache_dtype)
and self.vllm_flash_attn_version != 3):
raise NotImplementedError( raise NotImplementedError(
"FlashAttention with FP8 KV cache not yet supported") "Only FlashAttention3 supports FP8 KV cache")
if logits_soft_cap is None: if logits_soft_cap is None:
# In flash-attn, setting logits_soft_cap as 0 means no soft cap. # In flash-attn, setting logits_soft_cap as 0 means no soft cap.
logits_soft_cap = 0 logits_soft_cap = 0
@ -647,7 +649,6 @@ class FlashAttentionImpl(AttentionImpl):
f"Head size {head_size} is not supported by FlashAttention. " f"Head size {head_size} is not supported by FlashAttention. "
f"Supported head sizes are: {support_head_sizes}.") f"Supported head sizes are: {support_head_sizes}.")
self.attn_type = attn_type self.attn_type = attn_type
self.vllm_flash_attn_version = get_flash_attn_version()
def forward( def forward(
self, self,
@ -671,13 +672,19 @@ class FlashAttentionImpl(AttentionImpl):
for profiling run. for profiling run.
attn_metadata: Metadata for attention. attn_metadata: Metadata for attention.
NOTE: It in-place updates the output tensor. NOTE: It in-place updates the output tensor.
NOTE: FP8 quantization, flash-attn expect the size of
{q,k,v}_descale to be (num_sequences, num_kv_heads).
We use torch's .expand() to avoid duplicating values
""" """
# NOTE(woosuk): FlashAttention does not support FP8 KV cache.
assert layer._k_scale_float == 1.0 and layer._v_scale_float == 1.0, (
"key/v_scale is not supported in FlashAttention.")
assert output is not None, "Output tensor must be provided." assert output is not None, "Output tensor must be provided."
# NOTE(woosuk): FlashAttention2 does not support FP8 KV cache.
if self.vllm_flash_attn_version < 3 or output.dtype != torch.bfloat16:
assert (
layer._k_scale_float == 1.0 and layer._v_scale_float == 1.0), (
"key/v_scale is only supported in FlashAttention 3 with "
"base dtype bfloat16")
attn_type = self.attn_type attn_type = self.attn_type
if (attn_type == AttentionType.ENCODER if (attn_type == AttentionType.ENCODER
and (not attn_metadata.is_all_encoder_attn_metadata_set)): and (not attn_metadata.is_all_encoder_attn_metadata_set)):
@ -694,6 +701,7 @@ class FlashAttentionImpl(AttentionImpl):
window_size = self.sliding_window window_size = self.sliding_window
alibi_slopes: Optional[torch.Tensor] = self.alibi_slopes alibi_slopes: Optional[torch.Tensor] = self.alibi_slopes
logits_soft_cap: Optional[float] = self.logits_soft_cap logits_soft_cap: Optional[float] = self.logits_soft_cap
fp8_attention = kv_cache_dtype.startswith("fp8")
if kv_cache.numel() > 0: if kv_cache.numel() > 0:
key_cache = kv_cache[0] key_cache = kv_cache[0]
@ -729,6 +737,19 @@ class FlashAttentionImpl(AttentionImpl):
layer._v_scale, layer._v_scale,
) )
if fp8_attention:
kv_cache = kv_cache.view(torch.float8_e4m3fn)
key_cache = key_cache.view(torch.float8_e4m3fn)
value_cache = value_cache.view(torch.float8_e4m3fn)
if fp8_attention:
num_tokens, num_heads, head_size = query.shape
query, _ = ops.scaled_fp8_quant(
query.reshape(
(num_tokens, num_heads * head_size)).contiguous(),
layer._q_scale)
query = query.reshape((num_tokens, num_heads, head_size))
(num_prefill_query_tokens, num_prefill_kv_tokens, (num_prefill_query_tokens, num_prefill_kv_tokens,
num_decode_query_tokens) = \ num_decode_query_tokens) = \
get_num_prefill_decode_query_kv_tokens(attn_metadata, attn_type) get_num_prefill_decode_query_kv_tokens(attn_metadata, attn_type)
@ -753,6 +774,23 @@ class FlashAttentionImpl(AttentionImpl):
key = key[:num_prefill_kv_tokens] key = key[:num_prefill_kv_tokens]
value = value[:num_prefill_kv_tokens] value = value[:num_prefill_kv_tokens]
if fp8_attention:
num_kv_tokens, num_kv_heads, head_size = key.shape
key, _ = ops.scaled_fp8_quant(
key.reshape((num_kv_tokens,
num_kv_heads * head_size)).contiguous(),
layer._k_scale)
key = key.reshape((num_kv_tokens, num_kv_heads, head_size))
value, _ = ops.scaled_fp8_quant(
value.reshape((num_kv_tokens,
num_kv_heads * head_size)).contiguous(),
layer._v_scale)
value = value.reshape(
(num_kv_tokens, num_kv_heads, head_size))
descale_shape = (q_seq_start_loc.shape[0] - 1, key.shape[1])
flash_attn_varlen_func( flash_attn_varlen_func(
q=query, q=query,
k=key, k=key,
@ -768,13 +806,19 @@ class FlashAttentionImpl(AttentionImpl):
softcap=logits_soft_cap, softcap=logits_soft_cap,
out=prefill_output, out=prefill_output,
fa_version=self.vllm_flash_attn_version, fa_version=self.vllm_flash_attn_version,
q_descale=layer._q_scale.expand(descale_shape),
k_descale=layer._k_scale.expand(descale_shape),
v_descale=layer._v_scale.expand(descale_shape),
) )
else: else:
# prefix-enabled attention # prefix-enabled attention
assert attn_type == AttentionType.DECODER, ( assert attn_type == AttentionType.DECODER, (
"Only decoder-only models support prefix caching") "Only decoder-only models support prefix caching")
assert prefill_meta.seq_lens is not None assert prefill_meta.seq_lens is not None
assert prefill_meta.query_start_loc is not None
max_seq_len = max(prefill_meta.seq_lens) max_seq_len = max(prefill_meta.seq_lens)
descale_shape = (prefill_meta.query_start_loc.shape[0] - 1,
key.shape[1])
flash_attn_varlen_func( # noqa flash_attn_varlen_func( # noqa
q=query, q=query,
k=key_cache, k=key_cache,
@ -791,6 +835,9 @@ class FlashAttentionImpl(AttentionImpl):
softcap=logits_soft_cap, softcap=logits_soft_cap,
out=prefill_output, out=prefill_output,
fa_version=self.vllm_flash_attn_version, fa_version=self.vllm_flash_attn_version,
q_descale=layer._q_scale.expand(descale_shape),
k_descale=layer._k_scale.expand(descale_shape),
v_descale=layer._v_scale.expand(descale_shape),
) )
if decode_meta := attn_metadata.decode_metadata: if decode_meta := attn_metadata.decode_metadata:
@ -804,6 +851,9 @@ class FlashAttentionImpl(AttentionImpl):
assert attn_type == AttentionType.DECODER, ( assert attn_type == AttentionType.DECODER, (
"Only decoder-only models support max_decode_query_len > 1" "Only decoder-only models support max_decode_query_len > 1"
) )
assert decode_meta.query_start_loc is not None
descale_shape = (decode_meta.query_start_loc.shape[0] - 1,
key.shape[1])
flash_attn_varlen_func( flash_attn_varlen_func(
q=decode_query, q=decode_query,
k=key_cache, k=key_cache,
@ -820,6 +870,9 @@ class FlashAttentionImpl(AttentionImpl):
block_table=decode_meta.block_tables, block_table=decode_meta.block_tables,
out=decode_output, out=decode_output,
fa_version=self.vllm_flash_attn_version, fa_version=self.vllm_flash_attn_version,
q_descale=layer._q_scale.expand(descale_shape),
k_descale=layer._k_scale.expand(descale_shape),
v_descale=layer._v_scale.expand(descale_shape),
) )
else: else:
# Use flash_attn_with_kvcache for normal decoding. # Use flash_attn_with_kvcache for normal decoding.
@ -828,6 +881,7 @@ class FlashAttentionImpl(AttentionImpl):
_, _,
block_tables_arg, block_tables_arg,
) = get_seq_len_block_table_args(decode_meta, False, attn_type) ) = get_seq_len_block_table_args(decode_meta, False, attn_type)
descale_shape = (seq_lens_arg.shape[0], key_cache.shape[-2])
flash_attn_with_kvcache( flash_attn_with_kvcache(
q=decode_query.unsqueeze(1), q=decode_query.unsqueeze(1),
k_cache=key_cache, k_cache=key_cache,
@ -841,6 +895,9 @@ class FlashAttentionImpl(AttentionImpl):
softcap=logits_soft_cap, softcap=logits_soft_cap,
out=decode_output.unsqueeze(1), out=decode_output.unsqueeze(1),
fa_version=self.vllm_flash_attn_version, fa_version=self.vllm_flash_attn_version,
q_descale=layer._q_scale.expand(descale_shape),
k_descale=layer._k_scale.expand(descale_shape),
v_descale=layer._v_scale.expand(descale_shape),
) )
return output return output

View File

@ -203,9 +203,9 @@ from vllm.attention.backends.abstract import (AttentionBackend, AttentionLayer,
AttentionState, MLAAttentionImpl) AttentionState, MLAAttentionImpl)
from vllm.attention.backends.utils import (PAD_SLOT_ID, compute_slot_mapping, from vllm.attention.backends.utils import (PAD_SLOT_ID, compute_slot_mapping,
compute_slot_mapping_start_idx, compute_slot_mapping_start_idx,
get_flash_attn_version,
is_block_tables_empty) is_block_tables_empty)
from vllm.attention.ops.triton_merge_attn_states import merge_attn_states from vllm.attention.ops.triton_merge_attn_states import merge_attn_states
from vllm.fa_utils import get_flash_attn_version
from vllm.model_executor.layers.linear import (ColumnParallelLinear, from vllm.model_executor.layers.linear import (ColumnParallelLinear,
LinearBase, RowParallelLinear, LinearBase, RowParallelLinear,
UnquantizedLinearMethod) UnquantizedLinearMethod)

View File

@ -8,13 +8,11 @@ from typing import TYPE_CHECKING, Any, Dict, List, Tuple, Type, TypeVar, Union
import numpy as np import numpy as np
import torch import torch
from vllm import envs
from vllm.attention import (AttentionMetadata, AttentionMetadataBuilder, from vllm.attention import (AttentionMetadata, AttentionMetadataBuilder,
AttentionState) AttentionState)
from vllm.attention.backends.abstract import AttentionType from vllm.attention.backends.abstract import AttentionType
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.multimodal import MultiModalPlaceholderMap from vllm.multimodal import MultiModalPlaceholderMap
from vllm.platforms import current_platform
from vllm.utils import async_tensor_h2d, make_tensor_with_pad from vllm.utils import async_tensor_h2d, make_tensor_with_pad
logger = init_logger(__name__) logger = init_logger(__name__)
@ -585,35 +583,3 @@ def get_num_prefill_decode_query_kv_tokens(
return (num_prefill_query_tokens, num_prefill_kv_tokens, return (num_prefill_query_tokens, num_prefill_kv_tokens,
num_decode_query_tokens) num_decode_query_tokens)
def get_flash_attn_version():
try:
from vllm.vllm_flash_attn.flash_attn_interface import (
fa_version_unsupported_reason, is_fa_version_supported)
# if hopper default to FA3, otherwise stick to FA2 for now
# TODO(lucas): profile FA3 on ampere to see if it makes sense to
# use FA3 as default for both
if current_platform.get_device_capability()[0] == 9:
fa_version = 3 if is_fa_version_supported(3) else 2
else:
fa_version = 2
if envs.VLLM_FLASH_ATTN_VERSION is not None:
assert envs.VLLM_FLASH_ATTN_VERSION in [2, 3]
fa_version = envs.VLLM_FLASH_ATTN_VERSION
if (current_platform.get_device_capability()[0] == 10
and envs.VLLM_FLASH_ATTN_VERSION == 3):
logger.warning("Cannot use FA version 3 on Blackwell platform",
"defaulting to FA version 2.")
fa_version = 2
if not is_fa_version_supported(fa_version):
logger.error("Cannot use FA version %d is not supported due to %s",
fa_version, fa_version_unsupported_reason(fa_version))
assert is_fa_version_supported(fa_version)
return fa_version
except (ImportError, AssertionError):
return None

View File

@ -84,6 +84,9 @@ class Attention(nn.Module):
self.calculate_kv_scales = calculate_kv_scales self.calculate_kv_scales = calculate_kv_scales
self._k_scale = torch.tensor(1.0, dtype=torch.float32) self._k_scale = torch.tensor(1.0, dtype=torch.float32)
self._v_scale = torch.tensor(1.0, dtype=torch.float32) self._v_scale = torch.tensor(1.0, dtype=torch.float32)
# FlashAttn doesn't support quantizing the kv-cache only
# but requires q to be quantized as well.
self._q_scale = torch.tensor(1.0, dtype=torch.float32)
# We also keep the float32 versions of k/v_scale for attention # We also keep the float32 versions of k/v_scale for attention
# backends that don't support tensors (Flashinfer) # backends that don't support tensors (Flashinfer)
@ -153,6 +156,7 @@ class Attention(nn.Module):
).parallel_config.pipeline_parallel_size) ).parallel_config.pipeline_parallel_size)
] ]
self.q_range = torch.tensor(envs.Q_SCALE_CONSTANT, dtype=torch.float32)
self.k_range = torch.tensor(envs.K_SCALE_CONSTANT, dtype=torch.float32) self.k_range = torch.tensor(envs.K_SCALE_CONSTANT, dtype=torch.float32)
self.v_range = torch.tensor(envs.V_SCALE_CONSTANT, dtype=torch.float32) self.v_range = torch.tensor(envs.V_SCALE_CONSTANT, dtype=torch.float32)
@ -178,7 +182,7 @@ class Attention(nn.Module):
if self.calculate_kv_scales: if self.calculate_kv_scales:
attn_metadata = get_forward_context().attn_metadata attn_metadata = get_forward_context().attn_metadata
if attn_metadata.enable_kv_scales_calculation: if attn_metadata.enable_kv_scales_calculation:
self.calc_kv_scales(key, value) self.calc_kv_scales(query, key, value)
if self.use_output: if self.use_output:
output_shape = (output_shape output_shape = (output_shape
if output_shape is not None else query.shape) if output_shape is not None else query.shape)
@ -225,7 +229,8 @@ class Attention(nn.Module):
return torch.ops.vllm.unified_attention( return torch.ops.vllm.unified_attention(
query, key, value, self.layer_name) query, key, value, self.layer_name)
def calc_kv_scales(self, key, value): def calc_kv_scales(self, query, key, value):
self._q_scale.copy_(torch.abs(query).max() / self.q_range)
self._k_scale.copy_(torch.abs(key).max() / self.k_range) self._k_scale.copy_(torch.abs(key).max() / self.k_range)
self._v_scale.copy_(torch.abs(value).max() / self.v_range) self._v_scale.copy_(torch.abs(value).max() / self.v_range)
self._k_scale_float = self._k_scale.item() self._k_scale_float = self._k_scale.item()

View File

@ -78,6 +78,7 @@ if TYPE_CHECKING:
VLLM_ENABLE_V1_MULTIPROCESSING: bool = True VLLM_ENABLE_V1_MULTIPROCESSING: bool = True
VLLM_LOG_BATCHSIZE_INTERVAL: float = -1 VLLM_LOG_BATCHSIZE_INTERVAL: float = -1
VLLM_DISABLE_COMPILE_CACHE: bool = False VLLM_DISABLE_COMPILE_CACHE: bool = False
Q_SCALE_CONSTANT: int = 200
K_SCALE_CONSTANT: int = 200 K_SCALE_CONSTANT: int = 200
V_SCALE_CONSTANT: int = 100 V_SCALE_CONSTANT: int = 100
VLLM_SERVER_DEV_MODE: bool = False VLLM_SERVER_DEV_MODE: bool = False
@ -524,13 +525,17 @@ environment_variables: dict[str, Callable[[], Any]] = {
# Pad the fp8 weights to 256 bytes for ROCm # Pad the fp8 weights to 256 bytes for ROCm
"VLLM_ROCM_FP8_PADDING": "VLLM_ROCM_FP8_PADDING":
lambda: bool(int(os.getenv("VLLM_ROCM_FP8_PADDING", "1"))), lambda: bool(int(os.getenv("VLLM_ROCM_FP8_PADDING", "1"))),
# Divisor for dynamic query scale factor calculation for FP8 KV Cache
"Q_SCALE_CONSTANT":
lambda: int(os.getenv("Q_SCALE_CONSTANT", "200")),
# Divisor for dynamic key scale factor calculation for FP8 KV Cache # Divisor for dynamic key scale factor calculation for FP8 KV Cache
"K_SCALE_CONSTANT": "K_SCALE_CONSTANT":
lambda: int(os.getenv("K_SCALE_CONSTANT", "200")), lambda: int(os.getenv("K_SCALE_CONSTANT", "200")),
# Divisor for dynamic value scale factor calculation for FP8 KV Cache # Divisor for dynamic value scale factor calculation for FP8 KV Cache
"V_SCALE_CONSTANT": "V_SCALE_CONSTANT":
lambda: int(os.getenv("V_SCALE_CONSTANT", "100")), lambda: int(os.getenv("V_SCALE_CONSTANT", "100")),
# If set, enable multiprocessing in LLM for the V1 code path. # If set, enable multiprocessing in LLM for the V1 code path.
"VLLM_ENABLE_V1_MULTIPROCESSING": "VLLM_ENABLE_V1_MULTIPROCESSING":
lambda: bool(int(os.getenv("VLLM_ENABLE_V1_MULTIPROCESSING", "1"))), lambda: bool(int(os.getenv("VLLM_ENABLE_V1_MULTIPROCESSING", "1"))),

42
vllm/fa_utils.py Normal file
View File

@ -0,0 +1,42 @@
# SPDX-License-Identifier: Apache-2.0
from typing import Optional
from vllm import envs
from vllm.logger import init_logger
logger = init_logger(__name__)
def get_flash_attn_version() -> Optional[int]:
# import here to avoid circular dependencies
from vllm.platforms import current_platform
try:
from vllm.vllm_flash_attn.flash_attn_interface import (
fa_version_unsupported_reason, is_fa_version_supported)
device_capability = current_platform.get_device_capability()
assert device_capability is not None
# 1. default version depending on platform
fa_version = 3 if (device_capability.major == 9
and is_fa_version_supported(3)) else 2
# 2. override if passed by environment
if envs.VLLM_FLASH_ATTN_VERSION is not None:
assert envs.VLLM_FLASH_ATTN_VERSION in [2, 3]
fa_version = envs.VLLM_FLASH_ATTN_VERSION
# 3. fallback for unsupported combinations
if device_capability.major == 10 and fa_version == 3:
logger.warning("Cannot use FA version 3 on Blackwell platform",
"defaulting to FA version 2.")
fa_version = 2
if not is_fa_version_supported(fa_version):
logger.error("Cannot use FA version %d is not supported due to %s",
fa_version, fa_version_unsupported_reason(fa_version))
assert is_fa_version_supported(fa_version)
return fa_version
except (ImportError, AssertionError):
return None

View File

@ -26,11 +26,14 @@ class BaseKVCacheMethod(QuantizeMethodBase):
def create_weights(self, layer: torch.nn.Module): def create_weights(self, layer: torch.nn.Module):
""" """
Create "weight" (aka k_scale and v_scale) for an attention layer. Create "weight" (aka q_scale, k_scale and v_scale)
for an attention layer.
""" """
# Initialize the KV cache scales to -1.0, which is an invalid value. # Initialize the Q and KV cache scales to -1.0, an invalid value.
# If the k/v_scale appears in the checkpoint, it will be # If the q and k/v_scales appear in the checkpoint, it will be
# overwritten when loading weights. # overwritten when loading weights.
layer.q_scale = torch.nn.Parameter(torch.tensor(-1.0),
requires_grad=False)
layer.k_scale = torch.nn.Parameter(torch.tensor(-1.0), layer.k_scale = torch.nn.Parameter(torch.tensor(-1.0),
requires_grad=False) requires_grad=False)
layer.v_scale = torch.nn.Parameter(torch.tensor(-1.0), layer.v_scale = torch.nn.Parameter(torch.tensor(-1.0),
@ -75,6 +78,13 @@ class BaseKVCacheMethod(QuantizeMethodBase):
raise ValueError("Only support per-tensor scaling factor " raise ValueError("Only support per-tensor scaling factor "
"for fp8 KV cache") "for fp8 KV cache")
if layer.q_scale < 0.0:
logger.warning_once(
"Checkpoint does not provide a q scaling factor. "
"Setting it to k_scale. This only matters for "
"the flash-attn backend.")
layer._q_scale.copy_(k_scale)
# These are used in the final Attention.forward() # These are used in the final Attention.forward()
layer._k_scale.copy_(k_scale) layer._k_scale.copy_(k_scale)
layer._v_scale.copy_(v_scale) layer._v_scale.copy_(v_scale)

View File

@ -14,6 +14,7 @@ from typing_extensions import ParamSpec
# import custom ops, trigger op registration # import custom ops, trigger op registration
import vllm._C # noqa import vllm._C # noqa
import vllm.envs as envs import vllm.envs as envs
from vllm.fa_utils import get_flash_attn_version
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.utils import import_pynvml from vllm.utils import import_pynvml
@ -240,15 +241,6 @@ class CudaPlatformBase(Platform):
"Cannot use FlashAttention-2 backend for dtype other than " "Cannot use FlashAttention-2 backend for dtype other than "
"torch.float16 or torch.bfloat16.") "torch.float16 or torch.bfloat16.")
target_backend = _Backend.XFORMERS target_backend = _Backend.XFORMERS
elif kv_cache_dtype is not None and \
kv_cache_dtype.startswith("fp8"):
logger.info(
"Cannot use FlashAttention-2 backend for FP8 KV cache.")
logger.warning(
"Please use FlashInfer backend with FP8 KV Cache for "
"better performance by setting environment variable "
"VLLM_ATTENTION_BACKEND=FLASHINFER")
target_backend = _Backend.XFORMERS
elif block_size % 16 != 0: elif block_size % 16 != 0:
logger.info( logger.info(
"Cannot use FlashAttention-2 backend for block size not " "Cannot use FlashAttention-2 backend for block size not "
@ -270,6 +262,17 @@ class CudaPlatformBase(Platform):
"Cannot use FlashAttention-2 backend for head size %d.", "Cannot use FlashAttention-2 backend for head size %d.",
head_size) head_size)
target_backend = _Backend.XFORMERS target_backend = _Backend.XFORMERS
fp8_kv_cache = (kv_cache_dtype is not None
and kv_cache_dtype.startswith("fp8"))
if (fp8_kv_cache and get_flash_attn_version() != 3):
logger.info(
"Cannot use FlashAttention-2 backend for FP8 KV cache."
)
logger.warning(
"Please use FlashInfer backend with FP8 KV Cache for "
"better performance by setting environment variable "
"VLLM_ATTENTION_BACKEND=FLASHINFER")
target_backend = _Backend.XFORMERS
except ImportError: except ImportError:
logger.info( logger.info(
"Cannot use FlashAttention-2 backend because the " "Cannot use FlashAttention-2 backend because the "

View File

@ -6,11 +6,12 @@ from typing import TYPE_CHECKING, Any, Optional
import numpy as np import numpy as np
import torch import torch
from vllm import _custom_ops as ops
from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl, from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
AttentionMetadata, AttentionType, AttentionMetadata, AttentionType,
is_quantized_kv_cache) is_quantized_kv_cache)
from vllm.attention.backends.utils import get_flash_attn_version
from vllm.attention.ops.triton_merge_attn_states import merge_attn_states from vllm.attention.ops.triton_merge_attn_states import merge_attn_states
from vllm.fa_utils import get_flash_attn_version
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.platforms import current_platform from vllm.platforms import current_platform
from vllm.utils import cdiv from vllm.utils import cdiv
@ -226,6 +227,9 @@ class FlashAttentionImpl(AttentionImpl):
attn_metadata: Metadata for attention. attn_metadata: Metadata for attention.
Returns: Returns:
shape = [num_tokens, num_heads * head_size] shape = [num_tokens, num_heads * head_size]
NOTE: FP8 quantization, flash-attn expect the size of
{q,k,v}_descale to be (num_sequences, num_kv_heads).
We use torch's .expand() to avoid duplicating values
""" """
assert output is not None, "Output tensor must be provided." assert output is not None, "Output tensor must be provided."
@ -259,6 +263,17 @@ class FlashAttentionImpl(AttentionImpl):
layer._k_scale, layer._k_scale,
layer._v_scale, layer._v_scale,
) )
descale_shape = (attn_metadata.query_start_loc.shape[0] - 1,
key.shape[1])
if self.kv_cache_dtype.startswith("fp8"):
key_cache = key_cache.view(torch.float8_e4m3fn)
value_cache = value_cache.view(torch.float8_e4m3fn)
num_tokens, num_heads, head_size = query.shape
query, _ = ops.scaled_fp8_quant(
query.reshape(
(num_tokens, num_heads * head_size)).contiguous(),
layer._q_scale)
query = query.reshape((num_tokens, num_heads, head_size))
# Compute attention and update output up to `num_actual_tokens`. # Compute attention and update output up to `num_actual_tokens`.
if not attn_metadata.use_cascade: if not attn_metadata.use_cascade:
@ -279,6 +294,9 @@ class FlashAttentionImpl(AttentionImpl):
block_table=attn_metadata.block_table, block_table=attn_metadata.block_table,
softcap=self.logits_soft_cap, softcap=self.logits_soft_cap,
fa_version=self.vllm_flash_attn_version, fa_version=self.vllm_flash_attn_version,
q_descale=layer._q_scale.expand(descale_shape),
k_descale=layer._k_scale.expand(descale_shape),
v_descale=layer._v_scale.expand(descale_shape),
) )
return output return output
@ -301,6 +319,9 @@ class FlashAttentionImpl(AttentionImpl):
block_table=attn_metadata.block_table, block_table=attn_metadata.block_table,
common_prefix_len=attn_metadata.common_prefix_len, common_prefix_len=attn_metadata.common_prefix_len,
fa_version=self.vllm_flash_attn_version, fa_version=self.vllm_flash_attn_version,
q_descale=layer._q_scale,
k_descale=layer._k_scale,
v_descale=layer._v_scale,
) )
return output return output
@ -391,6 +412,9 @@ def cascade_attention(
block_table: torch.Tensor, block_table: torch.Tensor,
common_prefix_len: int, common_prefix_len: int,
fa_version: int, fa_version: int,
q_descale: Optional[torch.Tensor] = None,
k_descale: Optional[torch.Tensor] = None,
v_descale: Optional[torch.Tensor] = None,
) -> torch.Tensor: ) -> torch.Tensor:
assert alibi_slopes is None, ("Cascade attention does not support ALiBi.") assert alibi_slopes is None, ("Cascade attention does not support ALiBi.")
# TODO: Support sliding window. # TODO: Support sliding window.
@ -402,6 +426,7 @@ def cascade_attention(
assert common_prefix_len % block_size == 0 assert common_prefix_len % block_size == 0
num_common_kv_blocks = common_prefix_len // block_size num_common_kv_blocks = common_prefix_len // block_size
assert num_common_kv_blocks > 0 assert num_common_kv_blocks > 0
descale_shape = (cu_prefix_query_lens.shape[0] - 1, key_cache.shape[-2])
# Process shared prefix. # Process shared prefix.
prefix_output, prefix_lse = flash_attn_varlen_func( prefix_output, prefix_lse = flash_attn_varlen_func(
@ -419,8 +444,16 @@ def cascade_attention(
softcap=logits_soft_cap, softcap=logits_soft_cap,
return_softmax_lse=True, return_softmax_lse=True,
fa_version=fa_version, fa_version=fa_version,
q_descale=q_descale.expand(descale_shape)
if q_descale is not None else None,
k_descale=k_descale.expand(descale_shape)
if k_descale is not None else None,
v_descale=v_descale.expand(descale_shape)
if v_descale is not None else None,
) )
descale_shape = (cu_query_lens.shape[0] - 1, key_cache.shape[-2])
# Process suffix per query. # Process suffix per query.
suffix_output, suffix_lse = flash_attn_varlen_func( suffix_output, suffix_lse = flash_attn_varlen_func(
q=query, q=query,
@ -437,6 +470,12 @@ def cascade_attention(
softcap=logits_soft_cap, softcap=logits_soft_cap,
return_softmax_lse=True, return_softmax_lse=True,
fa_version=fa_version, fa_version=fa_version,
q_descale=q_descale.expand(descale_shape)
if q_descale is not None else None,
k_descale=k_descale.expand(descale_shape)
if k_descale is not None else None,
v_descale=v_descale.expand(descale_shape)
if v_descale is not None else None,
) )
# Merge prefix and suffix outputs, and store the result in output. # Merge prefix and suffix outputs, and store the result in output.

View File

@ -5,6 +5,7 @@ import pickle
import signal import signal
import sys import sys
import time import time
import traceback
import weakref import weakref
from dataclasses import dataclass from dataclasses import dataclass
from enum import Enum, auto from enum import Enum, auto
@ -370,6 +371,9 @@ class WorkerProc:
func = partial(cloudpickle.loads(method), self.worker) func = partial(cloudpickle.loads(method), self.worker)
output = func(*args, **kwargs) output = func(*args, **kwargs)
except Exception as e: except Exception as e:
# Notes have been introduced in python 3.11
if hasattr(e, "add_note"):
e.add_note(traceback.format_exc())
self.worker_response_mq.enqueue( self.worker_response_mq.enqueue(
(WorkerProc.ResponseStatus.FAILURE, e)) (WorkerProc.ResponseStatus.FAILURE, e))
logger.exception("WorkerProc hit an exception: %s", exc_info=e) logger.exception("WorkerProc hit an exception: %s", exc_info=e)

View File

@ -1558,7 +1558,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
block_size=block_size, block_size=block_size,
num_kv_heads=attn_module.num_kv_heads, num_kv_heads=attn_module.num_kv_heads,
head_size=attn_module.head_size, head_size=attn_module.head_size,
dtype=attn_module.dtype, dtype=self.kv_cache_dtype,
use_mla=use_mla) use_mla=use_mla)
elif attn_module.attn_type in (AttentionType.ENCODER, elif attn_module.attn_type in (AttentionType.ENCODER,
AttentionType.ENCODER_ONLY): AttentionType.ENCODER_ONLY):