Replace FlashAttention with xformers (#70)

This commit is contained in:
Woosuk Kwon 2023-05-05 02:01:08 -07:00 committed by GitHub
parent 189ae23133
commit c9d5b6d4a8
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
13 changed files with 89 additions and 133 deletions

View File

@ -3,11 +3,7 @@
## Installation
```bash
pip install psutil numpy ray torch
pip install git+https://github.com/huggingface/transformers # Required for LLaMA.
pip install sentencepiece # Required for LlamaTokenizer.
pip install ninja # To parallelize the compilation of flash-attn.
pip install flash-attn # This may take up to 10 mins.
pip install ninja psutil numpy sentencepiece ray torch transformers xformers
pip install -e .
```

View File

@ -213,7 +213,7 @@ def add_server_arguments(parser: argparse.ArgumentParser):
parser.add_argument('--use-np-cache', action='store_true',
help='save a numpy copy of model weights for faster loading')
parser.add_argument('--use-dummy-weights', action='store_true', help='use dummy values for model weights')
# NOTE(woosuk): FlashAttention does not support float32.
# TODO(woosuk): Support FP32 for debugging.
parser.add_argument('--dtype', type=str, default='default', choices=['default', 'half', 'bfloat16'],
help=('data type for model weights and activations. '
'The "default" option will use FP16 precision '

View File

@ -1,8 +1,8 @@
from typing import Optional
from flash_attn.flash_attn_interface import _flash_attn_forward
import torch
import torch.nn as nn
from xformers import ops as xops
from cacheflow import attention_ops
from cacheflow import cache_ops
@ -15,6 +15,7 @@ class GPTCacheFlowAttention(nn.Module):
def __init__(self, scale: float) -> None:
super().__init__()
self.scale = float(scale)
self.attn_op = xops.fmha.cutlass.FwOp()
def multi_query_kv_attention(
self,
@ -22,32 +23,21 @@ class GPTCacheFlowAttention(nn.Module):
query: torch.Tensor, # [num_prompt_tokens, num_heads, head_size]
key: torch.Tensor, # [num_prompt_tokens, num_heads, head_size]
value: torch.Tensor, # [num_prompt_tokens, num_heads, head_size]
cumulative_prompt_lens: torch.Tensor, # [num_prompts + 1]
max_prompt_len: int,
attn_bias: xops.AttentionBias,
) -> None:
if query.dtype == torch.float:
raise ValueError('The float data type is not supported by '
'FlashAttention. Use the half data type instead.')
head_size = query.shape[-1]
if head_size > 128:
raise ValueError('FlashAttention does not support head_size > 128.')
# Directly call FlashAttention's internal function to avoid allocating
# a new tensor for the output.
_flash_attn_forward(
query,
key,
value,
output,
cumulative_prompt_lens,
cumulative_prompt_lens,
max_prompt_len,
max_prompt_len,
dropout_p=0.0,
softmax_scale=self.scale,
causal=True,
return_softmax=False,
# TODO(woosuk): The unsqueeze op may incur some CPU overhead. Optimize.
out = xops.memory_efficient_attention_forward(
query.unsqueeze(0),
key.unsqueeze(0),
value.unsqueeze(0),
attn_bias=attn_bias,
p=0.0,
scale=self.scale,
op=self.attn_op,
)
# TODO(woosuk): Unnecessary copy. Optimize.
output.copy_(out.squeeze(0))
return output
def single_query_cached_kv_attention(
self,
@ -109,8 +99,7 @@ class GPTCacheFlowAttention(nn.Module):
query[:num_prompt_tokens],
key[:num_prompt_tokens],
value[:num_prompt_tokens],
input_metadata.cumulative_prompt_lens,
input_metadata.max_prompt_len,
input_metadata.attn_bias,
)
# Wait until the cache op is done.
@ -143,13 +132,6 @@ class GPTCacheFlowAttention(nn.Module):
return output.view(-1, num_heads * head_size)
class OPTCacheFlowAttention(GPTCacheFlowAttention):
"""OPT uses the same attention mechanism as GPT."""
def __init__(self, scale: float) -> None:
super().__init__(scale)
class GPTNeoXCacheFlowAttention(GPTCacheFlowAttention):
"""Attention with GPT-NeoX style rotary embedding."""
@ -207,7 +189,3 @@ class GPTNeoXCacheFlowAttention(GPTCacheFlowAttention):
input_metadata,
cache_event,
)
class LlamaCacheFlowAttention(GPTNeoXCacheFlowAttention):
"""LLaMA uses the GPT-NeoX style rotary embedding."""

View File

@ -1,6 +1,7 @@
from typing import List, Dict, Tuple
import torch
from xformers.ops.fmha.attn_bias import BlockDiagonalCausalMask
from cacheflow.sampling_params import SamplingParams
@ -12,7 +13,6 @@ class InputMetadata:
seq_groups: List[Tuple[List[int], SamplingParams]],
seq_logprobs: Dict[int, float], # Seq id -> cumulative logprobs.
prompt_lens: List[int],
cumulative_prompt_lens: torch.Tensor,
slot_mapping: torch.Tensor,
context_lens: torch.Tensor,
max_context_len: int,
@ -21,15 +21,14 @@ class InputMetadata:
self.seq_groups = seq_groups
self.seq_logprobs = seq_logprobs
self.prompt_lens = prompt_lens
self.cumulative_prompt_lens = cumulative_prompt_lens
self.slot_mapping = slot_mapping
self.context_lens = context_lens
self.max_context_len = max_context_len
self.block_tables = block_tables
self.attn_bias = BlockDiagonalCausalMask.from_seqlens(prompt_lens)
self.num_prompts = len(prompt_lens)
self.num_prompt_tokens = sum(prompt_lens)
self.max_prompt_len = max(prompt_lens) if prompt_lens else 0
self.num_generation_tokens = context_lens.shape[0]
self.num_valid_tokens = slot_mapping.shape[0]
if block_tables.numel() > 0:
@ -41,15 +40,13 @@ class InputMetadata:
def __repr__(self) -> str:
return (f'InputMetadata('
f'num_prompts={self.num_prompts}, '
f'num_prompt_tokens={self.num_prompt_tokens}, '
f'max_prompt_len={self.max_prompt_len}, '
f'num_generation_tokens={self.num_generation_tokens}, '
f'num_valid_tokens={self.num_valid_tokens}, '
f'max_num_blocks_per_seq={self.max_num_blocks_per_seq}, '
f'max_context_len={self.max_context_len}), '
f'num_prompt_tokens={self.num_prompt_tokens}, '
f'num_prompts={self.num_prompts}, '
f'prompt_lens={self.prompt_lens}, '
f'cumulative_prompt_lens={self.cumulative_prompt_lens}, '
f'slot_mapping={self.slot_mapping}, '
f'num_generation_tokens={self.num_generation_tokens}, '
f'context_lens={self.context_lens}, '
f'block_tables={self.block_tables})')
f'max_context_len={self.max_context_len}), '
f'max_num_blocks_per_seq={self.max_num_blocks_per_seq}, '
f'block_tables={self.block_tables}), '
f'slot_mapping={self.slot_mapping}')

View File

@ -7,7 +7,7 @@ from transformers import LlamaConfig
from cacheflow.models import InputMetadata
from cacheflow.models.activation import SiluAndMul
from cacheflow.models.attention import LlamaCacheFlowAttention
from cacheflow.models.attention import GPTNeoXCacheFlowAttention
from cacheflow.models.layernorm import RMSNorm
from cacheflow.models.sample import Sampler
from cacheflow.models.utils import (hf_model_weights_iterator,
@ -79,7 +79,7 @@ class LlamaAttention(nn.Module):
input_is_parallel=True,
perform_initialization=False,
)
self.attn = LlamaCacheFlowAttention(self.scaling, self.head_dim)
self.attn = GPTNeoXCacheFlowAttention(self.scaling, self.head_dim)
def forward(
self,

View File

@ -202,8 +202,8 @@ class OPTMemoryAnalyzer(CacheFlowMemoryAnalyzer):
# estimating
# 1) the maximum activation tensor size during inference
# 2) the residual tensor size during inference
# Here, we assume that FlashAttention is used and
# thus the attention maps are never materialized in GPU DRAM.
# Here, we assume that we use memory-efficient attention which
# does not materialize the attention maps in GPU DRAM.
residual = max_num_batched_tokens * self.hidden_size
qkv = 3 * (max_num_batched_tokens * self.hidden_size) // self.tensor_parallel_size
ffn = max_num_batched_tokens * self.ffn_size // self.tensor_parallel_size
@ -277,8 +277,8 @@ class LlamaMemoryAnalyzer(CacheFlowMemoryAnalyzer):
# estimating
# 1) the maximum activation tensor size during inference
# 2) the residual tensor size during inference
# Here, we assume that FlashAttention is used and
# thus the attention maps are never materialized in GPU DRAM.
# Here, we assume that we use memory-efficient attention which
# does not materialize the attention maps in GPU DRAM.
residual = max_num_batched_tokens * self.hidden_size
qkv = 3 * (max_num_batched_tokens * self.hidden_size) // self.tensor_parallel_size
ffn = 2 * (max_num_batched_tokens * self.ffn_size) // self.tensor_parallel_size
@ -353,8 +353,8 @@ class GPTNeoXMemoryAnalyzer(CacheFlowMemoryAnalyzer):
# estimating
# 1) the maximum activation tensor size during inference
# 2) the residual tensor size during inference
# Here, we assume that FlashAttention is used and
# thus the attention maps are never materialized in GPU DRAM.
# Here, we assume that we use memory-efficient attention which
# does not materialize the attention maps in GPU DRAM.
residual = max_num_batched_tokens * self.hidden_size
qkv = 3 * (max_num_batched_tokens * self.hidden_size) // self.tensor_parallel_size
ffn = 2 * (max_num_batched_tokens * self.ffn_size) // self.tensor_parallel_size

View File

@ -6,7 +6,7 @@ from torch import nn
from transformers import OPTConfig
from cacheflow.models import InputMetadata
from cacheflow.models.attention import OPTCacheFlowAttention
from cacheflow.models.attention import GPTCacheFlowAttention
from cacheflow.models.sample import Sampler
from cacheflow.models.utils import (hf_model_weights_iterator,
load_tensor_parallel_weights)
@ -55,7 +55,7 @@ class OPTAttention(nn.Module):
self.out_proj = RowParallelLinear(embed_dim, embed_dim, bias=bias,
input_is_parallel=True,
perform_initialization=False)
self.attn = OPTCacheFlowAttention(scale=self.scaling)
self.attn = GPTCacheFlowAttention(scale=self.scaling)
def forward(
self,

View File

@ -136,11 +136,6 @@ class Worker:
slot = block_number * self.block_size + block_offset
slot_mapping.append(slot)
cumulative_prompt_lens: List[int] = [0]
for prompt_len in prompt_lens:
cumulative_prompt_lens.append(
cumulative_prompt_lens[-1] + prompt_len)
# Add generation tokens.
max_context_len = 0
max_num_blocks_per_seq = 0
@ -196,14 +191,11 @@ class Worker:
for block_table in generation_block_tables]
block_tables_tensor = torch.tensor(
padded_block_tables, dtype=torch.int, device='cuda')
cumulative_prompt_lens_tensor = torch.tensor(
cumulative_prompt_lens, dtype=torch.int, device='cuda')
input_metadata = InputMetadata(
seq_groups=seq_groups,
seq_logprobs=seq_logprobs,
prompt_lens=prompt_lens,
cumulative_prompt_lens=cumulative_prompt_lens_tensor,
slot_mapping=slot_mapping_tensor,
context_lens=context_lens_tensor,
max_context_len=max_context_len,

View File

@ -23,7 +23,7 @@ def test_silu_and_mul(
if __name__ == '__main__':
for dtype in [torch.half, torch.float]:
for dtype in [torch.half, torch.bfloat16, torch.float]:
for num_tokens in [7, 83, 2048]:
for d in [512, 4096, 13824]:
print(f'Testing dtype={dtype}, num_tokens={num_tokens}, d={d}')

View File

@ -1,8 +1,9 @@
import random
from typing import List, Optional
from flash_attn.flash_attn_interface import _flash_attn_forward
import torch
from xformers import ops as xops
from xformers.ops.fmha.attn_bias import BlockDiagonalCausalMask
from cacheflow import attention_ops
@ -81,8 +82,10 @@ def ref_multi_query_kv_attention(
end_idx = cu_seq_lens[i + 1]
seq_len = end_idx - start_idx
# Create attention mask
attn_mask = torch.triu(torch.ones(seq_len, seq_len), diagonal=1) * -1e5
# Create attention mask.
attn_mask = torch.triu(
torch.ones(seq_len, seq_len, dtype=dtype), diagonal=1)
attn_mask = attn_mask * torch.finfo(dtype).min
attn_mask = attn_mask.to(dtype=dtype, device='cuda')
ref_output = ref_masked_attention(
@ -160,21 +163,20 @@ def test_single_query_cached_kv_attention(
num_blocks: int,
dtype: torch.dtype,
) -> None:
qkv = torch.randn(
qkv = torch.empty(
num_tokens, 3, num_heads, head_size, dtype=dtype, device='cuda')
qkv.uniform_(-1e-3, 1e-3)
query, _, _ = qkv.unbind(dim=1)
x = 16 // torch.tensor([], dtype=dtype).element_size()
key_block_shape = (num_heads, head_size // x, block_size, x)
key_cache = torch.randn(
key_cache = torch.empty(
size=(num_blocks, *key_block_shape), dtype=dtype, device='cuda')
key_cache.uniform_(-1e-3, 1e-3)
value_block_shape = (num_heads, head_size, block_size)
value_cache = torch.randn(
value_cache = torch.empty(
size=(num_blocks, *value_block_shape), dtype=dtype, device='cuda')
# Adjust the range of the values to reduce precision errors.
query = query / (head_size ** 0.5)
key_cache = key_cache / (head_size ** 0.5)
value_cache = value_cache / (head_size ** 0.5)
value_cache.uniform_(-1e-3, 1e-3)
context_lens = [random.randint(1, MAX_SEQ_LEN) for _ in range(num_tokens)]
max_context_len = max(context_lens)
@ -228,39 +230,30 @@ def test_multi_query_kv_attention(
dtype: torch.dtype,
) -> None:
seq_lens = random.sample(range(1, MAX_SEQ_LEN), num_seqs)
max_seq_len = max(seq_lens)
num_tokens = sum(seq_lens)
scale = float(1.0 / (head_size ** 0.5))
qkv = torch.empty(
num_tokens, 3, num_heads, head_size, dtype=dtype, device='cuda')
qkv.uniform_(-1e-3, 1e-3)
query, key, value = qkv.unbind(dim=1)
attn_op = xops.fmha.cutlass.FwOp()
attn_bias = BlockDiagonalCausalMask.from_seqlens(seq_lens)
output = xops.memory_efficient_attention_forward(
query.unsqueeze(0),
key.unsqueeze(0),
value.unsqueeze(0),
attn_bias=attn_bias,
p=0.0,
scale=scale,
op=attn_op,
)
output = output.squeeze(0)
cu_seq_lens = [0]
for seq_len in seq_lens:
cu_seq_lens.append(cu_seq_lens[-1] + seq_len)
cu_seq_lens = torch.tensor(cu_seq_lens, dtype=torch.int, device='cuda')
scale = float(1.0 / (head_size ** 0.5))
qkv = torch.randn(
num_tokens, 3, num_heads, head_size, dtype=dtype, device='cuda')
# Adjust the range of the values to reduce precision errors.
qkv = qkv / (head_size ** 0.5)
query, key, value = qkv.unbind(dim=1)
output = torch.empty(
num_tokens, num_heads, head_size, dtype=dtype, device='cuda')
_flash_attn_forward(
query,
key,
value,
output,
cu_seq_lens,
cu_seq_lens,
max_seq_len,
max_seq_len,
dropout_p=0.0,
softmax_scale=scale,
causal=True,
return_softmax=False,
)
cu_seq_lens = cu_seq_lens.cpu().tolist()
ref_output = ref_multi_query_kv_attention(
cu_seq_lens,
query,
@ -277,8 +270,8 @@ def test_attention(seed: int) -> None:
# the test fails due to the precision issue. Re-run the test if it fails.
torch.random.manual_seed(seed)
torch.cuda.manual_seed(seed)
for dtype in [torch.half, torch.float]:
for block_size in [8, 16, 32]:
for dtype in [torch.half, torch.bfloat16]:
for block_size in [8, 16, 32, 64]:
for head_size in [32, 64, 80, 96, 128, 160, 192, 256]:
print(f'Testing single_query_cached_kv_attention with '
f'dtype={dtype}, block_size={block_size}, '
@ -292,14 +285,12 @@ def test_attention(seed: int) -> None:
dtype=dtype,
)
# NOTE(woosuk): FlashAttention does not support FP32.
for dtype in [torch.half]:
# NOTE(woosuk): FlashAttention does not support head_size > 128.
for head_size in [64, 80, 96, 128]:
for dtype in [torch.half, torch.bfloat16]:
for head_size in [32, 64, 80, 96, 128, 160, 192, 256]:
print(f'Testing multi_query_kv_attention with dtype={dtype}, '
f'head_size={head_size}')
test_multi_query_kv_attention(
num_seqs=11,
num_seqs=5,
num_heads=3,
head_size=head_size,
dtype=dtype,

View File

@ -142,15 +142,16 @@ def test_gather_cached_kv(
@torch.inference_mode()
def test_cache() -> None:
test_copy_blocks(
num_mappings=23, num_layers=7, num_heads=17, head_size=16,
block_size=8, num_blocks=1024, dtype=torch.half)
test_reshape_and_cache(
num_tokens=3, num_heads=2, head_size=16, block_size=8, num_blocks=2,
dtype=torch.half)
test_gather_cached_kv(
num_tokens=3, num_heads=2, head_size=16, block_size=8, num_blocks=2,
dtype=torch.half)
for dtype in [torch.half, torch.bfloat16, torch.float]:
test_copy_blocks(
num_mappings=23, num_layers=7, num_heads=17, head_size=16,
block_size=8, num_blocks=1024, dtype=dtype)
test_reshape_and_cache(
num_tokens=3, num_heads=2, head_size=16, block_size=8, num_blocks=2,
dtype=dtype)
test_gather_cached_kv(
num_tokens=3, num_heads=2, head_size=16, block_size=8, num_blocks=2,
dtype=dtype)
if __name__ == '__main__':

View File

@ -8,7 +8,8 @@ class RefRMSNorm(nn.Module):
def __init__(self, hidden_size, eps=1e-6):
super().__init__()
weight = torch.randn(hidden_size) / (hidden_size ** 0.5)
weight = torch.empty(hidden_size)
weight.uniform_(-1e-3, 1e-3)
self.weight = nn.Parameter(weight)
self.variance_epsilon = eps
@ -41,7 +42,7 @@ def test_rms_norm(
if __name__ == '__main__':
for dtype in [torch.half, torch.float]:
for dtype in [torch.half, torch.bfloat16, torch.float]:
for num_tokens in [7, 128, 2048]:
for hidden_size in [13, 64, 1024, 5120]:
print(f'Testing RMS kernel with dtype={dtype}, num_tokens='

View File

@ -129,7 +129,7 @@ def test_rotary_embedding_neox(
if __name__ == '__main__':
for dtype in [torch.half, torch.float]:
for dtype in [torch.half, torch.bfloat16, torch.float]:
for head_size in [32, 64, 80, 96, 128, 160, 192, 256]:
print(f'Running tests for head_size={head_size} and dtype={dtype}')
test_rotary_embedding_neox(