95 lines
3.0 KiB
Python
95 lines
3.0 KiB
Python
# SPDX-License-Identifier: Apache-2.0
|
|
import pytest
|
|
import torch
|
|
import torch.nn.functional as F
|
|
from torch import Tensor
|
|
|
|
import vllm._custom_ops as ops
|
|
from vllm.platforms import current_platform
|
|
|
|
|
|
def cdiv(a, b):
|
|
return (a + b - 1) // b
|
|
|
|
|
|
def ref_mla(
|
|
out: Tensor, # (bs, num_heads, v_head_dim)
|
|
query: Tensor, # (bs, num_heads, head_dim)
|
|
kv_cache: Tensor, # (num_blocks, block_size, head_dim)
|
|
scale: float,
|
|
block_tables: Tensor, # (bs, max_num_blocks)
|
|
seq_lens: Tensor, # (bs,)
|
|
):
|
|
bs, num_heads, v_head_dim = out.shape
|
|
head_dim = query.shape[2]
|
|
|
|
for i in range(bs):
|
|
# gather and flatten KV-cache
|
|
kv = kv_cache[
|
|
block_tables[i]] # (max_num_blocks, block_size, head_dim)
|
|
kv = kv.view(1, -1,
|
|
head_dim)[:, :seq_lens[i]] # (1, seq_len, head_dim)
|
|
v = kv[:, :, :v_head_dim]
|
|
|
|
q = query[i].view(num_heads, 1, head_dim)
|
|
o = F.scaled_dot_product_attention(q,
|
|
kv,
|
|
v,
|
|
scale=scale,
|
|
enable_gqa=True)
|
|
out[i] = o.view(num_heads, v_head_dim)
|
|
|
|
return out
|
|
|
|
|
|
@pytest.mark.parametrize("bs", [4])
|
|
@pytest.mark.parametrize("mean_seq_len", [256])
|
|
@pytest.mark.parametrize("h_q", [16])
|
|
@pytest.mark.parametrize("d", [576])
|
|
@pytest.mark.parametrize("dv", [512])
|
|
@pytest.mark.parametrize("block_size", [16])
|
|
@pytest.mark.parametrize("dtype", [torch.float, torch.half, torch.bfloat16])
|
|
@pytest.mark.parametrize("varlen", [False, True])
|
|
@pytest.mark.cpu_model
|
|
@pytest.mark.skipif(not current_platform.is_cpu(), reason="CPU only")
|
|
def test_mla_decode_cpu(
|
|
bs: int,
|
|
mean_seq_len: int,
|
|
h_q: int,
|
|
d: int,
|
|
dv: int,
|
|
block_size: int,
|
|
dtype: torch.dtype,
|
|
varlen: bool,
|
|
):
|
|
torch.set_default_dtype(dtype)
|
|
torch.manual_seed(0)
|
|
|
|
scale = d**(-0.5)
|
|
if varlen:
|
|
seq_lens = torch.empty(bs).normal_(mean_seq_len, mean_seq_len / 2)
|
|
seq_lens = seq_lens.clip(2).to(torch.int32)
|
|
else:
|
|
seq_lens = torch.full((bs, ), mean_seq_len, dtype=torch.int32)
|
|
max_seq_len = seq_lens.max().item()
|
|
seqlen_pad = cdiv(max_seq_len, 256) * 256 # is this necessary?
|
|
|
|
q = torch.randn(bs, h_q, d)
|
|
block_table = torch.arange(bs * seqlen_pad // block_size,
|
|
dtype=torch.int32)
|
|
block_table = block_table.view(bs, seqlen_pad // block_size)
|
|
|
|
kv_cache = torch.randn(block_table.numel(), block_size, d)
|
|
for i, seq_len in enumerate(seq_lens.tolist()):
|
|
kv_cache.view(bs, seqlen_pad, d)[i, seq_len:] = float("nan")
|
|
|
|
out_mla = q.new_zeros(bs, h_q, dv)
|
|
ops.mla_decode_kvcache_cpu(out_mla, q, kv_cache, scale, block_table,
|
|
seq_lens)
|
|
|
|
out_ref = q.new_zeros(bs, h_q, dv)
|
|
ref_mla(out_ref, q, kv_cache, scale, block_table, seq_lens)
|
|
|
|
assert not out_mla.isnan().any(), "Likely read out of bounds"
|
|
torch.testing.assert_close(out_mla, out_ref)
|