vllm/tests/neuron/test_block_table.py
2025-02-20 17:45:45 -08:00

154 lines
4.2 KiB
Python

# SPDX-License-Identifier: Apache-2.0
import os
import neuronxcc.nki.language as nl
import pytest
import torch
import torch.nn.functional as F
from neuronxcc import nki
from vllm.attention.ops.nki_flash_attn import (
load_block_tables, transform_block_tables_for_indirect_load)
def is_power_of_2(n):
return n > 0 and (n & (n - 1) == 0)
def nki_load_and_transform_block_tables(
block_tables,
num_tiles,
num_blocks_per_tile,
num_head,
head_id,
block_size_tiling_factor,
):
assert is_power_of_2(
num_blocks_per_tile), f"{num_blocks_per_tile=} must be power of 2"
block_tables_sbuf = load_block_tables(block_tables, num_tiles,
num_blocks_per_tile)
# we need to pass an Index as head_id
head_id = nl.arange(1)[None, :] + head_id
block_tables_transposed = transform_block_tables_for_indirect_load(
block_tables_sbuf, block_size_tiling_factor, num_head, head_id)
B_P_SIZE = 128
assert block_tables_transposed.shape[1] == B_P_SIZE
out = nl.ndarray(
block_tables_transposed.shape,
dtype=nl.int32,
buffer=nl.shared_hbm,
)
for i in nl.affine_range(block_tables_transposed.shape[0]):
nl.store(dst=out[i], value=block_tables_transposed[i])
return out
def ref_block_tables_transform(
block_tables,
num_tiles,
num_blocks_per_tile,
num_head,
head_id,
block_size_tiling_factor,
):
assert block_tables.numel() == num_tiles * num_blocks_per_tile
block_tables = block_tables.view(num_tiles, num_blocks_per_tile)
B_F_SIZE = 128
num_tiles_padded = (num_tiles + B_F_SIZE - 1) // B_F_SIZE * B_F_SIZE
block_tables = F.pad(
block_tables,
(0, 0, 0, num_tiles_padded - num_tiles),
"constant",
0,
)
block_tables = block_tables * num_head + head_id
block_tables = block_tables.view(num_tiles_padded, num_blocks_per_tile, 1)
offset = torch.arange(0, block_size_tiling_factor).view(1, 1, -1)
block_tables = block_tables * block_size_tiling_factor + offset
block_tables_transposed = block_tables.view(num_tiles_padded, -1).t()
num_blocks_per_tile = block_tables_transposed.shape[0]
assert num_blocks_per_tile % B_F_SIZE == 0
return block_tables_transposed.view(num_blocks_per_tile // B_F_SIZE,
B_F_SIZE, num_tiles_padded)
@pytest.mark.parametrize(
"q_head_per_kv_head,head_id",
[
(1, 0),
(3, 1),
],
)
@pytest.mark.parametrize(
"num_tiles,num_blocks_per_tile",
[
(1, 1),
(13, 16),
(17, 128),
(35, 512),
(128, 128),
(130, 64),
(280, 256),
(315, 1),
],
)
@torch.inference_mode()
def test_load_and_transform_block_tables(
num_tiles,
num_blocks_per_tile,
q_head_per_kv_head,
head_id,
) -> None:
import torch_xla.core.xla_model as xm
device = xm.xla_device()
compiler_flags = [
"-O1",
"--retry_failed_compilation",
]
compiler_flags_str = " ".join(compiler_flags)
os.environ["NEURON_CC_FLAGS"] = compiler_flags_str
torch.manual_seed(10000)
torch.set_printoptions(sci_mode=False)
# On Neuron, we need B_P_SIZE = 128 blocks to make DMA efficient
B_P_SIZE = 128
if num_blocks_per_tile < B_P_SIZE:
assert B_P_SIZE % num_blocks_per_tile == 0
block_size_tiling_factor = B_P_SIZE // num_blocks_per_tile
else:
block_size_tiling_factor = 1
max_num_blocks = 100000
block_tables = torch.randint(
0,
max_num_blocks,
(num_tiles * num_blocks_per_tile, ),
dtype=torch.int32,
)
nki_out = nki.jit(nki_load_and_transform_block_tables)[1, 1](
block_tables.to(device=device),
num_tiles,
num_blocks_per_tile,
q_head_per_kv_head,
head_id,
block_size_tiling_factor,
).cpu()
ref_out = ref_block_tables_transform(
block_tables,
num_tiles,
num_blocks_per_tile,
q_head_per_kv_head,
head_id,
block_size_tiling_factor,
)
assert (nki_out.shape == ref_out.shape
), f"{nki_out.shape=} != {ref_out.shape=}"
assert torch.all(nki_out == ref_out)