[Kernel] FlashMLA integration (#13747)
Signed-off-by: Lucas Wilkinson <lwilkinson@neuralmagic.com> Signed-off-by: Lucas Wilkinson <lwilkins@redhat.com>
This commit is contained in:
parent
b382a7f28f
commit
f95903909f
@ -575,77 +575,8 @@ if(VLLM_GPU_LANG STREQUAL "HIP")
|
||||
WITH_SOABI)
|
||||
endif()
|
||||
|
||||
# vllm-flash-attn currently only supported on CUDA
|
||||
if (NOT VLLM_GPU_LANG STREQUAL "CUDA")
|
||||
return()
|
||||
# For CUDA we also build and ship some external projects.
|
||||
if (VLLM_GPU_LANG STREQUAL "CUDA")
|
||||
include(cmake/external_projects/flashmla.cmake)
|
||||
include(cmake/external_projects/vllm_flash_attn.cmake)
|
||||
endif ()
|
||||
|
||||
# vLLM flash attention requires VLLM_GPU_ARCHES to contain the set of target
|
||||
# arches in the CMake syntax (75-real, 89-virtual, etc), since we clear the
|
||||
# arches in the CUDA case (and instead set the gencodes on a per file basis)
|
||||
# we need to manually set VLLM_GPU_ARCHES here.
|
||||
if(VLLM_GPU_LANG STREQUAL "CUDA")
|
||||
foreach(_ARCH ${CUDA_ARCHS})
|
||||
string(REPLACE "." "" _ARCH "${_ARCH}")
|
||||
list(APPEND VLLM_GPU_ARCHES "${_ARCH}-real")
|
||||
endforeach()
|
||||
endif()
|
||||
|
||||
#
|
||||
# Build vLLM flash attention from source
|
||||
#
|
||||
# IMPORTANT: This has to be the last thing we do, because vllm-flash-attn uses the same macros/functions as vLLM.
|
||||
# Because functions all belong to the global scope, vllm-flash-attn's functions overwrite vLLMs.
|
||||
# They should be identical but if they aren't, this is a massive footgun.
|
||||
#
|
||||
# The vllm-flash-attn install rules are nested under vllm to make sure the library gets installed in the correct place.
|
||||
# To only install vllm-flash-attn, use --component _vllm_fa2_C (for FA2) or --component _vllm_fa3_C (for FA3).
|
||||
# If no component is specified, vllm-flash-attn is still installed.
|
||||
|
||||
# If VLLM_FLASH_ATTN_SRC_DIR is set, vllm-flash-attn is installed from that directory instead of downloading.
|
||||
# This is to enable local development of vllm-flash-attn within vLLM.
|
||||
# It can be set as an environment variable or passed as a cmake argument.
|
||||
# The environment variable takes precedence.
|
||||
if (DEFINED ENV{VLLM_FLASH_ATTN_SRC_DIR})
|
||||
set(VLLM_FLASH_ATTN_SRC_DIR $ENV{VLLM_FLASH_ATTN_SRC_DIR})
|
||||
endif()
|
||||
|
||||
if(VLLM_FLASH_ATTN_SRC_DIR)
|
||||
FetchContent_Declare(
|
||||
vllm-flash-attn SOURCE_DIR
|
||||
${VLLM_FLASH_ATTN_SRC_DIR}
|
||||
BINARY_DIR ${CMAKE_BINARY_DIR}/vllm-flash-attn
|
||||
)
|
||||
else()
|
||||
FetchContent_Declare(
|
||||
vllm-flash-attn
|
||||
GIT_REPOSITORY https://github.com/vllm-project/flash-attention.git
|
||||
GIT_TAG 720c94869cf2e0ff5a706e9c7f1dce0939686ade
|
||||
GIT_PROGRESS TRUE
|
||||
# Don't share the vllm-flash-attn build between build types
|
||||
BINARY_DIR ${CMAKE_BINARY_DIR}/vllm-flash-attn
|
||||
)
|
||||
endif()
|
||||
|
||||
|
||||
# Fetch the vllm-flash-attn library
|
||||
FetchContent_MakeAvailable(vllm-flash-attn)
|
||||
message(STATUS "vllm-flash-attn is available at ${vllm-flash-attn_SOURCE_DIR}")
|
||||
|
||||
# Copy over the vllm-flash-attn python files (duplicated for fa2 and fa3, in
|
||||
# case only one is built, in the case both are built redundant work is done)
|
||||
install(
|
||||
DIRECTORY ${vllm-flash-attn_SOURCE_DIR}/vllm_flash_attn/
|
||||
DESTINATION vllm_flash_attn
|
||||
COMPONENT _vllm_fa2_C
|
||||
FILES_MATCHING PATTERN "*.py"
|
||||
)
|
||||
|
||||
install(
|
||||
DIRECTORY ${vllm-flash-attn_SOURCE_DIR}/vllm_flash_attn/
|
||||
DESTINATION vllm_flash_attn
|
||||
COMPONENT _vllm_fa3_C
|
||||
FILES_MATCHING PATTERN "*.py"
|
||||
)
|
||||
|
||||
# Nothing after vllm-flash-attn, see comment about macros above
|
||||
|
66
cmake/external_projects/flashmla.cmake
Normal file
66
cmake/external_projects/flashmla.cmake
Normal file
@ -0,0 +1,66 @@
|
||||
include(FetchContent)
|
||||
|
||||
# If FLASH_MLA_SRC_DIR is set, flash-mla is installed from that directory
|
||||
# instead of downloading.
|
||||
# It can be set as an environment variable or passed as a cmake argument.
|
||||
# The environment variable takes precedence.
|
||||
if (DEFINED ENV{FLASH_MLA_SRC_DIR})
|
||||
set(FLASH_MLA_SRC_DIR $ENV{FLASH_MLA_SRC_DIR})
|
||||
endif()
|
||||
|
||||
if(FLASH_MLA_SRC_DIR)
|
||||
FetchContent_Declare(
|
||||
flashmla
|
||||
SOURCE_DIR ${FLASH_MLA_SRC_DIR}
|
||||
CONFIGURE_COMMAND ""
|
||||
BUILD_COMMAND ""
|
||||
)
|
||||
else()
|
||||
FetchContent_Declare(
|
||||
flashmla
|
||||
GIT_REPOSITORY https://github.com/vllm-project/FlashMLA.git
|
||||
GIT_TAG 575f7724b9762f265bbee5889df9c7d630801845
|
||||
GIT_PROGRESS TRUE
|
||||
CONFIGURE_COMMAND ""
|
||||
BUILD_COMMAND ""
|
||||
)
|
||||
endif()
|
||||
|
||||
|
||||
FetchContent_MakeAvailable(flashmla)
|
||||
message(STATUS "FlashMLA is available at ${flashmla_SOURCE_DIR}")
|
||||
|
||||
# The FlashMLA kernels only work on hopper and require CUDA 12.3 or later.
|
||||
# Only build FlashMLA kernels if we are building for something compatible with
|
||||
# sm90a
|
||||
cuda_archs_loose_intersection(FLASH_MLA_ARCHS "9.0a" "${CUDA_ARCHS}")
|
||||
if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER 12.3 AND FLASH_MLA_ARCHS)
|
||||
set(FlashMLA_SOURCES
|
||||
${flashmla_SOURCE_DIR}/csrc/flash_api.cpp
|
||||
${flashmla_SOURCE_DIR}/csrc/flash_fwd_mla_bf16_sm90.cu
|
||||
${flashmla_SOURCE_DIR}/csrc/flash_fwd_mla_fp16_sm90.cu
|
||||
${flashmla_SOURCE_DIR}/csrc/flash_fwd_mla_metadata.cu)
|
||||
|
||||
set(FlashMLA_INCLUDES
|
||||
${flashmla_SOURCE_DIR}/csrc/cutlass/include
|
||||
${flashmla_SOURCE_DIR}/csrc/include)
|
||||
|
||||
set_gencode_flags_for_srcs(
|
||||
SRCS "${FlashMLA_SOURCES}"
|
||||
CUDA_ARCHS "${FLASH_MLA_ARCHS}")
|
||||
|
||||
define_gpu_extension_target(
|
||||
_flashmla_C
|
||||
DESTINATION vllm
|
||||
LANGUAGE ${VLLM_GPU_LANG}
|
||||
SOURCES ${FlashMLA_SOURCES}
|
||||
COMPILE_FLAGS ${VLLM_GPU_FLAGS}
|
||||
ARCHITECTURES ${VLLM_GPU_ARCHES}
|
||||
INCLUDE_DIRECTORIES ${FlashMLA_INCLUDES}
|
||||
USE_SABI 3
|
||||
WITH_SOABI)
|
||||
else()
|
||||
# Create an empty target for setup.py when not targeting sm90a systems
|
||||
add_custom_target(_flashmla_C)
|
||||
endif()
|
||||
|
67
cmake/external_projects/vllm_flash_attn.cmake
Normal file
67
cmake/external_projects/vllm_flash_attn.cmake
Normal file
@ -0,0 +1,67 @@
|
||||
# vLLM flash attention requires VLLM_GPU_ARCHES to contain the set of target
|
||||
# arches in the CMake syntax (75-real, 89-virtual, etc), since we clear the
|
||||
# arches in the CUDA case (and instead set the gencodes on a per file basis)
|
||||
# we need to manually set VLLM_GPU_ARCHES here.
|
||||
if(VLLM_GPU_LANG STREQUAL "CUDA")
|
||||
foreach(_ARCH ${CUDA_ARCHS})
|
||||
string(REPLACE "." "" _ARCH "${_ARCH}")
|
||||
list(APPEND VLLM_GPU_ARCHES "${_ARCH}-real")
|
||||
endforeach()
|
||||
endif()
|
||||
|
||||
#
|
||||
# Build vLLM flash attention from source
|
||||
#
|
||||
# IMPORTANT: This has to be the last thing we do, because vllm-flash-attn uses the same macros/functions as vLLM.
|
||||
# Because functions all belong to the global scope, vllm-flash-attn's functions overwrite vLLMs.
|
||||
# They should be identical but if they aren't, this is a massive footgun.
|
||||
#
|
||||
# The vllm-flash-attn install rules are nested under vllm to make sure the library gets installed in the correct place.
|
||||
# To only install vllm-flash-attn, use --component _vllm_fa2_C (for FA2) or --component _vllm_fa3_C (for FA3).
|
||||
# If no component is specified, vllm-flash-attn is still installed.
|
||||
|
||||
# If VLLM_FLASH_ATTN_SRC_DIR is set, vllm-flash-attn is installed from that directory instead of downloading.
|
||||
# This is to enable local development of vllm-flash-attn within vLLM.
|
||||
# It can be set as an environment variable or passed as a cmake argument.
|
||||
# The environment variable takes precedence.
|
||||
if (DEFINED ENV{VLLM_FLASH_ATTN_SRC_DIR})
|
||||
set(VLLM_FLASH_ATTN_SRC_DIR $ENV{VLLM_FLASH_ATTN_SRC_DIR})
|
||||
endif()
|
||||
|
||||
if(VLLM_FLASH_ATTN_SRC_DIR)
|
||||
FetchContent_Declare(
|
||||
vllm-flash-attn SOURCE_DIR
|
||||
${VLLM_FLASH_ATTN_SRC_DIR}
|
||||
BINARY_DIR ${CMAKE_BINARY_DIR}/vllm-flash-attn
|
||||
)
|
||||
else()
|
||||
FetchContent_Declare(
|
||||
vllm-flash-attn
|
||||
GIT_REPOSITORY https://github.com/vllm-project/flash-attention.git
|
||||
GIT_TAG 720c94869cf2e0ff5a706e9c7f1dce0939686ade
|
||||
GIT_PROGRESS TRUE
|
||||
# Don't share the vllm-flash-attn build between build types
|
||||
BINARY_DIR ${CMAKE_BINARY_DIR}/vllm-flash-attn
|
||||
)
|
||||
endif()
|
||||
|
||||
|
||||
# Fetch the vllm-flash-attn library
|
||||
FetchContent_MakeAvailable(vllm-flash-attn)
|
||||
message(STATUS "vllm-flash-attn is available at ${vllm-flash-attn_SOURCE_DIR}")
|
||||
|
||||
# Copy over the vllm-flash-attn python files (duplicated for fa2 and fa3, in
|
||||
# case only one is built, in the case both are built redundant work is done)
|
||||
install(
|
||||
DIRECTORY ${vllm-flash-attn_SOURCE_DIR}/vllm_flash_attn/
|
||||
DESTINATION vllm_flash_attn
|
||||
COMPONENT _vllm_fa2_C
|
||||
FILES_MATCHING PATTERN "*.py"
|
||||
)
|
||||
|
||||
install(
|
||||
DIRECTORY ${vllm-flash-attn_SOURCE_DIR}/vllm_flash_attn/
|
||||
DESTINATION vllm_flash_attn
|
||||
COMPONENT _vllm_fa3_C
|
||||
FILES_MATCHING PATTERN "*.py"
|
||||
)
|
6
setup.py
6
setup.py
@ -328,6 +328,7 @@ class repackage_wheel(build_ext):
|
||||
files_to_copy = [
|
||||
"vllm/_C.abi3.so",
|
||||
"vllm/_moe_C.abi3.so",
|
||||
"vllm/_flashmla_C.abi3.so",
|
||||
"vllm/vllm_flash_attn/_vllm_fa2_C.abi3.so",
|
||||
"vllm/vllm_flash_attn/_vllm_fa3_C.abi3.so",
|
||||
"vllm/vllm_flash_attn/flash_attn_interface.py",
|
||||
@ -612,6 +613,11 @@ if _is_cuda():
|
||||
# FA3 requires CUDA 12.0 or later
|
||||
ext_modules.append(
|
||||
CMakeExtension(name="vllm.vllm_flash_attn._vllm_fa3_C"))
|
||||
if envs.VLLM_USE_PRECOMPILED or get_nvcc_cuda_version() >= Version("12.3"):
|
||||
# Optional since this doesn't get built (produce an .so file) when
|
||||
# not targeting a hopper system
|
||||
ext_modules.append(
|
||||
CMakeExtension(name="vllm._flashmla_C", optional=True))
|
||||
ext_modules.append(CMakeExtension(name="vllm.cumem_allocator"))
|
||||
|
||||
if _build_custom_ops():
|
||||
|
132
tests/kernels/test_flashmla.py
Normal file
132
tests/kernels/test_flashmla.py
Normal file
@ -0,0 +1,132 @@
|
||||
# Adapted from: https://github.com/deepseek-ai/FlashMLA/blob/main/tests/test_flash_mla.py
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
import math
|
||||
import random
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
import triton
|
||||
|
||||
from vllm.attention.ops.flashmla import (flash_mla_with_kvcache,
|
||||
get_mla_metadata,
|
||||
is_flashmla_supported)
|
||||
|
||||
|
||||
def cal_diff(x: torch.Tensor, y: torch.Tensor, name: str) -> None:
|
||||
x, y = x.double(), y.double()
|
||||
cos_diff = 1 - 2 * (x * y).sum().item() / max(
|
||||
(x * x + y * y).sum().item(), 1e-12)
|
||||
assert cos_diff < 1e-5
|
||||
|
||||
FLASH_MLA_UNSUPPORTED_REASON = is_flashmla_supported()[1] \
|
||||
if not is_flashmla_supported()[0] else "FlashMLA is supported"
|
||||
|
||||
|
||||
@pytest.mark.skipif(not is_flashmla_supported()[0],
|
||||
reason=FLASH_MLA_UNSUPPORTED_REASON)
|
||||
@pytest.mark.parametrize("b", [128])
|
||||
@pytest.mark.parametrize("s_q", [1, 2])
|
||||
@pytest.mark.parametrize("mean_sk", [4096, 8192])
|
||||
@pytest.mark.parametrize("h_q", [16, 32, 64, 128])
|
||||
@pytest.mark.parametrize("h_kv", [1])
|
||||
@pytest.mark.parametrize("d", [576])
|
||||
@pytest.mark.parametrize("dv", [512])
|
||||
@pytest.mark.parametrize("block_size", [64])
|
||||
@pytest.mark.parametrize("causal", [True])
|
||||
@pytest.mark.parametrize("varlen", [False, True])
|
||||
@torch.inference_mode()
|
||||
def test_flash_mla(b, s_q, mean_sk, h_q, h_kv, d, dv, block_size, causal,
|
||||
varlen):
|
||||
# TODO: parametrize using pytest
|
||||
dtype = torch.bfloat16
|
||||
device = torch.device("cuda:0")
|
||||
torch.set_default_dtype(dtype)
|
||||
torch.set_default_device(device)
|
||||
torch.cuda.set_device(device)
|
||||
torch.manual_seed(0)
|
||||
random.seed(0)
|
||||
|
||||
print(f"{b=}, {s_q=}, {mean_sk=}, {h_q=}, {h_kv=}, "
|
||||
f"{d=}, {dv=}, {causal=}, {varlen=}")
|
||||
|
||||
cache_seqlens = torch.full((b, ), mean_sk, dtype=torch.int32)
|
||||
if varlen:
|
||||
for i in range(b):
|
||||
cache_seqlens[i] = max(random.normalvariate(mean_sk, mean_sk / 2),
|
||||
s_q)
|
||||
total_seqlens = cache_seqlens.sum().item()
|
||||
max_seqlen = cache_seqlens.max().item()
|
||||
max_seqlen_pad = triton.cdiv(max_seqlen, 256) * 256
|
||||
|
||||
q = torch.randn(b, s_q, h_q, d)
|
||||
block_table = torch.arange(b * max_seqlen_pad // block_size,
|
||||
dtype=torch.int32).view(
|
||||
b, max_seqlen_pad // block_size)
|
||||
blocked_k = torch.randn(block_table.numel(), block_size, h_kv, d)
|
||||
for i in range(b):
|
||||
blocked_k.view(b, max_seqlen_pad, h_kv,
|
||||
d)[i, cache_seqlens[i].item():] = float("nan")
|
||||
blocked_v = blocked_k[..., :dv]
|
||||
|
||||
tile_scheduler_metadata, num_splits = get_mla_metadata(
|
||||
cache_seqlens, s_q * h_q // h_kv, h_kv)
|
||||
|
||||
def flash_mla():
|
||||
return flash_mla_with_kvcache(
|
||||
q,
|
||||
blocked_k,
|
||||
block_table,
|
||||
cache_seqlens,
|
||||
dv,
|
||||
tile_scheduler_metadata,
|
||||
num_splits,
|
||||
causal=causal,
|
||||
)
|
||||
|
||||
def scaled_dot_product_attention(query, key, value, is_causal=False):
|
||||
query = query.float()
|
||||
key = key.float()
|
||||
value = value.float()
|
||||
key = key.repeat_interleave(h_q // h_kv, dim=0)
|
||||
value = value.repeat_interleave(h_q // h_kv, dim=0)
|
||||
attn_weight = query @ key.transpose(-2, -1) / math.sqrt(query.size(-1))
|
||||
if is_causal:
|
||||
s_q = query.shape[-2]
|
||||
s_k = key.shape[-2]
|
||||
attn_bias = torch.zeros(s_q, s_k, dtype=query.dtype)
|
||||
temp_mask = torch.ones(s_q, s_k,
|
||||
dtype=torch.bool).tril(diagonal=s_k - s_q)
|
||||
attn_bias.masked_fill_(temp_mask.logical_not(), float("-inf"))
|
||||
attn_bias.to(query.dtype)
|
||||
attn_weight += attn_bias
|
||||
lse = attn_weight.logsumexp(dim=-1)
|
||||
attn_weight = torch.softmax(attn_weight, dim=-1, dtype=torch.float32)
|
||||
return attn_weight @ value, lse
|
||||
|
||||
def ref_mla():
|
||||
out = torch.empty(b, s_q, h_q, dv, dtype=torch.float32)
|
||||
lse = torch.empty(b, h_q, s_q, dtype=torch.float32)
|
||||
for i in range(b):
|
||||
begin = i * max_seqlen_pad
|
||||
end = begin + cache_seqlens[i]
|
||||
ref_O, LSE = scaled_dot_product_attention(
|
||||
q[i].transpose(0, 1),
|
||||
blocked_k.view(-1, h_kv, d)[begin:end].transpose(0, 1),
|
||||
blocked_v.view(-1, h_kv, dv)[begin:end].transpose(0, 1),
|
||||
is_causal=causal,
|
||||
)
|
||||
out[i] = ref_O.transpose(0, 1)
|
||||
lse[i] = LSE
|
||||
return out, lse
|
||||
|
||||
out_flash, lse_flash = flash_mla()
|
||||
out_torch, lse_torch = ref_mla()
|
||||
cal_diff(out_flash, out_torch, "out")
|
||||
cal_diff(lse_flash, lse_torch, "lse")
|
||||
|
||||
t = triton.testing.do_bench(flash_mla, fast_flush=False)
|
||||
FLOPS = s_q * total_seqlens * h_q * (d + dv) * 2
|
||||
bytes = (total_seqlens * h_kv * d + b * s_q * h_q * d +
|
||||
b * s_q * h_q * dv) * (torch.finfo(dtype).bits // 8)
|
||||
print(f"{t:.3f} ms, {FLOPS / 10 ** 9 / t:.0f} "
|
||||
f"TFLOPS, {bytes / 10 ** 6 / t:.0f} GB/s")
|
@ -1163,3 +1163,67 @@ def get_graph_buffer_ipc_meta(fa: int) -> Tuple[List[int], List[int]]:
|
||||
def register_graph_buffers(fa: int, handles: List[List[int]],
|
||||
offsets: List[List[int]]) -> None:
|
||||
torch.ops._C_custom_ar.register_graph_buffers(fa, handles, offsets)
|
||||
|
||||
|
||||
def get_flash_mla_metadata(
|
||||
cache_seqlens: torch.Tensor,
|
||||
num_heads_per_head_k: int,
|
||||
num_heads_k: int,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
Arguments:
|
||||
cache_seqlens: (batch_size), dtype torch.int32.
|
||||
num_heads_per_head_k: Equals to seq_len_q * num_heads_q // num_heads_k.
|
||||
num_heads_k: num_heads_k.
|
||||
|
||||
Return:
|
||||
tile_scheduler_metadata: (num_sm_parts, TileSchedulerMetaDataSize), dtype torch.int32.
|
||||
num_splits: (batch_size + 1), dtype torch.int32.
|
||||
"""
|
||||
return torch.ops._C.get_flash_mla_metadata(cache_seqlens,
|
||||
num_heads_per_head_k,
|
||||
num_heads_k)
|
||||
|
||||
|
||||
def flash_mla_with_kvcache(
|
||||
q: torch.Tensor,
|
||||
k_cache: torch.Tensor,
|
||||
block_table: torch.Tensor,
|
||||
cache_seqlens: torch.Tensor,
|
||||
head_dim_v: int,
|
||||
tile_scheduler_metadata: torch.Tensor,
|
||||
num_splits: torch.Tensor,
|
||||
softmax_scale: Optional[float] = None,
|
||||
causal: bool = False,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
Arguments:
|
||||
q: (batch_size, seq_len_q, num_heads_q, head_dim).
|
||||
k_cache: (num_blocks, page_block_size, num_heads_k, head_dim).
|
||||
block_table: (batch_size, max_num_blocks_per_seq), torch.int32.
|
||||
cache_seqlens: (batch_size), torch.int32.
|
||||
head_dim_v: Head_dim of v.
|
||||
tile_scheduler_metadata: (num_sm_parts, TileSchedulerMetaDataSize), torch.int32, return by get_mla_metadata.
|
||||
num_splits: (batch_size + 1), torch.int32, return by get_mla_metadata.
|
||||
softmax_scale: float. The scaling of QK^T before applying softmax. Default to 1 / sqrt(head_dim).
|
||||
causal: bool. Whether to apply causal attention mask.
|
||||
|
||||
Return:
|
||||
out: (batch_size, seq_len_q, num_heads_q, head_dim_v).
|
||||
softmax_lse: (batch_size, num_heads_q, seq_len_q), torch.float32.
|
||||
"""
|
||||
if softmax_scale is None:
|
||||
softmax_scale = q.shape[-1]**(-0.5)
|
||||
out, softmax_lse = torch.ops._C.flash_mla_fwd_kvcache(
|
||||
q,
|
||||
k_cache,
|
||||
None,
|
||||
head_dim_v,
|
||||
cache_seqlens,
|
||||
block_table,
|
||||
softmax_scale,
|
||||
causal,
|
||||
tile_scheduler_metadata,
|
||||
num_splits,
|
||||
)
|
||||
return out, softmax_lse
|
||||
|
239
vllm/attention/backends/flashmla.py
Normal file
239
vllm/attention/backends/flashmla.py
Normal file
@ -0,0 +1,239 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
from contextlib import contextmanager
|
||||
from dataclasses import dataclass
|
||||
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Type
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.attention.backends.abstract import AttentionType
|
||||
from vllm.attention.backends.mla.common import (MLACommonBackend,
|
||||
MLACommonImpl,
|
||||
MLACommonMetadata,
|
||||
MLACommonMetadataBuilder,
|
||||
MLACommonState)
|
||||
from vllm.attention.ops.flashmla import (flash_mla_with_kvcache,
|
||||
get_mla_metadata,
|
||||
is_flashmla_supported)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vllm.worker.model_runner import ModelInputForGPUWithSamplingMetadata
|
||||
|
||||
|
||||
class FlashMLABackend(MLACommonBackend):
|
||||
|
||||
@staticmethod
|
||||
def get_name() -> str:
|
||||
return "FLASHMLA"
|
||||
|
||||
@staticmethod
|
||||
def get_impl_cls() -> Type["FlashMLAImpl"]:
|
||||
return FlashMLAImpl
|
||||
|
||||
@staticmethod
|
||||
def get_metadata_cls() -> Type["FlashMLAMetadata"]:
|
||||
return FlashMLAMetadata
|
||||
|
||||
@staticmethod
|
||||
def get_builder_cls() -> Type["FlashMLAMetadataBuilder"]:
|
||||
return FlashMLAMetadataBuilder
|
||||
|
||||
@staticmethod
|
||||
def get_state_cls() -> Type["FlashMLAState"]:
|
||||
return FlashMLAState
|
||||
|
||||
|
||||
@dataclass
|
||||
class FlashMLAMetadata(MLACommonMetadata):
|
||||
decode_tile_scheduler_metadata: Optional[Tuple[torch.Tensor,
|
||||
torch.Tensor]] = None
|
||||
decode_num_splits: Optional[torch.Tensor] = None
|
||||
|
||||
@property
|
||||
def decode_metadata(self):
|
||||
decode_metadata = super().decode_metadata
|
||||
# TODO: cache assignment?
|
||||
if decode_metadata is not None:
|
||||
decode_metadata.decode_tile_scheduler_metadata=\
|
||||
self.decode_tile_scheduler_metadata
|
||||
decode_metadata.decode_num_splits=\
|
||||
self.decode_num_splits
|
||||
return decode_metadata
|
||||
|
||||
def advance_step(self,
|
||||
model_input: "ModelInputForGPUWithSamplingMetadata",
|
||||
sampled_token_ids: Optional[torch.Tensor],
|
||||
block_size: int,
|
||||
num_seqs: int,
|
||||
num_queries: int,
|
||||
turn_prefills_into_decodes: bool = False):
|
||||
raise NotImplementedError(
|
||||
"advance_step is not implemented for FlashMLA")
|
||||
|
||||
|
||||
class FlashMLAMetadataBuilder(MLACommonMetadataBuilder[FlashMLAMetadata]):
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
self.num_q_heads = self.runner.model_config.get_num_attention_heads(
|
||||
self.runner.parallel_config)
|
||||
|
||||
def build(self, seq_lens: List[int], query_lens: List[int],
|
||||
cuda_graph_pad_size: int, batch_size: int):
|
||||
m = super().build(seq_lens, query_lens, cuda_graph_pad_size,
|
||||
batch_size)
|
||||
|
||||
if m.num_decode_tokens > 0:
|
||||
m.decode_tile_scheduler_metadata, m.decode_num_splits = \
|
||||
get_mla_metadata(
|
||||
m.seq_lens_tensor[m.num_prefills:],
|
||||
self.num_q_heads,
|
||||
1, # MQA for the decode path
|
||||
)
|
||||
|
||||
return m
|
||||
|
||||
|
||||
class FlashMLAState(MLACommonState[FlashMLAMetadata]):
|
||||
|
||||
def __init__(self, *args, **kwds):
|
||||
super().__init__(*args, **kwds)
|
||||
|
||||
self.num_q_heads = self.runner.model_config.get_num_attention_heads(
|
||||
self.runner.parallel_config)
|
||||
|
||||
@contextmanager
|
||||
def graph_capture(self, max_batch_size: int):
|
||||
# Run a dummy `get_mla_metadata` so we can get the right shapes
|
||||
self._graph_decoder_tile_scheduler_metadata, \
|
||||
self._graph_decode_num_splits = get_mla_metadata(
|
||||
torch.ones(
|
||||
max_batch_size, dtype=torch.int32, device=self.runner.device),
|
||||
self.num_q_heads,
|
||||
1, # MQA for the decode path
|
||||
)
|
||||
|
||||
with super().graph_capture(max_batch_size):
|
||||
yield
|
||||
|
||||
del self._graph_decoder_tile_scheduler_metadata
|
||||
del self._graph_decode_num_splits
|
||||
|
||||
def graph_capture_get_metadata_for_batch(
|
||||
self, batch_size: int, is_encoder_decoder_model: bool = False):
|
||||
metadata = super().graph_capture_get_metadata_for_batch(
|
||||
batch_size, is_encoder_decoder_model)
|
||||
assert metadata.num_decode_tokens > 0
|
||||
|
||||
decoder_tile_scheduler_metadata, decode_num_splits = get_mla_metadata(
|
||||
self._graph_seq_lens[:batch_size],
|
||||
self.num_q_heads,
|
||||
1, # MQA for the decode path
|
||||
)
|
||||
|
||||
self._graph_decoder_tile_scheduler_metadata.copy_(
|
||||
decoder_tile_scheduler_metadata)
|
||||
self._graph_decode_num_splits[:batch_size + 1].copy_(decode_num_splits)
|
||||
|
||||
metadata.decode_tile_scheduler_metadata=\
|
||||
self._graph_decoder_tile_scheduler_metadata
|
||||
metadata.decode_num_splits=\
|
||||
self._graph_decode_num_splits[:batch_size + 1]
|
||||
|
||||
return metadata
|
||||
|
||||
def get_graph_input_buffers(self,
|
||||
attn_metadata,
|
||||
is_encoder_decoder_model: bool = False):
|
||||
input_buffers = super().get_graph_input_buffers(
|
||||
attn_metadata, is_encoder_decoder_model)
|
||||
input_buffers["decode_tile_scheduler_metadata"] = \
|
||||
attn_metadata.decode_metadata.decode_tile_scheduler_metadata
|
||||
input_buffers["decode_num_splits"] = \
|
||||
attn_metadata.decode_metadata.decode_num_splits
|
||||
|
||||
return input_buffers
|
||||
|
||||
def prepare_graph_input_buffers(self,
|
||||
input_buffers,
|
||||
attn_metadata,
|
||||
is_encoder_decoder_model: bool = False):
|
||||
super().prepare_graph_input_buffers(input_buffers, attn_metadata,
|
||||
is_encoder_decoder_model)
|
||||
|
||||
input_buffers["decode_tile_scheduler_metadata"].copy_(
|
||||
attn_metadata.decode_metadata.decode_tile_scheduler_metadata)
|
||||
input_buffers["decode_num_splits"].copy_(
|
||||
attn_metadata.decode_metadata.decode_num_splits)
|
||||
|
||||
|
||||
class FlashMLAImpl(MLACommonImpl[FlashMLAMetadata]):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
num_heads: int,
|
||||
head_size: int,
|
||||
scale: float,
|
||||
num_kv_heads: int,
|
||||
alibi_slopes: Optional[List[float]],
|
||||
sliding_window: Optional[int],
|
||||
kv_cache_dtype: str,
|
||||
blocksparse_params: Optional[Dict[str, Any]],
|
||||
logits_soft_cap: Optional[float],
|
||||
attn_type: str,
|
||||
# MLA Specific Arguments
|
||||
**mla_args) -> None:
|
||||
super().__init__(num_heads, head_size, scale, num_kv_heads,
|
||||
alibi_slopes, sliding_window, kv_cache_dtype,
|
||||
blocksparse_params, logits_soft_cap, attn_type,
|
||||
**mla_args)
|
||||
|
||||
assert is_flashmla_supported(), \
|
||||
"FlashMLA is not supported on this device"
|
||||
|
||||
unsupported_features = [
|
||||
alibi_slopes, sliding_window, blocksparse_params, logits_soft_cap
|
||||
]
|
||||
if any(unsupported_features):
|
||||
raise NotImplementedError(
|
||||
"FlashMLAImpl does not support one of the following: "
|
||||
"alibi_slopes, sliding_window, blocksparse_params, "
|
||||
"logits_soft_cap")
|
||||
|
||||
if attn_type != AttentionType.DECODER:
|
||||
raise NotImplementedError("Encoder self-attention and "
|
||||
"encoder/decoder cross-attention "
|
||||
"are not implemented for "
|
||||
"FlashMLAImpl")
|
||||
|
||||
def _forward_decode(
|
||||
self,
|
||||
q_nope: torch.Tensor,
|
||||
q_pe: torch.Tensor,
|
||||
kv_c_and_k_pe_cache: torch.Tensor,
|
||||
attn_metadata: FlashMLAMetadata,
|
||||
) -> torch.Tensor:
|
||||
assert kv_c_and_k_pe_cache.numel() > 0
|
||||
if self.kv_cache_dtype.startswith("fp8"):
|
||||
raise NotImplementedError("FP8 FlashMLA not yet supported")
|
||||
|
||||
decode_meta = attn_metadata.decode_metadata
|
||||
assert decode_meta is not None
|
||||
|
||||
q = torch.cat([q_nope, q_pe], dim=-1)\
|
||||
.unsqueeze(1) # Add seqlen dim of 1 (decode)
|
||||
|
||||
o, _ = flash_mla_with_kvcache(
|
||||
q=q,
|
||||
k_cache=kv_c_and_k_pe_cache.unsqueeze(-2), # Add head dim of 1
|
||||
block_table=decode_meta.block_tables,
|
||||
cache_seqlens=decode_meta.seq_lens_tensor,
|
||||
head_dim_v=self.kv_lora_rank,
|
||||
tile_scheduler_metadata=decode_meta.decode_tile_scheduler_metadata,
|
||||
num_splits=decode_meta.decode_num_splits,
|
||||
softmax_scale=self.scale,
|
||||
causal=True,
|
||||
)
|
||||
|
||||
return self._v_up_proj_and_o_proj(o)
|
@ -293,7 +293,10 @@ class MLACommonBackend(AttentionBackend):
|
||||
return [576]
|
||||
|
||||
|
||||
class MLACommonState(AttentionState):
|
||||
T = TypeVar("T", bound="MLACommonMetadata")
|
||||
|
||||
|
||||
class MLACommonState(AttentionState, Generic[T]):
|
||||
|
||||
def __init__(self, runner):
|
||||
self.runner = runner
|
||||
@ -355,7 +358,9 @@ class MLACommonState(AttentionState):
|
||||
return self.__class__(self.runner)
|
||||
|
||||
def graph_capture_get_metadata_for_batch(
|
||||
self, batch_size: int, is_encoder_decoder_model: bool = False):
|
||||
self,
|
||||
batch_size: int,
|
||||
is_encoder_decoder_model: bool = False) -> T:
|
||||
assert self._is_graph_capturing
|
||||
|
||||
attn_metadata = self.runner.attn_backend.make_metadata(
|
||||
@ -507,8 +512,8 @@ class MLACommonMetadata(AttentionMetadata):
|
||||
# [4, 6], it is [0, 4, 10].
|
||||
seq_start_loc: Optional[torch.Tensor] = None
|
||||
|
||||
_cached_prefill_metadata: Optional["MLACommonMetadata"] = None
|
||||
_cached_decode_metadata: Optional["MLACommonMetadata"] = None
|
||||
_cached_prefill_metadata: Optional[Any] = None
|
||||
_cached_decode_metadata: Optional[Any] = None
|
||||
|
||||
num_prefill_tokens: int
|
||||
|
||||
@ -537,7 +542,7 @@ class MLACommonMetadata(AttentionMetadata):
|
||||
f" received {self.head_dim}.")
|
||||
|
||||
@property
|
||||
def prefill_metadata(self) -> Optional["MLACommonMetadata"]:
|
||||
def prefill_metadata(self):
|
||||
if self.num_prefills == 0:
|
||||
return None
|
||||
|
||||
@ -565,7 +570,7 @@ class MLACommonMetadata(AttentionMetadata):
|
||||
input_positions = (None if self.input_positions is None else
|
||||
self.input_positions[:self.num_prefill_tokens])
|
||||
|
||||
self._cached_prefill_metadata = MLACommonMetadata(
|
||||
self._cached_prefill_metadata = self.__class__(
|
||||
# Required by ModelRunner
|
||||
use_cuda_graph=False, # Not Attention Related
|
||||
# Required by Attention Metadata
|
||||
@ -599,7 +604,7 @@ class MLACommonMetadata(AttentionMetadata):
|
||||
return self._cached_prefill_metadata
|
||||
|
||||
@property
|
||||
def decode_metadata(self) -> Optional["MLACommonMetadata"]:
|
||||
def decode_metadata(self):
|
||||
if self.num_decode_tokens == 0:
|
||||
return None
|
||||
|
||||
@ -617,7 +622,7 @@ class MLACommonMetadata(AttentionMetadata):
|
||||
input_positions = (None if self.input_positions is None else
|
||||
self.input_positions[self.num_prefill_tokens:])
|
||||
|
||||
self._cached_decode_metadata = MLACommonMetadata(
|
||||
self._cached_decode_metadata = self.__class__(
|
||||
# Required by ModelRunner
|
||||
use_cuda_graph=self.use_cuda_graph, # Not Attention Related
|
||||
# Required by Attention Metadata
|
||||
@ -723,10 +728,7 @@ class MLACommonMetadata(AttentionMetadata):
|
||||
block_tables=self.block_tables)
|
||||
|
||||
|
||||
T = TypeVar("T", bound=MLACommonMetadata)
|
||||
|
||||
|
||||
class MLACommonMetadataBuilder(AttentionMetadataBuilder[MLACommonMetadata]):
|
||||
class MLACommonMetadataBuilder(AttentionMetadataBuilder[T], Generic[T]):
|
||||
"""
|
||||
NOTE: Please read the comment at the top of the file before trying to
|
||||
understand this class
|
||||
@ -959,7 +961,7 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[MLACommonMetadata]):
|
||||
assert max(context_chunk_seq_tot) <= \
|
||||
self.chunked_prefill_workspace_size
|
||||
|
||||
return MLACommonMetadata(
|
||||
return self.runner.attn_backend.make_metadata(
|
||||
# Required by ModelRunner
|
||||
use_cuda_graph=use_captured_graph, # Not Attention Related
|
||||
# Required by Attention Metadata
|
||||
|
115
vllm/attention/ops/flashmla.py
Normal file
115
vllm/attention/ops/flashmla.py
Normal file
@ -0,0 +1,115 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# adapted from: https://github.com/deepseek-ai/FlashMLA/blob/main/flash_mla/flash_mla_interface.py
|
||||
from typing import Optional, Tuple
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.logger import init_logger
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
if current_platform.is_cuda():
|
||||
try:
|
||||
import vllm._flashmla_C # noqa: F401
|
||||
_flashmla_C_AVAILABLE = True
|
||||
except ImportError:
|
||||
_flashmla_C_AVAILABLE = False
|
||||
else:
|
||||
_flashmla_C_AVAILABLE = False
|
||||
|
||||
|
||||
def is_flashmla_supported() -> Tuple[bool, Optional[str]]:
|
||||
"""
|
||||
Return: is_supported_flag, unsupported_reason (optional).
|
||||
"""
|
||||
if not current_platform.is_cuda():
|
||||
return False, "FlashMLA is only supported on CUDA devices."
|
||||
if current_platform.get_device_capability()[0] != 9:
|
||||
return False, "FlashMLA is only supported on Hopper devices."
|
||||
if not _flashmla_C_AVAILABLE:
|
||||
return False, "vllm._flashmla_C is not available, likely was not "\
|
||||
"compiled due to insufficient nvcc version or a supported arch "\
|
||||
"(only sm90a currently) was not in the list of target arches to "\
|
||||
"compile for."
|
||||
return True, None
|
||||
|
||||
|
||||
def get_mla_metadata(
|
||||
cache_seqlens: torch.Tensor,
|
||||
num_heads_per_head_k: int,
|
||||
num_heads_k: int,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
Arguments:
|
||||
cache_seqlens: (batch_size), dtype torch.int32.
|
||||
num_heads_per_head_k: Equals to seq_len_q * num_heads_q // num_heads_k.
|
||||
num_heads_k: num_heads_k.
|
||||
|
||||
Return:
|
||||
tile_scheduler_metadata: (num_sm_parts, TileSchedulerMetaDataSize),
|
||||
dtype torch.int32.
|
||||
num_splits: (batch_size + 1), dtype torch.int32.
|
||||
"""
|
||||
return torch.ops._flashmla_C.get_mla_metadata(cache_seqlens,
|
||||
num_heads_per_head_k,
|
||||
num_heads_k)
|
||||
|
||||
|
||||
def flash_mla_with_kvcache(
|
||||
q: torch.Tensor,
|
||||
k_cache: torch.Tensor,
|
||||
block_table: torch.Tensor,
|
||||
cache_seqlens: torch.Tensor,
|
||||
head_dim_v: int,
|
||||
tile_scheduler_metadata: torch.Tensor,
|
||||
num_splits: torch.Tensor,
|
||||
softmax_scale: Optional[float] = None,
|
||||
causal: bool = False,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
Arguments:
|
||||
q: (batch_size, seq_len_q, num_heads_q, head_dim).
|
||||
k_cache: (num_blocks, page_block_size, num_heads_k, head_dim).
|
||||
block_table: (batch_size, max_num_blocks_per_seq), torch.int32.
|
||||
cache_seqlens: (batch_size), torch.int32.
|
||||
head_dim_v: Head_dim of v.
|
||||
tile_scheduler_metadata: (num_sm_parts, TileSchedulerMetaDataSize),
|
||||
torch.int32, return by get_mla_metadata.
|
||||
num_splits: (batch_size + 1), torch.int32, return by get_mla_metadata.
|
||||
softmax_scale: float. The scaling of QK^T before applying softmax.
|
||||
Default to 1 / sqrt(head_dim).
|
||||
causal: bool. Whether to apply causal attention mask.
|
||||
|
||||
Return:
|
||||
out: (batch_size, seq_len_q, num_heads_q, head_dim_v).
|
||||
softmax_lse: (batch_size, num_heads_q, seq_len_q), torch.float32.
|
||||
"""
|
||||
if softmax_scale is None:
|
||||
softmax_scale = q.shape[-1]**(-0.5)
|
||||
out, softmax_lse = torch.ops._flashmla_C.fwd_kvcache_mla(
|
||||
q,
|
||||
k_cache,
|
||||
None,
|
||||
head_dim_v,
|
||||
cache_seqlens,
|
||||
block_table,
|
||||
softmax_scale,
|
||||
causal,
|
||||
tile_scheduler_metadata,
|
||||
num_splits,
|
||||
)
|
||||
return out, softmax_lse
|
||||
|
||||
|
||||
#
|
||||
# TODO: Add fake functions
|
||||
#
|
||||
# @register_fake("_flashmla_C::get_mla_metadata")
|
||||
# def _get_mla_metadata_fake(....) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
# return ....
|
||||
#
|
||||
# @register_fake("_flashmla_C::fwd_kvcache_mla")
|
||||
# def _fwd_kvcache_mla_fake(....) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
# return ....
|
||||
#
|
@ -141,6 +141,14 @@ class CudaPlatformBase(Platform):
|
||||
cache_config = vllm_config.cache_config
|
||||
if cache_config and cache_config.block_size is None:
|
||||
cache_config.block_size = 16
|
||||
# TODO(lucas): handle this more gracefully
|
||||
if envs.VLLM_ATTENTION_BACKEND is not None \
|
||||
and envs.VLLM_ATTENTION_BACKEND == "FLASHMLA" \
|
||||
and cache_config.block_size != 64:
|
||||
cache_config.block_size = 64
|
||||
logger.info(
|
||||
"FlashMLA: Forcing kv cache block size to 64 since this"
|
||||
" is currently the only block size supported by the kernel.")
|
||||
|
||||
@classmethod
|
||||
def get_current_memory_usage(cls,
|
||||
@ -157,6 +165,22 @@ class CudaPlatformBase(Platform):
|
||||
logger.info("Using Flash Attention backend on V1 engine.")
|
||||
return "vllm.v1.attention.backends.flash_attn.FlashAttentionBackend"
|
||||
if use_mla:
|
||||
if selected_backend == _Backend.FLASHMLA:
|
||||
from vllm.attention.backends.flashmla import (
|
||||
is_flashmla_supported)
|
||||
if not is_flashmla_supported()[0]:
|
||||
logger.warning(
|
||||
"FlashMLA backend is not supported due to %s",
|
||||
is_flashmla_supported()[1])
|
||||
elif block_size != 64:
|
||||
logger.warning(
|
||||
"FlashMLA backend is not supported for block size %d"
|
||||
" (currently only supports block size 64).",
|
||||
block_size)
|
||||
else:
|
||||
logger.info("Using FlashMLA backend.")
|
||||
return "vllm.attention.backends.flashmla.FlashMLABackend"
|
||||
|
||||
logger.info("Using Triton MLA backend.")
|
||||
return "vllm.attention.backends.triton_mla.TritonMLABackend"
|
||||
if selected_backend == _Backend.FLASHINFER:
|
||||
|
@ -35,6 +35,7 @@ class _Backend(enum.Enum):
|
||||
OPENVINO = enum.auto()
|
||||
FLASHINFER = enum.auto()
|
||||
TRITON_MLA = enum.auto()
|
||||
FLASHMLA = enum.auto()
|
||||
HPU_ATTN = enum.auto()
|
||||
PALLAS = enum.auto()
|
||||
PALLAS_VLLM_V1 = enum.auto()
|
||||
|
Loading…
x
Reference in New Issue
Block a user