From e9528f6dc614879952aa871fae8df296adcfc559 Mon Sep 17 00:00:00 2001 From: DefTruth <31974251+DefTruth@users.noreply.github.com> Date: Fri, 11 Apr 2025 20:50:50 +0800 Subject: [PATCH] [Kernel] support merge_attn_states CUDA kernel, 3x speedup (#16173) Signed-off-by: DefTruth --- CMakeLists.txt | 1 + csrc/attention/merge_attn_states.cu | 173 +++++++++++++++ csrc/ops.h | 9 + csrc/torch_bindings.cpp | 15 ++ tests/kernels/test_merge_attn_states.py | 265 +++++++++++++++++++++++ vllm/_custom_ops.py | 11 + vllm/attention/backends/mla/common.py | 3 +- vllm/attention/ops/merge_attn_states.py | 42 ++++ vllm/v1/attention/backends/flash_attn.py | 2 +- vllm/v1/attention/backends/mla/common.py | 2 +- 10 files changed, 519 insertions(+), 4 deletions(-) create mode 100644 csrc/attention/merge_attn_states.cu create mode 100644 tests/kernels/test_merge_attn_states.py create mode 100644 vllm/attention/ops/merge_attn_states.py diff --git a/CMakeLists.txt b/CMakeLists.txt index 15db4a4f..a0c25df6 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -230,6 +230,7 @@ set(VLLM_EXT_SRC "csrc/cache_kernels.cu" "csrc/attention/paged_attention_v1.cu" "csrc/attention/paged_attention_v2.cu" + "csrc/attention/merge_attn_states.cu" "csrc/pos_encoding_kernels.cu" "csrc/activation_kernels.cu" "csrc/layernorm_kernels.cu" diff --git a/csrc/attention/merge_attn_states.cu b/csrc/attention/merge_attn_states.cu new file mode 100644 index 00000000..7af0cace --- /dev/null +++ b/csrc/attention/merge_attn_states.cu @@ -0,0 +1,173 @@ +#include +#include +#include +#include +#include + +#include "attention_dtypes.h" +#include "attention_utils.cuh" + +namespace vllm { + +// Implements section 2.2 of https://www.arxiv.org/pdf/2501.01005 +// can be used to combine partial attention results (in the split-KV case) +template +__global__ void merge_attn_states_kernel( + scalar_t* output, float* output_lse, const scalar_t* prefix_output, + const float* prefix_lse, const scalar_t* suffix_output, + const float* suffix_lse, const uint num_tokens, const uint num_heads, + const uint head_size) { + using pack_128b_t = uint4; + const uint pack_size = 16 / sizeof(scalar_t); + const uint threads_per_head = head_size / pack_size; + + const uint global_idx = blockIdx.x * NUM_THREADS + threadIdx.x; + const uint token_head_threads = num_tokens * num_heads * threads_per_head; + + if (global_idx >= token_head_threads) return; + + // global_idx -> token_idx + head_idx + pack_idx + const uint token_head_idx = global_idx / threads_per_head; + const uint pack_idx = global_idx % threads_per_head; + + const uint token_idx = token_head_idx / num_heads; + const uint head_idx = token_head_idx % num_heads; + + const uint pack_offset = pack_idx * pack_size; // (0~15)*8, etc. + const uint head_offset = + token_idx * num_heads * head_size + head_idx * head_size; + const scalar_t* prefix_head_ptr = prefix_output + head_offset; + const scalar_t* suffix_head_ptr = suffix_output + head_offset; + scalar_t* output_head_ptr = output + head_offset; + + float p_lse = prefix_lse[head_idx * num_tokens + token_idx]; + float s_lse = suffix_lse[head_idx * num_tokens + token_idx]; + p_lse = std::isinf(p_lse) ? -std::numeric_limits::infinity() : p_lse; + s_lse = std::isinf(s_lse) ? -std::numeric_limits::infinity() : s_lse; + + const float max_lse = fmaxf(p_lse, s_lse); + p_lse = p_lse - max_lse; + s_lse = s_lse - max_lse; + const float p_se = expf(p_lse); + const float s_se = expf(s_lse); + const float out_se = p_se + s_se; + const float p_scale = p_se / out_se; + const float s_scale = s_se / out_se; + + if (pack_offset < head_size) { + // Pack 128b load + pack_128b_t p_out_pack = reinterpret_cast( + prefix_head_ptr)[pack_offset / pack_size]; + pack_128b_t s_out_pack = reinterpret_cast( + suffix_head_ptr)[pack_offset / pack_size]; + pack_128b_t o_out_pack; + +#pragma unroll + for (uint i = 0; i < pack_size; ++i) { + // Always use float for FMA to keep high precision. + // half(uint16_t), bfloat16, float -> float. + const float p_out_f = + vllm::to_float(reinterpret_cast(&p_out_pack)[i]); + const float s_out_f = + vllm::to_float(reinterpret_cast(&s_out_pack)[i]); + // fma: a * b + c = p_out_f * p_scale + (s_out_f * s_scale) + const float o_out_f = p_out_f * p_scale + (s_out_f * s_scale); + // float -> half(uint16_t), bfloat16, float. + vllm::from_float(reinterpret_cast(&o_out_pack)[i], o_out_f); + } + + // Pack 128b storage + reinterpret_cast(output_head_ptr)[pack_offset / pack_size] = + o_out_pack; + } + // We only need to write to output_lse once per head. + if (output_lse != nullptr && pack_idx == 0) { + float out_lse = logf(out_se) + max_lse; + output_lse[head_idx * num_tokens + token_idx] = out_lse; + } +} + +} // namespace vllm + +// The following macro is used to dispatch the conversion function based on +// the output data type. The FN is a macro that calls a function with +// template. +#define DISPATCH_BY_SCALAR_DTYPE(scalar_dtype, fn) \ + { \ + if (scalar_dtype == at::ScalarType::Float) { \ + fn(float); \ + } else if (scalar_dtype == at::ScalarType::Half) { \ + fn(uint16_t); \ + } else if (scalar_dtype == at::ScalarType::BFloat16) { \ + fn(__nv_bfloat16); \ + } else { \ + TORCH_CHECK(false, "Unsupported data type of O: ", scalar_dtype); \ + } \ + } + +#define LAUNCH_MERGE_ATTN_STATES(scalar_t, NUM_THREADS) \ + { \ + vllm::merge_attn_states_kernel<<>>( \ + reinterpret_cast(output.data_ptr()), output_lse_ptr, \ + reinterpret_cast(prefix_output.data_ptr()), \ + reinterpret_cast(prefix_lse.data_ptr()), \ + reinterpret_cast(suffix_output.data_ptr()), \ + reinterpret_cast(suffix_lse.data_ptr()), num_tokens, \ + num_heads, head_size); \ + } + +/*@brief Merges the attention states from prefix and suffix + * into the output tensor. NUM_TOKENS: n, NUM_HEADS: h, HEAD_SIZE: d + * + * @param output [n,h,d] The output tensor to store the merged attention states. + * @param output_lse [h,d] Optional tensor to store the log-sum-exp values. + * @param prefix_output [n,h,d] The prefix attention states. + * @param prefix_lse [h,d] The log-sum-exp values for the prefix attention + * states. + * @param suffix_output [n,h,d] The suffix attention states. + * @param suffix_lse [h,d] The log-sum-exp values for the suffix attention + * states. + */ +template +void merge_attn_states_launcher(torch::Tensor& output, + std::optional output_lse, + const torch::Tensor& prefix_output, + const torch::Tensor& prefix_lse, + const torch::Tensor& suffix_output, + const torch::Tensor& suffix_lse) { + constexpr uint NUM_THREADS = 128; + const uint num_tokens = output.size(0); + const uint num_heads = output.size(1); + const uint head_size = output.size(2); + const uint pack_size = 16 / sizeof(scalar_t); + TORCH_CHECK(head_size % pack_size == 0, + "headsize must be multiple of pack_size:", pack_size); + float* output_lse_ptr = nullptr; + if (output_lse.has_value()) { + output_lse_ptr = output_lse.value().data_ptr(); + } + // process one pack elements per thread. float -> 4, half/bf16 -> 8 + const uint threads_per_head = head_size / pack_size; + const uint total_threads = num_tokens * num_heads * threads_per_head; + + dim3 block(NUM_THREADS); + dim3 grid((total_threads + NUM_THREADS - 1) / NUM_THREADS); + + LAUNCH_MERGE_ATTN_STATES(scalar_t, NUM_THREADS); +} + +#define CALL_MERGE_ATTN_STATES_LAUNCHER(scalar_t) \ + { \ + merge_attn_states_launcher(output, output_lse, prefix_output, \ + prefix_lse, suffix_output, \ + suffix_lse); \ + } + +void merge_attn_states(torch::Tensor& output, + std::optional output_lse, + const torch::Tensor& prefix_output, + const torch::Tensor& prefix_lse, + const torch::Tensor& suffix_output, + const torch::Tensor& suffix_lse) { + DISPATCH_BY_SCALAR_DTYPE(output.dtype(), CALL_MERGE_ATTN_STATES_LAUNCHER); +} diff --git a/csrc/ops.h b/csrc/ops.h index 152c94e8..86039a26 100644 --- a/csrc/ops.h +++ b/csrc/ops.h @@ -52,6 +52,15 @@ void paged_attention_v2( const int64_t blocksparse_vert_stride, const int64_t blocksparse_block_size, const int64_t blocksparse_head_sliding_step); +#ifndef USE_ROCM +void merge_attn_states(torch::Tensor& output, + std::optional output_lse, + const torch::Tensor& prefix_output, + const torch::Tensor& prefix_lse, + const torch::Tensor& suffix_output, + const torch::Tensor& suffix_lse); +#endif + void rms_norm(torch::Tensor& out, torch::Tensor& input, torch::Tensor& weight, double epsilon); diff --git a/csrc/torch_bindings.cpp b/csrc/torch_bindings.cpp index d3b80572..b6ff6a00 100644 --- a/csrc/torch_bindings.cpp +++ b/csrc/torch_bindings.cpp @@ -64,6 +64,21 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { " int blocksparse_head_sliding_step) -> ()"); ops.impl("paged_attention_v2", torch::kCUDA, &paged_attention_v2); +#ifndef USE_ROCM + // Merge attn states + // Implements section 2.2 of https://www.arxiv.org/pdf/2501.01005 + // can be used to combine partial attention results (in the split-KV case) + ops.def( + "merge_attn_states(" + " Tensor! output," + " Tensor!? output_lse," + " Tensor prefix_output," + " Tensor prefix_lse," + " Tensor suffix_output," + " Tensor suffix_lse) -> ()"); + ops.impl("merge_attn_states", torch::kCUDA, &merge_attn_states); +#endif + // Activation ops // Activation function used in SwiGLU. ops.def("silu_and_mul(Tensor! out, Tensor input) -> ()"); diff --git a/tests/kernels/test_merge_attn_states.py b/tests/kernels/test_merge_attn_states.py new file mode 100644 index 00000000..7038fbea --- /dev/null +++ b/tests/kernels/test_merge_attn_states.py @@ -0,0 +1,265 @@ +# SPDX-License-Identifier: Apache-2.0 +from typing import Optional + +import pytest +import torch + +from vllm._custom_ops import merge_attn_states as merge_attn_states_cuda +from vllm.attention.ops.triton_merge_attn_states import ( + merge_attn_states as merge_attn_states_triton) +from vllm.platforms import current_platform + + +# Naive PyTorch Implements section 2.2 of https://www.arxiv.org/pdf/2501.01005 +# can be used to combine partial attention results (in the split-KV case) +def merge_attn_states_torch( + output: torch.Tensor, # [NUM_TOKENS, NUM_HEADS, HEAD_SIZE] + prefix_output: torch.Tensor, # [NUM_TOKENS, NUM_HEADS, HEAD_SIZE] + prefix_lse: torch.Tensor, # [NUM_HEADS, NUM_TOKENS] + suffix_output: torch.Tensor, # [NUM_TOKENS, NUM_HEADS, HEAD_SIZE] + suffix_lse: torch.Tensor, # [NUM_HEADS, NUM_TOKENS] + output_lse: Optional[torch.Tensor] = None, # [NUM_HEADS, NUM_TOKENS] +): + p_lse = prefix_lse + s_lse = suffix_lse + # inf -> -inf + p_lse[p_lse == torch.inf] = -torch.inf + s_lse[s_lse == torch.inf] = -torch.inf + # max_lse [NUM_HEADS, NUM_TOKENS] + max_lse = torch.maximum(p_lse, s_lse) + p_lse = p_lse - max_lse + s_lse = s_lse - max_lse + p_lse_exp = torch.exp(p_lse) + s_lse_exp = torch.exp(s_lse) + out_se = (p_lse_exp + s_lse_exp) + if output_lse is not None: + output_lse = torch.log(out_se) + max_lse + p_scale = p_lse_exp / out_se # [NUM_HEADS, NUM_TOKENS] + s_scale = s_lse_exp / out_se # [NUM_HEADS, NUM_TOKENS] + p_scale = torch.transpose(p_scale, 0, + 1).unsqueeze(2) # [NUM_TOKENS, NUM_HEADS, 1] + s_scale = torch.transpose(s_scale, 0, + 1).unsqueeze(2) # [NUM_TOKENS, NUM_HEADS, 1] + output = prefix_output * p_scale + suffix_output * s_scale + return output, output_lse + + +NUM_BATCH_TOKENS = [256, 512, 613, 1024, 1536, 4096] +NUM_QUERY_HEADS = [4, 8, 16, 32, 48, 64] +HEAD_SIZES = [32, 48, 64, 96, 128, 256] +DTYPES = [torch.float32, torch.half, torch.bfloat16] + +all_case_info: list[tuple] = [] + + +def generate_markdown_table(): + global all_case_info + table_header = ("| tokens | heads | headsize | dtype " + "| device | torch | triton | cuda | speedup |") + table_separator = "| --- | --- | --- | --- | --- | --- | --- | --- | --- |" + + def shortly_dtype(dtype: torch.dtype) -> str: + return str(dtype).removeprefix("torch.") + + def shortly_device(device: str) -> str: + return device.removeprefix("NVIDIA").strip() + + print(table_header) + print(table_separator) + for info in all_case_info: + (num_tokens, num_heads, head_size, dtype, device, + avg_time_torch_kernel, avg_time_triton_kernel, avg_time_cuda_kernel, + performance_improved) = info + dtype = shortly_dtype(dtype) + device = shortly_device(device) + print(f"| {num_tokens} | {num_heads} | {head_size} " + f"| {dtype} | {device} | {avg_time_torch_kernel:.5f}ms " + f"| {avg_time_triton_kernel:.5f}ms " + f"| {avg_time_cuda_kernel:.5f}ms " + f"| {performance_improved:.4f}x |") + + +@pytest.mark.parametrize("num_tokens", NUM_BATCH_TOKENS) +@pytest.mark.parametrize("num_query_heads", NUM_QUERY_HEADS) +@pytest.mark.parametrize("head_size", HEAD_SIZES) +@pytest.mark.parametrize("output_dtype", DTYPES) +@torch.inference_mode() +def test_merge_attn_states(num_tokens: int, num_query_heads: int, + head_size: int, output_dtype: torch.dtype): + if not current_platform.is_cuda(): + pytest.skip('Currently only support compare triton merge_attn_states ' + 'with custom cuda merge_attn_states kernel') + + NUM_TOKENS = num_tokens + NUM_HEADS = num_query_heads + HEAD_SIZE = head_size + + print(f"\nNUM_TOKENS:{NUM_TOKENS}, NUM_HEADS:{NUM_HEADS}, " + f"HEAD_SIZE:{HEAD_SIZE}, DTYPE: {output_dtype}, " + f"Device: {current_platform.get_device_name()}") + + # prefix_lse and suffix_lse contain inf and normal values + prefix_lse = torch.randn(NUM_HEADS, + NUM_TOKENS, + dtype=torch.float32, + device="cuda") + suffix_lse = torch.randn(NUM_HEADS, + NUM_TOKENS, + dtype=torch.float32, + device="cuda") + + # Generate boolean masks + mask_prefix = torch.rand(NUM_HEADS, NUM_TOKENS) < 0.1 + mask_suffix = torch.rand(NUM_HEADS, NUM_TOKENS) < 0.1 + # Ensure that the same position is not True at the same time + combined_mask = torch.logical_and(mask_prefix, mask_suffix) + mask_prefix = torch.logical_and(mask_prefix, ~combined_mask) + mask_suffix = torch.logical_and(mask_suffix, ~combined_mask) + + prefix_lse[mask_prefix] = float('inf') + suffix_lse[mask_suffix] = float('inf') + + # Other input tensors (need to be initialized but + # no actual calculation needed) + output = torch.zeros((NUM_TOKENS, NUM_HEADS, HEAD_SIZE), + dtype=output_dtype, + device="cuda") + output_lse = torch.zeros((NUM_HEADS, NUM_TOKENS), + dtype=torch.float32, + device="cuda") + prefix_output = torch.randn((NUM_TOKENS, NUM_HEADS, HEAD_SIZE), + dtype=output_dtype, + device="cuda") + suffix_output = torch.randn((NUM_TOKENS, NUM_HEADS, HEAD_SIZE), + dtype=output_dtype, + device="cuda") + + warmup_times = 2 + repeat_times = 20 + + output_torch = output.clone() + output_lse_torch = output_lse.clone() + total_time_torch_kernel = 0 + start = torch.cuda.Event(enable_timing=True) + end = torch.cuda.Event(enable_timing=True) + + # 0. Run the Torch kernel + prefix_lse_torch = prefix_lse.clone() + suffix_lse_torch = suffix_lse.clone() + for _ in range(warmup_times): + output_torch, output_lse_torch = merge_attn_states_torch( + output_torch, prefix_output, prefix_lse_torch, suffix_output, + suffix_lse_torch, output_lse_torch) + torch.cuda.synchronize() + + for _ in range(repeat_times): + start.record() + output_torch, output_lse_torch = merge_attn_states_torch( + output_torch, prefix_output, prefix_lse_torch, suffix_output, + suffix_lse_torch, output_lse_torch) + end.record() + torch.cuda.synchronize() + total_time_torch_kernel += start.elapsed_time(end) + + avg_time_torch_kernel = total_time_torch_kernel / repeat_times + + # 1. Run the Triton kernel + output_ref_triton = output.clone() + output_lse_ref_triton = output_lse.clone() + + total_time_triton_kernel = 0 + start = torch.cuda.Event(enable_timing=True) + end = torch.cuda.Event(enable_timing=True) + + for _ in range(warmup_times): + merge_attn_states_triton(output_ref_triton, prefix_output, prefix_lse, + suffix_output, suffix_lse, + output_lse_ref_triton) + torch.cuda.synchronize() + + for _ in range(repeat_times): + start.record() + merge_attn_states_triton(output_ref_triton, prefix_output, prefix_lse, + suffix_output, suffix_lse, + output_lse_ref_triton) + end.record() + torch.cuda.synchronize() + total_time_triton_kernel += start.elapsed_time(end) + + avg_time_triton_kernel = total_time_triton_kernel / repeat_times + + # 2. Run the CUDA kernel + total_time_cuda_kernel = 0 + output_cuda = output.clone() + output_lse_cuda = output_lse.clone() + + for _ in range(warmup_times): + merge_attn_states_cuda(output_cuda, prefix_output, prefix_lse, + suffix_output, suffix_lse, output_lse_cuda) + torch.cuda.synchronize() + + for _ in range(repeat_times): + start.record() + merge_attn_states_cuda(output_cuda, prefix_output, prefix_lse, + suffix_output, suffix_lse, output_lse_cuda) + end.record() + torch.cuda.synchronize() + total_time_cuda_kernel += start.elapsed_time(end) + + avg_time_cuda_kernel = total_time_cuda_kernel / repeat_times + + # 3. Performance compare + performance_improved = avg_time_triton_kernel / avg_time_cuda_kernel + print(f" Torch time: {avg_time_torch_kernel:.6f}ms") + print(f"Triton time: {avg_time_triton_kernel:.6f}ms") + print(f" CUDA time: {avg_time_cuda_kernel:.6f}ms, " + f"Performance: {performance_improved:.5f}x") + print("-" * 100) + + # 4. Correctness compare + # Liger Kernel: Efficient Triton Kernels for LLM Training + # https://arxiv.org/pdf/2410.10989, 3.3 Correctness + # use rtol = 1e-2 for bfloat16. + rtol = 1e-2 if output_dtype == torch.bfloat16 else 1e-3 + + def diff(a: torch.Tensor, b: torch.Tensor): + max_diff = torch.max(torch.abs(a.float() - b.float())) + return max_diff + + # Use Triton output as reference because we want to replace + # the Triton kernel with custom CUDA kernel for merge attn + # states operation. + output_ref = output_ref_triton + output_lse_ref = output_lse_ref_triton + torch.testing.assert_close(output_cuda.float(), + output_ref.float(), + atol=1e-3, + rtol=rtol) + print("Output all match, max abs diff:") + print(f"(Triton vs Torch) : {diff(output_torch, output_ref)}") + print(f" (CUDA vs Torch) : {diff(output_torch, output_cuda)}") + print(f" (CUDA vs Triton): {diff(output_ref, output_cuda)}") + print("-" * 100) + + torch.testing.assert_close(output_lse_cuda.float(), + output_lse_ref.float(), + atol=1e-3, + rtol=rtol) + print("Output LSE all match, max abs diff:") + print(f"(Triton vs Torch) : {diff(output_lse_torch, output_lse_ref)}") + print(f" (CUDA vs Torch) : {diff(output_lse_torch, output_lse_cuda)}") + print(f" (CUDA vs Triton): {diff(output_lse_ref, output_lse_cuda)}") + print("-" * 100) + + print("All output values test passed! All inf values " + "are correctly replaced with -inf.") + print("-" * 100) + + device = current_platform.get_device_name() + all_case_info.append( + (NUM_TOKENS, NUM_HEADS, HEAD_SIZE, output_dtype, device, + avg_time_torch_kernel, avg_time_triton_kernel, avg_time_cuda_kernel, + performance_improved)) + if len(all_case_info) == (len(NUM_BATCH_TOKENS) * len(HEAD_SIZES) * + len(NUM_QUERY_HEADS) * len(DTYPES)): + generate_markdown_table() diff --git a/vllm/_custom_ops.py b/vllm/_custom_ops.py index 719e02ec..7a4c93ad 100644 --- a/vllm/_custom_ops.py +++ b/vllm/_custom_ops.py @@ -138,6 +138,17 @@ def mla_decode_kvcache_cpu( block_tables, seq_lens) +# merge attn states ops +def merge_attn_states(output: torch.Tensor, + prefix_output: torch.Tensor, + prefix_lse: torch.Tensor, + suffix_output: torch.Tensor, + suffix_lse: torch.Tensor, + output_lse: Optional[torch.Tensor] = None) -> None: + torch.ops._C.merge_attn_states(output, output_lse, prefix_output, + prefix_lse, suffix_output, suffix_lse) + + # pos encoding ops def rotary_embedding( positions: torch.Tensor, diff --git a/vllm/attention/backends/mla/common.py b/vllm/attention/backends/mla/common.py index 5a47c0f6..54278f5f 100644 --- a/vllm/attention/backends/mla/common.py +++ b/vllm/attention/backends/mla/common.py @@ -204,6 +204,7 @@ from vllm.attention.backends.abstract import (AttentionBackend, AttentionLayer, from vllm.attention.backends.utils import (PAD_SLOT_ID, compute_slot_mapping, compute_slot_mapping_start_idx, is_block_tables_empty) +from vllm.attention.ops.merge_attn_states import merge_attn_states from vllm.model_executor.layers.linear import (ColumnParallelLinear, LinearBase, RowParallelLinear, UnquantizedLinearMethod) @@ -217,9 +218,7 @@ from vllm.vllm_flash_attn.fa_utils import get_flash_attn_version if HAS_TRITON: from vllm.attention.ops.triton_flash_attention import triton_attention - from vllm.attention.ops.triton_merge_attn_states import merge_attn_states else: - merge_attn_states = None triton_attention = None try: diff --git a/vllm/attention/ops/merge_attn_states.py b/vllm/attention/ops/merge_attn_states.py new file mode 100644 index 00000000..f9fcfe6a --- /dev/null +++ b/vllm/attention/ops/merge_attn_states.py @@ -0,0 +1,42 @@ +# SPDX-License-Identifier: Apache-2.0 +from typing import Optional + +import torch + +from vllm.platforms import current_platform + + +def merge_attn_states( + output: torch.Tensor, + prefix_output: torch.Tensor, + prefix_lse: torch.Tensor, + suffix_output: torch.Tensor, + suffix_lse: torch.Tensor, + output_lse: Optional[torch.Tensor] = None, +) -> None: + + # NOTE(DefTruth): Currently, custom merge_attn_states CUDA kernel + # is not support for FP8 dtype, fallback to use Triton kernel. + def supported_dtypes(o: torch.Tensor) -> bool: + return o.dtype in [torch.float32, torch.half, torch.bfloat16] + + # NOTE(DefTruth): Currently, custom merge_attn_states CUDA + # kernel load/store 128b(16 bytes) per memory issue within + # thread. Namely, the headsize(headdim) must be multiple of + # pack_size (float32 -> 4, half/bfloat16 -> 8). + def supported_headdim(o: torch.Tensor) -> bool: + headdim = o.shape[2] # [NUM_TOKENS, NUM_HEADS, HEAD_SIZE] + if o.dtype == torch.float32: + return headdim % 4 == 0 + return headdim % 8 == 0 + + if (current_platform.is_cuda() and supported_dtypes(output) + and supported_headdim(output)): + from vllm._custom_ops import merge_attn_states + return merge_attn_states(output, prefix_output, prefix_lse, + suffix_output, suffix_lse, output_lse) + else: + from vllm.attention.ops.triton_merge_attn_states import ( + merge_attn_states) + return merge_attn_states(output, prefix_output, prefix_lse, + suffix_output, suffix_lse, output_lse) diff --git a/vllm/v1/attention/backends/flash_attn.py b/vllm/v1/attention/backends/flash_attn.py index e1858149..b4c7708d 100755 --- a/vllm/v1/attention/backends/flash_attn.py +++ b/vllm/v1/attention/backends/flash_attn.py @@ -10,7 +10,7 @@ from vllm import _custom_ops as ops from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl, AttentionMetadata, AttentionType, is_quantized_kv_cache) -from vllm.attention.ops.triton_merge_attn_states import merge_attn_states +from vllm.attention.ops.merge_attn_states import merge_attn_states from vllm.logger import init_logger from vllm.platforms import current_platform from vllm.utils import cdiv diff --git a/vllm/v1/attention/backends/mla/common.py b/vllm/v1/attention/backends/mla/common.py index e6c4ebc7..8c7179ba 100644 --- a/vllm/v1/attention/backends/mla/common.py +++ b/vllm/v1/attention/backends/mla/common.py @@ -195,7 +195,7 @@ from vllm import _custom_ops as ops from vllm.attention.backends.abstract import (AttentionBackend, AttentionLayer, AttentionMetadata, MLAAttentionImpl) -from vllm.attention.ops.triton_merge_attn_states import merge_attn_states +from vllm.attention.ops.merge_attn_states import merge_attn_states from vllm.logger import init_logger from vllm.model_executor.layers.linear import (ColumnParallelLinear, LinearBase, RowParallelLinear,