2023-02-24 08:58:46 +00:00
|
|
|
from typing import List, Optional
|
2023-02-23 09:31:55 +00:00
|
|
|
|
2023-03-01 21:13:08 -08:00
|
|
|
from flash_attn.flash_attention import FlashAttention
|
2023-02-23 09:31:55 +00:00
|
|
|
import torch
|
|
|
|
import torch.nn as nn
|
|
|
|
|
2023-03-01 15:02:19 -08:00
|
|
|
from cacheflow import attention_ops
|
|
|
|
from cacheflow import cache_ops
|
2023-02-23 09:31:55 +00:00
|
|
|
from cacheflow.models import InputMetadata
|
|
|
|
|
|
|
|
|
|
|
|
class OPTCacheFlowAttention(nn.Module):
|
|
|
|
|
|
|
|
def __init__(self, scale: float) -> None:
|
2023-03-13 13:48:38 -07:00
|
|
|
super(OPTCacheFlowAttention, self).__init__()
|
2023-03-01 15:02:19 -08:00
|
|
|
self.scale = float(scale)
|
2023-02-23 09:31:55 +00:00
|
|
|
|
2023-03-01 21:13:08 -08:00
|
|
|
self.flash_attn = FlashAttention(softmax_scale=self.scale)
|
2023-02-23 09:31:55 +00:00
|
|
|
|
|
|
|
def multi_query_kv_attention(
|
|
|
|
self,
|
2023-02-24 08:58:46 +00:00
|
|
|
output: torch.Tensor, # [num_prompt_tokens, num_heads, head_size]
|
|
|
|
query: torch.Tensor, # [num_prompt_tokens, num_heads, head_size]
|
|
|
|
key: torch.Tensor, # [num_prompt_tokens, num_heads, head_size]
|
|
|
|
value: torch.Tensor, # [num_prompt_tokens, num_heads, head_size]
|
|
|
|
prompt_lens: List[int],
|
2023-02-23 09:31:55 +00:00
|
|
|
) -> None:
|
2023-03-01 21:13:08 -08:00
|
|
|
if query.dtype == torch.float:
|
|
|
|
raise ValueError('The float data type is not supported by '
|
|
|
|
'FlashAttention. Use the half data type instead.')
|
|
|
|
head_size = query.shape[2]
|
|
|
|
if head_size > 128:
|
|
|
|
raise ValueError('FlashAttention does not support head_size > 128.')
|
|
|
|
|
|
|
|
device = query.device
|
|
|
|
prefix_sum = [0]
|
2023-02-24 08:58:46 +00:00
|
|
|
for prompt_len in prompt_lens:
|
2023-03-01 21:13:08 -08:00
|
|
|
prefix_sum.append(prefix_sum[-1] + prompt_len)
|
|
|
|
prefix_sum = torch.tensor(prefix_sum, dtype=torch.int, device=device)
|
|
|
|
max_prompt_len = max(prompt_lens)
|
|
|
|
|
|
|
|
# FIXME(woosuk): Unnecessary copy. Optimize this.
|
|
|
|
qkv = torch.stack([query, key, value], dim=1)
|
|
|
|
out = self.flash_attn(
|
|
|
|
qkv,
|
|
|
|
cu_seqlens=prefix_sum,
|
|
|
|
max_s=max_prompt_len,
|
|
|
|
causal=True,
|
|
|
|
)[0]
|
|
|
|
# FIXME(woosuk): Unnecessary copy. Optimize this.
|
2023-03-06 10:05:27 -08:00
|
|
|
output.copy_(out, non_blocking=True)
|
2023-02-23 09:31:55 +00:00
|
|
|
|
|
|
|
def single_query_cached_kv_attention(
|
|
|
|
self,
|
2023-02-24 08:58:46 +00:00
|
|
|
output: torch.Tensor, # [num_generation_tokens, num_heads, head_size]
|
|
|
|
query: torch.Tensor, # [num_generation_tokens, num_heads, head_size]
|
|
|
|
key_cache: torch.Tensor, # [num_blocks, num_heads, head_size/x, block_size, x]
|
2023-03-01 15:02:19 -08:00
|
|
|
value_cache: torch.Tensor, # [num_blocks, num_heads, head_size, block_size]
|
2023-02-23 09:31:55 +00:00
|
|
|
input_metadata: InputMetadata,
|
|
|
|
) -> None:
|
2023-03-01 21:13:08 -08:00
|
|
|
head_size = value_cache.shape[2]
|
|
|
|
supported_head_sizes = [32, 64, 80, 96, 128, 160, 192, 256]
|
|
|
|
if head_size not in supported_head_sizes:
|
|
|
|
raise ValueError(f'head_size ({head_size}) is not supported by '
|
|
|
|
'the single_query_cached_kv_attention kernel. '
|
|
|
|
'Use one of the following head sizes: '
|
|
|
|
f'{supported_head_sizes}.')
|
|
|
|
|
2023-03-01 15:02:19 -08:00
|
|
|
block_size = value_cache.shape[3]
|
|
|
|
attention_ops.single_query_cached_kv_attention(
|
|
|
|
output,
|
|
|
|
query,
|
|
|
|
key_cache,
|
|
|
|
value_cache,
|
|
|
|
self.scale,
|
|
|
|
input_metadata.block_tables,
|
|
|
|
input_metadata.context_lens,
|
|
|
|
block_size,
|
|
|
|
input_metadata.max_context_len,
|
|
|
|
)
|
2023-02-23 09:31:55 +00:00
|
|
|
|
|
|
|
def forward(
|
|
|
|
self,
|
2023-02-24 08:58:46 +00:00
|
|
|
query: torch.Tensor, # [num_tokens, num_heads * head_size]
|
|
|
|
key: torch.Tensor, # [num_tokens, num_heads * head_size]
|
|
|
|
value: torch.Tensor, # [num_tokens, num_heads * head_size]
|
|
|
|
key_cache: torch.Tensor, # [num_blocks, num_heads, head_size/x, block_size, x]
|
2023-03-01 15:02:19 -08:00
|
|
|
value_cache: torch.Tensor, # [num_blocks, num_heads, head_size, block_size]
|
2023-02-23 09:31:55 +00:00
|
|
|
input_metadata: InputMetadata,
|
|
|
|
cache_event: Optional[torch.cuda.Event],
|
2023-02-24 08:58:46 +00:00
|
|
|
) -> torch.Tensor: # [num_tokens, num_heads * head_size]
|
|
|
|
# Pre-allocate the output tensor.
|
|
|
|
output = torch.empty_like(query)
|
|
|
|
|
|
|
|
# Prune out paddings if any.
|
2023-02-23 23:02:25 +00:00
|
|
|
query = query[:input_metadata.num_valid_tokens]
|
|
|
|
key = key[:input_metadata.num_valid_tokens]
|
|
|
|
value = value[:input_metadata.num_valid_tokens]
|
|
|
|
|
2023-02-23 09:31:55 +00:00
|
|
|
# Reshape the input tensors.
|
|
|
|
num_heads = value_cache.shape[1]
|
2023-03-01 15:02:19 -08:00
|
|
|
head_size = value_cache.shape[2]
|
2023-02-23 09:31:55 +00:00
|
|
|
query = query.view(-1, num_heads, head_size)
|
|
|
|
key = key.view(-1, num_heads, head_size)
|
|
|
|
value = value.view(-1, num_heads, head_size)
|
2023-02-24 08:58:46 +00:00
|
|
|
output = output.view(-1, num_heads, head_size)
|
2023-02-23 09:31:55 +00:00
|
|
|
|
|
|
|
# Compute the attention op for prompts.
|
2023-03-13 13:48:38 -07:00
|
|
|
num_prompt_tokens = input_metadata.num_prompt_tokens
|
|
|
|
if num_prompt_tokens > 0:
|
2023-03-01 21:13:08 -08:00
|
|
|
self.multi_query_kv_attention(
|
2023-03-06 10:05:27 -08:00
|
|
|
output[:num_prompt_tokens],
|
|
|
|
query[:num_prompt_tokens],
|
|
|
|
key[:num_prompt_tokens],
|
2023-03-22 04:45:42 +08:00
|
|
|
value[:num_prompt_tokens],
|
2023-03-06 10:05:27 -08:00
|
|
|
input_metadata.prompt_lens,
|
|
|
|
)
|
2023-02-23 09:31:55 +00:00
|
|
|
|
|
|
|
# Wait until the cache op is done.
|
|
|
|
if cache_event is not None:
|
|
|
|
cache_event.wait()
|
|
|
|
|
|
|
|
# Reshape the keys and values and store them in the cache.
|
2023-03-01 15:02:19 -08:00
|
|
|
cache_ops.reshape_and_cache(
|
2023-02-23 09:31:55 +00:00
|
|
|
key, value, key_cache, value_cache, input_metadata.slot_mapping)
|
|
|
|
|
|
|
|
if input_metadata.num_generation_tokens > 0:
|
|
|
|
# Compute the attention op for generation tokens.
|
|
|
|
self.single_query_cached_kv_attention(
|
2023-03-13 13:48:38 -07:00
|
|
|
output[num_prompt_tokens:],
|
|
|
|
query[num_prompt_tokens:],
|
2023-02-23 09:31:55 +00:00
|
|
|
key_cache,
|
|
|
|
value_cache,
|
|
|
|
input_metadata)
|
|
|
|
|
|
|
|
# Reshape the output tensor.
|
2023-02-24 08:58:46 +00:00
|
|
|
# NOTE(woosuk): The output tensor may include paddings.
|
2023-02-23 09:31:55 +00:00
|
|
|
return output.view(-1, num_heads * head_size)
|