# SPDX-License-Identifier: Apache-2.0 import os import neuronxcc.nki.language as nl import pytest import torch import torch.nn.functional as F from neuronxcc import nki from vllm.attention.ops.nki_flash_attn import ( load_block_tables, transform_block_tables_for_indirect_load) def is_power_of_2(n): return n > 0 and (n & (n - 1) == 0) def nki_load_and_transform_block_tables( block_tables, num_tiles, num_blocks_per_tile, num_head, head_id, block_size_tiling_factor, ): assert is_power_of_2( num_blocks_per_tile), f"{num_blocks_per_tile=} must be power of 2" block_tables_sbuf = load_block_tables(block_tables, num_tiles, num_blocks_per_tile) # we need to pass an Index as head_id head_id = nl.arange(1)[None, :] + head_id block_tables_transposed = transform_block_tables_for_indirect_load( block_tables_sbuf, block_size_tiling_factor, num_head, head_id) B_P_SIZE = 128 assert block_tables_transposed.shape[1] == B_P_SIZE out = nl.ndarray( block_tables_transposed.shape, dtype=nl.int32, buffer=nl.shared_hbm, ) for i in nl.affine_range(block_tables_transposed.shape[0]): nl.store(dst=out[i], value=block_tables_transposed[i]) return out def ref_block_tables_transform( block_tables, num_tiles, num_blocks_per_tile, num_head, head_id, block_size_tiling_factor, ): assert block_tables.numel() == num_tiles * num_blocks_per_tile block_tables = block_tables.view(num_tiles, num_blocks_per_tile) B_F_SIZE = 128 num_tiles_padded = (num_tiles + B_F_SIZE - 1) // B_F_SIZE * B_F_SIZE block_tables = F.pad( block_tables, (0, 0, 0, num_tiles_padded - num_tiles), "constant", 0, ) block_tables = block_tables * num_head + head_id block_tables = block_tables.view(num_tiles_padded, num_blocks_per_tile, 1) offset = torch.arange(0, block_size_tiling_factor).view(1, 1, -1) block_tables = block_tables * block_size_tiling_factor + offset block_tables_transposed = block_tables.view(num_tiles_padded, -1).t() num_blocks_per_tile = block_tables_transposed.shape[0] assert num_blocks_per_tile % B_F_SIZE == 0 return block_tables_transposed.view(num_blocks_per_tile // B_F_SIZE, B_F_SIZE, num_tiles_padded) @pytest.mark.parametrize( "q_head_per_kv_head,head_id", [ (1, 0), (3, 1), ], ) @pytest.mark.parametrize( "num_tiles,num_blocks_per_tile", [ (1, 1), (13, 16), (17, 128), (35, 512), (128, 128), (130, 64), (280, 256), (315, 1), ], ) @torch.inference_mode() def test_load_and_transform_block_tables( num_tiles, num_blocks_per_tile, q_head_per_kv_head, head_id, ) -> None: import torch_xla.core.xla_model as xm device = xm.xla_device() compiler_flags = [ "-O1", "--retry_failed_compilation", ] compiler_flags_str = " ".join(compiler_flags) os.environ["NEURON_CC_FLAGS"] = compiler_flags_str torch.manual_seed(10000) torch.set_printoptions(sci_mode=False) # On Neuron, we need B_P_SIZE = 128 blocks to make DMA efficient B_P_SIZE = 128 if num_blocks_per_tile < B_P_SIZE: assert B_P_SIZE % num_blocks_per_tile == 0 block_size_tiling_factor = B_P_SIZE // num_blocks_per_tile else: block_size_tiling_factor = 1 max_num_blocks = 100000 block_tables = torch.randint( 0, max_num_blocks, (num_tiles * num_blocks_per_tile, ), dtype=torch.int32, ) nki_out = nki.jit(nki_load_and_transform_block_tables)[1, 1]( block_tables.to(device=device), num_tiles, num_blocks_per_tile, q_head_per_kv_head, head_id, block_size_tiling_factor, ).cpu() ref_out = ref_block_tables_transform( block_tables, num_tiles, num_blocks_per_tile, q_head_per_kv_head, head_id, block_size_tiling_factor, ) assert (nki_out.shape == ref_out.shape ), f"{nki_out.shape=} != {ref_out.shape=}" assert torch.all(nki_out == ref_out)