[Bugfix][Kernel] Fix CUDA 11.8 being broken by FA3 build (#12375)
Signed-off-by: Lucas Wilkinson <lwilkinson@neuralmagic.com>
This commit is contained in:
parent
3bb8e2c9a2
commit
ab5bbf5ae3
2
CMakeLists.txt
Normal file → Executable file
2
CMakeLists.txt
Normal file → Executable file
@ -576,7 +576,7 @@ else()
|
|||||||
FetchContent_Declare(
|
FetchContent_Declare(
|
||||||
vllm-flash-attn
|
vllm-flash-attn
|
||||||
GIT_REPOSITORY https://github.com/vllm-project/flash-attention.git
|
GIT_REPOSITORY https://github.com/vllm-project/flash-attention.git
|
||||||
GIT_TAG 90eacc1af2a7c3de62ea249e929ed5faccf38954
|
GIT_TAG 0aff05f577e8a10086066a00618609199b25231d
|
||||||
GIT_PROGRESS TRUE
|
GIT_PROGRESS TRUE
|
||||||
# Don't share the vllm-flash-attn build between build types
|
# Don't share the vllm-flash-attn build between build types
|
||||||
BINARY_DIR ${CMAKE_BINARY_DIR}/vllm-flash-attn
|
BINARY_DIR ${CMAKE_BINARY_DIR}/vllm-flash-attn
|
||||||
|
5
setup.py
Normal file → Executable file
5
setup.py
Normal file → Executable file
@ -598,7 +598,10 @@ if _is_hip():
|
|||||||
|
|
||||||
if _is_cuda():
|
if _is_cuda():
|
||||||
ext_modules.append(CMakeExtension(name="vllm.vllm_flash_attn._vllm_fa2_C"))
|
ext_modules.append(CMakeExtension(name="vllm.vllm_flash_attn._vllm_fa2_C"))
|
||||||
ext_modules.append(CMakeExtension(name="vllm.vllm_flash_attn._vllm_fa3_C"))
|
if envs.VLLM_USE_PRECOMPILED or get_nvcc_cuda_version() >= Version("12.0"):
|
||||||
|
# FA3 requires CUDA 12.0 or later
|
||||||
|
ext_modules.append(
|
||||||
|
CMakeExtension(name="vllm.vllm_flash_attn._vllm_fa3_C"))
|
||||||
ext_modules.append(CMakeExtension(name="vllm.cumem_allocator"))
|
ext_modules.append(CMakeExtension(name="vllm.cumem_allocator"))
|
||||||
|
|
||||||
if _build_custom_ops():
|
if _build_custom_ops():
|
||||||
|
11
tests/kernels/test_cascade_flash_attn.py
Normal file → Executable file
11
tests/kernels/test_cascade_flash_attn.py
Normal file → Executable file
@ -6,7 +6,9 @@ import torch
|
|||||||
from vllm.platforms import current_platform
|
from vllm.platforms import current_platform
|
||||||
from vllm.v1.attention.backends.flash_attn import (cascade_attention,
|
from vllm.v1.attention.backends.flash_attn import (cascade_attention,
|
||||||
merge_attn_states)
|
merge_attn_states)
|
||||||
from vllm.vllm_flash_attn import flash_attn_varlen_func
|
from vllm.vllm_flash_attn import (fa_version_unsupported_reason,
|
||||||
|
flash_attn_varlen_func,
|
||||||
|
is_fa_version_supported)
|
||||||
|
|
||||||
NUM_HEADS = [(4, 4), (8, 2), (16, 2)]
|
NUM_HEADS = [(4, 4), (8, 2), (16, 2)]
|
||||||
HEAD_SIZES = [128, 192, 256]
|
HEAD_SIZES = [128, 192, 256]
|
||||||
@ -91,10 +93,9 @@ def test_cascade(
|
|||||||
fa_version: int,
|
fa_version: int,
|
||||||
) -> None:
|
) -> None:
|
||||||
torch.set_default_device("cuda")
|
torch.set_default_device("cuda")
|
||||||
if fa_version == 3 and (torch.cuda.get_device_capability() == (8, 6)
|
if not is_fa_version_supported(fa_version):
|
||||||
or torch.cuda.get_device_capability() == (8, 9)):
|
pytest.skip(f"Flash attention version {fa_version} not supported due "
|
||||||
pytest.skip("Flash attention version 3 fails on 8.6 and 8.9 due to "
|
f"to: \"{fa_version_unsupported_reason(fa_version)}\"")
|
||||||
"insufficient shared memory for some shapes")
|
|
||||||
|
|
||||||
current_platform.seed_everything(0)
|
current_platform.seed_everything(0)
|
||||||
|
|
||||||
|
@ -4,8 +4,10 @@ import pytest
|
|||||||
import torch
|
import torch
|
||||||
|
|
||||||
from vllm.platforms import current_platform
|
from vllm.platforms import current_platform
|
||||||
from vllm.vllm_flash_attn import (flash_attn_varlen_func,
|
from vllm.vllm_flash_attn import (fa_version_unsupported_reason,
|
||||||
flash_attn_with_kvcache)
|
flash_attn_varlen_func,
|
||||||
|
flash_attn_with_kvcache,
|
||||||
|
is_fa_version_supported)
|
||||||
|
|
||||||
NUM_HEADS = [(4, 4), (8, 2), (16, 2)]
|
NUM_HEADS = [(4, 4), (8, 2), (16, 2)]
|
||||||
HEAD_SIZES = [128, 256]
|
HEAD_SIZES = [128, 256]
|
||||||
@ -95,10 +97,9 @@ def test_flash_attn_with_paged_kv(
|
|||||||
fa_version: int,
|
fa_version: int,
|
||||||
) -> None:
|
) -> None:
|
||||||
torch.set_default_device("cuda")
|
torch.set_default_device("cuda")
|
||||||
if fa_version == 3 and (torch.cuda.get_device_capability() == (8, 6)
|
if not is_fa_version_supported(fa_version):
|
||||||
or torch.cuda.get_device_capability() == (8, 9)):
|
pytest.skip(f"Flash attention version {fa_version} not supported due "
|
||||||
pytest.skip("Flash attention version 3 fails on 8.6 and 8.9 due to "
|
f"to: \"{fa_version_unsupported_reason(fa_version)}\"")
|
||||||
"insufficient shared memory for some shapes")
|
|
||||||
|
|
||||||
current_platform.seed_everything(0)
|
current_platform.seed_everything(0)
|
||||||
num_seqs = len(kv_lens)
|
num_seqs = len(kv_lens)
|
||||||
@ -182,11 +183,9 @@ def test_varlen_with_paged_kv(
|
|||||||
fa_version: int,
|
fa_version: int,
|
||||||
) -> None:
|
) -> None:
|
||||||
torch.set_default_device("cuda")
|
torch.set_default_device("cuda")
|
||||||
if fa_version == 3 and (torch.cuda.get_device_capability() == (8, 6)
|
if not is_fa_version_supported(fa_version):
|
||||||
or torch.cuda.get_device_capability() == (8, 9)):
|
pytest.skip(f"Flash attention version {fa_version} not supported due "
|
||||||
pytest.skip("Flash attention version 3 fails on 8.6 and 8.9 due to "
|
f"to: \"{fa_version_unsupported_reason(fa_version)}\"")
|
||||||
"insufficient shared memory for some shapes")
|
|
||||||
|
|
||||||
current_platform.seed_everything(0)
|
current_platform.seed_everything(0)
|
||||||
num_seqs = len(seq_lens)
|
num_seqs = len(seq_lens)
|
||||||
query_lens = [x[0] for x in seq_lens]
|
query_lens = [x[0] for x in seq_lens]
|
||||||
|
14
vllm/attention/backends/flash_attn.py
Normal file → Executable file
14
vllm/attention/backends/flash_attn.py
Normal file → Executable file
@ -18,17 +18,20 @@ from vllm.attention.backends.utils import (
|
|||||||
get_seq_len_block_table_args, is_all_cross_attn_metadata_set,
|
get_seq_len_block_table_args, is_all_cross_attn_metadata_set,
|
||||||
is_all_encoder_attn_metadata_set, is_block_tables_empty)
|
is_all_encoder_attn_metadata_set, is_block_tables_empty)
|
||||||
from vllm.envs import VLLM_FLASH_ATTN_VERSION
|
from vllm.envs import VLLM_FLASH_ATTN_VERSION
|
||||||
|
from vllm.logger import init_logger
|
||||||
from vllm.multimodal import MultiModalPlaceholderMap
|
from vllm.multimodal import MultiModalPlaceholderMap
|
||||||
from vllm.platforms import current_platform
|
from vllm.platforms import current_platform
|
||||||
from vllm.utils import async_tensor_h2d, make_tensor_with_pad
|
from vllm.utils import async_tensor_h2d, make_tensor_with_pad
|
||||||
|
from vllm.vllm_flash_attn import (fa_version_unsupported_reason,
|
||||||
|
flash_attn_varlen_func,
|
||||||
|
flash_attn_with_kvcache,
|
||||||
|
is_fa_version_supported)
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from vllm.worker.model_runner import (ModelInputForGPUBuilder,
|
from vllm.worker.model_runner import (ModelInputForGPUBuilder,
|
||||||
ModelInputForGPUWithSamplingMetadata)
|
ModelInputForGPUWithSamplingMetadata)
|
||||||
|
|
||||||
from vllm.vllm_flash_attn import (flash_attn_varlen_func,
|
logger = init_logger(__name__)
|
||||||
flash_attn_with_kvcache,
|
|
||||||
is_fa_version_supported)
|
|
||||||
|
|
||||||
|
|
||||||
class FlashAttentionBackend(AttentionBackend):
|
class FlashAttentionBackend(AttentionBackend):
|
||||||
@ -652,6 +655,11 @@ class FlashAttentionImpl(AttentionImpl):
|
|||||||
assert VLLM_FLASH_ATTN_VERSION in [2, 3]
|
assert VLLM_FLASH_ATTN_VERSION in [2, 3]
|
||||||
self.fa_version = VLLM_FLASH_ATTN_VERSION
|
self.fa_version = VLLM_FLASH_ATTN_VERSION
|
||||||
|
|
||||||
|
if not is_fa_version_supported(self.fa_version):
|
||||||
|
logger.error("Cannot use FA version %d is not supported due to %s",
|
||||||
|
self.fa_version,
|
||||||
|
fa_version_unsupported_reason(self.fa_version))
|
||||||
|
|
||||||
assert is_fa_version_supported(self.fa_version)
|
assert is_fa_version_supported(self.fa_version)
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
|
11
vllm/v1/attention/backends/flash_attn.py
Normal file → Executable file
11
vllm/v1/attention/backends/flash_attn.py
Normal file → Executable file
@ -10,11 +10,15 @@ import triton.language as tl
|
|||||||
from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
|
from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
|
||||||
AttentionMetadata, AttentionType)
|
AttentionMetadata, AttentionType)
|
||||||
from vllm.envs import VLLM_FLASH_ATTN_VERSION
|
from vllm.envs import VLLM_FLASH_ATTN_VERSION
|
||||||
|
from vllm.logger import init_logger
|
||||||
from vllm.platforms import current_platform
|
from vllm.platforms import current_platform
|
||||||
from vllm.utils import cdiv
|
from vllm.utils import cdiv
|
||||||
from vllm.vllm_flash_attn import (flash_attn_varlen_func,
|
from vllm.vllm_flash_attn import (fa_version_unsupported_reason,
|
||||||
|
flash_attn_varlen_func,
|
||||||
is_fa_version_supported)
|
is_fa_version_supported)
|
||||||
|
|
||||||
|
logger = init_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
class FlashAttentionBackend(AttentionBackend):
|
class FlashAttentionBackend(AttentionBackend):
|
||||||
|
|
||||||
@ -143,6 +147,11 @@ class FlashAttentionImpl(AttentionImpl):
|
|||||||
assert VLLM_FLASH_ATTN_VERSION in [2, 3]
|
assert VLLM_FLASH_ATTN_VERSION in [2, 3]
|
||||||
self.fa_version = VLLM_FLASH_ATTN_VERSION
|
self.fa_version = VLLM_FLASH_ATTN_VERSION
|
||||||
|
|
||||||
|
if not is_fa_version_supported(self.fa_version):
|
||||||
|
logger.error("Cannot use FA version %d is not supported due to %s",
|
||||||
|
self.fa_version,
|
||||||
|
fa_version_unsupported_reason(self.fa_version))
|
||||||
|
|
||||||
assert is_fa_version_supported(self.fa_version)
|
assert is_fa_version_supported(self.fa_version)
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
|
Loading…
x
Reference in New Issue
Block a user