[Neuron][kernel] Fuse kv cache into a single tensor (#15911)

Signed-off-by: Liangfu Chen <liangfc@amazon.com>
This commit is contained in:
Liangfu Chen 2025-04-03 09:51:32 -07:00 committed by GitHub
parent 82e7e19a6e
commit d2b58ca203
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 46 additions and 56 deletions

View File

@ -64,9 +64,11 @@ def test_reshape_and_cache(num_tokens, n_kv_head, d_head, num_blocks,
key_cache = torch.zeros_like(key_cache_cpu, device=device) key_cache = torch.zeros_like(key_cache_cpu, device=device)
value_cache = torch.zeros_like(value_cache_cpu, device=device) value_cache = torch.zeros_like(value_cache_cpu, device=device)
slot_mapping = slot_mapping_cpu.to(device) slot_mapping = slot_mapping_cpu.to(device)
kv_cache = torch.stack([key_cache, value_cache])
# Run vectorized implementation on XLA device # Run vectorized implementation on XLA device
reshape_and_cache(key, value, key_cache, value_cache, slot_mapping) reshape_and_cache(key, value, kv_cache, slot_mapping)
key_cache, value_cache = torch.unbind(kv_cache, dim=0)
# Move results back to CPU for comparison # Move results back to CPU for comparison
key_cache_result = key_cache.cpu() key_cache_result = key_cache.cpu()

View File

@ -258,13 +258,13 @@ def sample_inputs(
value[start_loc:end_loc]) value[start_loc:end_loc])
cur_ctx += block_size cur_ctx += block_size
block_id += 1 block_id += 1
kv_cache = torch.stack([k_cache, v_cache])
return ( return (
query, query,
k, k,
v, v,
k_cache, kv_cache,
v_cache,
block_table, block_table,
key, key,
value, value,
@ -361,8 +361,7 @@ def test_contexted_kv_attention(
query, query,
k_active, k_active,
v_active, v_active,
k_cache, kv_cache,
v_cache,
block_table, block_table,
key, key,
value, value,
@ -439,8 +438,7 @@ def test_contexted_kv_attention(
query = query.unsqueeze(0).permute(0, 2, 3, 1).contiguous() query = query.unsqueeze(0).permute(0, 2, 3, 1).contiguous()
k = k.unsqueeze(0).permute(0, 2, 3, 1).contiguous() k = k.unsqueeze(0).permute(0, 2, 3, 1).contiguous()
v = v.unsqueeze(0).permute(0, 2, 1, 3).contiguous() v = v.unsqueeze(0).permute(0, 2, 1, 3).contiguous()
k_cache = k_cache.permute(0, 2, 1, 3).contiguous() kv_cache = kv_cache.permute(0, 1, 3, 2, 4).contiguous()
v_cache = v_cache.permute(0, 2, 1, 3).contiguous()
# transform block table # transform block table
active_block_table = get_active_block_tables( active_block_table = get_active_block_tables(
@ -487,8 +485,7 @@ def test_contexted_kv_attention(
query.to(device=device), query.to(device=device),
k.to(device=device), k.to(device=device),
v.to(device=device), v.to(device=device),
k_cache.to(device=device), kv_cache.to(device=device),
v_cache.to(device=device),
active_block_table.to(device=device), active_block_table.to(device=device),
attn_mask.to(device=device), attn_mask.to(device=device),
) )

View File

@ -144,8 +144,7 @@ def transform_block_tables_for_indirect_load(
def load_kv_tile_from_cache( def load_kv_tile_from_cache(
cur_k_tile, cur_k_tile,
cur_v_tile, cur_v_tile,
key_cache, kv_cache,
value_cache,
block_tables, block_tables,
large_k_tile_idx, large_k_tile_idx,
num_blocks_per_large_tile, num_blocks_per_large_tile,
@ -169,8 +168,8 @@ def load_kv_tile_from_cache(
for load_idx in nl.affine_range(num_loads): for load_idx in nl.affine_range(num_loads):
i_p = nl.arange(B_P_SIZE)[:, None] i_p = nl.arange(B_P_SIZE)[:, None]
i_f = nl.arange(tiled_block_size * B_D_SIZE)[None, :] i_f = nl.arange(tiled_block_size * B_D_SIZE)[None, :]
loaded = nl.load(key_cache[block_tables[load_idx, i_p, loaded = nl.load(kv_cache[0, block_tables[load_idx, i_p,
large_k_tile_idx], i_f]) large_k_tile_idx], i_f])
if cur_k_tile.dtype != loaded.dtype: if cur_k_tile.dtype != loaded.dtype:
loaded = nl.copy(loaded, dtype=cur_k_tile.dtype) loaded = nl.copy(loaded, dtype=cur_k_tile.dtype)
# Transpose SBUF tensor using PE # Transpose SBUF tensor using PE
@ -185,7 +184,7 @@ def load_kv_tile_from_cache(
# load value cache # load value cache
for load_idx in nl.affine_range(num_loads): for load_idx in nl.affine_range(num_loads):
loaded = nl.load(value_cache[block_tables[load_idx, i_p, loaded = nl.load(kv_cache[1, block_tables[load_idx, i_p,
large_k_tile_idx], i_f]) large_k_tile_idx], i_f])
if cur_v_tile.dtype != loaded.dtype: if cur_v_tile.dtype != loaded.dtype:
loaded = nl.copy(loaded, dtype=cur_v_tile.dtype) loaded = nl.copy(loaded, dtype=cur_v_tile.dtype)
@ -418,8 +417,7 @@ def flash_paged_attention(
query, query,
key, key,
value, value,
key_cache, kv_cache,
value_cache,
block_tables, block_tables,
mask, mask,
softmax_scale=None, softmax_scale=None,
@ -434,8 +432,7 @@ def flash_paged_attention(
- query: shape (1, n_heads, d, seq_q) - query: shape (1, n_heads, d, seq_q)
- key: shape (1, n_kv_heads, d, seq_k) - key: shape (1, n_kv_heads, d, seq_k)
- value: shape (1, n_kv_heads, seq_v, d) - value: shape (1, n_kv_heads, seq_v, d)
- key_cache: (num_blocks, n_kv_heads, block_size, d) - kv_cache: (2, num_blocks, n_kv_heads, block_size, d)
- value_cache: (num_blocks, n_kv_heads, block_size, d)
- block_tables: (num_active_blocks, ) - block_tables: (num_active_blocks, )
- mask: (seq_q, num_active_blocks * block_size + seq_q) - mask: (seq_q, num_active_blocks * block_size + seq_q)
- o: shape (1, n_heads, seq_q, d) - o: shape (1, n_heads, seq_q, d)
@ -444,7 +441,7 @@ def flash_paged_attention(
- We use continuous batching by default, so the batch dimension is - We use continuous batching by default, so the batch dimension is
always 1, and different requests are concatenated along sequence always 1, and different requests are concatenated along sequence
dimension. dimension.
- We use paged cache blocks (key_cache, value_cache) to store KV cache. - We use paged cache blocks (kv_cache) to store KV cache.
IO tensor dtypes: IO tensor dtypes:
- This kernel assumes all IO tensors have the same dtype except for - This kernel assumes all IO tensors have the same dtype except for
@ -475,15 +472,13 @@ def flash_paged_attention(
b, h, d, seqlen_q = query.shape b, h, d, seqlen_q = query.shape
B_D_SIZE = d B_D_SIZE = d
n_tile_q = seqlen_q // B_P_SIZE # since q will be loaded on tensor engine n_tile_q = seqlen_q // B_P_SIZE # since q will be loaded on tensor engine
num_blocks, k_h, block_size, _ = key_cache.shape _, num_blocks, k_h, block_size, _ = kv_cache.shape
q_h_per_k_h = h // k_h q_h_per_k_h = h // k_h
assert b == 1, f"invalid batch size {b=}" assert b == 1, f"invalid batch size {b=}"
assert d <= 128, f" we do not support head_dim > 128, got head dim {d=}" assert d <= 128, f" we do not support head_dim > 128, got head dim {d=}"
cache_shape = (num_blocks, k_h, block_size, d) cache_shape = (2, num_blocks, k_h, block_size, d)
assert (tuple(key_cache.shape) == cache_shape assert (tuple(kv_cache.shape) == cache_shape
), f"{key_cache.shape=} mismatch, expect {cache_shape}" ), f"{kv_cache.shape=} mismatch, expect {cache_shape}"
assert (tuple(value_cache.shape) == cache_shape
), f"{value_cache.shape=} mismatch, expect {cache_shape}"
assert key is None or tuple(key.shape) == ( assert key is None or tuple(key.shape) == (
1, 1,
k_h, k_h,
@ -580,13 +575,13 @@ def flash_paged_attention(
head_id=head_id, head_id=head_id,
) )
# Flatten KV cache to be 2D for loading into SBUF # Flatten KV cache to be 3D for loading into SBUF
new_cache_shape = ( new_cache_shape = (
2,
num_blocks * k_h * block_size_tiling_factor, num_blocks * k_h * block_size_tiling_factor,
tiled_block_size * d, tiled_block_size * d,
) )
key_cache = key_cache.reshape(new_cache_shape) kv_cache = kv_cache.reshape(new_cache_shape)
value_cache = value_cache.reshape(new_cache_shape)
# Global Flash Attention accumulators # Global Flash Attention accumulators
o_buffer = nl.zeros( o_buffer = nl.zeros(
@ -621,8 +616,7 @@ def flash_paged_attention(
load_kv_tile_from_cache( load_kv_tile_from_cache(
cur_k_tile=cur_k_tile, cur_k_tile=cur_k_tile,
cur_v_tile=cur_v_tile, cur_v_tile=cur_v_tile,
key_cache=key_cache, kv_cache=kv_cache,
value_cache=value_cache,
block_tables=block_tables_sbuf, block_tables=block_tables_sbuf,
large_k_tile_idx=large_k_tile_idx, large_k_tile_idx=large_k_tile_idx,
num_blocks_per_large_tile=num_blocks_per_large_tile, num_blocks_per_large_tile=num_blocks_per_large_tile,
@ -821,8 +815,7 @@ def flash_attn_varlen_nkifunc(
query, query,
key, key,
value, value,
key_cache, kv_cache,
value_cache,
block_table, block_table,
attn_mask, attn_mask,
n_kv_head=None, n_kv_head=None,
@ -838,8 +831,7 @@ def flash_attn_varlen_nkifunc(
- query: (1, n_heads, d, seq_q) - query: (1, n_heads, d, seq_q)
- key: (1, n_kv_heads, d, seq_k) - key: (1, n_kv_heads, d, seq_k)
- value: (1, n_kv_heads, seq_v, d) - value: (1, n_kv_heads, seq_v, d)
- key_cache: (n_blocks, n_kv_heads, block_size, d) - kv_cache: (2, n_blocks, n_kv_heads, block_size, d)
- value_cache: (n_blocks, n_kv_heads, block_size, d)
- block_tables: (n_active_blocks, ) - block_tables: (n_active_blocks, )
- attn_mask: (seq_q, n_active_blocks * block_size + seq_q) - attn_mask: (seq_q, n_active_blocks * block_size + seq_q)
@ -849,17 +841,17 @@ def flash_attn_varlen_nkifunc(
for better DMA throughput for better DMA throughput
""" """
if n_kv_head is None: if n_kv_head is None:
n_kv_head = key_cache.shape[1] n_kv_head = kv_cache.shape[2]
assert key_cache.shape[1] == n_kv_head assert kv_cache.shape[0] == 2
assert kv_cache.shape[2] == n_kv_head
if head_size is None: if head_size is None:
head_size = key_cache.shape[-1] head_size = kv_cache.shape[-1]
kwargs = dict( kwargs = dict(
query=query, query=query,
key=key, key=key,
value=value, value=value,
key_cache=key_cache, kv_cache=kv_cache,
value_cache=value_cache,
block_tables=block_table, block_tables=block_table,
mask=attn_mask, mask=attn_mask,
softmax_scale=1.0 / (head_size**0.5), softmax_scale=1.0 / (head_size**0.5),
@ -874,8 +866,7 @@ def flash_attn_varlen_nkifunc(
def reshape_and_cache( def reshape_and_cache(
key: torch.Tensor, key: torch.Tensor,
value: torch.Tensor, value: torch.Tensor,
key_cache: torch.Tensor, kv_cache: torch.Tensor,
value_cache: torch.Tensor,
slot_mapping: torch.Tensor, slot_mapping: torch.Tensor,
) -> None: ) -> None:
""" """
@ -886,29 +877,29 @@ def reshape_and_cache(
(num_tokens, n_kv_head, d_head) (num_tokens, n_kv_head, d_head)
value (torch.Tensor): Value tensor with shape value (torch.Tensor): Value tensor with shape
(num_tokens, n_kv_head, d_head) (num_tokens, n_kv_head, d_head)
key_cache (torch.Tensor): Key cache tensor with shape kv_cache (torch.Tensor): Key/value cache tensor with shape
(num_blocks, n_kv_head, block_size, d_head) (2, num_blocks, n_kv_head, block_size, d_head)
value_cache (torch.Tensor): Value cache tensor with shape
(num_blocks, n_kv_head, block_size, d_head)
slot_mapping (torch.Tensor): Mapping tensor indicating cache positions slot_mapping (torch.Tensor): Mapping tensor indicating cache positions
with shape (num_tokens) with shape (num_tokens)
Returns: Returns:
None: Updates the key_cache and value_cache tensors in-place None: Updates the kv_cache tensor in-place
""" """
block_size = key_cache.size(2) block_size = kv_cache.size(3)
n_kv_head = key.size(1)
# Calculate indices with explicit floor division # Calculate indices with explicit floor division
block_indices = torch.div(slot_mapping, block_size, rounding_mode="floor") block_indices = torch.div(slot_mapping, block_size, rounding_mode="floor")
block_offsets = slot_mapping % block_size block_offsets = slot_mapping % block_size
# Update caches using index_put_ # Create the head indices tensor
key_cache.index_put_( head_indices = torch.arange(n_kv_head, device=key.device)
(block_indices.unsqueeze(1),
torch.arange(key_cache.size(1),
device=key.device), block_offsets.unsqueeze(1)), key)
value_cache.index_put_( # Update caches using index_put_
(block_indices.unsqueeze(1), kv_cache.index_put_(
torch.arange(value_cache.size(1), (torch.tensor([0], device=key.device), block_indices[:, None],
device=value.device), block_offsets.unsqueeze(1)), value) head_indices[None, :], block_offsets[:, None]), key)
kv_cache.index_put_(
(torch.tensor([1], device=key.device), block_indices[:, None],
head_indices[None, :], block_offsets[:, None]), value)