93 lines
3.1 KiB
Python
93 lines
3.1 KiB
Python
import enum
|
|
from functools import lru_cache
|
|
from typing import Type
|
|
|
|
import torch
|
|
|
|
import vllm.envs as envs
|
|
from vllm.attention.backends.abstract import AttentionBackend
|
|
from vllm.logger import init_logger
|
|
from vllm.utils import is_cpu, is_hip
|
|
|
|
logger = init_logger(__name__)
|
|
|
|
|
|
class _Backend(enum.Enum):
|
|
FLASH_ATTN = enum.auto()
|
|
XFORMERS = enum.auto()
|
|
ROCM_FLASH = enum.auto()
|
|
TORCH_SDPA = enum.auto()
|
|
FLASHINFER = enum.auto()
|
|
|
|
|
|
@lru_cache(maxsize=None)
|
|
def get_attn_backend(dtype: torch.dtype) -> Type[AttentionBackend]:
|
|
backend = _which_attn_to_use(dtype)
|
|
if backend == _Backend.FLASH_ATTN:
|
|
logger.info("Using FlashAttention-2 backend.")
|
|
from vllm.attention.backends.flash_attn import ( # noqa: F401
|
|
FlashAttentionBackend)
|
|
return FlashAttentionBackend
|
|
elif backend == _Backend.XFORMERS:
|
|
logger.info("Using XFormers backend.")
|
|
from vllm.attention.backends.xformers import ( # noqa: F401
|
|
XFormersBackend)
|
|
return XFormersBackend
|
|
elif backend == _Backend.ROCM_FLASH:
|
|
logger.info("Using ROCmFlashAttention backend.")
|
|
from vllm.attention.backends.rocm_flash_attn import ( # noqa: F401
|
|
ROCmFlashAttentionBackend)
|
|
return ROCmFlashAttentionBackend
|
|
elif backend == _Backend.TORCH_SDPA:
|
|
logger.info("Using Torch SDPA backend.")
|
|
from vllm.attention.backends.torch_sdpa import TorchSDPABackend
|
|
return TorchSDPABackend
|
|
elif backend == _Backend.FLASHINFER:
|
|
logger.info("Using Flashinfer backend.")
|
|
logger.warning("Eager mode is enforced for the Flashinfer backend. ")
|
|
from vllm.attention.backends.flashinfer import FlashInferBackend
|
|
return FlashInferBackend
|
|
else:
|
|
raise ValueError("Invalid attention backend.")
|
|
|
|
|
|
def _which_attn_to_use(dtype: torch.dtype) -> _Backend:
|
|
"""Returns which flash attention backend to use."""
|
|
if is_cpu():
|
|
return _Backend.TORCH_SDPA
|
|
|
|
if is_hip():
|
|
# AMD GPUs.
|
|
if torch.cuda.get_device_capability()[0] != 9:
|
|
# not Instinct series GPUs.
|
|
logger.info("flash_atten is not supported on NAVI GPUs.")
|
|
return _Backend.ROCM_FLASH
|
|
|
|
# NVIDIA GPUs.
|
|
if torch.cuda.get_device_capability()[0] < 8:
|
|
# Volta and Turing NVIDIA GPUs.
|
|
logger.info("Cannot use FlashAttention-2 backend for Volta and Turing "
|
|
"GPUs.")
|
|
return _Backend.XFORMERS
|
|
|
|
if dtype not in (torch.float16, torch.bfloat16):
|
|
logger.info("Cannot use FlashAttention-2 backend for dtype other than "
|
|
"torch.float16 or torch.bfloat16.")
|
|
return _Backend.XFORMERS
|
|
|
|
try:
|
|
import vllm_flash_attn # noqa: F401
|
|
except ImportError:
|
|
logger.info(
|
|
"Cannot use FlashAttention-2 backend because the vllm_flash_attn "
|
|
"package is not found. `pip install vllm-flash-attn` for better "
|
|
"performance.")
|
|
return _Backend.XFORMERS
|
|
|
|
backend_by_env_var = envs.VLLM_ATTENTION_BACKEND
|
|
if backend_by_env_var is not None:
|
|
return _Backend[backend_by_env_var]
|
|
|
|
# Default case.
|
|
return _Backend.FLASH_ATTN
|