[Kernel] Build flash-attn from source (#8245)

This commit is contained in:
Luka Govedič 2024-09-21 02:27:10 -04:00 committed by GitHub
parent 0faab90eb0
commit 71c60491f2
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
9 changed files with 124 additions and 41 deletions

View File

@ -15,5 +15,6 @@ $python_executable -m pip install -r requirements-cuda.txt
export MAX_JOBS=1 export MAX_JOBS=1
# Make sure release wheels are built for the following architectures # Make sure release wheels are built for the following architectures
export TORCH_CUDA_ARCH_LIST="7.0 7.5 8.0 8.6 8.9 9.0+PTX" export TORCH_CUDA_ARCH_LIST="7.0 7.5 8.0 8.6 8.9 9.0+PTX"
export VLLM_FA_CMAKE_GPU_ARCHES="80-real;90-real"
# Build # Build
$python_executable setup.py bdist_wheel --dist-dir=dist $python_executable setup.py bdist_wheel --dist-dir=dist

5
.gitignore vendored
View File

@ -1,6 +1,9 @@
# vllm commit id, generated by setup.py # vllm commit id, generated by setup.py
vllm/commit_id.py vllm/commit_id.py
# vllm-flash-attn built from source
vllm/vllm_flash_attn/
# Byte-compiled / optimized / DLL files # Byte-compiled / optimized / DLL files
__pycache__/ __pycache__/
*.py[cod] *.py[cod]
@ -12,6 +15,8 @@ __pycache__/
# Distribution / packaging # Distribution / packaging
.Python .Python
build/ build/
cmake-build-*/
CMakeUserPresets.json
develop-eggs/ develop-eggs/
dist/ dist/
downloads/ downloads/

View File

@ -1,5 +1,16 @@
cmake_minimum_required(VERSION 3.26) cmake_minimum_required(VERSION 3.26)
# When building directly using CMake, make sure you run the install step
# (it places the .so files in the correct location).
#
# Example:
# mkdir build && cd build
# cmake -G Ninja -DVLLM_PYTHON_EXECUTABLE=`which python3` -DCMAKE_INSTALL_PREFIX=.. ..
# cmake --build . --target install
#
# If you want to only build one target, make sure to install it manually:
# cmake --build . --target _C
# cmake --install . --component _C
project(vllm_extensions LANGUAGES CXX) project(vllm_extensions LANGUAGES CXX)
# CUDA by default, can be overridden by using -DVLLM_TARGET_DEVICE=... (used by setup.py) # CUDA by default, can be overridden by using -DVLLM_TARGET_DEVICE=... (used by setup.py)
@ -13,6 +24,9 @@ include(${CMAKE_CURRENT_LIST_DIR}/cmake/utils.cmake)
# Suppress potential warnings about unused manually-specified variables # Suppress potential warnings about unused manually-specified variables
set(ignoreMe "${VLLM_PYTHON_PATH}") set(ignoreMe "${VLLM_PYTHON_PATH}")
# Prevent installation of dependencies (cutlass) by default.
install(CODE "set(CMAKE_INSTALL_LOCAL_ONLY TRUE)" ALL_COMPONENTS)
# #
# Supported python versions. These versions will be searched in order, the # Supported python versions. These versions will be searched in order, the
# first match will be selected. These should be kept in sync with setup.py. # first match will be selected. These should be kept in sync with setup.py.
@ -70,19 +84,6 @@ endif()
find_package(Torch REQUIRED) find_package(Torch REQUIRED)
# #
# Add the `default` target which detects which extensions should be
# built based on platform/architecture. This is the same logic that
# setup.py uses to select which extensions should be built and should
# be kept in sync.
#
# The `default` target makes direct use of cmake easier since knowledge
# of which extensions are supported has been factored in, e.g.
#
# mkdir build && cd build
# cmake -G Ninja -DVLLM_PYTHON_EXECUTABLE=`which python3` -DCMAKE_LIBRARY_OUTPUT_DIRECTORY=../vllm ..
# cmake --build . --target default
#
add_custom_target(default)
message(STATUS "Enabling core extension.") message(STATUS "Enabling core extension.")
# Define _core_C extension # Define _core_C extension
@ -100,8 +101,6 @@ define_gpu_extension_target(
USE_SABI 3 USE_SABI 3
WITH_SOABI) WITH_SOABI)
add_dependencies(default _core_C)
# #
# Forward the non-CUDA device extensions to external CMake scripts. # Forward the non-CUDA device extensions to external CMake scripts.
# #
@ -167,6 +166,8 @@ if(NVCC_THREADS AND VLLM_GPU_LANG STREQUAL "CUDA")
list(APPEND VLLM_GPU_FLAGS "--threads=${NVCC_THREADS}") list(APPEND VLLM_GPU_FLAGS "--threads=${NVCC_THREADS}")
endif() endif()
include(FetchContent)
# #
# Define other extension targets # Define other extension targets
# #
@ -190,7 +191,6 @@ set(VLLM_EXT_SRC
"csrc/torch_bindings.cpp") "csrc/torch_bindings.cpp")
if(VLLM_GPU_LANG STREQUAL "CUDA") if(VLLM_GPU_LANG STREQUAL "CUDA")
include(FetchContent)
SET(CUTLASS_ENABLE_HEADERS_ONLY ON CACHE BOOL "Enable only the header library") SET(CUTLASS_ENABLE_HEADERS_ONLY ON CACHE BOOL "Enable only the header library")
FetchContent_Declare( FetchContent_Declare(
cutlass cutlass
@ -283,6 +283,7 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
csrc/quantization/machete/machete_pytorch.cu) csrc/quantization/machete/machete_pytorch.cu)
endif() endif()
message(STATUS "Enabling C extension.")
define_gpu_extension_target( define_gpu_extension_target(
_C _C
DESTINATION vllm DESTINATION vllm
@ -313,6 +314,7 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
"csrc/moe/marlin_moe_ops.cu") "csrc/moe/marlin_moe_ops.cu")
endif() endif()
message(STATUS "Enabling moe extension.")
define_gpu_extension_target( define_gpu_extension_target(
_moe_C _moe_C
DESTINATION vllm DESTINATION vllm
@ -323,7 +325,6 @@ define_gpu_extension_target(
USE_SABI 3 USE_SABI 3
WITH_SOABI) WITH_SOABI)
if(VLLM_GPU_LANG STREQUAL "HIP") if(VLLM_GPU_LANG STREQUAL "HIP")
# #
# _rocm_C extension # _rocm_C extension
@ -343,16 +344,63 @@ if(VLLM_GPU_LANG STREQUAL "HIP")
WITH_SOABI) WITH_SOABI)
endif() endif()
# vllm-flash-attn currently only supported on CUDA
if (NOT VLLM_TARGET_DEVICE STREQUAL "cuda")
return()
endif ()
if(VLLM_GPU_LANG STREQUAL "CUDA" OR VLLM_GPU_LANG STREQUAL "HIP") #
message(STATUS "Enabling C extension.") # Build vLLM flash attention from source
add_dependencies(default _C) #
# IMPORTANT: This has to be the last thing we do, because vllm-flash-attn uses the same macros/functions as vLLM.
# Because functions all belong to the global scope, vllm-flash-attn's functions overwrite vLLMs.
# They should be identical but if they aren't, this is a massive footgun.
#
# The vllm-flash-attn install rules are nested under vllm to make sure the library gets installed in the correct place.
# To only install vllm-flash-attn, use --component vllm_flash_attn_c.
# If no component is specified, vllm-flash-attn is still installed.
message(STATUS "Enabling moe extension.") # If VLLM_FLASH_ATTN_SRC_DIR is set, vllm-flash-attn is installed from that directory instead of downloading.
add_dependencies(default _moe_C) # This is to enable local development of vllm-flash-attn within vLLM.
# It can be set as an environment variable or passed as a cmake argument.
# The environment variable takes precedence.
if (DEFINED ENV{VLLM_FLASH_ATTN_SRC_DIR})
set(VLLM_FLASH_ATTN_SRC_DIR $ENV{VLLM_FLASH_ATTN_SRC_DIR})
endif() endif()
if(VLLM_GPU_LANG STREQUAL "HIP") if(VLLM_FLASH_ATTN_SRC_DIR)
message(STATUS "Enabling rocm extension.") FetchContent_Declare(vllm-flash-attn SOURCE_DIR ${VLLM_FLASH_ATTN_SRC_DIR})
add_dependencies(default _rocm_C) else()
FetchContent_Declare(
vllm-flash-attn
GIT_REPOSITORY https://github.com/vllm-project/flash-attention.git
GIT_TAG 013f0c4fc47e6574060879d9734c1df8c5c273bd
GIT_PROGRESS TRUE
)
endif() endif()
# Set the parent build flag so that the vllm-flash-attn library does not redo compile flag and arch initialization.
set(VLLM_PARENT_BUILD ON)
# Make sure vllm-flash-attn install rules are nested under vllm/
install(CODE "set(CMAKE_INSTALL_LOCAL_ONLY FALSE)" COMPONENT vllm_flash_attn_c)
install(CODE "set(OLD_CMAKE_INSTALL_PREFIX \"\${CMAKE_INSTALL_PREFIX}\")" COMPONENT vllm_flash_attn_c)
install(CODE "set(CMAKE_INSTALL_PREFIX \"\${CMAKE_INSTALL_PREFIX}/vllm/\")" COMPONENT vllm_flash_attn_c)
# Fetch the vllm-flash-attn library
FetchContent_MakeAvailable(vllm-flash-attn)
message(STATUS "vllm-flash-attn is available at ${vllm-flash-attn_SOURCE_DIR}")
# Restore the install prefix
install(CODE "set(CMAKE_INSTALL_PREFIX \"\${OLD_CMAKE_INSTALL_PREFIX}\")" COMPONENT vllm_flash_attn_c)
install(CODE "set(CMAKE_INSTALL_LOCAL_ONLY TRUE)" COMPONENT vllm_flash_attn_c)
# Copy over the vllm-flash-attn python files
install(
DIRECTORY ${vllm-flash-attn_SOURCE_DIR}/vllm_flash_attn/
DESTINATION vllm/vllm_flash_attn
COMPONENT vllm_flash_attn_c
FILES_MATCHING PATTERN "*.py"
)
# Nothing after vllm-flash-attn, see comment about macros above

View File

@ -48,6 +48,9 @@ RUN --mount=type=cache,target=/root/.cache/pip \
# see https://github.com/pytorch/pytorch/pull/123243 # see https://github.com/pytorch/pytorch/pull/123243
ARG torch_cuda_arch_list='7.0 7.5 8.0 8.6 8.9 9.0+PTX' ARG torch_cuda_arch_list='7.0 7.5 8.0 8.6 8.9 9.0+PTX'
ENV TORCH_CUDA_ARCH_LIST=${torch_cuda_arch_list} ENV TORCH_CUDA_ARCH_LIST=${torch_cuda_arch_list}
# Override the arch list for flash-attn to reduce the binary size
ARG vllm_fa_cmake_gpu_arches='80-real;90-real'
ENV VLLM_FA_CMAKE_GPU_ARCHES=${vllm_fa_cmake_gpu_arches}
#################### BASE BUILD IMAGE #################### #################### BASE BUILD IMAGE ####################
#################### WHEEL BUILD IMAGE #################### #################### WHEEL BUILD IMAGE ####################

View File

@ -364,5 +364,5 @@ function (define_gpu_extension_target GPU_MOD_NAME)
target_link_libraries(${GPU_MOD_NAME} PRIVATE ${TORCH_LIBRARIES}) target_link_libraries(${GPU_MOD_NAME} PRIVATE ${TORCH_LIBRARIES})
endif() endif()
install(TARGETS ${GPU_MOD_NAME} LIBRARY DESTINATION ${GPU_DESTINATION}) install(TARGETS ${GPU_MOD_NAME} LIBRARY DESTINATION ${GPU_DESTINATION} COMPONENT ${GPU_MOD_NAME})
endfunction() endfunction()

View File

@ -8,4 +8,3 @@ torch == 2.4.0
# These must be updated alongside torch # These must be updated alongside torch
torchvision == 0.19 # Required for phi3v processor. See https://github.com/pytorch/vision?tab=readme-ov-file#installation for corresponding version torchvision == 0.19 # Required for phi3v processor. See https://github.com/pytorch/vision?tab=readme-ov-file#installation for corresponding version
xformers == 0.0.27.post2; platform_system == 'Linux' and platform_machine == 'x86_64' # Requires PyTorch 2.4.0 xformers == 0.0.27.post2; platform_system == 'Linux' and platform_machine == 'x86_64' # Requires PyTorch 2.4.0
vllm-flash-attn == 2.6.1; platform_system == 'Linux' and platform_machine == 'x86_64' # Requires PyTorch 2.4.0

View File

@ -6,6 +6,7 @@ import re
import subprocess import subprocess
import sys import sys
import warnings import warnings
from pathlib import Path
from shutil import which from shutil import which
from typing import Dict, List from typing import Dict, List
@ -152,15 +153,8 @@ class cmake_build_ext(build_ext):
default_cfg = "Debug" if self.debug else "RelWithDebInfo" default_cfg = "Debug" if self.debug else "RelWithDebInfo"
cfg = envs.CMAKE_BUILD_TYPE or default_cfg cfg = envs.CMAKE_BUILD_TYPE or default_cfg
# where .so files will be written, should be the same for all extensions
# that use the same CMakeLists.txt.
outdir = os.path.abspath(
os.path.dirname(self.get_ext_fullpath(ext.name)))
cmake_args = [ cmake_args = [
'-DCMAKE_BUILD_TYPE={}'.format(cfg), '-DCMAKE_BUILD_TYPE={}'.format(cfg),
'-DCMAKE_LIBRARY_OUTPUT_DIRECTORY={}'.format(outdir),
'-DCMAKE_ARCHIVE_OUTPUT_DIRECTORY={}'.format(self.build_temp),
'-DVLLM_TARGET_DEVICE={}'.format(VLLM_TARGET_DEVICE), '-DVLLM_TARGET_DEVICE={}'.format(VLLM_TARGET_DEVICE),
] ]
@ -224,10 +218,12 @@ class cmake_build_ext(build_ext):
os.makedirs(self.build_temp) os.makedirs(self.build_temp)
targets = [] targets = []
target_name = lambda s: remove_prefix(remove_prefix(s, "vllm."),
"vllm_flash_attn.")
# Build all the extensions # Build all the extensions
for ext in self.extensions: for ext in self.extensions:
self.configure(ext) self.configure(ext)
targets.append(remove_prefix(ext.name, "vllm.")) targets.append(target_name(ext.name))
num_jobs, _ = self.compute_num_jobs() num_jobs, _ = self.compute_num_jobs()
@ -240,6 +236,28 @@ class cmake_build_ext(build_ext):
subprocess.check_call(["cmake", *build_args], cwd=self.build_temp) subprocess.check_call(["cmake", *build_args], cwd=self.build_temp)
# Install the libraries
for ext in self.extensions:
# Install the extension into the proper location
outdir = Path(self.get_ext_fullpath(ext.name)).parent.absolute()
# Skip if the install directory is the same as the build directory
if outdir == self.build_temp:
continue
# CMake appends the extension prefix to the install path,
# and outdir already contains that prefix, so we need to remove it.
prefix = outdir
for i in range(ext.name.count('.')):
prefix = prefix.parent
# prefix here should actually be the same for all components
install_args = [
"cmake", "--install", ".", "--prefix", prefix, "--component",
target_name(ext.name)
]
subprocess.check_call(install_args, cwd=self.build_temp)
def _no_device() -> bool: def _no_device() -> bool:
return VLLM_TARGET_DEVICE == "empty" return VLLM_TARGET_DEVICE == "empty"
@ -467,6 +485,10 @@ if _is_cuda() or _is_hip():
if _is_hip(): if _is_hip():
ext_modules.append(CMakeExtension(name="vllm._rocm_C")) ext_modules.append(CMakeExtension(name="vllm._rocm_C"))
if _is_cuda():
ext_modules.append(
CMakeExtension(name="vllm.vllm_flash_attn.vllm_flash_attn_c"))
if _build_custom_ops(): if _build_custom_ops():
ext_modules.append(CMakeExtension(name="vllm._C")) ext_modules.append(CMakeExtension(name="vllm._C"))

View File

@ -19,8 +19,13 @@ if TYPE_CHECKING:
from vllm.worker.model_runner import (ModelInputForGPUBuilder, from vllm.worker.model_runner import (ModelInputForGPUBuilder,
ModelInputForGPUWithSamplingMetadata) ModelInputForGPUWithSamplingMetadata)
from vllm_flash_attn import flash_attn_varlen_func as _flash_attn_varlen_func # yapf: disable
from vllm_flash_attn import flash_attn_with_kvcache as _flash_attn_with_kvcache from vllm.vllm_flash_attn import (
flash_attn_varlen_func as _flash_attn_varlen_func)
from vllm.vllm_flash_attn import (
flash_attn_with_kvcache as _flash_attn_with_kvcache)
# yapf: enable
@torch.library.custom_op("vllm::flash_attn_varlen_func", mutates_args=[]) @torch.library.custom_op("vllm::flash_attn_varlen_func", mutates_args=[])

View File

@ -244,8 +244,7 @@ def which_attn_to_use(
# FlashAttn is valid for the model, checking if the package is installed. # FlashAttn is valid for the model, checking if the package is installed.
if selected_backend == _Backend.FLASH_ATTN: if selected_backend == _Backend.FLASH_ATTN:
try: try:
import vllm_flash_attn # noqa: F401 import vllm.vllm_flash_attn # noqa: F401
from vllm.attention.backends.flash_attn import ( # noqa: F401 from vllm.attention.backends.flash_attn import ( # noqa: F401
FlashAttentionBackend) FlashAttentionBackend)
@ -258,8 +257,9 @@ def which_attn_to_use(
except ImportError: except ImportError:
logger.info( logger.info(
"Cannot use FlashAttention-2 backend because the " "Cannot use FlashAttention-2 backend because the "
"vllm_flash_attn package is not found. " "vllm.vllm_flash_attn package is not found. "
"`pip install vllm-flash-attn` for better performance.") "Make sure that vllm_flash_attn was built and installed "
"(on by default).")
selected_backend = _Backend.XFORMERS selected_backend = _Backend.XFORMERS
return selected_backend return selected_backend