[FIX] Make flash_attn
optional (#3269)
This commit is contained in:
parent
99c3cfb83c
commit
1cb0cc2975
3
.gitignore
vendored
3
.gitignore
vendored
@ -184,6 +184,3 @@ _build/
|
||||
|
||||
# Benchmark dataset
|
||||
*.json
|
||||
|
||||
# Third-party Python packages.
|
||||
vllm/thirdparty_files/
|
||||
|
48
setup.py
48
setup.py
@ -3,7 +3,6 @@ import io
|
||||
import os
|
||||
import re
|
||||
import subprocess
|
||||
import sys
|
||||
import warnings
|
||||
from pathlib import Path
|
||||
from typing import List, Set
|
||||
@ -15,8 +14,6 @@ import torch.utils.cpp_extension as torch_cpp_ext
|
||||
from torch.utils.cpp_extension import BuildExtension, CUDAExtension, CUDA_HOME, ROCM_HOME
|
||||
|
||||
ROOT_DIR = os.path.dirname(__file__)
|
||||
# This is a temporary directory to store third-party packages.
|
||||
THIRDPARTY_SUBDIR = "vllm/thirdparty_files"
|
||||
|
||||
# If you are developing the C++ backend of vLLM, consider building vLLM with
|
||||
# `python setup.py develop` since it will give you incremental builds.
|
||||
@ -327,46 +324,8 @@ if _is_cuda():
|
||||
"nvcc": NVCC_FLAGS_PUNICA,
|
||||
},
|
||||
))
|
||||
|
||||
# Download the FlashAttention package.
|
||||
# Adapted from https://github.com/ray-project/ray/blob/f92928c9cfcbbf80c3a8534ca4911de1b44069c0/python/setup.py#L518-L530
|
||||
flash_attn_version = "2.5.6"
|
||||
install_dir = os.path.join(ROOT_DIR, THIRDPARTY_SUBDIR)
|
||||
subprocess.check_call(
|
||||
[
|
||||
sys.executable,
|
||||
"-m",
|
||||
"pip",
|
||||
"install",
|
||||
"-q",
|
||||
f"--target={install_dir}",
|
||||
"einops", # Dependency of flash-attn.
|
||||
f"flash-attn=={flash_attn_version}",
|
||||
"--no-dependencies", # Required to avoid re-installing torch.
|
||||
],
|
||||
env=dict(os.environ, CC="gcc"),
|
||||
)
|
||||
|
||||
# Copy the FlashAttention package into the vLLM package after build.
|
||||
class build_ext(BuildExtension):
|
||||
|
||||
def run(self):
|
||||
super().run()
|
||||
target_dir = os.path.join(self.build_lib, THIRDPARTY_SUBDIR)
|
||||
if not os.path.exists(target_dir):
|
||||
os.makedirs(target_dir)
|
||||
self.copy_tree(install_dir, target_dir)
|
||||
|
||||
class BinaryDistribution(setuptools.Distribution):
|
||||
|
||||
def has_ext_modules(self):
|
||||
return True
|
||||
|
||||
else:
|
||||
build_ext = BuildExtension
|
||||
BinaryDistribution = setuptools.Distribution
|
||||
if _is_neuron():
|
||||
neuronxcc_version = get_neuronxcc_version()
|
||||
elif _is_neuron():
|
||||
neuronxcc_version = get_neuronxcc_version()
|
||||
|
||||
vllm_extension_sources = [
|
||||
"csrc/cache_kernels.cu",
|
||||
@ -509,7 +468,6 @@ setuptools.setup(
|
||||
python_requires=">=3.8",
|
||||
install_requires=get_requirements(),
|
||||
ext_modules=ext_modules,
|
||||
cmdclass={"build_ext": build_ext} if not _is_neuron() else {},
|
||||
distclass=BinaryDistribution,
|
||||
cmdclass={"build_ext": BuildExtension} if not _is_neuron() else {},
|
||||
package_data=package_data,
|
||||
)
|
||||
|
@ -1,28 +1,12 @@
|
||||
"""vLLM: a high-throughput and memory-efficient inference engine for LLMs"""
|
||||
|
||||
|
||||
# Adapted from https://github.com/ray-project/ray/blob/f92928c9cfcbbf80c3a8534ca4911de1b44069c0/python/ray/__init__.py#L11
|
||||
def _configure_system():
|
||||
import os
|
||||
import sys
|
||||
|
||||
# Importing flash-attn.
|
||||
thirdparty_files = os.path.join(os.path.abspath(os.path.dirname(__file__)),
|
||||
"thirdparty_files")
|
||||
sys.path.insert(0, thirdparty_files)
|
||||
|
||||
|
||||
_configure_system()
|
||||
# Delete configuration function.
|
||||
del _configure_system
|
||||
|
||||
from vllm.engine.arg_utils import AsyncEngineArgs, EngineArgs # noqa: E402
|
||||
from vllm.engine.async_llm_engine import AsyncLLMEngine # noqa: E402
|
||||
from vllm.engine.llm_engine import LLMEngine # noqa: E402
|
||||
from vllm.engine.ray_utils import initialize_cluster # noqa: E402
|
||||
from vllm.entrypoints.llm import LLM # noqa: E402
|
||||
from vllm.outputs import CompletionOutput, RequestOutput # noqa: E402
|
||||
from vllm.sampling_params import SamplingParams # noqa: E402
|
||||
from vllm.engine.arg_utils import AsyncEngineArgs, EngineArgs
|
||||
from vllm.engine.async_llm_engine import AsyncLLMEngine
|
||||
from vllm.engine.llm_engine import LLMEngine
|
||||
from vllm.engine.ray_utils import initialize_cluster
|
||||
from vllm.entrypoints.llm import LLM
|
||||
from vllm.outputs import CompletionOutput, RequestOutput
|
||||
from vllm.sampling_params import SamplingParams
|
||||
|
||||
__version__ = "0.3.3"
|
||||
|
||||
|
@ -1,12 +1,16 @@
|
||||
"""Attention layer."""
|
||||
from functools import lru_cache
|
||||
from typing import List, Optional
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.input_metadata import InputMetadata
|
||||
from vllm.utils import is_hip
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
class Attention(nn.Module):
|
||||
"""Attention layer.
|
||||
@ -30,17 +34,12 @@ class Attention(nn.Module):
|
||||
sliding_window: Optional[int] = None,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
if (not is_hip() and torch.cuda.get_device_capability()[0] >= 8 and
|
||||
torch.get_default_dtype() in (torch.float16, torch.bfloat16)):
|
||||
# Ampere or later NVIDIA GPUs.
|
||||
# NOTE(woosuk): FlashAttention does not support FP32.
|
||||
if _use_flash_attn():
|
||||
from vllm.model_executor.layers.attention.backends.flash_attn import FlashAttentionBackend
|
||||
self.backend = FlashAttentionBackend(num_heads, head_size, scale,
|
||||
num_kv_heads, alibi_slopes,
|
||||
sliding_window)
|
||||
else:
|
||||
# Turing and Volta NVIDIA GPUs or AMD GPUs.
|
||||
# Or FP32 on any GPU.
|
||||
from vllm.model_executor.layers.attention.backends.xformers import XFormersBackend
|
||||
self.backend = XFormersBackend(num_heads, head_size, scale,
|
||||
num_kv_heads, alibi_slopes,
|
||||
@ -57,3 +56,29 @@ class Attention(nn.Module):
|
||||
) -> torch.Tensor:
|
||||
return self.backend.forward(query, key, value, key_cache, value_cache,
|
||||
input_metadata)
|
||||
|
||||
|
||||
@lru_cache(maxsize=1)
|
||||
def _use_flash_attn() -> bool:
|
||||
try:
|
||||
import flash_attn # noqa: F401
|
||||
except ImportError:
|
||||
logger.info("flash_attn is not found. Using xformers backend.")
|
||||
return False
|
||||
|
||||
if is_hip():
|
||||
# AMD GPUs.
|
||||
return False
|
||||
if torch.cuda.get_device_capability()[0] < 8:
|
||||
# Volta and Turing NVIDIA GPUs.
|
||||
logger.info("flash_attn is not supported on Turing or older GPUs. "
|
||||
"Using xformers backend.")
|
||||
return False
|
||||
if torch.get_default_dtype() not in (torch.float16, torch.bfloat16):
|
||||
logger.info(
|
||||
"flash_attn only supports torch.float16 or torch.bfloat16. "
|
||||
"Using xformers backend.")
|
||||
return False
|
||||
|
||||
logger.info("Using flash_attn backend.")
|
||||
return True
|
||||
|
@ -1,7 +1,6 @@
|
||||
"""Attention layer with Flash and PagedAttention."""
|
||||
from typing import List, Optional
|
||||
|
||||
# NOTE(woosuk): This imports flash_attn under vllm/thirdparty_files/.
|
||||
from flash_attn import flash_attn_func
|
||||
import torch
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user