diff --git a/tests/v1/tpu/test_mha_attn.py b/tests/v1/tpu/test_mha_attn.py new file mode 100644 index 00000000..01664598 --- /dev/null +++ b/tests/v1/tpu/test_mha_attn.py @@ -0,0 +1,109 @@ +# SPDX-License-Identifier: Apache-2.0 +""" +Test: + +* Tests for MultiHeadAttention layer +""" + +import pytest +import torch +import torch_xla +import torch_xla.core +import torch_xla.core.xla_model + +from vllm import envs +from vllm.attention.layer import MultiHeadAttention +from vllm.attention.selector import _cached_get_attn_backend +from vllm.platforms import current_platform + +if not envs.VLLM_USE_V1: + pytest.skip( + "Skipping V1 tests. Rerun with `VLLM_USE_V1=1` to test.", + allow_module_level=True, + ) + + +@pytest.fixture(autouse=True) +def clear_cache(): + """Clear lru cache to ensure each test case runs without caching. + """ + _cached_get_attn_backend.cache_clear() + + +def ref_attention( + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + scale: float, +) -> torch.Tensor: + """ + Native implementation of scaled dot product attention without mask: + - query, key, value: [batch_size, seq_len, num_heads, head_size] + - attn_mask: [batch_size, seq_len, seq_len] + """ + query, key, value = (x.transpose(1, 2) for x in (query, key, value)) + attn_weights = scale * torch.matmul(query, key.transpose(2, 3)) + attn_weights = torch.softmax(attn_weights, dim=-1).to(value.dtype) + out = torch.matmul(attn_weights, value).transpose(1, 2) + return out + + +BATCH_SIZES = [1, 16] +SEQ_LENS = [1] +NUM_HEADS = [1, 16] +NUM_KV_HEADS = [1] +HEAD_SIZES = [64, 80] + + +@pytest.mark.skipif(not current_platform.is_tpu(), + reason="This test needs a TPU") +@pytest.mark.parametrize("batch_size", BATCH_SIZES) +@pytest.mark.parametrize("seq_len", SEQ_LENS) +@pytest.mark.parametrize("num_heads", NUM_HEADS) +@pytest.mark.parametrize("num_kv_heads", NUM_KV_HEADS) +@pytest.mark.parametrize("head_size", HEAD_SIZES) +@pytest.mark.parametrize("device", [torch_xla.core.xla_model.xla_device()]) +def test_mha_attn_forward( + batch_size: int, + seq_len: int, + num_heads: int, + num_kv_heads: int, + head_size: int, + device: str, +): + current_platform.seed_everything(0) + # These are expected to be f32 + q = torch.randn(batch_size, seq_len, num_heads * head_size, device=device) + k = torch.randn(batch_size, + seq_len, + num_kv_heads * head_size, + device=device) + v = torch.randn(batch_size, + seq_len, + num_kv_heads * head_size, + device=device) + scale = 1.0 / head_size**0.5 + attn = MultiHeadAttention(num_heads, + head_size, + scale=scale, + num_kv_heads=num_kv_heads) + output = attn(q, k, v) + + assert num_heads % num_kv_heads == 0 + num_queries_per_kv = num_heads // num_kv_heads + + q = q.reshape(batch_size, seq_len, num_heads, head_size) + k = k.reshape(batch_size, seq_len, num_kv_heads, head_size) + v = v.reshape(batch_size, seq_len, num_kv_heads, head_size) + if num_queries_per_kv > 1: + k = torch.repeat_interleave(k, num_queries_per_kv, dim=2) + v = torch.repeat_interleave(v, num_queries_per_kv, dim=2) + + ref_output = ref_attention( + q, + k, + v, + scale=scale, + ).reshape(batch_size, seq_len, num_heads * head_size) + # torch_xla flash_attn kernel is less accurate but much faster + torch.testing.assert_close(output, ref_output, atol=1e-2, rtol=1e-3) diff --git a/vllm/attention/layer.py b/vllm/attention/layer.py index 946c07d5..dbf4723e 100644 --- a/vllm/attention/layer.py +++ b/vllm/attention/layer.py @@ -281,8 +281,7 @@ class MultiHeadAttention(nn.Module): backend = _Backend.XFORMERS self.attn_backend = backend if backend in { - _Backend.TORCH_SDPA, - _Backend.XFORMERS, + _Backend.TORCH_SDPA, _Backend.XFORMERS, _Backend.PALLAS_VLLM_V1 } else _Backend.TORCH_SDPA def forward( @@ -320,6 +319,13 @@ class MultiHeadAttention(nn.Module): value, scale=self.scale) out = out.transpose(1, 2) + elif self.attn_backend == _Backend.PALLAS_VLLM_V1: + query, key, value = (x.transpose(1, 2) + for x in (query, key, value)) + from torch_xla.experimental.custom_kernel import flash_attention + out = flash_attention(query, key, value, sm_scale=self.scale) + out = out.transpose(1, 2) + return out.reshape(bsz, q_len, -1)