[Neuron][kernel] Fuse kv cache into a single tensor (#15911)
Signed-off-by: Liangfu Chen <liangfc@amazon.com>
This commit is contained in:
parent
82e7e19a6e
commit
d2b58ca203
@ -64,9 +64,11 @@ def test_reshape_and_cache(num_tokens, n_kv_head, d_head, num_blocks,
|
|||||||
key_cache = torch.zeros_like(key_cache_cpu, device=device)
|
key_cache = torch.zeros_like(key_cache_cpu, device=device)
|
||||||
value_cache = torch.zeros_like(value_cache_cpu, device=device)
|
value_cache = torch.zeros_like(value_cache_cpu, device=device)
|
||||||
slot_mapping = slot_mapping_cpu.to(device)
|
slot_mapping = slot_mapping_cpu.to(device)
|
||||||
|
kv_cache = torch.stack([key_cache, value_cache])
|
||||||
|
|
||||||
# Run vectorized implementation on XLA device
|
# Run vectorized implementation on XLA device
|
||||||
reshape_and_cache(key, value, key_cache, value_cache, slot_mapping)
|
reshape_and_cache(key, value, kv_cache, slot_mapping)
|
||||||
|
key_cache, value_cache = torch.unbind(kv_cache, dim=0)
|
||||||
|
|
||||||
# Move results back to CPU for comparison
|
# Move results back to CPU for comparison
|
||||||
key_cache_result = key_cache.cpu()
|
key_cache_result = key_cache.cpu()
|
||||||
|
@ -258,13 +258,13 @@ def sample_inputs(
|
|||||||
value[start_loc:end_loc])
|
value[start_loc:end_loc])
|
||||||
cur_ctx += block_size
|
cur_ctx += block_size
|
||||||
block_id += 1
|
block_id += 1
|
||||||
|
kv_cache = torch.stack([k_cache, v_cache])
|
||||||
|
|
||||||
return (
|
return (
|
||||||
query,
|
query,
|
||||||
k,
|
k,
|
||||||
v,
|
v,
|
||||||
k_cache,
|
kv_cache,
|
||||||
v_cache,
|
|
||||||
block_table,
|
block_table,
|
||||||
key,
|
key,
|
||||||
value,
|
value,
|
||||||
@ -361,8 +361,7 @@ def test_contexted_kv_attention(
|
|||||||
query,
|
query,
|
||||||
k_active,
|
k_active,
|
||||||
v_active,
|
v_active,
|
||||||
k_cache,
|
kv_cache,
|
||||||
v_cache,
|
|
||||||
block_table,
|
block_table,
|
||||||
key,
|
key,
|
||||||
value,
|
value,
|
||||||
@ -439,8 +438,7 @@ def test_contexted_kv_attention(
|
|||||||
query = query.unsqueeze(0).permute(0, 2, 3, 1).contiguous()
|
query = query.unsqueeze(0).permute(0, 2, 3, 1).contiguous()
|
||||||
k = k.unsqueeze(0).permute(0, 2, 3, 1).contiguous()
|
k = k.unsqueeze(0).permute(0, 2, 3, 1).contiguous()
|
||||||
v = v.unsqueeze(0).permute(0, 2, 1, 3).contiguous()
|
v = v.unsqueeze(0).permute(0, 2, 1, 3).contiguous()
|
||||||
k_cache = k_cache.permute(0, 2, 1, 3).contiguous()
|
kv_cache = kv_cache.permute(0, 1, 3, 2, 4).contiguous()
|
||||||
v_cache = v_cache.permute(0, 2, 1, 3).contiguous()
|
|
||||||
|
|
||||||
# transform block table
|
# transform block table
|
||||||
active_block_table = get_active_block_tables(
|
active_block_table = get_active_block_tables(
|
||||||
@ -487,8 +485,7 @@ def test_contexted_kv_attention(
|
|||||||
query.to(device=device),
|
query.to(device=device),
|
||||||
k.to(device=device),
|
k.to(device=device),
|
||||||
v.to(device=device),
|
v.to(device=device),
|
||||||
k_cache.to(device=device),
|
kv_cache.to(device=device),
|
||||||
v_cache.to(device=device),
|
|
||||||
active_block_table.to(device=device),
|
active_block_table.to(device=device),
|
||||||
attn_mask.to(device=device),
|
attn_mask.to(device=device),
|
||||||
)
|
)
|
||||||
|
@ -144,8 +144,7 @@ def transform_block_tables_for_indirect_load(
|
|||||||
def load_kv_tile_from_cache(
|
def load_kv_tile_from_cache(
|
||||||
cur_k_tile,
|
cur_k_tile,
|
||||||
cur_v_tile,
|
cur_v_tile,
|
||||||
key_cache,
|
kv_cache,
|
||||||
value_cache,
|
|
||||||
block_tables,
|
block_tables,
|
||||||
large_k_tile_idx,
|
large_k_tile_idx,
|
||||||
num_blocks_per_large_tile,
|
num_blocks_per_large_tile,
|
||||||
@ -169,8 +168,8 @@ def load_kv_tile_from_cache(
|
|||||||
for load_idx in nl.affine_range(num_loads):
|
for load_idx in nl.affine_range(num_loads):
|
||||||
i_p = nl.arange(B_P_SIZE)[:, None]
|
i_p = nl.arange(B_P_SIZE)[:, None]
|
||||||
i_f = nl.arange(tiled_block_size * B_D_SIZE)[None, :]
|
i_f = nl.arange(tiled_block_size * B_D_SIZE)[None, :]
|
||||||
loaded = nl.load(key_cache[block_tables[load_idx, i_p,
|
loaded = nl.load(kv_cache[0, block_tables[load_idx, i_p,
|
||||||
large_k_tile_idx], i_f])
|
large_k_tile_idx], i_f])
|
||||||
if cur_k_tile.dtype != loaded.dtype:
|
if cur_k_tile.dtype != loaded.dtype:
|
||||||
loaded = nl.copy(loaded, dtype=cur_k_tile.dtype)
|
loaded = nl.copy(loaded, dtype=cur_k_tile.dtype)
|
||||||
# Transpose SBUF tensor using PE
|
# Transpose SBUF tensor using PE
|
||||||
@ -185,7 +184,7 @@ def load_kv_tile_from_cache(
|
|||||||
|
|
||||||
# load value cache
|
# load value cache
|
||||||
for load_idx in nl.affine_range(num_loads):
|
for load_idx in nl.affine_range(num_loads):
|
||||||
loaded = nl.load(value_cache[block_tables[load_idx, i_p,
|
loaded = nl.load(kv_cache[1, block_tables[load_idx, i_p,
|
||||||
large_k_tile_idx], i_f])
|
large_k_tile_idx], i_f])
|
||||||
if cur_v_tile.dtype != loaded.dtype:
|
if cur_v_tile.dtype != loaded.dtype:
|
||||||
loaded = nl.copy(loaded, dtype=cur_v_tile.dtype)
|
loaded = nl.copy(loaded, dtype=cur_v_tile.dtype)
|
||||||
@ -418,8 +417,7 @@ def flash_paged_attention(
|
|||||||
query,
|
query,
|
||||||
key,
|
key,
|
||||||
value,
|
value,
|
||||||
key_cache,
|
kv_cache,
|
||||||
value_cache,
|
|
||||||
block_tables,
|
block_tables,
|
||||||
mask,
|
mask,
|
||||||
softmax_scale=None,
|
softmax_scale=None,
|
||||||
@ -434,8 +432,7 @@ def flash_paged_attention(
|
|||||||
- query: shape (1, n_heads, d, seq_q)
|
- query: shape (1, n_heads, d, seq_q)
|
||||||
- key: shape (1, n_kv_heads, d, seq_k)
|
- key: shape (1, n_kv_heads, d, seq_k)
|
||||||
- value: shape (1, n_kv_heads, seq_v, d)
|
- value: shape (1, n_kv_heads, seq_v, d)
|
||||||
- key_cache: (num_blocks, n_kv_heads, block_size, d)
|
- kv_cache: (2, num_blocks, n_kv_heads, block_size, d)
|
||||||
- value_cache: (num_blocks, n_kv_heads, block_size, d)
|
|
||||||
- block_tables: (num_active_blocks, )
|
- block_tables: (num_active_blocks, )
|
||||||
- mask: (seq_q, num_active_blocks * block_size + seq_q)
|
- mask: (seq_q, num_active_blocks * block_size + seq_q)
|
||||||
- o: shape (1, n_heads, seq_q, d)
|
- o: shape (1, n_heads, seq_q, d)
|
||||||
@ -444,7 +441,7 @@ def flash_paged_attention(
|
|||||||
- We use continuous batching by default, so the batch dimension is
|
- We use continuous batching by default, so the batch dimension is
|
||||||
always 1, and different requests are concatenated along sequence
|
always 1, and different requests are concatenated along sequence
|
||||||
dimension.
|
dimension.
|
||||||
- We use paged cache blocks (key_cache, value_cache) to store KV cache.
|
- We use paged cache blocks (kv_cache) to store KV cache.
|
||||||
|
|
||||||
IO tensor dtypes:
|
IO tensor dtypes:
|
||||||
- This kernel assumes all IO tensors have the same dtype except for
|
- This kernel assumes all IO tensors have the same dtype except for
|
||||||
@ -475,15 +472,13 @@ def flash_paged_attention(
|
|||||||
b, h, d, seqlen_q = query.shape
|
b, h, d, seqlen_q = query.shape
|
||||||
B_D_SIZE = d
|
B_D_SIZE = d
|
||||||
n_tile_q = seqlen_q // B_P_SIZE # since q will be loaded on tensor engine
|
n_tile_q = seqlen_q // B_P_SIZE # since q will be loaded on tensor engine
|
||||||
num_blocks, k_h, block_size, _ = key_cache.shape
|
_, num_blocks, k_h, block_size, _ = kv_cache.shape
|
||||||
q_h_per_k_h = h // k_h
|
q_h_per_k_h = h // k_h
|
||||||
assert b == 1, f"invalid batch size {b=}"
|
assert b == 1, f"invalid batch size {b=}"
|
||||||
assert d <= 128, f" we do not support head_dim > 128, got head dim {d=}"
|
assert d <= 128, f" we do not support head_dim > 128, got head dim {d=}"
|
||||||
cache_shape = (num_blocks, k_h, block_size, d)
|
cache_shape = (2, num_blocks, k_h, block_size, d)
|
||||||
assert (tuple(key_cache.shape) == cache_shape
|
assert (tuple(kv_cache.shape) == cache_shape
|
||||||
), f"{key_cache.shape=} mismatch, expect {cache_shape}"
|
), f"{kv_cache.shape=} mismatch, expect {cache_shape}"
|
||||||
assert (tuple(value_cache.shape) == cache_shape
|
|
||||||
), f"{value_cache.shape=} mismatch, expect {cache_shape}"
|
|
||||||
assert key is None or tuple(key.shape) == (
|
assert key is None or tuple(key.shape) == (
|
||||||
1,
|
1,
|
||||||
k_h,
|
k_h,
|
||||||
@ -580,13 +575,13 @@ def flash_paged_attention(
|
|||||||
head_id=head_id,
|
head_id=head_id,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Flatten KV cache to be 2D for loading into SBUF
|
# Flatten KV cache to be 3D for loading into SBUF
|
||||||
new_cache_shape = (
|
new_cache_shape = (
|
||||||
|
2,
|
||||||
num_blocks * k_h * block_size_tiling_factor,
|
num_blocks * k_h * block_size_tiling_factor,
|
||||||
tiled_block_size * d,
|
tiled_block_size * d,
|
||||||
)
|
)
|
||||||
key_cache = key_cache.reshape(new_cache_shape)
|
kv_cache = kv_cache.reshape(new_cache_shape)
|
||||||
value_cache = value_cache.reshape(new_cache_shape)
|
|
||||||
|
|
||||||
# Global Flash Attention accumulators
|
# Global Flash Attention accumulators
|
||||||
o_buffer = nl.zeros(
|
o_buffer = nl.zeros(
|
||||||
@ -621,8 +616,7 @@ def flash_paged_attention(
|
|||||||
load_kv_tile_from_cache(
|
load_kv_tile_from_cache(
|
||||||
cur_k_tile=cur_k_tile,
|
cur_k_tile=cur_k_tile,
|
||||||
cur_v_tile=cur_v_tile,
|
cur_v_tile=cur_v_tile,
|
||||||
key_cache=key_cache,
|
kv_cache=kv_cache,
|
||||||
value_cache=value_cache,
|
|
||||||
block_tables=block_tables_sbuf,
|
block_tables=block_tables_sbuf,
|
||||||
large_k_tile_idx=large_k_tile_idx,
|
large_k_tile_idx=large_k_tile_idx,
|
||||||
num_blocks_per_large_tile=num_blocks_per_large_tile,
|
num_blocks_per_large_tile=num_blocks_per_large_tile,
|
||||||
@ -821,8 +815,7 @@ def flash_attn_varlen_nkifunc(
|
|||||||
query,
|
query,
|
||||||
key,
|
key,
|
||||||
value,
|
value,
|
||||||
key_cache,
|
kv_cache,
|
||||||
value_cache,
|
|
||||||
block_table,
|
block_table,
|
||||||
attn_mask,
|
attn_mask,
|
||||||
n_kv_head=None,
|
n_kv_head=None,
|
||||||
@ -838,8 +831,7 @@ def flash_attn_varlen_nkifunc(
|
|||||||
- query: (1, n_heads, d, seq_q)
|
- query: (1, n_heads, d, seq_q)
|
||||||
- key: (1, n_kv_heads, d, seq_k)
|
- key: (1, n_kv_heads, d, seq_k)
|
||||||
- value: (1, n_kv_heads, seq_v, d)
|
- value: (1, n_kv_heads, seq_v, d)
|
||||||
- key_cache: (n_blocks, n_kv_heads, block_size, d)
|
- kv_cache: (2, n_blocks, n_kv_heads, block_size, d)
|
||||||
- value_cache: (n_blocks, n_kv_heads, block_size, d)
|
|
||||||
- block_tables: (n_active_blocks, )
|
- block_tables: (n_active_blocks, )
|
||||||
- attn_mask: (seq_q, n_active_blocks * block_size + seq_q)
|
- attn_mask: (seq_q, n_active_blocks * block_size + seq_q)
|
||||||
|
|
||||||
@ -849,17 +841,17 @@ def flash_attn_varlen_nkifunc(
|
|||||||
for better DMA throughput
|
for better DMA throughput
|
||||||
"""
|
"""
|
||||||
if n_kv_head is None:
|
if n_kv_head is None:
|
||||||
n_kv_head = key_cache.shape[1]
|
n_kv_head = kv_cache.shape[2]
|
||||||
assert key_cache.shape[1] == n_kv_head
|
assert kv_cache.shape[0] == 2
|
||||||
|
assert kv_cache.shape[2] == n_kv_head
|
||||||
if head_size is None:
|
if head_size is None:
|
||||||
head_size = key_cache.shape[-1]
|
head_size = kv_cache.shape[-1]
|
||||||
|
|
||||||
kwargs = dict(
|
kwargs = dict(
|
||||||
query=query,
|
query=query,
|
||||||
key=key,
|
key=key,
|
||||||
value=value,
|
value=value,
|
||||||
key_cache=key_cache,
|
kv_cache=kv_cache,
|
||||||
value_cache=value_cache,
|
|
||||||
block_tables=block_table,
|
block_tables=block_table,
|
||||||
mask=attn_mask,
|
mask=attn_mask,
|
||||||
softmax_scale=1.0 / (head_size**0.5),
|
softmax_scale=1.0 / (head_size**0.5),
|
||||||
@ -874,8 +866,7 @@ def flash_attn_varlen_nkifunc(
|
|||||||
def reshape_and_cache(
|
def reshape_and_cache(
|
||||||
key: torch.Tensor,
|
key: torch.Tensor,
|
||||||
value: torch.Tensor,
|
value: torch.Tensor,
|
||||||
key_cache: torch.Tensor,
|
kv_cache: torch.Tensor,
|
||||||
value_cache: torch.Tensor,
|
|
||||||
slot_mapping: torch.Tensor,
|
slot_mapping: torch.Tensor,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""
|
"""
|
||||||
@ -886,29 +877,29 @@ def reshape_and_cache(
|
|||||||
(num_tokens, n_kv_head, d_head)
|
(num_tokens, n_kv_head, d_head)
|
||||||
value (torch.Tensor): Value tensor with shape
|
value (torch.Tensor): Value tensor with shape
|
||||||
(num_tokens, n_kv_head, d_head)
|
(num_tokens, n_kv_head, d_head)
|
||||||
key_cache (torch.Tensor): Key cache tensor with shape
|
kv_cache (torch.Tensor): Key/value cache tensor with shape
|
||||||
(num_blocks, n_kv_head, block_size, d_head)
|
(2, num_blocks, n_kv_head, block_size, d_head)
|
||||||
value_cache (torch.Tensor): Value cache tensor with shape
|
|
||||||
(num_blocks, n_kv_head, block_size, d_head)
|
|
||||||
slot_mapping (torch.Tensor): Mapping tensor indicating cache positions
|
slot_mapping (torch.Tensor): Mapping tensor indicating cache positions
|
||||||
with shape (num_tokens)
|
with shape (num_tokens)
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
None: Updates the key_cache and value_cache tensors in-place
|
None: Updates the kv_cache tensor in-place
|
||||||
"""
|
"""
|
||||||
block_size = key_cache.size(2)
|
block_size = kv_cache.size(3)
|
||||||
|
n_kv_head = key.size(1)
|
||||||
|
|
||||||
# Calculate indices with explicit floor division
|
# Calculate indices with explicit floor division
|
||||||
block_indices = torch.div(slot_mapping, block_size, rounding_mode="floor")
|
block_indices = torch.div(slot_mapping, block_size, rounding_mode="floor")
|
||||||
block_offsets = slot_mapping % block_size
|
block_offsets = slot_mapping % block_size
|
||||||
|
|
||||||
# Update caches using index_put_
|
# Create the head indices tensor
|
||||||
key_cache.index_put_(
|
head_indices = torch.arange(n_kv_head, device=key.device)
|
||||||
(block_indices.unsqueeze(1),
|
|
||||||
torch.arange(key_cache.size(1),
|
|
||||||
device=key.device), block_offsets.unsqueeze(1)), key)
|
|
||||||
|
|
||||||
value_cache.index_put_(
|
# Update caches using index_put_
|
||||||
(block_indices.unsqueeze(1),
|
kv_cache.index_put_(
|
||||||
torch.arange(value_cache.size(1),
|
(torch.tensor([0], device=key.device), block_indices[:, None],
|
||||||
device=value.device), block_offsets.unsqueeze(1)), value)
|
head_indices[None, :], block_offsets[:, None]), key)
|
||||||
|
|
||||||
|
kv_cache.index_put_(
|
||||||
|
(torch.tensor([1], device=key.device), block_indices[:, None],
|
||||||
|
head_indices[None, :], block_offsets[:, None]), value)
|
||||||
|
Loading…
x
Reference in New Issue
Block a user