[Neuron][Kernel] NKI-based flash-attention kernel with paged KV cache (#11277)

Signed-off-by: Liangfu Chen <liangfc@amazon.com>
Co-authored-by: Jiangfei Duan <jfduan@outlook.com>
This commit is contained in:
Liangfu Chen 2025-01-27 17:31:16 -08:00 committed by GitHub
parent 823ab79633
commit ddee88d0ff
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 1126 additions and 1 deletions

View File

@ -54,4 +54,4 @@ docker run --rm -it --device=/dev/neuron0 --device=/dev/neuron1 --network host \
-e "NEURON_COMPILE_CACHE_URL=${NEURON_COMPILE_CACHE_MOUNT}" \
--name "${container_name}" \
${image_name} \
/bin/bash -c "python3 /workspace/vllm/examples/offline_inference/neuron.py"
/bin/bash -c "python3 /workspace/vllm/examples/offline_inference/neuron.py && python3 -m pytest /workspace/vllm/tests/neuron/ -v --capture=tee-sys"

View File

@ -0,0 +1,456 @@
import random
from typing import Optional
import pytest
import torch
import torch.nn.functional as F
class BlockDiagonalCausalFromBottomRightMask:
@staticmethod
def _from_seqlens(query_lens, seq_lens, block_size=None):
from torch import logical_and, logical_or
contexted = block_size is None
context_lens = torch.tensor(seq_lens) - torch.tensor(query_lens)
n_queries = sum(query_lens)
num_seqs = len(query_lens)
if contexted:
key_lens_blockaligned = seq_lens
else:
n_blocks_per_seq = (context_lens + block_size - 1) // block_size
offset_per_seq = n_blocks_per_seq * block_size
key_lens_blockaligned = offset_per_seq[:num_seqs].tolist()
n_keys = sum(key_lens_blockaligned)
a = (torch.arange(n_queries).reshape(n_queries,
1).expand(n_queries, n_keys))
b = torch.arange(n_keys).reshape(1, n_keys).expand(n_queries, n_keys)
q_cumsum = torch.tensor([0] + query_lens).cumsum(dim=0)
k_cumsum = torch.tensor([0] + key_lens_blockaligned).cumsum(dim=0)
prior_mask = torch.zeros(n_queries, n_keys)
new_masks: list[torch.Tensor] = []
for seq_id in range(num_seqs):
ri = q_cumsum[seq_id]
ci = k_cumsum[seq_id]
nr = query_lens[seq_id]
if contexted:
nc = seq_lens[seq_id]
a_offset = ci + nc - ri - nr
new_mask = (a + a_offset) >= b
else:
nc = context_lens[seq_id]
a_offset = ci + nc - 1
new_mask = a_offset >= b
left_mask = b >= ci
top_mask = a >= ri
bottom_mask = a < (ri + nr)
new_mask = logical_and(
logical_and(logical_and(new_mask, left_mask), top_mask),
bottom_mask,
)
prior_mask = logical_or(prior_mask, new_mask)
new_masks = new_masks + [new_mask]
return prior_mask
@staticmethod
def from_seqlens(query_lens, seq_lens, block_size=None):
contexted = block_size is None
if contexted:
prior_mask = BlockDiagonalCausalFromBottomRightMask._from_seqlens(
query_lens, seq_lens)
active_mask = None
else:
prior_mask = BlockDiagonalCausalFromBottomRightMask._from_seqlens(
query_lens, seq_lens, block_size)
active_mask = BlockDiagonalCausalFromBottomRightMask._from_seqlens(
query_lens, query_lens)
return prior_mask, active_mask
def ref_softmax(x: torch.Tensor,
dim: int,
mixed_precision=False,
return_max_reduce=False):
max_value = torch.amax(x, dim=dim, keepdims=True)
exp = torch.exp(x - max_value)
if mixed_precision:
sum_value = torch.sum(exp.astype(torch.float32),
dim=dim,
keepdims=True).astype(x.dtype)
else:
sum_value = torch.sum(exp, dim=dim, keepdims=True)
if return_max_reduce:
return exp / sum_value, max_value, torch.reciprocal(sum_value)
return exp / sum_value
def ref_masked_attention(
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
scale: float,
attn_mask: Optional[torch.Tensor] = None,
return_max_reduce: Optional[bool] = False,
) -> torch.Tensor:
scaled_qk = scale * torch.einsum("qhd,khd->hqk", query, key).float()
if attn_mask is not None:
masked_score = scaled_qk + attn_mask.float()
if return_max_reduce:
norm_score, cached_max, cached_sum_reciprocal = ref_softmax(
masked_score, dim=-1, return_max_reduce=True)
else:
norm_score = ref_softmax(masked_score, dim=-1)
out = torch.einsum("hqk,khd->qhd", norm_score, value)
if return_max_reduce:
return (
out,
cached_max,
cached_sum_reciprocal,
norm_score,
masked_score,
scaled_qk,
)
else:
return out
def ref_context_attention(
query,
key,
value,
query_lens,
seq_lens,
head_size,
num_kv_heads,
num_heads,
num_queries_per_kv,
return_max_reduce=False,
):
scale = float(1.0 / (head_size**0.5))
if num_queries_per_kv > 1:
# Handle MQA and GQA
key = torch.repeat_interleave(key, num_queries_per_kv, dim=1)
value = torch.repeat_interleave(value, num_queries_per_kv, dim=1)
attn_mask, _ = BlockDiagonalCausalFromBottomRightMask.from_seqlens(
query_lens, seq_lens)
# convert binary mask to -inf values
attn_mask = torch.logical_not(attn_mask)
attn_mask = attn_mask.float() * -30000
output, cached_max, cached_sum_reciprocal, lse, masked_score, scaled_qk = (
ref_masked_attention(
query,
key,
value,
scale,
attn_mask,
return_max_reduce=return_max_reduce,
))
output = output.unsqueeze(1)
if return_max_reduce:
return (
output,
cached_max,
cached_sum_reciprocal,
lse,
masked_score,
scaled_qk,
)
else:
return output
@pytest.mark.parametrize(
"num_heads,num_queries_per_kv,head_size,mixed_precision",
[
(4, 2, 8, False),
(4, 2, 8, True),
(32, 8, 64, True),
],
)
@torch.inference_mode()
def test_contexted_kv_attention(
num_heads: int,
num_queries_per_kv: int,
head_size: int,
mixed_precision: bool,
) -> None:
import os
import torch_xla.core.xla_model as xm
from vllm.attention.ops.nki_flash_attn import flash_attn_varlen_nkifunc
device = xm.xla_device()
os.environ["NEURON_CC_FLAGS"] = (
" --model-type=transformer -O1 "
" --internal-hlo2tensorizer-options='--verify-hlo' ")
random.seed(0)
torch.manual_seed(0)
torch.set_printoptions(sci_mode=False)
min_ctx_len = 2
max_ctx_len = 64
min_query_len = 2
max_query_len = 64
prefill_batch_size = 2
decode_batch_size = 6
batch_size = prefill_batch_size + decode_batch_size
block_size = 32
max_model_len = (max_query_len + max_ctx_len) * 4
max_block_per_request = max_model_len // block_size
dtype = torch.float32
cache_size = (batch_size * max_block_per_request) + 2
ctx_lens = [
random.randint(min_ctx_len, max_ctx_len)
for _ in range(prefill_batch_size)
] + [
random.randint(min_ctx_len, max_ctx_len)
for _ in range(decode_batch_size)
]
query_lens = [
random.randint(min_query_len, max_query_len)
for _ in range(prefill_batch_size)
] + [1 for _ in range(decode_batch_size)]
seq_lens = [a + b for a, b in zip(query_lens, ctx_lens)]
num_kv_heads = num_heads // num_queries_per_kv
num_tokens = sum(query_lens)
query = torch.empty(num_tokens, num_heads, head_size, dtype=dtype)
query.uniform_(-1, 1)
torch.empty(num_tokens, num_heads, head_size, dtype=dtype)
kv = torch.empty(sum(seq_lens), 2, num_kv_heads, head_size, dtype=dtype)
kv.uniform_(-1, 1)
key, value = kv.unbind(dim=1)
k_cache = torch.zeros(cache_size,
block_size,
num_kv_heads,
head_size,
dtype=dtype)
v_cache = torch.zeros(cache_size,
block_size,
num_kv_heads,
head_size,
dtype=dtype)
k = torch.zeros(sum(query_lens), num_kv_heads, head_size, dtype=dtype)
v = torch.zeros(sum(query_lens), num_kv_heads, head_size, dtype=dtype)
values = torch.arange(0, cache_size, dtype=torch.long)
values = values[torch.randperm(cache_size)]
block_table = values[:batch_size * max_block_per_request].view(
batch_size, max_block_per_request)
torch.tensor(seq_lens, dtype=torch.long)
b_ctx_len = torch.tensor(ctx_lens, dtype=torch.long)
b_start_loc = torch.cumsum(torch.tensor([0] + query_lens[:-1],
dtype=torch.long),
dim=0)
# copy kv to cache
b_seq_start_loc = torch.cumsum(torch.tensor([0] + seq_lens[:-1],
dtype=torch.long),
dim=0)
for i in range(batch_size):
for j in range(query_lens[i]):
k[b_start_loc[i] + j].copy_(key[b_seq_start_loc[i] + b_ctx_len[i] +
j])
v[b_start_loc[i] + j].copy_(value[b_seq_start_loc[i] +
b_ctx_len[i] + j])
cur_ctx = 0
block_id = 0
while cur_ctx < b_ctx_len[i]:
start_loc = b_seq_start_loc[i] + cur_ctx
if cur_ctx + block_size > b_ctx_len[i]:
end_loc = b_seq_start_loc[i] + b_ctx_len[i]
else:
end_loc = start_loc + block_size
start_slot = block_table[i, block_id] * block_size
end_slot = start_slot + end_loc - start_loc
k_cache.view(-1, num_kv_heads,
head_size)[start_slot:end_slot].copy_(
key[start_loc:end_loc])
v_cache.view(-1, num_kv_heads,
head_size)[start_slot:end_slot].copy_(
value[start_loc:end_loc])
cur_ctx += block_size
block_id += 1
(
output_ref,
cached_max,
cached_sum_reciprocal,
lse,
masked_score,
scaled_qk,
) = ref_context_attention(
query,
key,
value,
query_lens,
seq_lens,
head_size,
num_kv_heads,
num_heads,
num_queries_per_kv,
return_max_reduce=True,
)
# build neuron program
return_debug_tensors = False
B_P_SIZE = 128
LARGE_TILE_SZ = 2048
max_num_queries = (
(sum(query_lens) + block_size - 1) // block_size) * block_size
def get_active_block_tables(block_tables, query_lens, seq_lens, block_size,
num_blocks):
context_lens = seq_lens - query_lens
blocks_per_seq = (context_lens + block_size - 1) // block_size
num_seqs = len(seq_lens)
active_blocks: list[int] = []
for seq_id in range(num_seqs):
active_blocks = (
active_blocks +
block_tables[seq_id, :blocks_per_seq[seq_id]].tolist())
return F.pad(
torch.tensor(active_blocks),
(0, num_blocks - len(active_blocks)),
"constant",
0,
)
def shift_bit_length(x):
return 1 << (x - 1).bit_length()
# calculate input shapes
max_num_queries_shifted = shift_bit_length(max_num_queries)
max_num_queries_factor = B_P_SIZE // max_num_queries_shifted
max_num_queries_padded = max_num_queries_shifted * max_num_queries_factor
assert (max_num_queries_padded == B_P_SIZE
), "invalid {max_num_queries_padded=}"
head_size_padded = B_P_SIZE
context_lens = torch.tensor(seq_lens) - torch.tensor(query_lens)
num_active_blocks_shifted = shift_bit_length(
((context_lens + block_size - 1) // block_size).sum().item())
num_active_blocks_factor = (LARGE_TILE_SZ // block_size //
num_active_blocks_shifted)
num_active_blocks = num_active_blocks_shifted * num_active_blocks_factor
assert (num_active_blocks *
block_size) == LARGE_TILE_SZ, "invalid {num_active_blocks=}"
context_kv_len = num_active_blocks * block_size
assert context_kv_len == LARGE_TILE_SZ, f"invalid {context_kv_len=}"
# pad QKV tensors
pad_dims = (
0,
head_size_padded - query.shape[2],
0,
0,
0,
max_num_queries_padded - query.shape[0],
)
query = F.pad(query, pad_dims, "constant", 0)
k = F.pad(k, pad_dims, "constant", 0)
v = F.pad(v, pad_dims, "constant", 0)
k_cache = F.pad(k_cache, (0, head_size_padded - head_size), "constant", 0)
v_cache = F.pad(v_cache, (0, head_size_padded - head_size), "constant", 0)
# permute QKV tensors
# query: (1, n_heads, d, seq_q)
# key: (1, n_kv_heads, d, seq_k)
# value: (1, n_kv_heads, seq_v, d)
query = query.unsqueeze(0).permute(0, 2, 3, 1).contiguous()
k = k.unsqueeze(0).permute(0, 2, 3, 1).contiguous()
v = v.unsqueeze(0).permute(0, 2, 1, 3).contiguous()
# transform block table
active_block_table = get_active_block_tables(
block_table,
torch.tensor(query_lens),
torch.tensor(seq_lens),
block_size,
num_active_blocks,
)
# Build attention masks
prior_mask, active_mask = (
BlockDiagonalCausalFromBottomRightMask.from_seqlens(
query_lens, seq_lens, block_size=block_size))
attn_mask = torch.concat(
[
F.pad(
prior_mask,
(
0,
context_kv_len - prior_mask.shape[1],
0,
B_P_SIZE - prior_mask.shape[0],
),
"constant",
0,
).bool(),
F.pad(
active_mask,
(
0,
B_P_SIZE - active_mask.shape[1],
0,
B_P_SIZE - active_mask.shape[0],
),
"constant",
0,
).bool(),
],
dim=1,
)
input_args = (
query.to(device=device),
k.to(device=device),
v.to(device=device),
k_cache.to(device=device),
v_cache.to(device=device),
active_block_table.to(torch.int32).to(device=device),
attn_mask.to(device=device),
)
input_kwargs = dict(
n_kv_head=num_kv_heads,
head_size=head_size,
mixed_precision=mixed_precision,
)
if return_debug_tensors:
output_nki, *debug_tensors = flash_attn_varlen_nkifunc(
*input_args, **input_kwargs)
else:
output_nki = flash_attn_varlen_nkifunc(*input_args, **input_kwargs)
debug_tensors = []
output_nki = torch.tensor(output_nki).cpu()
debug_tensors = [torch.tensor(dt).cpu() for dt in debug_tensors]
num_actual_tokens = sum(query_lens)
print(f"{num_actual_tokens=}")
# - o: shape (bs, n_heads, seq_q, d) -> (bs, seq_q, n_heads, d)
output_nki = output_nki.permute(
0, 2, 1, 3)[:, :, :, :head_size].cpu()[0, :num_actual_tokens, :, :]
output_ref_padded = F.pad(
output_ref,
(0, 0, 0, 0, 0, 0, 0, max_num_queries_padded - output_ref.shape[0]),
"constant",
0,
)
output_ref = output_ref_padded.transpose(0, 1)[0, :num_actual_tokens, :, :]
torch.testing.assert_close(output_nki, output_ref, atol=1e-2, rtol=0)

View File

@ -0,0 +1,669 @@
from dataclasses import dataclass
import neuronxcc.nki.isa as nisa
import neuronxcc.nki.language as nl
import numpy as np
from neuronxcc import nki
from neuronxcc.nki.language import par_dim
@dataclass(frozen=True)
class FlashConfig:
"""
Config class for flash attention with default values
"""
seq_tile_size: int = 2048
should_transpose_v: bool = False
__annotations__ = {
"seq_tile_size": int,
"should_transpose_v": bool,
}
@nki.jit
def transpose_p_local(p_local_transposed,
p_local,
LARGE_TILE_SZ,
forward_mask,
B_F_SIZE=512):
for i in nl.affine_range(LARGE_TILE_SZ // B_F_SIZE):
if nisa.get_nc_version() == nisa.nc_version.gen3:
p_local_t_tmp = nl.ndarray((par_dim(128), B_F_SIZE),
buffer=nl.sbuf,
dtype=p_local.dtype)
else:
p_local_t_tmp = nl.ndarray((par_dim(128), B_F_SIZE),
buffer=nl.psum,
dtype=np.float32)
for j in nl.affine_range(B_F_SIZE // 128):
j_128_slice = nl.ds(j * 128, 128)
i_j_128_slice = nl.ds(i * B_F_SIZE + j * 128, 128)
if nisa.get_nc_version() == nisa.nc_version.gen3:
p_local_t_tmp[:, j_128_slice] = nisa.dma_transpose(
p_local[:, i_j_128_slice], mask=forward_mask)
else:
p_local_t_tmp[:, j_128_slice] = nisa.nc_transpose(
p_local[:, i_j_128_slice], mask=forward_mask)
p_local_transposed[:, nl.ds(i * B_F_SIZE, B_F_SIZE)] = nl.copy(
p_local_t_tmp, dtype=p_local_transposed.dtype, mask=forward_mask)
@nki.jit
def _flash_attention_core(
q_local_tile,
k,
v,
q_h_per_k_h,
seqlen_q,
nheads,
o_buffer,
l_buffer,
m_buffer,
batch_id,
head_id,
gqa_head_idx,
q_tile_idx,
local_k_large_tile_idx,
kernel_dtype,
acc_type,
flash_config: FlashConfig,
use_causal_mask=False,
continuous_batching_mask=None,
initialize=False,
B_P_SIZE=128,
B_F_SIZE=512,
B_D_SIZE=128,
dropout_p=0.0,
dropout_p_tensor=None,
seed_tensor=None,
logit_bias_tile=None,
qk_res_buffer=None,
):
"""
The flash attention core function to calculate self attention between a tile
of q and a block of K and V.
The q_local_tile has (B_P_SIZE, B_F_SIZE), which is loaded into the SBUF
already. The block size of K and V
is defined in the seq_tile_size of the flash_config. The results are stored
in the following three buffers
o_buffer: (B_P_SIZE, d)
l_buffer: (B_P_SIZE, 1)
m_buffer: (B_P_SIZE, 1)
"""
LARGE_TILE_SZ = flash_config.seq_tile_size
num_k_tile_per_large_tile = LARGE_TILE_SZ // B_F_SIZE
seqlen_k = k.shape[-1]
seqlen_q // B_P_SIZE
seqlen_k // B_F_SIZE
# TODO : support logit_bias with continuous_batching_mask
assert not use_causal_mask, "causal mask is not supported."
assert (continuous_batching_mask
is not None), "continuous_batching_mask input is required."
if continuous_batching_mask is not None:
assert (logit_bias_tile is
None), "continuous_batching_mask does not support logit_bias!"
# mask are used to only apply computation to the lower half of the matrix,
# which reduce the arthimetic intensity by half
forward_mask = (q_tile_idx * B_P_SIZE >= local_k_large_tile_idx *
LARGE_TILE_SZ if use_causal_mask else None)
qk_res_buf = nl.ndarray((par_dim(B_P_SIZE), LARGE_TILE_SZ),
buffer=nl.sbuf,
dtype=acc_type)
max_local = nl.ndarray((par_dim(B_P_SIZE), num_k_tile_per_large_tile),
dtype=acc_type)
for k_i in nl.affine_range(num_k_tile_per_large_tile):
k_i_b_f_slice = nl.ds(k_i * B_F_SIZE, B_F_SIZE)
qk_psum = nl.zeros((par_dim(B_P_SIZE), B_F_SIZE),
dtype=np.float32,
buffer=nl.psum) # (128, 512)
qk_psum[:, :] = nl.matmul(q_local_tile,
k[:, k_i_b_f_slice],
transpose_x=True,
mask=None) # (p(128), 512)
qk_res_buf[:, k_i_b_f_slice] = nl.where(
continuous_batching_mask[:, k_i_b_f_slice],
qk_psum[:, nl.ds(0, B_F_SIZE)],
-9984.0,
dtype=acc_type,
)
# Calculate max of the current tile
max_local[:, k_i] = nisa.tensor_reduce(
np.max,
qk_res_buf[:, k_i_b_f_slice],
axis=(1, ),
dtype=acc_type,
negate=False,
mask=forward_mask,
)
if qk_res_buffer is not None:
qk_res_buffer[:, :] = nl.copy(qk_res_buf[:, :])
max_ = nisa.tensor_reduce(
np.max,
max_local[:, :],
axis=(1, ),
dtype=acc_type,
negate=False,
mask=forward_mask,
)
o_previous_scaled = nl.ndarray((par_dim(B_P_SIZE), B_D_SIZE),
dtype=o_buffer.dtype)
if initialize:
m_buffer[:, 0] = nl.copy(max_)
m_current = max_
else:
m_previous = nl.copy(m_buffer[:, 0])
m_buffer[:, 0] = nl.maximum(m_previous, max_,
mask=forward_mask) # (128,1)
m_current = m_buffer[:, 0]
# Compute scaling factor
alpha = nisa.activation(
np.exp,
m_previous,
bias=-1 * m_current,
scale=1.0,
mask=forward_mask,
)
o_previous_scaled[...] = nl.multiply(o_buffer[:, :],
alpha,
mask=forward_mask)
p_local = nl.ndarray((par_dim(B_P_SIZE), LARGE_TILE_SZ),
dtype=kernel_dtype)
REDUCTION_TILE = min(2048, LARGE_TILE_SZ // 2)
p_partial_sum = nl.ndarray(
(par_dim(B_P_SIZE), LARGE_TILE_SZ // REDUCTION_TILE), dtype=acc_type)
for k_r_i in nl.affine_range(LARGE_TILE_SZ // REDUCTION_TILE):
k_r_i_reduce_slice = nl.ds(k_r_i * REDUCTION_TILE, REDUCTION_TILE)
# compute exp(qk - max)
# Compute partial row - tile sum of exp(qk - max))
# FIXME : Use activation accumulate to accumulate over k_r_i loop ?
p_local[:, k_r_i_reduce_slice] = nisa.activation_reduce(
np.exp,
qk_res_buf[:, k_r_i_reduce_slice],
bias=-1 * m_current,
scale=1.0,
reduce_op=nl.add,
reduce_res=p_partial_sum[:, k_r_i],
dtype=kernel_dtype,
mask=forward_mask,
)
ps = nl.sum(p_partial_sum, axis=1, dtype=acc_type, mask=forward_mask)
p_local_transposed = nl.ndarray((par_dim(B_P_SIZE), LARGE_TILE_SZ),
dtype=kernel_dtype)
transpose_p_local(
p_local_transposed=p_local_transposed,
p_local=p_local,
LARGE_TILE_SZ=LARGE_TILE_SZ,
forward_mask=forward_mask,
B_F_SIZE=B_F_SIZE,
)
pv_psum = nl.zeros((par_dim(B_P_SIZE), B_D_SIZE),
dtype=np.float32,
buffer=nl.psum)
for k_i in nl.affine_range(LARGE_TILE_SZ // B_P_SIZE):
pv_psum[:, :] += nl.matmul(
p_local_transposed[:, nl.ds(k_i * B_P_SIZE, B_P_SIZE)],
v[k_i, :, :],
transpose_x=True,
mask=forward_mask,
) # (128, 128) (p(Br), d)
if initialize:
o_buffer[:, :] = nl.copy(pv_psum[:, :])
l_buffer[:, 0] = nl.add(nl.log(ps), max_)
else:
o_buffer[:, :] = nl.add(o_previous_scaled, pv_psum, mask=forward_mask)
l_prev = l_buffer[:, 0]
l_exp = nl.add(
nl.exp(
nl.subtract(l_prev, m_current, mask=forward_mask),
mask=forward_mask,
),
ps,
mask=forward_mask,
)
l_buffer[:, 0] = nl.add(m_current,
nl.log(l_exp, mask=forward_mask),
mask=forward_mask)
@nki.jit
def load_v_tile(v_hbm_tile, cur_v_tile, j, v_i, config):
LARGE_TILE_SZ = config.seq_tile_size
B_P_SIZE = 128
if not config.should_transpose_v:
cur_v_tile[v_i, :, :] = nl.load(
v_hbm_tile[nl.ds(j * LARGE_TILE_SZ + B_P_SIZE * v_i, B_P_SIZE), :],
dtype=cur_v_tile.dtype,
)
return
if nisa.get_nc_version() == nisa.nc_version.gen3:
cur_v_tile_transposed = nisa.dma_transpose(
v_hbm_tile[:,
nl.ds(j * LARGE_TILE_SZ + B_P_SIZE * v_i, B_P_SIZE)])
cur_v_tile[v_i, :, :] = nisa.tensor_copy(cur_v_tile_transposed,
dtype=cur_v_tile.dtype)
return
cur_v_tile[v_i, :, :] = nl.load_transpose2d(
v_hbm_tile[:, nl.ds(j * LARGE_TILE_SZ + B_P_SIZE * v_i, B_P_SIZE)],
dtype=cur_v_tile.dtype,
)
@nki.jit
def flash_paged_attention(
query,
key,
value,
key_cache,
value_cache,
block_tables,
mask,
softmax_scale=None,
mixed_precision=True,
config=None,
return_debug_tensors=False,
):
"""
Flash PagedAttention Forward Kernel.
- PagedAttention Paper: https://arxiv.org/abs/2309.06180
- Chunked Prefill Paper: https://arxiv.org/abs/2403.02310
IO tensor layouts:
- query: shape (1, n_heads, d, seq_q)
- key: shape (1, n_kv_heads, d, seq_k)
- value: shape (1, n_kv_heads, seq_v, d)
- key_cache: (num_blocks, block_size, n_kv_heads, d)
- value_cache: (num_blocks, block_size, n_kv_heads, d)
- block_tables: (num_active_blocks, )
- mask: (seq_q, num_active_blocks * block_size)
- o: shape (1, n_heads, seq_q, d)
- l_m: shape (1, n_heads, seq_q, 2)
- This kernel requires seq_k == seq_v
- We use continuous batching by default, so the batch dimension is
always 1, and different requests are concatenated along sequence
dimension.
- We use paged cache blocks (key_cache, value_cache) to store KV cache.
IO tensor dtypes:
- This kernel assumes all IO tensors have the same dtype except for
block_tables (int32) and mask (int32)
- If mixed_percision is True, then all Tensor Engine operation will be
performed in bfloat16 and accumulation will be performed in float32.
Otherwise the intermediates will be in the same type as the inputs.
Compile-time Constants:
- softmax_scale: scaling for softmax, is None, default is `1.0/(d**0.5)`
- mixed_precision: flag to set non-matmul ops in fp32 precision, default
is set to `true`, if false, we use same precision as input types
- config: Instance of dataclass :class:`nki.kernels.attention.FlashConfig`
with Performance config parameters for flash attention with default
values
seq_tile_size: `default=2048`, size of the kv tile size for attention
computation reduction
GQA support Notes:
the spmd kernel for launching kernel should be on kv_heads instead of
nheads
Example usage:
MHA: q: [b, h, d, s], k: [b, h, d, s], v: [b, h, s, d]
usage: `flash_fwd[b, h](q, k, v, ...)`
GQA: q: [b, h, d, s], k: [b, kv_h, d, s], v: [b, kv_h, s, d]
usage: `flash_fwd[b, kv_h](q, k, v, ...)`
"""
config = config or FlashConfig()
B_F_SIZE = 512
B_P_SIZE = 128
b, h, d, seqlen_q = query.shape
B_D_SIZE = d
LARGE_TILE_SZ = config.seq_tile_size
n_tile_q = seqlen_q // B_P_SIZE # since q will be loaded on tensor engine
num_blocks, block_size, k_h, _ = key_cache.shape
q_h_per_k_h = h // k_h
assert tuple(key_cache.shape) == (
num_blocks,
block_size,
k_h,
d,
), "Input shape mismatch!"
assert tuple(value_cache.shape) == (
num_blocks,
block_size,
k_h,
d,
), "Input shape mismatch!"
assert b == 1, f"invalid batch size {b=}"
assert d <= 128, f" we do not support head_dim > 128, got head dim {d}"
kernel_dtype = nl.bfloat16 if mixed_precision else query.dtype
acc_type = np.dtype(np.float32) if mixed_precision else kernel_dtype
o = nl.ndarray((b, h, seqlen_q, d),
dtype=query.dtype,
buffer=nl.shared_hbm)
hbm_l_buffer, hbm_m_buffer, hbm_qk_res, qk_res_buffer = (
None,
None,
None,
None,
)
if return_debug_tensors:
hbm_l_buffer = nl.ndarray((b, h, seqlen_q),
dtype=acc_type,
buffer=nl.shared_hbm)
hbm_m_buffer = nl.ndarray((b, h, seqlen_q),
dtype=acc_type,
buffer=nl.shared_hbm)
hbm_qk_res = nl.ndarray((b, h, B_P_SIZE, seqlen_q),
dtype=acc_type,
buffer=nl.shared_hbm)
qk_res_buffer = nl.zeros(
(n_tile_q, q_h_per_k_h, par_dim(B_P_SIZE), seqlen_q),
dtype=acc_type,
buffer=nl.sbuf,
lazy_initialization=True,
)
assert (
nl.program_ndim() == 2
), f"Expect spmd grid with 2 dimensions, got {nl.program_ndim()} instead!"
batch_id = nl.program_id(axis=0)
head_id = nl.program_id(axis=1)
softmax_scale = softmax_scale or (1.0 / (d**0.5))
(num_active_blocks, ) = block_tables.shape
context_kv_len = num_active_blocks * block_size
assert (config.seq_tile_size >= 512
), f" seq tile_size {config.seq_tile_size} cannot be less than 512"
assert (context_kv_len % LARGE_TILE_SZ == 0
), f"Need {context_kv_len=} to be divisible by {LARGE_TILE_SZ=}"
assert (
LARGE_TILE_SZ % B_P_SIZE == 0
), f"Need LARGE_TILE_SZ ({LARGE_TILE_SZ}) to be divisible by {B_P_SIZE=}"
assert (B_P_SIZE % block_size == 0
), f"Need B_P_SIZE ({B_P_SIZE}) to be divisible by {block_size=}"
num_large_k_tile = context_kv_len // LARGE_TILE_SZ
num_blocks_per_large_tile = LARGE_TILE_SZ // block_size
assert (num_blocks_per_large_tile <= B_P_SIZE
), f"The number of blocks in each large tile " \
f"({num_blocks_per_large_tile}) shouldn't exceed partition size {B_P_SIZE}"
block_tables_sbuf = nl.full((par_dim(B_P_SIZE), num_large_k_tile),
0,
dtype=np.int32,
buffer=nl.sbuf)
for j in nl.affine_range(num_large_k_tile):
i_p = nl.arange(num_blocks_per_large_tile)[:, None]
block_tables_sbuf[i_p, j] = nl.load(
block_tables[j * num_blocks_per_large_tile + i_p], dtype=np.int32)
# Global Flash Attention accumulators
o_buffer = nl.zeros(
(n_tile_q, q_h_per_k_h, par_dim(B_P_SIZE), d),
dtype=acc_type,
buffer=nl.sbuf,
lazy_initialization=True,
)
l_buffer = nl.zeros(
(par_dim(B_P_SIZE), n_tile_q, q_h_per_k_h),
dtype=acc_type,
buffer=nl.sbuf,
lazy_initialization=True,
)
m_buffer = nl.zeros(
(n_tile_q, q_h_per_k_h, par_dim(B_P_SIZE), 1),
dtype=acc_type,
buffer=nl.sbuf,
lazy_initialization=True,
)
for j in nl.sequential_range(0, num_large_k_tile):
cur_k_tile = nl.ndarray((par_dim(B_D_SIZE), LARGE_TILE_SZ),
dtype=kernel_dtype)
cur_v_tile = nl.ndarray(
(LARGE_TILE_SZ // B_P_SIZE, par_dim(B_P_SIZE), B_D_SIZE),
dtype=kernel_dtype,
)
for k_i in nl.affine_range(num_blocks_per_large_tile):
loaded = nl.load(key_cache[block_tables_sbuf[k_i, j], :,
head_id, :])
cur_k_tile[:, nl.ds(k_i *
block_size, block_size)] = nl.transpose(loaded)
load_tile_size = B_P_SIZE
num_blocks_per_partition = load_tile_size // block_size
for partition_idx in nl.affine_range(LARGE_TILE_SZ // load_tile_size):
for block_in_partition in nl.affine_range(
num_blocks_per_partition):
v_i = (partition_idx * num_blocks_per_partition +
block_in_partition)
loaded_v = nl.load(value_cache[block_tables_sbuf[v_i, j], :,
head_id, :])
cur_v_tile[partition_idx,
nl.ds(block_in_partition *
block_size, block_size), :, ] = loaded_v
cur_mask = nl.ndarray((par_dim(B_P_SIZE), LARGE_TILE_SZ),
dtype=mask.dtype)
for m_i in nl.affine_range(LARGE_TILE_SZ // B_F_SIZE):
cur_mask[:, nl.ds(m_i * B_F_SIZE, B_F_SIZE)] = nl.load(
mask[:, nl.ds(j * LARGE_TILE_SZ + m_i * B_F_SIZE, B_F_SIZE)])
for i_q_h in nl.affine_range(q_h_per_k_h):
for i in nl.affine_range(n_tile_q):
q_tile = nl.ndarray((B_D_SIZE, B_P_SIZE), dtype=kernel_dtype)
q_hbm_tile = query[batch_id, head_id * q_h_per_k_h + i_q_h]
q_sbuf_tile = nl.load(
q_hbm_tile[:, nl.ds(i * B_P_SIZE, B_P_SIZE)],
dtype=kernel_dtype,
) # load (d, 128) tile in SBUF
q_tile[:, :] = q_sbuf_tile * softmax_scale
_flash_attention_core(
q_local_tile=q_tile,
k=cur_k_tile,
v=cur_v_tile,
q_h_per_k_h=q_h_per_k_h,
seqlen_q=seqlen_q,
nheads=h,
o_buffer=o_buffer[i, i_q_h],
l_buffer=l_buffer[:, i, i_q_h],
m_buffer=m_buffer[i, i_q_h],
batch_id=batch_id,
head_id=head_id,
gqa_head_idx=i_q_h,
q_tile_idx=i,
local_k_large_tile_idx=j,
kernel_dtype=kernel_dtype,
acc_type=acc_type,
flash_config=config,
use_causal_mask=False,
continuous_batching_mask=cur_mask,
initialize=j == 0,
B_P_SIZE=B_P_SIZE,
B_F_SIZE=B_F_SIZE,
B_D_SIZE=B_D_SIZE,
dropout_p=0.0,
dropout_p_tensor=None,
seed_tensor=None,
logit_bias_tile=None,
)
# compute attention between input query, key and value
if key is not None and value is not None:
B_F_SIZE = seqlen_q
LARGE_TILE_SZ = seqlen_q
active_config = FlashConfig(
seq_tile_size=LARGE_TILE_SZ,
should_transpose_v=config.should_transpose_v,
)
cur_k_tile = nl.ndarray((par_dim(B_D_SIZE), LARGE_TILE_SZ),
dtype=kernel_dtype)
cur_v_tile = nl.ndarray(
(LARGE_TILE_SZ // B_P_SIZE, par_dim(B_P_SIZE), B_D_SIZE),
dtype=kernel_dtype,
)
cur_k_tile[:, :] = nl.load(key[batch_id, head_id, :, :])
load_tile_size = B_P_SIZE
v_hbm_tile = value[batch_id, head_id]
for v_i in nl.affine_range(LARGE_TILE_SZ // load_tile_size):
load_v_tile(
v_hbm_tile=v_hbm_tile,
cur_v_tile=cur_v_tile,
j=0,
v_i=v_i,
config=active_config,
)
cur_mask = nl.ndarray((par_dim(B_P_SIZE), B_F_SIZE), dtype=mask.dtype)
cur_mask[:, :] = nl.load(mask[:, nl.ds(context_kv_len, B_F_SIZE)])
for i_q_h in nl.affine_range(q_h_per_k_h):
for i in nl.affine_range(n_tile_q):
q_tile = nl.ndarray((B_D_SIZE, B_P_SIZE), dtype=kernel_dtype)
q_hbm_tile = query[batch_id, head_id * q_h_per_k_h + i_q_h]
q_sbuf_tile = nl.load(
q_hbm_tile[:, nl.ds(i * B_P_SIZE, B_P_SIZE)],
dtype=kernel_dtype,
) # load (d, 128) tile in SBUF
q_tile[:, :] = q_sbuf_tile * softmax_scale
_flash_attention_core(
q_local_tile=q_tile,
k=cur_k_tile,
v=cur_v_tile,
q_h_per_k_h=q_h_per_k_h,
seqlen_q=seqlen_q,
nheads=h,
o_buffer=o_buffer[i, i_q_h],
l_buffer=l_buffer[:, i, i_q_h],
m_buffer=m_buffer[i, i_q_h],
batch_id=batch_id,
head_id=head_id,
gqa_head_idx=i_q_h,
q_tile_idx=i,
local_k_large_tile_idx=0,
kernel_dtype=kernel_dtype,
acc_type=acc_type,
flash_config=active_config,
use_causal_mask=False,
continuous_batching_mask=cur_mask,
initialize=False,
B_P_SIZE=B_P_SIZE,
B_F_SIZE=B_F_SIZE,
B_D_SIZE=B_D_SIZE,
dropout_p=0.0,
dropout_p_tensor=None,
seed_tensor=None,
logit_bias_tile=None,
qk_res_buffer=qk_res_buffer[i, i_q_h]
if qk_res_buffer is not None else None,
)
# -- -- -- -- write output to buffer on HBM -- -- -- -- -- -- #
for i_q_h in nl.affine_range(q_h_per_k_h):
for i in nl.affine_range(n_tile_q):
out = nl.multiply(
o_buffer[i, i_q_h, :, :],
nl.exp(m_buffer[i, i_q_h, :, :] - l_buffer[:, i, i_q_h]),
dtype=kernel_dtype,
)
nl.store(
o[batch_id, head_id * q_h_per_k_h + i_q_h,
nl.ds(i * B_P_SIZE, B_P_SIZE), :, ],
out,
)
# maximum and summation statistics
if return_debug_tensors:
nl.store(
hbm_m_buffer[batch_id, head_id * q_h_per_k_h + i_q_h,
nl.ds(i * B_P_SIZE, B_P_SIZE), ],
m_buffer[i, i_q_h, :, :],
)
nl.store(
hbm_l_buffer[batch_id, head_id * q_h_per_k_h + i_q_h,
nl.ds(i * B_P_SIZE, B_P_SIZE), ],
l_buffer[:, i, i_q_h],
)
nl.store(
hbm_qk_res[batch_id, head_id * q_h_per_k_h + i_q_h, :, :],
qk_res_buffer[batch_id, i_q_h, :, :],
)
if return_debug_tensors:
return o, hbm_m_buffer, hbm_l_buffer, hbm_qk_res
return o
def flash_attn_varlen_nkifunc(
query,
key,
value,
key_cache,
value_cache,
block_table,
attn_mask,
n_kv_head=None,
head_size=None,
B_P_SIZE=128,
LARGE_TILE_SZ=2048,
return_debug_tensors=False,
mixed_precision=True,
):
config = FlashConfig(
seq_tile_size=LARGE_TILE_SZ,
should_transpose_v=False,
)
kwargs = dict(
query=query,
key=key,
value=value,
key_cache=key_cache,
value_cache=value_cache,
block_tables=block_table,
mask=attn_mask,
softmax_scale=1.0 / (head_size**0.5),
config=config,
mixed_precision=mixed_precision,
return_debug_tensors=return_debug_tensors,
)
_, n_kv_head, _, _ = key.shape
if return_debug_tensors:
o, *debug_tensors = flash_paged_attention[1, n_kv_head](**kwargs)
return o, *debug_tensors
else:
o = flash_paged_attention[1, n_kv_head](**kwargs)
return o