2025-02-02 14:58:18 -05:00
|
|
|
# SPDX-License-Identifier: Apache-2.0
|
2025-01-25 15:07:35 +08:00
|
|
|
"""
|
|
|
|
Test:
|
|
|
|
|
|
|
|
* Tests for MultiHeadAttention layer
|
|
|
|
"""
|
|
|
|
from unittest.mock import patch
|
|
|
|
|
|
|
|
import pytest
|
|
|
|
import torch
|
|
|
|
|
|
|
|
from vllm.attention.layer import MultiHeadAttention
|
|
|
|
from vllm.attention.selector import _Backend, _cached_get_attn_backend
|
|
|
|
from vllm.platforms import current_platform
|
|
|
|
from vllm.platforms.cpu import CpuPlatform
|
|
|
|
from vllm.platforms.cuda import CudaPlatform
|
|
|
|
from vllm.platforms.rocm import RocmPlatform
|
|
|
|
|
|
|
|
|
|
|
|
@pytest.fixture(autouse=True)
|
|
|
|
def clear_cache():
|
|
|
|
"""Clear lru cache to ensure each test case runs without caching.
|
|
|
|
"""
|
|
|
|
_cached_get_attn_backend.cache_clear()
|
|
|
|
|
|
|
|
|
|
|
|
@pytest.mark.parametrize("device", ["cpu", "hip", "cuda"])
|
|
|
|
def test_mha_attn_platform(device: str):
|
|
|
|
"""
|
2025-01-26 13:39:03 -05:00
|
|
|
Test the attention selector between different platform and device.
|
2025-01-25 15:07:35 +08:00
|
|
|
"""
|
|
|
|
torch.set_default_dtype(torch.float16)
|
|
|
|
|
|
|
|
if device == "cpu":
|
|
|
|
with patch("vllm.attention.selector.current_platform", CpuPlatform()):
|
|
|
|
attn = MultiHeadAttention(16, 64, scale=1)
|
|
|
|
assert attn.attn_backend == _Backend.TORCH_SDPA
|
|
|
|
elif device == "hip":
|
|
|
|
with patch("vllm.attention.selector.current_platform", RocmPlatform()):
|
|
|
|
attn = MultiHeadAttention(16, 64, scale=1)
|
|
|
|
assert attn.attn_backend == _Backend.TORCH_SDPA
|
|
|
|
else:
|
|
|
|
with patch("vllm.attention.selector.current_platform", CudaPlatform()):
|
|
|
|
attn = MultiHeadAttention(16, 64, scale=1)
|
2025-01-26 13:39:03 -05:00
|
|
|
assert attn.attn_backend == _Backend.XFORMERS
|
2025-01-25 15:07:35 +08:00
|
|
|
|
|
|
|
with patch("vllm.attention.selector.current_platform", CudaPlatform()):
|
|
|
|
attn = MultiHeadAttention(16, 72, scale=1)
|
|
|
|
assert attn.attn_backend == _Backend.XFORMERS
|
|
|
|
|
|
|
|
|
|
|
|
def ref_attention(
|
|
|
|
query: torch.Tensor,
|
|
|
|
key: torch.Tensor,
|
|
|
|
value: torch.Tensor,
|
|
|
|
scale: float,
|
|
|
|
) -> torch.Tensor:
|
|
|
|
"""
|
|
|
|
Native implementation of scaled dot product attention without mask:
|
|
|
|
- query, key, value: [batch_size, seq_len, num_heads, head_size]
|
|
|
|
- attn_mask: [batch_size, seq_len, seq_len]
|
|
|
|
"""
|
|
|
|
query, key, value = (x.transpose(1, 2) for x in (query, key, value))
|
|
|
|
attn_weights = scale * torch.matmul(query, key.transpose(2, 3))
|
|
|
|
attn_weights = torch.softmax(attn_weights, dim=-1).to(value.dtype)
|
|
|
|
out = torch.matmul(attn_weights, value).transpose(1, 2)
|
|
|
|
return out
|
|
|
|
|
|
|
|
|
|
|
|
BATCH_SIZES = [1, 16]
|
|
|
|
SEQ_LENS = [1]
|
|
|
|
NUM_HEADS = [1, 16]
|
|
|
|
NUM_KV_HEADS = [1]
|
|
|
|
HEAD_SIZES = [64, 80]
|
|
|
|
# flshattF and tritonflashattF supported: {torch.float16, torch.bfloat16}
|
|
|
|
DTYPES = [
|
|
|
|
torch.half, torch.bfloat16, torch.float
|
|
|
|
] if not current_platform.is_rocm() else [torch.half, torch.bfloat16]
|
|
|
|
CUDA_DEVICES = ["cuda"]
|
|
|
|
|
|
|
|
|
|
|
|
@pytest.mark.parametrize("batch_size", BATCH_SIZES)
|
|
|
|
@pytest.mark.parametrize("seq_len", SEQ_LENS)
|
|
|
|
@pytest.mark.parametrize("num_heads", NUM_HEADS)
|
|
|
|
@pytest.mark.parametrize("num_kv_heads", NUM_KV_HEADS)
|
|
|
|
@pytest.mark.parametrize("head_size", HEAD_SIZES)
|
|
|
|
@pytest.mark.parametrize("dtype", DTYPES)
|
|
|
|
@pytest.mark.parametrize("device", CUDA_DEVICES)
|
|
|
|
def test_mha_attn_forward(
|
|
|
|
batch_size: int,
|
|
|
|
seq_len: int,
|
|
|
|
num_heads: int,
|
|
|
|
num_kv_heads: int,
|
|
|
|
head_size: int,
|
|
|
|
dtype: torch.dtype,
|
|
|
|
device: str,
|
|
|
|
):
|
|
|
|
current_platform.seed_everything(0)
|
|
|
|
torch.set_default_device(device)
|
|
|
|
torch.set_default_dtype(dtype)
|
|
|
|
|
|
|
|
q = torch.randn(batch_size, seq_len, num_heads * head_size)
|
|
|
|
k = torch.randn(batch_size, seq_len, num_kv_heads * head_size)
|
|
|
|
v = torch.randn(batch_size, seq_len, num_kv_heads * head_size)
|
|
|
|
scale = 1.0 / head_size**0.5
|
|
|
|
attn = MultiHeadAttention(num_heads,
|
|
|
|
head_size,
|
|
|
|
scale=scale,
|
|
|
|
num_kv_heads=num_kv_heads)
|
|
|
|
output = attn(q, k, v)
|
|
|
|
|
|
|
|
assert num_heads % num_kv_heads == 0
|
|
|
|
num_queries_per_kv = num_heads // num_kv_heads
|
|
|
|
q = q.reshape(batch_size, seq_len, num_heads, head_size)
|
|
|
|
k = k.reshape(batch_size, seq_len, num_kv_heads, head_size)
|
|
|
|
v = v.reshape(batch_size, seq_len, num_kv_heads, head_size)
|
|
|
|
if num_queries_per_kv > 1:
|
|
|
|
k = torch.repeat_interleave(k, num_queries_per_kv, dim=2)
|
|
|
|
v = torch.repeat_interleave(v, num_queries_per_kv, dim=2)
|
|
|
|
|
|
|
|
ref_output = ref_attention(
|
|
|
|
q,
|
|
|
|
k,
|
|
|
|
v,
|
|
|
|
scale=scale,
|
|
|
|
).reshape(batch_size, seq_len, num_heads * head_size)
|
|
|
|
torch.testing.assert_close(output, ref_output)
|