[Kernel] Build flash-attn from source (#8245)

This commit is contained in:
Luka Govedič 2024-09-21 02:27:10 -04:00 committed by GitHub
parent 0faab90eb0
commit 71c60491f2
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
9 changed files with 124 additions and 41 deletions

View File

@ -15,5 +15,6 @@ $python_executable -m pip install -r requirements-cuda.txt
export MAX_JOBS=1
# Make sure release wheels are built for the following architectures
export TORCH_CUDA_ARCH_LIST="7.0 7.5 8.0 8.6 8.9 9.0+PTX"
export VLLM_FA_CMAKE_GPU_ARCHES="80-real;90-real"
# Build
$python_executable setup.py bdist_wheel --dist-dir=dist

5
.gitignore vendored
View File

@ -1,6 +1,9 @@
# vllm commit id, generated by setup.py
vllm/commit_id.py
# vllm-flash-attn built from source
vllm/vllm_flash_attn/
# Byte-compiled / optimized / DLL files
__pycache__/
*.py[cod]
@ -12,6 +15,8 @@ __pycache__/
# Distribution / packaging
.Python
build/
cmake-build-*/
CMakeUserPresets.json
develop-eggs/
dist/
downloads/

View File

@ -1,5 +1,16 @@
cmake_minimum_required(VERSION 3.26)
# When building directly using CMake, make sure you run the install step
# (it places the .so files in the correct location).
#
# Example:
# mkdir build && cd build
# cmake -G Ninja -DVLLM_PYTHON_EXECUTABLE=`which python3` -DCMAKE_INSTALL_PREFIX=.. ..
# cmake --build . --target install
#
# If you want to only build one target, make sure to install it manually:
# cmake --build . --target _C
# cmake --install . --component _C
project(vllm_extensions LANGUAGES CXX)
# CUDA by default, can be overridden by using -DVLLM_TARGET_DEVICE=... (used by setup.py)
@ -13,6 +24,9 @@ include(${CMAKE_CURRENT_LIST_DIR}/cmake/utils.cmake)
# Suppress potential warnings about unused manually-specified variables
set(ignoreMe "${VLLM_PYTHON_PATH}")
# Prevent installation of dependencies (cutlass) by default.
install(CODE "set(CMAKE_INSTALL_LOCAL_ONLY TRUE)" ALL_COMPONENTS)
#
# Supported python versions. These versions will be searched in order, the
# first match will be selected. These should be kept in sync with setup.py.
@ -70,19 +84,6 @@ endif()
find_package(Torch REQUIRED)
#
# Add the `default` target which detects which extensions should be
# built based on platform/architecture. This is the same logic that
# setup.py uses to select which extensions should be built and should
# be kept in sync.
#
# The `default` target makes direct use of cmake easier since knowledge
# of which extensions are supported has been factored in, e.g.
#
# mkdir build && cd build
# cmake -G Ninja -DVLLM_PYTHON_EXECUTABLE=`which python3` -DCMAKE_LIBRARY_OUTPUT_DIRECTORY=../vllm ..
# cmake --build . --target default
#
add_custom_target(default)
message(STATUS "Enabling core extension.")
# Define _core_C extension
@ -100,8 +101,6 @@ define_gpu_extension_target(
USE_SABI 3
WITH_SOABI)
add_dependencies(default _core_C)
#
# Forward the non-CUDA device extensions to external CMake scripts.
#
@ -167,6 +166,8 @@ if(NVCC_THREADS AND VLLM_GPU_LANG STREQUAL "CUDA")
list(APPEND VLLM_GPU_FLAGS "--threads=${NVCC_THREADS}")
endif()
include(FetchContent)
#
# Define other extension targets
#
@ -190,7 +191,6 @@ set(VLLM_EXT_SRC
"csrc/torch_bindings.cpp")
if(VLLM_GPU_LANG STREQUAL "CUDA")
include(FetchContent)
SET(CUTLASS_ENABLE_HEADERS_ONLY ON CACHE BOOL "Enable only the header library")
FetchContent_Declare(
cutlass
@ -283,6 +283,7 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
csrc/quantization/machete/machete_pytorch.cu)
endif()
message(STATUS "Enabling C extension.")
define_gpu_extension_target(
_C
DESTINATION vllm
@ -313,6 +314,7 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
"csrc/moe/marlin_moe_ops.cu")
endif()
message(STATUS "Enabling moe extension.")
define_gpu_extension_target(
_moe_C
DESTINATION vllm
@ -323,7 +325,6 @@ define_gpu_extension_target(
USE_SABI 3
WITH_SOABI)
if(VLLM_GPU_LANG STREQUAL "HIP")
#
# _rocm_C extension
@ -343,16 +344,63 @@ if(VLLM_GPU_LANG STREQUAL "HIP")
WITH_SOABI)
endif()
# vllm-flash-attn currently only supported on CUDA
if (NOT VLLM_TARGET_DEVICE STREQUAL "cuda")
return()
endif ()
if(VLLM_GPU_LANG STREQUAL "CUDA" OR VLLM_GPU_LANG STREQUAL "HIP")
message(STATUS "Enabling C extension.")
add_dependencies(default _C)
#
# Build vLLM flash attention from source
#
# IMPORTANT: This has to be the last thing we do, because vllm-flash-attn uses the same macros/functions as vLLM.
# Because functions all belong to the global scope, vllm-flash-attn's functions overwrite vLLMs.
# They should be identical but if they aren't, this is a massive footgun.
#
# The vllm-flash-attn install rules are nested under vllm to make sure the library gets installed in the correct place.
# To only install vllm-flash-attn, use --component vllm_flash_attn_c.
# If no component is specified, vllm-flash-attn is still installed.
message(STATUS "Enabling moe extension.")
add_dependencies(default _moe_C)
# If VLLM_FLASH_ATTN_SRC_DIR is set, vllm-flash-attn is installed from that directory instead of downloading.
# This is to enable local development of vllm-flash-attn within vLLM.
# It can be set as an environment variable or passed as a cmake argument.
# The environment variable takes precedence.
if (DEFINED ENV{VLLM_FLASH_ATTN_SRC_DIR})
set(VLLM_FLASH_ATTN_SRC_DIR $ENV{VLLM_FLASH_ATTN_SRC_DIR})
endif()
if(VLLM_GPU_LANG STREQUAL "HIP")
message(STATUS "Enabling rocm extension.")
add_dependencies(default _rocm_C)
if(VLLM_FLASH_ATTN_SRC_DIR)
FetchContent_Declare(vllm-flash-attn SOURCE_DIR ${VLLM_FLASH_ATTN_SRC_DIR})
else()
FetchContent_Declare(
vllm-flash-attn
GIT_REPOSITORY https://github.com/vllm-project/flash-attention.git
GIT_TAG 013f0c4fc47e6574060879d9734c1df8c5c273bd
GIT_PROGRESS TRUE
)
endif()
# Set the parent build flag so that the vllm-flash-attn library does not redo compile flag and arch initialization.
set(VLLM_PARENT_BUILD ON)
# Make sure vllm-flash-attn install rules are nested under vllm/
install(CODE "set(CMAKE_INSTALL_LOCAL_ONLY FALSE)" COMPONENT vllm_flash_attn_c)
install(CODE "set(OLD_CMAKE_INSTALL_PREFIX \"\${CMAKE_INSTALL_PREFIX}\")" COMPONENT vllm_flash_attn_c)
install(CODE "set(CMAKE_INSTALL_PREFIX \"\${CMAKE_INSTALL_PREFIX}/vllm/\")" COMPONENT vllm_flash_attn_c)
# Fetch the vllm-flash-attn library
FetchContent_MakeAvailable(vllm-flash-attn)
message(STATUS "vllm-flash-attn is available at ${vllm-flash-attn_SOURCE_DIR}")
# Restore the install prefix
install(CODE "set(CMAKE_INSTALL_PREFIX \"\${OLD_CMAKE_INSTALL_PREFIX}\")" COMPONENT vllm_flash_attn_c)
install(CODE "set(CMAKE_INSTALL_LOCAL_ONLY TRUE)" COMPONENT vllm_flash_attn_c)
# Copy over the vllm-flash-attn python files
install(
DIRECTORY ${vllm-flash-attn_SOURCE_DIR}/vllm_flash_attn/
DESTINATION vllm/vllm_flash_attn
COMPONENT vllm_flash_attn_c
FILES_MATCHING PATTERN "*.py"
)
# Nothing after vllm-flash-attn, see comment about macros above

View File

@ -48,6 +48,9 @@ RUN --mount=type=cache,target=/root/.cache/pip \
# see https://github.com/pytorch/pytorch/pull/123243
ARG torch_cuda_arch_list='7.0 7.5 8.0 8.6 8.9 9.0+PTX'
ENV TORCH_CUDA_ARCH_LIST=${torch_cuda_arch_list}
# Override the arch list for flash-attn to reduce the binary size
ARG vllm_fa_cmake_gpu_arches='80-real;90-real'
ENV VLLM_FA_CMAKE_GPU_ARCHES=${vllm_fa_cmake_gpu_arches}
#################### BASE BUILD IMAGE ####################
#################### WHEEL BUILD IMAGE ####################

View File

@ -364,5 +364,5 @@ function (define_gpu_extension_target GPU_MOD_NAME)
target_link_libraries(${GPU_MOD_NAME} PRIVATE ${TORCH_LIBRARIES})
endif()
install(TARGETS ${GPU_MOD_NAME} LIBRARY DESTINATION ${GPU_DESTINATION})
install(TARGETS ${GPU_MOD_NAME} LIBRARY DESTINATION ${GPU_DESTINATION} COMPONENT ${GPU_MOD_NAME})
endfunction()

View File

@ -8,4 +8,3 @@ torch == 2.4.0
# These must be updated alongside torch
torchvision == 0.19 # Required for phi3v processor. See https://github.com/pytorch/vision?tab=readme-ov-file#installation for corresponding version
xformers == 0.0.27.post2; platform_system == 'Linux' and platform_machine == 'x86_64' # Requires PyTorch 2.4.0
vllm-flash-attn == 2.6.1; platform_system == 'Linux' and platform_machine == 'x86_64' # Requires PyTorch 2.4.0

View File

@ -6,6 +6,7 @@ import re
import subprocess
import sys
import warnings
from pathlib import Path
from shutil import which
from typing import Dict, List
@ -152,15 +153,8 @@ class cmake_build_ext(build_ext):
default_cfg = "Debug" if self.debug else "RelWithDebInfo"
cfg = envs.CMAKE_BUILD_TYPE or default_cfg
# where .so files will be written, should be the same for all extensions
# that use the same CMakeLists.txt.
outdir = os.path.abspath(
os.path.dirname(self.get_ext_fullpath(ext.name)))
cmake_args = [
'-DCMAKE_BUILD_TYPE={}'.format(cfg),
'-DCMAKE_LIBRARY_OUTPUT_DIRECTORY={}'.format(outdir),
'-DCMAKE_ARCHIVE_OUTPUT_DIRECTORY={}'.format(self.build_temp),
'-DVLLM_TARGET_DEVICE={}'.format(VLLM_TARGET_DEVICE),
]
@ -224,10 +218,12 @@ class cmake_build_ext(build_ext):
os.makedirs(self.build_temp)
targets = []
target_name = lambda s: remove_prefix(remove_prefix(s, "vllm."),
"vllm_flash_attn.")
# Build all the extensions
for ext in self.extensions:
self.configure(ext)
targets.append(remove_prefix(ext.name, "vllm."))
targets.append(target_name(ext.name))
num_jobs, _ = self.compute_num_jobs()
@ -240,6 +236,28 @@ class cmake_build_ext(build_ext):
subprocess.check_call(["cmake", *build_args], cwd=self.build_temp)
# Install the libraries
for ext in self.extensions:
# Install the extension into the proper location
outdir = Path(self.get_ext_fullpath(ext.name)).parent.absolute()
# Skip if the install directory is the same as the build directory
if outdir == self.build_temp:
continue
# CMake appends the extension prefix to the install path,
# and outdir already contains that prefix, so we need to remove it.
prefix = outdir
for i in range(ext.name.count('.')):
prefix = prefix.parent
# prefix here should actually be the same for all components
install_args = [
"cmake", "--install", ".", "--prefix", prefix, "--component",
target_name(ext.name)
]
subprocess.check_call(install_args, cwd=self.build_temp)
def _no_device() -> bool:
return VLLM_TARGET_DEVICE == "empty"
@ -467,6 +485,10 @@ if _is_cuda() or _is_hip():
if _is_hip():
ext_modules.append(CMakeExtension(name="vllm._rocm_C"))
if _is_cuda():
ext_modules.append(
CMakeExtension(name="vllm.vllm_flash_attn.vllm_flash_attn_c"))
if _build_custom_ops():
ext_modules.append(CMakeExtension(name="vllm._C"))

View File

@ -19,8 +19,13 @@ if TYPE_CHECKING:
from vllm.worker.model_runner import (ModelInputForGPUBuilder,
ModelInputForGPUWithSamplingMetadata)
from vllm_flash_attn import flash_attn_varlen_func as _flash_attn_varlen_func
from vllm_flash_attn import flash_attn_with_kvcache as _flash_attn_with_kvcache
# yapf: disable
from vllm.vllm_flash_attn import (
flash_attn_varlen_func as _flash_attn_varlen_func)
from vllm.vllm_flash_attn import (
flash_attn_with_kvcache as _flash_attn_with_kvcache)
# yapf: enable
@torch.library.custom_op("vllm::flash_attn_varlen_func", mutates_args=[])

View File

@ -244,8 +244,7 @@ def which_attn_to_use(
# FlashAttn is valid for the model, checking if the package is installed.
if selected_backend == _Backend.FLASH_ATTN:
try:
import vllm_flash_attn # noqa: F401
import vllm.vllm_flash_attn # noqa: F401
from vllm.attention.backends.flash_attn import ( # noqa: F401
FlashAttentionBackend)
@ -258,8 +257,9 @@ def which_attn_to_use(
except ImportError:
logger.info(
"Cannot use FlashAttention-2 backend because the "
"vllm_flash_attn package is not found. "
"`pip install vllm-flash-attn` for better performance.")
"vllm.vllm_flash_attn package is not found. "
"Make sure that vllm_flash_attn was built and installed "
"(on by default).")
selected_backend = _Backend.XFORMERS
return selected_backend