
Signed-off-by: qscqesze <475517977@qq.com> Co-authored-by: qingjun <qingjun@minimaxi.com> Co-authored-by: qscqesze <475517977@qq.com>
287 lines
9.5 KiB
Python
287 lines
9.5 KiB
Python
# SPDX-License-Identifier: Apache-2.0
|
|
|
|
import pytest
|
|
import torch
|
|
|
|
from vllm.model_executor.layers.lightning_attn import (
|
|
linear_decode_forward_triton)
|
|
from vllm.platforms import current_platform
|
|
|
|
NUM_HEADS = [4, 8]
|
|
HEAD_SIZES = [64]
|
|
BATCH_SIZES = [1, 2]
|
|
SEQ_LENGTHS = [16]
|
|
DTYPES = [torch.float32]
|
|
|
|
|
|
def reference_lightning_attention(q, k, v, ed, block_size, kv_history):
|
|
"""Reference implementation of lightning attention core algorithm
|
|
|
|
The difference from the main implementation is that this processes
|
|
each step sequentially, instead of using parallelized triton kernels
|
|
"""
|
|
B, H, S, D = q.shape
|
|
E = v.shape[-1]
|
|
dtype = q.dtype
|
|
output = torch.zeros((B, H, S, E), dtype=dtype, device=q.device)
|
|
|
|
# Use clone() to ensure an independent copy
|
|
if kv_history is None:
|
|
kv_cache = torch.zeros((B, H, D, E), dtype=dtype, device=q.device)
|
|
else:
|
|
kv_cache = kv_history.clone()
|
|
|
|
# More efficient implementation
|
|
# Convert decay factors to matrix form
|
|
if ed.dim() == 1:
|
|
decay = torch.exp(-ed).view(1, -1, 1, 1)
|
|
else:
|
|
decay = torch.exp(-ed)
|
|
|
|
for b in range(B):
|
|
for step in range(S):
|
|
# Process all heads at once for this position
|
|
q_bs = q[b, :, step] # [H, D]
|
|
k_bs = k[b, :, step] # [H, D]
|
|
v_bs = v[b, :, step] # [H, E]
|
|
|
|
# Calculate KV outer products for all heads
|
|
for h in range(H):
|
|
# Calculate KV outer product
|
|
kv_outer = torch.outer(k_bs[h], v_bs[h])
|
|
|
|
# Update KV cache with decay
|
|
# Note: Using the same order as in the Triton kernel
|
|
kv_cache[b, h] = decay[0, h, 0, 0] * kv_cache[b, h] + kv_outer
|
|
|
|
# Calculate attention output
|
|
output[b, h, step] = torch.matmul(q_bs[h], kv_cache[b, h])
|
|
|
|
# Match the shape returned by the actual implementation
|
|
# The actual implementation returns a tensor of shape [B, H, 2, D, E]
|
|
# where dimension 2 contains both KV and KV history
|
|
kv_reshaped = kv_cache.unsqueeze(2) # [B, H, 1, D, E]
|
|
final_kv_cache = torch.cat([kv_reshaped, kv_reshaped],
|
|
dim=2) # [B, H, 2, D, E]
|
|
|
|
return output, final_kv_cache
|
|
|
|
|
|
def reference_linear_decode(q, k, v, kv_caches, slope_rate, slot_idx):
|
|
"""Reference implementation: linear attention decode function"""
|
|
B, H, _, D = q.shape
|
|
output = torch.zeros(B, H * D, dtype=q.dtype, device=q.device)
|
|
|
|
# Calculate decay factors once (more efficient)
|
|
decay = torch.exp(-slope_rate).view(-1, 1, 1) # [H, 1, 1]
|
|
|
|
# Process each batch
|
|
for b in range(B):
|
|
slot_id = slot_idx[b].item()
|
|
|
|
# Skip padding positions
|
|
if slot_id == -1:
|
|
continue
|
|
|
|
# Process all heads at once for this batch
|
|
q_b = q[b, :, 0] # [H, D]
|
|
k_b = k[b, :, 0] # [H, D]
|
|
v_b = v[b, :, 0] # [H, D]
|
|
|
|
# Process each attention head
|
|
for h in range(H):
|
|
# Get current query, key and value
|
|
q_bh = q_b[h]
|
|
k_bh = k_b[h]
|
|
v_bh = v_b[h]
|
|
|
|
# Get cache
|
|
kv_cache_old = kv_caches[b, h]
|
|
|
|
# Calculate new key-value outer product
|
|
kv_outer = torch.outer(k_bh, v_bh)
|
|
|
|
# Apply decay and update cache
|
|
kv_new = kv_outer + decay[h, 0, 0] * kv_cache_old
|
|
|
|
# Calculate output
|
|
out_h = torch.matmul(q_bh, kv_new)
|
|
|
|
# Update output and cache
|
|
output[b, h * D:(h + 1) * D] = out_h
|
|
kv_caches[b, h] = kv_new
|
|
|
|
return output
|
|
|
|
|
|
@pytest.mark.parametrize("batch_size", BATCH_SIZES)
|
|
@pytest.mark.parametrize("num_heads", NUM_HEADS)
|
|
@pytest.mark.parametrize("head_size", HEAD_SIZES)
|
|
@pytest.mark.parametrize("dtype", DTYPES)
|
|
@torch.inference_mode()
|
|
def test_linear_decode_forward_triton(
|
|
batch_size: int,
|
|
num_heads: int,
|
|
head_size: int,
|
|
dtype: torch.dtype,
|
|
):
|
|
torch.set_default_device("cuda")
|
|
torch.manual_seed(42)
|
|
torch.cuda.manual_seed_all(42)
|
|
current_platform.seed_everything(42)
|
|
base = 0.01
|
|
q = base * torch.randn(batch_size, num_heads, 1, head_size, dtype=dtype)
|
|
k = base * torch.randn(batch_size, num_heads, 1, head_size, dtype=dtype)
|
|
v = base * torch.randn(batch_size, num_heads, 1, head_size, dtype=dtype)
|
|
|
|
kv_caches = base * torch.randn(batch_size,
|
|
num_heads,
|
|
head_size,
|
|
head_size,
|
|
dtype=dtype,
|
|
device="cuda")
|
|
|
|
kv_caches_copy = kv_caches.clone()
|
|
|
|
slope_rate = torch.zeros(num_heads, device="cuda")
|
|
for h in range(num_heads):
|
|
slope_rate[h] = 0.1 * (h + 1)
|
|
|
|
slot_idx = torch.arange(batch_size, device="cuda")
|
|
|
|
triton_output = linear_decode_forward_triton(q, k, v, kv_caches,
|
|
slope_rate, slot_idx)
|
|
|
|
reference_output = reference_linear_decode(q, k, v, kv_caches_copy,
|
|
slope_rate, slot_idx)
|
|
torch.testing.assert_close(triton_output,
|
|
reference_output,
|
|
rtol=1e-1,
|
|
atol=1e-1)
|
|
torch.testing.assert_close(kv_caches, kv_caches_copy, rtol=1e-1, atol=1e-1)
|
|
|
|
assert triton_output.shape == (batch_size, num_heads * head_size)
|
|
|
|
|
|
@pytest.mark.parametrize("num_heads", NUM_HEADS)
|
|
@pytest.mark.parametrize("head_size", HEAD_SIZES)
|
|
@pytest.mark.parametrize("dtype", DTYPES)
|
|
@torch.inference_mode()
|
|
def test_linear_decode_forward_triton_with_padding(
|
|
num_heads: int,
|
|
head_size: int,
|
|
dtype: torch.dtype,
|
|
):
|
|
torch.set_default_device("cuda")
|
|
torch.manual_seed(42)
|
|
torch.cuda.manual_seed_all(42)
|
|
current_platform.seed_everything(42)
|
|
|
|
batch_size = 4
|
|
base = 0.01
|
|
q = base * torch.randn(batch_size, num_heads, 1, head_size, dtype=dtype)
|
|
k = base * torch.randn(batch_size, num_heads, 1, head_size, dtype=dtype)
|
|
v = base * torch.randn(batch_size, num_heads, 1, head_size, dtype=dtype)
|
|
|
|
kv_caches = base * torch.randn(batch_size,
|
|
num_heads,
|
|
head_size,
|
|
head_size,
|
|
dtype=dtype,
|
|
device="cuda")
|
|
|
|
kv_caches_copy = kv_caches.clone()
|
|
|
|
slope_rate = torch.zeros(num_heads, device="cuda")
|
|
for h in range(num_heads):
|
|
slope_rate[h] = 0.1 * (h + 1)
|
|
|
|
slot_idx = torch.tensor([0, 1, -1, 2], device="cuda")
|
|
|
|
triton_output = linear_decode_forward_triton(q, k, v, kv_caches,
|
|
slope_rate, slot_idx)
|
|
|
|
reference_output = reference_linear_decode(q, k, v, kv_caches_copy,
|
|
slope_rate, slot_idx)
|
|
|
|
padding_mask = (slot_idx
|
|
!= -1).unsqueeze(1).expand(-1, num_heads * head_size)
|
|
|
|
triton_masked = triton_output[padding_mask]
|
|
reference_masked = reference_output[padding_mask]
|
|
|
|
atol, rtol = 1.5e-1, 1.5e-1
|
|
|
|
valid_indices = slot_idx != -1
|
|
|
|
for i in range(batch_size):
|
|
if valid_indices[i] > 0:
|
|
torch.testing.assert_close(kv_caches[i],
|
|
kv_caches_copy[i],
|
|
rtol=rtol,
|
|
atol=atol)
|
|
|
|
torch.testing.assert_close(triton_masked,
|
|
reference_masked,
|
|
rtol=rtol,
|
|
atol=atol)
|
|
|
|
assert triton_output.shape == (batch_size, num_heads * head_size)
|
|
|
|
|
|
@pytest.mark.parametrize("batch_size", BATCH_SIZES)
|
|
@pytest.mark.parametrize("num_heads", NUM_HEADS)
|
|
@pytest.mark.parametrize("head_size", HEAD_SIZES)
|
|
@pytest.mark.parametrize("seq_len", SEQ_LENGTHS)
|
|
@pytest.mark.parametrize("dtype", DTYPES)
|
|
@torch.inference_mode()
|
|
def test_lightning_attention_reference(
|
|
batch_size: int,
|
|
num_heads: int,
|
|
head_size: int,
|
|
seq_len: int,
|
|
dtype: torch.dtype,
|
|
):
|
|
torch.set_default_device("cuda")
|
|
torch.manual_seed(42)
|
|
torch.cuda.manual_seed_all(42)
|
|
current_platform.seed_everything(42)
|
|
|
|
base = 0.01
|
|
q = base * torch.randn(
|
|
batch_size, num_heads, seq_len, head_size, dtype=dtype)
|
|
k = base * torch.randn(
|
|
batch_size, num_heads, seq_len, head_size, dtype=dtype)
|
|
v = base * torch.randn(
|
|
batch_size, num_heads, seq_len, head_size, dtype=dtype)
|
|
|
|
ed = torch.zeros(num_heads, device="cuda")
|
|
for h in range(num_heads):
|
|
ed[h] = 0.1 * (h + 1)
|
|
|
|
kv_history = base * torch.randn(batch_size,
|
|
num_heads,
|
|
head_size,
|
|
head_size,
|
|
dtype=dtype,
|
|
device="cuda")
|
|
|
|
kv_history_clone = kv_history.clone()
|
|
|
|
ref_output, ref_kv_cache = reference_lightning_attention(
|
|
q, k, v, ed, 256, kv_history)
|
|
|
|
from vllm.model_executor.layers.lightning_attn import lightning_attention
|
|
actual_output, actual_kv_cache = lightning_attention(
|
|
q, k, v, ed, 256, kv_history_clone)
|
|
|
|
atol, rtol = 1.5e-1, 1.5e-1
|
|
torch.testing.assert_close(ref_output, actual_output, rtol=rtol, atol=atol)
|
|
torch.testing.assert_close(ref_kv_cache,
|
|
actual_kv_cache,
|
|
rtol=rtol,
|
|
atol=atol)
|
|
|
|
assert ref_output.shape == (batch_size, num_heads, seq_len, head_size)
|
|
assert ref_kv_cache.shape == actual_kv_cache.shape
|