[Attention] Flash Attention 3 - fp8 (#14570)
Signed-off-by: Mickael Seznec <mickael@mistral.ai>
This commit is contained in:
parent
ae65f3e237
commit
a597a57595
@ -38,7 +38,7 @@ else()
|
|||||||
FetchContent_Declare(
|
FetchContent_Declare(
|
||||||
vllm-flash-attn
|
vllm-flash-attn
|
||||||
GIT_REPOSITORY https://github.com/vllm-project/flash-attention.git
|
GIT_REPOSITORY https://github.com/vllm-project/flash-attention.git
|
||||||
GIT_TAG 9bfa9869829d8c593527eb34c5271d0090f7ccc9
|
GIT_TAG dc9d410b3e2d6534a4c70724c2515f4def670a22
|
||||||
GIT_PROGRESS TRUE
|
GIT_PROGRESS TRUE
|
||||||
# Don't share the vllm-flash-attn build between build types
|
# Don't share the vllm-flash-attn build between build types
|
||||||
BINARY_DIR ${CMAKE_BINARY_DIR}/vllm-flash-attn
|
BINARY_DIR ${CMAKE_BINARY_DIR}/vllm-flash-attn
|
||||||
|
@ -15,6 +15,7 @@ NUM_HEADS = [(4, 4), (8, 2), (16, 2)]
|
|||||||
HEAD_SIZES = [128, 256]
|
HEAD_SIZES = [128, 256]
|
||||||
BLOCK_SIZES = [16, 32]
|
BLOCK_SIZES = [16, 32]
|
||||||
DTYPES = [torch.float16, torch.bfloat16]
|
DTYPES = [torch.float16, torch.bfloat16]
|
||||||
|
QDTYPES = [None, torch.float8_e4m3fn]
|
||||||
# one value large enough to test overflow in index calculation.
|
# one value large enough to test overflow in index calculation.
|
||||||
# one value small enough to test the schema op check
|
# one value small enough to test the schema op check
|
||||||
NUM_BLOCKS = [32768, 2048]
|
NUM_BLOCKS = [32768, 2048]
|
||||||
@ -85,6 +86,7 @@ def ref_paged_attn(
|
|||||||
@pytest.mark.parametrize("num_blocks", NUM_BLOCKS)
|
@pytest.mark.parametrize("num_blocks", NUM_BLOCKS)
|
||||||
@pytest.mark.parametrize("sliding_window", [None, 256])
|
@pytest.mark.parametrize("sliding_window", [None, 256])
|
||||||
@pytest.mark.parametrize("fa_version", [2, 3])
|
@pytest.mark.parametrize("fa_version", [2, 3])
|
||||||
|
@pytest.mark.parametrize("q_dtype", QDTYPES)
|
||||||
@torch.inference_mode()
|
@torch.inference_mode()
|
||||||
def test_flash_attn_with_paged_kv(
|
def test_flash_attn_with_paged_kv(
|
||||||
use_out: bool,
|
use_out: bool,
|
||||||
@ -97,11 +99,15 @@ def test_flash_attn_with_paged_kv(
|
|||||||
num_blocks: int,
|
num_blocks: int,
|
||||||
sliding_window: Optional[int],
|
sliding_window: Optional[int],
|
||||||
fa_version: int,
|
fa_version: int,
|
||||||
|
q_dtype: Optional[torch.dtype],
|
||||||
) -> None:
|
) -> None:
|
||||||
torch.set_default_device("cuda")
|
torch.set_default_device("cuda")
|
||||||
if not is_fa_version_supported(fa_version):
|
if not is_fa_version_supported(fa_version):
|
||||||
pytest.skip(f"Flash attention version {fa_version} not supported due "
|
pytest.skip(f"Flash attention version {fa_version} not supported due "
|
||||||
f"to: \"{fa_version_unsupported_reason(fa_version)}\"")
|
f"to: \"{fa_version_unsupported_reason(fa_version)}\"")
|
||||||
|
if q_dtype is not None and (dtype != torch.bfloat16 or fa_version == 2):
|
||||||
|
pytest.skip("Flash attention with quantized inputs is only "
|
||||||
|
"supported on version 3 with bfloat16 base type")
|
||||||
|
|
||||||
current_platform.seed_everything(0)
|
current_platform.seed_everything(0)
|
||||||
num_seqs = len(kv_lens)
|
num_seqs = len(kv_lens)
|
||||||
@ -130,10 +136,28 @@ def test_flash_attn_with_paged_kv(
|
|||||||
|
|
||||||
q = query.unsqueeze(1)
|
q = query.unsqueeze(1)
|
||||||
out = torch.empty_like(q) if use_out else None
|
out = torch.empty_like(q) if use_out else None
|
||||||
|
|
||||||
|
maybe_quantized_query = q
|
||||||
|
maybe_quantized_key_cache = key_cache
|
||||||
|
maybe_quantized_value_cache = value_cache
|
||||||
|
q_descale = None
|
||||||
|
k_descale = None
|
||||||
|
v_descale = None
|
||||||
|
if q_dtype is not None:
|
||||||
|
# QKV are drawn from N(0, 1): no need for a fp8 scaling factor
|
||||||
|
maybe_quantized_query = query.to(q_dtype)
|
||||||
|
maybe_quantized_key_cache = key_cache.to(q_dtype)
|
||||||
|
maybe_quantized_value_cache = value_cache.to(q_dtype)
|
||||||
|
|
||||||
|
scale_shape = (num_seqs, num_kv_heads)
|
||||||
|
q_descale = torch.ones(scale_shape, dtype=torch.float32)
|
||||||
|
k_descale = torch.ones(scale_shape, dtype=torch.float32)
|
||||||
|
v_descale = torch.ones(scale_shape, dtype=torch.float32)
|
||||||
|
|
||||||
output = flash_attn_with_kvcache(
|
output = flash_attn_with_kvcache(
|
||||||
q=q,
|
q=maybe_quantized_query,
|
||||||
k_cache=key_cache,
|
k_cache=maybe_quantized_key_cache,
|
||||||
v_cache=value_cache,
|
v_cache=maybe_quantized_value_cache,
|
||||||
out=out,
|
out=out,
|
||||||
softmax_scale=scale,
|
softmax_scale=scale,
|
||||||
causal=True,
|
causal=True,
|
||||||
@ -142,10 +166,17 @@ def test_flash_attn_with_paged_kv(
|
|||||||
softcap=soft_cap if soft_cap is not None else 0,
|
softcap=soft_cap if soft_cap is not None else 0,
|
||||||
window_size=window_size,
|
window_size=window_size,
|
||||||
fa_version=fa_version,
|
fa_version=fa_version,
|
||||||
|
q_descale=q_descale,
|
||||||
|
k_descale=k_descale,
|
||||||
|
v_descale=v_descale,
|
||||||
)
|
)
|
||||||
output = output if not use_out else out
|
output = output if not use_out else out
|
||||||
output = output.squeeze(1)
|
output = output.squeeze(1)
|
||||||
|
|
||||||
|
atol, rtol = 1.5e-2, 1e-2
|
||||||
|
if q_dtype is not None:
|
||||||
|
atol, rtol = 1.5e-1, 1.5e-1
|
||||||
|
|
||||||
ref_output = ref_paged_attn(query=query,
|
ref_output = ref_paged_attn(query=query,
|
||||||
key_cache=key_cache,
|
key_cache=key_cache,
|
||||||
value_cache=value_cache,
|
value_cache=value_cache,
|
||||||
@ -155,7 +186,7 @@ def test_flash_attn_with_paged_kv(
|
|||||||
scale=scale,
|
scale=scale,
|
||||||
soft_cap=soft_cap,
|
soft_cap=soft_cap,
|
||||||
sliding_window=sliding_window)
|
sliding_window=sliding_window)
|
||||||
torch.testing.assert_close(output, ref_output, atol=2e-2, rtol=1e-2), \
|
torch.testing.assert_close(output, ref_output, atol=atol, rtol=rtol), \
|
||||||
f"{torch.max(torch.abs(output - ref_output))}"
|
f"{torch.max(torch.abs(output - ref_output))}"
|
||||||
|
|
||||||
|
|
||||||
@ -171,6 +202,7 @@ def test_flash_attn_with_paged_kv(
|
|||||||
@pytest.mark.parametrize("soft_cap", [None, 10.0, 50.0])
|
@pytest.mark.parametrize("soft_cap", [None, 10.0, 50.0])
|
||||||
@pytest.mark.parametrize("num_blocks", NUM_BLOCKS)
|
@pytest.mark.parametrize("num_blocks", NUM_BLOCKS)
|
||||||
@pytest.mark.parametrize("fa_version", [2, 3])
|
@pytest.mark.parametrize("fa_version", [2, 3])
|
||||||
|
@pytest.mark.parametrize("q_dtype", QDTYPES)
|
||||||
@torch.inference_mode()
|
@torch.inference_mode()
|
||||||
def test_varlen_with_paged_kv(
|
def test_varlen_with_paged_kv(
|
||||||
use_out: bool,
|
use_out: bool,
|
||||||
@ -183,11 +215,15 @@ def test_varlen_with_paged_kv(
|
|||||||
soft_cap: Optional[float],
|
soft_cap: Optional[float],
|
||||||
num_blocks: int,
|
num_blocks: int,
|
||||||
fa_version: int,
|
fa_version: int,
|
||||||
|
q_dtype: Optional[torch.dtype],
|
||||||
) -> None:
|
) -> None:
|
||||||
torch.set_default_device("cuda")
|
torch.set_default_device("cuda")
|
||||||
if not is_fa_version_supported(fa_version):
|
if not is_fa_version_supported(fa_version):
|
||||||
pytest.skip(f"Flash attention version {fa_version} not supported due "
|
pytest.skip(f"Flash attention version {fa_version} not supported due "
|
||||||
f"to: \"{fa_version_unsupported_reason(fa_version)}\"")
|
f"to: \"{fa_version_unsupported_reason(fa_version)}\"")
|
||||||
|
if q_dtype is not None and (dtype != torch.bfloat16 or fa_version == 2):
|
||||||
|
pytest.skip("Flash attention with quantized inputs is only "
|
||||||
|
"supported on version 3 with bfloat16 base type")
|
||||||
current_platform.seed_everything(0)
|
current_platform.seed_everything(0)
|
||||||
num_seqs = len(seq_lens)
|
num_seqs = len(seq_lens)
|
||||||
query_lens = [x[0] for x in seq_lens]
|
query_lens = [x[0] for x in seq_lens]
|
||||||
@ -223,10 +259,28 @@ def test_varlen_with_paged_kv(
|
|||||||
dtype=torch.int32)
|
dtype=torch.int32)
|
||||||
|
|
||||||
out = torch.empty_like(query) if use_out else None
|
out = torch.empty_like(query) if use_out else None
|
||||||
|
|
||||||
|
maybe_quantized_query = query
|
||||||
|
maybe_quantized_key_cache = key_cache
|
||||||
|
maybe_quantized_value_cache = value_cache
|
||||||
|
q_descale = None
|
||||||
|
k_descale = None
|
||||||
|
v_descale = None
|
||||||
|
if q_dtype is not None:
|
||||||
|
# QKV are drawn from N(0, 1): no need for a fp8 scaling factor
|
||||||
|
maybe_quantized_query = query.to(q_dtype)
|
||||||
|
maybe_quantized_key_cache = key_cache.to(q_dtype)
|
||||||
|
maybe_quantized_value_cache = value_cache.to(q_dtype)
|
||||||
|
|
||||||
|
scale_shape = (num_seqs, num_kv_heads)
|
||||||
|
q_descale = torch.ones(scale_shape, dtype=torch.float32)
|
||||||
|
k_descale = torch.ones(scale_shape, dtype=torch.float32)
|
||||||
|
v_descale = torch.ones(scale_shape, dtype=torch.float32)
|
||||||
|
|
||||||
output = flash_attn_varlen_func(
|
output = flash_attn_varlen_func(
|
||||||
q=query,
|
q=maybe_quantized_query,
|
||||||
k=key_cache,
|
k=maybe_quantized_key_cache,
|
||||||
v=value_cache,
|
v=maybe_quantized_value_cache,
|
||||||
out=out,
|
out=out,
|
||||||
cu_seqlens_q=cu_query_lens,
|
cu_seqlens_q=cu_query_lens,
|
||||||
seqused_k=kv_lens,
|
seqused_k=kv_lens,
|
||||||
@ -238,6 +292,9 @@ def test_varlen_with_paged_kv(
|
|||||||
block_table=block_tables,
|
block_table=block_tables,
|
||||||
softcap=soft_cap if soft_cap is not None else 0,
|
softcap=soft_cap if soft_cap is not None else 0,
|
||||||
fa_version=fa_version,
|
fa_version=fa_version,
|
||||||
|
q_descale=q_descale,
|
||||||
|
k_descale=k_descale,
|
||||||
|
v_descale=v_descale,
|
||||||
)
|
)
|
||||||
output = output if not use_out else out
|
output = output if not use_out else out
|
||||||
|
|
||||||
@ -252,5 +309,8 @@ def test_varlen_with_paged_kv(
|
|||||||
sliding_window=sliding_window,
|
sliding_window=sliding_window,
|
||||||
soft_cap=soft_cap,
|
soft_cap=soft_cap,
|
||||||
)
|
)
|
||||||
torch.testing.assert_close(output, ref_output, atol=2e-2, rtol=1e-2), \
|
atol, rtol = 1.5e-2, 1e-2
|
||||||
|
if q_dtype is not None:
|
||||||
|
atol, rtol = 1.5e-1, 1.5e-1
|
||||||
|
torch.testing.assert_close(output, ref_output, atol=atol, rtol=rtol), \
|
||||||
f"{torch.max(torch.abs(output - ref_output))}"
|
f"{torch.max(torch.abs(output - ref_output))}"
|
||||||
|
@ -4,12 +4,16 @@ from vllm.attention.backends.abstract import (AttentionBackend,
|
|||||||
AttentionMetadata,
|
AttentionMetadata,
|
||||||
AttentionMetadataBuilder,
|
AttentionMetadataBuilder,
|
||||||
AttentionState, AttentionType)
|
AttentionState, AttentionType)
|
||||||
from vllm.attention.backends.utils import get_flash_attn_version
|
|
||||||
from vllm.attention.layer import Attention
|
from vllm.attention.layer import Attention
|
||||||
from vllm.attention.selector import get_attn_backend
|
from vllm.attention.selector import get_attn_backend
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"Attention", "AttentionBackend", "AttentionMetadata", "AttentionType",
|
"Attention",
|
||||||
"AttentionMetadataBuilder", "Attention", "AttentionState",
|
"AttentionBackend",
|
||||||
"get_attn_backend", "get_flash_attn_version"
|
"AttentionMetadata",
|
||||||
|
"AttentionType",
|
||||||
|
"AttentionMetadataBuilder",
|
||||||
|
"Attention",
|
||||||
|
"AttentionState",
|
||||||
|
"get_attn_backend",
|
||||||
]
|
]
|
||||||
|
@ -232,6 +232,7 @@ class AttentionMetadataBuilder(ABC, Generic[T]):
|
|||||||
|
|
||||||
class AttentionLayer(Protocol):
|
class AttentionLayer(Protocol):
|
||||||
|
|
||||||
|
_q_scale: torch.Tensor
|
||||||
_k_scale: torch.Tensor
|
_k_scale: torch.Tensor
|
||||||
_v_scale: torch.Tensor
|
_v_scale: torch.Tensor
|
||||||
_k_scale_float: float
|
_k_scale_float: float
|
||||||
|
@ -19,10 +19,10 @@ from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
|
|||||||
# yapf: enable
|
# yapf: enable
|
||||||
from vllm.attention.backends.utils import (
|
from vllm.attention.backends.utils import (
|
||||||
PAD_SLOT_ID, CommonAttentionState, compute_slot_mapping,
|
PAD_SLOT_ID, CommonAttentionState, compute_slot_mapping,
|
||||||
compute_slot_mapping_start_idx, get_flash_attn_version,
|
compute_slot_mapping_start_idx, get_num_prefill_decode_query_kv_tokens,
|
||||||
get_num_prefill_decode_query_kv_tokens, get_seq_len_block_table_args,
|
get_seq_len_block_table_args, is_all_cross_attn_metadata_set,
|
||||||
is_all_cross_attn_metadata_set, is_all_encoder_attn_metadata_set,
|
is_all_encoder_attn_metadata_set, is_block_tables_empty)
|
||||||
is_block_tables_empty)
|
from vllm.fa_utils import get_flash_attn_version
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
from vllm.multimodal import MultiModalPlaceholderMap
|
from vllm.multimodal import MultiModalPlaceholderMap
|
||||||
from vllm.utils import async_tensor_h2d, make_tensor_with_pad
|
from vllm.utils import async_tensor_h2d, make_tensor_with_pad
|
||||||
@ -630,9 +630,11 @@ class FlashAttentionImpl(AttentionImpl):
|
|||||||
self.sliding_window = ((sliding_window - 1,
|
self.sliding_window = ((sliding_window - 1,
|
||||||
0) if sliding_window is not None else (-1, -1))
|
0) if sliding_window is not None else (-1, -1))
|
||||||
self.kv_cache_dtype = kv_cache_dtype
|
self.kv_cache_dtype = kv_cache_dtype
|
||||||
if is_quantized_kv_cache(self.kv_cache_dtype):
|
self.vllm_flash_attn_version = get_flash_attn_version()
|
||||||
|
if (is_quantized_kv_cache(self.kv_cache_dtype)
|
||||||
|
and self.vllm_flash_attn_version != 3):
|
||||||
raise NotImplementedError(
|
raise NotImplementedError(
|
||||||
"FlashAttention with FP8 KV cache not yet supported")
|
"Only FlashAttention3 supports FP8 KV cache")
|
||||||
if logits_soft_cap is None:
|
if logits_soft_cap is None:
|
||||||
# In flash-attn, setting logits_soft_cap as 0 means no soft cap.
|
# In flash-attn, setting logits_soft_cap as 0 means no soft cap.
|
||||||
logits_soft_cap = 0
|
logits_soft_cap = 0
|
||||||
@ -647,7 +649,6 @@ class FlashAttentionImpl(AttentionImpl):
|
|||||||
f"Head size {head_size} is not supported by FlashAttention. "
|
f"Head size {head_size} is not supported by FlashAttention. "
|
||||||
f"Supported head sizes are: {support_head_sizes}.")
|
f"Supported head sizes are: {support_head_sizes}.")
|
||||||
self.attn_type = attn_type
|
self.attn_type = attn_type
|
||||||
self.vllm_flash_attn_version = get_flash_attn_version()
|
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
@ -671,13 +672,19 @@ class FlashAttentionImpl(AttentionImpl):
|
|||||||
for profiling run.
|
for profiling run.
|
||||||
attn_metadata: Metadata for attention.
|
attn_metadata: Metadata for attention.
|
||||||
NOTE: It in-place updates the output tensor.
|
NOTE: It in-place updates the output tensor.
|
||||||
|
NOTE: FP8 quantization, flash-attn expect the size of
|
||||||
|
{q,k,v}_descale to be (num_sequences, num_kv_heads).
|
||||||
|
We use torch's .expand() to avoid duplicating values
|
||||||
"""
|
"""
|
||||||
# NOTE(woosuk): FlashAttention does not support FP8 KV cache.
|
|
||||||
assert layer._k_scale_float == 1.0 and layer._v_scale_float == 1.0, (
|
|
||||||
"key/v_scale is not supported in FlashAttention.")
|
|
||||||
|
|
||||||
assert output is not None, "Output tensor must be provided."
|
assert output is not None, "Output tensor must be provided."
|
||||||
|
|
||||||
|
# NOTE(woosuk): FlashAttention2 does not support FP8 KV cache.
|
||||||
|
if self.vllm_flash_attn_version < 3 or output.dtype != torch.bfloat16:
|
||||||
|
assert (
|
||||||
|
layer._k_scale_float == 1.0 and layer._v_scale_float == 1.0), (
|
||||||
|
"key/v_scale is only supported in FlashAttention 3 with "
|
||||||
|
"base dtype bfloat16")
|
||||||
|
|
||||||
attn_type = self.attn_type
|
attn_type = self.attn_type
|
||||||
if (attn_type == AttentionType.ENCODER
|
if (attn_type == AttentionType.ENCODER
|
||||||
and (not attn_metadata.is_all_encoder_attn_metadata_set)):
|
and (not attn_metadata.is_all_encoder_attn_metadata_set)):
|
||||||
@ -694,6 +701,7 @@ class FlashAttentionImpl(AttentionImpl):
|
|||||||
window_size = self.sliding_window
|
window_size = self.sliding_window
|
||||||
alibi_slopes: Optional[torch.Tensor] = self.alibi_slopes
|
alibi_slopes: Optional[torch.Tensor] = self.alibi_slopes
|
||||||
logits_soft_cap: Optional[float] = self.logits_soft_cap
|
logits_soft_cap: Optional[float] = self.logits_soft_cap
|
||||||
|
fp8_attention = kv_cache_dtype.startswith("fp8")
|
||||||
|
|
||||||
if kv_cache.numel() > 0:
|
if kv_cache.numel() > 0:
|
||||||
key_cache = kv_cache[0]
|
key_cache = kv_cache[0]
|
||||||
@ -729,6 +737,19 @@ class FlashAttentionImpl(AttentionImpl):
|
|||||||
layer._v_scale,
|
layer._v_scale,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if fp8_attention:
|
||||||
|
kv_cache = kv_cache.view(torch.float8_e4m3fn)
|
||||||
|
key_cache = key_cache.view(torch.float8_e4m3fn)
|
||||||
|
value_cache = value_cache.view(torch.float8_e4m3fn)
|
||||||
|
|
||||||
|
if fp8_attention:
|
||||||
|
num_tokens, num_heads, head_size = query.shape
|
||||||
|
query, _ = ops.scaled_fp8_quant(
|
||||||
|
query.reshape(
|
||||||
|
(num_tokens, num_heads * head_size)).contiguous(),
|
||||||
|
layer._q_scale)
|
||||||
|
query = query.reshape((num_tokens, num_heads, head_size))
|
||||||
|
|
||||||
(num_prefill_query_tokens, num_prefill_kv_tokens,
|
(num_prefill_query_tokens, num_prefill_kv_tokens,
|
||||||
num_decode_query_tokens) = \
|
num_decode_query_tokens) = \
|
||||||
get_num_prefill_decode_query_kv_tokens(attn_metadata, attn_type)
|
get_num_prefill_decode_query_kv_tokens(attn_metadata, attn_type)
|
||||||
@ -753,6 +774,23 @@ class FlashAttentionImpl(AttentionImpl):
|
|||||||
key = key[:num_prefill_kv_tokens]
|
key = key[:num_prefill_kv_tokens]
|
||||||
value = value[:num_prefill_kv_tokens]
|
value = value[:num_prefill_kv_tokens]
|
||||||
|
|
||||||
|
if fp8_attention:
|
||||||
|
num_kv_tokens, num_kv_heads, head_size = key.shape
|
||||||
|
|
||||||
|
key, _ = ops.scaled_fp8_quant(
|
||||||
|
key.reshape((num_kv_tokens,
|
||||||
|
num_kv_heads * head_size)).contiguous(),
|
||||||
|
layer._k_scale)
|
||||||
|
key = key.reshape((num_kv_tokens, num_kv_heads, head_size))
|
||||||
|
|
||||||
|
value, _ = ops.scaled_fp8_quant(
|
||||||
|
value.reshape((num_kv_tokens,
|
||||||
|
num_kv_heads * head_size)).contiguous(),
|
||||||
|
layer._v_scale)
|
||||||
|
value = value.reshape(
|
||||||
|
(num_kv_tokens, num_kv_heads, head_size))
|
||||||
|
|
||||||
|
descale_shape = (q_seq_start_loc.shape[0] - 1, key.shape[1])
|
||||||
flash_attn_varlen_func(
|
flash_attn_varlen_func(
|
||||||
q=query,
|
q=query,
|
||||||
k=key,
|
k=key,
|
||||||
@ -768,13 +806,19 @@ class FlashAttentionImpl(AttentionImpl):
|
|||||||
softcap=logits_soft_cap,
|
softcap=logits_soft_cap,
|
||||||
out=prefill_output,
|
out=prefill_output,
|
||||||
fa_version=self.vllm_flash_attn_version,
|
fa_version=self.vllm_flash_attn_version,
|
||||||
|
q_descale=layer._q_scale.expand(descale_shape),
|
||||||
|
k_descale=layer._k_scale.expand(descale_shape),
|
||||||
|
v_descale=layer._v_scale.expand(descale_shape),
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
# prefix-enabled attention
|
# prefix-enabled attention
|
||||||
assert attn_type == AttentionType.DECODER, (
|
assert attn_type == AttentionType.DECODER, (
|
||||||
"Only decoder-only models support prefix caching")
|
"Only decoder-only models support prefix caching")
|
||||||
assert prefill_meta.seq_lens is not None
|
assert prefill_meta.seq_lens is not None
|
||||||
|
assert prefill_meta.query_start_loc is not None
|
||||||
max_seq_len = max(prefill_meta.seq_lens)
|
max_seq_len = max(prefill_meta.seq_lens)
|
||||||
|
descale_shape = (prefill_meta.query_start_loc.shape[0] - 1,
|
||||||
|
key.shape[1])
|
||||||
flash_attn_varlen_func( # noqa
|
flash_attn_varlen_func( # noqa
|
||||||
q=query,
|
q=query,
|
||||||
k=key_cache,
|
k=key_cache,
|
||||||
@ -791,6 +835,9 @@ class FlashAttentionImpl(AttentionImpl):
|
|||||||
softcap=logits_soft_cap,
|
softcap=logits_soft_cap,
|
||||||
out=prefill_output,
|
out=prefill_output,
|
||||||
fa_version=self.vllm_flash_attn_version,
|
fa_version=self.vllm_flash_attn_version,
|
||||||
|
q_descale=layer._q_scale.expand(descale_shape),
|
||||||
|
k_descale=layer._k_scale.expand(descale_shape),
|
||||||
|
v_descale=layer._v_scale.expand(descale_shape),
|
||||||
)
|
)
|
||||||
|
|
||||||
if decode_meta := attn_metadata.decode_metadata:
|
if decode_meta := attn_metadata.decode_metadata:
|
||||||
@ -804,6 +851,9 @@ class FlashAttentionImpl(AttentionImpl):
|
|||||||
assert attn_type == AttentionType.DECODER, (
|
assert attn_type == AttentionType.DECODER, (
|
||||||
"Only decoder-only models support max_decode_query_len > 1"
|
"Only decoder-only models support max_decode_query_len > 1"
|
||||||
)
|
)
|
||||||
|
assert decode_meta.query_start_loc is not None
|
||||||
|
descale_shape = (decode_meta.query_start_loc.shape[0] - 1,
|
||||||
|
key.shape[1])
|
||||||
flash_attn_varlen_func(
|
flash_attn_varlen_func(
|
||||||
q=decode_query,
|
q=decode_query,
|
||||||
k=key_cache,
|
k=key_cache,
|
||||||
@ -820,6 +870,9 @@ class FlashAttentionImpl(AttentionImpl):
|
|||||||
block_table=decode_meta.block_tables,
|
block_table=decode_meta.block_tables,
|
||||||
out=decode_output,
|
out=decode_output,
|
||||||
fa_version=self.vllm_flash_attn_version,
|
fa_version=self.vllm_flash_attn_version,
|
||||||
|
q_descale=layer._q_scale.expand(descale_shape),
|
||||||
|
k_descale=layer._k_scale.expand(descale_shape),
|
||||||
|
v_descale=layer._v_scale.expand(descale_shape),
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
# Use flash_attn_with_kvcache for normal decoding.
|
# Use flash_attn_with_kvcache for normal decoding.
|
||||||
@ -828,6 +881,7 @@ class FlashAttentionImpl(AttentionImpl):
|
|||||||
_,
|
_,
|
||||||
block_tables_arg,
|
block_tables_arg,
|
||||||
) = get_seq_len_block_table_args(decode_meta, False, attn_type)
|
) = get_seq_len_block_table_args(decode_meta, False, attn_type)
|
||||||
|
descale_shape = (seq_lens_arg.shape[0], key_cache.shape[-2])
|
||||||
flash_attn_with_kvcache(
|
flash_attn_with_kvcache(
|
||||||
q=decode_query.unsqueeze(1),
|
q=decode_query.unsqueeze(1),
|
||||||
k_cache=key_cache,
|
k_cache=key_cache,
|
||||||
@ -841,6 +895,9 @@ class FlashAttentionImpl(AttentionImpl):
|
|||||||
softcap=logits_soft_cap,
|
softcap=logits_soft_cap,
|
||||||
out=decode_output.unsqueeze(1),
|
out=decode_output.unsqueeze(1),
|
||||||
fa_version=self.vllm_flash_attn_version,
|
fa_version=self.vllm_flash_attn_version,
|
||||||
|
q_descale=layer._q_scale.expand(descale_shape),
|
||||||
|
k_descale=layer._k_scale.expand(descale_shape),
|
||||||
|
v_descale=layer._v_scale.expand(descale_shape),
|
||||||
)
|
)
|
||||||
return output
|
return output
|
||||||
|
|
||||||
|
@ -203,9 +203,9 @@ from vllm.attention.backends.abstract import (AttentionBackend, AttentionLayer,
|
|||||||
AttentionState, MLAAttentionImpl)
|
AttentionState, MLAAttentionImpl)
|
||||||
from vllm.attention.backends.utils import (PAD_SLOT_ID, compute_slot_mapping,
|
from vllm.attention.backends.utils import (PAD_SLOT_ID, compute_slot_mapping,
|
||||||
compute_slot_mapping_start_idx,
|
compute_slot_mapping_start_idx,
|
||||||
get_flash_attn_version,
|
|
||||||
is_block_tables_empty)
|
is_block_tables_empty)
|
||||||
from vllm.attention.ops.triton_merge_attn_states import merge_attn_states
|
from vllm.attention.ops.triton_merge_attn_states import merge_attn_states
|
||||||
|
from vllm.fa_utils import get_flash_attn_version
|
||||||
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
|
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
|
||||||
LinearBase, RowParallelLinear,
|
LinearBase, RowParallelLinear,
|
||||||
UnquantizedLinearMethod)
|
UnquantizedLinearMethod)
|
||||||
|
@ -8,13 +8,11 @@ from typing import TYPE_CHECKING, Any, Dict, List, Tuple, Type, TypeVar, Union
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from vllm import envs
|
|
||||||
from vllm.attention import (AttentionMetadata, AttentionMetadataBuilder,
|
from vllm.attention import (AttentionMetadata, AttentionMetadataBuilder,
|
||||||
AttentionState)
|
AttentionState)
|
||||||
from vllm.attention.backends.abstract import AttentionType
|
from vllm.attention.backends.abstract import AttentionType
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
from vllm.multimodal import MultiModalPlaceholderMap
|
from vllm.multimodal import MultiModalPlaceholderMap
|
||||||
from vllm.platforms import current_platform
|
|
||||||
from vllm.utils import async_tensor_h2d, make_tensor_with_pad
|
from vllm.utils import async_tensor_h2d, make_tensor_with_pad
|
||||||
|
|
||||||
logger = init_logger(__name__)
|
logger = init_logger(__name__)
|
||||||
@ -585,35 +583,3 @@ def get_num_prefill_decode_query_kv_tokens(
|
|||||||
|
|
||||||
return (num_prefill_query_tokens, num_prefill_kv_tokens,
|
return (num_prefill_query_tokens, num_prefill_kv_tokens,
|
||||||
num_decode_query_tokens)
|
num_decode_query_tokens)
|
||||||
|
|
||||||
|
|
||||||
def get_flash_attn_version():
|
|
||||||
try:
|
|
||||||
from vllm.vllm_flash_attn.flash_attn_interface import (
|
|
||||||
fa_version_unsupported_reason, is_fa_version_supported)
|
|
||||||
|
|
||||||
# if hopper default to FA3, otherwise stick to FA2 for now
|
|
||||||
# TODO(lucas): profile FA3 on ampere to see if it makes sense to
|
|
||||||
# use FA3 as default for both
|
|
||||||
if current_platform.get_device_capability()[0] == 9:
|
|
||||||
fa_version = 3 if is_fa_version_supported(3) else 2
|
|
||||||
else:
|
|
||||||
fa_version = 2
|
|
||||||
|
|
||||||
if envs.VLLM_FLASH_ATTN_VERSION is not None:
|
|
||||||
assert envs.VLLM_FLASH_ATTN_VERSION in [2, 3]
|
|
||||||
fa_version = envs.VLLM_FLASH_ATTN_VERSION
|
|
||||||
if (current_platform.get_device_capability()[0] == 10
|
|
||||||
and envs.VLLM_FLASH_ATTN_VERSION == 3):
|
|
||||||
logger.warning("Cannot use FA version 3 on Blackwell platform",
|
|
||||||
"defaulting to FA version 2.")
|
|
||||||
fa_version = 2
|
|
||||||
|
|
||||||
if not is_fa_version_supported(fa_version):
|
|
||||||
logger.error("Cannot use FA version %d is not supported due to %s",
|
|
||||||
fa_version, fa_version_unsupported_reason(fa_version))
|
|
||||||
|
|
||||||
assert is_fa_version_supported(fa_version)
|
|
||||||
return fa_version
|
|
||||||
except (ImportError, AssertionError):
|
|
||||||
return None
|
|
||||||
|
@ -84,6 +84,9 @@ class Attention(nn.Module):
|
|||||||
self.calculate_kv_scales = calculate_kv_scales
|
self.calculate_kv_scales = calculate_kv_scales
|
||||||
self._k_scale = torch.tensor(1.0, dtype=torch.float32)
|
self._k_scale = torch.tensor(1.0, dtype=torch.float32)
|
||||||
self._v_scale = torch.tensor(1.0, dtype=torch.float32)
|
self._v_scale = torch.tensor(1.0, dtype=torch.float32)
|
||||||
|
# FlashAttn doesn't support quantizing the kv-cache only
|
||||||
|
# but requires q to be quantized as well.
|
||||||
|
self._q_scale = torch.tensor(1.0, dtype=torch.float32)
|
||||||
|
|
||||||
# We also keep the float32 versions of k/v_scale for attention
|
# We also keep the float32 versions of k/v_scale for attention
|
||||||
# backends that don't support tensors (Flashinfer)
|
# backends that don't support tensors (Flashinfer)
|
||||||
@ -153,6 +156,7 @@ class Attention(nn.Module):
|
|||||||
).parallel_config.pipeline_parallel_size)
|
).parallel_config.pipeline_parallel_size)
|
||||||
]
|
]
|
||||||
|
|
||||||
|
self.q_range = torch.tensor(envs.Q_SCALE_CONSTANT, dtype=torch.float32)
|
||||||
self.k_range = torch.tensor(envs.K_SCALE_CONSTANT, dtype=torch.float32)
|
self.k_range = torch.tensor(envs.K_SCALE_CONSTANT, dtype=torch.float32)
|
||||||
self.v_range = torch.tensor(envs.V_SCALE_CONSTANT, dtype=torch.float32)
|
self.v_range = torch.tensor(envs.V_SCALE_CONSTANT, dtype=torch.float32)
|
||||||
|
|
||||||
@ -178,7 +182,7 @@ class Attention(nn.Module):
|
|||||||
if self.calculate_kv_scales:
|
if self.calculate_kv_scales:
|
||||||
attn_metadata = get_forward_context().attn_metadata
|
attn_metadata = get_forward_context().attn_metadata
|
||||||
if attn_metadata.enable_kv_scales_calculation:
|
if attn_metadata.enable_kv_scales_calculation:
|
||||||
self.calc_kv_scales(key, value)
|
self.calc_kv_scales(query, key, value)
|
||||||
if self.use_output:
|
if self.use_output:
|
||||||
output_shape = (output_shape
|
output_shape = (output_shape
|
||||||
if output_shape is not None else query.shape)
|
if output_shape is not None else query.shape)
|
||||||
@ -225,7 +229,8 @@ class Attention(nn.Module):
|
|||||||
return torch.ops.vllm.unified_attention(
|
return torch.ops.vllm.unified_attention(
|
||||||
query, key, value, self.layer_name)
|
query, key, value, self.layer_name)
|
||||||
|
|
||||||
def calc_kv_scales(self, key, value):
|
def calc_kv_scales(self, query, key, value):
|
||||||
|
self._q_scale.copy_(torch.abs(query).max() / self.q_range)
|
||||||
self._k_scale.copy_(torch.abs(key).max() / self.k_range)
|
self._k_scale.copy_(torch.abs(key).max() / self.k_range)
|
||||||
self._v_scale.copy_(torch.abs(value).max() / self.v_range)
|
self._v_scale.copy_(torch.abs(value).max() / self.v_range)
|
||||||
self._k_scale_float = self._k_scale.item()
|
self._k_scale_float = self._k_scale.item()
|
||||||
|
@ -78,6 +78,7 @@ if TYPE_CHECKING:
|
|||||||
VLLM_ENABLE_V1_MULTIPROCESSING: bool = True
|
VLLM_ENABLE_V1_MULTIPROCESSING: bool = True
|
||||||
VLLM_LOG_BATCHSIZE_INTERVAL: float = -1
|
VLLM_LOG_BATCHSIZE_INTERVAL: float = -1
|
||||||
VLLM_DISABLE_COMPILE_CACHE: bool = False
|
VLLM_DISABLE_COMPILE_CACHE: bool = False
|
||||||
|
Q_SCALE_CONSTANT: int = 200
|
||||||
K_SCALE_CONSTANT: int = 200
|
K_SCALE_CONSTANT: int = 200
|
||||||
V_SCALE_CONSTANT: int = 100
|
V_SCALE_CONSTANT: int = 100
|
||||||
VLLM_SERVER_DEV_MODE: bool = False
|
VLLM_SERVER_DEV_MODE: bool = False
|
||||||
@ -524,13 +525,17 @@ environment_variables: dict[str, Callable[[], Any]] = {
|
|||||||
# Pad the fp8 weights to 256 bytes for ROCm
|
# Pad the fp8 weights to 256 bytes for ROCm
|
||||||
"VLLM_ROCM_FP8_PADDING":
|
"VLLM_ROCM_FP8_PADDING":
|
||||||
lambda: bool(int(os.getenv("VLLM_ROCM_FP8_PADDING", "1"))),
|
lambda: bool(int(os.getenv("VLLM_ROCM_FP8_PADDING", "1"))),
|
||||||
|
|
||||||
|
# Divisor for dynamic query scale factor calculation for FP8 KV Cache
|
||||||
|
"Q_SCALE_CONSTANT":
|
||||||
|
lambda: int(os.getenv("Q_SCALE_CONSTANT", "200")),
|
||||||
# Divisor for dynamic key scale factor calculation for FP8 KV Cache
|
# Divisor for dynamic key scale factor calculation for FP8 KV Cache
|
||||||
"K_SCALE_CONSTANT":
|
"K_SCALE_CONSTANT":
|
||||||
lambda: int(os.getenv("K_SCALE_CONSTANT", "200")),
|
lambda: int(os.getenv("K_SCALE_CONSTANT", "200")),
|
||||||
|
|
||||||
# Divisor for dynamic value scale factor calculation for FP8 KV Cache
|
# Divisor for dynamic value scale factor calculation for FP8 KV Cache
|
||||||
"V_SCALE_CONSTANT":
|
"V_SCALE_CONSTANT":
|
||||||
lambda: int(os.getenv("V_SCALE_CONSTANT", "100")),
|
lambda: int(os.getenv("V_SCALE_CONSTANT", "100")),
|
||||||
|
|
||||||
# If set, enable multiprocessing in LLM for the V1 code path.
|
# If set, enable multiprocessing in LLM for the V1 code path.
|
||||||
"VLLM_ENABLE_V1_MULTIPROCESSING":
|
"VLLM_ENABLE_V1_MULTIPROCESSING":
|
||||||
lambda: bool(int(os.getenv("VLLM_ENABLE_V1_MULTIPROCESSING", "1"))),
|
lambda: bool(int(os.getenv("VLLM_ENABLE_V1_MULTIPROCESSING", "1"))),
|
||||||
|
42
vllm/fa_utils.py
Normal file
42
vllm/fa_utils.py
Normal file
@ -0,0 +1,42 @@
|
|||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
from vllm import envs
|
||||||
|
from vllm.logger import init_logger
|
||||||
|
|
||||||
|
logger = init_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
def get_flash_attn_version() -> Optional[int]:
|
||||||
|
# import here to avoid circular dependencies
|
||||||
|
from vllm.platforms import current_platform
|
||||||
|
try:
|
||||||
|
from vllm.vllm_flash_attn.flash_attn_interface import (
|
||||||
|
fa_version_unsupported_reason, is_fa_version_supported)
|
||||||
|
device_capability = current_platform.get_device_capability()
|
||||||
|
|
||||||
|
assert device_capability is not None
|
||||||
|
|
||||||
|
# 1. default version depending on platform
|
||||||
|
fa_version = 3 if (device_capability.major == 9
|
||||||
|
and is_fa_version_supported(3)) else 2
|
||||||
|
|
||||||
|
# 2. override if passed by environment
|
||||||
|
if envs.VLLM_FLASH_ATTN_VERSION is not None:
|
||||||
|
assert envs.VLLM_FLASH_ATTN_VERSION in [2, 3]
|
||||||
|
fa_version = envs.VLLM_FLASH_ATTN_VERSION
|
||||||
|
|
||||||
|
# 3. fallback for unsupported combinations
|
||||||
|
if device_capability.major == 10 and fa_version == 3:
|
||||||
|
logger.warning("Cannot use FA version 3 on Blackwell platform",
|
||||||
|
"defaulting to FA version 2.")
|
||||||
|
fa_version = 2
|
||||||
|
|
||||||
|
if not is_fa_version_supported(fa_version):
|
||||||
|
logger.error("Cannot use FA version %d is not supported due to %s",
|
||||||
|
fa_version, fa_version_unsupported_reason(fa_version))
|
||||||
|
|
||||||
|
assert is_fa_version_supported(fa_version)
|
||||||
|
return fa_version
|
||||||
|
except (ImportError, AssertionError):
|
||||||
|
return None
|
@ -26,11 +26,14 @@ class BaseKVCacheMethod(QuantizeMethodBase):
|
|||||||
|
|
||||||
def create_weights(self, layer: torch.nn.Module):
|
def create_weights(self, layer: torch.nn.Module):
|
||||||
"""
|
"""
|
||||||
Create "weight" (aka k_scale and v_scale) for an attention layer.
|
Create "weight" (aka q_scale, k_scale and v_scale)
|
||||||
|
for an attention layer.
|
||||||
"""
|
"""
|
||||||
# Initialize the KV cache scales to -1.0, which is an invalid value.
|
# Initialize the Q and KV cache scales to -1.0, an invalid value.
|
||||||
# If the k/v_scale appears in the checkpoint, it will be
|
# If the q and k/v_scales appear in the checkpoint, it will be
|
||||||
# overwritten when loading weights.
|
# overwritten when loading weights.
|
||||||
|
layer.q_scale = torch.nn.Parameter(torch.tensor(-1.0),
|
||||||
|
requires_grad=False)
|
||||||
layer.k_scale = torch.nn.Parameter(torch.tensor(-1.0),
|
layer.k_scale = torch.nn.Parameter(torch.tensor(-1.0),
|
||||||
requires_grad=False)
|
requires_grad=False)
|
||||||
layer.v_scale = torch.nn.Parameter(torch.tensor(-1.0),
|
layer.v_scale = torch.nn.Parameter(torch.tensor(-1.0),
|
||||||
@ -75,6 +78,13 @@ class BaseKVCacheMethod(QuantizeMethodBase):
|
|||||||
raise ValueError("Only support per-tensor scaling factor "
|
raise ValueError("Only support per-tensor scaling factor "
|
||||||
"for fp8 KV cache")
|
"for fp8 KV cache")
|
||||||
|
|
||||||
|
if layer.q_scale < 0.0:
|
||||||
|
logger.warning_once(
|
||||||
|
"Checkpoint does not provide a q scaling factor. "
|
||||||
|
"Setting it to k_scale. This only matters for "
|
||||||
|
"the flash-attn backend.")
|
||||||
|
layer._q_scale.copy_(k_scale)
|
||||||
|
|
||||||
# These are used in the final Attention.forward()
|
# These are used in the final Attention.forward()
|
||||||
layer._k_scale.copy_(k_scale)
|
layer._k_scale.copy_(k_scale)
|
||||||
layer._v_scale.copy_(v_scale)
|
layer._v_scale.copy_(v_scale)
|
||||||
|
@ -14,6 +14,7 @@ from typing_extensions import ParamSpec
|
|||||||
# import custom ops, trigger op registration
|
# import custom ops, trigger op registration
|
||||||
import vllm._C # noqa
|
import vllm._C # noqa
|
||||||
import vllm.envs as envs
|
import vllm.envs as envs
|
||||||
|
from vllm.fa_utils import get_flash_attn_version
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
from vllm.utils import import_pynvml
|
from vllm.utils import import_pynvml
|
||||||
|
|
||||||
@ -240,15 +241,6 @@ class CudaPlatformBase(Platform):
|
|||||||
"Cannot use FlashAttention-2 backend for dtype other than "
|
"Cannot use FlashAttention-2 backend for dtype other than "
|
||||||
"torch.float16 or torch.bfloat16.")
|
"torch.float16 or torch.bfloat16.")
|
||||||
target_backend = _Backend.XFORMERS
|
target_backend = _Backend.XFORMERS
|
||||||
elif kv_cache_dtype is not None and \
|
|
||||||
kv_cache_dtype.startswith("fp8"):
|
|
||||||
logger.info(
|
|
||||||
"Cannot use FlashAttention-2 backend for FP8 KV cache.")
|
|
||||||
logger.warning(
|
|
||||||
"Please use FlashInfer backend with FP8 KV Cache for "
|
|
||||||
"better performance by setting environment variable "
|
|
||||||
"VLLM_ATTENTION_BACKEND=FLASHINFER")
|
|
||||||
target_backend = _Backend.XFORMERS
|
|
||||||
elif block_size % 16 != 0:
|
elif block_size % 16 != 0:
|
||||||
logger.info(
|
logger.info(
|
||||||
"Cannot use FlashAttention-2 backend for block size not "
|
"Cannot use FlashAttention-2 backend for block size not "
|
||||||
@ -270,6 +262,17 @@ class CudaPlatformBase(Platform):
|
|||||||
"Cannot use FlashAttention-2 backend for head size %d.",
|
"Cannot use FlashAttention-2 backend for head size %d.",
|
||||||
head_size)
|
head_size)
|
||||||
target_backend = _Backend.XFORMERS
|
target_backend = _Backend.XFORMERS
|
||||||
|
fp8_kv_cache = (kv_cache_dtype is not None
|
||||||
|
and kv_cache_dtype.startswith("fp8"))
|
||||||
|
if (fp8_kv_cache and get_flash_attn_version() != 3):
|
||||||
|
logger.info(
|
||||||
|
"Cannot use FlashAttention-2 backend for FP8 KV cache."
|
||||||
|
)
|
||||||
|
logger.warning(
|
||||||
|
"Please use FlashInfer backend with FP8 KV Cache for "
|
||||||
|
"better performance by setting environment variable "
|
||||||
|
"VLLM_ATTENTION_BACKEND=FLASHINFER")
|
||||||
|
target_backend = _Backend.XFORMERS
|
||||||
except ImportError:
|
except ImportError:
|
||||||
logger.info(
|
logger.info(
|
||||||
"Cannot use FlashAttention-2 backend because the "
|
"Cannot use FlashAttention-2 backend because the "
|
||||||
|
@ -6,11 +6,12 @@ from typing import TYPE_CHECKING, Any, Optional
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
|
from vllm import _custom_ops as ops
|
||||||
from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
|
from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
|
||||||
AttentionMetadata, AttentionType,
|
AttentionMetadata, AttentionType,
|
||||||
is_quantized_kv_cache)
|
is_quantized_kv_cache)
|
||||||
from vllm.attention.backends.utils import get_flash_attn_version
|
|
||||||
from vllm.attention.ops.triton_merge_attn_states import merge_attn_states
|
from vllm.attention.ops.triton_merge_attn_states import merge_attn_states
|
||||||
|
from vllm.fa_utils import get_flash_attn_version
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
from vllm.platforms import current_platform
|
from vllm.platforms import current_platform
|
||||||
from vllm.utils import cdiv
|
from vllm.utils import cdiv
|
||||||
@ -226,6 +227,9 @@ class FlashAttentionImpl(AttentionImpl):
|
|||||||
attn_metadata: Metadata for attention.
|
attn_metadata: Metadata for attention.
|
||||||
Returns:
|
Returns:
|
||||||
shape = [num_tokens, num_heads * head_size]
|
shape = [num_tokens, num_heads * head_size]
|
||||||
|
NOTE: FP8 quantization, flash-attn expect the size of
|
||||||
|
{q,k,v}_descale to be (num_sequences, num_kv_heads).
|
||||||
|
We use torch's .expand() to avoid duplicating values
|
||||||
"""
|
"""
|
||||||
assert output is not None, "Output tensor must be provided."
|
assert output is not None, "Output tensor must be provided."
|
||||||
|
|
||||||
@ -259,6 +263,17 @@ class FlashAttentionImpl(AttentionImpl):
|
|||||||
layer._k_scale,
|
layer._k_scale,
|
||||||
layer._v_scale,
|
layer._v_scale,
|
||||||
)
|
)
|
||||||
|
descale_shape = (attn_metadata.query_start_loc.shape[0] - 1,
|
||||||
|
key.shape[1])
|
||||||
|
if self.kv_cache_dtype.startswith("fp8"):
|
||||||
|
key_cache = key_cache.view(torch.float8_e4m3fn)
|
||||||
|
value_cache = value_cache.view(torch.float8_e4m3fn)
|
||||||
|
num_tokens, num_heads, head_size = query.shape
|
||||||
|
query, _ = ops.scaled_fp8_quant(
|
||||||
|
query.reshape(
|
||||||
|
(num_tokens, num_heads * head_size)).contiguous(),
|
||||||
|
layer._q_scale)
|
||||||
|
query = query.reshape((num_tokens, num_heads, head_size))
|
||||||
|
|
||||||
# Compute attention and update output up to `num_actual_tokens`.
|
# Compute attention and update output up to `num_actual_tokens`.
|
||||||
if not attn_metadata.use_cascade:
|
if not attn_metadata.use_cascade:
|
||||||
@ -279,6 +294,9 @@ class FlashAttentionImpl(AttentionImpl):
|
|||||||
block_table=attn_metadata.block_table,
|
block_table=attn_metadata.block_table,
|
||||||
softcap=self.logits_soft_cap,
|
softcap=self.logits_soft_cap,
|
||||||
fa_version=self.vllm_flash_attn_version,
|
fa_version=self.vllm_flash_attn_version,
|
||||||
|
q_descale=layer._q_scale.expand(descale_shape),
|
||||||
|
k_descale=layer._k_scale.expand(descale_shape),
|
||||||
|
v_descale=layer._v_scale.expand(descale_shape),
|
||||||
)
|
)
|
||||||
return output
|
return output
|
||||||
|
|
||||||
@ -301,6 +319,9 @@ class FlashAttentionImpl(AttentionImpl):
|
|||||||
block_table=attn_metadata.block_table,
|
block_table=attn_metadata.block_table,
|
||||||
common_prefix_len=attn_metadata.common_prefix_len,
|
common_prefix_len=attn_metadata.common_prefix_len,
|
||||||
fa_version=self.vllm_flash_attn_version,
|
fa_version=self.vllm_flash_attn_version,
|
||||||
|
q_descale=layer._q_scale,
|
||||||
|
k_descale=layer._k_scale,
|
||||||
|
v_descale=layer._v_scale,
|
||||||
)
|
)
|
||||||
return output
|
return output
|
||||||
|
|
||||||
@ -391,6 +412,9 @@ def cascade_attention(
|
|||||||
block_table: torch.Tensor,
|
block_table: torch.Tensor,
|
||||||
common_prefix_len: int,
|
common_prefix_len: int,
|
||||||
fa_version: int,
|
fa_version: int,
|
||||||
|
q_descale: Optional[torch.Tensor] = None,
|
||||||
|
k_descale: Optional[torch.Tensor] = None,
|
||||||
|
v_descale: Optional[torch.Tensor] = None,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
assert alibi_slopes is None, ("Cascade attention does not support ALiBi.")
|
assert alibi_slopes is None, ("Cascade attention does not support ALiBi.")
|
||||||
# TODO: Support sliding window.
|
# TODO: Support sliding window.
|
||||||
@ -402,6 +426,7 @@ def cascade_attention(
|
|||||||
assert common_prefix_len % block_size == 0
|
assert common_prefix_len % block_size == 0
|
||||||
num_common_kv_blocks = common_prefix_len // block_size
|
num_common_kv_blocks = common_prefix_len // block_size
|
||||||
assert num_common_kv_blocks > 0
|
assert num_common_kv_blocks > 0
|
||||||
|
descale_shape = (cu_prefix_query_lens.shape[0] - 1, key_cache.shape[-2])
|
||||||
|
|
||||||
# Process shared prefix.
|
# Process shared prefix.
|
||||||
prefix_output, prefix_lse = flash_attn_varlen_func(
|
prefix_output, prefix_lse = flash_attn_varlen_func(
|
||||||
@ -419,8 +444,16 @@ def cascade_attention(
|
|||||||
softcap=logits_soft_cap,
|
softcap=logits_soft_cap,
|
||||||
return_softmax_lse=True,
|
return_softmax_lse=True,
|
||||||
fa_version=fa_version,
|
fa_version=fa_version,
|
||||||
|
q_descale=q_descale.expand(descale_shape)
|
||||||
|
if q_descale is not None else None,
|
||||||
|
k_descale=k_descale.expand(descale_shape)
|
||||||
|
if k_descale is not None else None,
|
||||||
|
v_descale=v_descale.expand(descale_shape)
|
||||||
|
if v_descale is not None else None,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
descale_shape = (cu_query_lens.shape[0] - 1, key_cache.shape[-2])
|
||||||
|
|
||||||
# Process suffix per query.
|
# Process suffix per query.
|
||||||
suffix_output, suffix_lse = flash_attn_varlen_func(
|
suffix_output, suffix_lse = flash_attn_varlen_func(
|
||||||
q=query,
|
q=query,
|
||||||
@ -437,6 +470,12 @@ def cascade_attention(
|
|||||||
softcap=logits_soft_cap,
|
softcap=logits_soft_cap,
|
||||||
return_softmax_lse=True,
|
return_softmax_lse=True,
|
||||||
fa_version=fa_version,
|
fa_version=fa_version,
|
||||||
|
q_descale=q_descale.expand(descale_shape)
|
||||||
|
if q_descale is not None else None,
|
||||||
|
k_descale=k_descale.expand(descale_shape)
|
||||||
|
if k_descale is not None else None,
|
||||||
|
v_descale=v_descale.expand(descale_shape)
|
||||||
|
if v_descale is not None else None,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Merge prefix and suffix outputs, and store the result in output.
|
# Merge prefix and suffix outputs, and store the result in output.
|
||||||
|
@ -5,6 +5,7 @@ import pickle
|
|||||||
import signal
|
import signal
|
||||||
import sys
|
import sys
|
||||||
import time
|
import time
|
||||||
|
import traceback
|
||||||
import weakref
|
import weakref
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from enum import Enum, auto
|
from enum import Enum, auto
|
||||||
@ -370,6 +371,9 @@ class WorkerProc:
|
|||||||
func = partial(cloudpickle.loads(method), self.worker)
|
func = partial(cloudpickle.loads(method), self.worker)
|
||||||
output = func(*args, **kwargs)
|
output = func(*args, **kwargs)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
# Notes have been introduced in python 3.11
|
||||||
|
if hasattr(e, "add_note"):
|
||||||
|
e.add_note(traceback.format_exc())
|
||||||
self.worker_response_mq.enqueue(
|
self.worker_response_mq.enqueue(
|
||||||
(WorkerProc.ResponseStatus.FAILURE, e))
|
(WorkerProc.ResponseStatus.FAILURE, e))
|
||||||
logger.exception("WorkerProc hit an exception: %s", exc_info=e)
|
logger.exception("WorkerProc hit an exception: %s", exc_info=e)
|
||||||
|
@ -1558,7 +1558,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
|||||||
block_size=block_size,
|
block_size=block_size,
|
||||||
num_kv_heads=attn_module.num_kv_heads,
|
num_kv_heads=attn_module.num_kv_heads,
|
||||||
head_size=attn_module.head_size,
|
head_size=attn_module.head_size,
|
||||||
dtype=attn_module.dtype,
|
dtype=self.kv_cache_dtype,
|
||||||
use_mla=use_mla)
|
use_mla=use_mla)
|
||||||
elif attn_module.attn_type in (AttentionType.ENCODER,
|
elif attn_module.attn_type in (AttentionType.ENCODER,
|
||||||
AttentionType.ENCODER_ONLY):
|
AttentionType.ENCODER_ONLY):
|
||||||
|
Loading…
x
Reference in New Issue
Block a user