From d2b58ca203fcff18c66e93fc4a2f851090b8bf75 Mon Sep 17 00:00:00 2001 From: Liangfu Chen Date: Thu, 3 Apr 2025 09:51:32 -0700 Subject: [PATCH] [Neuron][kernel] Fuse kv cache into a single tensor (#15911) Signed-off-by: Liangfu Chen --- tests/neuron/1_core/test_cache.py | 4 +- tests/neuron/1_core/test_prefix_prefill.py | 13 ++-- vllm/attention/ops/nki_flash_attn.py | 85 ++++++++++------------ 3 files changed, 46 insertions(+), 56 deletions(-) diff --git a/tests/neuron/1_core/test_cache.py b/tests/neuron/1_core/test_cache.py index ea33727b..3d869cd2 100644 --- a/tests/neuron/1_core/test_cache.py +++ b/tests/neuron/1_core/test_cache.py @@ -64,9 +64,11 @@ def test_reshape_and_cache(num_tokens, n_kv_head, d_head, num_blocks, key_cache = torch.zeros_like(key_cache_cpu, device=device) value_cache = torch.zeros_like(value_cache_cpu, device=device) slot_mapping = slot_mapping_cpu.to(device) + kv_cache = torch.stack([key_cache, value_cache]) # Run vectorized implementation on XLA device - reshape_and_cache(key, value, key_cache, value_cache, slot_mapping) + reshape_and_cache(key, value, kv_cache, slot_mapping) + key_cache, value_cache = torch.unbind(kv_cache, dim=0) # Move results back to CPU for comparison key_cache_result = key_cache.cpu() diff --git a/tests/neuron/1_core/test_prefix_prefill.py b/tests/neuron/1_core/test_prefix_prefill.py index 5a811f6d..8f7e711b 100644 --- a/tests/neuron/1_core/test_prefix_prefill.py +++ b/tests/neuron/1_core/test_prefix_prefill.py @@ -258,13 +258,13 @@ def sample_inputs( value[start_loc:end_loc]) cur_ctx += block_size block_id += 1 + kv_cache = torch.stack([k_cache, v_cache]) return ( query, k, v, - k_cache, - v_cache, + kv_cache, block_table, key, value, @@ -361,8 +361,7 @@ def test_contexted_kv_attention( query, k_active, v_active, - k_cache, - v_cache, + kv_cache, block_table, key, value, @@ -439,8 +438,7 @@ def test_contexted_kv_attention( query = query.unsqueeze(0).permute(0, 2, 3, 1).contiguous() k = k.unsqueeze(0).permute(0, 2, 3, 1).contiguous() v = v.unsqueeze(0).permute(0, 2, 1, 3).contiguous() - k_cache = k_cache.permute(0, 2, 1, 3).contiguous() - v_cache = v_cache.permute(0, 2, 1, 3).contiguous() + kv_cache = kv_cache.permute(0, 1, 3, 2, 4).contiguous() # transform block table active_block_table = get_active_block_tables( @@ -487,8 +485,7 @@ def test_contexted_kv_attention( query.to(device=device), k.to(device=device), v.to(device=device), - k_cache.to(device=device), - v_cache.to(device=device), + kv_cache.to(device=device), active_block_table.to(device=device), attn_mask.to(device=device), ) diff --git a/vllm/attention/ops/nki_flash_attn.py b/vllm/attention/ops/nki_flash_attn.py index dcf9b0ef..6bce5879 100644 --- a/vllm/attention/ops/nki_flash_attn.py +++ b/vllm/attention/ops/nki_flash_attn.py @@ -144,8 +144,7 @@ def transform_block_tables_for_indirect_load( def load_kv_tile_from_cache( cur_k_tile, cur_v_tile, - key_cache, - value_cache, + kv_cache, block_tables, large_k_tile_idx, num_blocks_per_large_tile, @@ -169,8 +168,8 @@ def load_kv_tile_from_cache( for load_idx in nl.affine_range(num_loads): i_p = nl.arange(B_P_SIZE)[:, None] i_f = nl.arange(tiled_block_size * B_D_SIZE)[None, :] - loaded = nl.load(key_cache[block_tables[load_idx, i_p, - large_k_tile_idx], i_f]) + loaded = nl.load(kv_cache[0, block_tables[load_idx, i_p, + large_k_tile_idx], i_f]) if cur_k_tile.dtype != loaded.dtype: loaded = nl.copy(loaded, dtype=cur_k_tile.dtype) # Transpose SBUF tensor using PE @@ -185,7 +184,7 @@ def load_kv_tile_from_cache( # load value cache for load_idx in nl.affine_range(num_loads): - loaded = nl.load(value_cache[block_tables[load_idx, i_p, + loaded = nl.load(kv_cache[1, block_tables[load_idx, i_p, large_k_tile_idx], i_f]) if cur_v_tile.dtype != loaded.dtype: loaded = nl.copy(loaded, dtype=cur_v_tile.dtype) @@ -418,8 +417,7 @@ def flash_paged_attention( query, key, value, - key_cache, - value_cache, + kv_cache, block_tables, mask, softmax_scale=None, @@ -434,8 +432,7 @@ def flash_paged_attention( - query: shape (1, n_heads, d, seq_q) - key: shape (1, n_kv_heads, d, seq_k) - value: shape (1, n_kv_heads, seq_v, d) - - key_cache: (num_blocks, n_kv_heads, block_size, d) - - value_cache: (num_blocks, n_kv_heads, block_size, d) + - kv_cache: (2, num_blocks, n_kv_heads, block_size, d) - block_tables: (num_active_blocks, ) - mask: (seq_q, num_active_blocks * block_size + seq_q) - o: shape (1, n_heads, seq_q, d) @@ -444,7 +441,7 @@ def flash_paged_attention( - We use continuous batching by default, so the batch dimension is always 1, and different requests are concatenated along sequence dimension. - - We use paged cache blocks (key_cache, value_cache) to store KV cache. + - We use paged cache blocks (kv_cache) to store KV cache. IO tensor dtypes: - This kernel assumes all IO tensors have the same dtype except for @@ -475,15 +472,13 @@ def flash_paged_attention( b, h, d, seqlen_q = query.shape B_D_SIZE = d n_tile_q = seqlen_q // B_P_SIZE # since q will be loaded on tensor engine - num_blocks, k_h, block_size, _ = key_cache.shape + _, num_blocks, k_h, block_size, _ = kv_cache.shape q_h_per_k_h = h // k_h assert b == 1, f"invalid batch size {b=}" assert d <= 128, f" we do not support head_dim > 128, got head dim {d=}" - cache_shape = (num_blocks, k_h, block_size, d) - assert (tuple(key_cache.shape) == cache_shape - ), f"{key_cache.shape=} mismatch, expect {cache_shape}" - assert (tuple(value_cache.shape) == cache_shape - ), f"{value_cache.shape=} mismatch, expect {cache_shape}" + cache_shape = (2, num_blocks, k_h, block_size, d) + assert (tuple(kv_cache.shape) == cache_shape + ), f"{kv_cache.shape=} mismatch, expect {cache_shape}" assert key is None or tuple(key.shape) == ( 1, k_h, @@ -580,13 +575,13 @@ def flash_paged_attention( head_id=head_id, ) - # Flatten KV cache to be 2D for loading into SBUF + # Flatten KV cache to be 3D for loading into SBUF new_cache_shape = ( + 2, num_blocks * k_h * block_size_tiling_factor, tiled_block_size * d, ) - key_cache = key_cache.reshape(new_cache_shape) - value_cache = value_cache.reshape(new_cache_shape) + kv_cache = kv_cache.reshape(new_cache_shape) # Global Flash Attention accumulators o_buffer = nl.zeros( @@ -621,8 +616,7 @@ def flash_paged_attention( load_kv_tile_from_cache( cur_k_tile=cur_k_tile, cur_v_tile=cur_v_tile, - key_cache=key_cache, - value_cache=value_cache, + kv_cache=kv_cache, block_tables=block_tables_sbuf, large_k_tile_idx=large_k_tile_idx, num_blocks_per_large_tile=num_blocks_per_large_tile, @@ -821,8 +815,7 @@ def flash_attn_varlen_nkifunc( query, key, value, - key_cache, - value_cache, + kv_cache, block_table, attn_mask, n_kv_head=None, @@ -838,8 +831,7 @@ def flash_attn_varlen_nkifunc( - query: (1, n_heads, d, seq_q) - key: (1, n_kv_heads, d, seq_k) - value: (1, n_kv_heads, seq_v, d) - - key_cache: (n_blocks, n_kv_heads, block_size, d) - - value_cache: (n_blocks, n_kv_heads, block_size, d) + - kv_cache: (2, n_blocks, n_kv_heads, block_size, d) - block_tables: (n_active_blocks, ) - attn_mask: (seq_q, n_active_blocks * block_size + seq_q) @@ -849,17 +841,17 @@ def flash_attn_varlen_nkifunc( for better DMA throughput """ if n_kv_head is None: - n_kv_head = key_cache.shape[1] - assert key_cache.shape[1] == n_kv_head + n_kv_head = kv_cache.shape[2] + assert kv_cache.shape[0] == 2 + assert kv_cache.shape[2] == n_kv_head if head_size is None: - head_size = key_cache.shape[-1] + head_size = kv_cache.shape[-1] kwargs = dict( query=query, key=key, value=value, - key_cache=key_cache, - value_cache=value_cache, + kv_cache=kv_cache, block_tables=block_table, mask=attn_mask, softmax_scale=1.0 / (head_size**0.5), @@ -874,8 +866,7 @@ def flash_attn_varlen_nkifunc( def reshape_and_cache( key: torch.Tensor, value: torch.Tensor, - key_cache: torch.Tensor, - value_cache: torch.Tensor, + kv_cache: torch.Tensor, slot_mapping: torch.Tensor, ) -> None: """ @@ -886,29 +877,29 @@ def reshape_and_cache( (num_tokens, n_kv_head, d_head) value (torch.Tensor): Value tensor with shape (num_tokens, n_kv_head, d_head) - key_cache (torch.Tensor): Key cache tensor with shape - (num_blocks, n_kv_head, block_size, d_head) - value_cache (torch.Tensor): Value cache tensor with shape - (num_blocks, n_kv_head, block_size, d_head) + kv_cache (torch.Tensor): Key/value cache tensor with shape + (2, num_blocks, n_kv_head, block_size, d_head) slot_mapping (torch.Tensor): Mapping tensor indicating cache positions with shape (num_tokens) Returns: - None: Updates the key_cache and value_cache tensors in-place + None: Updates the kv_cache tensor in-place """ - block_size = key_cache.size(2) + block_size = kv_cache.size(3) + n_kv_head = key.size(1) # Calculate indices with explicit floor division block_indices = torch.div(slot_mapping, block_size, rounding_mode="floor") block_offsets = slot_mapping % block_size - # Update caches using index_put_ - key_cache.index_put_( - (block_indices.unsqueeze(1), - torch.arange(key_cache.size(1), - device=key.device), block_offsets.unsqueeze(1)), key) + # Create the head indices tensor + head_indices = torch.arange(n_kv_head, device=key.device) - value_cache.index_put_( - (block_indices.unsqueeze(1), - torch.arange(value_cache.size(1), - device=value.device), block_offsets.unsqueeze(1)), value) + # Update caches using index_put_ + kv_cache.index_put_( + (torch.tensor([0], device=key.device), block_indices[:, None], + head_indices[None, :], block_offsets[:, None]), key) + + kv_cache.index_put_( + (torch.tensor([1], device=key.device), block_indices[:, None], + head_indices[None, :], block_offsets[:, None]), value)