[Misc] Use vllm-flash-attn instead of flash-attn (#4686)
This commit is contained in:
parent
230c4b38c1
commit
89579a201f
21
Dockerfile
21
Dockerfile
@ -87,23 +87,6 @@ RUN --mount=type=cache,target=/root/.cache/pip \
|
||||
pip cache remove vllm_nccl*
|
||||
#################### EXTENSION Build IMAGE ####################
|
||||
|
||||
#################### FLASH_ATTENTION Build IMAGE ####################
|
||||
FROM dev as flash-attn-builder
|
||||
# max jobs used for build
|
||||
ARG max_jobs=2
|
||||
ENV MAX_JOBS=${max_jobs}
|
||||
# flash attention version
|
||||
ARG flash_attn_version=v2.5.8
|
||||
ENV FLASH_ATTN_VERSION=${flash_attn_version}
|
||||
|
||||
WORKDIR /usr/src/flash-attention-v2
|
||||
|
||||
# Download the wheel or build it if a pre-compiled release doesn't exist
|
||||
RUN pip --verbose wheel flash-attn==${FLASH_ATTN_VERSION} \
|
||||
--no-build-isolation --no-deps --no-cache-dir
|
||||
|
||||
#################### FLASH_ATTENTION Build IMAGE ####################
|
||||
|
||||
#################### vLLM installation IMAGE ####################
|
||||
# image with vLLM installed
|
||||
FROM nvidia/cuda:12.4.1-base-ubuntu22.04 AS vllm-base
|
||||
@ -122,10 +105,6 @@ RUN ldconfig /usr/local/cuda-12.4/compat/
|
||||
RUN --mount=type=bind,from=build,src=/workspace/dist,target=/vllm-workspace/dist \
|
||||
--mount=type=cache,target=/root/.cache/pip \
|
||||
pip install dist/*.whl --verbose
|
||||
|
||||
RUN --mount=type=bind,from=flash-attn-builder,src=/usr/src/flash-attention-v2,target=/usr/src/flash-attention-v2 \
|
||||
--mount=type=cache,target=/root/.cache/pip \
|
||||
pip install /usr/src/flash-attention-v2/*.whl --no-cache-dir
|
||||
#################### vLLM installation IMAGE ####################
|
||||
|
||||
|
||||
|
@ -7,3 +7,4 @@ nvidia-ml-py # for pynvml package
|
||||
vllm-nccl-cu12>=2.18,<2.19 # for downloading nccl library
|
||||
torch == 2.3.0
|
||||
xformers == 0.0.26.post1 # Requires PyTorch 2.3.0
|
||||
vllm-flash-attn == 2.5.8.post1 # Requires PyTorch 2.3.0
|
||||
|
14
setup.py
14
setup.py
@ -355,14 +355,18 @@ def get_requirements() -> List[str]:
|
||||
|
||||
if _is_cuda():
|
||||
requirements = _read_requirements("requirements-cuda.txt")
|
||||
cuda_major = torch.version.cuda.split(".")[0]
|
||||
cuda_major, cuda_minor = torch.version.cuda.split(".")
|
||||
modified_requirements = []
|
||||
for req in requirements:
|
||||
if "vllm-nccl-cu12" in req:
|
||||
modified_requirements.append(
|
||||
req.replace("vllm-nccl-cu12", f"vllm-nccl-cu{cuda_major}"))
|
||||
else:
|
||||
modified_requirements.append(req)
|
||||
req = req.replace("vllm-nccl-cu12",
|
||||
f"vllm-nccl-cu{cuda_major}")
|
||||
elif ("vllm-flash-attn" in req
|
||||
and not (cuda_major == "12" and cuda_minor == "1")):
|
||||
# vllm-flash-attn is built only for CUDA 12.1.
|
||||
# Skip for other versions.
|
||||
continue
|
||||
modified_requirements.append(req)
|
||||
requirements = modified_requirements
|
||||
elif _is_hip():
|
||||
requirements = _read_requirements("requirements-rocm.txt")
|
||||
|
@ -8,7 +8,7 @@ from dataclasses import dataclass
|
||||
from typing import List, Optional, Tuple, Type
|
||||
|
||||
import torch
|
||||
from flash_attn import flash_attn_varlen_func
|
||||
from vllm_flash_attn import flash_attn_varlen_func
|
||||
|
||||
from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
|
||||
AttentionMetadata,
|
||||
|
@ -3,8 +3,8 @@ from typing import Any, Dict, List, Optional, Set, Tuple, Type
|
||||
|
||||
import flashinfer
|
||||
import torch
|
||||
from flash_attn import flash_attn_varlen_func
|
||||
from flashinfer import BatchDecodeWithPagedKVCacheWrapper
|
||||
from vllm_flash_attn import flash_attn_varlen_func
|
||||
|
||||
from vllm import _custom_ops as ops
|
||||
from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
|
||||
|
@ -76,11 +76,12 @@ def _which_attn_to_use(dtype: torch.dtype) -> _Backend:
|
||||
return _Backend.XFORMERS
|
||||
|
||||
try:
|
||||
import flash_attn # noqa: F401
|
||||
import vllm_flash_attn # noqa: F401
|
||||
except ImportError:
|
||||
logger.info(
|
||||
"Cannot use FlashAttention-2 backend because the flash_attn "
|
||||
"package is not found. Please install it for better performance.")
|
||||
"Cannot use FlashAttention-2 backend because the vllm_flash_attn "
|
||||
"package is not found. `pip install vllm-flash-attn` for better "
|
||||
"performance.")
|
||||
return _Backend.XFORMERS
|
||||
|
||||
backend_by_env_var = envs.VLLM_ATTENTION_BACKEND
|
||||
|
Loading…
x
Reference in New Issue
Block a user