# SPDX-License-Identifier: Apache-2.0 import pytest import torch import torch.nn.functional as F from torch import Tensor import vllm._custom_ops as ops from vllm.platforms import current_platform def cdiv(a, b): return (a + b - 1) // b def ref_mla( out: Tensor, # (bs, num_heads, v_head_dim) query: Tensor, # (bs, num_heads, head_dim) kv_cache: Tensor, # (num_blocks, block_size, head_dim) scale: float, block_tables: Tensor, # (bs, max_num_blocks) seq_lens: Tensor, # (bs,) ): bs, num_heads, v_head_dim = out.shape head_dim = query.shape[2] for i in range(bs): # gather and flatten KV-cache kv = kv_cache[ block_tables[i]] # (max_num_blocks, block_size, head_dim) kv = kv.view(1, -1, head_dim)[:, :seq_lens[i]] # (1, seq_len, head_dim) v = kv[:, :, :v_head_dim] q = query[i].view(num_heads, 1, head_dim) o = F.scaled_dot_product_attention(q, kv, v, scale=scale, enable_gqa=True) out[i] = o.view(num_heads, v_head_dim) return out @pytest.mark.parametrize("bs", [4]) @pytest.mark.parametrize("mean_seq_len", [256]) @pytest.mark.parametrize("h_q", [16]) @pytest.mark.parametrize("d", [576]) @pytest.mark.parametrize("dv", [512]) @pytest.mark.parametrize("block_size", [16]) @pytest.mark.parametrize("dtype", [torch.float, torch.half, torch.bfloat16]) @pytest.mark.parametrize("varlen", [False, True]) @pytest.mark.cpu_model @pytest.mark.skipif(not current_platform.is_cpu(), reason="CPU only") def test_mla_decode_cpu( bs: int, mean_seq_len: int, h_q: int, d: int, dv: int, block_size: int, dtype: torch.dtype, varlen: bool, ): torch.set_default_dtype(dtype) torch.manual_seed(0) scale = d**(-0.5) if varlen: seq_lens = torch.empty(bs).normal_(mean_seq_len, mean_seq_len / 2) seq_lens = seq_lens.clip(2).to(torch.int32) else: seq_lens = torch.full((bs, ), mean_seq_len, dtype=torch.int32) max_seq_len = seq_lens.max().item() seqlen_pad = cdiv(max_seq_len, 256) * 256 # is this necessary? q = torch.randn(bs, h_q, d) block_table = torch.arange(bs * seqlen_pad // block_size, dtype=torch.int32) block_table = block_table.view(bs, seqlen_pad // block_size) kv_cache = torch.randn(block_table.numel(), block_size, d) for i, seq_len in enumerate(seq_lens.tolist()): kv_cache.view(bs, seqlen_pad, d)[i, seq_len:] = float("nan") out_mla = q.new_zeros(bs, h_q, dv) ops.mla_decode_kvcache_cpu(out_mla, q, kv_cache, scale, block_table, seq_lens) out_ref = q.new_zeros(bs, h_q, dv) ref_mla(out_ref, q, kv_cache, scale, block_table, seq_lens) assert not out_mla.isnan().any(), "Likely read out of bounds" torch.testing.assert_close(out_mla, out_ref)