[Neuron][kernel] Fuse kv cache into a single tensor (#15911)
Signed-off-by: Liangfu Chen <liangfc@amazon.com>
This commit is contained in:
parent
82e7e19a6e
commit
d2b58ca203
@ -64,9 +64,11 @@ def test_reshape_and_cache(num_tokens, n_kv_head, d_head, num_blocks,
|
||||
key_cache = torch.zeros_like(key_cache_cpu, device=device)
|
||||
value_cache = torch.zeros_like(value_cache_cpu, device=device)
|
||||
slot_mapping = slot_mapping_cpu.to(device)
|
||||
kv_cache = torch.stack([key_cache, value_cache])
|
||||
|
||||
# Run vectorized implementation on XLA device
|
||||
reshape_and_cache(key, value, key_cache, value_cache, slot_mapping)
|
||||
reshape_and_cache(key, value, kv_cache, slot_mapping)
|
||||
key_cache, value_cache = torch.unbind(kv_cache, dim=0)
|
||||
|
||||
# Move results back to CPU for comparison
|
||||
key_cache_result = key_cache.cpu()
|
||||
|
@ -258,13 +258,13 @@ def sample_inputs(
|
||||
value[start_loc:end_loc])
|
||||
cur_ctx += block_size
|
||||
block_id += 1
|
||||
kv_cache = torch.stack([k_cache, v_cache])
|
||||
|
||||
return (
|
||||
query,
|
||||
k,
|
||||
v,
|
||||
k_cache,
|
||||
v_cache,
|
||||
kv_cache,
|
||||
block_table,
|
||||
key,
|
||||
value,
|
||||
@ -361,8 +361,7 @@ def test_contexted_kv_attention(
|
||||
query,
|
||||
k_active,
|
||||
v_active,
|
||||
k_cache,
|
||||
v_cache,
|
||||
kv_cache,
|
||||
block_table,
|
||||
key,
|
||||
value,
|
||||
@ -439,8 +438,7 @@ def test_contexted_kv_attention(
|
||||
query = query.unsqueeze(0).permute(0, 2, 3, 1).contiguous()
|
||||
k = k.unsqueeze(0).permute(0, 2, 3, 1).contiguous()
|
||||
v = v.unsqueeze(0).permute(0, 2, 1, 3).contiguous()
|
||||
k_cache = k_cache.permute(0, 2, 1, 3).contiguous()
|
||||
v_cache = v_cache.permute(0, 2, 1, 3).contiguous()
|
||||
kv_cache = kv_cache.permute(0, 1, 3, 2, 4).contiguous()
|
||||
|
||||
# transform block table
|
||||
active_block_table = get_active_block_tables(
|
||||
@ -487,8 +485,7 @@ def test_contexted_kv_attention(
|
||||
query.to(device=device),
|
||||
k.to(device=device),
|
||||
v.to(device=device),
|
||||
k_cache.to(device=device),
|
||||
v_cache.to(device=device),
|
||||
kv_cache.to(device=device),
|
||||
active_block_table.to(device=device),
|
||||
attn_mask.to(device=device),
|
||||
)
|
||||
|
@ -144,8 +144,7 @@ def transform_block_tables_for_indirect_load(
|
||||
def load_kv_tile_from_cache(
|
||||
cur_k_tile,
|
||||
cur_v_tile,
|
||||
key_cache,
|
||||
value_cache,
|
||||
kv_cache,
|
||||
block_tables,
|
||||
large_k_tile_idx,
|
||||
num_blocks_per_large_tile,
|
||||
@ -169,7 +168,7 @@ def load_kv_tile_from_cache(
|
||||
for load_idx in nl.affine_range(num_loads):
|
||||
i_p = nl.arange(B_P_SIZE)[:, None]
|
||||
i_f = nl.arange(tiled_block_size * B_D_SIZE)[None, :]
|
||||
loaded = nl.load(key_cache[block_tables[load_idx, i_p,
|
||||
loaded = nl.load(kv_cache[0, block_tables[load_idx, i_p,
|
||||
large_k_tile_idx], i_f])
|
||||
if cur_k_tile.dtype != loaded.dtype:
|
||||
loaded = nl.copy(loaded, dtype=cur_k_tile.dtype)
|
||||
@ -185,7 +184,7 @@ def load_kv_tile_from_cache(
|
||||
|
||||
# load value cache
|
||||
for load_idx in nl.affine_range(num_loads):
|
||||
loaded = nl.load(value_cache[block_tables[load_idx, i_p,
|
||||
loaded = nl.load(kv_cache[1, block_tables[load_idx, i_p,
|
||||
large_k_tile_idx], i_f])
|
||||
if cur_v_tile.dtype != loaded.dtype:
|
||||
loaded = nl.copy(loaded, dtype=cur_v_tile.dtype)
|
||||
@ -418,8 +417,7 @@ def flash_paged_attention(
|
||||
query,
|
||||
key,
|
||||
value,
|
||||
key_cache,
|
||||
value_cache,
|
||||
kv_cache,
|
||||
block_tables,
|
||||
mask,
|
||||
softmax_scale=None,
|
||||
@ -434,8 +432,7 @@ def flash_paged_attention(
|
||||
- query: shape (1, n_heads, d, seq_q)
|
||||
- key: shape (1, n_kv_heads, d, seq_k)
|
||||
- value: shape (1, n_kv_heads, seq_v, d)
|
||||
- key_cache: (num_blocks, n_kv_heads, block_size, d)
|
||||
- value_cache: (num_blocks, n_kv_heads, block_size, d)
|
||||
- kv_cache: (2, num_blocks, n_kv_heads, block_size, d)
|
||||
- block_tables: (num_active_blocks, )
|
||||
- mask: (seq_q, num_active_blocks * block_size + seq_q)
|
||||
- o: shape (1, n_heads, seq_q, d)
|
||||
@ -444,7 +441,7 @@ def flash_paged_attention(
|
||||
- We use continuous batching by default, so the batch dimension is
|
||||
always 1, and different requests are concatenated along sequence
|
||||
dimension.
|
||||
- We use paged cache blocks (key_cache, value_cache) to store KV cache.
|
||||
- We use paged cache blocks (kv_cache) to store KV cache.
|
||||
|
||||
IO tensor dtypes:
|
||||
- This kernel assumes all IO tensors have the same dtype except for
|
||||
@ -475,15 +472,13 @@ def flash_paged_attention(
|
||||
b, h, d, seqlen_q = query.shape
|
||||
B_D_SIZE = d
|
||||
n_tile_q = seqlen_q // B_P_SIZE # since q will be loaded on tensor engine
|
||||
num_blocks, k_h, block_size, _ = key_cache.shape
|
||||
_, num_blocks, k_h, block_size, _ = kv_cache.shape
|
||||
q_h_per_k_h = h // k_h
|
||||
assert b == 1, f"invalid batch size {b=}"
|
||||
assert d <= 128, f" we do not support head_dim > 128, got head dim {d=}"
|
||||
cache_shape = (num_blocks, k_h, block_size, d)
|
||||
assert (tuple(key_cache.shape) == cache_shape
|
||||
), f"{key_cache.shape=} mismatch, expect {cache_shape}"
|
||||
assert (tuple(value_cache.shape) == cache_shape
|
||||
), f"{value_cache.shape=} mismatch, expect {cache_shape}"
|
||||
cache_shape = (2, num_blocks, k_h, block_size, d)
|
||||
assert (tuple(kv_cache.shape) == cache_shape
|
||||
), f"{kv_cache.shape=} mismatch, expect {cache_shape}"
|
||||
assert key is None or tuple(key.shape) == (
|
||||
1,
|
||||
k_h,
|
||||
@ -580,13 +575,13 @@ def flash_paged_attention(
|
||||
head_id=head_id,
|
||||
)
|
||||
|
||||
# Flatten KV cache to be 2D for loading into SBUF
|
||||
# Flatten KV cache to be 3D for loading into SBUF
|
||||
new_cache_shape = (
|
||||
2,
|
||||
num_blocks * k_h * block_size_tiling_factor,
|
||||
tiled_block_size * d,
|
||||
)
|
||||
key_cache = key_cache.reshape(new_cache_shape)
|
||||
value_cache = value_cache.reshape(new_cache_shape)
|
||||
kv_cache = kv_cache.reshape(new_cache_shape)
|
||||
|
||||
# Global Flash Attention accumulators
|
||||
o_buffer = nl.zeros(
|
||||
@ -621,8 +616,7 @@ def flash_paged_attention(
|
||||
load_kv_tile_from_cache(
|
||||
cur_k_tile=cur_k_tile,
|
||||
cur_v_tile=cur_v_tile,
|
||||
key_cache=key_cache,
|
||||
value_cache=value_cache,
|
||||
kv_cache=kv_cache,
|
||||
block_tables=block_tables_sbuf,
|
||||
large_k_tile_idx=large_k_tile_idx,
|
||||
num_blocks_per_large_tile=num_blocks_per_large_tile,
|
||||
@ -821,8 +815,7 @@ def flash_attn_varlen_nkifunc(
|
||||
query,
|
||||
key,
|
||||
value,
|
||||
key_cache,
|
||||
value_cache,
|
||||
kv_cache,
|
||||
block_table,
|
||||
attn_mask,
|
||||
n_kv_head=None,
|
||||
@ -838,8 +831,7 @@ def flash_attn_varlen_nkifunc(
|
||||
- query: (1, n_heads, d, seq_q)
|
||||
- key: (1, n_kv_heads, d, seq_k)
|
||||
- value: (1, n_kv_heads, seq_v, d)
|
||||
- key_cache: (n_blocks, n_kv_heads, block_size, d)
|
||||
- value_cache: (n_blocks, n_kv_heads, block_size, d)
|
||||
- kv_cache: (2, n_blocks, n_kv_heads, block_size, d)
|
||||
- block_tables: (n_active_blocks, )
|
||||
- attn_mask: (seq_q, n_active_blocks * block_size + seq_q)
|
||||
|
||||
@ -849,17 +841,17 @@ def flash_attn_varlen_nkifunc(
|
||||
for better DMA throughput
|
||||
"""
|
||||
if n_kv_head is None:
|
||||
n_kv_head = key_cache.shape[1]
|
||||
assert key_cache.shape[1] == n_kv_head
|
||||
n_kv_head = kv_cache.shape[2]
|
||||
assert kv_cache.shape[0] == 2
|
||||
assert kv_cache.shape[2] == n_kv_head
|
||||
if head_size is None:
|
||||
head_size = key_cache.shape[-1]
|
||||
head_size = kv_cache.shape[-1]
|
||||
|
||||
kwargs = dict(
|
||||
query=query,
|
||||
key=key,
|
||||
value=value,
|
||||
key_cache=key_cache,
|
||||
value_cache=value_cache,
|
||||
kv_cache=kv_cache,
|
||||
block_tables=block_table,
|
||||
mask=attn_mask,
|
||||
softmax_scale=1.0 / (head_size**0.5),
|
||||
@ -874,8 +866,7 @@ def flash_attn_varlen_nkifunc(
|
||||
def reshape_and_cache(
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
key_cache: torch.Tensor,
|
||||
value_cache: torch.Tensor,
|
||||
kv_cache: torch.Tensor,
|
||||
slot_mapping: torch.Tensor,
|
||||
) -> None:
|
||||
"""
|
||||
@ -886,29 +877,29 @@ def reshape_and_cache(
|
||||
(num_tokens, n_kv_head, d_head)
|
||||
value (torch.Tensor): Value tensor with shape
|
||||
(num_tokens, n_kv_head, d_head)
|
||||
key_cache (torch.Tensor): Key cache tensor with shape
|
||||
(num_blocks, n_kv_head, block_size, d_head)
|
||||
value_cache (torch.Tensor): Value cache tensor with shape
|
||||
(num_blocks, n_kv_head, block_size, d_head)
|
||||
kv_cache (torch.Tensor): Key/value cache tensor with shape
|
||||
(2, num_blocks, n_kv_head, block_size, d_head)
|
||||
slot_mapping (torch.Tensor): Mapping tensor indicating cache positions
|
||||
with shape (num_tokens)
|
||||
|
||||
Returns:
|
||||
None: Updates the key_cache and value_cache tensors in-place
|
||||
None: Updates the kv_cache tensor in-place
|
||||
"""
|
||||
block_size = key_cache.size(2)
|
||||
block_size = kv_cache.size(3)
|
||||
n_kv_head = key.size(1)
|
||||
|
||||
# Calculate indices with explicit floor division
|
||||
block_indices = torch.div(slot_mapping, block_size, rounding_mode="floor")
|
||||
block_offsets = slot_mapping % block_size
|
||||
|
||||
# Update caches using index_put_
|
||||
key_cache.index_put_(
|
||||
(block_indices.unsqueeze(1),
|
||||
torch.arange(key_cache.size(1),
|
||||
device=key.device), block_offsets.unsqueeze(1)), key)
|
||||
# Create the head indices tensor
|
||||
head_indices = torch.arange(n_kv_head, device=key.device)
|
||||
|
||||
value_cache.index_put_(
|
||||
(block_indices.unsqueeze(1),
|
||||
torch.arange(value_cache.size(1),
|
||||
device=value.device), block_offsets.unsqueeze(1)), value)
|
||||
# Update caches using index_put_
|
||||
kv_cache.index_put_(
|
||||
(torch.tensor([0], device=key.device), block_indices[:, None],
|
||||
head_indices[None, :], block_offsets[:, None]), key)
|
||||
|
||||
kv_cache.index_put_(
|
||||
(torch.tensor([1], device=key.device), block_indices[:, None],
|
||||
head_indices[None, :], block_offsets[:, None]), value)
|
||||
|
Loading…
x
Reference in New Issue
Block a user