# SPDX-License-Identifier: Apache-2.0 from unittest.mock import ANY, patch import torch from vllm.attention.backends.abstract import AttentionType from vllm.v1.attention.backends.pallas import (PallasAttentionBackendImpl, PallasMetadata) def test_ragged_paged_attention(): # We verify that the kernel inputs such as sliding_window, etc. are passed # in from the model correctly. # The correctness of the paged attention kernel is tested in the kernel # library. num_heads = 4 head_size = 128 scale = 1.0 num_kv_heads = 4 sliding_window = 128 logits_soft_cap = 50.0 attn_impl = PallasAttentionBackendImpl( num_heads=num_heads, head_size=head_size, scale=scale, num_kv_heads=num_kv_heads, alibi_slopes=None, sliding_window=sliding_window, kv_cache_dtype="auto", logits_soft_cap=logits_soft_cap, attn_type=AttentionType.DECODER, ) class FakeAttentionLayer: _k_scale_float: float _v_scale_float: float layer = FakeAttentionLayer() layer._k_scale_float = 1.0 layer._v_scale_float = 1.0 num_tokens = 16 num_blocks = 1024 block_size = 16 query = torch.zeros(num_tokens, num_heads * head_size) key = torch.zeros(num_tokens, num_kv_heads * head_size) value = torch.zeros(num_tokens, num_kv_heads * head_size) kv_cache = torch.zeros(num_blocks, block_size, num_kv_heads * 2, head_size) slot_mapping = torch.zeros(num_tokens, dtype=torch.int64) max_num_reqs = 8 max_num_blocks_per_req = 8 block_tables = torch.zeros((max_num_reqs, max_num_blocks_per_req), dtype=torch.int32) context_lens = torch.ones((max_num_reqs, ), dtype=torch.int32) query_lens = [1] * max_num_reqs query_start_loc = torch.cumsum(torch.tensor([0] + query_lens, dtype=torch.int32), dim=0, dtype=torch.int32) num_seqs = torch.tensor([max_num_reqs], dtype=torch.int32) attn_metadata = PallasMetadata( slot_mapping=slot_mapping, block_tables=block_tables, context_lens=context_lens, query_start_loc=query_start_loc, num_seqs=num_seqs, ) with patch("torch.ops.xla.ragged_paged_attention" ) as mock_ragged_paged_attention: attn_impl.forward( layer=layer, query=query, key=key, value=value, kv_cache=kv_cache, attn_metadata=attn_metadata, ) mock_ragged_paged_attention.assert_called_once_with( ANY, # query ANY, # kv_cache ANY, # context_lens ANY, # block_tables ANY, # query_start_loc ANY, # num_seqs num_kv_pages_per_block=None, num_queries_per_block=None, vmem_limit_bytes=None, use_kernel=True, sm_scale=scale, sliding_window=sliding_window, soft_cap=logits_soft_cap, )