227 lines
9.3 KiB
Python
227 lines
9.3 KiB
Python
"""Multi-head attention."""
|
|
from typing import Optional
|
|
|
|
import torch
|
|
import torch.nn as nn
|
|
from xformers import ops as xops
|
|
|
|
from cacheflow import attention_ops
|
|
from cacheflow import cache_ops
|
|
from cacheflow import pos_encoding_ops
|
|
from cacheflow.model_executor.input_metadata import InputMetadata
|
|
|
|
|
|
_SUPPORTED_HEAD_SIZES = [32, 64, 80, 96, 128, 160, 192, 256]
|
|
|
|
class GPTCacheFlowAttention(nn.Module):
|
|
"""GPT-style multi-head attention.
|
|
|
|
This class takes flattened 1D query, key, and value tensors as input. The
|
|
input 1D tensors can be split into three parts: the prompt tokens, the
|
|
generation tokens, and the paddings.
|
|
|
|
|<------------------------------------- num_valid_tokens ------------------------------------->|
|
|
|<--------------- num_prompt_tokens -------------->|<------- num_generation_tokens (M) ------->|
|
|
|<--prompt_0-->|<--prompt_1-->|...|<--prompt_N-1-->|<--generation_0-->|...|<--generation_M-1-->|<--padding-->|
|
|
|
|
The prompts might have different lengths, while the generation tokens always
|
|
have length 1. The paddings are appended to make the input length a multiple
|
|
of 8, which is desirable for Tensor Cores.
|
|
|
|
The class does the following:
|
|
1. Perform multi_query_kv_attention for the prompts. This operation does
|
|
not use the KV cache.
|
|
2. Wait for the cache operations (e.g., swap, copy) to finish. The cache
|
|
operations are issued by the cache engine before executing the forward
|
|
pass of the model, and they are executed asynchronously.
|
|
3. Reshape and store the input key and value tensors in the KV cache.
|
|
4. Perform single_query_cached_kv_attention for the generation tokens.
|
|
This operation reads the previous key and value tensors from the KV
|
|
cache.
|
|
5. Output a flattened 1D tensor.
|
|
"""
|
|
|
|
def __init__(self, num_heads: int, head_size: int, scale: float) -> None:
|
|
super().__init__()
|
|
self.num_heads = num_heads
|
|
self.head_size = head_size
|
|
self.scale = float(scale)
|
|
self.attn_op = xops.fmha.cutlass.FwOp()
|
|
|
|
if self.head_size not in _SUPPORTED_HEAD_SIZES:
|
|
raise ValueError(f'head_size ({self.head_size}) is not supported by '
|
|
'the single_query_cached_kv_attention kernel. '
|
|
'Use one of the following head sizes: '
|
|
f'{_SUPPORTED_HEAD_SIZES}.')
|
|
|
|
def multi_query_kv_attention(
|
|
self,
|
|
output: torch.Tensor, # [num_prompt_tokens, num_heads, head_size]
|
|
query: torch.Tensor, # [num_prompt_tokens, num_heads, head_size]
|
|
key: torch.Tensor, # [num_prompt_tokens, num_heads, head_size]
|
|
value: torch.Tensor, # [num_prompt_tokens, num_heads, head_size]
|
|
attn_bias: xops.AttentionBias,
|
|
) -> None:
|
|
# TODO(woosuk): The unsqueeze op may incur some CPU overhead. Optimize.
|
|
out = xops.memory_efficient_attention_forward(
|
|
query.unsqueeze(0),
|
|
key.unsqueeze(0),
|
|
value.unsqueeze(0),
|
|
attn_bias=attn_bias,
|
|
p=0.0,
|
|
scale=self.scale,
|
|
op=self.attn_op,
|
|
)
|
|
# TODO(woosuk): Unnecessary copy. Optimize.
|
|
output.copy_(out.squeeze(0))
|
|
return output
|
|
|
|
def single_query_cached_kv_attention(
|
|
self,
|
|
output: torch.Tensor, # [num_generation_tokens, num_heads, head_size]
|
|
query: torch.Tensor, # [num_generation_tokens, num_heads, head_size]
|
|
key_cache: torch.Tensor, # [num_blocks, num_heads, head_size/x, block_size, x]
|
|
value_cache: torch.Tensor, # [num_blocks, num_heads, head_size, block_size]
|
|
input_metadata: InputMetadata,
|
|
) -> None:
|
|
block_size = value_cache.shape[3]
|
|
attention_ops.single_query_cached_kv_attention(
|
|
output,
|
|
query,
|
|
key_cache,
|
|
value_cache,
|
|
self.scale,
|
|
input_metadata.block_tables,
|
|
input_metadata.context_lens,
|
|
block_size,
|
|
input_metadata.max_context_len,
|
|
)
|
|
|
|
def forward(
|
|
self,
|
|
query: torch.Tensor, # [num_tokens, num_heads * head_size]
|
|
key: torch.Tensor, # [num_tokens, num_heads * head_size]
|
|
value: torch.Tensor, # [num_tokens, num_heads * head_size]
|
|
key_cache: Optional[torch.Tensor], # [num_blocks, num_heads, head_size/x, block_size, x]
|
|
value_cache: Optional[torch.Tensor], # [num_blocks, num_heads, head_size, block_size]
|
|
input_metadata: InputMetadata,
|
|
cache_event: Optional[torch.cuda.Event],
|
|
) -> torch.Tensor: # [num_tokens, num_heads * head_size]
|
|
# NOTE: The query, key, and value tensors must be sliced from a qkv
|
|
# tensor of shape [num_tokens, 3 * num_heads * head_size].
|
|
|
|
# Reshape the query, key, and value tensors.
|
|
query = query.view(-1, self.num_heads, self.head_size)
|
|
key = key.view(-1, self.num_heads, self.head_size)
|
|
value = value.view(-1, self.num_heads, self.head_size)
|
|
|
|
# Pre-allocate the output tensor.
|
|
output = torch.empty_like(query)
|
|
|
|
# Compute the attention op for prompts.
|
|
num_prompt_tokens = input_metadata.num_prompt_tokens
|
|
if num_prompt_tokens > 0:
|
|
self.multi_query_kv_attention(
|
|
output[:num_prompt_tokens],
|
|
query[:num_prompt_tokens],
|
|
key[:num_prompt_tokens],
|
|
value[:num_prompt_tokens],
|
|
input_metadata.attn_bias,
|
|
)
|
|
|
|
# Wait until the cache op is done.
|
|
if cache_event is not None:
|
|
cache_event.wait()
|
|
|
|
# Reshape the keys and values and store them in the cache.
|
|
# When key_cache and value_cache are not provided, the new key
|
|
# and value vectors will not be cached.
|
|
num_valid_tokens = input_metadata.num_valid_tokens
|
|
if (num_valid_tokens > 0 and key_cache is not None
|
|
and value_cache is not None):
|
|
# The stride is 3 because the key and value are sliced from qkv.
|
|
cache_ops.reshape_and_cache(
|
|
key[:num_valid_tokens],
|
|
value[:num_valid_tokens],
|
|
key_cache,
|
|
value_cache,
|
|
input_metadata.slot_mapping,
|
|
)
|
|
|
|
if input_metadata.num_generation_tokens > 0:
|
|
assert key_cache is not None and value_cache is not None, (
|
|
"key_cache and value_cache must be provided when "
|
|
"generating tokens."
|
|
)
|
|
# Compute the attention op for generation tokens.
|
|
self.single_query_cached_kv_attention(
|
|
output[num_prompt_tokens:num_valid_tokens],
|
|
query[num_prompt_tokens:num_valid_tokens],
|
|
key_cache,
|
|
value_cache,
|
|
input_metadata)
|
|
|
|
# Reshape the output tensor.
|
|
# NOTE(woosuk): The output tensor may include paddings.
|
|
return output.view(-1, self.num_heads * self.head_size)
|
|
|
|
|
|
class GPTNeoXCacheFlowAttention(GPTCacheFlowAttention):
|
|
"""Attention with GPT-NeoX style rotary embedding."""
|
|
|
|
def __init__(
|
|
self,
|
|
num_heads: int,
|
|
head_size: int,
|
|
scale: float,
|
|
rotary_dim: int,
|
|
max_position: int = 8192,
|
|
base: int = 10000,
|
|
) -> None:
|
|
super().__init__(num_heads, head_size, scale)
|
|
|
|
# Create the cos and sin cache.
|
|
inv_freq = 1.0 / (base ** (torch.arange(0, rotary_dim, 2) / rotary_dim))
|
|
t = torch.arange(max_position).float()
|
|
freqs = torch.einsum('i,j -> ij', t, inv_freq.float())
|
|
cos = freqs.cos()
|
|
sin = freqs.sin()
|
|
cache = torch.cat((cos, sin), dim=-1)
|
|
|
|
# FIXME(woosuk): This assumes that we configure the default dtype when
|
|
# initializing the model. Make it more robust.
|
|
torch_dtype = torch.get_default_dtype()
|
|
cache = cache.to(torch_dtype)
|
|
# Embedding size: [max_position, rotary_dim]
|
|
self.register_buffer("cos_sin_cache", cache, persistent=False)
|
|
|
|
def forward(
|
|
self,
|
|
positions: torch.LongTensor, # [num_tokens]
|
|
query: torch.Tensor, # [num_tokens, num_heads * head_size]
|
|
key: torch.Tensor, # [num_tokens, num_heads * head_size]
|
|
value: torch.Tensor, # [num_tokens, num_heads * head_size]
|
|
key_cache: torch.Tensor, # [num_blocks, num_heads, head_size/x, block_size, x]
|
|
value_cache: torch.Tensor, # [num_blocks, num_heads, head_size, block_size]
|
|
input_metadata: InputMetadata,
|
|
cache_event: Optional[torch.cuda.Event],
|
|
) -> torch.Tensor: # [num_tokens, num_heads * head_size]
|
|
# Apply rotary embedding to the query and key before passing them
|
|
# to the attention op.
|
|
pos_encoding_ops.rotary_embedding_neox(
|
|
positions,
|
|
query,
|
|
key,
|
|
self.head_size,
|
|
self.cos_sin_cache,
|
|
)
|
|
return super().forward(
|
|
query,
|
|
key,
|
|
value,
|
|
key_cache,
|
|
value_cache,
|
|
input_metadata,
|
|
cache_event,
|
|
)
|