2024-05-15 11:52:45 +09:00
|
|
|
"""Attention layer with Flash and PagedAttention.
|
|
|
|
|
|
|
|
NOTE(woosuk): At the moment, this file includes a lot of duplicated code from
|
|
|
|
XFormers backend. The duplicated code will be removed once we use flash-attn or
|
|
|
|
flashinfer for all the attention operations.
|
|
|
|
"""
|
2024-03-24 21:39:33 -07:00
|
|
|
from dataclasses import dataclass
|
2024-05-08 12:07:05 -07:00
|
|
|
from typing import List, Optional, Tuple, Type
|
2024-03-24 21:39:33 -07:00
|
|
|
|
|
|
|
import torch
|
2024-05-15 11:52:45 +09:00
|
|
|
from vllm_flash_attn import flash_attn_varlen_func
|
2024-03-24 21:39:33 -07:00
|
|
|
|
|
|
|
from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
|
2024-04-11 09:56:48 +09:00
|
|
|
AttentionMetadata,
|
|
|
|
AttentionMetadataPerStage)
|
2024-05-15 11:52:45 +09:00
|
|
|
from vllm.attention.ops.paged_attn import (PagedAttention,
|
|
|
|
PagedAttentionMetadata)
|
2024-03-24 21:39:33 -07:00
|
|
|
|
|
|
|
|
|
|
|
class FlashAttentionBackend(AttentionBackend):
|
|
|
|
|
2024-05-08 09:59:31 -07:00
|
|
|
@staticmethod
|
|
|
|
def get_name() -> str:
|
|
|
|
return "flash-attn"
|
|
|
|
|
2024-03-24 21:39:33 -07:00
|
|
|
@staticmethod
|
|
|
|
def get_impl_cls() -> Type["FlashAttentionImpl"]:
|
|
|
|
return FlashAttentionImpl
|
|
|
|
|
|
|
|
@staticmethod
|
|
|
|
def make_metadata(*args, **kwargs) -> "FlashAttentionMetadata":
|
|
|
|
return FlashAttentionMetadata(*args, **kwargs)
|
|
|
|
|
|
|
|
@staticmethod
|
|
|
|
def get_kv_cache_shape(
|
|
|
|
num_blocks: int,
|
|
|
|
block_size: int,
|
|
|
|
num_kv_heads: int,
|
|
|
|
head_size: int,
|
|
|
|
) -> Tuple[int, ...]:
|
2024-05-15 11:52:45 +09:00
|
|
|
return PagedAttention.get_kv_cache_shape(num_blocks, block_size,
|
|
|
|
num_kv_heads, head_size)
|
2024-03-24 21:39:33 -07:00
|
|
|
|
|
|
|
@staticmethod
|
|
|
|
def swap_blocks(
|
|
|
|
src_kv_cache: torch.Tensor,
|
|
|
|
dst_kv_cache: torch.Tensor,
|
2024-05-08 12:07:05 -07:00
|
|
|
src_to_dst: torch.Tensor,
|
2024-03-24 21:39:33 -07:00
|
|
|
) -> None:
|
2024-05-15 11:52:45 +09:00
|
|
|
PagedAttention.swap_blocks(src_kv_cache, dst_kv_cache, src_to_dst)
|
2024-03-24 21:39:33 -07:00
|
|
|
|
|
|
|
@staticmethod
|
|
|
|
def copy_blocks(
|
|
|
|
kv_caches: List[torch.Tensor],
|
2024-05-06 21:30:27 -07:00
|
|
|
src_to_dists: torch.Tensor,
|
2024-03-24 21:39:33 -07:00
|
|
|
) -> None:
|
2024-05-15 11:52:45 +09:00
|
|
|
PagedAttention.copy_blocks(kv_caches, src_to_dists)
|
2024-03-24 21:39:33 -07:00
|
|
|
|
|
|
|
|
|
|
|
@dataclass
|
2024-05-15 11:52:45 +09:00
|
|
|
class FlashAttentionMetadata(AttentionMetadataPerStage,
|
|
|
|
PagedAttentionMetadata):
|
2024-03-24 21:39:33 -07:00
|
|
|
"""Metadata for FlashAttentionBackend.
|
|
|
|
|
|
|
|
NOTE: Any python object stored here is not updated when it is
|
|
|
|
cuda-graph replayed. If you have values that need to be changed
|
|
|
|
dynamically, it should be stored in tensor. The tensor has to be
|
|
|
|
updated from `CUDAGraphRunner.forward` API.
|
|
|
|
"""
|
|
|
|
# Currently, input sequences can only contain all prompts
|
|
|
|
# or all decoding. True if all sequences are prompts.
|
|
|
|
is_prompt: bool
|
2024-05-04 02:20:12 +09:00
|
|
|
# (batch_size,). The sequence length per sequence. Sequence length means
|
|
|
|
# the computed tokens + new tokens None if it is a decoding.
|
|
|
|
seq_lens: Optional[List[int]]
|
|
|
|
# seq_lens stored as a tensor.
|
|
|
|
seq_lens_tensor: Optional[torch.Tensor]
|
2024-03-24 21:39:33 -07:00
|
|
|
|
2024-05-04 02:20:12 +09:00
|
|
|
# NOTE(sang): Definition of context_len, query_len, and seq_len.
|
2024-03-24 21:39:33 -07:00
|
|
|
# |---------- N-1 iteration --------|
|
|
|
|
# |---------------- N iteration ---------------------|
|
|
|
|
# |- tokenA -|......................|-- newTokens ---|
|
|
|
|
# |---------- context_len ----------|
|
2024-05-04 02:20:12 +09:00
|
|
|
# |-------------------- seq_len ----------------------|
|
|
|
|
# |-- query_len ---|
|
2024-03-24 21:39:33 -07:00
|
|
|
|
2024-05-04 02:20:12 +09:00
|
|
|
# Maximum query length in the batch.
|
|
|
|
max_query_len: Optional[int]
|
|
|
|
# Maximum sequence length in the batch.
|
|
|
|
max_seq_len: Optional[int]
|
2024-03-24 21:39:33 -07:00
|
|
|
# (batch_size + 1,). The cumulative subquery lengths of the sequences in
|
|
|
|
# the batch, used to index into subquery. E.g., if the subquery length
|
|
|
|
# is [4, 6], it is [0, 4, 10].
|
|
|
|
subquery_start_loc: Optional[torch.Tensor]
|
|
|
|
# (batch_size + 1,). The cumulative sequence lengths of the sequences in
|
|
|
|
# the batch, used to index into sequence. E.g., if the sequence length is
|
|
|
|
# [4, 6], it is [0, 4, 10].
|
|
|
|
seq_start_loc: Optional[torch.Tensor]
|
2024-05-04 02:20:12 +09:00
|
|
|
# (batch_size,) A tensor of context lengths (tokens that are computed
|
|
|
|
# so far).
|
|
|
|
context_lens_tensor: Optional[torch.Tensor]
|
2024-03-24 21:39:33 -07:00
|
|
|
|
|
|
|
# Whether or not if cuda graph is enabled.
|
|
|
|
# Cuda-graph is currently enabled for decoding only.
|
|
|
|
# TODO(woosuk): Move `use_cuda_graph` out since it's unrelated to attention.
|
|
|
|
use_cuda_graph: bool
|
|
|
|
|
|
|
|
|
|
|
|
class FlashAttentionImpl(AttentionImpl):
|
|
|
|
"""
|
|
|
|
If the input tensors contain prompt tokens, the layout is as follows:
|
2024-04-11 09:56:48 +09:00
|
|
|
|<--------------- num_prefill_tokens ----------------->|
|
|
|
|
|<--prefill_0-->|<--prefill_1-->|...|<--prefill_N-1--->|
|
2024-03-24 21:39:33 -07:00
|
|
|
|
|
|
|
Otherwise, the layout is as follows:
|
2024-04-11 09:56:48 +09:00
|
|
|
|<----------------- num_decode_tokens ------------------>|
|
|
|
|
|<--decode_0-->|..........|<--decode_M-1-->|<--padding-->|
|
2024-03-24 21:39:33 -07:00
|
|
|
|
|
|
|
Generation tokens can contain padding when cuda-graph is used.
|
|
|
|
Currently, prompt tokens don't contain any padding.
|
|
|
|
|
|
|
|
The prompts might have different lengths, while the generation tokens
|
|
|
|
always have length 1.
|
2024-04-11 09:56:48 +09:00
|
|
|
|
|
|
|
If chunked prefill is enabled, prefill tokens and decode tokens can be
|
|
|
|
batched together in a flattened 1D query.
|
|
|
|
|
|
|
|
|<----- num_prefill_tokens ---->|<------- num_decode_tokens --------->|
|
|
|
|
|<-prefill_0->|...|<-prefill_N-1->|<--decode_0-->|...|<--decode_M-1-->|
|
|
|
|
|
|
|
|
Currently, cuda graph is disabled for chunked prefill, meaning there's no
|
|
|
|
padding between prefill and decode tokens.
|
2024-03-24 21:39:33 -07:00
|
|
|
"""
|
|
|
|
|
|
|
|
def __init__(
|
|
|
|
self,
|
|
|
|
num_heads: int,
|
|
|
|
head_size: int,
|
|
|
|
scale: float,
|
|
|
|
num_kv_heads: Optional[int] = None,
|
|
|
|
alibi_slopes: Optional[List[float]] = None,
|
|
|
|
sliding_window: Optional[int] = None,
|
2024-05-13 10:47:25 -07:00
|
|
|
kv_cache_dtype: str = "auto",
|
2024-03-24 21:39:33 -07:00
|
|
|
) -> None:
|
|
|
|
self.num_heads = num_heads
|
|
|
|
self.head_size = head_size
|
|
|
|
self.scale = float(scale)
|
|
|
|
self.num_kv_heads = num_heads if num_kv_heads is None else num_kv_heads
|
|
|
|
if alibi_slopes is not None:
|
|
|
|
alibi_slopes = torch.tensor(alibi_slopes, dtype=torch.float32)
|
|
|
|
self.alibi_slopes = alibi_slopes
|
2024-05-13 10:47:25 -07:00
|
|
|
self.sliding_window = ((sliding_window, sliding_window)
|
|
|
|
if sliding_window is not None else (-1, -1))
|
|
|
|
self.kv_cache_dtype = kv_cache_dtype
|
2024-03-24 21:39:33 -07:00
|
|
|
|
|
|
|
assert self.num_heads % self.num_kv_heads == 0
|
|
|
|
self.num_queries_per_kv = self.num_heads // self.num_kv_heads
|
|
|
|
|
2024-05-15 11:52:45 +09:00
|
|
|
suppored_head_sizes = PagedAttention.get_supported_head_sizes()
|
|
|
|
if head_size not in suppored_head_sizes:
|
2024-03-24 21:39:33 -07:00
|
|
|
raise ValueError(
|
2024-05-15 11:52:45 +09:00
|
|
|
f"Head size {head_size} is not supported by PagedAttention. "
|
|
|
|
f"Supported head sizes are: {suppored_head_sizes}.")
|
2024-03-24 21:39:33 -07:00
|
|
|
|
|
|
|
def forward(
|
|
|
|
self,
|
|
|
|
query: torch.Tensor,
|
|
|
|
key: torch.Tensor,
|
|
|
|
value: torch.Tensor,
|
|
|
|
kv_cache: torch.Tensor,
|
2024-04-11 09:56:48 +09:00
|
|
|
attn_metadata: AttentionMetadata[FlashAttentionMetadata],
|
2024-05-13 10:47:25 -07:00
|
|
|
kv_scale: float = 1.0,
|
2024-03-24 21:39:33 -07:00
|
|
|
) -> torch.Tensor:
|
2024-05-15 11:52:45 +09:00
|
|
|
"""Forward pass with FlashAttention and PagedAttention.
|
2024-03-24 21:39:33 -07:00
|
|
|
|
|
|
|
Args:
|
|
|
|
query: shape = [num_tokens, num_heads * head_size]
|
|
|
|
key: shape = [num_tokens, num_kv_heads * head_size]
|
|
|
|
value: shape = [num_tokens, num_kv_heads * head_size]
|
2024-05-15 11:52:45 +09:00
|
|
|
kv_cache = [2, num_blocks, block_size * num_kv_heads * head_size]
|
2024-03-24 21:39:33 -07:00
|
|
|
attn_metadata: Metadata for attention.
|
|
|
|
Returns:
|
|
|
|
shape = [num_tokens, num_heads * head_size]
|
|
|
|
"""
|
|
|
|
num_tokens, hidden_size = query.shape
|
|
|
|
# Reshape the query, key, and value tensors.
|
|
|
|
query = query.view(-1, self.num_heads, self.head_size)
|
|
|
|
key = key.view(-1, self.num_kv_heads, self.head_size)
|
|
|
|
value = value.view(-1, self.num_kv_heads, self.head_size)
|
|
|
|
|
|
|
|
if kv_cache is not None:
|
2024-05-15 11:52:45 +09:00
|
|
|
key_cache, value_cache = PagedAttention.split_kv_cache(
|
|
|
|
kv_cache, self.num_kv_heads, self.head_size)
|
2024-03-24 21:39:33 -07:00
|
|
|
|
|
|
|
# Reshape the input keys and values and store them in the cache.
|
|
|
|
# If kv_cache is not provided, the new key and value tensors are
|
|
|
|
# not cached. This happens during the initial memory profiling run.
|
2024-05-15 11:52:45 +09:00
|
|
|
PagedAttention.write_to_paged_cache(key, value, key_cache,
|
|
|
|
value_cache,
|
|
|
|
attn_metadata.slot_mapping,
|
|
|
|
self.kv_cache_dtype, kv_scale)
|
2024-03-24 21:39:33 -07:00
|
|
|
|
2024-04-11 09:56:48 +09:00
|
|
|
num_prefill_tokens = attn_metadata.num_prefill_tokens
|
|
|
|
num_decode_tokens = attn_metadata.num_decode_tokens
|
|
|
|
assert key.shape[0] == num_prefill_tokens + num_decode_tokens
|
|
|
|
assert value.shape[0] == num_prefill_tokens + num_decode_tokens
|
|
|
|
|
|
|
|
output = torch.empty_like(query)
|
|
|
|
# Query for decode. KV is not needed because it is already cached.
|
|
|
|
decode_query = query[num_prefill_tokens:]
|
|
|
|
# QKV for prefill.
|
|
|
|
query = query[:num_prefill_tokens]
|
|
|
|
key = key[:num_prefill_tokens]
|
|
|
|
value = value[:num_prefill_tokens]
|
|
|
|
|
|
|
|
assert query.shape[0] == num_prefill_tokens
|
|
|
|
assert decode_query.shape[0] == num_decode_tokens
|
|
|
|
|
|
|
|
if prefill_meta := attn_metadata.prefill_metadata:
|
2024-03-24 21:39:33 -07:00
|
|
|
# Prompt run.
|
2024-05-15 11:52:45 +09:00
|
|
|
if kv_cache is None or prefill_meta.block_tables.numel() == 0:
|
2024-03-24 21:39:33 -07:00
|
|
|
# normal attention
|
|
|
|
# When block_tables are not filled, it means q and k are the
|
|
|
|
# prompt, and they have the same length.
|
2024-04-11 09:56:48 +09:00
|
|
|
out = flash_attn_varlen_func(
|
2024-03-24 21:39:33 -07:00
|
|
|
q=query,
|
|
|
|
k=key,
|
|
|
|
v=value,
|
2024-04-11 09:56:48 +09:00
|
|
|
cu_seqlens_q=prefill_meta.seq_start_loc,
|
|
|
|
cu_seqlens_k=prefill_meta.seq_start_loc,
|
2024-05-04 02:20:12 +09:00
|
|
|
max_seqlen_q=prefill_meta.max_seq_len,
|
|
|
|
max_seqlen_k=prefill_meta.max_seq_len,
|
2024-03-24 21:39:33 -07:00
|
|
|
softmax_scale=self.scale,
|
|
|
|
causal=True,
|
|
|
|
window_size=self.sliding_window,
|
|
|
|
alibi_slopes=self.alibi_slopes,
|
|
|
|
)
|
2024-04-11 09:56:48 +09:00
|
|
|
assert output[:num_prefill_tokens].shape == out.shape
|
|
|
|
output[:num_prefill_tokens] = out
|
2024-03-24 21:39:33 -07:00
|
|
|
else:
|
|
|
|
# prefix-enabled attention
|
2024-05-15 11:52:45 +09:00
|
|
|
# TODO(Hai) this triton kernel has regression issue (broke) to
|
|
|
|
# deal with different data types between KV and FP8 KV cache,
|
|
|
|
# to be addressed separately.
|
|
|
|
output[:num_prefill_tokens] = PagedAttention.forward_prefix(
|
|
|
|
query,
|
|
|
|
key,
|
|
|
|
value,
|
|
|
|
key_cache,
|
|
|
|
value_cache,
|
|
|
|
prefill_meta.block_tables,
|
|
|
|
prefill_meta.subquery_start_loc,
|
|
|
|
prefill_meta.seq_lens_tensor,
|
|
|
|
prefill_meta.context_lens_tensor,
|
|
|
|
prefill_meta.max_query_len,
|
|
|
|
self.alibi_slopes,
|
|
|
|
self.sliding_window[0],
|
2024-03-24 21:39:33 -07:00
|
|
|
)
|
2024-04-11 09:56:48 +09:00
|
|
|
if decode_meta := attn_metadata.decode_metadata:
|
2024-03-24 21:39:33 -07:00
|
|
|
# Decoding run.
|
2024-05-15 11:52:45 +09:00
|
|
|
output[num_prefill_tokens:] = PagedAttention.forward_decode(
|
|
|
|
decode_query,
|
2024-03-24 21:39:33 -07:00
|
|
|
key_cache,
|
|
|
|
value_cache,
|
2024-05-15 11:52:45 +09:00
|
|
|
decode_meta.block_tables,
|
|
|
|
decode_meta.seq_lens_tensor,
|
|
|
|
decode_meta.max_seq_len,
|
|
|
|
self.kv_cache_dtype,
|
|
|
|
self.num_kv_heads,
|
|
|
|
self.scale,
|
|
|
|
self.alibi_slopes,
|
|
|
|
kv_scale,
|
|
|
|
)
|
2024-03-24 21:39:33 -07:00
|
|
|
|
|
|
|
# Reshape the output tensor.
|
|
|
|
return output.view(num_tokens, hidden_size)
|