225 lines
9.1 KiB
Python
Raw Normal View History

"""Multi-head attention."""
2023-04-02 00:30:17 -07:00
from typing import Optional
2023-02-23 09:31:55 +00:00
import torch
import torch.nn as nn
from xformers import ops as xops
2023-02-23 09:31:55 +00:00
from cacheflow import attention_ops
from cacheflow import cache_ops
from cacheflow import pos_encoding_ops
2023-05-09 15:30:12 -07:00
from cacheflow.model_executor.input_metadata import InputMetadata
2023-02-23 09:31:55 +00:00
2023-06-07 00:40:21 -07:00
_SUPPORTED_HEAD_SIZES = [64, 80, 96, 128]
2023-05-20 13:06:59 -07:00
class GPTCacheFlowAttention(nn.Module):
"""GPT-style multi-head attention.
This class takes flattened 1D query, key, and value tensors as input. The
input 1D tensors can be split into three parts: the prompt tokens, the
generation tokens, and the paddings.
|<------------------------------------- num_valid_tokens ------------------------------------->|
|<--------------- num_prompt_tokens -------------->|<------- num_generation_tokens (M) ------->|
|<--prompt_0-->|<--prompt_1-->|...|<--prompt_N-1-->|<--generation_0-->|...|<--generation_M-1-->|<--padding-->|
The prompts might have different lengths, while the generation tokens always
have length 1. The paddings are appended to make the input length a multiple
of 8, which is desirable for Tensor Cores.
The class does the following:
1. Perform multi_query_kv_attention for the prompts. This operation does
not use the KV cache.
2. Wait for the cache operations (e.g., swap, copy) to finish. The cache
operations are issued by the cache engine before executing the forward
pass of the model, and they are executed asynchronously.
3. Reshape and store the input key and value tensors in the KV cache.
4. Perform single_query_cached_kv_attention for the generation tokens.
This operation reads the previous key and value tensors from the KV
cache.
5. Output a flattened 1D tensor.
"""
2023-02-23 09:31:55 +00:00
def __init__(self, num_heads: int, head_size: int, scale: float) -> None:
super().__init__()
self.num_heads = num_heads
self.head_size = head_size
self.scale = float(scale)
self.attn_op = xops.fmha.cutlass.FwOp()
2023-02-23 09:31:55 +00:00
if self.head_size not in _SUPPORTED_HEAD_SIZES:
2023-06-07 00:40:21 -07:00
raise ValueError(f"head_size ({self.head_size}) is not supported. "
f"Supported head sizes: {_SUPPORTED_HEAD_SIZES}.")
2023-02-23 09:31:55 +00:00
def multi_query_kv_attention(
self,
2023-04-02 00:30:17 -07:00
output: torch.Tensor, # [num_prompt_tokens, num_heads, head_size]
query: torch.Tensor, # [num_prompt_tokens, num_heads, head_size]
key: torch.Tensor, # [num_prompt_tokens, num_heads, head_size]
value: torch.Tensor, # [num_prompt_tokens, num_heads, head_size]
attn_bias: xops.AttentionBias,
) -> torch.Tensor:
# TODO(woosuk): The unsqueeze op may incur some CPU overhead. Optimize.
out = xops.memory_efficient_attention_forward(
query.unsqueeze(0),
key.unsqueeze(0),
value.unsqueeze(0),
attn_bias=attn_bias,
p=0.0,
scale=self.scale,
op=self.attn_op,
2023-04-02 00:30:17 -07:00
)
# TODO(woosuk): Unnecessary copy. Optimize.
output.copy_(out.squeeze(0))
return output
2023-02-23 09:31:55 +00:00
def single_query_cached_kv_attention(
self,
output: torch.Tensor, # [num_generation_tokens, num_heads, head_size]
query: torch.Tensor, # [num_generation_tokens, num_heads, head_size]
key_cache: torch.Tensor, # [num_blocks, num_heads, head_size/x, block_size, x]
value_cache: torch.Tensor, # [num_blocks, num_heads, head_size, block_size]
2023-02-23 09:31:55 +00:00
input_metadata: InputMetadata,
) -> None:
block_size = value_cache.shape[3]
attention_ops.single_query_cached_kv_attention(
output,
query,
key_cache,
value_cache,
self.scale,
input_metadata.block_tables,
input_metadata.context_lens,
block_size,
input_metadata.max_context_len,
)
2023-02-23 09:31:55 +00:00
def forward(
self,
query: torch.Tensor, # [num_tokens, num_heads * head_size]
key: torch.Tensor, # [num_tokens, num_heads * head_size]
value: torch.Tensor, # [num_tokens, num_heads * head_size]
key_cache: Optional[torch.Tensor], # [num_blocks, num_heads, head_size/x, block_size, x]
value_cache: Optional[torch.Tensor], # [num_blocks, num_heads, head_size, block_size]
2023-02-23 09:31:55 +00:00
input_metadata: InputMetadata,
cache_event: Optional[torch.cuda.Event],
) -> torch.Tensor: # [num_tokens, num_heads * head_size]
2023-04-02 00:30:17 -07:00
# NOTE: The query, key, and value tensors must be sliced from a qkv
# tensor of shape [num_tokens, 3 * num_heads * head_size].
2023-02-23 23:02:25 +00:00
2023-04-02 00:30:17 -07:00
# Reshape the query, key, and value tensors.
query = query.view(-1, self.num_heads, self.head_size)
key = key.view(-1, self.num_heads, self.head_size)
value = value.view(-1, self.num_heads, self.head_size)
2023-04-02 00:30:17 -07:00
# Pre-allocate the output tensor.
output = torch.empty_like(query)
2023-02-23 09:31:55 +00:00
# Compute the attention op for prompts.
2023-03-13 13:48:38 -07:00
num_prompt_tokens = input_metadata.num_prompt_tokens
if num_prompt_tokens > 0:
self.multi_query_kv_attention(
2023-03-06 10:05:27 -08:00
output[:num_prompt_tokens],
query[:num_prompt_tokens],
key[:num_prompt_tokens],
2023-03-22 04:45:42 +08:00
value[:num_prompt_tokens],
input_metadata.attn_bias,
2023-03-06 10:05:27 -08:00
)
2023-02-23 09:31:55 +00:00
# Wait until the cache op is done.
if cache_event is not None:
cache_event.wait()
# Reshape the keys and values and store them in the cache.
# When key_cache and value_cache are not provided, the new key
# and value vectors will not be cached.
2023-04-02 00:30:17 -07:00
num_valid_tokens = input_metadata.num_valid_tokens
if (num_valid_tokens > 0 and key_cache is not None
and value_cache is not None):
2023-04-02 00:30:17 -07:00
# The stride is 3 because the key and value are sliced from qkv.
cache_ops.reshape_and_cache(
key[:num_valid_tokens],
value[:num_valid_tokens],
key_cache,
value_cache,
input_metadata.slot_mapping,
)
2023-02-23 09:31:55 +00:00
if input_metadata.num_generation_tokens > 0:
assert key_cache is not None and value_cache is not None, (
"key_cache and value_cache must be provided when "
"generating tokens."
)
2023-02-23 09:31:55 +00:00
# Compute the attention op for generation tokens.
self.single_query_cached_kv_attention(
2023-04-02 00:30:17 -07:00
output[num_prompt_tokens:num_valid_tokens],
query[num_prompt_tokens:num_valid_tokens],
2023-02-23 09:31:55 +00:00
key_cache,
value_cache,
input_metadata)
# Reshape the output tensor.
# NOTE(woosuk): The output tensor may include paddings.
return output.view(-1, self.num_heads * self.head_size)
class GPTNeoXCacheFlowAttention(GPTCacheFlowAttention):
"""Attention with GPT-NeoX style rotary embedding."""
def __init__(
self,
num_heads: int,
head_size: int,
scale: float,
rotary_dim: int,
max_position: int = 8192,
base: int = 10000,
) -> None:
super().__init__(num_heads, head_size, scale)
# Create the cos and sin cache.
inv_freq = 1.0 / (base ** (torch.arange(0, rotary_dim, 2) / rotary_dim))
t = torch.arange(max_position).float()
freqs = torch.einsum('i,j -> ij', t, inv_freq.float())
cos = freqs.cos()
sin = freqs.sin()
cache = torch.cat((cos, sin), dim=-1)
# FIXME(woosuk): This assumes that we configure the default dtype when
# initializing the model. Make it more robust.
torch_dtype = torch.get_default_dtype()
cache = cache.to(torch_dtype)
# Embedding size: [max_position, rotary_dim]
self.register_buffer("cos_sin_cache", cache, persistent=False)
def forward(
self,
positions: torch.Tensor, # [num_tokens]
query: torch.Tensor, # [num_tokens, num_heads * head_size]
key: torch.Tensor, # [num_tokens, num_heads * head_size]
value: torch.Tensor, # [num_tokens, num_heads * head_size]
key_cache: torch.Tensor, # [num_blocks, num_heads, head_size/x, block_size, x]
value_cache: torch.Tensor, # [num_blocks, num_heads, head_size, block_size]
input_metadata: InputMetadata,
cache_event: Optional[torch.cuda.Event],
) -> torch.Tensor: # [num_tokens, num_heads * head_size]
# Apply rotary embedding to the query and key before passing them
# to the attention op.
pos_encoding_ops.rotary_embedding_neox(
positions,
query,
key,
self.head_size,
self.cos_sin_cache,
)
return super().forward(
2023-04-02 00:30:17 -07:00
query,
key,
value,
key_cache,
value_cache,
input_metadata,
cache_event,
)