2025-04-03 14:23:28 -07:00
|
|
|
# SPDX-License-Identifier: Apache-2.0
|
|
|
|
from unittest.mock import ANY, patch
|
|
|
|
|
|
|
|
import torch
|
|
|
|
|
|
|
|
from vllm.attention.backends.abstract import AttentionType
|
2025-04-08 23:46:32 -07:00
|
|
|
from vllm.v1.attention.backends.pallas import (PallasAttentionBackendImpl,
|
2025-04-03 14:23:28 -07:00
|
|
|
PallasMetadata)
|
|
|
|
|
|
|
|
|
|
|
|
def test_ragged_paged_attention():
|
|
|
|
# We verify that the kernel inputs such as sliding_window, etc. are passed
|
|
|
|
# in from the model correctly.
|
|
|
|
# The correctness of the paged attention kernel is tested in the kernel
|
|
|
|
# library.
|
|
|
|
num_heads = 4
|
|
|
|
head_size = 128
|
|
|
|
scale = 1.0
|
|
|
|
num_kv_heads = 4
|
|
|
|
sliding_window = 128
|
|
|
|
logits_soft_cap = 50.0
|
|
|
|
attn_impl = PallasAttentionBackendImpl(
|
|
|
|
num_heads=num_heads,
|
|
|
|
head_size=head_size,
|
|
|
|
scale=scale,
|
|
|
|
num_kv_heads=num_kv_heads,
|
|
|
|
alibi_slopes=None,
|
|
|
|
sliding_window=sliding_window,
|
|
|
|
kv_cache_dtype="auto",
|
|
|
|
logits_soft_cap=logits_soft_cap,
|
|
|
|
attn_type=AttentionType.DECODER,
|
|
|
|
)
|
|
|
|
|
|
|
|
class FakeAttentionLayer:
|
|
|
|
_k_scale_float: float
|
|
|
|
_v_scale_float: float
|
|
|
|
|
|
|
|
layer = FakeAttentionLayer()
|
|
|
|
layer._k_scale_float = 1.0
|
|
|
|
layer._v_scale_float = 1.0
|
|
|
|
|
|
|
|
num_tokens = 16
|
|
|
|
num_blocks = 1024
|
|
|
|
block_size = 16
|
|
|
|
query = torch.zeros(num_tokens, num_heads * head_size)
|
|
|
|
key = torch.zeros(num_tokens, num_kv_heads * head_size)
|
|
|
|
value = torch.zeros(num_tokens, num_kv_heads * head_size)
|
|
|
|
kv_cache = torch.zeros(num_blocks, block_size, num_kv_heads * 2, head_size)
|
|
|
|
slot_mapping = torch.zeros(num_tokens, dtype=torch.int64)
|
|
|
|
max_num_reqs = 8
|
|
|
|
max_num_blocks_per_req = 8
|
|
|
|
block_tables = torch.zeros((max_num_reqs, max_num_blocks_per_req),
|
|
|
|
dtype=torch.int32)
|
|
|
|
context_lens = torch.ones((max_num_reqs, ), dtype=torch.int32)
|
|
|
|
query_lens = [1] * max_num_reqs
|
|
|
|
query_start_loc = torch.cumsum(torch.tensor([0] + query_lens,
|
|
|
|
dtype=torch.int32),
|
|
|
|
dim=0,
|
|
|
|
dtype=torch.int32)
|
|
|
|
num_seqs = torch.tensor([max_num_reqs], dtype=torch.int32)
|
|
|
|
attn_metadata = PallasMetadata(
|
|
|
|
slot_mapping=slot_mapping,
|
|
|
|
block_tables=block_tables,
|
|
|
|
context_lens=context_lens,
|
|
|
|
query_start_loc=query_start_loc,
|
|
|
|
num_seqs=num_seqs,
|
|
|
|
)
|
|
|
|
|
|
|
|
with patch("torch.ops.xla.ragged_paged_attention"
|
|
|
|
) as mock_ragged_paged_attention:
|
|
|
|
attn_impl.forward(
|
|
|
|
layer=layer,
|
|
|
|
query=query,
|
|
|
|
key=key,
|
|
|
|
value=value,
|
|
|
|
kv_cache=kv_cache,
|
|
|
|
attn_metadata=attn_metadata,
|
|
|
|
)
|
|
|
|
|
|
|
|
mock_ragged_paged_attention.assert_called_once_with(
|
|
|
|
ANY, # query
|
|
|
|
ANY, # kv_cache
|
|
|
|
ANY, # context_lens
|
|
|
|
ANY, # block_tables
|
|
|
|
ANY, # query_start_loc
|
|
|
|
ANY, # num_seqs
|
2025-04-08 23:46:32 -07:00
|
|
|
num_kv_pages_per_block=None,
|
|
|
|
num_queries_per_block=None,
|
|
|
|
vmem_limit_bytes=None,
|
2025-04-03 14:23:28 -07:00
|
|
|
use_kernel=True,
|
|
|
|
sm_scale=scale,
|
|
|
|
sliding_window=sliding_window,
|
|
|
|
soft_cap=logits_soft_cap,
|
|
|
|
)
|