2023-04-02 00:30:17 -07:00
|
|
|
from typing import Optional
|
2023-02-23 09:31:55 +00:00
|
|
|
|
|
|
|
import torch
|
|
|
|
import torch.nn as nn
|
2023-05-05 02:01:08 -07:00
|
|
|
from xformers import ops as xops
|
2023-02-23 09:31:55 +00:00
|
|
|
|
2023-03-01 15:02:19 -08:00
|
|
|
from cacheflow import attention_ops
|
|
|
|
from cacheflow import cache_ops
|
2023-03-30 11:04:21 -07:00
|
|
|
from cacheflow import pos_encoding_ops
|
2023-05-09 15:30:12 -07:00
|
|
|
from cacheflow.model_executor.input_metadata import InputMetadata
|
2023-02-23 09:31:55 +00:00
|
|
|
|
|
|
|
|
2023-03-30 11:04:21 -07:00
|
|
|
class GPTCacheFlowAttention(nn.Module):
|
2023-02-23 09:31:55 +00:00
|
|
|
|
|
|
|
def __init__(self, scale: float) -> None:
|
2023-03-30 11:04:21 -07:00
|
|
|
super().__init__()
|
2023-03-01 15:02:19 -08:00
|
|
|
self.scale = float(scale)
|
2023-05-05 02:01:08 -07:00
|
|
|
self.attn_op = xops.fmha.cutlass.FwOp()
|
2023-02-23 09:31:55 +00:00
|
|
|
|
|
|
|
def multi_query_kv_attention(
|
|
|
|
self,
|
2023-04-02 00:30:17 -07:00
|
|
|
output: torch.Tensor, # [num_prompt_tokens, num_heads, head_size]
|
|
|
|
query: torch.Tensor, # [num_prompt_tokens, num_heads, head_size]
|
|
|
|
key: torch.Tensor, # [num_prompt_tokens, num_heads, head_size]
|
|
|
|
value: torch.Tensor, # [num_prompt_tokens, num_heads, head_size]
|
2023-05-05 02:01:08 -07:00
|
|
|
attn_bias: xops.AttentionBias,
|
2023-02-23 09:31:55 +00:00
|
|
|
) -> None:
|
2023-05-05 02:01:08 -07:00
|
|
|
# TODO(woosuk): The unsqueeze op may incur some CPU overhead. Optimize.
|
|
|
|
out = xops.memory_efficient_attention_forward(
|
|
|
|
query.unsqueeze(0),
|
|
|
|
key.unsqueeze(0),
|
|
|
|
value.unsqueeze(0),
|
|
|
|
attn_bias=attn_bias,
|
|
|
|
p=0.0,
|
|
|
|
scale=self.scale,
|
|
|
|
op=self.attn_op,
|
2023-04-02 00:30:17 -07:00
|
|
|
)
|
2023-05-05 02:01:08 -07:00
|
|
|
# TODO(woosuk): Unnecessary copy. Optimize.
|
|
|
|
output.copy_(out.squeeze(0))
|
|
|
|
return output
|
2023-02-23 09:31:55 +00:00
|
|
|
|
|
|
|
def single_query_cached_kv_attention(
|
|
|
|
self,
|
2023-02-24 08:58:46 +00:00
|
|
|
output: torch.Tensor, # [num_generation_tokens, num_heads, head_size]
|
|
|
|
query: torch.Tensor, # [num_generation_tokens, num_heads, head_size]
|
|
|
|
key_cache: torch.Tensor, # [num_blocks, num_heads, head_size/x, block_size, x]
|
2023-03-01 15:02:19 -08:00
|
|
|
value_cache: torch.Tensor, # [num_blocks, num_heads, head_size, block_size]
|
2023-02-23 09:31:55 +00:00
|
|
|
input_metadata: InputMetadata,
|
|
|
|
) -> None:
|
2023-03-01 21:13:08 -08:00
|
|
|
head_size = value_cache.shape[2]
|
|
|
|
supported_head_sizes = [32, 64, 80, 96, 128, 160, 192, 256]
|
|
|
|
if head_size not in supported_head_sizes:
|
|
|
|
raise ValueError(f'head_size ({head_size}) is not supported by '
|
|
|
|
'the single_query_cached_kv_attention kernel. '
|
|
|
|
'Use one of the following head sizes: '
|
|
|
|
f'{supported_head_sizes}.')
|
|
|
|
|
2023-03-01 15:02:19 -08:00
|
|
|
block_size = value_cache.shape[3]
|
|
|
|
attention_ops.single_query_cached_kv_attention(
|
|
|
|
output,
|
|
|
|
query,
|
|
|
|
key_cache,
|
|
|
|
value_cache,
|
|
|
|
self.scale,
|
|
|
|
input_metadata.block_tables,
|
|
|
|
input_metadata.context_lens,
|
|
|
|
block_size,
|
|
|
|
input_metadata.max_context_len,
|
|
|
|
)
|
2023-02-23 09:31:55 +00:00
|
|
|
|
|
|
|
def forward(
|
|
|
|
self,
|
2023-02-24 08:58:46 +00:00
|
|
|
query: torch.Tensor, # [num_tokens, num_heads * head_size]
|
|
|
|
key: torch.Tensor, # [num_tokens, num_heads * head_size]
|
|
|
|
value: torch.Tensor, # [num_tokens, num_heads * head_size]
|
|
|
|
key_cache: torch.Tensor, # [num_blocks, num_heads, head_size/x, block_size, x]
|
2023-03-01 15:02:19 -08:00
|
|
|
value_cache: torch.Tensor, # [num_blocks, num_heads, head_size, block_size]
|
2023-02-23 09:31:55 +00:00
|
|
|
input_metadata: InputMetadata,
|
|
|
|
cache_event: Optional[torch.cuda.Event],
|
2023-02-24 08:58:46 +00:00
|
|
|
) -> torch.Tensor: # [num_tokens, num_heads * head_size]
|
2023-04-02 00:30:17 -07:00
|
|
|
# NOTE: The query, key, and value tensors must be sliced from a qkv
|
|
|
|
# tensor of shape [num_tokens, 3 * num_heads * head_size].
|
2023-02-23 23:02:25 +00:00
|
|
|
|
2023-04-02 00:30:17 -07:00
|
|
|
# Reshape the query, key, and value tensors.
|
2023-02-23 09:31:55 +00:00
|
|
|
num_heads = value_cache.shape[1]
|
2023-03-01 15:02:19 -08:00
|
|
|
head_size = value_cache.shape[2]
|
2023-02-23 09:31:55 +00:00
|
|
|
query = query.view(-1, num_heads, head_size)
|
|
|
|
key = key.view(-1, num_heads, head_size)
|
|
|
|
value = value.view(-1, num_heads, head_size)
|
2023-04-02 00:30:17 -07:00
|
|
|
|
|
|
|
# Pre-allocate the output tensor.
|
|
|
|
output = torch.empty_like(query)
|
2023-02-23 09:31:55 +00:00
|
|
|
|
|
|
|
# Compute the attention op for prompts.
|
2023-03-13 13:48:38 -07:00
|
|
|
num_prompt_tokens = input_metadata.num_prompt_tokens
|
|
|
|
if num_prompt_tokens > 0:
|
2023-03-01 21:13:08 -08:00
|
|
|
self.multi_query_kv_attention(
|
2023-03-06 10:05:27 -08:00
|
|
|
output[:num_prompt_tokens],
|
|
|
|
query[:num_prompt_tokens],
|
|
|
|
key[:num_prompt_tokens],
|
2023-03-22 04:45:42 +08:00
|
|
|
value[:num_prompt_tokens],
|
2023-05-05 02:01:08 -07:00
|
|
|
input_metadata.attn_bias,
|
2023-03-06 10:05:27 -08:00
|
|
|
)
|
2023-02-23 09:31:55 +00:00
|
|
|
|
|
|
|
# Wait until the cache op is done.
|
|
|
|
if cache_event is not None:
|
|
|
|
cache_event.wait()
|
|
|
|
|
|
|
|
# Reshape the keys and values and store them in the cache.
|
2023-04-02 00:30:17 -07:00
|
|
|
num_valid_tokens = input_metadata.num_valid_tokens
|
|
|
|
if num_valid_tokens > 0:
|
|
|
|
# The stride is 3 because the key and value are sliced from qkv.
|
|
|
|
cache_ops.reshape_and_cache(
|
|
|
|
key[:num_valid_tokens],
|
|
|
|
value[:num_valid_tokens],
|
|
|
|
key_cache,
|
|
|
|
value_cache,
|
|
|
|
input_metadata.slot_mapping,
|
|
|
|
)
|
2023-02-23 09:31:55 +00:00
|
|
|
|
|
|
|
if input_metadata.num_generation_tokens > 0:
|
|
|
|
# Compute the attention op for generation tokens.
|
|
|
|
self.single_query_cached_kv_attention(
|
2023-04-02 00:30:17 -07:00
|
|
|
output[num_prompt_tokens:num_valid_tokens],
|
|
|
|
query[num_prompt_tokens:num_valid_tokens],
|
2023-02-23 09:31:55 +00:00
|
|
|
key_cache,
|
|
|
|
value_cache,
|
|
|
|
input_metadata)
|
|
|
|
|
|
|
|
# Reshape the output tensor.
|
2023-02-24 08:58:46 +00:00
|
|
|
# NOTE(woosuk): The output tensor may include paddings.
|
2023-02-23 09:31:55 +00:00
|
|
|
return output.view(-1, num_heads * head_size)
|
2023-03-30 11:04:21 -07:00
|
|
|
|
|
|
|
|
2023-04-28 00:32:10 -07:00
|
|
|
class GPTNeoXCacheFlowAttention(GPTCacheFlowAttention):
|
|
|
|
"""Attention with GPT-NeoX style rotary embedding."""
|
2023-03-30 11:04:21 -07:00
|
|
|
|
|
|
|
def __init__(
|
|
|
|
self,
|
|
|
|
scale: float,
|
2023-04-28 00:32:10 -07:00
|
|
|
rotary_dim: int,
|
2023-03-30 11:04:21 -07:00
|
|
|
max_position: int = 8192,
|
|
|
|
base: int = 10000,
|
|
|
|
) -> None:
|
|
|
|
super().__init__(scale)
|
|
|
|
|
|
|
|
# Create the cos and sin cache.
|
2023-04-28 00:32:10 -07:00
|
|
|
inv_freq = 1.0 / (base ** (torch.arange(0, rotary_dim, 2) / rotary_dim))
|
2023-03-30 11:04:21 -07:00
|
|
|
t = torch.arange(max_position).float()
|
|
|
|
freqs = torch.einsum('i,j -> ij', t, inv_freq.float())
|
|
|
|
cos = freqs.cos()
|
|
|
|
sin = freqs.sin()
|
|
|
|
cache = torch.cat((cos, sin), dim=-1)
|
|
|
|
|
|
|
|
# FIXME(woosuk): This assumes that we configure the default dtype when
|
|
|
|
# initializing the model. Make it more robust.
|
|
|
|
torch_dtype = torch.get_default_dtype()
|
|
|
|
cache = cache.to(torch_dtype)
|
2023-04-28 00:32:10 -07:00
|
|
|
# Embedding size: [max_position, rotary_dim]
|
2023-03-30 11:04:21 -07:00
|
|
|
self.register_buffer('cos_sin_cache', cache, persistent=False)
|
|
|
|
|
|
|
|
def forward(
|
|
|
|
self,
|
|
|
|
positions: torch.LongTensor, # [num_tokens]
|
|
|
|
query: torch.Tensor, # [num_tokens, num_heads * head_size]
|
|
|
|
key: torch.Tensor, # [num_tokens, num_heads * head_size]
|
|
|
|
value: torch.Tensor, # [num_tokens, num_heads * head_size]
|
|
|
|
key_cache: torch.Tensor, # [num_blocks, num_heads, head_size/x, block_size, x]
|
|
|
|
value_cache: torch.Tensor, # [num_blocks, num_heads, head_size, block_size]
|
|
|
|
input_metadata: InputMetadata,
|
|
|
|
cache_event: Optional[torch.cuda.Event],
|
|
|
|
) -> torch.Tensor: # [num_tokens, num_heads * head_size]
|
|
|
|
# Apply rotary embedding to the query and key before passing them
|
|
|
|
# to the attention op.
|
2023-04-28 00:32:10 -07:00
|
|
|
head_size = value_cache.shape[2]
|
2023-03-30 11:04:21 -07:00
|
|
|
pos_encoding_ops.rotary_embedding_neox(
|
|
|
|
positions,
|
|
|
|
query,
|
|
|
|
key,
|
2023-04-28 00:32:10 -07:00
|
|
|
head_size,
|
2023-03-30 11:04:21 -07:00
|
|
|
self.cos_sin_cache,
|
|
|
|
)
|
|
|
|
return super().forward(
|
2023-04-02 00:30:17 -07:00
|
|
|
query,
|
|
|
|
key,
|
2023-03-30 11:04:21 -07:00
|
|
|
value,
|
|
|
|
key_cache,
|
|
|
|
value_cache,
|
|
|
|
input_metadata,
|
|
|
|
cache_event,
|
|
|
|
)
|