[Misc] Use vllm-flash-attn instead of flash-attn (#4686)

This commit is contained in:
Woosuk Kwon 2024-05-08 13:15:34 -07:00 committed by GitHub
parent 230c4b38c1
commit 89579a201f
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
6 changed files with 16 additions and 31 deletions

View File

@ -87,23 +87,6 @@ RUN --mount=type=cache,target=/root/.cache/pip \
pip cache remove vllm_nccl*
#################### EXTENSION Build IMAGE ####################
#################### FLASH_ATTENTION Build IMAGE ####################
FROM dev as flash-attn-builder
# max jobs used for build
ARG max_jobs=2
ENV MAX_JOBS=${max_jobs}
# flash attention version
ARG flash_attn_version=v2.5.8
ENV FLASH_ATTN_VERSION=${flash_attn_version}
WORKDIR /usr/src/flash-attention-v2
# Download the wheel or build it if a pre-compiled release doesn't exist
RUN pip --verbose wheel flash-attn==${FLASH_ATTN_VERSION} \
--no-build-isolation --no-deps --no-cache-dir
#################### FLASH_ATTENTION Build IMAGE ####################
#################### vLLM installation IMAGE ####################
# image with vLLM installed
FROM nvidia/cuda:12.4.1-base-ubuntu22.04 AS vllm-base
@ -122,10 +105,6 @@ RUN ldconfig /usr/local/cuda-12.4/compat/
RUN --mount=type=bind,from=build,src=/workspace/dist,target=/vllm-workspace/dist \
--mount=type=cache,target=/root/.cache/pip \
pip install dist/*.whl --verbose
RUN --mount=type=bind,from=flash-attn-builder,src=/usr/src/flash-attention-v2,target=/usr/src/flash-attention-v2 \
--mount=type=cache,target=/root/.cache/pip \
pip install /usr/src/flash-attention-v2/*.whl --no-cache-dir
#################### vLLM installation IMAGE ####################

View File

@ -7,3 +7,4 @@ nvidia-ml-py # for pynvml package
vllm-nccl-cu12>=2.18,<2.19 # for downloading nccl library
torch == 2.3.0
xformers == 0.0.26.post1 # Requires PyTorch 2.3.0
vllm-flash-attn == 2.5.8.post1 # Requires PyTorch 2.3.0

View File

@ -355,14 +355,18 @@ def get_requirements() -> List[str]:
if _is_cuda():
requirements = _read_requirements("requirements-cuda.txt")
cuda_major = torch.version.cuda.split(".")[0]
cuda_major, cuda_minor = torch.version.cuda.split(".")
modified_requirements = []
for req in requirements:
if "vllm-nccl-cu12" in req:
modified_requirements.append(
req.replace("vllm-nccl-cu12", f"vllm-nccl-cu{cuda_major}"))
else:
modified_requirements.append(req)
req = req.replace("vllm-nccl-cu12",
f"vllm-nccl-cu{cuda_major}")
elif ("vllm-flash-attn" in req
and not (cuda_major == "12" and cuda_minor == "1")):
# vllm-flash-attn is built only for CUDA 12.1.
# Skip for other versions.
continue
modified_requirements.append(req)
requirements = modified_requirements
elif _is_hip():
requirements = _read_requirements("requirements-rocm.txt")

View File

@ -8,7 +8,7 @@ from dataclasses import dataclass
from typing import List, Optional, Tuple, Type
import torch
from flash_attn import flash_attn_varlen_func
from vllm_flash_attn import flash_attn_varlen_func
from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
AttentionMetadata,

View File

@ -3,8 +3,8 @@ from typing import Any, Dict, List, Optional, Set, Tuple, Type
import flashinfer
import torch
from flash_attn import flash_attn_varlen_func
from flashinfer import BatchDecodeWithPagedKVCacheWrapper
from vllm_flash_attn import flash_attn_varlen_func
from vllm import _custom_ops as ops
from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,

View File

@ -76,11 +76,12 @@ def _which_attn_to_use(dtype: torch.dtype) -> _Backend:
return _Backend.XFORMERS
try:
import flash_attn # noqa: F401
import vllm_flash_attn # noqa: F401
except ImportError:
logger.info(
"Cannot use FlashAttention-2 backend because the flash_attn "
"package is not found. Please install it for better performance.")
"Cannot use FlashAttention-2 backend because the vllm_flash_attn "
"package is not found. `pip install vllm-flash-attn` for better "
"performance.")
return _Backend.XFORMERS
backend_by_env_var = envs.VLLM_ATTENTION_BACKEND