[Neuron][Kernel] NKI-based flash-attention kernel with paged KV cache (#11277)
Signed-off-by: Liangfu Chen <liangfc@amazon.com> Co-authored-by: Jiangfei Duan <jfduan@outlook.com>
This commit is contained in:
parent
823ab79633
commit
ddee88d0ff
@ -54,4 +54,4 @@ docker run --rm -it --device=/dev/neuron0 --device=/dev/neuron1 --network host \
|
|||||||
-e "NEURON_COMPILE_CACHE_URL=${NEURON_COMPILE_CACHE_MOUNT}" \
|
-e "NEURON_COMPILE_CACHE_URL=${NEURON_COMPILE_CACHE_MOUNT}" \
|
||||||
--name "${container_name}" \
|
--name "${container_name}" \
|
||||||
${image_name} \
|
${image_name} \
|
||||||
/bin/bash -c "python3 /workspace/vllm/examples/offline_inference/neuron.py"
|
/bin/bash -c "python3 /workspace/vllm/examples/offline_inference/neuron.py && python3 -m pytest /workspace/vllm/tests/neuron/ -v --capture=tee-sys"
|
||||||
|
456
tests/neuron/test_prefix_prefill.py
Normal file
456
tests/neuron/test_prefix_prefill.py
Normal file
@ -0,0 +1,456 @@
|
|||||||
|
import random
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
import torch
|
||||||
|
import torch.nn.functional as F
|
||||||
|
|
||||||
|
|
||||||
|
class BlockDiagonalCausalFromBottomRightMask:
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _from_seqlens(query_lens, seq_lens, block_size=None):
|
||||||
|
from torch import logical_and, logical_or
|
||||||
|
|
||||||
|
contexted = block_size is None
|
||||||
|
context_lens = torch.tensor(seq_lens) - torch.tensor(query_lens)
|
||||||
|
n_queries = sum(query_lens)
|
||||||
|
num_seqs = len(query_lens)
|
||||||
|
if contexted:
|
||||||
|
key_lens_blockaligned = seq_lens
|
||||||
|
else:
|
||||||
|
n_blocks_per_seq = (context_lens + block_size - 1) // block_size
|
||||||
|
offset_per_seq = n_blocks_per_seq * block_size
|
||||||
|
key_lens_blockaligned = offset_per_seq[:num_seqs].tolist()
|
||||||
|
n_keys = sum(key_lens_blockaligned)
|
||||||
|
|
||||||
|
a = (torch.arange(n_queries).reshape(n_queries,
|
||||||
|
1).expand(n_queries, n_keys))
|
||||||
|
b = torch.arange(n_keys).reshape(1, n_keys).expand(n_queries, n_keys)
|
||||||
|
q_cumsum = torch.tensor([0] + query_lens).cumsum(dim=0)
|
||||||
|
k_cumsum = torch.tensor([0] + key_lens_blockaligned).cumsum(dim=0)
|
||||||
|
|
||||||
|
prior_mask = torch.zeros(n_queries, n_keys)
|
||||||
|
new_masks: list[torch.Tensor] = []
|
||||||
|
for seq_id in range(num_seqs):
|
||||||
|
ri = q_cumsum[seq_id]
|
||||||
|
ci = k_cumsum[seq_id]
|
||||||
|
nr = query_lens[seq_id]
|
||||||
|
|
||||||
|
if contexted:
|
||||||
|
nc = seq_lens[seq_id]
|
||||||
|
a_offset = ci + nc - ri - nr
|
||||||
|
new_mask = (a + a_offset) >= b
|
||||||
|
else:
|
||||||
|
nc = context_lens[seq_id]
|
||||||
|
a_offset = ci + nc - 1
|
||||||
|
new_mask = a_offset >= b
|
||||||
|
|
||||||
|
left_mask = b >= ci
|
||||||
|
top_mask = a >= ri
|
||||||
|
bottom_mask = a < (ri + nr)
|
||||||
|
|
||||||
|
new_mask = logical_and(
|
||||||
|
logical_and(logical_and(new_mask, left_mask), top_mask),
|
||||||
|
bottom_mask,
|
||||||
|
)
|
||||||
|
prior_mask = logical_or(prior_mask, new_mask)
|
||||||
|
new_masks = new_masks + [new_mask]
|
||||||
|
return prior_mask
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def from_seqlens(query_lens, seq_lens, block_size=None):
|
||||||
|
contexted = block_size is None
|
||||||
|
if contexted:
|
||||||
|
prior_mask = BlockDiagonalCausalFromBottomRightMask._from_seqlens(
|
||||||
|
query_lens, seq_lens)
|
||||||
|
active_mask = None
|
||||||
|
else:
|
||||||
|
prior_mask = BlockDiagonalCausalFromBottomRightMask._from_seqlens(
|
||||||
|
query_lens, seq_lens, block_size)
|
||||||
|
active_mask = BlockDiagonalCausalFromBottomRightMask._from_seqlens(
|
||||||
|
query_lens, query_lens)
|
||||||
|
return prior_mask, active_mask
|
||||||
|
|
||||||
|
|
||||||
|
def ref_softmax(x: torch.Tensor,
|
||||||
|
dim: int,
|
||||||
|
mixed_precision=False,
|
||||||
|
return_max_reduce=False):
|
||||||
|
max_value = torch.amax(x, dim=dim, keepdims=True)
|
||||||
|
exp = torch.exp(x - max_value)
|
||||||
|
if mixed_precision:
|
||||||
|
sum_value = torch.sum(exp.astype(torch.float32),
|
||||||
|
dim=dim,
|
||||||
|
keepdims=True).astype(x.dtype)
|
||||||
|
else:
|
||||||
|
sum_value = torch.sum(exp, dim=dim, keepdims=True)
|
||||||
|
if return_max_reduce:
|
||||||
|
return exp / sum_value, max_value, torch.reciprocal(sum_value)
|
||||||
|
return exp / sum_value
|
||||||
|
|
||||||
|
|
||||||
|
def ref_masked_attention(
|
||||||
|
query: torch.Tensor,
|
||||||
|
key: torch.Tensor,
|
||||||
|
value: torch.Tensor,
|
||||||
|
scale: float,
|
||||||
|
attn_mask: Optional[torch.Tensor] = None,
|
||||||
|
return_max_reduce: Optional[bool] = False,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
scaled_qk = scale * torch.einsum("qhd,khd->hqk", query, key).float()
|
||||||
|
if attn_mask is not None:
|
||||||
|
masked_score = scaled_qk + attn_mask.float()
|
||||||
|
if return_max_reduce:
|
||||||
|
norm_score, cached_max, cached_sum_reciprocal = ref_softmax(
|
||||||
|
masked_score, dim=-1, return_max_reduce=True)
|
||||||
|
else:
|
||||||
|
norm_score = ref_softmax(masked_score, dim=-1)
|
||||||
|
out = torch.einsum("hqk,khd->qhd", norm_score, value)
|
||||||
|
if return_max_reduce:
|
||||||
|
return (
|
||||||
|
out,
|
||||||
|
cached_max,
|
||||||
|
cached_sum_reciprocal,
|
||||||
|
norm_score,
|
||||||
|
masked_score,
|
||||||
|
scaled_qk,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
return out
|
||||||
|
|
||||||
|
|
||||||
|
def ref_context_attention(
|
||||||
|
query,
|
||||||
|
key,
|
||||||
|
value,
|
||||||
|
query_lens,
|
||||||
|
seq_lens,
|
||||||
|
head_size,
|
||||||
|
num_kv_heads,
|
||||||
|
num_heads,
|
||||||
|
num_queries_per_kv,
|
||||||
|
return_max_reduce=False,
|
||||||
|
):
|
||||||
|
scale = float(1.0 / (head_size**0.5))
|
||||||
|
if num_queries_per_kv > 1:
|
||||||
|
# Handle MQA and GQA
|
||||||
|
key = torch.repeat_interleave(key, num_queries_per_kv, dim=1)
|
||||||
|
value = torch.repeat_interleave(value, num_queries_per_kv, dim=1)
|
||||||
|
|
||||||
|
attn_mask, _ = BlockDiagonalCausalFromBottomRightMask.from_seqlens(
|
||||||
|
query_lens, seq_lens)
|
||||||
|
|
||||||
|
# convert binary mask to -inf values
|
||||||
|
attn_mask = torch.logical_not(attn_mask)
|
||||||
|
attn_mask = attn_mask.float() * -30000
|
||||||
|
|
||||||
|
output, cached_max, cached_sum_reciprocal, lse, masked_score, scaled_qk = (
|
||||||
|
ref_masked_attention(
|
||||||
|
query,
|
||||||
|
key,
|
||||||
|
value,
|
||||||
|
scale,
|
||||||
|
attn_mask,
|
||||||
|
return_max_reduce=return_max_reduce,
|
||||||
|
))
|
||||||
|
|
||||||
|
output = output.unsqueeze(1)
|
||||||
|
if return_max_reduce:
|
||||||
|
return (
|
||||||
|
output,
|
||||||
|
cached_max,
|
||||||
|
cached_sum_reciprocal,
|
||||||
|
lse,
|
||||||
|
masked_score,
|
||||||
|
scaled_qk,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
return output
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"num_heads,num_queries_per_kv,head_size,mixed_precision",
|
||||||
|
[
|
||||||
|
(4, 2, 8, False),
|
||||||
|
(4, 2, 8, True),
|
||||||
|
(32, 8, 64, True),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
@torch.inference_mode()
|
||||||
|
def test_contexted_kv_attention(
|
||||||
|
num_heads: int,
|
||||||
|
num_queries_per_kv: int,
|
||||||
|
head_size: int,
|
||||||
|
mixed_precision: bool,
|
||||||
|
) -> None:
|
||||||
|
import os
|
||||||
|
|
||||||
|
import torch_xla.core.xla_model as xm
|
||||||
|
|
||||||
|
from vllm.attention.ops.nki_flash_attn import flash_attn_varlen_nkifunc
|
||||||
|
|
||||||
|
device = xm.xla_device()
|
||||||
|
|
||||||
|
os.environ["NEURON_CC_FLAGS"] = (
|
||||||
|
" --model-type=transformer -O1 "
|
||||||
|
" --internal-hlo2tensorizer-options='--verify-hlo' ")
|
||||||
|
|
||||||
|
random.seed(0)
|
||||||
|
torch.manual_seed(0)
|
||||||
|
torch.set_printoptions(sci_mode=False)
|
||||||
|
|
||||||
|
min_ctx_len = 2
|
||||||
|
max_ctx_len = 64
|
||||||
|
min_query_len = 2
|
||||||
|
max_query_len = 64
|
||||||
|
prefill_batch_size = 2
|
||||||
|
decode_batch_size = 6
|
||||||
|
batch_size = prefill_batch_size + decode_batch_size
|
||||||
|
block_size = 32
|
||||||
|
max_model_len = (max_query_len + max_ctx_len) * 4
|
||||||
|
|
||||||
|
max_block_per_request = max_model_len // block_size
|
||||||
|
dtype = torch.float32
|
||||||
|
cache_size = (batch_size * max_block_per_request) + 2
|
||||||
|
ctx_lens = [
|
||||||
|
random.randint(min_ctx_len, max_ctx_len)
|
||||||
|
for _ in range(prefill_batch_size)
|
||||||
|
] + [
|
||||||
|
random.randint(min_ctx_len, max_ctx_len)
|
||||||
|
for _ in range(decode_batch_size)
|
||||||
|
]
|
||||||
|
query_lens = [
|
||||||
|
random.randint(min_query_len, max_query_len)
|
||||||
|
for _ in range(prefill_batch_size)
|
||||||
|
] + [1 for _ in range(decode_batch_size)]
|
||||||
|
seq_lens = [a + b for a, b in zip(query_lens, ctx_lens)]
|
||||||
|
num_kv_heads = num_heads // num_queries_per_kv
|
||||||
|
|
||||||
|
num_tokens = sum(query_lens)
|
||||||
|
query = torch.empty(num_tokens, num_heads, head_size, dtype=dtype)
|
||||||
|
query.uniform_(-1, 1)
|
||||||
|
torch.empty(num_tokens, num_heads, head_size, dtype=dtype)
|
||||||
|
|
||||||
|
kv = torch.empty(sum(seq_lens), 2, num_kv_heads, head_size, dtype=dtype)
|
||||||
|
kv.uniform_(-1, 1)
|
||||||
|
key, value = kv.unbind(dim=1)
|
||||||
|
|
||||||
|
k_cache = torch.zeros(cache_size,
|
||||||
|
block_size,
|
||||||
|
num_kv_heads,
|
||||||
|
head_size,
|
||||||
|
dtype=dtype)
|
||||||
|
v_cache = torch.zeros(cache_size,
|
||||||
|
block_size,
|
||||||
|
num_kv_heads,
|
||||||
|
head_size,
|
||||||
|
dtype=dtype)
|
||||||
|
k = torch.zeros(sum(query_lens), num_kv_heads, head_size, dtype=dtype)
|
||||||
|
v = torch.zeros(sum(query_lens), num_kv_heads, head_size, dtype=dtype)
|
||||||
|
values = torch.arange(0, cache_size, dtype=torch.long)
|
||||||
|
values = values[torch.randperm(cache_size)]
|
||||||
|
block_table = values[:batch_size * max_block_per_request].view(
|
||||||
|
batch_size, max_block_per_request)
|
||||||
|
torch.tensor(seq_lens, dtype=torch.long)
|
||||||
|
b_ctx_len = torch.tensor(ctx_lens, dtype=torch.long)
|
||||||
|
b_start_loc = torch.cumsum(torch.tensor([0] + query_lens[:-1],
|
||||||
|
dtype=torch.long),
|
||||||
|
dim=0)
|
||||||
|
# copy kv to cache
|
||||||
|
b_seq_start_loc = torch.cumsum(torch.tensor([0] + seq_lens[:-1],
|
||||||
|
dtype=torch.long),
|
||||||
|
dim=0)
|
||||||
|
for i in range(batch_size):
|
||||||
|
for j in range(query_lens[i]):
|
||||||
|
k[b_start_loc[i] + j].copy_(key[b_seq_start_loc[i] + b_ctx_len[i] +
|
||||||
|
j])
|
||||||
|
v[b_start_loc[i] + j].copy_(value[b_seq_start_loc[i] +
|
||||||
|
b_ctx_len[i] + j])
|
||||||
|
cur_ctx = 0
|
||||||
|
block_id = 0
|
||||||
|
while cur_ctx < b_ctx_len[i]:
|
||||||
|
start_loc = b_seq_start_loc[i] + cur_ctx
|
||||||
|
if cur_ctx + block_size > b_ctx_len[i]:
|
||||||
|
end_loc = b_seq_start_loc[i] + b_ctx_len[i]
|
||||||
|
else:
|
||||||
|
end_loc = start_loc + block_size
|
||||||
|
start_slot = block_table[i, block_id] * block_size
|
||||||
|
end_slot = start_slot + end_loc - start_loc
|
||||||
|
k_cache.view(-1, num_kv_heads,
|
||||||
|
head_size)[start_slot:end_slot].copy_(
|
||||||
|
key[start_loc:end_loc])
|
||||||
|
v_cache.view(-1, num_kv_heads,
|
||||||
|
head_size)[start_slot:end_slot].copy_(
|
||||||
|
value[start_loc:end_loc])
|
||||||
|
cur_ctx += block_size
|
||||||
|
block_id += 1
|
||||||
|
|
||||||
|
(
|
||||||
|
output_ref,
|
||||||
|
cached_max,
|
||||||
|
cached_sum_reciprocal,
|
||||||
|
lse,
|
||||||
|
masked_score,
|
||||||
|
scaled_qk,
|
||||||
|
) = ref_context_attention(
|
||||||
|
query,
|
||||||
|
key,
|
||||||
|
value,
|
||||||
|
query_lens,
|
||||||
|
seq_lens,
|
||||||
|
head_size,
|
||||||
|
num_kv_heads,
|
||||||
|
num_heads,
|
||||||
|
num_queries_per_kv,
|
||||||
|
return_max_reduce=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
# build neuron program
|
||||||
|
return_debug_tensors = False
|
||||||
|
B_P_SIZE = 128
|
||||||
|
LARGE_TILE_SZ = 2048
|
||||||
|
max_num_queries = (
|
||||||
|
(sum(query_lens) + block_size - 1) // block_size) * block_size
|
||||||
|
|
||||||
|
def get_active_block_tables(block_tables, query_lens, seq_lens, block_size,
|
||||||
|
num_blocks):
|
||||||
|
context_lens = seq_lens - query_lens
|
||||||
|
blocks_per_seq = (context_lens + block_size - 1) // block_size
|
||||||
|
num_seqs = len(seq_lens)
|
||||||
|
active_blocks: list[int] = []
|
||||||
|
for seq_id in range(num_seqs):
|
||||||
|
active_blocks = (
|
||||||
|
active_blocks +
|
||||||
|
block_tables[seq_id, :blocks_per_seq[seq_id]].tolist())
|
||||||
|
return F.pad(
|
||||||
|
torch.tensor(active_blocks),
|
||||||
|
(0, num_blocks - len(active_blocks)),
|
||||||
|
"constant",
|
||||||
|
0,
|
||||||
|
)
|
||||||
|
|
||||||
|
def shift_bit_length(x):
|
||||||
|
return 1 << (x - 1).bit_length()
|
||||||
|
|
||||||
|
# calculate input shapes
|
||||||
|
max_num_queries_shifted = shift_bit_length(max_num_queries)
|
||||||
|
max_num_queries_factor = B_P_SIZE // max_num_queries_shifted
|
||||||
|
max_num_queries_padded = max_num_queries_shifted * max_num_queries_factor
|
||||||
|
assert (max_num_queries_padded == B_P_SIZE
|
||||||
|
), "invalid {max_num_queries_padded=}"
|
||||||
|
head_size_padded = B_P_SIZE
|
||||||
|
context_lens = torch.tensor(seq_lens) - torch.tensor(query_lens)
|
||||||
|
num_active_blocks_shifted = shift_bit_length(
|
||||||
|
((context_lens + block_size - 1) // block_size).sum().item())
|
||||||
|
num_active_blocks_factor = (LARGE_TILE_SZ // block_size //
|
||||||
|
num_active_blocks_shifted)
|
||||||
|
num_active_blocks = num_active_blocks_shifted * num_active_blocks_factor
|
||||||
|
assert (num_active_blocks *
|
||||||
|
block_size) == LARGE_TILE_SZ, "invalid {num_active_blocks=}"
|
||||||
|
context_kv_len = num_active_blocks * block_size
|
||||||
|
assert context_kv_len == LARGE_TILE_SZ, f"invalid {context_kv_len=}"
|
||||||
|
|
||||||
|
# pad QKV tensors
|
||||||
|
pad_dims = (
|
||||||
|
0,
|
||||||
|
head_size_padded - query.shape[2],
|
||||||
|
0,
|
||||||
|
0,
|
||||||
|
0,
|
||||||
|
max_num_queries_padded - query.shape[0],
|
||||||
|
)
|
||||||
|
query = F.pad(query, pad_dims, "constant", 0)
|
||||||
|
k = F.pad(k, pad_dims, "constant", 0)
|
||||||
|
v = F.pad(v, pad_dims, "constant", 0)
|
||||||
|
k_cache = F.pad(k_cache, (0, head_size_padded - head_size), "constant", 0)
|
||||||
|
v_cache = F.pad(v_cache, (0, head_size_padded - head_size), "constant", 0)
|
||||||
|
|
||||||
|
# permute QKV tensors
|
||||||
|
# query: (1, n_heads, d, seq_q)
|
||||||
|
# key: (1, n_kv_heads, d, seq_k)
|
||||||
|
# value: (1, n_kv_heads, seq_v, d)
|
||||||
|
query = query.unsqueeze(0).permute(0, 2, 3, 1).contiguous()
|
||||||
|
k = k.unsqueeze(0).permute(0, 2, 3, 1).contiguous()
|
||||||
|
v = v.unsqueeze(0).permute(0, 2, 1, 3).contiguous()
|
||||||
|
|
||||||
|
# transform block table
|
||||||
|
active_block_table = get_active_block_tables(
|
||||||
|
block_table,
|
||||||
|
torch.tensor(query_lens),
|
||||||
|
torch.tensor(seq_lens),
|
||||||
|
block_size,
|
||||||
|
num_active_blocks,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Build attention masks
|
||||||
|
prior_mask, active_mask = (
|
||||||
|
BlockDiagonalCausalFromBottomRightMask.from_seqlens(
|
||||||
|
query_lens, seq_lens, block_size=block_size))
|
||||||
|
attn_mask = torch.concat(
|
||||||
|
[
|
||||||
|
F.pad(
|
||||||
|
prior_mask,
|
||||||
|
(
|
||||||
|
0,
|
||||||
|
context_kv_len - prior_mask.shape[1],
|
||||||
|
0,
|
||||||
|
B_P_SIZE - prior_mask.shape[0],
|
||||||
|
),
|
||||||
|
"constant",
|
||||||
|
0,
|
||||||
|
).bool(),
|
||||||
|
F.pad(
|
||||||
|
active_mask,
|
||||||
|
(
|
||||||
|
0,
|
||||||
|
B_P_SIZE - active_mask.shape[1],
|
||||||
|
0,
|
||||||
|
B_P_SIZE - active_mask.shape[0],
|
||||||
|
),
|
||||||
|
"constant",
|
||||||
|
0,
|
||||||
|
).bool(),
|
||||||
|
],
|
||||||
|
dim=1,
|
||||||
|
)
|
||||||
|
|
||||||
|
input_args = (
|
||||||
|
query.to(device=device),
|
||||||
|
k.to(device=device),
|
||||||
|
v.to(device=device),
|
||||||
|
k_cache.to(device=device),
|
||||||
|
v_cache.to(device=device),
|
||||||
|
active_block_table.to(torch.int32).to(device=device),
|
||||||
|
attn_mask.to(device=device),
|
||||||
|
)
|
||||||
|
input_kwargs = dict(
|
||||||
|
n_kv_head=num_kv_heads,
|
||||||
|
head_size=head_size,
|
||||||
|
mixed_precision=mixed_precision,
|
||||||
|
)
|
||||||
|
|
||||||
|
if return_debug_tensors:
|
||||||
|
output_nki, *debug_tensors = flash_attn_varlen_nkifunc(
|
||||||
|
*input_args, **input_kwargs)
|
||||||
|
else:
|
||||||
|
output_nki = flash_attn_varlen_nkifunc(*input_args, **input_kwargs)
|
||||||
|
debug_tensors = []
|
||||||
|
|
||||||
|
output_nki = torch.tensor(output_nki).cpu()
|
||||||
|
debug_tensors = [torch.tensor(dt).cpu() for dt in debug_tensors]
|
||||||
|
|
||||||
|
num_actual_tokens = sum(query_lens)
|
||||||
|
print(f"{num_actual_tokens=}")
|
||||||
|
# - o: shape (bs, n_heads, seq_q, d) -> (bs, seq_q, n_heads, d)
|
||||||
|
output_nki = output_nki.permute(
|
||||||
|
0, 2, 1, 3)[:, :, :, :head_size].cpu()[0, :num_actual_tokens, :, :]
|
||||||
|
output_ref_padded = F.pad(
|
||||||
|
output_ref,
|
||||||
|
(0, 0, 0, 0, 0, 0, 0, max_num_queries_padded - output_ref.shape[0]),
|
||||||
|
"constant",
|
||||||
|
0,
|
||||||
|
)
|
||||||
|
output_ref = output_ref_padded.transpose(0, 1)[0, :num_actual_tokens, :, :]
|
||||||
|
|
||||||
|
torch.testing.assert_close(output_nki, output_ref, atol=1e-2, rtol=0)
|
669
vllm/attention/ops/nki_flash_attn.py
Normal file
669
vllm/attention/ops/nki_flash_attn.py
Normal file
@ -0,0 +1,669 @@
|
|||||||
|
from dataclasses import dataclass
|
||||||
|
|
||||||
|
import neuronxcc.nki.isa as nisa
|
||||||
|
import neuronxcc.nki.language as nl
|
||||||
|
import numpy as np
|
||||||
|
from neuronxcc import nki
|
||||||
|
from neuronxcc.nki.language import par_dim
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass(frozen=True)
|
||||||
|
class FlashConfig:
|
||||||
|
"""
|
||||||
|
Config class for flash attention with default values
|
||||||
|
"""
|
||||||
|
|
||||||
|
seq_tile_size: int = 2048
|
||||||
|
should_transpose_v: bool = False
|
||||||
|
|
||||||
|
__annotations__ = {
|
||||||
|
"seq_tile_size": int,
|
||||||
|
"should_transpose_v": bool,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@nki.jit
|
||||||
|
def transpose_p_local(p_local_transposed,
|
||||||
|
p_local,
|
||||||
|
LARGE_TILE_SZ,
|
||||||
|
forward_mask,
|
||||||
|
B_F_SIZE=512):
|
||||||
|
for i in nl.affine_range(LARGE_TILE_SZ // B_F_SIZE):
|
||||||
|
if nisa.get_nc_version() == nisa.nc_version.gen3:
|
||||||
|
p_local_t_tmp = nl.ndarray((par_dim(128), B_F_SIZE),
|
||||||
|
buffer=nl.sbuf,
|
||||||
|
dtype=p_local.dtype)
|
||||||
|
else:
|
||||||
|
p_local_t_tmp = nl.ndarray((par_dim(128), B_F_SIZE),
|
||||||
|
buffer=nl.psum,
|
||||||
|
dtype=np.float32)
|
||||||
|
|
||||||
|
for j in nl.affine_range(B_F_SIZE // 128):
|
||||||
|
j_128_slice = nl.ds(j * 128, 128)
|
||||||
|
i_j_128_slice = nl.ds(i * B_F_SIZE + j * 128, 128)
|
||||||
|
|
||||||
|
if nisa.get_nc_version() == nisa.nc_version.gen3:
|
||||||
|
p_local_t_tmp[:, j_128_slice] = nisa.dma_transpose(
|
||||||
|
p_local[:, i_j_128_slice], mask=forward_mask)
|
||||||
|
else:
|
||||||
|
p_local_t_tmp[:, j_128_slice] = nisa.nc_transpose(
|
||||||
|
p_local[:, i_j_128_slice], mask=forward_mask)
|
||||||
|
|
||||||
|
p_local_transposed[:, nl.ds(i * B_F_SIZE, B_F_SIZE)] = nl.copy(
|
||||||
|
p_local_t_tmp, dtype=p_local_transposed.dtype, mask=forward_mask)
|
||||||
|
|
||||||
|
|
||||||
|
@nki.jit
|
||||||
|
def _flash_attention_core(
|
||||||
|
q_local_tile,
|
||||||
|
k,
|
||||||
|
v,
|
||||||
|
q_h_per_k_h,
|
||||||
|
seqlen_q,
|
||||||
|
nheads,
|
||||||
|
o_buffer,
|
||||||
|
l_buffer,
|
||||||
|
m_buffer,
|
||||||
|
batch_id,
|
||||||
|
head_id,
|
||||||
|
gqa_head_idx,
|
||||||
|
q_tile_idx,
|
||||||
|
local_k_large_tile_idx,
|
||||||
|
kernel_dtype,
|
||||||
|
acc_type,
|
||||||
|
flash_config: FlashConfig,
|
||||||
|
use_causal_mask=False,
|
||||||
|
continuous_batching_mask=None,
|
||||||
|
initialize=False,
|
||||||
|
B_P_SIZE=128,
|
||||||
|
B_F_SIZE=512,
|
||||||
|
B_D_SIZE=128,
|
||||||
|
dropout_p=0.0,
|
||||||
|
dropout_p_tensor=None,
|
||||||
|
seed_tensor=None,
|
||||||
|
logit_bias_tile=None,
|
||||||
|
qk_res_buffer=None,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
The flash attention core function to calculate self attention between a tile
|
||||||
|
of q and a block of K and V.
|
||||||
|
The q_local_tile has (B_P_SIZE, B_F_SIZE), which is loaded into the SBUF
|
||||||
|
already. The block size of K and V
|
||||||
|
is defined in the seq_tile_size of the flash_config. The results are stored
|
||||||
|
in the following three buffers
|
||||||
|
o_buffer: (B_P_SIZE, d)
|
||||||
|
l_buffer: (B_P_SIZE, 1)
|
||||||
|
m_buffer: (B_P_SIZE, 1)
|
||||||
|
"""
|
||||||
|
LARGE_TILE_SZ = flash_config.seq_tile_size
|
||||||
|
num_k_tile_per_large_tile = LARGE_TILE_SZ // B_F_SIZE
|
||||||
|
seqlen_k = k.shape[-1]
|
||||||
|
seqlen_q // B_P_SIZE
|
||||||
|
seqlen_k // B_F_SIZE
|
||||||
|
|
||||||
|
# TODO : support logit_bias with continuous_batching_mask
|
||||||
|
assert not use_causal_mask, "causal mask is not supported."
|
||||||
|
assert (continuous_batching_mask
|
||||||
|
is not None), "continuous_batching_mask input is required."
|
||||||
|
if continuous_batching_mask is not None:
|
||||||
|
assert (logit_bias_tile is
|
||||||
|
None), "continuous_batching_mask does not support logit_bias!"
|
||||||
|
|
||||||
|
# mask are used to only apply computation to the lower half of the matrix,
|
||||||
|
# which reduce the arthimetic intensity by half
|
||||||
|
forward_mask = (q_tile_idx * B_P_SIZE >= local_k_large_tile_idx *
|
||||||
|
LARGE_TILE_SZ if use_causal_mask else None)
|
||||||
|
|
||||||
|
qk_res_buf = nl.ndarray((par_dim(B_P_SIZE), LARGE_TILE_SZ),
|
||||||
|
buffer=nl.sbuf,
|
||||||
|
dtype=acc_type)
|
||||||
|
max_local = nl.ndarray((par_dim(B_P_SIZE), num_k_tile_per_large_tile),
|
||||||
|
dtype=acc_type)
|
||||||
|
for k_i in nl.affine_range(num_k_tile_per_large_tile):
|
||||||
|
k_i_b_f_slice = nl.ds(k_i * B_F_SIZE, B_F_SIZE)
|
||||||
|
|
||||||
|
qk_psum = nl.zeros((par_dim(B_P_SIZE), B_F_SIZE),
|
||||||
|
dtype=np.float32,
|
||||||
|
buffer=nl.psum) # (128, 512)
|
||||||
|
qk_psum[:, :] = nl.matmul(q_local_tile,
|
||||||
|
k[:, k_i_b_f_slice],
|
||||||
|
transpose_x=True,
|
||||||
|
mask=None) # (p(128), 512)
|
||||||
|
|
||||||
|
qk_res_buf[:, k_i_b_f_slice] = nl.where(
|
||||||
|
continuous_batching_mask[:, k_i_b_f_slice],
|
||||||
|
qk_psum[:, nl.ds(0, B_F_SIZE)],
|
||||||
|
-9984.0,
|
||||||
|
dtype=acc_type,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Calculate max of the current tile
|
||||||
|
max_local[:, k_i] = nisa.tensor_reduce(
|
||||||
|
np.max,
|
||||||
|
qk_res_buf[:, k_i_b_f_slice],
|
||||||
|
axis=(1, ),
|
||||||
|
dtype=acc_type,
|
||||||
|
negate=False,
|
||||||
|
mask=forward_mask,
|
||||||
|
)
|
||||||
|
|
||||||
|
if qk_res_buffer is not None:
|
||||||
|
qk_res_buffer[:, :] = nl.copy(qk_res_buf[:, :])
|
||||||
|
|
||||||
|
max_ = nisa.tensor_reduce(
|
||||||
|
np.max,
|
||||||
|
max_local[:, :],
|
||||||
|
axis=(1, ),
|
||||||
|
dtype=acc_type,
|
||||||
|
negate=False,
|
||||||
|
mask=forward_mask,
|
||||||
|
)
|
||||||
|
|
||||||
|
o_previous_scaled = nl.ndarray((par_dim(B_P_SIZE), B_D_SIZE),
|
||||||
|
dtype=o_buffer.dtype)
|
||||||
|
|
||||||
|
if initialize:
|
||||||
|
m_buffer[:, 0] = nl.copy(max_)
|
||||||
|
m_current = max_
|
||||||
|
else:
|
||||||
|
m_previous = nl.copy(m_buffer[:, 0])
|
||||||
|
m_buffer[:, 0] = nl.maximum(m_previous, max_,
|
||||||
|
mask=forward_mask) # (128,1)
|
||||||
|
|
||||||
|
m_current = m_buffer[:, 0]
|
||||||
|
# Compute scaling factor
|
||||||
|
alpha = nisa.activation(
|
||||||
|
np.exp,
|
||||||
|
m_previous,
|
||||||
|
bias=-1 * m_current,
|
||||||
|
scale=1.0,
|
||||||
|
mask=forward_mask,
|
||||||
|
)
|
||||||
|
o_previous_scaled[...] = nl.multiply(o_buffer[:, :],
|
||||||
|
alpha,
|
||||||
|
mask=forward_mask)
|
||||||
|
|
||||||
|
p_local = nl.ndarray((par_dim(B_P_SIZE), LARGE_TILE_SZ),
|
||||||
|
dtype=kernel_dtype)
|
||||||
|
REDUCTION_TILE = min(2048, LARGE_TILE_SZ // 2)
|
||||||
|
|
||||||
|
p_partial_sum = nl.ndarray(
|
||||||
|
(par_dim(B_P_SIZE), LARGE_TILE_SZ // REDUCTION_TILE), dtype=acc_type)
|
||||||
|
|
||||||
|
for k_r_i in nl.affine_range(LARGE_TILE_SZ // REDUCTION_TILE):
|
||||||
|
k_r_i_reduce_slice = nl.ds(k_r_i * REDUCTION_TILE, REDUCTION_TILE)
|
||||||
|
|
||||||
|
# compute exp(qk - max)
|
||||||
|
# Compute partial row - tile sum of exp(qk - max))
|
||||||
|
# FIXME : Use activation accumulate to accumulate over k_r_i loop ?
|
||||||
|
p_local[:, k_r_i_reduce_slice] = nisa.activation_reduce(
|
||||||
|
np.exp,
|
||||||
|
qk_res_buf[:, k_r_i_reduce_slice],
|
||||||
|
bias=-1 * m_current,
|
||||||
|
scale=1.0,
|
||||||
|
reduce_op=nl.add,
|
||||||
|
reduce_res=p_partial_sum[:, k_r_i],
|
||||||
|
dtype=kernel_dtype,
|
||||||
|
mask=forward_mask,
|
||||||
|
)
|
||||||
|
|
||||||
|
ps = nl.sum(p_partial_sum, axis=1, dtype=acc_type, mask=forward_mask)
|
||||||
|
|
||||||
|
p_local_transposed = nl.ndarray((par_dim(B_P_SIZE), LARGE_TILE_SZ),
|
||||||
|
dtype=kernel_dtype)
|
||||||
|
transpose_p_local(
|
||||||
|
p_local_transposed=p_local_transposed,
|
||||||
|
p_local=p_local,
|
||||||
|
LARGE_TILE_SZ=LARGE_TILE_SZ,
|
||||||
|
forward_mask=forward_mask,
|
||||||
|
B_F_SIZE=B_F_SIZE,
|
||||||
|
)
|
||||||
|
|
||||||
|
pv_psum = nl.zeros((par_dim(B_P_SIZE), B_D_SIZE),
|
||||||
|
dtype=np.float32,
|
||||||
|
buffer=nl.psum)
|
||||||
|
for k_i in nl.affine_range(LARGE_TILE_SZ // B_P_SIZE):
|
||||||
|
pv_psum[:, :] += nl.matmul(
|
||||||
|
p_local_transposed[:, nl.ds(k_i * B_P_SIZE, B_P_SIZE)],
|
||||||
|
v[k_i, :, :],
|
||||||
|
transpose_x=True,
|
||||||
|
mask=forward_mask,
|
||||||
|
) # (128, 128) (p(Br), d)
|
||||||
|
|
||||||
|
if initialize:
|
||||||
|
o_buffer[:, :] = nl.copy(pv_psum[:, :])
|
||||||
|
l_buffer[:, 0] = nl.add(nl.log(ps), max_)
|
||||||
|
else:
|
||||||
|
o_buffer[:, :] = nl.add(o_previous_scaled, pv_psum, mask=forward_mask)
|
||||||
|
|
||||||
|
l_prev = l_buffer[:, 0]
|
||||||
|
l_exp = nl.add(
|
||||||
|
nl.exp(
|
||||||
|
nl.subtract(l_prev, m_current, mask=forward_mask),
|
||||||
|
mask=forward_mask,
|
||||||
|
),
|
||||||
|
ps,
|
||||||
|
mask=forward_mask,
|
||||||
|
)
|
||||||
|
l_buffer[:, 0] = nl.add(m_current,
|
||||||
|
nl.log(l_exp, mask=forward_mask),
|
||||||
|
mask=forward_mask)
|
||||||
|
|
||||||
|
|
||||||
|
@nki.jit
|
||||||
|
def load_v_tile(v_hbm_tile, cur_v_tile, j, v_i, config):
|
||||||
|
LARGE_TILE_SZ = config.seq_tile_size
|
||||||
|
B_P_SIZE = 128
|
||||||
|
|
||||||
|
if not config.should_transpose_v:
|
||||||
|
cur_v_tile[v_i, :, :] = nl.load(
|
||||||
|
v_hbm_tile[nl.ds(j * LARGE_TILE_SZ + B_P_SIZE * v_i, B_P_SIZE), :],
|
||||||
|
dtype=cur_v_tile.dtype,
|
||||||
|
)
|
||||||
|
return
|
||||||
|
|
||||||
|
if nisa.get_nc_version() == nisa.nc_version.gen3:
|
||||||
|
cur_v_tile_transposed = nisa.dma_transpose(
|
||||||
|
v_hbm_tile[:,
|
||||||
|
nl.ds(j * LARGE_TILE_SZ + B_P_SIZE * v_i, B_P_SIZE)])
|
||||||
|
cur_v_tile[v_i, :, :] = nisa.tensor_copy(cur_v_tile_transposed,
|
||||||
|
dtype=cur_v_tile.dtype)
|
||||||
|
return
|
||||||
|
|
||||||
|
cur_v_tile[v_i, :, :] = nl.load_transpose2d(
|
||||||
|
v_hbm_tile[:, nl.ds(j * LARGE_TILE_SZ + B_P_SIZE * v_i, B_P_SIZE)],
|
||||||
|
dtype=cur_v_tile.dtype,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@nki.jit
|
||||||
|
def flash_paged_attention(
|
||||||
|
query,
|
||||||
|
key,
|
||||||
|
value,
|
||||||
|
key_cache,
|
||||||
|
value_cache,
|
||||||
|
block_tables,
|
||||||
|
mask,
|
||||||
|
softmax_scale=None,
|
||||||
|
mixed_precision=True,
|
||||||
|
config=None,
|
||||||
|
return_debug_tensors=False,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Flash PagedAttention Forward Kernel.
|
||||||
|
- PagedAttention Paper: https://arxiv.org/abs/2309.06180
|
||||||
|
- Chunked Prefill Paper: https://arxiv.org/abs/2403.02310
|
||||||
|
|
||||||
|
IO tensor layouts:
|
||||||
|
- query: shape (1, n_heads, d, seq_q)
|
||||||
|
- key: shape (1, n_kv_heads, d, seq_k)
|
||||||
|
- value: shape (1, n_kv_heads, seq_v, d)
|
||||||
|
- key_cache: (num_blocks, block_size, n_kv_heads, d)
|
||||||
|
- value_cache: (num_blocks, block_size, n_kv_heads, d)
|
||||||
|
- block_tables: (num_active_blocks, )
|
||||||
|
- mask: (seq_q, num_active_blocks * block_size)
|
||||||
|
- o: shape (1, n_heads, seq_q, d)
|
||||||
|
- l_m: shape (1, n_heads, seq_q, 2)
|
||||||
|
|
||||||
|
- This kernel requires seq_k == seq_v
|
||||||
|
- We use continuous batching by default, so the batch dimension is
|
||||||
|
always 1, and different requests are concatenated along sequence
|
||||||
|
dimension.
|
||||||
|
- We use paged cache blocks (key_cache, value_cache) to store KV cache.
|
||||||
|
|
||||||
|
IO tensor dtypes:
|
||||||
|
- This kernel assumes all IO tensors have the same dtype except for
|
||||||
|
block_tables (int32) and mask (int32)
|
||||||
|
- If mixed_percision is True, then all Tensor Engine operation will be
|
||||||
|
performed in bfloat16 and accumulation will be performed in float32.
|
||||||
|
Otherwise the intermediates will be in the same type as the inputs.
|
||||||
|
|
||||||
|
Compile-time Constants:
|
||||||
|
- softmax_scale: scaling for softmax, is None, default is `1.0/(d**0.5)`
|
||||||
|
- mixed_precision: flag to set non-matmul ops in fp32 precision, default
|
||||||
|
is set to `true`, if false, we use same precision as input types
|
||||||
|
- config: Instance of dataclass :class:`nki.kernels.attention.FlashConfig`
|
||||||
|
with Performance config parameters for flash attention with default
|
||||||
|
values
|
||||||
|
seq_tile_size: `default=2048`, size of the kv tile size for attention
|
||||||
|
computation reduction
|
||||||
|
|
||||||
|
GQA support Notes:
|
||||||
|
the spmd kernel for launching kernel should be on kv_heads instead of
|
||||||
|
nheads
|
||||||
|
|
||||||
|
Example usage:
|
||||||
|
MHA: q: [b, h, d, s], k: [b, h, d, s], v: [b, h, s, d]
|
||||||
|
usage: `flash_fwd[b, h](q, k, v, ...)`
|
||||||
|
GQA: q: [b, h, d, s], k: [b, kv_h, d, s], v: [b, kv_h, s, d]
|
||||||
|
usage: `flash_fwd[b, kv_h](q, k, v, ...)`
|
||||||
|
"""
|
||||||
|
config = config or FlashConfig()
|
||||||
|
B_F_SIZE = 512
|
||||||
|
B_P_SIZE = 128
|
||||||
|
b, h, d, seqlen_q = query.shape
|
||||||
|
B_D_SIZE = d
|
||||||
|
LARGE_TILE_SZ = config.seq_tile_size
|
||||||
|
n_tile_q = seqlen_q // B_P_SIZE # since q will be loaded on tensor engine
|
||||||
|
num_blocks, block_size, k_h, _ = key_cache.shape
|
||||||
|
q_h_per_k_h = h // k_h
|
||||||
|
assert tuple(key_cache.shape) == (
|
||||||
|
num_blocks,
|
||||||
|
block_size,
|
||||||
|
k_h,
|
||||||
|
d,
|
||||||
|
), "Input shape mismatch!"
|
||||||
|
assert tuple(value_cache.shape) == (
|
||||||
|
num_blocks,
|
||||||
|
block_size,
|
||||||
|
k_h,
|
||||||
|
d,
|
||||||
|
), "Input shape mismatch!"
|
||||||
|
assert b == 1, f"invalid batch size {b=}"
|
||||||
|
assert d <= 128, f" we do not support head_dim > 128, got head dim {d}"
|
||||||
|
kernel_dtype = nl.bfloat16 if mixed_precision else query.dtype
|
||||||
|
acc_type = np.dtype(np.float32) if mixed_precision else kernel_dtype
|
||||||
|
|
||||||
|
o = nl.ndarray((b, h, seqlen_q, d),
|
||||||
|
dtype=query.dtype,
|
||||||
|
buffer=nl.shared_hbm)
|
||||||
|
hbm_l_buffer, hbm_m_buffer, hbm_qk_res, qk_res_buffer = (
|
||||||
|
None,
|
||||||
|
None,
|
||||||
|
None,
|
||||||
|
None,
|
||||||
|
)
|
||||||
|
if return_debug_tensors:
|
||||||
|
hbm_l_buffer = nl.ndarray((b, h, seqlen_q),
|
||||||
|
dtype=acc_type,
|
||||||
|
buffer=nl.shared_hbm)
|
||||||
|
hbm_m_buffer = nl.ndarray((b, h, seqlen_q),
|
||||||
|
dtype=acc_type,
|
||||||
|
buffer=nl.shared_hbm)
|
||||||
|
hbm_qk_res = nl.ndarray((b, h, B_P_SIZE, seqlen_q),
|
||||||
|
dtype=acc_type,
|
||||||
|
buffer=nl.shared_hbm)
|
||||||
|
qk_res_buffer = nl.zeros(
|
||||||
|
(n_tile_q, q_h_per_k_h, par_dim(B_P_SIZE), seqlen_q),
|
||||||
|
dtype=acc_type,
|
||||||
|
buffer=nl.sbuf,
|
||||||
|
lazy_initialization=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert (
|
||||||
|
nl.program_ndim() == 2
|
||||||
|
), f"Expect spmd grid with 2 dimensions, got {nl.program_ndim()} instead!"
|
||||||
|
batch_id = nl.program_id(axis=0)
|
||||||
|
head_id = nl.program_id(axis=1)
|
||||||
|
|
||||||
|
softmax_scale = softmax_scale or (1.0 / (d**0.5))
|
||||||
|
|
||||||
|
(num_active_blocks, ) = block_tables.shape
|
||||||
|
context_kv_len = num_active_blocks * block_size
|
||||||
|
assert (config.seq_tile_size >= 512
|
||||||
|
), f" seq tile_size {config.seq_tile_size} cannot be less than 512"
|
||||||
|
assert (context_kv_len % LARGE_TILE_SZ == 0
|
||||||
|
), f"Need {context_kv_len=} to be divisible by {LARGE_TILE_SZ=}"
|
||||||
|
assert (
|
||||||
|
LARGE_TILE_SZ % B_P_SIZE == 0
|
||||||
|
), f"Need LARGE_TILE_SZ ({LARGE_TILE_SZ}) to be divisible by {B_P_SIZE=}"
|
||||||
|
assert (B_P_SIZE % block_size == 0
|
||||||
|
), f"Need B_P_SIZE ({B_P_SIZE}) to be divisible by {block_size=}"
|
||||||
|
num_large_k_tile = context_kv_len // LARGE_TILE_SZ
|
||||||
|
num_blocks_per_large_tile = LARGE_TILE_SZ // block_size
|
||||||
|
assert (num_blocks_per_large_tile <= B_P_SIZE
|
||||||
|
), f"The number of blocks in each large tile " \
|
||||||
|
f"({num_blocks_per_large_tile}) shouldn't exceed partition size {B_P_SIZE}"
|
||||||
|
|
||||||
|
block_tables_sbuf = nl.full((par_dim(B_P_SIZE), num_large_k_tile),
|
||||||
|
0,
|
||||||
|
dtype=np.int32,
|
||||||
|
buffer=nl.sbuf)
|
||||||
|
for j in nl.affine_range(num_large_k_tile):
|
||||||
|
i_p = nl.arange(num_blocks_per_large_tile)[:, None]
|
||||||
|
block_tables_sbuf[i_p, j] = nl.load(
|
||||||
|
block_tables[j * num_blocks_per_large_tile + i_p], dtype=np.int32)
|
||||||
|
|
||||||
|
# Global Flash Attention accumulators
|
||||||
|
o_buffer = nl.zeros(
|
||||||
|
(n_tile_q, q_h_per_k_h, par_dim(B_P_SIZE), d),
|
||||||
|
dtype=acc_type,
|
||||||
|
buffer=nl.sbuf,
|
||||||
|
lazy_initialization=True,
|
||||||
|
)
|
||||||
|
l_buffer = nl.zeros(
|
||||||
|
(par_dim(B_P_SIZE), n_tile_q, q_h_per_k_h),
|
||||||
|
dtype=acc_type,
|
||||||
|
buffer=nl.sbuf,
|
||||||
|
lazy_initialization=True,
|
||||||
|
)
|
||||||
|
m_buffer = nl.zeros(
|
||||||
|
(n_tile_q, q_h_per_k_h, par_dim(B_P_SIZE), 1),
|
||||||
|
dtype=acc_type,
|
||||||
|
buffer=nl.sbuf,
|
||||||
|
lazy_initialization=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
for j in nl.sequential_range(0, num_large_k_tile):
|
||||||
|
cur_k_tile = nl.ndarray((par_dim(B_D_SIZE), LARGE_TILE_SZ),
|
||||||
|
dtype=kernel_dtype)
|
||||||
|
cur_v_tile = nl.ndarray(
|
||||||
|
(LARGE_TILE_SZ // B_P_SIZE, par_dim(B_P_SIZE), B_D_SIZE),
|
||||||
|
dtype=kernel_dtype,
|
||||||
|
)
|
||||||
|
|
||||||
|
for k_i in nl.affine_range(num_blocks_per_large_tile):
|
||||||
|
loaded = nl.load(key_cache[block_tables_sbuf[k_i, j], :,
|
||||||
|
head_id, :])
|
||||||
|
cur_k_tile[:, nl.ds(k_i *
|
||||||
|
block_size, block_size)] = nl.transpose(loaded)
|
||||||
|
|
||||||
|
load_tile_size = B_P_SIZE
|
||||||
|
num_blocks_per_partition = load_tile_size // block_size
|
||||||
|
for partition_idx in nl.affine_range(LARGE_TILE_SZ // load_tile_size):
|
||||||
|
for block_in_partition in nl.affine_range(
|
||||||
|
num_blocks_per_partition):
|
||||||
|
v_i = (partition_idx * num_blocks_per_partition +
|
||||||
|
block_in_partition)
|
||||||
|
loaded_v = nl.load(value_cache[block_tables_sbuf[v_i, j], :,
|
||||||
|
head_id, :])
|
||||||
|
cur_v_tile[partition_idx,
|
||||||
|
nl.ds(block_in_partition *
|
||||||
|
block_size, block_size), :, ] = loaded_v
|
||||||
|
|
||||||
|
cur_mask = nl.ndarray((par_dim(B_P_SIZE), LARGE_TILE_SZ),
|
||||||
|
dtype=mask.dtype)
|
||||||
|
for m_i in nl.affine_range(LARGE_TILE_SZ // B_F_SIZE):
|
||||||
|
cur_mask[:, nl.ds(m_i * B_F_SIZE, B_F_SIZE)] = nl.load(
|
||||||
|
mask[:, nl.ds(j * LARGE_TILE_SZ + m_i * B_F_SIZE, B_F_SIZE)])
|
||||||
|
|
||||||
|
for i_q_h in nl.affine_range(q_h_per_k_h):
|
||||||
|
for i in nl.affine_range(n_tile_q):
|
||||||
|
q_tile = nl.ndarray((B_D_SIZE, B_P_SIZE), dtype=kernel_dtype)
|
||||||
|
q_hbm_tile = query[batch_id, head_id * q_h_per_k_h + i_q_h]
|
||||||
|
q_sbuf_tile = nl.load(
|
||||||
|
q_hbm_tile[:, nl.ds(i * B_P_SIZE, B_P_SIZE)],
|
||||||
|
dtype=kernel_dtype,
|
||||||
|
) # load (d, 128) tile in SBUF
|
||||||
|
q_tile[:, :] = q_sbuf_tile * softmax_scale
|
||||||
|
|
||||||
|
_flash_attention_core(
|
||||||
|
q_local_tile=q_tile,
|
||||||
|
k=cur_k_tile,
|
||||||
|
v=cur_v_tile,
|
||||||
|
q_h_per_k_h=q_h_per_k_h,
|
||||||
|
seqlen_q=seqlen_q,
|
||||||
|
nheads=h,
|
||||||
|
o_buffer=o_buffer[i, i_q_h],
|
||||||
|
l_buffer=l_buffer[:, i, i_q_h],
|
||||||
|
m_buffer=m_buffer[i, i_q_h],
|
||||||
|
batch_id=batch_id,
|
||||||
|
head_id=head_id,
|
||||||
|
gqa_head_idx=i_q_h,
|
||||||
|
q_tile_idx=i,
|
||||||
|
local_k_large_tile_idx=j,
|
||||||
|
kernel_dtype=kernel_dtype,
|
||||||
|
acc_type=acc_type,
|
||||||
|
flash_config=config,
|
||||||
|
use_causal_mask=False,
|
||||||
|
continuous_batching_mask=cur_mask,
|
||||||
|
initialize=j == 0,
|
||||||
|
B_P_SIZE=B_P_SIZE,
|
||||||
|
B_F_SIZE=B_F_SIZE,
|
||||||
|
B_D_SIZE=B_D_SIZE,
|
||||||
|
dropout_p=0.0,
|
||||||
|
dropout_p_tensor=None,
|
||||||
|
seed_tensor=None,
|
||||||
|
logit_bias_tile=None,
|
||||||
|
)
|
||||||
|
|
||||||
|
# compute attention between input query, key and value
|
||||||
|
if key is not None and value is not None:
|
||||||
|
B_F_SIZE = seqlen_q
|
||||||
|
LARGE_TILE_SZ = seqlen_q
|
||||||
|
active_config = FlashConfig(
|
||||||
|
seq_tile_size=LARGE_TILE_SZ,
|
||||||
|
should_transpose_v=config.should_transpose_v,
|
||||||
|
)
|
||||||
|
|
||||||
|
cur_k_tile = nl.ndarray((par_dim(B_D_SIZE), LARGE_TILE_SZ),
|
||||||
|
dtype=kernel_dtype)
|
||||||
|
cur_v_tile = nl.ndarray(
|
||||||
|
(LARGE_TILE_SZ // B_P_SIZE, par_dim(B_P_SIZE), B_D_SIZE),
|
||||||
|
dtype=kernel_dtype,
|
||||||
|
)
|
||||||
|
|
||||||
|
cur_k_tile[:, :] = nl.load(key[batch_id, head_id, :, :])
|
||||||
|
|
||||||
|
load_tile_size = B_P_SIZE
|
||||||
|
v_hbm_tile = value[batch_id, head_id]
|
||||||
|
for v_i in nl.affine_range(LARGE_TILE_SZ // load_tile_size):
|
||||||
|
load_v_tile(
|
||||||
|
v_hbm_tile=v_hbm_tile,
|
||||||
|
cur_v_tile=cur_v_tile,
|
||||||
|
j=0,
|
||||||
|
v_i=v_i,
|
||||||
|
config=active_config,
|
||||||
|
)
|
||||||
|
|
||||||
|
cur_mask = nl.ndarray((par_dim(B_P_SIZE), B_F_SIZE), dtype=mask.dtype)
|
||||||
|
cur_mask[:, :] = nl.load(mask[:, nl.ds(context_kv_len, B_F_SIZE)])
|
||||||
|
|
||||||
|
for i_q_h in nl.affine_range(q_h_per_k_h):
|
||||||
|
for i in nl.affine_range(n_tile_q):
|
||||||
|
q_tile = nl.ndarray((B_D_SIZE, B_P_SIZE), dtype=kernel_dtype)
|
||||||
|
q_hbm_tile = query[batch_id, head_id * q_h_per_k_h + i_q_h]
|
||||||
|
q_sbuf_tile = nl.load(
|
||||||
|
q_hbm_tile[:, nl.ds(i * B_P_SIZE, B_P_SIZE)],
|
||||||
|
dtype=kernel_dtype,
|
||||||
|
) # load (d, 128) tile in SBUF
|
||||||
|
q_tile[:, :] = q_sbuf_tile * softmax_scale
|
||||||
|
_flash_attention_core(
|
||||||
|
q_local_tile=q_tile,
|
||||||
|
k=cur_k_tile,
|
||||||
|
v=cur_v_tile,
|
||||||
|
q_h_per_k_h=q_h_per_k_h,
|
||||||
|
seqlen_q=seqlen_q,
|
||||||
|
nheads=h,
|
||||||
|
o_buffer=o_buffer[i, i_q_h],
|
||||||
|
l_buffer=l_buffer[:, i, i_q_h],
|
||||||
|
m_buffer=m_buffer[i, i_q_h],
|
||||||
|
batch_id=batch_id,
|
||||||
|
head_id=head_id,
|
||||||
|
gqa_head_idx=i_q_h,
|
||||||
|
q_tile_idx=i,
|
||||||
|
local_k_large_tile_idx=0,
|
||||||
|
kernel_dtype=kernel_dtype,
|
||||||
|
acc_type=acc_type,
|
||||||
|
flash_config=active_config,
|
||||||
|
use_causal_mask=False,
|
||||||
|
continuous_batching_mask=cur_mask,
|
||||||
|
initialize=False,
|
||||||
|
B_P_SIZE=B_P_SIZE,
|
||||||
|
B_F_SIZE=B_F_SIZE,
|
||||||
|
B_D_SIZE=B_D_SIZE,
|
||||||
|
dropout_p=0.0,
|
||||||
|
dropout_p_tensor=None,
|
||||||
|
seed_tensor=None,
|
||||||
|
logit_bias_tile=None,
|
||||||
|
qk_res_buffer=qk_res_buffer[i, i_q_h]
|
||||||
|
if qk_res_buffer is not None else None,
|
||||||
|
)
|
||||||
|
|
||||||
|
# -- -- -- -- write output to buffer on HBM -- -- -- -- -- -- #
|
||||||
|
for i_q_h in nl.affine_range(q_h_per_k_h):
|
||||||
|
for i in nl.affine_range(n_tile_q):
|
||||||
|
out = nl.multiply(
|
||||||
|
o_buffer[i, i_q_h, :, :],
|
||||||
|
nl.exp(m_buffer[i, i_q_h, :, :] - l_buffer[:, i, i_q_h]),
|
||||||
|
dtype=kernel_dtype,
|
||||||
|
)
|
||||||
|
|
||||||
|
nl.store(
|
||||||
|
o[batch_id, head_id * q_h_per_k_h + i_q_h,
|
||||||
|
nl.ds(i * B_P_SIZE, B_P_SIZE), :, ],
|
||||||
|
out,
|
||||||
|
)
|
||||||
|
# maximum and summation statistics
|
||||||
|
if return_debug_tensors:
|
||||||
|
nl.store(
|
||||||
|
hbm_m_buffer[batch_id, head_id * q_h_per_k_h + i_q_h,
|
||||||
|
nl.ds(i * B_P_SIZE, B_P_SIZE), ],
|
||||||
|
m_buffer[i, i_q_h, :, :],
|
||||||
|
)
|
||||||
|
nl.store(
|
||||||
|
hbm_l_buffer[batch_id, head_id * q_h_per_k_h + i_q_h,
|
||||||
|
nl.ds(i * B_P_SIZE, B_P_SIZE), ],
|
||||||
|
l_buffer[:, i, i_q_h],
|
||||||
|
)
|
||||||
|
nl.store(
|
||||||
|
hbm_qk_res[batch_id, head_id * q_h_per_k_h + i_q_h, :, :],
|
||||||
|
qk_res_buffer[batch_id, i_q_h, :, :],
|
||||||
|
)
|
||||||
|
|
||||||
|
if return_debug_tensors:
|
||||||
|
return o, hbm_m_buffer, hbm_l_buffer, hbm_qk_res
|
||||||
|
return o
|
||||||
|
|
||||||
|
|
||||||
|
def flash_attn_varlen_nkifunc(
|
||||||
|
query,
|
||||||
|
key,
|
||||||
|
value,
|
||||||
|
key_cache,
|
||||||
|
value_cache,
|
||||||
|
block_table,
|
||||||
|
attn_mask,
|
||||||
|
n_kv_head=None,
|
||||||
|
head_size=None,
|
||||||
|
B_P_SIZE=128,
|
||||||
|
LARGE_TILE_SZ=2048,
|
||||||
|
return_debug_tensors=False,
|
||||||
|
mixed_precision=True,
|
||||||
|
):
|
||||||
|
config = FlashConfig(
|
||||||
|
seq_tile_size=LARGE_TILE_SZ,
|
||||||
|
should_transpose_v=False,
|
||||||
|
)
|
||||||
|
kwargs = dict(
|
||||||
|
query=query,
|
||||||
|
key=key,
|
||||||
|
value=value,
|
||||||
|
key_cache=key_cache,
|
||||||
|
value_cache=value_cache,
|
||||||
|
block_tables=block_table,
|
||||||
|
mask=attn_mask,
|
||||||
|
softmax_scale=1.0 / (head_size**0.5),
|
||||||
|
config=config,
|
||||||
|
mixed_precision=mixed_precision,
|
||||||
|
return_debug_tensors=return_debug_tensors,
|
||||||
|
)
|
||||||
|
_, n_kv_head, _, _ = key.shape
|
||||||
|
|
||||||
|
if return_debug_tensors:
|
||||||
|
o, *debug_tensors = flash_paged_attention[1, n_kv_head](**kwargs)
|
||||||
|
return o, *debug_tensors
|
||||||
|
else:
|
||||||
|
o = flash_paged_attention[1, n_kv_head](**kwargs)
|
||||||
|
return o
|
Loading…
x
Reference in New Issue
Block a user