[TPU][V1] MHA Pallas backend (#15288)
Signed-off-by: NickLucche <nlucches@redhat.com>
This commit is contained in:
parent
baec0d4de9
commit
cfbb8c930f
109
tests/v1/tpu/test_mha_attn.py
Normal file
109
tests/v1/tpu/test_mha_attn.py
Normal file
@ -0,0 +1,109 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
"""
|
||||
Test:
|
||||
|
||||
* Tests for MultiHeadAttention layer
|
||||
"""
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
import torch_xla
|
||||
import torch_xla.core
|
||||
import torch_xla.core.xla_model
|
||||
|
||||
from vllm import envs
|
||||
from vllm.attention.layer import MultiHeadAttention
|
||||
from vllm.attention.selector import _cached_get_attn_backend
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
if not envs.VLLM_USE_V1:
|
||||
pytest.skip(
|
||||
"Skipping V1 tests. Rerun with `VLLM_USE_V1=1` to test.",
|
||||
allow_module_level=True,
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def clear_cache():
|
||||
"""Clear lru cache to ensure each test case runs without caching.
|
||||
"""
|
||||
_cached_get_attn_backend.cache_clear()
|
||||
|
||||
|
||||
def ref_attention(
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
scale: float,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Native implementation of scaled dot product attention without mask:
|
||||
- query, key, value: [batch_size, seq_len, num_heads, head_size]
|
||||
- attn_mask: [batch_size, seq_len, seq_len]
|
||||
"""
|
||||
query, key, value = (x.transpose(1, 2) for x in (query, key, value))
|
||||
attn_weights = scale * torch.matmul(query, key.transpose(2, 3))
|
||||
attn_weights = torch.softmax(attn_weights, dim=-1).to(value.dtype)
|
||||
out = torch.matmul(attn_weights, value).transpose(1, 2)
|
||||
return out
|
||||
|
||||
|
||||
BATCH_SIZES = [1, 16]
|
||||
SEQ_LENS = [1]
|
||||
NUM_HEADS = [1, 16]
|
||||
NUM_KV_HEADS = [1]
|
||||
HEAD_SIZES = [64, 80]
|
||||
|
||||
|
||||
@pytest.mark.skipif(not current_platform.is_tpu(),
|
||||
reason="This test needs a TPU")
|
||||
@pytest.mark.parametrize("batch_size", BATCH_SIZES)
|
||||
@pytest.mark.parametrize("seq_len", SEQ_LENS)
|
||||
@pytest.mark.parametrize("num_heads", NUM_HEADS)
|
||||
@pytest.mark.parametrize("num_kv_heads", NUM_KV_HEADS)
|
||||
@pytest.mark.parametrize("head_size", HEAD_SIZES)
|
||||
@pytest.mark.parametrize("device", [torch_xla.core.xla_model.xla_device()])
|
||||
def test_mha_attn_forward(
|
||||
batch_size: int,
|
||||
seq_len: int,
|
||||
num_heads: int,
|
||||
num_kv_heads: int,
|
||||
head_size: int,
|
||||
device: str,
|
||||
):
|
||||
current_platform.seed_everything(0)
|
||||
# These are expected to be f32
|
||||
q = torch.randn(batch_size, seq_len, num_heads * head_size, device=device)
|
||||
k = torch.randn(batch_size,
|
||||
seq_len,
|
||||
num_kv_heads * head_size,
|
||||
device=device)
|
||||
v = torch.randn(batch_size,
|
||||
seq_len,
|
||||
num_kv_heads * head_size,
|
||||
device=device)
|
||||
scale = 1.0 / head_size**0.5
|
||||
attn = MultiHeadAttention(num_heads,
|
||||
head_size,
|
||||
scale=scale,
|
||||
num_kv_heads=num_kv_heads)
|
||||
output = attn(q, k, v)
|
||||
|
||||
assert num_heads % num_kv_heads == 0
|
||||
num_queries_per_kv = num_heads // num_kv_heads
|
||||
|
||||
q = q.reshape(batch_size, seq_len, num_heads, head_size)
|
||||
k = k.reshape(batch_size, seq_len, num_kv_heads, head_size)
|
||||
v = v.reshape(batch_size, seq_len, num_kv_heads, head_size)
|
||||
if num_queries_per_kv > 1:
|
||||
k = torch.repeat_interleave(k, num_queries_per_kv, dim=2)
|
||||
v = torch.repeat_interleave(v, num_queries_per_kv, dim=2)
|
||||
|
||||
ref_output = ref_attention(
|
||||
q,
|
||||
k,
|
||||
v,
|
||||
scale=scale,
|
||||
).reshape(batch_size, seq_len, num_heads * head_size)
|
||||
# torch_xla flash_attn kernel is less accurate but much faster
|
||||
torch.testing.assert_close(output, ref_output, atol=1e-2, rtol=1e-3)
|
@ -281,8 +281,7 @@ class MultiHeadAttention(nn.Module):
|
||||
backend = _Backend.XFORMERS
|
||||
|
||||
self.attn_backend = backend if backend in {
|
||||
_Backend.TORCH_SDPA,
|
||||
_Backend.XFORMERS,
|
||||
_Backend.TORCH_SDPA, _Backend.XFORMERS, _Backend.PALLAS_VLLM_V1
|
||||
} else _Backend.TORCH_SDPA
|
||||
|
||||
def forward(
|
||||
@ -320,6 +319,13 @@ class MultiHeadAttention(nn.Module):
|
||||
value,
|
||||
scale=self.scale)
|
||||
out = out.transpose(1, 2)
|
||||
elif self.attn_backend == _Backend.PALLAS_VLLM_V1:
|
||||
query, key, value = (x.transpose(1, 2)
|
||||
for x in (query, key, value))
|
||||
from torch_xla.experimental.custom_kernel import flash_attention
|
||||
out = flash_attention(query, key, value, sm_scale=self.scale)
|
||||
out = out.transpose(1, 2)
|
||||
|
||||
return out.reshape(bsz, q_len, -1)
|
||||
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user