[Kernel] Flash Attention 3 Support (#12093)

Signed-off-by: Lucas Wilkinson <lwilkinson@neuralmagic.com>
This commit is contained in:
Lucas Wilkinson 2025-01-23 09:45:48 -05:00 committed by GitHub
parent c5b4b11d7f
commit 978b45f399
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
8 changed files with 151 additions and 83 deletions

View File

@ -24,9 +24,6 @@ include(${CMAKE_CURRENT_LIST_DIR}/cmake/utils.cmake)
# Suppress potential warnings about unused manually-specified variables # Suppress potential warnings about unused manually-specified variables
set(ignoreMe "${VLLM_PYTHON_PATH}") set(ignoreMe "${VLLM_PYTHON_PATH}")
# Prevent installation of dependencies (cutlass) by default.
install(CODE "set(CMAKE_INSTALL_LOCAL_ONLY TRUE)" ALL_COMPONENTS)
# #
# Supported python versions. These versions will be searched in order, the # Supported python versions. These versions will be searched in order, the
# first match will be selected. These should be kept in sync with setup.py. # first match will be selected. These should be kept in sync with setup.py.
@ -535,7 +532,7 @@ if(VLLM_GPU_LANG STREQUAL "HIP")
endif() endif()
# vllm-flash-attn currently only supported on CUDA # vllm-flash-attn currently only supported on CUDA
if (NOT VLLM_TARGET_DEVICE STREQUAL "cuda") if (NOT VLLM_GPU_LANG STREQUAL "CUDA")
return() return()
endif () endif ()
@ -558,7 +555,7 @@ endif()
# They should be identical but if they aren't, this is a massive footgun. # They should be identical but if they aren't, this is a massive footgun.
# #
# The vllm-flash-attn install rules are nested under vllm to make sure the library gets installed in the correct place. # The vllm-flash-attn install rules are nested under vllm to make sure the library gets installed in the correct place.
# To only install vllm-flash-attn, use --component vllm_flash_attn_c. # To only install vllm-flash-attn, use --component _vllm_fa2_C (for FA2) or --component _vllm_fa3_C (for FA3).
# If no component is specified, vllm-flash-attn is still installed. # If no component is specified, vllm-flash-attn is still installed.
# If VLLM_FLASH_ATTN_SRC_DIR is set, vllm-flash-attn is installed from that directory instead of downloading. # If VLLM_FLASH_ATTN_SRC_DIR is set, vllm-flash-attn is installed from that directory instead of downloading.
@ -570,43 +567,41 @@ if (DEFINED ENV{VLLM_FLASH_ATTN_SRC_DIR})
endif() endif()
if(VLLM_FLASH_ATTN_SRC_DIR) if(VLLM_FLASH_ATTN_SRC_DIR)
FetchContent_Declare(vllm-flash-attn SOURCE_DIR ${VLLM_FLASH_ATTN_SRC_DIR}) FetchContent_Declare(
vllm-flash-attn SOURCE_DIR
${VLLM_FLASH_ATTN_SRC_DIR}
BINARY_DIR ${CMAKE_BINARY_DIR}/vllm-flash-attn
)
else() else()
FetchContent_Declare( FetchContent_Declare(
vllm-flash-attn vllm-flash-attn
GIT_REPOSITORY https://github.com/vllm-project/flash-attention.git GIT_REPOSITORY https://github.com/vllm-project/flash-attention.git
GIT_TAG 96266b1111111f3d11aabefaf3bacbab6a89d03c GIT_TAG 90eacc1af2a7c3de62ea249e929ed5faccf38954
GIT_PROGRESS TRUE GIT_PROGRESS TRUE
# Don't share the vllm-flash-attn build between build types # Don't share the vllm-flash-attn build between build types
BINARY_DIR ${CMAKE_BINARY_DIR}/vllm-flash-attn BINARY_DIR ${CMAKE_BINARY_DIR}/vllm-flash-attn
) )
endif() endif()
# Set the parent build flag so that the vllm-flash-attn library does not redo compile flag and arch initialization.
set(VLLM_PARENT_BUILD ON)
# Ensure the vllm/vllm_flash_attn directory exists before installation
install(CODE "file(MAKE_DIRECTORY \"\${CMAKE_INSTALL_PREFIX}/vllm/vllm_flash_attn\")" COMPONENT vllm_flash_attn_c)
# Make sure vllm-flash-attn install rules are nested under vllm/
install(CODE "set(CMAKE_INSTALL_LOCAL_ONLY FALSE)" COMPONENT vllm_flash_attn_c)
install(CODE "set(OLD_CMAKE_INSTALL_PREFIX \"\${CMAKE_INSTALL_PREFIX}\")" COMPONENT vllm_flash_attn_c)
install(CODE "set(CMAKE_INSTALL_PREFIX \"\${CMAKE_INSTALL_PREFIX}/vllm/\")" COMPONENT vllm_flash_attn_c)
# Fetch the vllm-flash-attn library # Fetch the vllm-flash-attn library
FetchContent_MakeAvailable(vllm-flash-attn) FetchContent_MakeAvailable(vllm-flash-attn)
message(STATUS "vllm-flash-attn is available at ${vllm-flash-attn_SOURCE_DIR}") message(STATUS "vllm-flash-attn is available at ${vllm-flash-attn_SOURCE_DIR}")
# Restore the install prefix # Copy over the vllm-flash-attn python files (duplicated for fa2 and fa3, in
install(CODE "set(CMAKE_INSTALL_PREFIX \"\${OLD_CMAKE_INSTALL_PREFIX}\")" COMPONENT vllm_flash_attn_c) # case only one is built, in the case both are built redundant work is done)
install(CODE "set(CMAKE_INSTALL_LOCAL_ONLY TRUE)" COMPONENT vllm_flash_attn_c)
# Copy over the vllm-flash-attn python files
install( install(
DIRECTORY ${vllm-flash-attn_SOURCE_DIR}/vllm_flash_attn/ DIRECTORY ${vllm-flash-attn_SOURCE_DIR}/vllm_flash_attn/
DESTINATION vllm/vllm_flash_attn DESTINATION vllm_flash_attn
COMPONENT vllm_flash_attn_c COMPONENT _vllm_fa2_C
FILES_MATCHING PATTERN "*.py" FILES_MATCHING PATTERN "*.py"
)
install(
DIRECTORY ${vllm-flash-attn_SOURCE_DIR}/vllm_flash_attn/
DESTINATION vllm_flash_attn
COMPONENT _vllm_fa3_C
FILES_MATCHING PATTERN "*.py"
) )
# Nothing after vllm-flash-attn, see comment about macros above # Nothing after vllm-flash-attn, see comment about macros above

View File

@ -228,8 +228,11 @@ class cmake_build_ext(build_ext):
# CMake appends the extension prefix to the install path, # CMake appends the extension prefix to the install path,
# and outdir already contains that prefix, so we need to remove it. # and outdir already contains that prefix, so we need to remove it.
# We assume only the final component of extension prefix is added by
# CMake, this is currently true for current extensions but may not
# always be the case.
prefix = outdir prefix = outdir
for i in range(ext.name.count('.')): if '.' in ext.name:
prefix = prefix.parent prefix = prefix.parent
# prefix here should actually be the same for all components # prefix here should actually be the same for all components
@ -298,7 +301,8 @@ class repackage_wheel(build_ext):
files_to_copy = [ files_to_copy = [
"vllm/_C.abi3.so", "vllm/_C.abi3.so",
"vllm/_moe_C.abi3.so", "vllm/_moe_C.abi3.so",
"vllm/vllm_flash_attn/vllm_flash_attn_c.abi3.so", "vllm/vllm_flash_attn/_vllm_fa2_C.abi3.so",
"vllm/vllm_flash_attn/_vllm_fa3_C.abi3.so",
"vllm/vllm_flash_attn/flash_attn_interface.py", "vllm/vllm_flash_attn/flash_attn_interface.py",
"vllm/vllm_flash_attn/__init__.py", "vllm/vllm_flash_attn/__init__.py",
"vllm/cumem_allocator.abi3.so", "vllm/cumem_allocator.abi3.so",
@ -593,8 +597,8 @@ if _is_hip():
ext_modules.append(CMakeExtension(name="vllm._rocm_C")) ext_modules.append(CMakeExtension(name="vllm._rocm_C"))
if _is_cuda(): if _is_cuda():
ext_modules.append( ext_modules.append(CMakeExtension(name="vllm.vllm_flash_attn._vllm_fa2_C"))
CMakeExtension(name="vllm.vllm_flash_attn.vllm_flash_attn_c")) ext_modules.append(CMakeExtension(name="vllm.vllm_flash_attn._vllm_fa3_C"))
ext_modules.append(CMakeExtension(name="vllm.cumem_allocator")) ext_modules.append(CMakeExtension(name="vllm.cumem_allocator"))
if _build_custom_ops(): if _build_custom_ops():

View File

@ -78,6 +78,7 @@ CASES = [
@pytest.mark.parametrize("block_size", BLOCK_SIZES) @pytest.mark.parametrize("block_size", BLOCK_SIZES)
@pytest.mark.parametrize("soft_cap", [None, 50]) @pytest.mark.parametrize("soft_cap", [None, 50])
@pytest.mark.parametrize("num_blocks", [2048]) @pytest.mark.parametrize("num_blocks", [2048])
@pytest.mark.parametrize("fa_version", [2, 3])
@torch.inference_mode() @torch.inference_mode()
def test_cascade( def test_cascade(
seq_lens_and_common_prefix: Tuple[List[Tuple[int, int]], int], seq_lens_and_common_prefix: Tuple[List[Tuple[int, int]], int],
@ -87,8 +88,14 @@ def test_cascade(
block_size: int, block_size: int,
soft_cap: Optional[float], soft_cap: Optional[float],
num_blocks: int, num_blocks: int,
fa_version: int,
) -> None: ) -> None:
torch.set_default_device("cuda") torch.set_default_device("cuda")
if fa_version == 3 and (torch.cuda.get_device_capability() == (8, 6)
or torch.cuda.get_device_capability() == (8, 9)):
pytest.skip("Flash attention version 3 fails on 8.6 and 8.9 due to "
"insufficient shared memory for some shapes")
current_platform.seed_everything(0) current_platform.seed_everything(0)
window_size = (-1, -1) window_size = (-1, -1)
@ -118,9 +125,7 @@ def test_cascade(
cu_query_lens = torch.tensor([0] + query_lens, cu_query_lens = torch.tensor([0] + query_lens,
dtype=torch.int32).cumsum(dim=0, dtype=torch.int32).cumsum(dim=0,
dtype=torch.int32) dtype=torch.int32)
cu_kv_lens = torch.tensor([0] + kv_lens, kv_lens_tensor = torch.tensor(kv_lens, dtype=torch.int32)
dtype=torch.int32).cumsum(dim=0,
dtype=torch.int32)
max_num_blocks_per_seq = (max_kv_len + block_size - 1) // block_size max_num_blocks_per_seq = (max_kv_len + block_size - 1) // block_size
block_tables = torch.randint(0, block_tables = torch.randint(0,
num_blocks, num_blocks,
@ -140,7 +145,7 @@ def test_cascade(
k=key_cache, k=key_cache,
v=value_cache, v=value_cache,
cu_seqlens_q=cu_query_lens, cu_seqlens_q=cu_query_lens,
cu_seqlens_k=cu_kv_lens, seqused_k=kv_lens_tensor,
max_seqlen_q=max_query_len, max_seqlen_q=max_query_len,
max_seqlen_k=max_kv_len, max_seqlen_k=max_kv_len,
softmax_scale=scale, softmax_scale=scale,
@ -154,10 +159,8 @@ def test_cascade(
assert all(common_prefix_len < kv_len for kv_len in kv_lens) assert all(common_prefix_len < kv_len for kv_len in kv_lens)
cu_prefix_query_lens = torch.tensor([0, total_num_query_tokens], cu_prefix_query_lens = torch.tensor([0, total_num_query_tokens],
dtype=torch.int32) dtype=torch.int32)
cu_prefix_kv_lens = torch.tensor([0, common_prefix_len], dtype=torch.int32) prefix_kv_lens = torch.tensor([common_prefix_len], dtype=torch.int32)
cu_suffix_kv_lens = ( suffix_kv_lens = kv_lens_tensor - common_prefix_len
cu_kv_lens -
torch.arange(num_seqs + 1, dtype=torch.int32) * common_prefix_len)
output = torch.empty_like(query) output = torch.empty_like(query)
cascade_attention( cascade_attention(
output=output, output=output,
@ -167,8 +170,8 @@ def test_cascade(
cu_query_lens=cu_query_lens, cu_query_lens=cu_query_lens,
max_query_len=max_query_len, max_query_len=max_query_len,
cu_prefix_query_lens=cu_prefix_query_lens, cu_prefix_query_lens=cu_prefix_query_lens,
cu_prefix_kv_lens=cu_prefix_kv_lens, prefix_kv_lens=prefix_kv_lens,
cu_suffix_kv_lens=cu_suffix_kv_lens, suffix_kv_lens=suffix_kv_lens,
max_kv_len=max_kv_len, max_kv_len=max_kv_len,
softmax_scale=scale, softmax_scale=scale,
alibi_slopes=None, alibi_slopes=None,
@ -176,6 +179,7 @@ def test_cascade(
logits_soft_cap=soft_cap if soft_cap is not None else 0, logits_soft_cap=soft_cap if soft_cap is not None else 0,
block_table=block_tables, block_table=block_tables,
common_prefix_len=common_prefix_len, common_prefix_len=common_prefix_len,
fa_version=fa_version,
) )
# Compare the results. # Compare the results.

View File

@ -80,6 +80,7 @@ def ref_paged_attn(
@pytest.mark.parametrize("soft_cap", [None, 10.0, 50.0]) @pytest.mark.parametrize("soft_cap", [None, 10.0, 50.0])
@pytest.mark.parametrize("num_blocks", NUM_BLOCKS) @pytest.mark.parametrize("num_blocks", NUM_BLOCKS)
@pytest.mark.parametrize("sliding_window", [None, 256]) @pytest.mark.parametrize("sliding_window", [None, 256])
@pytest.mark.parametrize("fa_version", [2, 3])
@torch.inference_mode() @torch.inference_mode()
def test_flash_attn_with_paged_kv( def test_flash_attn_with_paged_kv(
use_out: bool, use_out: bool,
@ -91,8 +92,14 @@ def test_flash_attn_with_paged_kv(
soft_cap: Optional[float], soft_cap: Optional[float],
num_blocks: int, num_blocks: int,
sliding_window: Optional[int], sliding_window: Optional[int],
fa_version: int,
) -> None: ) -> None:
torch.set_default_device("cuda") torch.set_default_device("cuda")
if fa_version == 3 and (torch.cuda.get_device_capability() == (8, 6)
or torch.cuda.get_device_capability() == (8, 9)):
pytest.skip("Flash attention version 3 fails on 8.6 and 8.9 due to "
"insufficient shared memory for some shapes")
current_platform.seed_everything(0) current_platform.seed_everything(0)
num_seqs = len(kv_lens) num_seqs = len(kv_lens)
num_query_heads = num_heads[0] num_query_heads = num_heads[0]
@ -131,6 +138,7 @@ def test_flash_attn_with_paged_kv(
cache_seqlens=kv_lens_tensor, cache_seqlens=kv_lens_tensor,
softcap=soft_cap if soft_cap is not None else 0, softcap=soft_cap if soft_cap is not None else 0,
window_size=window_size, window_size=window_size,
fa_version=fa_version,
) )
output = output if not use_out else out output = output if not use_out else out
output = output.squeeze(1) output = output.squeeze(1)
@ -159,6 +167,7 @@ def test_flash_attn_with_paged_kv(
@pytest.mark.parametrize("dtype", DTYPES) @pytest.mark.parametrize("dtype", DTYPES)
@pytest.mark.parametrize("soft_cap", [None, 10.0, 50.0]) @pytest.mark.parametrize("soft_cap", [None, 10.0, 50.0])
@pytest.mark.parametrize("num_blocks", NUM_BLOCKS) @pytest.mark.parametrize("num_blocks", NUM_BLOCKS)
@pytest.mark.parametrize("fa_version", [2, 3])
@torch.inference_mode() @torch.inference_mode()
def test_varlen_with_paged_kv( def test_varlen_with_paged_kv(
use_out: bool, use_out: bool,
@ -170,8 +179,14 @@ def test_varlen_with_paged_kv(
block_size: int, block_size: int,
soft_cap: Optional[float], soft_cap: Optional[float],
num_blocks: int, num_blocks: int,
fa_version: int,
) -> None: ) -> None:
torch.set_default_device("cuda") torch.set_default_device("cuda")
if fa_version == 3 and (torch.cuda.get_device_capability() == (8, 6)
or torch.cuda.get_device_capability() == (8, 9)):
pytest.skip("Flash attention version 3 fails on 8.6 and 8.9 due to "
"insufficient shared memory for some shapes")
current_platform.seed_everything(0) current_platform.seed_everything(0)
num_seqs = len(seq_lens) num_seqs = len(seq_lens)
query_lens = [x[0] for x in seq_lens] query_lens = [x[0] for x in seq_lens]
@ -198,9 +213,7 @@ def test_varlen_with_paged_kv(
cu_query_lens = torch.tensor([0] + query_lens, cu_query_lens = torch.tensor([0] + query_lens,
dtype=torch.int32).cumsum(dim=0, dtype=torch.int32).cumsum(dim=0,
dtype=torch.int32) dtype=torch.int32)
cu_kv_lens = torch.tensor([0] + kv_lens, kv_lens = torch.tensor(kv_lens, dtype=torch.int32)
dtype=torch.int32).cumsum(dim=0,
dtype=torch.int32)
max_num_blocks_per_seq = (max_kv_len + block_size - 1) // block_size max_num_blocks_per_seq = (max_kv_len + block_size - 1) // block_size
block_tables = torch.randint(0, block_tables = torch.randint(0,
@ -215,7 +228,7 @@ def test_varlen_with_paged_kv(
v=value_cache, v=value_cache,
out=out, out=out,
cu_seqlens_q=cu_query_lens, cu_seqlens_q=cu_query_lens,
cu_seqlens_k=cu_kv_lens, seqused_k=kv_lens,
max_seqlen_q=max_query_len, max_seqlen_q=max_query_len,
max_seqlen_k=max_kv_len, max_seqlen_k=max_kv_len,
softmax_scale=scale, softmax_scale=scale,
@ -223,6 +236,7 @@ def test_varlen_with_paged_kv(
window_size=window_size, window_size=window_size,
block_table=block_tables, block_table=block_tables,
softcap=soft_cap if soft_cap is not None else 0, softcap=soft_cap if soft_cap is not None else 0,
fa_version=fa_version,
) )
output = output if not use_out else out output = output if not use_out else out

View File

@ -17,7 +17,9 @@ from vllm.attention.backends.utils import (
compute_slot_mapping_start_idx, get_num_prefill_decode_query_kv_tokens, compute_slot_mapping_start_idx, get_num_prefill_decode_query_kv_tokens,
get_seq_len_block_table_args, is_all_cross_attn_metadata_set, get_seq_len_block_table_args, is_all_cross_attn_metadata_set,
is_all_encoder_attn_metadata_set, is_block_tables_empty) is_all_encoder_attn_metadata_set, is_block_tables_empty)
from vllm.envs import VLLM_FLASH_ATTN_VERSION
from vllm.multimodal import MultiModalPlaceholderMap from vllm.multimodal import MultiModalPlaceholderMap
from vllm.platforms import current_platform
from vllm.utils import async_tensor_h2d, make_tensor_with_pad from vllm.utils import async_tensor_h2d, make_tensor_with_pad
if TYPE_CHECKING: if TYPE_CHECKING:
@ -25,7 +27,8 @@ if TYPE_CHECKING:
ModelInputForGPUWithSamplingMetadata) ModelInputForGPUWithSamplingMetadata)
from vllm.vllm_flash_attn import (flash_attn_varlen_func, from vllm.vllm_flash_attn import (flash_attn_varlen_func,
flash_attn_with_kvcache) flash_attn_with_kvcache,
is_fa_version_supported)
class FlashAttentionBackend(AttentionBackend): class FlashAttentionBackend(AttentionBackend):
@ -634,6 +637,20 @@ class FlashAttentionImpl(AttentionImpl):
f"Supported head sizes are: {support_head_sizes}.") f"Supported head sizes are: {support_head_sizes}.")
self.attn_type = attn_type self.attn_type = attn_type
# if hopper default to FA3, otherwise stick to FA2 for now
# TODO(lucas): profile FA3 on ampere to see if it makes sense to
# use FA3 as default for both
if current_platform.get_device_capability()[0] >= 9:
self.fa_version = 3 if is_fa_version_supported(3) else 2
else:
self.fa_version = 2
if VLLM_FLASH_ATTN_VERSION is not None:
assert VLLM_FLASH_ATTN_VERSION in [2, 3]
self.fa_version = VLLM_FLASH_ATTN_VERSION
assert is_fa_version_supported(self.fa_version)
def forward( def forward(
self, self,
layer: AttentionLayer, layer: AttentionLayer,
@ -752,6 +769,7 @@ class FlashAttentionImpl(AttentionImpl):
alibi_slopes=alibi_slopes, alibi_slopes=alibi_slopes,
softcap=logits_soft_cap, softcap=logits_soft_cap,
out=prefill_output, out=prefill_output,
fa_version=self.fa_version,
) )
else: else:
# prefix-enabled attention # prefix-enabled attention
@ -765,7 +783,7 @@ class FlashAttentionImpl(AttentionImpl):
v=value_cache, v=value_cache,
cu_seqlens_q=prefill_meta.query_start_loc, cu_seqlens_q=prefill_meta.query_start_loc,
max_seqlen_q=prefill_meta.max_query_len, max_seqlen_q=prefill_meta.max_query_len,
cu_seqlens_k=prefill_meta.seq_start_loc, seqused_k=prefill_meta.seq_lens_tensor,
max_seqlen_k=max_seq_len, max_seqlen_k=max_seq_len,
softmax_scale=softmax_scale, softmax_scale=softmax_scale,
causal=True, causal=True,
@ -774,6 +792,7 @@ class FlashAttentionImpl(AttentionImpl):
block_table=prefill_meta.block_tables, block_table=prefill_meta.block_tables,
softcap=logits_soft_cap, softcap=logits_soft_cap,
out=prefill_output, out=prefill_output,
fa_version=self.fa_version,
) )
if decode_meta := attn_metadata.decode_metadata: if decode_meta := attn_metadata.decode_metadata:
@ -793,7 +812,7 @@ class FlashAttentionImpl(AttentionImpl):
v=value_cache, v=value_cache,
cu_seqlens_q=decode_meta.query_start_loc, cu_seqlens_q=decode_meta.query_start_loc,
max_seqlen_q=decode_meta.max_decode_query_len, max_seqlen_q=decode_meta.max_decode_query_len,
cu_seqlens_k=decode_meta.seq_start_loc, seqused_k=decode_meta.seq_lens_tensor,
max_seqlen_k=decode_meta.max_decode_seq_len, max_seqlen_k=decode_meta.max_decode_seq_len,
softmax_scale=softmax_scale, softmax_scale=softmax_scale,
causal=True, causal=True,
@ -802,6 +821,7 @@ class FlashAttentionImpl(AttentionImpl):
softcap=logits_soft_cap, softcap=logits_soft_cap,
block_table=decode_meta.block_tables, block_table=decode_meta.block_tables,
out=decode_output, out=decode_output,
fa_version=self.fa_version,
) )
else: else:
# Use flash_attn_with_kvcache for normal decoding. # Use flash_attn_with_kvcache for normal decoding.
@ -822,6 +842,7 @@ class FlashAttentionImpl(AttentionImpl):
alibi_slopes=alibi_slopes, alibi_slopes=alibi_slopes,
softcap=logits_soft_cap, softcap=logits_soft_cap,
out=decode_output.unsqueeze(1), out=decode_output.unsqueeze(1),
fa_version=self.fa_version,
) )
return output return output

View File

@ -11,6 +11,7 @@ if TYPE_CHECKING:
VLLM_NCCL_SO_PATH: Optional[str] = None VLLM_NCCL_SO_PATH: Optional[str] = None
LD_LIBRARY_PATH: Optional[str] = None LD_LIBRARY_PATH: Optional[str] = None
VLLM_USE_TRITON_FLASH_ATTN: bool = False VLLM_USE_TRITON_FLASH_ATTN: bool = False
VLLM_FLASH_ATTN_VERSION: Optional[int] = None
LOCAL_RANK: int = 0 LOCAL_RANK: int = 0
CUDA_VISIBLE_DEVICES: Optional[str] = None CUDA_VISIBLE_DEVICES: Optional[str] = None
VLLM_ENGINE_ITERATION_TIMEOUT_S: int = 60 VLLM_ENGINE_ITERATION_TIMEOUT_S: int = 60
@ -90,6 +91,12 @@ def get_default_config_root():
) )
def maybe_convert_int(value: Optional[str]) -> Optional[int]:
if value is None:
return None
return int(value)
# The begin-* and end* here are used by the documentation generator # The begin-* and end* here are used by the documentation generator
# to extract the used env vars. # to extract the used env vars.
@ -203,6 +210,11 @@ environment_variables: Dict[str, Callable[[], Any]] = {
lambda: (os.environ.get("VLLM_USE_TRITON_FLASH_ATTN", "True").lower() in lambda: (os.environ.get("VLLM_USE_TRITON_FLASH_ATTN", "True").lower() in
("true", "1")), ("true", "1")),
# Force vllm to use a specific flash-attention version (2 or 3), only valid
# when using the flash-attention backend.
"VLLM_FLASH_ATTN_VERSION":
lambda: maybe_convert_int(os.environ.get("VLLM_FLASH_ATTN_VERSION", None)),
# Internal flag to enable Dynamo fullgraph capture # Internal flag to enable Dynamo fullgraph capture
"VLLM_TEST_DYNAMO_FULLGRAPH_CAPTURE": "VLLM_TEST_DYNAMO_FULLGRAPH_CAPTURE":
lambda: bool( lambda: bool(

View File

@ -9,8 +9,11 @@ import triton.language as tl
from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl, from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
AttentionMetadata, AttentionType) AttentionMetadata, AttentionType)
from vllm.envs import VLLM_FLASH_ATTN_VERSION
from vllm.platforms import current_platform
from vllm.utils import cdiv from vllm.utils import cdiv
from vllm.vllm_flash_attn import flash_attn_varlen_func from vllm.vllm_flash_attn import (flash_attn_varlen_func,
is_fa_version_supported)
class FlashAttentionBackend(AttentionBackend): class FlashAttentionBackend(AttentionBackend):
@ -63,7 +66,7 @@ class FlashAttentionMetadata:
max_query_len: int max_query_len: int
query_start_loc: torch.Tensor query_start_loc: torch.Tensor
max_seq_len: int max_seq_len: int
seq_start_loc: torch.Tensor seq_lens: torch.Tensor
block_table: torch.Tensor block_table: torch.Tensor
slot_mapping: torch.Tensor slot_mapping: torch.Tensor
@ -71,8 +74,8 @@ class FlashAttentionMetadata:
use_cascade: bool use_cascade: bool
common_prefix_len: int common_prefix_len: int
cu_prefix_query_lens: Optional[torch.Tensor] cu_prefix_query_lens: Optional[torch.Tensor]
cu_prefix_kv_lens: Optional[torch.Tensor] prefix_kv_lens: Optional[torch.Tensor]
cu_suffix_kv_lens: Optional[torch.Tensor] suffix_kv_lens: Optional[torch.Tensor]
# For logging. # For logging.
num_input_tokens: int = 0 # Number of tokens including padding. num_input_tokens: int = 0 # Number of tokens including padding.
@ -128,6 +131,20 @@ class FlashAttentionImpl(AttentionImpl):
"are not implemented for " "are not implemented for "
"FlashAttentionImpl") "FlashAttentionImpl")
# if hopper default to FA3, otherwise stick to FA2 for now
# TODO(lucas): profile FA3 on ampere to see if it makes sense to
# use FA3 as default for both
if current_platform.get_device_capability()[0] >= 9:
self.fa_version = 3 if is_fa_version_supported(3) else 2
else:
self.fa_version = 2
if VLLM_FLASH_ATTN_VERSION is not None:
assert VLLM_FLASH_ATTN_VERSION in [2, 3]
self.fa_version = VLLM_FLASH_ATTN_VERSION
assert is_fa_version_supported(self.fa_version)
def forward( def forward(
self, self,
layer: torch.nn.Module, layer: torch.nn.Module,
@ -196,7 +213,7 @@ class FlashAttentionImpl(AttentionImpl):
out=output[:num_actual_tokens], out=output[:num_actual_tokens],
cu_seqlens_q=attn_metadata.query_start_loc, cu_seqlens_q=attn_metadata.query_start_loc,
max_seqlen_q=attn_metadata.max_query_len, max_seqlen_q=attn_metadata.max_query_len,
cu_seqlens_k=attn_metadata.seq_start_loc, seqused_k=attn_metadata.seq_lens,
max_seqlen_k=attn_metadata.max_seq_len, max_seqlen_k=attn_metadata.max_seq_len,
softmax_scale=self.scale, softmax_scale=self.scale,
causal=True, causal=True,
@ -204,6 +221,7 @@ class FlashAttentionImpl(AttentionImpl):
window_size=self.sliding_window, window_size=self.sliding_window,
block_table=attn_metadata.block_table, block_table=attn_metadata.block_table,
softcap=self.logits_soft_cap, softcap=self.logits_soft_cap,
fa_version=self.fa_version,
) )
return output return output
@ -216,8 +234,8 @@ class FlashAttentionImpl(AttentionImpl):
cu_query_lens=attn_metadata.query_start_loc, cu_query_lens=attn_metadata.query_start_loc,
max_query_len=attn_metadata.max_query_len, max_query_len=attn_metadata.max_query_len,
cu_prefix_query_lens=attn_metadata.cu_prefix_query_lens, cu_prefix_query_lens=attn_metadata.cu_prefix_query_lens,
cu_prefix_kv_lens=attn_metadata.cu_prefix_kv_lens, prefix_kv_lens=attn_metadata.prefix_kv_lens,
cu_suffix_kv_lens=attn_metadata.cu_suffix_kv_lens, suffix_kv_lens=attn_metadata.suffix_kv_lens,
max_kv_len=attn_metadata.max_seq_len, max_kv_len=attn_metadata.max_seq_len,
softmax_scale=self.scale, softmax_scale=self.scale,
alibi_slopes=self.alibi_slopes, alibi_slopes=self.alibi_slopes,
@ -225,6 +243,7 @@ class FlashAttentionImpl(AttentionImpl):
logits_soft_cap=self.logits_soft_cap, logits_soft_cap=self.logits_soft_cap,
block_table=attn_metadata.block_table, block_table=attn_metadata.block_table,
common_prefix_len=attn_metadata.common_prefix_len, common_prefix_len=attn_metadata.common_prefix_len,
fa_version=self.fa_version,
) )
return output return output
@ -305,8 +324,8 @@ def cascade_attention(
cu_query_lens: torch.Tensor, cu_query_lens: torch.Tensor,
max_query_len: int, max_query_len: int,
cu_prefix_query_lens: torch.Tensor, cu_prefix_query_lens: torch.Tensor,
cu_prefix_kv_lens: torch.Tensor, prefix_kv_lens: torch.Tensor,
cu_suffix_kv_lens: torch.Tensor, suffix_kv_lens: torch.Tensor,
max_kv_len: int, max_kv_len: int,
softmax_scale: float, softmax_scale: float,
alibi_slopes: Optional[torch.Tensor], alibi_slopes: Optional[torch.Tensor],
@ -314,6 +333,7 @@ def cascade_attention(
logits_soft_cap: float, logits_soft_cap: float,
block_table: torch.Tensor, block_table: torch.Tensor,
common_prefix_len: int, common_prefix_len: int,
fa_version: int,
) -> torch.Tensor: ) -> torch.Tensor:
assert alibi_slopes is None, ("Cascade attention does not support ALiBi.") assert alibi_slopes is None, ("Cascade attention does not support ALiBi.")
# TODO: Support sliding window. # TODO: Support sliding window.
@ -332,7 +352,7 @@ def cascade_attention(
k=key_cache, k=key_cache,
v=value_cache, v=value_cache,
cu_seqlens_q=cu_prefix_query_lens, cu_seqlens_q=cu_prefix_query_lens,
cu_seqlens_k=cu_prefix_kv_lens, seqused_k=prefix_kv_lens,
max_seqlen_q=num_tokens, max_seqlen_q=num_tokens,
max_seqlen_k=common_prefix_len, max_seqlen_k=common_prefix_len,
softmax_scale=softmax_scale, softmax_scale=softmax_scale,
@ -341,6 +361,7 @@ def cascade_attention(
block_table=block_table[:1], block_table=block_table[:1],
softcap=logits_soft_cap, softcap=logits_soft_cap,
return_softmax_lse=True, return_softmax_lse=True,
fa_version=fa_version,
) )
# Process suffix per query. # Process suffix per query.
@ -349,7 +370,7 @@ def cascade_attention(
k=key_cache, k=key_cache,
v=value_cache, v=value_cache,
cu_seqlens_q=cu_query_lens, cu_seqlens_q=cu_query_lens,
cu_seqlens_k=cu_suffix_kv_lens, seqused_k=suffix_kv_lens,
max_seqlen_q=max_query_len, max_seqlen_q=max_query_len,
max_seqlen_k=max_kv_len - common_prefix_len, max_seqlen_k=max_kv_len - common_prefix_len,
softmax_scale=softmax_scale, softmax_scale=softmax_scale,
@ -358,6 +379,7 @@ def cascade_attention(
block_table=block_table[:, num_common_kv_blocks:], block_table=block_table[:, num_common_kv_blocks:],
softcap=logits_soft_cap, softcap=logits_soft_cap,
return_softmax_lse=True, return_softmax_lse=True,
fa_version=fa_version,
) )
# Merge prefix and suffix outputs, and store the result in output. # Merge prefix and suffix outputs, and store the result in output.

View File

@ -199,11 +199,11 @@ class GPUModelRunner:
device="cpu", device="cpu",
pin_memory=self.pin_memory) pin_memory=self.pin_memory)
self.query_start_loc_np = self.query_start_loc_cpu.numpy() self.query_start_loc_np = self.query_start_loc_cpu.numpy()
self.seq_start_loc_cpu = torch.zeros(self.max_num_reqs + 1, self.seq_lens_cpu = torch.zeros(self.max_num_reqs,
dtype=torch.int32, dtype=torch.int32,
device="cpu", device="cpu",
pin_memory=self.pin_memory) pin_memory=self.pin_memory)
self.seq_start_loc_np = self.seq_start_loc_cpu.numpy() self.seq_lens_np = self.seq_lens_cpu.numpy()
def _update_states(self, scheduler_output: "SchedulerOutput") -> None: def _update_states(self, scheduler_output: "SchedulerOutput") -> None:
# Remove stopped requests from the cached states. # Remove stopped requests from the cached states.
@ -412,11 +412,10 @@ class GPUModelRunner:
np.cumsum(num_scheduled_tokens, np.cumsum(num_scheduled_tokens,
out=self.query_start_loc_np[1:num_reqs + 1]) out=self.query_start_loc_np[1:num_reqs + 1])
seq_lens = (self.input_batch.num_computed_tokens_cpu[:num_reqs] + self.seq_lens_np[:num_reqs] = (
num_scheduled_tokens) self.input_batch.num_computed_tokens_cpu[:num_reqs] +
max_seq_len = seq_lens.max() num_scheduled_tokens)
self.seq_start_loc_np[0] = 0 max_seq_len = self.seq_lens_np[:num_reqs].max()
np.cumsum(seq_lens, out=self.seq_start_loc_np[1:num_reqs + 1])
# Copy the tensors to the GPU. # Copy the tensors to the GPU.
self.input_ids[:total_num_scheduled_tokens].copy_( self.input_ids[:total_num_scheduled_tokens].copy_(
@ -433,8 +432,8 @@ class GPUModelRunner:
non_blocking=True) non_blocking=True)
query_start_loc = self.query_start_loc_cpu[:num_reqs + 1].to( query_start_loc = self.query_start_loc_cpu[:num_reqs + 1].to(
self.device, non_blocking=True) self.device, non_blocking=True)
seq_start_loc = self.seq_start_loc_cpu[:num_reqs + 1].to( seq_lens = self.seq_lens_cpu[:num_reqs].to(self.device,
self.device, non_blocking=True) non_blocking=True)
slot_mapping = self.slot_mapping_cpu[:total_num_scheduled_tokens].to( slot_mapping = self.slot_mapping_cpu[:total_num_scheduled_tokens].to(
self.device, non_blocking=True).long() self.device, non_blocking=True).long()
@ -506,33 +505,30 @@ class GPUModelRunner:
[0, total_num_scheduled_tokens], [0, total_num_scheduled_tokens],
dtype=torch.int32, dtype=torch.int32,
device=self.device) device=self.device)
cu_prefix_kv_lens = torch.tensor([0, common_prefix_len], prefix_kv_lens = torch.tensor([common_prefix_len],
dtype=torch.int32, dtype=torch.int32,
device=self.device) device=self.device)
cu_suffix_kv_lens = ( suffix_kv_lens = (self.seq_lens_np[:num_reqs] - common_prefix_len)
self.seq_start_loc_np[:num_reqs + 1] - suffix_kv_lens = torch.from_numpy(suffix_kv_lens).to(self.device)
self.arange_np[:num_reqs + 1] * common_prefix_len)
cu_suffix_kv_lens = torch.from_numpy(cu_suffix_kv_lens).to(
self.device)
else: else:
cu_prefix_query_lens = None cu_prefix_query_lens = None
cu_prefix_kv_lens = None prefix_kv_lens = None
cu_suffix_kv_lens = None suffix_kv_lens = None
attn_metadata = FlashAttentionMetadata( attn_metadata = FlashAttentionMetadata(
num_actual_tokens=total_num_scheduled_tokens, num_actual_tokens=total_num_scheduled_tokens,
max_query_len=max_num_scheduled_tokens, max_query_len=max_num_scheduled_tokens,
query_start_loc=query_start_loc, query_start_loc=query_start_loc,
max_seq_len=max_seq_len, max_seq_len=max_seq_len,
seq_start_loc=seq_start_loc, seq_lens=seq_lens,
block_table=( block_table=(
self.input_batch.block_table.get_device_tensor()[:num_reqs]), self.input_batch.block_table.get_device_tensor()[:num_reqs]),
slot_mapping=slot_mapping, slot_mapping=slot_mapping,
use_cascade=use_cascade, use_cascade=use_cascade,
common_prefix_len=common_prefix_len, common_prefix_len=common_prefix_len,
cu_prefix_query_lens=cu_prefix_query_lens, cu_prefix_query_lens=cu_prefix_query_lens,
cu_prefix_kv_lens=cu_prefix_kv_lens, prefix_kv_lens=prefix_kv_lens,
cu_suffix_kv_lens=cu_suffix_kv_lens, suffix_kv_lens=suffix_kv_lens,
) )
# NOTE(woosuk): Due to chunked prefills, there can be at most 1 partial # NOTE(woosuk): Due to chunked prefills, there can be at most 1 partial
# request in the batch. While we should not sample any token from this # request in the batch. While we should not sample any token from this