From dccf535f8edb861e95cf2f0b3512e1fd737265c2 Mon Sep 17 00:00:00 2001 From: Lucas Wilkinson Date: Sun, 23 Mar 2025 18:07:04 -0400 Subject: [PATCH] [V1] Enable V1 Fp8 cache for FA3 in the oracle (#15191) Signed-off-by: Lucas Wilkinson Signed-off-by: Lucas Wilkinson --- .gitignore | 3 ++- vllm/attention/backends/flash_attn.py | 16 ++++++++++++---- vllm/attention/backends/mla/common.py | 2 +- vllm/config.py | 4 ---- vllm/engine/arg_utils.py | 17 ++++++++++++++--- vllm/platforms/cuda.py | 8 +++----- vllm/v1/attention/backends/flash_attn.py | 10 ++++++---- vllm/v1/attention/backends/mla/common.py | 2 +- vllm/{ => vllm_flash_attn}/fa_utils.py | 6 ++++++ 9 files changed, 45 insertions(+), 23 deletions(-) rename vllm/{ => vllm_flash_attn}/fa_utils.py (90%) diff --git a/.gitignore b/.gitignore index e40752f4..6f5cbd07 100644 --- a/.gitignore +++ b/.gitignore @@ -2,7 +2,8 @@ /vllm/_version.py # vllm-flash-attn built from source -vllm/vllm_flash_attn/ +vllm/vllm_flash_attn/* +!vllm/vllm_flash_attn/fa_utils.py # Byte-compiled / optimized / DLL files __pycache__/ diff --git a/vllm/attention/backends/flash_attn.py b/vllm/attention/backends/flash_attn.py index 4cb0b916..27bd292b 100755 --- a/vllm/attention/backends/flash_attn.py +++ b/vllm/attention/backends/flash_attn.py @@ -22,12 +22,13 @@ from vllm.attention.backends.utils import ( compute_slot_mapping_start_idx, get_num_prefill_decode_query_kv_tokens, get_seq_len_block_table_args, is_all_cross_attn_metadata_set, is_all_encoder_attn_metadata_set, is_block_tables_empty) -from vllm.fa_utils import get_flash_attn_version from vllm.logger import init_logger from vllm.multimodal import MultiModalPlaceholderMap from vllm.utils import async_tensor_h2d, make_tensor_with_pad from vllm.vllm_flash_attn import (flash_attn_varlen_func, flash_attn_with_kvcache) +from vllm.vllm_flash_attn.fa_utils import (flash_attn_supports_fp8, + get_flash_attn_version) if TYPE_CHECKING: from vllm.worker.model_runner import (ModelInputForGPUBuilder, @@ -632,10 +633,13 @@ class FlashAttentionImpl(AttentionImpl): self.kv_cache_dtype = kv_cache_dtype self.vllm_flash_attn_version = get_flash_attn_version( requires_alibi=self.alibi_slopes is not None) - if (is_quantized_kv_cache(self.kv_cache_dtype) - and self.vllm_flash_attn_version != 3): + if is_quantized_kv_cache(self.kv_cache_dtype) and ( + not self.kv_cache_dtype.startswith("fp8") + or not flash_attn_supports_fp8()): raise NotImplementedError( - "Only FlashAttention3 supports FP8 KV cache") + f"FlashAttention does not support {self.kv_cache_dtype} " + "kv-cache on this device " + f"(FA supports fp8 = {flash_attn_supports_fp8()}).") if logits_soft_cap is None: # In flash-attn, setting logits_soft_cap as 0 means no soft cap. logits_soft_cap = 0 @@ -704,6 +708,10 @@ class FlashAttentionImpl(AttentionImpl): logits_soft_cap: Optional[float] = self.logits_soft_cap fp8_attention = kv_cache_dtype.startswith("fp8") + if fp8_attention and not flash_attn_supports_fp8(): + raise NotImplementedError( + "FlashAttention does not support FP8 kv-cache on this device.") + if kv_cache.numel() > 0: key_cache = kv_cache[0] value_cache = kv_cache[1] diff --git a/vllm/attention/backends/mla/common.py b/vllm/attention/backends/mla/common.py index 258090d3..1b1ab314 100644 --- a/vllm/attention/backends/mla/common.py +++ b/vllm/attention/backends/mla/common.py @@ -205,7 +205,6 @@ from vllm.attention.backends.utils import (PAD_SLOT_ID, compute_slot_mapping, compute_slot_mapping_start_idx, is_block_tables_empty) from vllm.attention.ops.triton_merge_attn_states import merge_attn_states -from vllm.fa_utils import get_flash_attn_version from vllm.model_executor.layers.linear import (ColumnParallelLinear, LinearBase, RowParallelLinear, UnquantizedLinearMethod) @@ -214,6 +213,7 @@ from vllm.model_executor.layers.rotary_embedding import ( from vllm.multimodal import MultiModalPlaceholderMap from vllm.platforms import current_platform from vllm.utils import async_tensor_h2d, cdiv, make_tensor_with_pad, round_down +from vllm.vllm_flash_attn.fa_utils import get_flash_attn_version try: from vllm.vllm_flash_attn import flash_attn_varlen_func diff --git a/vllm/config.py b/vllm/config.py index ea056bcc..e486889b 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -1157,10 +1157,6 @@ class CacheConfig: if self.cache_dtype == "auto": pass elif self.cache_dtype in ("fp8", "fp8_e4m3", "fp8_e5m2"): - if envs.VLLM_USE_V1: - raise NotImplementedError( - "V1 does not yet support fp8 KV cache. " - "Set VLLM_USE_V1=0 to enable fp8 kv cache.") logger.info( "Using fp8 data type to store kv cache. It reduces the GPU " "memory footprint and boosts the performance. " diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index c9946221..38a47a84 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -1562,9 +1562,20 @@ class EngineArgs: # No Fp8 KV cache so far. if self.kv_cache_dtype != "auto": - _raise_or_fallback(feature_name="--kv-cache-dtype", - recommend_to_remove=False) - return False + fp8_attention = self.kv_cache_dtype.startswith("fp8") + will_use_fa = ( + current_platform.is_cuda() + and not envs.is_set("VLLM_ATTENTION_BACKEND") + ) or envs.VLLM_ATTENTION_BACKEND == "FLASH_ATTN_VLLM_V1" + supported = False + if fp8_attention and will_use_fa: + from vllm.vllm_flash_attn.fa_utils import ( + flash_attn_supports_fp8) + supported = flash_attn_supports_fp8() + if not supported: + _raise_or_fallback(feature_name="--kv-cache-dtype", + recommend_to_remove=False) + return False # No Prompt Adapter so far. if self.enable_prompt_adapter: diff --git a/vllm/platforms/cuda.py b/vllm/platforms/cuda.py index 38d8fffd..bb773180 100644 --- a/vllm/platforms/cuda.py +++ b/vllm/platforms/cuda.py @@ -14,7 +14,6 @@ from typing_extensions import ParamSpec # import custom ops, trigger op registration import vllm._C # noqa import vllm.envs as envs -from vllm.fa_utils import get_flash_attn_version from vllm.logger import init_logger from vllm.utils import import_pynvml @@ -258,7 +257,7 @@ class CudaPlatformBase(Platform): try: import vllm.vllm_flash_attn # noqa: F401 from vllm.attention.backends.flash_attn import ( # noqa: F401 - FlashAttentionBackend) + FlashAttentionBackend, flash_attn_supports_fp8) supported_sizes = \ FlashAttentionBackend.get_supported_head_sizes() @@ -269,10 +268,9 @@ class CudaPlatformBase(Platform): target_backend = _Backend.XFORMERS fp8_kv_cache = (kv_cache_dtype is not None and kv_cache_dtype.startswith("fp8")) - if (fp8_kv_cache and get_flash_attn_version() != 3): + if (fp8_kv_cache and not flash_attn_supports_fp8()): logger.info( - "Cannot use FlashAttention-2 backend for FP8 KV cache." - ) + "Cannot use FlashAttention backend for FP8 KV cache.") logger.warning( "Please use FlashInfer backend with FP8 KV Cache for " "better performance by setting environment variable " diff --git a/vllm/v1/attention/backends/flash_attn.py b/vllm/v1/attention/backends/flash_attn.py index 27b3aabb..92e4ffd0 100755 --- a/vllm/v1/attention/backends/flash_attn.py +++ b/vllm/v1/attention/backends/flash_attn.py @@ -11,10 +11,11 @@ from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl, AttentionMetadata, AttentionType, is_quantized_kv_cache) from vllm.attention.ops.triton_merge_attn_states import merge_attn_states -from vllm.fa_utils import get_flash_attn_version from vllm.logger import init_logger from vllm.platforms import current_platform from vllm.utils import cdiv +from vllm.vllm_flash_attn.fa_utils import (flash_attn_supports_fp8, + get_flash_attn_version) if TYPE_CHECKING: from vllm.v1.core.sched.output import SchedulerOutput @@ -182,9 +183,6 @@ class FlashAttentionImpl(AttentionImpl): else: self.sliding_window = (sliding_window - 1, 0) self.kv_cache_dtype = kv_cache_dtype - if is_quantized_kv_cache(self.kv_cache_dtype): - raise NotImplementedError( - "FlashAttention V1 with FP8 KV cache not yet supported") if logits_soft_cap is None: # In flash-attn, setting logits_soft_cap as 0 means no soft cap. logits_soft_cap = 0 @@ -206,6 +204,10 @@ class FlashAttentionImpl(AttentionImpl): "are not implemented for " "FlashAttentionImpl") self.vllm_flash_attn_version = get_flash_attn_version() + if is_quantized_kv_cache(self.kv_cache_dtype) \ + and not flash_attn_supports_fp8(): + raise NotImplementedError( + "FlashAttention does not support fp8 kv-cache on this device.") def forward( self, diff --git a/vllm/v1/attention/backends/mla/common.py b/vllm/v1/attention/backends/mla/common.py index 31244443..1437db7e 100644 --- a/vllm/v1/attention/backends/mla/common.py +++ b/vllm/v1/attention/backends/mla/common.py @@ -196,7 +196,6 @@ from vllm.attention.backends.abstract import (AttentionBackend, AttentionLayer, AttentionMetadata, MLAAttentionImpl) from vllm.attention.ops.triton_merge_attn_states import merge_attn_states -from vllm.fa_utils import get_flash_attn_version from vllm.logger import init_logger from vllm.model_executor.layers.linear import (ColumnParallelLinear, LinearBase, RowParallelLinear, @@ -204,6 +203,7 @@ from vllm.model_executor.layers.linear import (ColumnParallelLinear, from vllm.model_executor.layers.rotary_embedding import RotaryEmbedding from vllm.platforms import current_platform from vllm.utils import cdiv, round_down +from vllm.vllm_flash_attn.fa_utils import get_flash_attn_version try: from vllm.vllm_flash_attn import flash_attn_varlen_func diff --git a/vllm/fa_utils.py b/vllm/vllm_flash_attn/fa_utils.py similarity index 90% rename from vllm/fa_utils.py rename to vllm/vllm_flash_attn/fa_utils.py index 41765349..ca88549f 100644 --- a/vllm/fa_utils.py +++ b/vllm/vllm_flash_attn/fa_utils.py @@ -46,3 +46,9 @@ def get_flash_attn_version(requires_alibi: bool = False) -> Optional[int]: return fa_version except (ImportError, AssertionError): return None + + +def flash_attn_supports_fp8() -> bool: + from vllm.platforms import current_platform + return get_flash_attn_version() == 3 and \ + current_platform.get_device_capability().major == 9