[Bugfix][Kernel] Add head size check for attention backend selection (#4944)
This commit is contained in:
parent
14772eeb8e
commit
99eff67ba9
@ -9,11 +9,13 @@ from vllm._C import cache_ops
|
|||||||
from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
|
from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
|
||||||
AttentionMetadata)
|
AttentionMetadata)
|
||||||
|
|
||||||
_SUPPORTED_HEAD_SIZES = [32, 64, 96, 128, 160, 192, 224, 256]
|
|
||||||
|
|
||||||
|
|
||||||
class FlashAttentionBackend(AttentionBackend):
|
class FlashAttentionBackend(AttentionBackend):
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def get_supported_head_sizes() -> List[int]:
|
||||||
|
return [32, 64, 96, 128, 160, 192, 224, 256]
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def get_name() -> str:
|
def get_name() -> str:
|
||||||
return "flash-attn"
|
return "flash-attn"
|
||||||
@ -237,10 +239,12 @@ class FlashAttentionImpl(AttentionImpl):
|
|||||||
# paged KV cache.
|
# paged KV cache.
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"Sliding window is not supported in FlashAttention.")
|
"Sliding window is not supported in FlashAttention.")
|
||||||
if head_size not in _SUPPORTED_HEAD_SIZES:
|
|
||||||
|
support_head_sizes = FlashAttentionBackend.get_supported_head_sizes()
|
||||||
|
if head_size not in support_head_sizes:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"Head size {head_size} is not supported by FlashAttention. "
|
f"Head size {head_size} is not supported by FlashAttention. "
|
||||||
f"Supported head sizes are: {_SUPPORTED_HEAD_SIZES}.")
|
f"Supported head sizes are: {support_head_sizes}.")
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
|
@ -34,11 +34,21 @@ def get_attn_backend(
|
|||||||
sliding_window, dtype, kv_cache_dtype,
|
sliding_window, dtype, kv_cache_dtype,
|
||||||
block_size)
|
block_size)
|
||||||
if backend == _Backend.FLASH_ATTN:
|
if backend == _Backend.FLASH_ATTN:
|
||||||
logger.info("Using FlashAttention-2 backend.")
|
|
||||||
from vllm.attention.backends.flash_attn import ( # noqa: F401
|
from vllm.attention.backends.flash_attn import ( # noqa: F401
|
||||||
FlashAttentionBackend)
|
FlashAttentionBackend)
|
||||||
return FlashAttentionBackend
|
|
||||||
elif backend == _Backend.XFORMERS:
|
# We check it here not in _which_attn_to_use because we cannot know
|
||||||
|
# the head size until we import FlashAttentionBackend.
|
||||||
|
supported_head_sizes = FlashAttentionBackend.get_supported_head_sizes()
|
||||||
|
if head_size in supported_head_sizes:
|
||||||
|
logger.info("Using FlashAttention-2 backend.")
|
||||||
|
return FlashAttentionBackend
|
||||||
|
logger.info(
|
||||||
|
"Cannot use FlashAttention-2 backend for head size %d. "
|
||||||
|
"Using XFormers backend instead.", head_size)
|
||||||
|
backend = _Backend.XFORMERS
|
||||||
|
|
||||||
|
if backend == _Backend.XFORMERS:
|
||||||
logger.info("Using XFormers backend.")
|
logger.info("Using XFormers backend.")
|
||||||
from vllm.attention.backends.xformers import ( # noqa: F401
|
from vllm.attention.backends.xformers import ( # noqa: F401
|
||||||
XFormersBackend)
|
XFormersBackend)
|
||||||
|
Loading…
x
Reference in New Issue
Block a user