[Kernel] FlashMLA integration (#13747)
Signed-off-by: Lucas Wilkinson <lwilkinson@neuralmagic.com> Signed-off-by: Lucas Wilkinson <lwilkins@redhat.com>
This commit is contained in:
parent
b382a7f28f
commit
f95903909f
@ -575,77 +575,8 @@ if(VLLM_GPU_LANG STREQUAL "HIP")
|
|||||||
WITH_SOABI)
|
WITH_SOABI)
|
||||||
endif()
|
endif()
|
||||||
|
|
||||||
# vllm-flash-attn currently only supported on CUDA
|
# For CUDA we also build and ship some external projects.
|
||||||
if (NOT VLLM_GPU_LANG STREQUAL "CUDA")
|
if (VLLM_GPU_LANG STREQUAL "CUDA")
|
||||||
return()
|
include(cmake/external_projects/flashmla.cmake)
|
||||||
|
include(cmake/external_projects/vllm_flash_attn.cmake)
|
||||||
endif ()
|
endif ()
|
||||||
|
|
||||||
# vLLM flash attention requires VLLM_GPU_ARCHES to contain the set of target
|
|
||||||
# arches in the CMake syntax (75-real, 89-virtual, etc), since we clear the
|
|
||||||
# arches in the CUDA case (and instead set the gencodes on a per file basis)
|
|
||||||
# we need to manually set VLLM_GPU_ARCHES here.
|
|
||||||
if(VLLM_GPU_LANG STREQUAL "CUDA")
|
|
||||||
foreach(_ARCH ${CUDA_ARCHS})
|
|
||||||
string(REPLACE "." "" _ARCH "${_ARCH}")
|
|
||||||
list(APPEND VLLM_GPU_ARCHES "${_ARCH}-real")
|
|
||||||
endforeach()
|
|
||||||
endif()
|
|
||||||
|
|
||||||
#
|
|
||||||
# Build vLLM flash attention from source
|
|
||||||
#
|
|
||||||
# IMPORTANT: This has to be the last thing we do, because vllm-flash-attn uses the same macros/functions as vLLM.
|
|
||||||
# Because functions all belong to the global scope, vllm-flash-attn's functions overwrite vLLMs.
|
|
||||||
# They should be identical but if they aren't, this is a massive footgun.
|
|
||||||
#
|
|
||||||
# The vllm-flash-attn install rules are nested under vllm to make sure the library gets installed in the correct place.
|
|
||||||
# To only install vllm-flash-attn, use --component _vllm_fa2_C (for FA2) or --component _vllm_fa3_C (for FA3).
|
|
||||||
# If no component is specified, vllm-flash-attn is still installed.
|
|
||||||
|
|
||||||
# If VLLM_FLASH_ATTN_SRC_DIR is set, vllm-flash-attn is installed from that directory instead of downloading.
|
|
||||||
# This is to enable local development of vllm-flash-attn within vLLM.
|
|
||||||
# It can be set as an environment variable or passed as a cmake argument.
|
|
||||||
# The environment variable takes precedence.
|
|
||||||
if (DEFINED ENV{VLLM_FLASH_ATTN_SRC_DIR})
|
|
||||||
set(VLLM_FLASH_ATTN_SRC_DIR $ENV{VLLM_FLASH_ATTN_SRC_DIR})
|
|
||||||
endif()
|
|
||||||
|
|
||||||
if(VLLM_FLASH_ATTN_SRC_DIR)
|
|
||||||
FetchContent_Declare(
|
|
||||||
vllm-flash-attn SOURCE_DIR
|
|
||||||
${VLLM_FLASH_ATTN_SRC_DIR}
|
|
||||||
BINARY_DIR ${CMAKE_BINARY_DIR}/vllm-flash-attn
|
|
||||||
)
|
|
||||||
else()
|
|
||||||
FetchContent_Declare(
|
|
||||||
vllm-flash-attn
|
|
||||||
GIT_REPOSITORY https://github.com/vllm-project/flash-attention.git
|
|
||||||
GIT_TAG 720c94869cf2e0ff5a706e9c7f1dce0939686ade
|
|
||||||
GIT_PROGRESS TRUE
|
|
||||||
# Don't share the vllm-flash-attn build between build types
|
|
||||||
BINARY_DIR ${CMAKE_BINARY_DIR}/vllm-flash-attn
|
|
||||||
)
|
|
||||||
endif()
|
|
||||||
|
|
||||||
|
|
||||||
# Fetch the vllm-flash-attn library
|
|
||||||
FetchContent_MakeAvailable(vllm-flash-attn)
|
|
||||||
message(STATUS "vllm-flash-attn is available at ${vllm-flash-attn_SOURCE_DIR}")
|
|
||||||
|
|
||||||
# Copy over the vllm-flash-attn python files (duplicated for fa2 and fa3, in
|
|
||||||
# case only one is built, in the case both are built redundant work is done)
|
|
||||||
install(
|
|
||||||
DIRECTORY ${vllm-flash-attn_SOURCE_DIR}/vllm_flash_attn/
|
|
||||||
DESTINATION vllm_flash_attn
|
|
||||||
COMPONENT _vllm_fa2_C
|
|
||||||
FILES_MATCHING PATTERN "*.py"
|
|
||||||
)
|
|
||||||
|
|
||||||
install(
|
|
||||||
DIRECTORY ${vllm-flash-attn_SOURCE_DIR}/vllm_flash_attn/
|
|
||||||
DESTINATION vllm_flash_attn
|
|
||||||
COMPONENT _vllm_fa3_C
|
|
||||||
FILES_MATCHING PATTERN "*.py"
|
|
||||||
)
|
|
||||||
|
|
||||||
# Nothing after vllm-flash-attn, see comment about macros above
|
|
||||||
|
66
cmake/external_projects/flashmla.cmake
Normal file
66
cmake/external_projects/flashmla.cmake
Normal file
@ -0,0 +1,66 @@
|
|||||||
|
include(FetchContent)
|
||||||
|
|
||||||
|
# If FLASH_MLA_SRC_DIR is set, flash-mla is installed from that directory
|
||||||
|
# instead of downloading.
|
||||||
|
# It can be set as an environment variable or passed as a cmake argument.
|
||||||
|
# The environment variable takes precedence.
|
||||||
|
if (DEFINED ENV{FLASH_MLA_SRC_DIR})
|
||||||
|
set(FLASH_MLA_SRC_DIR $ENV{FLASH_MLA_SRC_DIR})
|
||||||
|
endif()
|
||||||
|
|
||||||
|
if(FLASH_MLA_SRC_DIR)
|
||||||
|
FetchContent_Declare(
|
||||||
|
flashmla
|
||||||
|
SOURCE_DIR ${FLASH_MLA_SRC_DIR}
|
||||||
|
CONFIGURE_COMMAND ""
|
||||||
|
BUILD_COMMAND ""
|
||||||
|
)
|
||||||
|
else()
|
||||||
|
FetchContent_Declare(
|
||||||
|
flashmla
|
||||||
|
GIT_REPOSITORY https://github.com/vllm-project/FlashMLA.git
|
||||||
|
GIT_TAG 575f7724b9762f265bbee5889df9c7d630801845
|
||||||
|
GIT_PROGRESS TRUE
|
||||||
|
CONFIGURE_COMMAND ""
|
||||||
|
BUILD_COMMAND ""
|
||||||
|
)
|
||||||
|
endif()
|
||||||
|
|
||||||
|
|
||||||
|
FetchContent_MakeAvailable(flashmla)
|
||||||
|
message(STATUS "FlashMLA is available at ${flashmla_SOURCE_DIR}")
|
||||||
|
|
||||||
|
# The FlashMLA kernels only work on hopper and require CUDA 12.3 or later.
|
||||||
|
# Only build FlashMLA kernels if we are building for something compatible with
|
||||||
|
# sm90a
|
||||||
|
cuda_archs_loose_intersection(FLASH_MLA_ARCHS "9.0a" "${CUDA_ARCHS}")
|
||||||
|
if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER 12.3 AND FLASH_MLA_ARCHS)
|
||||||
|
set(FlashMLA_SOURCES
|
||||||
|
${flashmla_SOURCE_DIR}/csrc/flash_api.cpp
|
||||||
|
${flashmla_SOURCE_DIR}/csrc/flash_fwd_mla_bf16_sm90.cu
|
||||||
|
${flashmla_SOURCE_DIR}/csrc/flash_fwd_mla_fp16_sm90.cu
|
||||||
|
${flashmla_SOURCE_DIR}/csrc/flash_fwd_mla_metadata.cu)
|
||||||
|
|
||||||
|
set(FlashMLA_INCLUDES
|
||||||
|
${flashmla_SOURCE_DIR}/csrc/cutlass/include
|
||||||
|
${flashmla_SOURCE_DIR}/csrc/include)
|
||||||
|
|
||||||
|
set_gencode_flags_for_srcs(
|
||||||
|
SRCS "${FlashMLA_SOURCES}"
|
||||||
|
CUDA_ARCHS "${FLASH_MLA_ARCHS}")
|
||||||
|
|
||||||
|
define_gpu_extension_target(
|
||||||
|
_flashmla_C
|
||||||
|
DESTINATION vllm
|
||||||
|
LANGUAGE ${VLLM_GPU_LANG}
|
||||||
|
SOURCES ${FlashMLA_SOURCES}
|
||||||
|
COMPILE_FLAGS ${VLLM_GPU_FLAGS}
|
||||||
|
ARCHITECTURES ${VLLM_GPU_ARCHES}
|
||||||
|
INCLUDE_DIRECTORIES ${FlashMLA_INCLUDES}
|
||||||
|
USE_SABI 3
|
||||||
|
WITH_SOABI)
|
||||||
|
else()
|
||||||
|
# Create an empty target for setup.py when not targeting sm90a systems
|
||||||
|
add_custom_target(_flashmla_C)
|
||||||
|
endif()
|
||||||
|
|
67
cmake/external_projects/vllm_flash_attn.cmake
Normal file
67
cmake/external_projects/vllm_flash_attn.cmake
Normal file
@ -0,0 +1,67 @@
|
|||||||
|
# vLLM flash attention requires VLLM_GPU_ARCHES to contain the set of target
|
||||||
|
# arches in the CMake syntax (75-real, 89-virtual, etc), since we clear the
|
||||||
|
# arches in the CUDA case (and instead set the gencodes on a per file basis)
|
||||||
|
# we need to manually set VLLM_GPU_ARCHES here.
|
||||||
|
if(VLLM_GPU_LANG STREQUAL "CUDA")
|
||||||
|
foreach(_ARCH ${CUDA_ARCHS})
|
||||||
|
string(REPLACE "." "" _ARCH "${_ARCH}")
|
||||||
|
list(APPEND VLLM_GPU_ARCHES "${_ARCH}-real")
|
||||||
|
endforeach()
|
||||||
|
endif()
|
||||||
|
|
||||||
|
#
|
||||||
|
# Build vLLM flash attention from source
|
||||||
|
#
|
||||||
|
# IMPORTANT: This has to be the last thing we do, because vllm-flash-attn uses the same macros/functions as vLLM.
|
||||||
|
# Because functions all belong to the global scope, vllm-flash-attn's functions overwrite vLLMs.
|
||||||
|
# They should be identical but if they aren't, this is a massive footgun.
|
||||||
|
#
|
||||||
|
# The vllm-flash-attn install rules are nested under vllm to make sure the library gets installed in the correct place.
|
||||||
|
# To only install vllm-flash-attn, use --component _vllm_fa2_C (for FA2) or --component _vllm_fa3_C (for FA3).
|
||||||
|
# If no component is specified, vllm-flash-attn is still installed.
|
||||||
|
|
||||||
|
# If VLLM_FLASH_ATTN_SRC_DIR is set, vllm-flash-attn is installed from that directory instead of downloading.
|
||||||
|
# This is to enable local development of vllm-flash-attn within vLLM.
|
||||||
|
# It can be set as an environment variable or passed as a cmake argument.
|
||||||
|
# The environment variable takes precedence.
|
||||||
|
if (DEFINED ENV{VLLM_FLASH_ATTN_SRC_DIR})
|
||||||
|
set(VLLM_FLASH_ATTN_SRC_DIR $ENV{VLLM_FLASH_ATTN_SRC_DIR})
|
||||||
|
endif()
|
||||||
|
|
||||||
|
if(VLLM_FLASH_ATTN_SRC_DIR)
|
||||||
|
FetchContent_Declare(
|
||||||
|
vllm-flash-attn SOURCE_DIR
|
||||||
|
${VLLM_FLASH_ATTN_SRC_DIR}
|
||||||
|
BINARY_DIR ${CMAKE_BINARY_DIR}/vllm-flash-attn
|
||||||
|
)
|
||||||
|
else()
|
||||||
|
FetchContent_Declare(
|
||||||
|
vllm-flash-attn
|
||||||
|
GIT_REPOSITORY https://github.com/vllm-project/flash-attention.git
|
||||||
|
GIT_TAG 720c94869cf2e0ff5a706e9c7f1dce0939686ade
|
||||||
|
GIT_PROGRESS TRUE
|
||||||
|
# Don't share the vllm-flash-attn build between build types
|
||||||
|
BINARY_DIR ${CMAKE_BINARY_DIR}/vllm-flash-attn
|
||||||
|
)
|
||||||
|
endif()
|
||||||
|
|
||||||
|
|
||||||
|
# Fetch the vllm-flash-attn library
|
||||||
|
FetchContent_MakeAvailable(vllm-flash-attn)
|
||||||
|
message(STATUS "vllm-flash-attn is available at ${vllm-flash-attn_SOURCE_DIR}")
|
||||||
|
|
||||||
|
# Copy over the vllm-flash-attn python files (duplicated for fa2 and fa3, in
|
||||||
|
# case only one is built, in the case both are built redundant work is done)
|
||||||
|
install(
|
||||||
|
DIRECTORY ${vllm-flash-attn_SOURCE_DIR}/vllm_flash_attn/
|
||||||
|
DESTINATION vllm_flash_attn
|
||||||
|
COMPONENT _vllm_fa2_C
|
||||||
|
FILES_MATCHING PATTERN "*.py"
|
||||||
|
)
|
||||||
|
|
||||||
|
install(
|
||||||
|
DIRECTORY ${vllm-flash-attn_SOURCE_DIR}/vllm_flash_attn/
|
||||||
|
DESTINATION vllm_flash_attn
|
||||||
|
COMPONENT _vllm_fa3_C
|
||||||
|
FILES_MATCHING PATTERN "*.py"
|
||||||
|
)
|
6
setup.py
6
setup.py
@ -328,6 +328,7 @@ class repackage_wheel(build_ext):
|
|||||||
files_to_copy = [
|
files_to_copy = [
|
||||||
"vllm/_C.abi3.so",
|
"vllm/_C.abi3.so",
|
||||||
"vllm/_moe_C.abi3.so",
|
"vllm/_moe_C.abi3.so",
|
||||||
|
"vllm/_flashmla_C.abi3.so",
|
||||||
"vllm/vllm_flash_attn/_vllm_fa2_C.abi3.so",
|
"vllm/vllm_flash_attn/_vllm_fa2_C.abi3.so",
|
||||||
"vllm/vllm_flash_attn/_vllm_fa3_C.abi3.so",
|
"vllm/vllm_flash_attn/_vllm_fa3_C.abi3.so",
|
||||||
"vllm/vllm_flash_attn/flash_attn_interface.py",
|
"vllm/vllm_flash_attn/flash_attn_interface.py",
|
||||||
@ -612,6 +613,11 @@ if _is_cuda():
|
|||||||
# FA3 requires CUDA 12.0 or later
|
# FA3 requires CUDA 12.0 or later
|
||||||
ext_modules.append(
|
ext_modules.append(
|
||||||
CMakeExtension(name="vllm.vllm_flash_attn._vllm_fa3_C"))
|
CMakeExtension(name="vllm.vllm_flash_attn._vllm_fa3_C"))
|
||||||
|
if envs.VLLM_USE_PRECOMPILED or get_nvcc_cuda_version() >= Version("12.3"):
|
||||||
|
# Optional since this doesn't get built (produce an .so file) when
|
||||||
|
# not targeting a hopper system
|
||||||
|
ext_modules.append(
|
||||||
|
CMakeExtension(name="vllm._flashmla_C", optional=True))
|
||||||
ext_modules.append(CMakeExtension(name="vllm.cumem_allocator"))
|
ext_modules.append(CMakeExtension(name="vllm.cumem_allocator"))
|
||||||
|
|
||||||
if _build_custom_ops():
|
if _build_custom_ops():
|
||||||
|
132
tests/kernels/test_flashmla.py
Normal file
132
tests/kernels/test_flashmla.py
Normal file
@ -0,0 +1,132 @@
|
|||||||
|
# Adapted from: https://github.com/deepseek-ai/FlashMLA/blob/main/tests/test_flash_mla.py
|
||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
import math
|
||||||
|
import random
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
import torch
|
||||||
|
import triton
|
||||||
|
|
||||||
|
from vllm.attention.ops.flashmla import (flash_mla_with_kvcache,
|
||||||
|
get_mla_metadata,
|
||||||
|
is_flashmla_supported)
|
||||||
|
|
||||||
|
|
||||||
|
def cal_diff(x: torch.Tensor, y: torch.Tensor, name: str) -> None:
|
||||||
|
x, y = x.double(), y.double()
|
||||||
|
cos_diff = 1 - 2 * (x * y).sum().item() / max(
|
||||||
|
(x * x + y * y).sum().item(), 1e-12)
|
||||||
|
assert cos_diff < 1e-5
|
||||||
|
|
||||||
|
FLASH_MLA_UNSUPPORTED_REASON = is_flashmla_supported()[1] \
|
||||||
|
if not is_flashmla_supported()[0] else "FlashMLA is supported"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.skipif(not is_flashmla_supported()[0],
|
||||||
|
reason=FLASH_MLA_UNSUPPORTED_REASON)
|
||||||
|
@pytest.mark.parametrize("b", [128])
|
||||||
|
@pytest.mark.parametrize("s_q", [1, 2])
|
||||||
|
@pytest.mark.parametrize("mean_sk", [4096, 8192])
|
||||||
|
@pytest.mark.parametrize("h_q", [16, 32, 64, 128])
|
||||||
|
@pytest.mark.parametrize("h_kv", [1])
|
||||||
|
@pytest.mark.parametrize("d", [576])
|
||||||
|
@pytest.mark.parametrize("dv", [512])
|
||||||
|
@pytest.mark.parametrize("block_size", [64])
|
||||||
|
@pytest.mark.parametrize("causal", [True])
|
||||||
|
@pytest.mark.parametrize("varlen", [False, True])
|
||||||
|
@torch.inference_mode()
|
||||||
|
def test_flash_mla(b, s_q, mean_sk, h_q, h_kv, d, dv, block_size, causal,
|
||||||
|
varlen):
|
||||||
|
# TODO: parametrize using pytest
|
||||||
|
dtype = torch.bfloat16
|
||||||
|
device = torch.device("cuda:0")
|
||||||
|
torch.set_default_dtype(dtype)
|
||||||
|
torch.set_default_device(device)
|
||||||
|
torch.cuda.set_device(device)
|
||||||
|
torch.manual_seed(0)
|
||||||
|
random.seed(0)
|
||||||
|
|
||||||
|
print(f"{b=}, {s_q=}, {mean_sk=}, {h_q=}, {h_kv=}, "
|
||||||
|
f"{d=}, {dv=}, {causal=}, {varlen=}")
|
||||||
|
|
||||||
|
cache_seqlens = torch.full((b, ), mean_sk, dtype=torch.int32)
|
||||||
|
if varlen:
|
||||||
|
for i in range(b):
|
||||||
|
cache_seqlens[i] = max(random.normalvariate(mean_sk, mean_sk / 2),
|
||||||
|
s_q)
|
||||||
|
total_seqlens = cache_seqlens.sum().item()
|
||||||
|
max_seqlen = cache_seqlens.max().item()
|
||||||
|
max_seqlen_pad = triton.cdiv(max_seqlen, 256) * 256
|
||||||
|
|
||||||
|
q = torch.randn(b, s_q, h_q, d)
|
||||||
|
block_table = torch.arange(b * max_seqlen_pad // block_size,
|
||||||
|
dtype=torch.int32).view(
|
||||||
|
b, max_seqlen_pad // block_size)
|
||||||
|
blocked_k = torch.randn(block_table.numel(), block_size, h_kv, d)
|
||||||
|
for i in range(b):
|
||||||
|
blocked_k.view(b, max_seqlen_pad, h_kv,
|
||||||
|
d)[i, cache_seqlens[i].item():] = float("nan")
|
||||||
|
blocked_v = blocked_k[..., :dv]
|
||||||
|
|
||||||
|
tile_scheduler_metadata, num_splits = get_mla_metadata(
|
||||||
|
cache_seqlens, s_q * h_q // h_kv, h_kv)
|
||||||
|
|
||||||
|
def flash_mla():
|
||||||
|
return flash_mla_with_kvcache(
|
||||||
|
q,
|
||||||
|
blocked_k,
|
||||||
|
block_table,
|
||||||
|
cache_seqlens,
|
||||||
|
dv,
|
||||||
|
tile_scheduler_metadata,
|
||||||
|
num_splits,
|
||||||
|
causal=causal,
|
||||||
|
)
|
||||||
|
|
||||||
|
def scaled_dot_product_attention(query, key, value, is_causal=False):
|
||||||
|
query = query.float()
|
||||||
|
key = key.float()
|
||||||
|
value = value.float()
|
||||||
|
key = key.repeat_interleave(h_q // h_kv, dim=0)
|
||||||
|
value = value.repeat_interleave(h_q // h_kv, dim=0)
|
||||||
|
attn_weight = query @ key.transpose(-2, -1) / math.sqrt(query.size(-1))
|
||||||
|
if is_causal:
|
||||||
|
s_q = query.shape[-2]
|
||||||
|
s_k = key.shape[-2]
|
||||||
|
attn_bias = torch.zeros(s_q, s_k, dtype=query.dtype)
|
||||||
|
temp_mask = torch.ones(s_q, s_k,
|
||||||
|
dtype=torch.bool).tril(diagonal=s_k - s_q)
|
||||||
|
attn_bias.masked_fill_(temp_mask.logical_not(), float("-inf"))
|
||||||
|
attn_bias.to(query.dtype)
|
||||||
|
attn_weight += attn_bias
|
||||||
|
lse = attn_weight.logsumexp(dim=-1)
|
||||||
|
attn_weight = torch.softmax(attn_weight, dim=-1, dtype=torch.float32)
|
||||||
|
return attn_weight @ value, lse
|
||||||
|
|
||||||
|
def ref_mla():
|
||||||
|
out = torch.empty(b, s_q, h_q, dv, dtype=torch.float32)
|
||||||
|
lse = torch.empty(b, h_q, s_q, dtype=torch.float32)
|
||||||
|
for i in range(b):
|
||||||
|
begin = i * max_seqlen_pad
|
||||||
|
end = begin + cache_seqlens[i]
|
||||||
|
ref_O, LSE = scaled_dot_product_attention(
|
||||||
|
q[i].transpose(0, 1),
|
||||||
|
blocked_k.view(-1, h_kv, d)[begin:end].transpose(0, 1),
|
||||||
|
blocked_v.view(-1, h_kv, dv)[begin:end].transpose(0, 1),
|
||||||
|
is_causal=causal,
|
||||||
|
)
|
||||||
|
out[i] = ref_O.transpose(0, 1)
|
||||||
|
lse[i] = LSE
|
||||||
|
return out, lse
|
||||||
|
|
||||||
|
out_flash, lse_flash = flash_mla()
|
||||||
|
out_torch, lse_torch = ref_mla()
|
||||||
|
cal_diff(out_flash, out_torch, "out")
|
||||||
|
cal_diff(lse_flash, lse_torch, "lse")
|
||||||
|
|
||||||
|
t = triton.testing.do_bench(flash_mla, fast_flush=False)
|
||||||
|
FLOPS = s_q * total_seqlens * h_q * (d + dv) * 2
|
||||||
|
bytes = (total_seqlens * h_kv * d + b * s_q * h_q * d +
|
||||||
|
b * s_q * h_q * dv) * (torch.finfo(dtype).bits // 8)
|
||||||
|
print(f"{t:.3f} ms, {FLOPS / 10 ** 9 / t:.0f} "
|
||||||
|
f"TFLOPS, {bytes / 10 ** 6 / t:.0f} GB/s")
|
@ -1163,3 +1163,67 @@ def get_graph_buffer_ipc_meta(fa: int) -> Tuple[List[int], List[int]]:
|
|||||||
def register_graph_buffers(fa: int, handles: List[List[int]],
|
def register_graph_buffers(fa: int, handles: List[List[int]],
|
||||||
offsets: List[List[int]]) -> None:
|
offsets: List[List[int]]) -> None:
|
||||||
torch.ops._C_custom_ar.register_graph_buffers(fa, handles, offsets)
|
torch.ops._C_custom_ar.register_graph_buffers(fa, handles, offsets)
|
||||||
|
|
||||||
|
|
||||||
|
def get_flash_mla_metadata(
|
||||||
|
cache_seqlens: torch.Tensor,
|
||||||
|
num_heads_per_head_k: int,
|
||||||
|
num_heads_k: int,
|
||||||
|
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||||
|
"""
|
||||||
|
Arguments:
|
||||||
|
cache_seqlens: (batch_size), dtype torch.int32.
|
||||||
|
num_heads_per_head_k: Equals to seq_len_q * num_heads_q // num_heads_k.
|
||||||
|
num_heads_k: num_heads_k.
|
||||||
|
|
||||||
|
Return:
|
||||||
|
tile_scheduler_metadata: (num_sm_parts, TileSchedulerMetaDataSize), dtype torch.int32.
|
||||||
|
num_splits: (batch_size + 1), dtype torch.int32.
|
||||||
|
"""
|
||||||
|
return torch.ops._C.get_flash_mla_metadata(cache_seqlens,
|
||||||
|
num_heads_per_head_k,
|
||||||
|
num_heads_k)
|
||||||
|
|
||||||
|
|
||||||
|
def flash_mla_with_kvcache(
|
||||||
|
q: torch.Tensor,
|
||||||
|
k_cache: torch.Tensor,
|
||||||
|
block_table: torch.Tensor,
|
||||||
|
cache_seqlens: torch.Tensor,
|
||||||
|
head_dim_v: int,
|
||||||
|
tile_scheduler_metadata: torch.Tensor,
|
||||||
|
num_splits: torch.Tensor,
|
||||||
|
softmax_scale: Optional[float] = None,
|
||||||
|
causal: bool = False,
|
||||||
|
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||||
|
"""
|
||||||
|
Arguments:
|
||||||
|
q: (batch_size, seq_len_q, num_heads_q, head_dim).
|
||||||
|
k_cache: (num_blocks, page_block_size, num_heads_k, head_dim).
|
||||||
|
block_table: (batch_size, max_num_blocks_per_seq), torch.int32.
|
||||||
|
cache_seqlens: (batch_size), torch.int32.
|
||||||
|
head_dim_v: Head_dim of v.
|
||||||
|
tile_scheduler_metadata: (num_sm_parts, TileSchedulerMetaDataSize), torch.int32, return by get_mla_metadata.
|
||||||
|
num_splits: (batch_size + 1), torch.int32, return by get_mla_metadata.
|
||||||
|
softmax_scale: float. The scaling of QK^T before applying softmax. Default to 1 / sqrt(head_dim).
|
||||||
|
causal: bool. Whether to apply causal attention mask.
|
||||||
|
|
||||||
|
Return:
|
||||||
|
out: (batch_size, seq_len_q, num_heads_q, head_dim_v).
|
||||||
|
softmax_lse: (batch_size, num_heads_q, seq_len_q), torch.float32.
|
||||||
|
"""
|
||||||
|
if softmax_scale is None:
|
||||||
|
softmax_scale = q.shape[-1]**(-0.5)
|
||||||
|
out, softmax_lse = torch.ops._C.flash_mla_fwd_kvcache(
|
||||||
|
q,
|
||||||
|
k_cache,
|
||||||
|
None,
|
||||||
|
head_dim_v,
|
||||||
|
cache_seqlens,
|
||||||
|
block_table,
|
||||||
|
softmax_scale,
|
||||||
|
causal,
|
||||||
|
tile_scheduler_metadata,
|
||||||
|
num_splits,
|
||||||
|
)
|
||||||
|
return out, softmax_lse
|
||||||
|
239
vllm/attention/backends/flashmla.py
Normal file
239
vllm/attention/backends/flashmla.py
Normal file
@ -0,0 +1,239 @@
|
|||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
|
||||||
|
from contextlib import contextmanager
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Type
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from vllm.attention.backends.abstract import AttentionType
|
||||||
|
from vllm.attention.backends.mla.common import (MLACommonBackend,
|
||||||
|
MLACommonImpl,
|
||||||
|
MLACommonMetadata,
|
||||||
|
MLACommonMetadataBuilder,
|
||||||
|
MLACommonState)
|
||||||
|
from vllm.attention.ops.flashmla import (flash_mla_with_kvcache,
|
||||||
|
get_mla_metadata,
|
||||||
|
is_flashmla_supported)
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from vllm.worker.model_runner import ModelInputForGPUWithSamplingMetadata
|
||||||
|
|
||||||
|
|
||||||
|
class FlashMLABackend(MLACommonBackend):
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def get_name() -> str:
|
||||||
|
return "FLASHMLA"
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def get_impl_cls() -> Type["FlashMLAImpl"]:
|
||||||
|
return FlashMLAImpl
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def get_metadata_cls() -> Type["FlashMLAMetadata"]:
|
||||||
|
return FlashMLAMetadata
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def get_builder_cls() -> Type["FlashMLAMetadataBuilder"]:
|
||||||
|
return FlashMLAMetadataBuilder
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def get_state_cls() -> Type["FlashMLAState"]:
|
||||||
|
return FlashMLAState
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class FlashMLAMetadata(MLACommonMetadata):
|
||||||
|
decode_tile_scheduler_metadata: Optional[Tuple[torch.Tensor,
|
||||||
|
torch.Tensor]] = None
|
||||||
|
decode_num_splits: Optional[torch.Tensor] = None
|
||||||
|
|
||||||
|
@property
|
||||||
|
def decode_metadata(self):
|
||||||
|
decode_metadata = super().decode_metadata
|
||||||
|
# TODO: cache assignment?
|
||||||
|
if decode_metadata is not None:
|
||||||
|
decode_metadata.decode_tile_scheduler_metadata=\
|
||||||
|
self.decode_tile_scheduler_metadata
|
||||||
|
decode_metadata.decode_num_splits=\
|
||||||
|
self.decode_num_splits
|
||||||
|
return decode_metadata
|
||||||
|
|
||||||
|
def advance_step(self,
|
||||||
|
model_input: "ModelInputForGPUWithSamplingMetadata",
|
||||||
|
sampled_token_ids: Optional[torch.Tensor],
|
||||||
|
block_size: int,
|
||||||
|
num_seqs: int,
|
||||||
|
num_queries: int,
|
||||||
|
turn_prefills_into_decodes: bool = False):
|
||||||
|
raise NotImplementedError(
|
||||||
|
"advance_step is not implemented for FlashMLA")
|
||||||
|
|
||||||
|
|
||||||
|
class FlashMLAMetadataBuilder(MLACommonMetadataBuilder[FlashMLAMetadata]):
|
||||||
|
|
||||||
|
def __init__(self, *args, **kwargs):
|
||||||
|
super().__init__(*args, **kwargs)
|
||||||
|
|
||||||
|
self.num_q_heads = self.runner.model_config.get_num_attention_heads(
|
||||||
|
self.runner.parallel_config)
|
||||||
|
|
||||||
|
def build(self, seq_lens: List[int], query_lens: List[int],
|
||||||
|
cuda_graph_pad_size: int, batch_size: int):
|
||||||
|
m = super().build(seq_lens, query_lens, cuda_graph_pad_size,
|
||||||
|
batch_size)
|
||||||
|
|
||||||
|
if m.num_decode_tokens > 0:
|
||||||
|
m.decode_tile_scheduler_metadata, m.decode_num_splits = \
|
||||||
|
get_mla_metadata(
|
||||||
|
m.seq_lens_tensor[m.num_prefills:],
|
||||||
|
self.num_q_heads,
|
||||||
|
1, # MQA for the decode path
|
||||||
|
)
|
||||||
|
|
||||||
|
return m
|
||||||
|
|
||||||
|
|
||||||
|
class FlashMLAState(MLACommonState[FlashMLAMetadata]):
|
||||||
|
|
||||||
|
def __init__(self, *args, **kwds):
|
||||||
|
super().__init__(*args, **kwds)
|
||||||
|
|
||||||
|
self.num_q_heads = self.runner.model_config.get_num_attention_heads(
|
||||||
|
self.runner.parallel_config)
|
||||||
|
|
||||||
|
@contextmanager
|
||||||
|
def graph_capture(self, max_batch_size: int):
|
||||||
|
# Run a dummy `get_mla_metadata` so we can get the right shapes
|
||||||
|
self._graph_decoder_tile_scheduler_metadata, \
|
||||||
|
self._graph_decode_num_splits = get_mla_metadata(
|
||||||
|
torch.ones(
|
||||||
|
max_batch_size, dtype=torch.int32, device=self.runner.device),
|
||||||
|
self.num_q_heads,
|
||||||
|
1, # MQA for the decode path
|
||||||
|
)
|
||||||
|
|
||||||
|
with super().graph_capture(max_batch_size):
|
||||||
|
yield
|
||||||
|
|
||||||
|
del self._graph_decoder_tile_scheduler_metadata
|
||||||
|
del self._graph_decode_num_splits
|
||||||
|
|
||||||
|
def graph_capture_get_metadata_for_batch(
|
||||||
|
self, batch_size: int, is_encoder_decoder_model: bool = False):
|
||||||
|
metadata = super().graph_capture_get_metadata_for_batch(
|
||||||
|
batch_size, is_encoder_decoder_model)
|
||||||
|
assert metadata.num_decode_tokens > 0
|
||||||
|
|
||||||
|
decoder_tile_scheduler_metadata, decode_num_splits = get_mla_metadata(
|
||||||
|
self._graph_seq_lens[:batch_size],
|
||||||
|
self.num_q_heads,
|
||||||
|
1, # MQA for the decode path
|
||||||
|
)
|
||||||
|
|
||||||
|
self._graph_decoder_tile_scheduler_metadata.copy_(
|
||||||
|
decoder_tile_scheduler_metadata)
|
||||||
|
self._graph_decode_num_splits[:batch_size + 1].copy_(decode_num_splits)
|
||||||
|
|
||||||
|
metadata.decode_tile_scheduler_metadata=\
|
||||||
|
self._graph_decoder_tile_scheduler_metadata
|
||||||
|
metadata.decode_num_splits=\
|
||||||
|
self._graph_decode_num_splits[:batch_size + 1]
|
||||||
|
|
||||||
|
return metadata
|
||||||
|
|
||||||
|
def get_graph_input_buffers(self,
|
||||||
|
attn_metadata,
|
||||||
|
is_encoder_decoder_model: bool = False):
|
||||||
|
input_buffers = super().get_graph_input_buffers(
|
||||||
|
attn_metadata, is_encoder_decoder_model)
|
||||||
|
input_buffers["decode_tile_scheduler_metadata"] = \
|
||||||
|
attn_metadata.decode_metadata.decode_tile_scheduler_metadata
|
||||||
|
input_buffers["decode_num_splits"] = \
|
||||||
|
attn_metadata.decode_metadata.decode_num_splits
|
||||||
|
|
||||||
|
return input_buffers
|
||||||
|
|
||||||
|
def prepare_graph_input_buffers(self,
|
||||||
|
input_buffers,
|
||||||
|
attn_metadata,
|
||||||
|
is_encoder_decoder_model: bool = False):
|
||||||
|
super().prepare_graph_input_buffers(input_buffers, attn_metadata,
|
||||||
|
is_encoder_decoder_model)
|
||||||
|
|
||||||
|
input_buffers["decode_tile_scheduler_metadata"].copy_(
|
||||||
|
attn_metadata.decode_metadata.decode_tile_scheduler_metadata)
|
||||||
|
input_buffers["decode_num_splits"].copy_(
|
||||||
|
attn_metadata.decode_metadata.decode_num_splits)
|
||||||
|
|
||||||
|
|
||||||
|
class FlashMLAImpl(MLACommonImpl[FlashMLAMetadata]):
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
num_heads: int,
|
||||||
|
head_size: int,
|
||||||
|
scale: float,
|
||||||
|
num_kv_heads: int,
|
||||||
|
alibi_slopes: Optional[List[float]],
|
||||||
|
sliding_window: Optional[int],
|
||||||
|
kv_cache_dtype: str,
|
||||||
|
blocksparse_params: Optional[Dict[str, Any]],
|
||||||
|
logits_soft_cap: Optional[float],
|
||||||
|
attn_type: str,
|
||||||
|
# MLA Specific Arguments
|
||||||
|
**mla_args) -> None:
|
||||||
|
super().__init__(num_heads, head_size, scale, num_kv_heads,
|
||||||
|
alibi_slopes, sliding_window, kv_cache_dtype,
|
||||||
|
blocksparse_params, logits_soft_cap, attn_type,
|
||||||
|
**mla_args)
|
||||||
|
|
||||||
|
assert is_flashmla_supported(), \
|
||||||
|
"FlashMLA is not supported on this device"
|
||||||
|
|
||||||
|
unsupported_features = [
|
||||||
|
alibi_slopes, sliding_window, blocksparse_params, logits_soft_cap
|
||||||
|
]
|
||||||
|
if any(unsupported_features):
|
||||||
|
raise NotImplementedError(
|
||||||
|
"FlashMLAImpl does not support one of the following: "
|
||||||
|
"alibi_slopes, sliding_window, blocksparse_params, "
|
||||||
|
"logits_soft_cap")
|
||||||
|
|
||||||
|
if attn_type != AttentionType.DECODER:
|
||||||
|
raise NotImplementedError("Encoder self-attention and "
|
||||||
|
"encoder/decoder cross-attention "
|
||||||
|
"are not implemented for "
|
||||||
|
"FlashMLAImpl")
|
||||||
|
|
||||||
|
def _forward_decode(
|
||||||
|
self,
|
||||||
|
q_nope: torch.Tensor,
|
||||||
|
q_pe: torch.Tensor,
|
||||||
|
kv_c_and_k_pe_cache: torch.Tensor,
|
||||||
|
attn_metadata: FlashMLAMetadata,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
assert kv_c_and_k_pe_cache.numel() > 0
|
||||||
|
if self.kv_cache_dtype.startswith("fp8"):
|
||||||
|
raise NotImplementedError("FP8 FlashMLA not yet supported")
|
||||||
|
|
||||||
|
decode_meta = attn_metadata.decode_metadata
|
||||||
|
assert decode_meta is not None
|
||||||
|
|
||||||
|
q = torch.cat([q_nope, q_pe], dim=-1)\
|
||||||
|
.unsqueeze(1) # Add seqlen dim of 1 (decode)
|
||||||
|
|
||||||
|
o, _ = flash_mla_with_kvcache(
|
||||||
|
q=q,
|
||||||
|
k_cache=kv_c_and_k_pe_cache.unsqueeze(-2), # Add head dim of 1
|
||||||
|
block_table=decode_meta.block_tables,
|
||||||
|
cache_seqlens=decode_meta.seq_lens_tensor,
|
||||||
|
head_dim_v=self.kv_lora_rank,
|
||||||
|
tile_scheduler_metadata=decode_meta.decode_tile_scheduler_metadata,
|
||||||
|
num_splits=decode_meta.decode_num_splits,
|
||||||
|
softmax_scale=self.scale,
|
||||||
|
causal=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
return self._v_up_proj_and_o_proj(o)
|
@ -293,7 +293,10 @@ class MLACommonBackend(AttentionBackend):
|
|||||||
return [576]
|
return [576]
|
||||||
|
|
||||||
|
|
||||||
class MLACommonState(AttentionState):
|
T = TypeVar("T", bound="MLACommonMetadata")
|
||||||
|
|
||||||
|
|
||||||
|
class MLACommonState(AttentionState, Generic[T]):
|
||||||
|
|
||||||
def __init__(self, runner):
|
def __init__(self, runner):
|
||||||
self.runner = runner
|
self.runner = runner
|
||||||
@ -355,7 +358,9 @@ class MLACommonState(AttentionState):
|
|||||||
return self.__class__(self.runner)
|
return self.__class__(self.runner)
|
||||||
|
|
||||||
def graph_capture_get_metadata_for_batch(
|
def graph_capture_get_metadata_for_batch(
|
||||||
self, batch_size: int, is_encoder_decoder_model: bool = False):
|
self,
|
||||||
|
batch_size: int,
|
||||||
|
is_encoder_decoder_model: bool = False) -> T:
|
||||||
assert self._is_graph_capturing
|
assert self._is_graph_capturing
|
||||||
|
|
||||||
attn_metadata = self.runner.attn_backend.make_metadata(
|
attn_metadata = self.runner.attn_backend.make_metadata(
|
||||||
@ -507,8 +512,8 @@ class MLACommonMetadata(AttentionMetadata):
|
|||||||
# [4, 6], it is [0, 4, 10].
|
# [4, 6], it is [0, 4, 10].
|
||||||
seq_start_loc: Optional[torch.Tensor] = None
|
seq_start_loc: Optional[torch.Tensor] = None
|
||||||
|
|
||||||
_cached_prefill_metadata: Optional["MLACommonMetadata"] = None
|
_cached_prefill_metadata: Optional[Any] = None
|
||||||
_cached_decode_metadata: Optional["MLACommonMetadata"] = None
|
_cached_decode_metadata: Optional[Any] = None
|
||||||
|
|
||||||
num_prefill_tokens: int
|
num_prefill_tokens: int
|
||||||
|
|
||||||
@ -537,7 +542,7 @@ class MLACommonMetadata(AttentionMetadata):
|
|||||||
f" received {self.head_dim}.")
|
f" received {self.head_dim}.")
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def prefill_metadata(self) -> Optional["MLACommonMetadata"]:
|
def prefill_metadata(self):
|
||||||
if self.num_prefills == 0:
|
if self.num_prefills == 0:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
@ -565,7 +570,7 @@ class MLACommonMetadata(AttentionMetadata):
|
|||||||
input_positions = (None if self.input_positions is None else
|
input_positions = (None if self.input_positions is None else
|
||||||
self.input_positions[:self.num_prefill_tokens])
|
self.input_positions[:self.num_prefill_tokens])
|
||||||
|
|
||||||
self._cached_prefill_metadata = MLACommonMetadata(
|
self._cached_prefill_metadata = self.__class__(
|
||||||
# Required by ModelRunner
|
# Required by ModelRunner
|
||||||
use_cuda_graph=False, # Not Attention Related
|
use_cuda_graph=False, # Not Attention Related
|
||||||
# Required by Attention Metadata
|
# Required by Attention Metadata
|
||||||
@ -599,7 +604,7 @@ class MLACommonMetadata(AttentionMetadata):
|
|||||||
return self._cached_prefill_metadata
|
return self._cached_prefill_metadata
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def decode_metadata(self) -> Optional["MLACommonMetadata"]:
|
def decode_metadata(self):
|
||||||
if self.num_decode_tokens == 0:
|
if self.num_decode_tokens == 0:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
@ -617,7 +622,7 @@ class MLACommonMetadata(AttentionMetadata):
|
|||||||
input_positions = (None if self.input_positions is None else
|
input_positions = (None if self.input_positions is None else
|
||||||
self.input_positions[self.num_prefill_tokens:])
|
self.input_positions[self.num_prefill_tokens:])
|
||||||
|
|
||||||
self._cached_decode_metadata = MLACommonMetadata(
|
self._cached_decode_metadata = self.__class__(
|
||||||
# Required by ModelRunner
|
# Required by ModelRunner
|
||||||
use_cuda_graph=self.use_cuda_graph, # Not Attention Related
|
use_cuda_graph=self.use_cuda_graph, # Not Attention Related
|
||||||
# Required by Attention Metadata
|
# Required by Attention Metadata
|
||||||
@ -723,10 +728,7 @@ class MLACommonMetadata(AttentionMetadata):
|
|||||||
block_tables=self.block_tables)
|
block_tables=self.block_tables)
|
||||||
|
|
||||||
|
|
||||||
T = TypeVar("T", bound=MLACommonMetadata)
|
class MLACommonMetadataBuilder(AttentionMetadataBuilder[T], Generic[T]):
|
||||||
|
|
||||||
|
|
||||||
class MLACommonMetadataBuilder(AttentionMetadataBuilder[MLACommonMetadata]):
|
|
||||||
"""
|
"""
|
||||||
NOTE: Please read the comment at the top of the file before trying to
|
NOTE: Please read the comment at the top of the file before trying to
|
||||||
understand this class
|
understand this class
|
||||||
@ -959,7 +961,7 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[MLACommonMetadata]):
|
|||||||
assert max(context_chunk_seq_tot) <= \
|
assert max(context_chunk_seq_tot) <= \
|
||||||
self.chunked_prefill_workspace_size
|
self.chunked_prefill_workspace_size
|
||||||
|
|
||||||
return MLACommonMetadata(
|
return self.runner.attn_backend.make_metadata(
|
||||||
# Required by ModelRunner
|
# Required by ModelRunner
|
||||||
use_cuda_graph=use_captured_graph, # Not Attention Related
|
use_cuda_graph=use_captured_graph, # Not Attention Related
|
||||||
# Required by Attention Metadata
|
# Required by Attention Metadata
|
||||||
|
115
vllm/attention/ops/flashmla.py
Normal file
115
vllm/attention/ops/flashmla.py
Normal file
@ -0,0 +1,115 @@
|
|||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
# adapted from: https://github.com/deepseek-ai/FlashMLA/blob/main/flash_mla/flash_mla_interface.py
|
||||||
|
from typing import Optional, Tuple
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from vllm.logger import init_logger
|
||||||
|
from vllm.platforms import current_platform
|
||||||
|
|
||||||
|
logger = init_logger(__name__)
|
||||||
|
|
||||||
|
if current_platform.is_cuda():
|
||||||
|
try:
|
||||||
|
import vllm._flashmla_C # noqa: F401
|
||||||
|
_flashmla_C_AVAILABLE = True
|
||||||
|
except ImportError:
|
||||||
|
_flashmla_C_AVAILABLE = False
|
||||||
|
else:
|
||||||
|
_flashmla_C_AVAILABLE = False
|
||||||
|
|
||||||
|
|
||||||
|
def is_flashmla_supported() -> Tuple[bool, Optional[str]]:
|
||||||
|
"""
|
||||||
|
Return: is_supported_flag, unsupported_reason (optional).
|
||||||
|
"""
|
||||||
|
if not current_platform.is_cuda():
|
||||||
|
return False, "FlashMLA is only supported on CUDA devices."
|
||||||
|
if current_platform.get_device_capability()[0] != 9:
|
||||||
|
return False, "FlashMLA is only supported on Hopper devices."
|
||||||
|
if not _flashmla_C_AVAILABLE:
|
||||||
|
return False, "vllm._flashmla_C is not available, likely was not "\
|
||||||
|
"compiled due to insufficient nvcc version or a supported arch "\
|
||||||
|
"(only sm90a currently) was not in the list of target arches to "\
|
||||||
|
"compile for."
|
||||||
|
return True, None
|
||||||
|
|
||||||
|
|
||||||
|
def get_mla_metadata(
|
||||||
|
cache_seqlens: torch.Tensor,
|
||||||
|
num_heads_per_head_k: int,
|
||||||
|
num_heads_k: int,
|
||||||
|
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||||
|
"""
|
||||||
|
Arguments:
|
||||||
|
cache_seqlens: (batch_size), dtype torch.int32.
|
||||||
|
num_heads_per_head_k: Equals to seq_len_q * num_heads_q // num_heads_k.
|
||||||
|
num_heads_k: num_heads_k.
|
||||||
|
|
||||||
|
Return:
|
||||||
|
tile_scheduler_metadata: (num_sm_parts, TileSchedulerMetaDataSize),
|
||||||
|
dtype torch.int32.
|
||||||
|
num_splits: (batch_size + 1), dtype torch.int32.
|
||||||
|
"""
|
||||||
|
return torch.ops._flashmla_C.get_mla_metadata(cache_seqlens,
|
||||||
|
num_heads_per_head_k,
|
||||||
|
num_heads_k)
|
||||||
|
|
||||||
|
|
||||||
|
def flash_mla_with_kvcache(
|
||||||
|
q: torch.Tensor,
|
||||||
|
k_cache: torch.Tensor,
|
||||||
|
block_table: torch.Tensor,
|
||||||
|
cache_seqlens: torch.Tensor,
|
||||||
|
head_dim_v: int,
|
||||||
|
tile_scheduler_metadata: torch.Tensor,
|
||||||
|
num_splits: torch.Tensor,
|
||||||
|
softmax_scale: Optional[float] = None,
|
||||||
|
causal: bool = False,
|
||||||
|
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||||
|
"""
|
||||||
|
Arguments:
|
||||||
|
q: (batch_size, seq_len_q, num_heads_q, head_dim).
|
||||||
|
k_cache: (num_blocks, page_block_size, num_heads_k, head_dim).
|
||||||
|
block_table: (batch_size, max_num_blocks_per_seq), torch.int32.
|
||||||
|
cache_seqlens: (batch_size), torch.int32.
|
||||||
|
head_dim_v: Head_dim of v.
|
||||||
|
tile_scheduler_metadata: (num_sm_parts, TileSchedulerMetaDataSize),
|
||||||
|
torch.int32, return by get_mla_metadata.
|
||||||
|
num_splits: (batch_size + 1), torch.int32, return by get_mla_metadata.
|
||||||
|
softmax_scale: float. The scaling of QK^T before applying softmax.
|
||||||
|
Default to 1 / sqrt(head_dim).
|
||||||
|
causal: bool. Whether to apply causal attention mask.
|
||||||
|
|
||||||
|
Return:
|
||||||
|
out: (batch_size, seq_len_q, num_heads_q, head_dim_v).
|
||||||
|
softmax_lse: (batch_size, num_heads_q, seq_len_q), torch.float32.
|
||||||
|
"""
|
||||||
|
if softmax_scale is None:
|
||||||
|
softmax_scale = q.shape[-1]**(-0.5)
|
||||||
|
out, softmax_lse = torch.ops._flashmla_C.fwd_kvcache_mla(
|
||||||
|
q,
|
||||||
|
k_cache,
|
||||||
|
None,
|
||||||
|
head_dim_v,
|
||||||
|
cache_seqlens,
|
||||||
|
block_table,
|
||||||
|
softmax_scale,
|
||||||
|
causal,
|
||||||
|
tile_scheduler_metadata,
|
||||||
|
num_splits,
|
||||||
|
)
|
||||||
|
return out, softmax_lse
|
||||||
|
|
||||||
|
|
||||||
|
#
|
||||||
|
# TODO: Add fake functions
|
||||||
|
#
|
||||||
|
# @register_fake("_flashmla_C::get_mla_metadata")
|
||||||
|
# def _get_mla_metadata_fake(....) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||||
|
# return ....
|
||||||
|
#
|
||||||
|
# @register_fake("_flashmla_C::fwd_kvcache_mla")
|
||||||
|
# def _fwd_kvcache_mla_fake(....) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||||
|
# return ....
|
||||||
|
#
|
@ -141,6 +141,14 @@ class CudaPlatformBase(Platform):
|
|||||||
cache_config = vllm_config.cache_config
|
cache_config = vllm_config.cache_config
|
||||||
if cache_config and cache_config.block_size is None:
|
if cache_config and cache_config.block_size is None:
|
||||||
cache_config.block_size = 16
|
cache_config.block_size = 16
|
||||||
|
# TODO(lucas): handle this more gracefully
|
||||||
|
if envs.VLLM_ATTENTION_BACKEND is not None \
|
||||||
|
and envs.VLLM_ATTENTION_BACKEND == "FLASHMLA" \
|
||||||
|
and cache_config.block_size != 64:
|
||||||
|
cache_config.block_size = 64
|
||||||
|
logger.info(
|
||||||
|
"FlashMLA: Forcing kv cache block size to 64 since this"
|
||||||
|
" is currently the only block size supported by the kernel.")
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def get_current_memory_usage(cls,
|
def get_current_memory_usage(cls,
|
||||||
@ -157,6 +165,22 @@ class CudaPlatformBase(Platform):
|
|||||||
logger.info("Using Flash Attention backend on V1 engine.")
|
logger.info("Using Flash Attention backend on V1 engine.")
|
||||||
return "vllm.v1.attention.backends.flash_attn.FlashAttentionBackend"
|
return "vllm.v1.attention.backends.flash_attn.FlashAttentionBackend"
|
||||||
if use_mla:
|
if use_mla:
|
||||||
|
if selected_backend == _Backend.FLASHMLA:
|
||||||
|
from vllm.attention.backends.flashmla import (
|
||||||
|
is_flashmla_supported)
|
||||||
|
if not is_flashmla_supported()[0]:
|
||||||
|
logger.warning(
|
||||||
|
"FlashMLA backend is not supported due to %s",
|
||||||
|
is_flashmla_supported()[1])
|
||||||
|
elif block_size != 64:
|
||||||
|
logger.warning(
|
||||||
|
"FlashMLA backend is not supported for block size %d"
|
||||||
|
" (currently only supports block size 64).",
|
||||||
|
block_size)
|
||||||
|
else:
|
||||||
|
logger.info("Using FlashMLA backend.")
|
||||||
|
return "vllm.attention.backends.flashmla.FlashMLABackend"
|
||||||
|
|
||||||
logger.info("Using Triton MLA backend.")
|
logger.info("Using Triton MLA backend.")
|
||||||
return "vllm.attention.backends.triton_mla.TritonMLABackend"
|
return "vllm.attention.backends.triton_mla.TritonMLABackend"
|
||||||
if selected_backend == _Backend.FLASHINFER:
|
if selected_backend == _Backend.FLASHINFER:
|
||||||
|
@ -35,6 +35,7 @@ class _Backend(enum.Enum):
|
|||||||
OPENVINO = enum.auto()
|
OPENVINO = enum.auto()
|
||||||
FLASHINFER = enum.auto()
|
FLASHINFER = enum.auto()
|
||||||
TRITON_MLA = enum.auto()
|
TRITON_MLA = enum.auto()
|
||||||
|
FLASHMLA = enum.auto()
|
||||||
HPU_ATTN = enum.auto()
|
HPU_ATTN = enum.auto()
|
||||||
PALLAS = enum.auto()
|
PALLAS = enum.auto()
|
||||||
PALLAS_VLLM_V1 = enum.auto()
|
PALLAS_VLLM_V1 = enum.auto()
|
||||||
|
Loading…
x
Reference in New Issue
Block a user