diff --git a/CMakeLists.txt b/CMakeLists.txt index 02a60c0e..0dd350c9 100755 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -575,77 +575,8 @@ if(VLLM_GPU_LANG STREQUAL "HIP") WITH_SOABI) endif() -# vllm-flash-attn currently only supported on CUDA -if (NOT VLLM_GPU_LANG STREQUAL "CUDA") - return() +# For CUDA we also build and ship some external projects. +if (VLLM_GPU_LANG STREQUAL "CUDA") + include(cmake/external_projects/flashmla.cmake) + include(cmake/external_projects/vllm_flash_attn.cmake) endif () - -# vLLM flash attention requires VLLM_GPU_ARCHES to contain the set of target -# arches in the CMake syntax (75-real, 89-virtual, etc), since we clear the -# arches in the CUDA case (and instead set the gencodes on a per file basis) -# we need to manually set VLLM_GPU_ARCHES here. -if(VLLM_GPU_LANG STREQUAL "CUDA") - foreach(_ARCH ${CUDA_ARCHS}) - string(REPLACE "." "" _ARCH "${_ARCH}") - list(APPEND VLLM_GPU_ARCHES "${_ARCH}-real") - endforeach() -endif() - -# -# Build vLLM flash attention from source -# -# IMPORTANT: This has to be the last thing we do, because vllm-flash-attn uses the same macros/functions as vLLM. -# Because functions all belong to the global scope, vllm-flash-attn's functions overwrite vLLMs. -# They should be identical but if they aren't, this is a massive footgun. -# -# The vllm-flash-attn install rules are nested under vllm to make sure the library gets installed in the correct place. -# To only install vllm-flash-attn, use --component _vllm_fa2_C (for FA2) or --component _vllm_fa3_C (for FA3). -# If no component is specified, vllm-flash-attn is still installed. - -# If VLLM_FLASH_ATTN_SRC_DIR is set, vllm-flash-attn is installed from that directory instead of downloading. -# This is to enable local development of vllm-flash-attn within vLLM. -# It can be set as an environment variable or passed as a cmake argument. -# The environment variable takes precedence. -if (DEFINED ENV{VLLM_FLASH_ATTN_SRC_DIR}) - set(VLLM_FLASH_ATTN_SRC_DIR $ENV{VLLM_FLASH_ATTN_SRC_DIR}) -endif() - -if(VLLM_FLASH_ATTN_SRC_DIR) - FetchContent_Declare( - vllm-flash-attn SOURCE_DIR - ${VLLM_FLASH_ATTN_SRC_DIR} - BINARY_DIR ${CMAKE_BINARY_DIR}/vllm-flash-attn - ) -else() - FetchContent_Declare( - vllm-flash-attn - GIT_REPOSITORY https://github.com/vllm-project/flash-attention.git - GIT_TAG 720c94869cf2e0ff5a706e9c7f1dce0939686ade - GIT_PROGRESS TRUE - # Don't share the vllm-flash-attn build between build types - BINARY_DIR ${CMAKE_BINARY_DIR}/vllm-flash-attn - ) -endif() - - -# Fetch the vllm-flash-attn library -FetchContent_MakeAvailable(vllm-flash-attn) -message(STATUS "vllm-flash-attn is available at ${vllm-flash-attn_SOURCE_DIR}") - -# Copy over the vllm-flash-attn python files (duplicated for fa2 and fa3, in -# case only one is built, in the case both are built redundant work is done) -install( - DIRECTORY ${vllm-flash-attn_SOURCE_DIR}/vllm_flash_attn/ - DESTINATION vllm_flash_attn - COMPONENT _vllm_fa2_C - FILES_MATCHING PATTERN "*.py" -) - -install( - DIRECTORY ${vllm-flash-attn_SOURCE_DIR}/vllm_flash_attn/ - DESTINATION vllm_flash_attn - COMPONENT _vllm_fa3_C - FILES_MATCHING PATTERN "*.py" -) - -# Nothing after vllm-flash-attn, see comment about macros above diff --git a/cmake/external_projects/flashmla.cmake b/cmake/external_projects/flashmla.cmake new file mode 100644 index 00000000..62914751 --- /dev/null +++ b/cmake/external_projects/flashmla.cmake @@ -0,0 +1,66 @@ +include(FetchContent) + +# If FLASH_MLA_SRC_DIR is set, flash-mla is installed from that directory +# instead of downloading. +# It can be set as an environment variable or passed as a cmake argument. +# The environment variable takes precedence. +if (DEFINED ENV{FLASH_MLA_SRC_DIR}) + set(FLASH_MLA_SRC_DIR $ENV{FLASH_MLA_SRC_DIR}) +endif() + +if(FLASH_MLA_SRC_DIR) + FetchContent_Declare( + flashmla + SOURCE_DIR ${FLASH_MLA_SRC_DIR} + CONFIGURE_COMMAND "" + BUILD_COMMAND "" + ) +else() + FetchContent_Declare( + flashmla + GIT_REPOSITORY https://github.com/vllm-project/FlashMLA.git + GIT_TAG 575f7724b9762f265bbee5889df9c7d630801845 + GIT_PROGRESS TRUE + CONFIGURE_COMMAND "" + BUILD_COMMAND "" + ) +endif() + + +FetchContent_MakeAvailable(flashmla) +message(STATUS "FlashMLA is available at ${flashmla_SOURCE_DIR}") + +# The FlashMLA kernels only work on hopper and require CUDA 12.3 or later. +# Only build FlashMLA kernels if we are building for something compatible with +# sm90a +cuda_archs_loose_intersection(FLASH_MLA_ARCHS "9.0a" "${CUDA_ARCHS}") +if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER 12.3 AND FLASH_MLA_ARCHS) + set(FlashMLA_SOURCES + ${flashmla_SOURCE_DIR}/csrc/flash_api.cpp + ${flashmla_SOURCE_DIR}/csrc/flash_fwd_mla_bf16_sm90.cu + ${flashmla_SOURCE_DIR}/csrc/flash_fwd_mla_fp16_sm90.cu + ${flashmla_SOURCE_DIR}/csrc/flash_fwd_mla_metadata.cu) + + set(FlashMLA_INCLUDES + ${flashmla_SOURCE_DIR}/csrc/cutlass/include + ${flashmla_SOURCE_DIR}/csrc/include) + + set_gencode_flags_for_srcs( + SRCS "${FlashMLA_SOURCES}" + CUDA_ARCHS "${FLASH_MLA_ARCHS}") + + define_gpu_extension_target( + _flashmla_C + DESTINATION vllm + LANGUAGE ${VLLM_GPU_LANG} + SOURCES ${FlashMLA_SOURCES} + COMPILE_FLAGS ${VLLM_GPU_FLAGS} + ARCHITECTURES ${VLLM_GPU_ARCHES} + INCLUDE_DIRECTORIES ${FlashMLA_INCLUDES} + USE_SABI 3 + WITH_SOABI) +else() + # Create an empty target for setup.py when not targeting sm90a systems + add_custom_target(_flashmla_C) +endif() + diff --git a/cmake/external_projects/vllm_flash_attn.cmake b/cmake/external_projects/vllm_flash_attn.cmake new file mode 100644 index 00000000..ef6261fa --- /dev/null +++ b/cmake/external_projects/vllm_flash_attn.cmake @@ -0,0 +1,67 @@ +# vLLM flash attention requires VLLM_GPU_ARCHES to contain the set of target +# arches in the CMake syntax (75-real, 89-virtual, etc), since we clear the +# arches in the CUDA case (and instead set the gencodes on a per file basis) +# we need to manually set VLLM_GPU_ARCHES here. +if(VLLM_GPU_LANG STREQUAL "CUDA") + foreach(_ARCH ${CUDA_ARCHS}) + string(REPLACE "." "" _ARCH "${_ARCH}") + list(APPEND VLLM_GPU_ARCHES "${_ARCH}-real") + endforeach() +endif() + +# +# Build vLLM flash attention from source +# +# IMPORTANT: This has to be the last thing we do, because vllm-flash-attn uses the same macros/functions as vLLM. +# Because functions all belong to the global scope, vllm-flash-attn's functions overwrite vLLMs. +# They should be identical but if they aren't, this is a massive footgun. +# +# The vllm-flash-attn install rules are nested under vllm to make sure the library gets installed in the correct place. +# To only install vllm-flash-attn, use --component _vllm_fa2_C (for FA2) or --component _vllm_fa3_C (for FA3). +# If no component is specified, vllm-flash-attn is still installed. + +# If VLLM_FLASH_ATTN_SRC_DIR is set, vllm-flash-attn is installed from that directory instead of downloading. +# This is to enable local development of vllm-flash-attn within vLLM. +# It can be set as an environment variable or passed as a cmake argument. +# The environment variable takes precedence. +if (DEFINED ENV{VLLM_FLASH_ATTN_SRC_DIR}) + set(VLLM_FLASH_ATTN_SRC_DIR $ENV{VLLM_FLASH_ATTN_SRC_DIR}) +endif() + +if(VLLM_FLASH_ATTN_SRC_DIR) + FetchContent_Declare( + vllm-flash-attn SOURCE_DIR + ${VLLM_FLASH_ATTN_SRC_DIR} + BINARY_DIR ${CMAKE_BINARY_DIR}/vllm-flash-attn + ) +else() + FetchContent_Declare( + vllm-flash-attn + GIT_REPOSITORY https://github.com/vllm-project/flash-attention.git + GIT_TAG 720c94869cf2e0ff5a706e9c7f1dce0939686ade + GIT_PROGRESS TRUE + # Don't share the vllm-flash-attn build between build types + BINARY_DIR ${CMAKE_BINARY_DIR}/vllm-flash-attn + ) +endif() + + +# Fetch the vllm-flash-attn library +FetchContent_MakeAvailable(vllm-flash-attn) +message(STATUS "vllm-flash-attn is available at ${vllm-flash-attn_SOURCE_DIR}") + +# Copy over the vllm-flash-attn python files (duplicated for fa2 and fa3, in +# case only one is built, in the case both are built redundant work is done) +install( + DIRECTORY ${vllm-flash-attn_SOURCE_DIR}/vllm_flash_attn/ + DESTINATION vllm_flash_attn + COMPONENT _vllm_fa2_C + FILES_MATCHING PATTERN "*.py" +) + +install( + DIRECTORY ${vllm-flash-attn_SOURCE_DIR}/vllm_flash_attn/ + DESTINATION vllm_flash_attn + COMPONENT _vllm_fa3_C + FILES_MATCHING PATTERN "*.py" +) \ No newline at end of file diff --git a/setup.py b/setup.py index d8a336c2..a636d266 100755 --- a/setup.py +++ b/setup.py @@ -328,6 +328,7 @@ class repackage_wheel(build_ext): files_to_copy = [ "vllm/_C.abi3.so", "vllm/_moe_C.abi3.so", + "vllm/_flashmla_C.abi3.so", "vllm/vllm_flash_attn/_vllm_fa2_C.abi3.so", "vllm/vllm_flash_attn/_vllm_fa3_C.abi3.so", "vllm/vllm_flash_attn/flash_attn_interface.py", @@ -612,6 +613,11 @@ if _is_cuda(): # FA3 requires CUDA 12.0 or later ext_modules.append( CMakeExtension(name="vllm.vllm_flash_attn._vllm_fa3_C")) + if envs.VLLM_USE_PRECOMPILED or get_nvcc_cuda_version() >= Version("12.3"): + # Optional since this doesn't get built (produce an .so file) when + # not targeting a hopper system + ext_modules.append( + CMakeExtension(name="vllm._flashmla_C", optional=True)) ext_modules.append(CMakeExtension(name="vllm.cumem_allocator")) if _build_custom_ops(): diff --git a/tests/kernels/test_flashmla.py b/tests/kernels/test_flashmla.py new file mode 100644 index 00000000..21c1079f --- /dev/null +++ b/tests/kernels/test_flashmla.py @@ -0,0 +1,132 @@ +# Adapted from: https://github.com/deepseek-ai/FlashMLA/blob/main/tests/test_flash_mla.py +# SPDX-License-Identifier: Apache-2.0 +import math +import random + +import pytest +import torch +import triton + +from vllm.attention.ops.flashmla import (flash_mla_with_kvcache, + get_mla_metadata, + is_flashmla_supported) + + +def cal_diff(x: torch.Tensor, y: torch.Tensor, name: str) -> None: + x, y = x.double(), y.double() + cos_diff = 1 - 2 * (x * y).sum().item() / max( + (x * x + y * y).sum().item(), 1e-12) + assert cos_diff < 1e-5 + +FLASH_MLA_UNSUPPORTED_REASON = is_flashmla_supported()[1] \ + if not is_flashmla_supported()[0] else "FlashMLA is supported" + + +@pytest.mark.skipif(not is_flashmla_supported()[0], + reason=FLASH_MLA_UNSUPPORTED_REASON) +@pytest.mark.parametrize("b", [128]) +@pytest.mark.parametrize("s_q", [1, 2]) +@pytest.mark.parametrize("mean_sk", [4096, 8192]) +@pytest.mark.parametrize("h_q", [16, 32, 64, 128]) +@pytest.mark.parametrize("h_kv", [1]) +@pytest.mark.parametrize("d", [576]) +@pytest.mark.parametrize("dv", [512]) +@pytest.mark.parametrize("block_size", [64]) +@pytest.mark.parametrize("causal", [True]) +@pytest.mark.parametrize("varlen", [False, True]) +@torch.inference_mode() +def test_flash_mla(b, s_q, mean_sk, h_q, h_kv, d, dv, block_size, causal, + varlen): + # TODO: parametrize using pytest + dtype = torch.bfloat16 + device = torch.device("cuda:0") + torch.set_default_dtype(dtype) + torch.set_default_device(device) + torch.cuda.set_device(device) + torch.manual_seed(0) + random.seed(0) + + print(f"{b=}, {s_q=}, {mean_sk=}, {h_q=}, {h_kv=}, " + f"{d=}, {dv=}, {causal=}, {varlen=}") + + cache_seqlens = torch.full((b, ), mean_sk, dtype=torch.int32) + if varlen: + for i in range(b): + cache_seqlens[i] = max(random.normalvariate(mean_sk, mean_sk / 2), + s_q) + total_seqlens = cache_seqlens.sum().item() + max_seqlen = cache_seqlens.max().item() + max_seqlen_pad = triton.cdiv(max_seqlen, 256) * 256 + + q = torch.randn(b, s_q, h_q, d) + block_table = torch.arange(b * max_seqlen_pad // block_size, + dtype=torch.int32).view( + b, max_seqlen_pad // block_size) + blocked_k = torch.randn(block_table.numel(), block_size, h_kv, d) + for i in range(b): + blocked_k.view(b, max_seqlen_pad, h_kv, + d)[i, cache_seqlens[i].item():] = float("nan") + blocked_v = blocked_k[..., :dv] + + tile_scheduler_metadata, num_splits = get_mla_metadata( + cache_seqlens, s_q * h_q // h_kv, h_kv) + + def flash_mla(): + return flash_mla_with_kvcache( + q, + blocked_k, + block_table, + cache_seqlens, + dv, + tile_scheduler_metadata, + num_splits, + causal=causal, + ) + + def scaled_dot_product_attention(query, key, value, is_causal=False): + query = query.float() + key = key.float() + value = value.float() + key = key.repeat_interleave(h_q // h_kv, dim=0) + value = value.repeat_interleave(h_q // h_kv, dim=0) + attn_weight = query @ key.transpose(-2, -1) / math.sqrt(query.size(-1)) + if is_causal: + s_q = query.shape[-2] + s_k = key.shape[-2] + attn_bias = torch.zeros(s_q, s_k, dtype=query.dtype) + temp_mask = torch.ones(s_q, s_k, + dtype=torch.bool).tril(diagonal=s_k - s_q) + attn_bias.masked_fill_(temp_mask.logical_not(), float("-inf")) + attn_bias.to(query.dtype) + attn_weight += attn_bias + lse = attn_weight.logsumexp(dim=-1) + attn_weight = torch.softmax(attn_weight, dim=-1, dtype=torch.float32) + return attn_weight @ value, lse + + def ref_mla(): + out = torch.empty(b, s_q, h_q, dv, dtype=torch.float32) + lse = torch.empty(b, h_q, s_q, dtype=torch.float32) + for i in range(b): + begin = i * max_seqlen_pad + end = begin + cache_seqlens[i] + ref_O, LSE = scaled_dot_product_attention( + q[i].transpose(0, 1), + blocked_k.view(-1, h_kv, d)[begin:end].transpose(0, 1), + blocked_v.view(-1, h_kv, dv)[begin:end].transpose(0, 1), + is_causal=causal, + ) + out[i] = ref_O.transpose(0, 1) + lse[i] = LSE + return out, lse + + out_flash, lse_flash = flash_mla() + out_torch, lse_torch = ref_mla() + cal_diff(out_flash, out_torch, "out") + cal_diff(lse_flash, lse_torch, "lse") + + t = triton.testing.do_bench(flash_mla, fast_flush=False) + FLOPS = s_q * total_seqlens * h_q * (d + dv) * 2 + bytes = (total_seqlens * h_kv * d + b * s_q * h_q * d + + b * s_q * h_q * dv) * (torch.finfo(dtype).bits // 8) + print(f"{t:.3f} ms, {FLOPS / 10 ** 9 / t:.0f} " + f"TFLOPS, {bytes / 10 ** 6 / t:.0f} GB/s") diff --git a/vllm/_custom_ops.py b/vllm/_custom_ops.py index 3306610a..0e83bcae 100644 --- a/vllm/_custom_ops.py +++ b/vllm/_custom_ops.py @@ -1163,3 +1163,67 @@ def get_graph_buffer_ipc_meta(fa: int) -> Tuple[List[int], List[int]]: def register_graph_buffers(fa: int, handles: List[List[int]], offsets: List[List[int]]) -> None: torch.ops._C_custom_ar.register_graph_buffers(fa, handles, offsets) + + +def get_flash_mla_metadata( + cache_seqlens: torch.Tensor, + num_heads_per_head_k: int, + num_heads_k: int, +) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Arguments: + cache_seqlens: (batch_size), dtype torch.int32. + num_heads_per_head_k: Equals to seq_len_q * num_heads_q // num_heads_k. + num_heads_k: num_heads_k. + + Return: + tile_scheduler_metadata: (num_sm_parts, TileSchedulerMetaDataSize), dtype torch.int32. + num_splits: (batch_size + 1), dtype torch.int32. + """ + return torch.ops._C.get_flash_mla_metadata(cache_seqlens, + num_heads_per_head_k, + num_heads_k) + + +def flash_mla_with_kvcache( + q: torch.Tensor, + k_cache: torch.Tensor, + block_table: torch.Tensor, + cache_seqlens: torch.Tensor, + head_dim_v: int, + tile_scheduler_metadata: torch.Tensor, + num_splits: torch.Tensor, + softmax_scale: Optional[float] = None, + causal: bool = False, +) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Arguments: + q: (batch_size, seq_len_q, num_heads_q, head_dim). + k_cache: (num_blocks, page_block_size, num_heads_k, head_dim). + block_table: (batch_size, max_num_blocks_per_seq), torch.int32. + cache_seqlens: (batch_size), torch.int32. + head_dim_v: Head_dim of v. + tile_scheduler_metadata: (num_sm_parts, TileSchedulerMetaDataSize), torch.int32, return by get_mla_metadata. + num_splits: (batch_size + 1), torch.int32, return by get_mla_metadata. + softmax_scale: float. The scaling of QK^T before applying softmax. Default to 1 / sqrt(head_dim). + causal: bool. Whether to apply causal attention mask. + + Return: + out: (batch_size, seq_len_q, num_heads_q, head_dim_v). + softmax_lse: (batch_size, num_heads_q, seq_len_q), torch.float32. + """ + if softmax_scale is None: + softmax_scale = q.shape[-1]**(-0.5) + out, softmax_lse = torch.ops._C.flash_mla_fwd_kvcache( + q, + k_cache, + None, + head_dim_v, + cache_seqlens, + block_table, + softmax_scale, + causal, + tile_scheduler_metadata, + num_splits, + ) + return out, softmax_lse diff --git a/vllm/attention/backends/flashmla.py b/vllm/attention/backends/flashmla.py new file mode 100644 index 00000000..273c69b6 --- /dev/null +++ b/vllm/attention/backends/flashmla.py @@ -0,0 +1,239 @@ +# SPDX-License-Identifier: Apache-2.0 + +from contextlib import contextmanager +from dataclasses import dataclass +from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Type + +import torch + +from vllm.attention.backends.abstract import AttentionType +from vllm.attention.backends.mla.common import (MLACommonBackend, + MLACommonImpl, + MLACommonMetadata, + MLACommonMetadataBuilder, + MLACommonState) +from vllm.attention.ops.flashmla import (flash_mla_with_kvcache, + get_mla_metadata, + is_flashmla_supported) + +if TYPE_CHECKING: + from vllm.worker.model_runner import ModelInputForGPUWithSamplingMetadata + + +class FlashMLABackend(MLACommonBackend): + + @staticmethod + def get_name() -> str: + return "FLASHMLA" + + @staticmethod + def get_impl_cls() -> Type["FlashMLAImpl"]: + return FlashMLAImpl + + @staticmethod + def get_metadata_cls() -> Type["FlashMLAMetadata"]: + return FlashMLAMetadata + + @staticmethod + def get_builder_cls() -> Type["FlashMLAMetadataBuilder"]: + return FlashMLAMetadataBuilder + + @staticmethod + def get_state_cls() -> Type["FlashMLAState"]: + return FlashMLAState + + +@dataclass +class FlashMLAMetadata(MLACommonMetadata): + decode_tile_scheduler_metadata: Optional[Tuple[torch.Tensor, + torch.Tensor]] = None + decode_num_splits: Optional[torch.Tensor] = None + + @property + def decode_metadata(self): + decode_metadata = super().decode_metadata + # TODO: cache assignment? + if decode_metadata is not None: + decode_metadata.decode_tile_scheduler_metadata=\ + self.decode_tile_scheduler_metadata + decode_metadata.decode_num_splits=\ + self.decode_num_splits + return decode_metadata + + def advance_step(self, + model_input: "ModelInputForGPUWithSamplingMetadata", + sampled_token_ids: Optional[torch.Tensor], + block_size: int, + num_seqs: int, + num_queries: int, + turn_prefills_into_decodes: bool = False): + raise NotImplementedError( + "advance_step is not implemented for FlashMLA") + + +class FlashMLAMetadataBuilder(MLACommonMetadataBuilder[FlashMLAMetadata]): + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + self.num_q_heads = self.runner.model_config.get_num_attention_heads( + self.runner.parallel_config) + + def build(self, seq_lens: List[int], query_lens: List[int], + cuda_graph_pad_size: int, batch_size: int): + m = super().build(seq_lens, query_lens, cuda_graph_pad_size, + batch_size) + + if m.num_decode_tokens > 0: + m.decode_tile_scheduler_metadata, m.decode_num_splits = \ + get_mla_metadata( + m.seq_lens_tensor[m.num_prefills:], + self.num_q_heads, + 1, # MQA for the decode path + ) + + return m + + +class FlashMLAState(MLACommonState[FlashMLAMetadata]): + + def __init__(self, *args, **kwds): + super().__init__(*args, **kwds) + + self.num_q_heads = self.runner.model_config.get_num_attention_heads( + self.runner.parallel_config) + + @contextmanager + def graph_capture(self, max_batch_size: int): + # Run a dummy `get_mla_metadata` so we can get the right shapes + self._graph_decoder_tile_scheduler_metadata, \ + self._graph_decode_num_splits = get_mla_metadata( + torch.ones( + max_batch_size, dtype=torch.int32, device=self.runner.device), + self.num_q_heads, + 1, # MQA for the decode path + ) + + with super().graph_capture(max_batch_size): + yield + + del self._graph_decoder_tile_scheduler_metadata + del self._graph_decode_num_splits + + def graph_capture_get_metadata_for_batch( + self, batch_size: int, is_encoder_decoder_model: bool = False): + metadata = super().graph_capture_get_metadata_for_batch( + batch_size, is_encoder_decoder_model) + assert metadata.num_decode_tokens > 0 + + decoder_tile_scheduler_metadata, decode_num_splits = get_mla_metadata( + self._graph_seq_lens[:batch_size], + self.num_q_heads, + 1, # MQA for the decode path + ) + + self._graph_decoder_tile_scheduler_metadata.copy_( + decoder_tile_scheduler_metadata) + self._graph_decode_num_splits[:batch_size + 1].copy_(decode_num_splits) + + metadata.decode_tile_scheduler_metadata=\ + self._graph_decoder_tile_scheduler_metadata + metadata.decode_num_splits=\ + self._graph_decode_num_splits[:batch_size + 1] + + return metadata + + def get_graph_input_buffers(self, + attn_metadata, + is_encoder_decoder_model: bool = False): + input_buffers = super().get_graph_input_buffers( + attn_metadata, is_encoder_decoder_model) + input_buffers["decode_tile_scheduler_metadata"] = \ + attn_metadata.decode_metadata.decode_tile_scheduler_metadata + input_buffers["decode_num_splits"] = \ + attn_metadata.decode_metadata.decode_num_splits + + return input_buffers + + def prepare_graph_input_buffers(self, + input_buffers, + attn_metadata, + is_encoder_decoder_model: bool = False): + super().prepare_graph_input_buffers(input_buffers, attn_metadata, + is_encoder_decoder_model) + + input_buffers["decode_tile_scheduler_metadata"].copy_( + attn_metadata.decode_metadata.decode_tile_scheduler_metadata) + input_buffers["decode_num_splits"].copy_( + attn_metadata.decode_metadata.decode_num_splits) + + +class FlashMLAImpl(MLACommonImpl[FlashMLAMetadata]): + + def __init__( + self, + num_heads: int, + head_size: int, + scale: float, + num_kv_heads: int, + alibi_slopes: Optional[List[float]], + sliding_window: Optional[int], + kv_cache_dtype: str, + blocksparse_params: Optional[Dict[str, Any]], + logits_soft_cap: Optional[float], + attn_type: str, + # MLA Specific Arguments + **mla_args) -> None: + super().__init__(num_heads, head_size, scale, num_kv_heads, + alibi_slopes, sliding_window, kv_cache_dtype, + blocksparse_params, logits_soft_cap, attn_type, + **mla_args) + + assert is_flashmla_supported(), \ + "FlashMLA is not supported on this device" + + unsupported_features = [ + alibi_slopes, sliding_window, blocksparse_params, logits_soft_cap + ] + if any(unsupported_features): + raise NotImplementedError( + "FlashMLAImpl does not support one of the following: " + "alibi_slopes, sliding_window, blocksparse_params, " + "logits_soft_cap") + + if attn_type != AttentionType.DECODER: + raise NotImplementedError("Encoder self-attention and " + "encoder/decoder cross-attention " + "are not implemented for " + "FlashMLAImpl") + + def _forward_decode( + self, + q_nope: torch.Tensor, + q_pe: torch.Tensor, + kv_c_and_k_pe_cache: torch.Tensor, + attn_metadata: FlashMLAMetadata, + ) -> torch.Tensor: + assert kv_c_and_k_pe_cache.numel() > 0 + if self.kv_cache_dtype.startswith("fp8"): + raise NotImplementedError("FP8 FlashMLA not yet supported") + + decode_meta = attn_metadata.decode_metadata + assert decode_meta is not None + + q = torch.cat([q_nope, q_pe], dim=-1)\ + .unsqueeze(1) # Add seqlen dim of 1 (decode) + + o, _ = flash_mla_with_kvcache( + q=q, + k_cache=kv_c_and_k_pe_cache.unsqueeze(-2), # Add head dim of 1 + block_table=decode_meta.block_tables, + cache_seqlens=decode_meta.seq_lens_tensor, + head_dim_v=self.kv_lora_rank, + tile_scheduler_metadata=decode_meta.decode_tile_scheduler_metadata, + num_splits=decode_meta.decode_num_splits, + softmax_scale=self.scale, + causal=True, + ) + + return self._v_up_proj_and_o_proj(o) diff --git a/vllm/attention/backends/mla/common.py b/vllm/attention/backends/mla/common.py index 225fee8d..1befcb6b 100644 --- a/vllm/attention/backends/mla/common.py +++ b/vllm/attention/backends/mla/common.py @@ -293,7 +293,10 @@ class MLACommonBackend(AttentionBackend): return [576] -class MLACommonState(AttentionState): +T = TypeVar("T", bound="MLACommonMetadata") + + +class MLACommonState(AttentionState, Generic[T]): def __init__(self, runner): self.runner = runner @@ -355,7 +358,9 @@ class MLACommonState(AttentionState): return self.__class__(self.runner) def graph_capture_get_metadata_for_batch( - self, batch_size: int, is_encoder_decoder_model: bool = False): + self, + batch_size: int, + is_encoder_decoder_model: bool = False) -> T: assert self._is_graph_capturing attn_metadata = self.runner.attn_backend.make_metadata( @@ -507,8 +512,8 @@ class MLACommonMetadata(AttentionMetadata): # [4, 6], it is [0, 4, 10]. seq_start_loc: Optional[torch.Tensor] = None - _cached_prefill_metadata: Optional["MLACommonMetadata"] = None - _cached_decode_metadata: Optional["MLACommonMetadata"] = None + _cached_prefill_metadata: Optional[Any] = None + _cached_decode_metadata: Optional[Any] = None num_prefill_tokens: int @@ -537,7 +542,7 @@ class MLACommonMetadata(AttentionMetadata): f" received {self.head_dim}.") @property - def prefill_metadata(self) -> Optional["MLACommonMetadata"]: + def prefill_metadata(self): if self.num_prefills == 0: return None @@ -565,7 +570,7 @@ class MLACommonMetadata(AttentionMetadata): input_positions = (None if self.input_positions is None else self.input_positions[:self.num_prefill_tokens]) - self._cached_prefill_metadata = MLACommonMetadata( + self._cached_prefill_metadata = self.__class__( # Required by ModelRunner use_cuda_graph=False, # Not Attention Related # Required by Attention Metadata @@ -599,7 +604,7 @@ class MLACommonMetadata(AttentionMetadata): return self._cached_prefill_metadata @property - def decode_metadata(self) -> Optional["MLACommonMetadata"]: + def decode_metadata(self): if self.num_decode_tokens == 0: return None @@ -617,7 +622,7 @@ class MLACommonMetadata(AttentionMetadata): input_positions = (None if self.input_positions is None else self.input_positions[self.num_prefill_tokens:]) - self._cached_decode_metadata = MLACommonMetadata( + self._cached_decode_metadata = self.__class__( # Required by ModelRunner use_cuda_graph=self.use_cuda_graph, # Not Attention Related # Required by Attention Metadata @@ -723,10 +728,7 @@ class MLACommonMetadata(AttentionMetadata): block_tables=self.block_tables) -T = TypeVar("T", bound=MLACommonMetadata) - - -class MLACommonMetadataBuilder(AttentionMetadataBuilder[MLACommonMetadata]): +class MLACommonMetadataBuilder(AttentionMetadataBuilder[T], Generic[T]): """ NOTE: Please read the comment at the top of the file before trying to understand this class @@ -959,7 +961,7 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[MLACommonMetadata]): assert max(context_chunk_seq_tot) <= \ self.chunked_prefill_workspace_size - return MLACommonMetadata( + return self.runner.attn_backend.make_metadata( # Required by ModelRunner use_cuda_graph=use_captured_graph, # Not Attention Related # Required by Attention Metadata diff --git a/vllm/attention/ops/flashmla.py b/vllm/attention/ops/flashmla.py new file mode 100644 index 00000000..18b69a6b --- /dev/null +++ b/vllm/attention/ops/flashmla.py @@ -0,0 +1,115 @@ +# SPDX-License-Identifier: Apache-2.0 +# adapted from: https://github.com/deepseek-ai/FlashMLA/blob/main/flash_mla/flash_mla_interface.py +from typing import Optional, Tuple + +import torch + +from vllm.logger import init_logger +from vllm.platforms import current_platform + +logger = init_logger(__name__) + +if current_platform.is_cuda(): + try: + import vllm._flashmla_C # noqa: F401 + _flashmla_C_AVAILABLE = True + except ImportError: + _flashmla_C_AVAILABLE = False +else: + _flashmla_C_AVAILABLE = False + + +def is_flashmla_supported() -> Tuple[bool, Optional[str]]: + """ + Return: is_supported_flag, unsupported_reason (optional). + """ + if not current_platform.is_cuda(): + return False, "FlashMLA is only supported on CUDA devices." + if current_platform.get_device_capability()[0] != 9: + return False, "FlashMLA is only supported on Hopper devices." + if not _flashmla_C_AVAILABLE: + return False, "vllm._flashmla_C is not available, likely was not "\ + "compiled due to insufficient nvcc version or a supported arch "\ + "(only sm90a currently) was not in the list of target arches to "\ + "compile for." + return True, None + + +def get_mla_metadata( + cache_seqlens: torch.Tensor, + num_heads_per_head_k: int, + num_heads_k: int, +) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Arguments: + cache_seqlens: (batch_size), dtype torch.int32. + num_heads_per_head_k: Equals to seq_len_q * num_heads_q // num_heads_k. + num_heads_k: num_heads_k. + + Return: + tile_scheduler_metadata: (num_sm_parts, TileSchedulerMetaDataSize), + dtype torch.int32. + num_splits: (batch_size + 1), dtype torch.int32. + """ + return torch.ops._flashmla_C.get_mla_metadata(cache_seqlens, + num_heads_per_head_k, + num_heads_k) + + +def flash_mla_with_kvcache( + q: torch.Tensor, + k_cache: torch.Tensor, + block_table: torch.Tensor, + cache_seqlens: torch.Tensor, + head_dim_v: int, + tile_scheduler_metadata: torch.Tensor, + num_splits: torch.Tensor, + softmax_scale: Optional[float] = None, + causal: bool = False, +) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Arguments: + q: (batch_size, seq_len_q, num_heads_q, head_dim). + k_cache: (num_blocks, page_block_size, num_heads_k, head_dim). + block_table: (batch_size, max_num_blocks_per_seq), torch.int32. + cache_seqlens: (batch_size), torch.int32. + head_dim_v: Head_dim of v. + tile_scheduler_metadata: (num_sm_parts, TileSchedulerMetaDataSize), + torch.int32, return by get_mla_metadata. + num_splits: (batch_size + 1), torch.int32, return by get_mla_metadata. + softmax_scale: float. The scaling of QK^T before applying softmax. + Default to 1 / sqrt(head_dim). + causal: bool. Whether to apply causal attention mask. + + Return: + out: (batch_size, seq_len_q, num_heads_q, head_dim_v). + softmax_lse: (batch_size, num_heads_q, seq_len_q), torch.float32. + """ + if softmax_scale is None: + softmax_scale = q.shape[-1]**(-0.5) + out, softmax_lse = torch.ops._flashmla_C.fwd_kvcache_mla( + q, + k_cache, + None, + head_dim_v, + cache_seqlens, + block_table, + softmax_scale, + causal, + tile_scheduler_metadata, + num_splits, + ) + return out, softmax_lse + + +# +# TODO: Add fake functions +# +# @register_fake("_flashmla_C::get_mla_metadata") +# def _get_mla_metadata_fake(....) -> Tuple[torch.Tensor, torch.Tensor]: +# return .... +# +# @register_fake("_flashmla_C::fwd_kvcache_mla") +# def _fwd_kvcache_mla_fake(....) -> Tuple[torch.Tensor, torch.Tensor]: +# return .... +# diff --git a/vllm/platforms/cuda.py b/vllm/platforms/cuda.py index bf425b89..c6f3ccf0 100644 --- a/vllm/platforms/cuda.py +++ b/vllm/platforms/cuda.py @@ -141,6 +141,14 @@ class CudaPlatformBase(Platform): cache_config = vllm_config.cache_config if cache_config and cache_config.block_size is None: cache_config.block_size = 16 + # TODO(lucas): handle this more gracefully + if envs.VLLM_ATTENTION_BACKEND is not None \ + and envs.VLLM_ATTENTION_BACKEND == "FLASHMLA" \ + and cache_config.block_size != 64: + cache_config.block_size = 64 + logger.info( + "FlashMLA: Forcing kv cache block size to 64 since this" + " is currently the only block size supported by the kernel.") @classmethod def get_current_memory_usage(cls, @@ -157,6 +165,22 @@ class CudaPlatformBase(Platform): logger.info("Using Flash Attention backend on V1 engine.") return "vllm.v1.attention.backends.flash_attn.FlashAttentionBackend" if use_mla: + if selected_backend == _Backend.FLASHMLA: + from vllm.attention.backends.flashmla import ( + is_flashmla_supported) + if not is_flashmla_supported()[0]: + logger.warning( + "FlashMLA backend is not supported due to %s", + is_flashmla_supported()[1]) + elif block_size != 64: + logger.warning( + "FlashMLA backend is not supported for block size %d" + " (currently only supports block size 64).", + block_size) + else: + logger.info("Using FlashMLA backend.") + return "vllm.attention.backends.flashmla.FlashMLABackend" + logger.info("Using Triton MLA backend.") return "vllm.attention.backends.triton_mla.TritonMLABackend" if selected_backend == _Backend.FLASHINFER: diff --git a/vllm/platforms/interface.py b/vllm/platforms/interface.py index d6dae2e5..0e4988a4 100644 --- a/vllm/platforms/interface.py +++ b/vllm/platforms/interface.py @@ -35,6 +35,7 @@ class _Backend(enum.Enum): OPENVINO = enum.auto() FLASHINFER = enum.auto() TRITON_MLA = enum.auto() + FLASHMLA = enum.auto() HPU_ATTN = enum.auto() PALLAS = enum.auto() PALLAS_VLLM_V1 = enum.auto()