[V1] Enable V1 Fp8 cache for FA3 in the oracle (#15191)
Signed-off-by: Lucas Wilkinson <lwilkinson@neuralmagic.com> Signed-off-by: Lucas Wilkinson <lwilkins@redhat.com>
This commit is contained in:
parent
9c5c81b0da
commit
dccf535f8e
3
.gitignore
vendored
3
.gitignore
vendored
@ -2,7 +2,8 @@
|
||||
/vllm/_version.py
|
||||
|
||||
# vllm-flash-attn built from source
|
||||
vllm/vllm_flash_attn/
|
||||
vllm/vllm_flash_attn/*
|
||||
!vllm/vllm_flash_attn/fa_utils.py
|
||||
|
||||
# Byte-compiled / optimized / DLL files
|
||||
__pycache__/
|
||||
|
@ -22,12 +22,13 @@ from vllm.attention.backends.utils import (
|
||||
compute_slot_mapping_start_idx, get_num_prefill_decode_query_kv_tokens,
|
||||
get_seq_len_block_table_args, is_all_cross_attn_metadata_set,
|
||||
is_all_encoder_attn_metadata_set, is_block_tables_empty)
|
||||
from vllm.fa_utils import get_flash_attn_version
|
||||
from vllm.logger import init_logger
|
||||
from vllm.multimodal import MultiModalPlaceholderMap
|
||||
from vllm.utils import async_tensor_h2d, make_tensor_with_pad
|
||||
from vllm.vllm_flash_attn import (flash_attn_varlen_func,
|
||||
flash_attn_with_kvcache)
|
||||
from vllm.vllm_flash_attn.fa_utils import (flash_attn_supports_fp8,
|
||||
get_flash_attn_version)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vllm.worker.model_runner import (ModelInputForGPUBuilder,
|
||||
@ -632,10 +633,13 @@ class FlashAttentionImpl(AttentionImpl):
|
||||
self.kv_cache_dtype = kv_cache_dtype
|
||||
self.vllm_flash_attn_version = get_flash_attn_version(
|
||||
requires_alibi=self.alibi_slopes is not None)
|
||||
if (is_quantized_kv_cache(self.kv_cache_dtype)
|
||||
and self.vllm_flash_attn_version != 3):
|
||||
if is_quantized_kv_cache(self.kv_cache_dtype) and (
|
||||
not self.kv_cache_dtype.startswith("fp8")
|
||||
or not flash_attn_supports_fp8()):
|
||||
raise NotImplementedError(
|
||||
"Only FlashAttention3 supports FP8 KV cache")
|
||||
f"FlashAttention does not support {self.kv_cache_dtype} "
|
||||
"kv-cache on this device "
|
||||
f"(FA supports fp8 = {flash_attn_supports_fp8()}).")
|
||||
if logits_soft_cap is None:
|
||||
# In flash-attn, setting logits_soft_cap as 0 means no soft cap.
|
||||
logits_soft_cap = 0
|
||||
@ -704,6 +708,10 @@ class FlashAttentionImpl(AttentionImpl):
|
||||
logits_soft_cap: Optional[float] = self.logits_soft_cap
|
||||
fp8_attention = kv_cache_dtype.startswith("fp8")
|
||||
|
||||
if fp8_attention and not flash_attn_supports_fp8():
|
||||
raise NotImplementedError(
|
||||
"FlashAttention does not support FP8 kv-cache on this device.")
|
||||
|
||||
if kv_cache.numel() > 0:
|
||||
key_cache = kv_cache[0]
|
||||
value_cache = kv_cache[1]
|
||||
|
@ -205,7 +205,6 @@ from vllm.attention.backends.utils import (PAD_SLOT_ID, compute_slot_mapping,
|
||||
compute_slot_mapping_start_idx,
|
||||
is_block_tables_empty)
|
||||
from vllm.attention.ops.triton_merge_attn_states import merge_attn_states
|
||||
from vllm.fa_utils import get_flash_attn_version
|
||||
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
|
||||
LinearBase, RowParallelLinear,
|
||||
UnquantizedLinearMethod)
|
||||
@ -214,6 +213,7 @@ from vllm.model_executor.layers.rotary_embedding import (
|
||||
from vllm.multimodal import MultiModalPlaceholderMap
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.utils import async_tensor_h2d, cdiv, make_tensor_with_pad, round_down
|
||||
from vllm.vllm_flash_attn.fa_utils import get_flash_attn_version
|
||||
|
||||
try:
|
||||
from vllm.vllm_flash_attn import flash_attn_varlen_func
|
||||
|
@ -1157,10 +1157,6 @@ class CacheConfig:
|
||||
if self.cache_dtype == "auto":
|
||||
pass
|
||||
elif self.cache_dtype in ("fp8", "fp8_e4m3", "fp8_e5m2"):
|
||||
if envs.VLLM_USE_V1:
|
||||
raise NotImplementedError(
|
||||
"V1 does not yet support fp8 KV cache. "
|
||||
"Set VLLM_USE_V1=0 to enable fp8 kv cache.")
|
||||
logger.info(
|
||||
"Using fp8 data type to store kv cache. It reduces the GPU "
|
||||
"memory footprint and boosts the performance. "
|
||||
|
@ -1562,6 +1562,17 @@ class EngineArgs:
|
||||
|
||||
# No Fp8 KV cache so far.
|
||||
if self.kv_cache_dtype != "auto":
|
||||
fp8_attention = self.kv_cache_dtype.startswith("fp8")
|
||||
will_use_fa = (
|
||||
current_platform.is_cuda()
|
||||
and not envs.is_set("VLLM_ATTENTION_BACKEND")
|
||||
) or envs.VLLM_ATTENTION_BACKEND == "FLASH_ATTN_VLLM_V1"
|
||||
supported = False
|
||||
if fp8_attention and will_use_fa:
|
||||
from vllm.vllm_flash_attn.fa_utils import (
|
||||
flash_attn_supports_fp8)
|
||||
supported = flash_attn_supports_fp8()
|
||||
if not supported:
|
||||
_raise_or_fallback(feature_name="--kv-cache-dtype",
|
||||
recommend_to_remove=False)
|
||||
return False
|
||||
|
@ -14,7 +14,6 @@ from typing_extensions import ParamSpec
|
||||
# import custom ops, trigger op registration
|
||||
import vllm._C # noqa
|
||||
import vllm.envs as envs
|
||||
from vllm.fa_utils import get_flash_attn_version
|
||||
from vllm.logger import init_logger
|
||||
from vllm.utils import import_pynvml
|
||||
|
||||
@ -258,7 +257,7 @@ class CudaPlatformBase(Platform):
|
||||
try:
|
||||
import vllm.vllm_flash_attn # noqa: F401
|
||||
from vllm.attention.backends.flash_attn import ( # noqa: F401
|
||||
FlashAttentionBackend)
|
||||
FlashAttentionBackend, flash_attn_supports_fp8)
|
||||
|
||||
supported_sizes = \
|
||||
FlashAttentionBackend.get_supported_head_sizes()
|
||||
@ -269,10 +268,9 @@ class CudaPlatformBase(Platform):
|
||||
target_backend = _Backend.XFORMERS
|
||||
fp8_kv_cache = (kv_cache_dtype is not None
|
||||
and kv_cache_dtype.startswith("fp8"))
|
||||
if (fp8_kv_cache and get_flash_attn_version() != 3):
|
||||
if (fp8_kv_cache and not flash_attn_supports_fp8()):
|
||||
logger.info(
|
||||
"Cannot use FlashAttention-2 backend for FP8 KV cache."
|
||||
)
|
||||
"Cannot use FlashAttention backend for FP8 KV cache.")
|
||||
logger.warning(
|
||||
"Please use FlashInfer backend with FP8 KV Cache for "
|
||||
"better performance by setting environment variable "
|
||||
|
@ -11,10 +11,11 @@ from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
|
||||
AttentionMetadata, AttentionType,
|
||||
is_quantized_kv_cache)
|
||||
from vllm.attention.ops.triton_merge_attn_states import merge_attn_states
|
||||
from vllm.fa_utils import get_flash_attn_version
|
||||
from vllm.logger import init_logger
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.utils import cdiv
|
||||
from vllm.vllm_flash_attn.fa_utils import (flash_attn_supports_fp8,
|
||||
get_flash_attn_version)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vllm.v1.core.sched.output import SchedulerOutput
|
||||
@ -182,9 +183,6 @@ class FlashAttentionImpl(AttentionImpl):
|
||||
else:
|
||||
self.sliding_window = (sliding_window - 1, 0)
|
||||
self.kv_cache_dtype = kv_cache_dtype
|
||||
if is_quantized_kv_cache(self.kv_cache_dtype):
|
||||
raise NotImplementedError(
|
||||
"FlashAttention V1 with FP8 KV cache not yet supported")
|
||||
if logits_soft_cap is None:
|
||||
# In flash-attn, setting logits_soft_cap as 0 means no soft cap.
|
||||
logits_soft_cap = 0
|
||||
@ -206,6 +204,10 @@ class FlashAttentionImpl(AttentionImpl):
|
||||
"are not implemented for "
|
||||
"FlashAttentionImpl")
|
||||
self.vllm_flash_attn_version = get_flash_attn_version()
|
||||
if is_quantized_kv_cache(self.kv_cache_dtype) \
|
||||
and not flash_attn_supports_fp8():
|
||||
raise NotImplementedError(
|
||||
"FlashAttention does not support fp8 kv-cache on this device.")
|
||||
|
||||
def forward(
|
||||
self,
|
||||
|
@ -196,7 +196,6 @@ from vllm.attention.backends.abstract import (AttentionBackend, AttentionLayer,
|
||||
AttentionMetadata,
|
||||
MLAAttentionImpl)
|
||||
from vllm.attention.ops.triton_merge_attn_states import merge_attn_states
|
||||
from vllm.fa_utils import get_flash_attn_version
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
|
||||
LinearBase, RowParallelLinear,
|
||||
@ -204,6 +203,7 @@ from vllm.model_executor.layers.linear import (ColumnParallelLinear,
|
||||
from vllm.model_executor.layers.rotary_embedding import RotaryEmbedding
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.utils import cdiv, round_down
|
||||
from vllm.vllm_flash_attn.fa_utils import get_flash_attn_version
|
||||
|
||||
try:
|
||||
from vllm.vllm_flash_attn import flash_attn_varlen_func
|
||||
|
@ -46,3 +46,9 @@ def get_flash_attn_version(requires_alibi: bool = False) -> Optional[int]:
|
||||
return fa_version
|
||||
except (ImportError, AssertionError):
|
||||
return None
|
||||
|
||||
|
||||
def flash_attn_supports_fp8() -> bool:
|
||||
from vllm.platforms import current_platform
|
||||
return get_flash_attn_version() == 3 and \
|
||||
current_platform.get_device_capability().major == 9
|
Loading…
x
Reference in New Issue
Block a user