[V1] Enable V1 Fp8 cache for FA3 in the oracle (#15191)
Signed-off-by: Lucas Wilkinson <lwilkinson@neuralmagic.com> Signed-off-by: Lucas Wilkinson <lwilkins@redhat.com>
This commit is contained in:
parent
9c5c81b0da
commit
dccf535f8e
3
.gitignore
vendored
3
.gitignore
vendored
@ -2,7 +2,8 @@
|
|||||||
/vllm/_version.py
|
/vllm/_version.py
|
||||||
|
|
||||||
# vllm-flash-attn built from source
|
# vllm-flash-attn built from source
|
||||||
vllm/vllm_flash_attn/
|
vllm/vllm_flash_attn/*
|
||||||
|
!vllm/vllm_flash_attn/fa_utils.py
|
||||||
|
|
||||||
# Byte-compiled / optimized / DLL files
|
# Byte-compiled / optimized / DLL files
|
||||||
__pycache__/
|
__pycache__/
|
||||||
|
@ -22,12 +22,13 @@ from vllm.attention.backends.utils import (
|
|||||||
compute_slot_mapping_start_idx, get_num_prefill_decode_query_kv_tokens,
|
compute_slot_mapping_start_idx, get_num_prefill_decode_query_kv_tokens,
|
||||||
get_seq_len_block_table_args, is_all_cross_attn_metadata_set,
|
get_seq_len_block_table_args, is_all_cross_attn_metadata_set,
|
||||||
is_all_encoder_attn_metadata_set, is_block_tables_empty)
|
is_all_encoder_attn_metadata_set, is_block_tables_empty)
|
||||||
from vllm.fa_utils import get_flash_attn_version
|
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
from vllm.multimodal import MultiModalPlaceholderMap
|
from vllm.multimodal import MultiModalPlaceholderMap
|
||||||
from vllm.utils import async_tensor_h2d, make_tensor_with_pad
|
from vllm.utils import async_tensor_h2d, make_tensor_with_pad
|
||||||
from vllm.vllm_flash_attn import (flash_attn_varlen_func,
|
from vllm.vllm_flash_attn import (flash_attn_varlen_func,
|
||||||
flash_attn_with_kvcache)
|
flash_attn_with_kvcache)
|
||||||
|
from vllm.vllm_flash_attn.fa_utils import (flash_attn_supports_fp8,
|
||||||
|
get_flash_attn_version)
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from vllm.worker.model_runner import (ModelInputForGPUBuilder,
|
from vllm.worker.model_runner import (ModelInputForGPUBuilder,
|
||||||
@ -632,10 +633,13 @@ class FlashAttentionImpl(AttentionImpl):
|
|||||||
self.kv_cache_dtype = kv_cache_dtype
|
self.kv_cache_dtype = kv_cache_dtype
|
||||||
self.vllm_flash_attn_version = get_flash_attn_version(
|
self.vllm_flash_attn_version = get_flash_attn_version(
|
||||||
requires_alibi=self.alibi_slopes is not None)
|
requires_alibi=self.alibi_slopes is not None)
|
||||||
if (is_quantized_kv_cache(self.kv_cache_dtype)
|
if is_quantized_kv_cache(self.kv_cache_dtype) and (
|
||||||
and self.vllm_flash_attn_version != 3):
|
not self.kv_cache_dtype.startswith("fp8")
|
||||||
|
or not flash_attn_supports_fp8()):
|
||||||
raise NotImplementedError(
|
raise NotImplementedError(
|
||||||
"Only FlashAttention3 supports FP8 KV cache")
|
f"FlashAttention does not support {self.kv_cache_dtype} "
|
||||||
|
"kv-cache on this device "
|
||||||
|
f"(FA supports fp8 = {flash_attn_supports_fp8()}).")
|
||||||
if logits_soft_cap is None:
|
if logits_soft_cap is None:
|
||||||
# In flash-attn, setting logits_soft_cap as 0 means no soft cap.
|
# In flash-attn, setting logits_soft_cap as 0 means no soft cap.
|
||||||
logits_soft_cap = 0
|
logits_soft_cap = 0
|
||||||
@ -704,6 +708,10 @@ class FlashAttentionImpl(AttentionImpl):
|
|||||||
logits_soft_cap: Optional[float] = self.logits_soft_cap
|
logits_soft_cap: Optional[float] = self.logits_soft_cap
|
||||||
fp8_attention = kv_cache_dtype.startswith("fp8")
|
fp8_attention = kv_cache_dtype.startswith("fp8")
|
||||||
|
|
||||||
|
if fp8_attention and not flash_attn_supports_fp8():
|
||||||
|
raise NotImplementedError(
|
||||||
|
"FlashAttention does not support FP8 kv-cache on this device.")
|
||||||
|
|
||||||
if kv_cache.numel() > 0:
|
if kv_cache.numel() > 0:
|
||||||
key_cache = kv_cache[0]
|
key_cache = kv_cache[0]
|
||||||
value_cache = kv_cache[1]
|
value_cache = kv_cache[1]
|
||||||
|
@ -205,7 +205,6 @@ from vllm.attention.backends.utils import (PAD_SLOT_ID, compute_slot_mapping,
|
|||||||
compute_slot_mapping_start_idx,
|
compute_slot_mapping_start_idx,
|
||||||
is_block_tables_empty)
|
is_block_tables_empty)
|
||||||
from vllm.attention.ops.triton_merge_attn_states import merge_attn_states
|
from vllm.attention.ops.triton_merge_attn_states import merge_attn_states
|
||||||
from vllm.fa_utils import get_flash_attn_version
|
|
||||||
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
|
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
|
||||||
LinearBase, RowParallelLinear,
|
LinearBase, RowParallelLinear,
|
||||||
UnquantizedLinearMethod)
|
UnquantizedLinearMethod)
|
||||||
@ -214,6 +213,7 @@ from vllm.model_executor.layers.rotary_embedding import (
|
|||||||
from vllm.multimodal import MultiModalPlaceholderMap
|
from vllm.multimodal import MultiModalPlaceholderMap
|
||||||
from vllm.platforms import current_platform
|
from vllm.platforms import current_platform
|
||||||
from vllm.utils import async_tensor_h2d, cdiv, make_tensor_with_pad, round_down
|
from vllm.utils import async_tensor_h2d, cdiv, make_tensor_with_pad, round_down
|
||||||
|
from vllm.vllm_flash_attn.fa_utils import get_flash_attn_version
|
||||||
|
|
||||||
try:
|
try:
|
||||||
from vllm.vllm_flash_attn import flash_attn_varlen_func
|
from vllm.vllm_flash_attn import flash_attn_varlen_func
|
||||||
|
@ -1157,10 +1157,6 @@ class CacheConfig:
|
|||||||
if self.cache_dtype == "auto":
|
if self.cache_dtype == "auto":
|
||||||
pass
|
pass
|
||||||
elif self.cache_dtype in ("fp8", "fp8_e4m3", "fp8_e5m2"):
|
elif self.cache_dtype in ("fp8", "fp8_e4m3", "fp8_e5m2"):
|
||||||
if envs.VLLM_USE_V1:
|
|
||||||
raise NotImplementedError(
|
|
||||||
"V1 does not yet support fp8 KV cache. "
|
|
||||||
"Set VLLM_USE_V1=0 to enable fp8 kv cache.")
|
|
||||||
logger.info(
|
logger.info(
|
||||||
"Using fp8 data type to store kv cache. It reduces the GPU "
|
"Using fp8 data type to store kv cache. It reduces the GPU "
|
||||||
"memory footprint and boosts the performance. "
|
"memory footprint and boosts the performance. "
|
||||||
|
@ -1562,9 +1562,20 @@ class EngineArgs:
|
|||||||
|
|
||||||
# No Fp8 KV cache so far.
|
# No Fp8 KV cache so far.
|
||||||
if self.kv_cache_dtype != "auto":
|
if self.kv_cache_dtype != "auto":
|
||||||
_raise_or_fallback(feature_name="--kv-cache-dtype",
|
fp8_attention = self.kv_cache_dtype.startswith("fp8")
|
||||||
recommend_to_remove=False)
|
will_use_fa = (
|
||||||
return False
|
current_platform.is_cuda()
|
||||||
|
and not envs.is_set("VLLM_ATTENTION_BACKEND")
|
||||||
|
) or envs.VLLM_ATTENTION_BACKEND == "FLASH_ATTN_VLLM_V1"
|
||||||
|
supported = False
|
||||||
|
if fp8_attention and will_use_fa:
|
||||||
|
from vllm.vllm_flash_attn.fa_utils import (
|
||||||
|
flash_attn_supports_fp8)
|
||||||
|
supported = flash_attn_supports_fp8()
|
||||||
|
if not supported:
|
||||||
|
_raise_or_fallback(feature_name="--kv-cache-dtype",
|
||||||
|
recommend_to_remove=False)
|
||||||
|
return False
|
||||||
|
|
||||||
# No Prompt Adapter so far.
|
# No Prompt Adapter so far.
|
||||||
if self.enable_prompt_adapter:
|
if self.enable_prompt_adapter:
|
||||||
|
@ -14,7 +14,6 @@ from typing_extensions import ParamSpec
|
|||||||
# import custom ops, trigger op registration
|
# import custom ops, trigger op registration
|
||||||
import vllm._C # noqa
|
import vllm._C # noqa
|
||||||
import vllm.envs as envs
|
import vllm.envs as envs
|
||||||
from vllm.fa_utils import get_flash_attn_version
|
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
from vllm.utils import import_pynvml
|
from vllm.utils import import_pynvml
|
||||||
|
|
||||||
@ -258,7 +257,7 @@ class CudaPlatformBase(Platform):
|
|||||||
try:
|
try:
|
||||||
import vllm.vllm_flash_attn # noqa: F401
|
import vllm.vllm_flash_attn # noqa: F401
|
||||||
from vllm.attention.backends.flash_attn import ( # noqa: F401
|
from vllm.attention.backends.flash_attn import ( # noqa: F401
|
||||||
FlashAttentionBackend)
|
FlashAttentionBackend, flash_attn_supports_fp8)
|
||||||
|
|
||||||
supported_sizes = \
|
supported_sizes = \
|
||||||
FlashAttentionBackend.get_supported_head_sizes()
|
FlashAttentionBackend.get_supported_head_sizes()
|
||||||
@ -269,10 +268,9 @@ class CudaPlatformBase(Platform):
|
|||||||
target_backend = _Backend.XFORMERS
|
target_backend = _Backend.XFORMERS
|
||||||
fp8_kv_cache = (kv_cache_dtype is not None
|
fp8_kv_cache = (kv_cache_dtype is not None
|
||||||
and kv_cache_dtype.startswith("fp8"))
|
and kv_cache_dtype.startswith("fp8"))
|
||||||
if (fp8_kv_cache and get_flash_attn_version() != 3):
|
if (fp8_kv_cache and not flash_attn_supports_fp8()):
|
||||||
logger.info(
|
logger.info(
|
||||||
"Cannot use FlashAttention-2 backend for FP8 KV cache."
|
"Cannot use FlashAttention backend for FP8 KV cache.")
|
||||||
)
|
|
||||||
logger.warning(
|
logger.warning(
|
||||||
"Please use FlashInfer backend with FP8 KV Cache for "
|
"Please use FlashInfer backend with FP8 KV Cache for "
|
||||||
"better performance by setting environment variable "
|
"better performance by setting environment variable "
|
||||||
|
@ -11,10 +11,11 @@ from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
|
|||||||
AttentionMetadata, AttentionType,
|
AttentionMetadata, AttentionType,
|
||||||
is_quantized_kv_cache)
|
is_quantized_kv_cache)
|
||||||
from vllm.attention.ops.triton_merge_attn_states import merge_attn_states
|
from vllm.attention.ops.triton_merge_attn_states import merge_attn_states
|
||||||
from vllm.fa_utils import get_flash_attn_version
|
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
from vllm.platforms import current_platform
|
from vllm.platforms import current_platform
|
||||||
from vllm.utils import cdiv
|
from vllm.utils import cdiv
|
||||||
|
from vllm.vllm_flash_attn.fa_utils import (flash_attn_supports_fp8,
|
||||||
|
get_flash_attn_version)
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from vllm.v1.core.sched.output import SchedulerOutput
|
from vllm.v1.core.sched.output import SchedulerOutput
|
||||||
@ -182,9 +183,6 @@ class FlashAttentionImpl(AttentionImpl):
|
|||||||
else:
|
else:
|
||||||
self.sliding_window = (sliding_window - 1, 0)
|
self.sliding_window = (sliding_window - 1, 0)
|
||||||
self.kv_cache_dtype = kv_cache_dtype
|
self.kv_cache_dtype = kv_cache_dtype
|
||||||
if is_quantized_kv_cache(self.kv_cache_dtype):
|
|
||||||
raise NotImplementedError(
|
|
||||||
"FlashAttention V1 with FP8 KV cache not yet supported")
|
|
||||||
if logits_soft_cap is None:
|
if logits_soft_cap is None:
|
||||||
# In flash-attn, setting logits_soft_cap as 0 means no soft cap.
|
# In flash-attn, setting logits_soft_cap as 0 means no soft cap.
|
||||||
logits_soft_cap = 0
|
logits_soft_cap = 0
|
||||||
@ -206,6 +204,10 @@ class FlashAttentionImpl(AttentionImpl):
|
|||||||
"are not implemented for "
|
"are not implemented for "
|
||||||
"FlashAttentionImpl")
|
"FlashAttentionImpl")
|
||||||
self.vllm_flash_attn_version = get_flash_attn_version()
|
self.vllm_flash_attn_version = get_flash_attn_version()
|
||||||
|
if is_quantized_kv_cache(self.kv_cache_dtype) \
|
||||||
|
and not flash_attn_supports_fp8():
|
||||||
|
raise NotImplementedError(
|
||||||
|
"FlashAttention does not support fp8 kv-cache on this device.")
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
|
@ -196,7 +196,6 @@ from vllm.attention.backends.abstract import (AttentionBackend, AttentionLayer,
|
|||||||
AttentionMetadata,
|
AttentionMetadata,
|
||||||
MLAAttentionImpl)
|
MLAAttentionImpl)
|
||||||
from vllm.attention.ops.triton_merge_attn_states import merge_attn_states
|
from vllm.attention.ops.triton_merge_attn_states import merge_attn_states
|
||||||
from vllm.fa_utils import get_flash_attn_version
|
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
|
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
|
||||||
LinearBase, RowParallelLinear,
|
LinearBase, RowParallelLinear,
|
||||||
@ -204,6 +203,7 @@ from vllm.model_executor.layers.linear import (ColumnParallelLinear,
|
|||||||
from vllm.model_executor.layers.rotary_embedding import RotaryEmbedding
|
from vllm.model_executor.layers.rotary_embedding import RotaryEmbedding
|
||||||
from vllm.platforms import current_platform
|
from vllm.platforms import current_platform
|
||||||
from vllm.utils import cdiv, round_down
|
from vllm.utils import cdiv, round_down
|
||||||
|
from vllm.vllm_flash_attn.fa_utils import get_flash_attn_version
|
||||||
|
|
||||||
try:
|
try:
|
||||||
from vllm.vllm_flash_attn import flash_attn_varlen_func
|
from vllm.vllm_flash_attn import flash_attn_varlen_func
|
||||||
|
@ -46,3 +46,9 @@ def get_flash_attn_version(requires_alibi: bool = False) -> Optional[int]:
|
|||||||
return fa_version
|
return fa_version
|
||||||
except (ImportError, AssertionError):
|
except (ImportError, AssertionError):
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
def flash_attn_supports_fp8() -> bool:
|
||||||
|
from vllm.platforms import current_platform
|
||||||
|
return get_flash_attn_version() == 3 and \
|
||||||
|
current_platform.get_device_capability().major == 9
|
Loading…
x
Reference in New Issue
Block a user