[V1] Enable V1 Fp8 cache for FA3 in the oracle (#15191)

Signed-off-by: Lucas Wilkinson <lwilkinson@neuralmagic.com>
Signed-off-by: Lucas Wilkinson <lwilkins@redhat.com>
This commit is contained in:
Lucas Wilkinson 2025-03-23 18:07:04 -04:00 committed by GitHub
parent 9c5c81b0da
commit dccf535f8e
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
9 changed files with 45 additions and 23 deletions

3
.gitignore vendored
View File

@ -2,7 +2,8 @@
/vllm/_version.py
# vllm-flash-attn built from source
vllm/vllm_flash_attn/
vllm/vllm_flash_attn/*
!vllm/vllm_flash_attn/fa_utils.py
# Byte-compiled / optimized / DLL files
__pycache__/

View File

@ -22,12 +22,13 @@ from vllm.attention.backends.utils import (
compute_slot_mapping_start_idx, get_num_prefill_decode_query_kv_tokens,
get_seq_len_block_table_args, is_all_cross_attn_metadata_set,
is_all_encoder_attn_metadata_set, is_block_tables_empty)
from vllm.fa_utils import get_flash_attn_version
from vllm.logger import init_logger
from vllm.multimodal import MultiModalPlaceholderMap
from vllm.utils import async_tensor_h2d, make_tensor_with_pad
from vllm.vllm_flash_attn import (flash_attn_varlen_func,
flash_attn_with_kvcache)
from vllm.vllm_flash_attn.fa_utils import (flash_attn_supports_fp8,
get_flash_attn_version)
if TYPE_CHECKING:
from vllm.worker.model_runner import (ModelInputForGPUBuilder,
@ -632,10 +633,13 @@ class FlashAttentionImpl(AttentionImpl):
self.kv_cache_dtype = kv_cache_dtype
self.vllm_flash_attn_version = get_flash_attn_version(
requires_alibi=self.alibi_slopes is not None)
if (is_quantized_kv_cache(self.kv_cache_dtype)
and self.vllm_flash_attn_version != 3):
if is_quantized_kv_cache(self.kv_cache_dtype) and (
not self.kv_cache_dtype.startswith("fp8")
or not flash_attn_supports_fp8()):
raise NotImplementedError(
"Only FlashAttention3 supports FP8 KV cache")
f"FlashAttention does not support {self.kv_cache_dtype} "
"kv-cache on this device "
f"(FA supports fp8 = {flash_attn_supports_fp8()}).")
if logits_soft_cap is None:
# In flash-attn, setting logits_soft_cap as 0 means no soft cap.
logits_soft_cap = 0
@ -704,6 +708,10 @@ class FlashAttentionImpl(AttentionImpl):
logits_soft_cap: Optional[float] = self.logits_soft_cap
fp8_attention = kv_cache_dtype.startswith("fp8")
if fp8_attention and not flash_attn_supports_fp8():
raise NotImplementedError(
"FlashAttention does not support FP8 kv-cache on this device.")
if kv_cache.numel() > 0:
key_cache = kv_cache[0]
value_cache = kv_cache[1]

View File

@ -205,7 +205,6 @@ from vllm.attention.backends.utils import (PAD_SLOT_ID, compute_slot_mapping,
compute_slot_mapping_start_idx,
is_block_tables_empty)
from vllm.attention.ops.triton_merge_attn_states import merge_attn_states
from vllm.fa_utils import get_flash_attn_version
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
LinearBase, RowParallelLinear,
UnquantizedLinearMethod)
@ -214,6 +213,7 @@ from vllm.model_executor.layers.rotary_embedding import (
from vllm.multimodal import MultiModalPlaceholderMap
from vllm.platforms import current_platform
from vllm.utils import async_tensor_h2d, cdiv, make_tensor_with_pad, round_down
from vllm.vllm_flash_attn.fa_utils import get_flash_attn_version
try:
from vllm.vllm_flash_attn import flash_attn_varlen_func

View File

@ -1157,10 +1157,6 @@ class CacheConfig:
if self.cache_dtype == "auto":
pass
elif self.cache_dtype in ("fp8", "fp8_e4m3", "fp8_e5m2"):
if envs.VLLM_USE_V1:
raise NotImplementedError(
"V1 does not yet support fp8 KV cache. "
"Set VLLM_USE_V1=0 to enable fp8 kv cache.")
logger.info(
"Using fp8 data type to store kv cache. It reduces the GPU "
"memory footprint and boosts the performance. "

View File

@ -1562,6 +1562,17 @@ class EngineArgs:
# No Fp8 KV cache so far.
if self.kv_cache_dtype != "auto":
fp8_attention = self.kv_cache_dtype.startswith("fp8")
will_use_fa = (
current_platform.is_cuda()
and not envs.is_set("VLLM_ATTENTION_BACKEND")
) or envs.VLLM_ATTENTION_BACKEND == "FLASH_ATTN_VLLM_V1"
supported = False
if fp8_attention and will_use_fa:
from vllm.vllm_flash_attn.fa_utils import (
flash_attn_supports_fp8)
supported = flash_attn_supports_fp8()
if not supported:
_raise_or_fallback(feature_name="--kv-cache-dtype",
recommend_to_remove=False)
return False

View File

@ -14,7 +14,6 @@ from typing_extensions import ParamSpec
# import custom ops, trigger op registration
import vllm._C # noqa
import vllm.envs as envs
from vllm.fa_utils import get_flash_attn_version
from vllm.logger import init_logger
from vllm.utils import import_pynvml
@ -258,7 +257,7 @@ class CudaPlatformBase(Platform):
try:
import vllm.vllm_flash_attn # noqa: F401
from vllm.attention.backends.flash_attn import ( # noqa: F401
FlashAttentionBackend)
FlashAttentionBackend, flash_attn_supports_fp8)
supported_sizes = \
FlashAttentionBackend.get_supported_head_sizes()
@ -269,10 +268,9 @@ class CudaPlatformBase(Platform):
target_backend = _Backend.XFORMERS
fp8_kv_cache = (kv_cache_dtype is not None
and kv_cache_dtype.startswith("fp8"))
if (fp8_kv_cache and get_flash_attn_version() != 3):
if (fp8_kv_cache and not flash_attn_supports_fp8()):
logger.info(
"Cannot use FlashAttention-2 backend for FP8 KV cache."
)
"Cannot use FlashAttention backend for FP8 KV cache.")
logger.warning(
"Please use FlashInfer backend with FP8 KV Cache for "
"better performance by setting environment variable "

View File

@ -11,10 +11,11 @@ from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
AttentionMetadata, AttentionType,
is_quantized_kv_cache)
from vllm.attention.ops.triton_merge_attn_states import merge_attn_states
from vllm.fa_utils import get_flash_attn_version
from vllm.logger import init_logger
from vllm.platforms import current_platform
from vllm.utils import cdiv
from vllm.vllm_flash_attn.fa_utils import (flash_attn_supports_fp8,
get_flash_attn_version)
if TYPE_CHECKING:
from vllm.v1.core.sched.output import SchedulerOutput
@ -182,9 +183,6 @@ class FlashAttentionImpl(AttentionImpl):
else:
self.sliding_window = (sliding_window - 1, 0)
self.kv_cache_dtype = kv_cache_dtype
if is_quantized_kv_cache(self.kv_cache_dtype):
raise NotImplementedError(
"FlashAttention V1 with FP8 KV cache not yet supported")
if logits_soft_cap is None:
# In flash-attn, setting logits_soft_cap as 0 means no soft cap.
logits_soft_cap = 0
@ -206,6 +204,10 @@ class FlashAttentionImpl(AttentionImpl):
"are not implemented for "
"FlashAttentionImpl")
self.vllm_flash_attn_version = get_flash_attn_version()
if is_quantized_kv_cache(self.kv_cache_dtype) \
and not flash_attn_supports_fp8():
raise NotImplementedError(
"FlashAttention does not support fp8 kv-cache on this device.")
def forward(
self,

View File

@ -196,7 +196,6 @@ from vllm.attention.backends.abstract import (AttentionBackend, AttentionLayer,
AttentionMetadata,
MLAAttentionImpl)
from vllm.attention.ops.triton_merge_attn_states import merge_attn_states
from vllm.fa_utils import get_flash_attn_version
from vllm.logger import init_logger
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
LinearBase, RowParallelLinear,
@ -204,6 +203,7 @@ from vllm.model_executor.layers.linear import (ColumnParallelLinear,
from vllm.model_executor.layers.rotary_embedding import RotaryEmbedding
from vllm.platforms import current_platform
from vllm.utils import cdiv, round_down
from vllm.vllm_flash_attn.fa_utils import get_flash_attn_version
try:
from vllm.vllm_flash_attn import flash_attn_varlen_func

View File

@ -46,3 +46,9 @@ def get_flash_attn_version(requires_alibi: bool = False) -> Optional[int]:
return fa_version
except (ImportError, AssertionError):
return None
def flash_attn_supports_fp8() -> bool:
from vllm.platforms import current_platform
return get_flash_attn_version() == 3 and \
current_platform.get_device_capability().major == 9