[Attention] Flash Attention 3 - fp8 (#14570)
Signed-off-by: Mickael Seznec <mickael@mistral.ai>
This commit is contained in:
parent
ae65f3e237
commit
a597a57595
@ -38,7 +38,7 @@ else()
|
||||
FetchContent_Declare(
|
||||
vllm-flash-attn
|
||||
GIT_REPOSITORY https://github.com/vllm-project/flash-attention.git
|
||||
GIT_TAG 9bfa9869829d8c593527eb34c5271d0090f7ccc9
|
||||
GIT_TAG dc9d410b3e2d6534a4c70724c2515f4def670a22
|
||||
GIT_PROGRESS TRUE
|
||||
# Don't share the vllm-flash-attn build between build types
|
||||
BINARY_DIR ${CMAKE_BINARY_DIR}/vllm-flash-attn
|
||||
|
@ -15,6 +15,7 @@ NUM_HEADS = [(4, 4), (8, 2), (16, 2)]
|
||||
HEAD_SIZES = [128, 256]
|
||||
BLOCK_SIZES = [16, 32]
|
||||
DTYPES = [torch.float16, torch.bfloat16]
|
||||
QDTYPES = [None, torch.float8_e4m3fn]
|
||||
# one value large enough to test overflow in index calculation.
|
||||
# one value small enough to test the schema op check
|
||||
NUM_BLOCKS = [32768, 2048]
|
||||
@ -85,6 +86,7 @@ def ref_paged_attn(
|
||||
@pytest.mark.parametrize("num_blocks", NUM_BLOCKS)
|
||||
@pytest.mark.parametrize("sliding_window", [None, 256])
|
||||
@pytest.mark.parametrize("fa_version", [2, 3])
|
||||
@pytest.mark.parametrize("q_dtype", QDTYPES)
|
||||
@torch.inference_mode()
|
||||
def test_flash_attn_with_paged_kv(
|
||||
use_out: bool,
|
||||
@ -97,11 +99,15 @@ def test_flash_attn_with_paged_kv(
|
||||
num_blocks: int,
|
||||
sliding_window: Optional[int],
|
||||
fa_version: int,
|
||||
q_dtype: Optional[torch.dtype],
|
||||
) -> None:
|
||||
torch.set_default_device("cuda")
|
||||
if not is_fa_version_supported(fa_version):
|
||||
pytest.skip(f"Flash attention version {fa_version} not supported due "
|
||||
f"to: \"{fa_version_unsupported_reason(fa_version)}\"")
|
||||
if q_dtype is not None and (dtype != torch.bfloat16 or fa_version == 2):
|
||||
pytest.skip("Flash attention with quantized inputs is only "
|
||||
"supported on version 3 with bfloat16 base type")
|
||||
|
||||
current_platform.seed_everything(0)
|
||||
num_seqs = len(kv_lens)
|
||||
@ -130,10 +136,28 @@ def test_flash_attn_with_paged_kv(
|
||||
|
||||
q = query.unsqueeze(1)
|
||||
out = torch.empty_like(q) if use_out else None
|
||||
|
||||
maybe_quantized_query = q
|
||||
maybe_quantized_key_cache = key_cache
|
||||
maybe_quantized_value_cache = value_cache
|
||||
q_descale = None
|
||||
k_descale = None
|
||||
v_descale = None
|
||||
if q_dtype is not None:
|
||||
# QKV are drawn from N(0, 1): no need for a fp8 scaling factor
|
||||
maybe_quantized_query = query.to(q_dtype)
|
||||
maybe_quantized_key_cache = key_cache.to(q_dtype)
|
||||
maybe_quantized_value_cache = value_cache.to(q_dtype)
|
||||
|
||||
scale_shape = (num_seqs, num_kv_heads)
|
||||
q_descale = torch.ones(scale_shape, dtype=torch.float32)
|
||||
k_descale = torch.ones(scale_shape, dtype=torch.float32)
|
||||
v_descale = torch.ones(scale_shape, dtype=torch.float32)
|
||||
|
||||
output = flash_attn_with_kvcache(
|
||||
q=q,
|
||||
k_cache=key_cache,
|
||||
v_cache=value_cache,
|
||||
q=maybe_quantized_query,
|
||||
k_cache=maybe_quantized_key_cache,
|
||||
v_cache=maybe_quantized_value_cache,
|
||||
out=out,
|
||||
softmax_scale=scale,
|
||||
causal=True,
|
||||
@ -142,10 +166,17 @@ def test_flash_attn_with_paged_kv(
|
||||
softcap=soft_cap if soft_cap is not None else 0,
|
||||
window_size=window_size,
|
||||
fa_version=fa_version,
|
||||
q_descale=q_descale,
|
||||
k_descale=k_descale,
|
||||
v_descale=v_descale,
|
||||
)
|
||||
output = output if not use_out else out
|
||||
output = output.squeeze(1)
|
||||
|
||||
atol, rtol = 1.5e-2, 1e-2
|
||||
if q_dtype is not None:
|
||||
atol, rtol = 1.5e-1, 1.5e-1
|
||||
|
||||
ref_output = ref_paged_attn(query=query,
|
||||
key_cache=key_cache,
|
||||
value_cache=value_cache,
|
||||
@ -155,7 +186,7 @@ def test_flash_attn_with_paged_kv(
|
||||
scale=scale,
|
||||
soft_cap=soft_cap,
|
||||
sliding_window=sliding_window)
|
||||
torch.testing.assert_close(output, ref_output, atol=2e-2, rtol=1e-2), \
|
||||
torch.testing.assert_close(output, ref_output, atol=atol, rtol=rtol), \
|
||||
f"{torch.max(torch.abs(output - ref_output))}"
|
||||
|
||||
|
||||
@ -171,6 +202,7 @@ def test_flash_attn_with_paged_kv(
|
||||
@pytest.mark.parametrize("soft_cap", [None, 10.0, 50.0])
|
||||
@pytest.mark.parametrize("num_blocks", NUM_BLOCKS)
|
||||
@pytest.mark.parametrize("fa_version", [2, 3])
|
||||
@pytest.mark.parametrize("q_dtype", QDTYPES)
|
||||
@torch.inference_mode()
|
||||
def test_varlen_with_paged_kv(
|
||||
use_out: bool,
|
||||
@ -183,11 +215,15 @@ def test_varlen_with_paged_kv(
|
||||
soft_cap: Optional[float],
|
||||
num_blocks: int,
|
||||
fa_version: int,
|
||||
q_dtype: Optional[torch.dtype],
|
||||
) -> None:
|
||||
torch.set_default_device("cuda")
|
||||
if not is_fa_version_supported(fa_version):
|
||||
pytest.skip(f"Flash attention version {fa_version} not supported due "
|
||||
f"to: \"{fa_version_unsupported_reason(fa_version)}\"")
|
||||
if q_dtype is not None and (dtype != torch.bfloat16 or fa_version == 2):
|
||||
pytest.skip("Flash attention with quantized inputs is only "
|
||||
"supported on version 3 with bfloat16 base type")
|
||||
current_platform.seed_everything(0)
|
||||
num_seqs = len(seq_lens)
|
||||
query_lens = [x[0] for x in seq_lens]
|
||||
@ -223,10 +259,28 @@ def test_varlen_with_paged_kv(
|
||||
dtype=torch.int32)
|
||||
|
||||
out = torch.empty_like(query) if use_out else None
|
||||
|
||||
maybe_quantized_query = query
|
||||
maybe_quantized_key_cache = key_cache
|
||||
maybe_quantized_value_cache = value_cache
|
||||
q_descale = None
|
||||
k_descale = None
|
||||
v_descale = None
|
||||
if q_dtype is not None:
|
||||
# QKV are drawn from N(0, 1): no need for a fp8 scaling factor
|
||||
maybe_quantized_query = query.to(q_dtype)
|
||||
maybe_quantized_key_cache = key_cache.to(q_dtype)
|
||||
maybe_quantized_value_cache = value_cache.to(q_dtype)
|
||||
|
||||
scale_shape = (num_seqs, num_kv_heads)
|
||||
q_descale = torch.ones(scale_shape, dtype=torch.float32)
|
||||
k_descale = torch.ones(scale_shape, dtype=torch.float32)
|
||||
v_descale = torch.ones(scale_shape, dtype=torch.float32)
|
||||
|
||||
output = flash_attn_varlen_func(
|
||||
q=query,
|
||||
k=key_cache,
|
||||
v=value_cache,
|
||||
q=maybe_quantized_query,
|
||||
k=maybe_quantized_key_cache,
|
||||
v=maybe_quantized_value_cache,
|
||||
out=out,
|
||||
cu_seqlens_q=cu_query_lens,
|
||||
seqused_k=kv_lens,
|
||||
@ -238,6 +292,9 @@ def test_varlen_with_paged_kv(
|
||||
block_table=block_tables,
|
||||
softcap=soft_cap if soft_cap is not None else 0,
|
||||
fa_version=fa_version,
|
||||
q_descale=q_descale,
|
||||
k_descale=k_descale,
|
||||
v_descale=v_descale,
|
||||
)
|
||||
output = output if not use_out else out
|
||||
|
||||
@ -252,5 +309,8 @@ def test_varlen_with_paged_kv(
|
||||
sliding_window=sliding_window,
|
||||
soft_cap=soft_cap,
|
||||
)
|
||||
torch.testing.assert_close(output, ref_output, atol=2e-2, rtol=1e-2), \
|
||||
atol, rtol = 1.5e-2, 1e-2
|
||||
if q_dtype is not None:
|
||||
atol, rtol = 1.5e-1, 1.5e-1
|
||||
torch.testing.assert_close(output, ref_output, atol=atol, rtol=rtol), \
|
||||
f"{torch.max(torch.abs(output - ref_output))}"
|
||||
|
@ -4,12 +4,16 @@ from vllm.attention.backends.abstract import (AttentionBackend,
|
||||
AttentionMetadata,
|
||||
AttentionMetadataBuilder,
|
||||
AttentionState, AttentionType)
|
||||
from vllm.attention.backends.utils import get_flash_attn_version
|
||||
from vllm.attention.layer import Attention
|
||||
from vllm.attention.selector import get_attn_backend
|
||||
|
||||
__all__ = [
|
||||
"Attention", "AttentionBackend", "AttentionMetadata", "AttentionType",
|
||||
"AttentionMetadataBuilder", "Attention", "AttentionState",
|
||||
"get_attn_backend", "get_flash_attn_version"
|
||||
"Attention",
|
||||
"AttentionBackend",
|
||||
"AttentionMetadata",
|
||||
"AttentionType",
|
||||
"AttentionMetadataBuilder",
|
||||
"Attention",
|
||||
"AttentionState",
|
||||
"get_attn_backend",
|
||||
]
|
||||
|
@ -232,6 +232,7 @@ class AttentionMetadataBuilder(ABC, Generic[T]):
|
||||
|
||||
class AttentionLayer(Protocol):
|
||||
|
||||
_q_scale: torch.Tensor
|
||||
_k_scale: torch.Tensor
|
||||
_v_scale: torch.Tensor
|
||||
_k_scale_float: float
|
||||
|
@ -19,10 +19,10 @@ from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
|
||||
# yapf: enable
|
||||
from vllm.attention.backends.utils import (
|
||||
PAD_SLOT_ID, CommonAttentionState, compute_slot_mapping,
|
||||
compute_slot_mapping_start_idx, get_flash_attn_version,
|
||||
get_num_prefill_decode_query_kv_tokens, get_seq_len_block_table_args,
|
||||
is_all_cross_attn_metadata_set, is_all_encoder_attn_metadata_set,
|
||||
is_block_tables_empty)
|
||||
compute_slot_mapping_start_idx, get_num_prefill_decode_query_kv_tokens,
|
||||
get_seq_len_block_table_args, is_all_cross_attn_metadata_set,
|
||||
is_all_encoder_attn_metadata_set, is_block_tables_empty)
|
||||
from vllm.fa_utils import get_flash_attn_version
|
||||
from vllm.logger import init_logger
|
||||
from vllm.multimodal import MultiModalPlaceholderMap
|
||||
from vllm.utils import async_tensor_h2d, make_tensor_with_pad
|
||||
@ -630,9 +630,11 @@ class FlashAttentionImpl(AttentionImpl):
|
||||
self.sliding_window = ((sliding_window - 1,
|
||||
0) if sliding_window is not None else (-1, -1))
|
||||
self.kv_cache_dtype = kv_cache_dtype
|
||||
if is_quantized_kv_cache(self.kv_cache_dtype):
|
||||
self.vllm_flash_attn_version = get_flash_attn_version()
|
||||
if (is_quantized_kv_cache(self.kv_cache_dtype)
|
||||
and self.vllm_flash_attn_version != 3):
|
||||
raise NotImplementedError(
|
||||
"FlashAttention with FP8 KV cache not yet supported")
|
||||
"Only FlashAttention3 supports FP8 KV cache")
|
||||
if logits_soft_cap is None:
|
||||
# In flash-attn, setting logits_soft_cap as 0 means no soft cap.
|
||||
logits_soft_cap = 0
|
||||
@ -647,7 +649,6 @@ class FlashAttentionImpl(AttentionImpl):
|
||||
f"Head size {head_size} is not supported by FlashAttention. "
|
||||
f"Supported head sizes are: {support_head_sizes}.")
|
||||
self.attn_type = attn_type
|
||||
self.vllm_flash_attn_version = get_flash_attn_version()
|
||||
|
||||
def forward(
|
||||
self,
|
||||
@ -671,13 +672,19 @@ class FlashAttentionImpl(AttentionImpl):
|
||||
for profiling run.
|
||||
attn_metadata: Metadata for attention.
|
||||
NOTE: It in-place updates the output tensor.
|
||||
NOTE: FP8 quantization, flash-attn expect the size of
|
||||
{q,k,v}_descale to be (num_sequences, num_kv_heads).
|
||||
We use torch's .expand() to avoid duplicating values
|
||||
"""
|
||||
# NOTE(woosuk): FlashAttention does not support FP8 KV cache.
|
||||
assert layer._k_scale_float == 1.0 and layer._v_scale_float == 1.0, (
|
||||
"key/v_scale is not supported in FlashAttention.")
|
||||
|
||||
assert output is not None, "Output tensor must be provided."
|
||||
|
||||
# NOTE(woosuk): FlashAttention2 does not support FP8 KV cache.
|
||||
if self.vllm_flash_attn_version < 3 or output.dtype != torch.bfloat16:
|
||||
assert (
|
||||
layer._k_scale_float == 1.0 and layer._v_scale_float == 1.0), (
|
||||
"key/v_scale is only supported in FlashAttention 3 with "
|
||||
"base dtype bfloat16")
|
||||
|
||||
attn_type = self.attn_type
|
||||
if (attn_type == AttentionType.ENCODER
|
||||
and (not attn_metadata.is_all_encoder_attn_metadata_set)):
|
||||
@ -694,6 +701,7 @@ class FlashAttentionImpl(AttentionImpl):
|
||||
window_size = self.sliding_window
|
||||
alibi_slopes: Optional[torch.Tensor] = self.alibi_slopes
|
||||
logits_soft_cap: Optional[float] = self.logits_soft_cap
|
||||
fp8_attention = kv_cache_dtype.startswith("fp8")
|
||||
|
||||
if kv_cache.numel() > 0:
|
||||
key_cache = kv_cache[0]
|
||||
@ -729,6 +737,19 @@ class FlashAttentionImpl(AttentionImpl):
|
||||
layer._v_scale,
|
||||
)
|
||||
|
||||
if fp8_attention:
|
||||
kv_cache = kv_cache.view(torch.float8_e4m3fn)
|
||||
key_cache = key_cache.view(torch.float8_e4m3fn)
|
||||
value_cache = value_cache.view(torch.float8_e4m3fn)
|
||||
|
||||
if fp8_attention:
|
||||
num_tokens, num_heads, head_size = query.shape
|
||||
query, _ = ops.scaled_fp8_quant(
|
||||
query.reshape(
|
||||
(num_tokens, num_heads * head_size)).contiguous(),
|
||||
layer._q_scale)
|
||||
query = query.reshape((num_tokens, num_heads, head_size))
|
||||
|
||||
(num_prefill_query_tokens, num_prefill_kv_tokens,
|
||||
num_decode_query_tokens) = \
|
||||
get_num_prefill_decode_query_kv_tokens(attn_metadata, attn_type)
|
||||
@ -753,6 +774,23 @@ class FlashAttentionImpl(AttentionImpl):
|
||||
key = key[:num_prefill_kv_tokens]
|
||||
value = value[:num_prefill_kv_tokens]
|
||||
|
||||
if fp8_attention:
|
||||
num_kv_tokens, num_kv_heads, head_size = key.shape
|
||||
|
||||
key, _ = ops.scaled_fp8_quant(
|
||||
key.reshape((num_kv_tokens,
|
||||
num_kv_heads * head_size)).contiguous(),
|
||||
layer._k_scale)
|
||||
key = key.reshape((num_kv_tokens, num_kv_heads, head_size))
|
||||
|
||||
value, _ = ops.scaled_fp8_quant(
|
||||
value.reshape((num_kv_tokens,
|
||||
num_kv_heads * head_size)).contiguous(),
|
||||
layer._v_scale)
|
||||
value = value.reshape(
|
||||
(num_kv_tokens, num_kv_heads, head_size))
|
||||
|
||||
descale_shape = (q_seq_start_loc.shape[0] - 1, key.shape[1])
|
||||
flash_attn_varlen_func(
|
||||
q=query,
|
||||
k=key,
|
||||
@ -768,13 +806,19 @@ class FlashAttentionImpl(AttentionImpl):
|
||||
softcap=logits_soft_cap,
|
||||
out=prefill_output,
|
||||
fa_version=self.vllm_flash_attn_version,
|
||||
q_descale=layer._q_scale.expand(descale_shape),
|
||||
k_descale=layer._k_scale.expand(descale_shape),
|
||||
v_descale=layer._v_scale.expand(descale_shape),
|
||||
)
|
||||
else:
|
||||
# prefix-enabled attention
|
||||
assert attn_type == AttentionType.DECODER, (
|
||||
"Only decoder-only models support prefix caching")
|
||||
assert prefill_meta.seq_lens is not None
|
||||
assert prefill_meta.query_start_loc is not None
|
||||
max_seq_len = max(prefill_meta.seq_lens)
|
||||
descale_shape = (prefill_meta.query_start_loc.shape[0] - 1,
|
||||
key.shape[1])
|
||||
flash_attn_varlen_func( # noqa
|
||||
q=query,
|
||||
k=key_cache,
|
||||
@ -791,6 +835,9 @@ class FlashAttentionImpl(AttentionImpl):
|
||||
softcap=logits_soft_cap,
|
||||
out=prefill_output,
|
||||
fa_version=self.vllm_flash_attn_version,
|
||||
q_descale=layer._q_scale.expand(descale_shape),
|
||||
k_descale=layer._k_scale.expand(descale_shape),
|
||||
v_descale=layer._v_scale.expand(descale_shape),
|
||||
)
|
||||
|
||||
if decode_meta := attn_metadata.decode_metadata:
|
||||
@ -804,6 +851,9 @@ class FlashAttentionImpl(AttentionImpl):
|
||||
assert attn_type == AttentionType.DECODER, (
|
||||
"Only decoder-only models support max_decode_query_len > 1"
|
||||
)
|
||||
assert decode_meta.query_start_loc is not None
|
||||
descale_shape = (decode_meta.query_start_loc.shape[0] - 1,
|
||||
key.shape[1])
|
||||
flash_attn_varlen_func(
|
||||
q=decode_query,
|
||||
k=key_cache,
|
||||
@ -820,6 +870,9 @@ class FlashAttentionImpl(AttentionImpl):
|
||||
block_table=decode_meta.block_tables,
|
||||
out=decode_output,
|
||||
fa_version=self.vllm_flash_attn_version,
|
||||
q_descale=layer._q_scale.expand(descale_shape),
|
||||
k_descale=layer._k_scale.expand(descale_shape),
|
||||
v_descale=layer._v_scale.expand(descale_shape),
|
||||
)
|
||||
else:
|
||||
# Use flash_attn_with_kvcache for normal decoding.
|
||||
@ -828,6 +881,7 @@ class FlashAttentionImpl(AttentionImpl):
|
||||
_,
|
||||
block_tables_arg,
|
||||
) = get_seq_len_block_table_args(decode_meta, False, attn_type)
|
||||
descale_shape = (seq_lens_arg.shape[0], key_cache.shape[-2])
|
||||
flash_attn_with_kvcache(
|
||||
q=decode_query.unsqueeze(1),
|
||||
k_cache=key_cache,
|
||||
@ -841,6 +895,9 @@ class FlashAttentionImpl(AttentionImpl):
|
||||
softcap=logits_soft_cap,
|
||||
out=decode_output.unsqueeze(1),
|
||||
fa_version=self.vllm_flash_attn_version,
|
||||
q_descale=layer._q_scale.expand(descale_shape),
|
||||
k_descale=layer._k_scale.expand(descale_shape),
|
||||
v_descale=layer._v_scale.expand(descale_shape),
|
||||
)
|
||||
return output
|
||||
|
||||
|
@ -203,9 +203,9 @@ from vllm.attention.backends.abstract import (AttentionBackend, AttentionLayer,
|
||||
AttentionState, MLAAttentionImpl)
|
||||
from vllm.attention.backends.utils import (PAD_SLOT_ID, compute_slot_mapping,
|
||||
compute_slot_mapping_start_idx,
|
||||
get_flash_attn_version,
|
||||
is_block_tables_empty)
|
||||
from vllm.attention.ops.triton_merge_attn_states import merge_attn_states
|
||||
from vllm.fa_utils import get_flash_attn_version
|
||||
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
|
||||
LinearBase, RowParallelLinear,
|
||||
UnquantizedLinearMethod)
|
||||
|
@ -8,13 +8,11 @@ from typing import TYPE_CHECKING, Any, Dict, List, Tuple, Type, TypeVar, Union
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from vllm import envs
|
||||
from vllm.attention import (AttentionMetadata, AttentionMetadataBuilder,
|
||||
AttentionState)
|
||||
from vllm.attention.backends.abstract import AttentionType
|
||||
from vllm.logger import init_logger
|
||||
from vllm.multimodal import MultiModalPlaceholderMap
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.utils import async_tensor_h2d, make_tensor_with_pad
|
||||
|
||||
logger = init_logger(__name__)
|
||||
@ -585,35 +583,3 @@ def get_num_prefill_decode_query_kv_tokens(
|
||||
|
||||
return (num_prefill_query_tokens, num_prefill_kv_tokens,
|
||||
num_decode_query_tokens)
|
||||
|
||||
|
||||
def get_flash_attn_version():
|
||||
try:
|
||||
from vllm.vllm_flash_attn.flash_attn_interface import (
|
||||
fa_version_unsupported_reason, is_fa_version_supported)
|
||||
|
||||
# if hopper default to FA3, otherwise stick to FA2 for now
|
||||
# TODO(lucas): profile FA3 on ampere to see if it makes sense to
|
||||
# use FA3 as default for both
|
||||
if current_platform.get_device_capability()[0] == 9:
|
||||
fa_version = 3 if is_fa_version_supported(3) else 2
|
||||
else:
|
||||
fa_version = 2
|
||||
|
||||
if envs.VLLM_FLASH_ATTN_VERSION is not None:
|
||||
assert envs.VLLM_FLASH_ATTN_VERSION in [2, 3]
|
||||
fa_version = envs.VLLM_FLASH_ATTN_VERSION
|
||||
if (current_platform.get_device_capability()[0] == 10
|
||||
and envs.VLLM_FLASH_ATTN_VERSION == 3):
|
||||
logger.warning("Cannot use FA version 3 on Blackwell platform",
|
||||
"defaulting to FA version 2.")
|
||||
fa_version = 2
|
||||
|
||||
if not is_fa_version_supported(fa_version):
|
||||
logger.error("Cannot use FA version %d is not supported due to %s",
|
||||
fa_version, fa_version_unsupported_reason(fa_version))
|
||||
|
||||
assert is_fa_version_supported(fa_version)
|
||||
return fa_version
|
||||
except (ImportError, AssertionError):
|
||||
return None
|
||||
|
@ -84,6 +84,9 @@ class Attention(nn.Module):
|
||||
self.calculate_kv_scales = calculate_kv_scales
|
||||
self._k_scale = torch.tensor(1.0, dtype=torch.float32)
|
||||
self._v_scale = torch.tensor(1.0, dtype=torch.float32)
|
||||
# FlashAttn doesn't support quantizing the kv-cache only
|
||||
# but requires q to be quantized as well.
|
||||
self._q_scale = torch.tensor(1.0, dtype=torch.float32)
|
||||
|
||||
# We also keep the float32 versions of k/v_scale for attention
|
||||
# backends that don't support tensors (Flashinfer)
|
||||
@ -153,6 +156,7 @@ class Attention(nn.Module):
|
||||
).parallel_config.pipeline_parallel_size)
|
||||
]
|
||||
|
||||
self.q_range = torch.tensor(envs.Q_SCALE_CONSTANT, dtype=torch.float32)
|
||||
self.k_range = torch.tensor(envs.K_SCALE_CONSTANT, dtype=torch.float32)
|
||||
self.v_range = torch.tensor(envs.V_SCALE_CONSTANT, dtype=torch.float32)
|
||||
|
||||
@ -178,7 +182,7 @@ class Attention(nn.Module):
|
||||
if self.calculate_kv_scales:
|
||||
attn_metadata = get_forward_context().attn_metadata
|
||||
if attn_metadata.enable_kv_scales_calculation:
|
||||
self.calc_kv_scales(key, value)
|
||||
self.calc_kv_scales(query, key, value)
|
||||
if self.use_output:
|
||||
output_shape = (output_shape
|
||||
if output_shape is not None else query.shape)
|
||||
@ -225,7 +229,8 @@ class Attention(nn.Module):
|
||||
return torch.ops.vllm.unified_attention(
|
||||
query, key, value, self.layer_name)
|
||||
|
||||
def calc_kv_scales(self, key, value):
|
||||
def calc_kv_scales(self, query, key, value):
|
||||
self._q_scale.copy_(torch.abs(query).max() / self.q_range)
|
||||
self._k_scale.copy_(torch.abs(key).max() / self.k_range)
|
||||
self._v_scale.copy_(torch.abs(value).max() / self.v_range)
|
||||
self._k_scale_float = self._k_scale.item()
|
||||
|
@ -78,6 +78,7 @@ if TYPE_CHECKING:
|
||||
VLLM_ENABLE_V1_MULTIPROCESSING: bool = True
|
||||
VLLM_LOG_BATCHSIZE_INTERVAL: float = -1
|
||||
VLLM_DISABLE_COMPILE_CACHE: bool = False
|
||||
Q_SCALE_CONSTANT: int = 200
|
||||
K_SCALE_CONSTANT: int = 200
|
||||
V_SCALE_CONSTANT: int = 100
|
||||
VLLM_SERVER_DEV_MODE: bool = False
|
||||
@ -524,13 +525,17 @@ environment_variables: dict[str, Callable[[], Any]] = {
|
||||
# Pad the fp8 weights to 256 bytes for ROCm
|
||||
"VLLM_ROCM_FP8_PADDING":
|
||||
lambda: bool(int(os.getenv("VLLM_ROCM_FP8_PADDING", "1"))),
|
||||
|
||||
# Divisor for dynamic query scale factor calculation for FP8 KV Cache
|
||||
"Q_SCALE_CONSTANT":
|
||||
lambda: int(os.getenv("Q_SCALE_CONSTANT", "200")),
|
||||
# Divisor for dynamic key scale factor calculation for FP8 KV Cache
|
||||
"K_SCALE_CONSTANT":
|
||||
lambda: int(os.getenv("K_SCALE_CONSTANT", "200")),
|
||||
|
||||
# Divisor for dynamic value scale factor calculation for FP8 KV Cache
|
||||
"V_SCALE_CONSTANT":
|
||||
lambda: int(os.getenv("V_SCALE_CONSTANT", "100")),
|
||||
|
||||
# If set, enable multiprocessing in LLM for the V1 code path.
|
||||
"VLLM_ENABLE_V1_MULTIPROCESSING":
|
||||
lambda: bool(int(os.getenv("VLLM_ENABLE_V1_MULTIPROCESSING", "1"))),
|
||||
|
42
vllm/fa_utils.py
Normal file
42
vllm/fa_utils.py
Normal file
@ -0,0 +1,42 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
from typing import Optional
|
||||
|
||||
from vllm import envs
|
||||
from vllm.logger import init_logger
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
def get_flash_attn_version() -> Optional[int]:
|
||||
# import here to avoid circular dependencies
|
||||
from vllm.platforms import current_platform
|
||||
try:
|
||||
from vllm.vllm_flash_attn.flash_attn_interface import (
|
||||
fa_version_unsupported_reason, is_fa_version_supported)
|
||||
device_capability = current_platform.get_device_capability()
|
||||
|
||||
assert device_capability is not None
|
||||
|
||||
# 1. default version depending on platform
|
||||
fa_version = 3 if (device_capability.major == 9
|
||||
and is_fa_version_supported(3)) else 2
|
||||
|
||||
# 2. override if passed by environment
|
||||
if envs.VLLM_FLASH_ATTN_VERSION is not None:
|
||||
assert envs.VLLM_FLASH_ATTN_VERSION in [2, 3]
|
||||
fa_version = envs.VLLM_FLASH_ATTN_VERSION
|
||||
|
||||
# 3. fallback for unsupported combinations
|
||||
if device_capability.major == 10 and fa_version == 3:
|
||||
logger.warning("Cannot use FA version 3 on Blackwell platform",
|
||||
"defaulting to FA version 2.")
|
||||
fa_version = 2
|
||||
|
||||
if not is_fa_version_supported(fa_version):
|
||||
logger.error("Cannot use FA version %d is not supported due to %s",
|
||||
fa_version, fa_version_unsupported_reason(fa_version))
|
||||
|
||||
assert is_fa_version_supported(fa_version)
|
||||
return fa_version
|
||||
except (ImportError, AssertionError):
|
||||
return None
|
@ -26,11 +26,14 @@ class BaseKVCacheMethod(QuantizeMethodBase):
|
||||
|
||||
def create_weights(self, layer: torch.nn.Module):
|
||||
"""
|
||||
Create "weight" (aka k_scale and v_scale) for an attention layer.
|
||||
Create "weight" (aka q_scale, k_scale and v_scale)
|
||||
for an attention layer.
|
||||
"""
|
||||
# Initialize the KV cache scales to -1.0, which is an invalid value.
|
||||
# If the k/v_scale appears in the checkpoint, it will be
|
||||
# Initialize the Q and KV cache scales to -1.0, an invalid value.
|
||||
# If the q and k/v_scales appear in the checkpoint, it will be
|
||||
# overwritten when loading weights.
|
||||
layer.q_scale = torch.nn.Parameter(torch.tensor(-1.0),
|
||||
requires_grad=False)
|
||||
layer.k_scale = torch.nn.Parameter(torch.tensor(-1.0),
|
||||
requires_grad=False)
|
||||
layer.v_scale = torch.nn.Parameter(torch.tensor(-1.0),
|
||||
@ -75,6 +78,13 @@ class BaseKVCacheMethod(QuantizeMethodBase):
|
||||
raise ValueError("Only support per-tensor scaling factor "
|
||||
"for fp8 KV cache")
|
||||
|
||||
if layer.q_scale < 0.0:
|
||||
logger.warning_once(
|
||||
"Checkpoint does not provide a q scaling factor. "
|
||||
"Setting it to k_scale. This only matters for "
|
||||
"the flash-attn backend.")
|
||||
layer._q_scale.copy_(k_scale)
|
||||
|
||||
# These are used in the final Attention.forward()
|
||||
layer._k_scale.copy_(k_scale)
|
||||
layer._v_scale.copy_(v_scale)
|
||||
|
@ -14,6 +14,7 @@ from typing_extensions import ParamSpec
|
||||
# import custom ops, trigger op registration
|
||||
import vllm._C # noqa
|
||||
import vllm.envs as envs
|
||||
from vllm.fa_utils import get_flash_attn_version
|
||||
from vllm.logger import init_logger
|
||||
from vllm.utils import import_pynvml
|
||||
|
||||
@ -240,15 +241,6 @@ class CudaPlatformBase(Platform):
|
||||
"Cannot use FlashAttention-2 backend for dtype other than "
|
||||
"torch.float16 or torch.bfloat16.")
|
||||
target_backend = _Backend.XFORMERS
|
||||
elif kv_cache_dtype is not None and \
|
||||
kv_cache_dtype.startswith("fp8"):
|
||||
logger.info(
|
||||
"Cannot use FlashAttention-2 backend for FP8 KV cache.")
|
||||
logger.warning(
|
||||
"Please use FlashInfer backend with FP8 KV Cache for "
|
||||
"better performance by setting environment variable "
|
||||
"VLLM_ATTENTION_BACKEND=FLASHINFER")
|
||||
target_backend = _Backend.XFORMERS
|
||||
elif block_size % 16 != 0:
|
||||
logger.info(
|
||||
"Cannot use FlashAttention-2 backend for block size not "
|
||||
@ -270,6 +262,17 @@ class CudaPlatformBase(Platform):
|
||||
"Cannot use FlashAttention-2 backend for head size %d.",
|
||||
head_size)
|
||||
target_backend = _Backend.XFORMERS
|
||||
fp8_kv_cache = (kv_cache_dtype is not None
|
||||
and kv_cache_dtype.startswith("fp8"))
|
||||
if (fp8_kv_cache and get_flash_attn_version() != 3):
|
||||
logger.info(
|
||||
"Cannot use FlashAttention-2 backend for FP8 KV cache."
|
||||
)
|
||||
logger.warning(
|
||||
"Please use FlashInfer backend with FP8 KV Cache for "
|
||||
"better performance by setting environment variable "
|
||||
"VLLM_ATTENTION_BACKEND=FLASHINFER")
|
||||
target_backend = _Backend.XFORMERS
|
||||
except ImportError:
|
||||
logger.info(
|
||||
"Cannot use FlashAttention-2 backend because the "
|
||||
|
@ -6,11 +6,12 @@ from typing import TYPE_CHECKING, Any, Optional
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from vllm import _custom_ops as ops
|
||||
from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
|
||||
AttentionMetadata, AttentionType,
|
||||
is_quantized_kv_cache)
|
||||
from vllm.attention.backends.utils import get_flash_attn_version
|
||||
from vllm.attention.ops.triton_merge_attn_states import merge_attn_states
|
||||
from vllm.fa_utils import get_flash_attn_version
|
||||
from vllm.logger import init_logger
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.utils import cdiv
|
||||
@ -226,6 +227,9 @@ class FlashAttentionImpl(AttentionImpl):
|
||||
attn_metadata: Metadata for attention.
|
||||
Returns:
|
||||
shape = [num_tokens, num_heads * head_size]
|
||||
NOTE: FP8 quantization, flash-attn expect the size of
|
||||
{q,k,v}_descale to be (num_sequences, num_kv_heads).
|
||||
We use torch's .expand() to avoid duplicating values
|
||||
"""
|
||||
assert output is not None, "Output tensor must be provided."
|
||||
|
||||
@ -259,6 +263,17 @@ class FlashAttentionImpl(AttentionImpl):
|
||||
layer._k_scale,
|
||||
layer._v_scale,
|
||||
)
|
||||
descale_shape = (attn_metadata.query_start_loc.shape[0] - 1,
|
||||
key.shape[1])
|
||||
if self.kv_cache_dtype.startswith("fp8"):
|
||||
key_cache = key_cache.view(torch.float8_e4m3fn)
|
||||
value_cache = value_cache.view(torch.float8_e4m3fn)
|
||||
num_tokens, num_heads, head_size = query.shape
|
||||
query, _ = ops.scaled_fp8_quant(
|
||||
query.reshape(
|
||||
(num_tokens, num_heads * head_size)).contiguous(),
|
||||
layer._q_scale)
|
||||
query = query.reshape((num_tokens, num_heads, head_size))
|
||||
|
||||
# Compute attention and update output up to `num_actual_tokens`.
|
||||
if not attn_metadata.use_cascade:
|
||||
@ -279,6 +294,9 @@ class FlashAttentionImpl(AttentionImpl):
|
||||
block_table=attn_metadata.block_table,
|
||||
softcap=self.logits_soft_cap,
|
||||
fa_version=self.vllm_flash_attn_version,
|
||||
q_descale=layer._q_scale.expand(descale_shape),
|
||||
k_descale=layer._k_scale.expand(descale_shape),
|
||||
v_descale=layer._v_scale.expand(descale_shape),
|
||||
)
|
||||
return output
|
||||
|
||||
@ -301,6 +319,9 @@ class FlashAttentionImpl(AttentionImpl):
|
||||
block_table=attn_metadata.block_table,
|
||||
common_prefix_len=attn_metadata.common_prefix_len,
|
||||
fa_version=self.vllm_flash_attn_version,
|
||||
q_descale=layer._q_scale,
|
||||
k_descale=layer._k_scale,
|
||||
v_descale=layer._v_scale,
|
||||
)
|
||||
return output
|
||||
|
||||
@ -391,6 +412,9 @@ def cascade_attention(
|
||||
block_table: torch.Tensor,
|
||||
common_prefix_len: int,
|
||||
fa_version: int,
|
||||
q_descale: Optional[torch.Tensor] = None,
|
||||
k_descale: Optional[torch.Tensor] = None,
|
||||
v_descale: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
assert alibi_slopes is None, ("Cascade attention does not support ALiBi.")
|
||||
# TODO: Support sliding window.
|
||||
@ -402,6 +426,7 @@ def cascade_attention(
|
||||
assert common_prefix_len % block_size == 0
|
||||
num_common_kv_blocks = common_prefix_len // block_size
|
||||
assert num_common_kv_blocks > 0
|
||||
descale_shape = (cu_prefix_query_lens.shape[0] - 1, key_cache.shape[-2])
|
||||
|
||||
# Process shared prefix.
|
||||
prefix_output, prefix_lse = flash_attn_varlen_func(
|
||||
@ -419,8 +444,16 @@ def cascade_attention(
|
||||
softcap=logits_soft_cap,
|
||||
return_softmax_lse=True,
|
||||
fa_version=fa_version,
|
||||
q_descale=q_descale.expand(descale_shape)
|
||||
if q_descale is not None else None,
|
||||
k_descale=k_descale.expand(descale_shape)
|
||||
if k_descale is not None else None,
|
||||
v_descale=v_descale.expand(descale_shape)
|
||||
if v_descale is not None else None,
|
||||
)
|
||||
|
||||
descale_shape = (cu_query_lens.shape[0] - 1, key_cache.shape[-2])
|
||||
|
||||
# Process suffix per query.
|
||||
suffix_output, suffix_lse = flash_attn_varlen_func(
|
||||
q=query,
|
||||
@ -437,6 +470,12 @@ def cascade_attention(
|
||||
softcap=logits_soft_cap,
|
||||
return_softmax_lse=True,
|
||||
fa_version=fa_version,
|
||||
q_descale=q_descale.expand(descale_shape)
|
||||
if q_descale is not None else None,
|
||||
k_descale=k_descale.expand(descale_shape)
|
||||
if k_descale is not None else None,
|
||||
v_descale=v_descale.expand(descale_shape)
|
||||
if v_descale is not None else None,
|
||||
)
|
||||
|
||||
# Merge prefix and suffix outputs, and store the result in output.
|
||||
|
@ -5,6 +5,7 @@ import pickle
|
||||
import signal
|
||||
import sys
|
||||
import time
|
||||
import traceback
|
||||
import weakref
|
||||
from dataclasses import dataclass
|
||||
from enum import Enum, auto
|
||||
@ -370,6 +371,9 @@ class WorkerProc:
|
||||
func = partial(cloudpickle.loads(method), self.worker)
|
||||
output = func(*args, **kwargs)
|
||||
except Exception as e:
|
||||
# Notes have been introduced in python 3.11
|
||||
if hasattr(e, "add_note"):
|
||||
e.add_note(traceback.format_exc())
|
||||
self.worker_response_mq.enqueue(
|
||||
(WorkerProc.ResponseStatus.FAILURE, e))
|
||||
logger.exception("WorkerProc hit an exception: %s", exc_info=e)
|
||||
|
@ -1558,7 +1558,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
||||
block_size=block_size,
|
||||
num_kv_heads=attn_module.num_kv_heads,
|
||||
head_size=attn_module.head_size,
|
||||
dtype=attn_module.dtype,
|
||||
dtype=self.kv_cache_dtype,
|
||||
use_mla=use_mla)
|
||||
elif attn_module.attn_type in (AttentionType.ENCODER,
|
||||
AttentionType.ENCODER_ONLY):
|
||||
|
Loading…
x
Reference in New Issue
Block a user