diff --git a/README.md b/README.md index 8df3fece..0543b9de 100644 --- a/README.md +++ b/README.md @@ -3,11 +3,7 @@ ## Installation ```bash -pip install psutil numpy ray torch -pip install git+https://github.com/huggingface/transformers # Required for LLaMA. -pip install sentencepiece # Required for LlamaTokenizer. -pip install ninja # To parallelize the compilation of flash-attn. -pip install flash-attn # This may take up to 10 mins. +pip install ninja psutil numpy sentencepiece ray torch transformers xformers pip install -e . ``` diff --git a/cacheflow/master/server.py b/cacheflow/master/server.py index 6f28b96f..b0296513 100644 --- a/cacheflow/master/server.py +++ b/cacheflow/master/server.py @@ -213,7 +213,7 @@ def add_server_arguments(parser: argparse.ArgumentParser): parser.add_argument('--use-np-cache', action='store_true', help='save a numpy copy of model weights for faster loading') parser.add_argument('--use-dummy-weights', action='store_true', help='use dummy values for model weights') - # NOTE(woosuk): FlashAttention does not support float32. + # TODO(woosuk): Support FP32 for debugging. parser.add_argument('--dtype', type=str, default='default', choices=['default', 'half', 'bfloat16'], help=('data type for model weights and activations. ' 'The "default" option will use FP16 precision ' diff --git a/cacheflow/models/attention.py b/cacheflow/models/attention.py index 4c085644..179fbd0b 100644 --- a/cacheflow/models/attention.py +++ b/cacheflow/models/attention.py @@ -1,8 +1,8 @@ from typing import Optional -from flash_attn.flash_attn_interface import _flash_attn_forward import torch import torch.nn as nn +from xformers import ops as xops from cacheflow import attention_ops from cacheflow import cache_ops @@ -15,6 +15,7 @@ class GPTCacheFlowAttention(nn.Module): def __init__(self, scale: float) -> None: super().__init__() self.scale = float(scale) + self.attn_op = xops.fmha.cutlass.FwOp() def multi_query_kv_attention( self, @@ -22,32 +23,21 @@ class GPTCacheFlowAttention(nn.Module): query: torch.Tensor, # [num_prompt_tokens, num_heads, head_size] key: torch.Tensor, # [num_prompt_tokens, num_heads, head_size] value: torch.Tensor, # [num_prompt_tokens, num_heads, head_size] - cumulative_prompt_lens: torch.Tensor, # [num_prompts + 1] - max_prompt_len: int, + attn_bias: xops.AttentionBias, ) -> None: - if query.dtype == torch.float: - raise ValueError('The float data type is not supported by ' - 'FlashAttention. Use the half data type instead.') - head_size = query.shape[-1] - if head_size > 128: - raise ValueError('FlashAttention does not support head_size > 128.') - - # Directly call FlashAttention's internal function to avoid allocating - # a new tensor for the output. - _flash_attn_forward( - query, - key, - value, - output, - cumulative_prompt_lens, - cumulative_prompt_lens, - max_prompt_len, - max_prompt_len, - dropout_p=0.0, - softmax_scale=self.scale, - causal=True, - return_softmax=False, + # TODO(woosuk): The unsqueeze op may incur some CPU overhead. Optimize. + out = xops.memory_efficient_attention_forward( + query.unsqueeze(0), + key.unsqueeze(0), + value.unsqueeze(0), + attn_bias=attn_bias, + p=0.0, + scale=self.scale, + op=self.attn_op, ) + # TODO(woosuk): Unnecessary copy. Optimize. + output.copy_(out.squeeze(0)) + return output def single_query_cached_kv_attention( self, @@ -109,8 +99,7 @@ class GPTCacheFlowAttention(nn.Module): query[:num_prompt_tokens], key[:num_prompt_tokens], value[:num_prompt_tokens], - input_metadata.cumulative_prompt_lens, - input_metadata.max_prompt_len, + input_metadata.attn_bias, ) # Wait until the cache op is done. @@ -143,13 +132,6 @@ class GPTCacheFlowAttention(nn.Module): return output.view(-1, num_heads * head_size) -class OPTCacheFlowAttention(GPTCacheFlowAttention): - """OPT uses the same attention mechanism as GPT.""" - - def __init__(self, scale: float) -> None: - super().__init__(scale) - - class GPTNeoXCacheFlowAttention(GPTCacheFlowAttention): """Attention with GPT-NeoX style rotary embedding.""" @@ -207,7 +189,3 @@ class GPTNeoXCacheFlowAttention(GPTCacheFlowAttention): input_metadata, cache_event, ) - - -class LlamaCacheFlowAttention(GPTNeoXCacheFlowAttention): - """LLaMA uses the GPT-NeoX style rotary embedding.""" diff --git a/cacheflow/models/input_metadata.py b/cacheflow/models/input_metadata.py index c61bfff2..943524c9 100644 --- a/cacheflow/models/input_metadata.py +++ b/cacheflow/models/input_metadata.py @@ -1,6 +1,7 @@ from typing import List, Dict, Tuple import torch +from xformers.ops.fmha.attn_bias import BlockDiagonalCausalMask from cacheflow.sampling_params import SamplingParams @@ -12,7 +13,6 @@ class InputMetadata: seq_groups: List[Tuple[List[int], SamplingParams]], seq_logprobs: Dict[int, float], # Seq id -> cumulative logprobs. prompt_lens: List[int], - cumulative_prompt_lens: torch.Tensor, slot_mapping: torch.Tensor, context_lens: torch.Tensor, max_context_len: int, @@ -21,15 +21,14 @@ class InputMetadata: self.seq_groups = seq_groups self.seq_logprobs = seq_logprobs self.prompt_lens = prompt_lens - self.cumulative_prompt_lens = cumulative_prompt_lens self.slot_mapping = slot_mapping self.context_lens = context_lens self.max_context_len = max_context_len self.block_tables = block_tables + self.attn_bias = BlockDiagonalCausalMask.from_seqlens(prompt_lens) self.num_prompts = len(prompt_lens) self.num_prompt_tokens = sum(prompt_lens) - self.max_prompt_len = max(prompt_lens) if prompt_lens else 0 self.num_generation_tokens = context_lens.shape[0] self.num_valid_tokens = slot_mapping.shape[0] if block_tables.numel() > 0: @@ -41,15 +40,13 @@ class InputMetadata: def __repr__(self) -> str: return (f'InputMetadata(' - f'num_prompts={self.num_prompts}, ' - f'num_prompt_tokens={self.num_prompt_tokens}, ' - f'max_prompt_len={self.max_prompt_len}, ' - f'num_generation_tokens={self.num_generation_tokens}, ' f'num_valid_tokens={self.num_valid_tokens}, ' - f'max_num_blocks_per_seq={self.max_num_blocks_per_seq}, ' - f'max_context_len={self.max_context_len}), ' + f'num_prompt_tokens={self.num_prompt_tokens}, ' + f'num_prompts={self.num_prompts}, ' f'prompt_lens={self.prompt_lens}, ' - f'cumulative_prompt_lens={self.cumulative_prompt_lens}, ' - f'slot_mapping={self.slot_mapping}, ' + f'num_generation_tokens={self.num_generation_tokens}, ' f'context_lens={self.context_lens}, ' - f'block_tables={self.block_tables})') + f'max_context_len={self.max_context_len}), ' + f'max_num_blocks_per_seq={self.max_num_blocks_per_seq}, ' + f'block_tables={self.block_tables}), ' + f'slot_mapping={self.slot_mapping}') diff --git a/cacheflow/models/llama.py b/cacheflow/models/llama.py index 70665030..0669742d 100644 --- a/cacheflow/models/llama.py +++ b/cacheflow/models/llama.py @@ -7,7 +7,7 @@ from transformers import LlamaConfig from cacheflow.models import InputMetadata from cacheflow.models.activation import SiluAndMul -from cacheflow.models.attention import LlamaCacheFlowAttention +from cacheflow.models.attention import GPTNeoXCacheFlowAttention from cacheflow.models.layernorm import RMSNorm from cacheflow.models.sample import Sampler from cacheflow.models.utils import (hf_model_weights_iterator, @@ -79,7 +79,7 @@ class LlamaAttention(nn.Module): input_is_parallel=True, perform_initialization=False, ) - self.attn = LlamaCacheFlowAttention(self.scaling, self.head_dim) + self.attn = GPTNeoXCacheFlowAttention(self.scaling, self.head_dim) def forward( self, diff --git a/cacheflow/models/memory_analyzer.py b/cacheflow/models/memory_analyzer.py index 2f15052a..0adc2e79 100644 --- a/cacheflow/models/memory_analyzer.py +++ b/cacheflow/models/memory_analyzer.py @@ -202,8 +202,8 @@ class OPTMemoryAnalyzer(CacheFlowMemoryAnalyzer): # estimating # 1) the maximum activation tensor size during inference # 2) the residual tensor size during inference - # Here, we assume that FlashAttention is used and - # thus the attention maps are never materialized in GPU DRAM. + # Here, we assume that we use memory-efficient attention which + # does not materialize the attention maps in GPU DRAM. residual = max_num_batched_tokens * self.hidden_size qkv = 3 * (max_num_batched_tokens * self.hidden_size) // self.tensor_parallel_size ffn = max_num_batched_tokens * self.ffn_size // self.tensor_parallel_size @@ -277,8 +277,8 @@ class LlamaMemoryAnalyzer(CacheFlowMemoryAnalyzer): # estimating # 1) the maximum activation tensor size during inference # 2) the residual tensor size during inference - # Here, we assume that FlashAttention is used and - # thus the attention maps are never materialized in GPU DRAM. + # Here, we assume that we use memory-efficient attention which + # does not materialize the attention maps in GPU DRAM. residual = max_num_batched_tokens * self.hidden_size qkv = 3 * (max_num_batched_tokens * self.hidden_size) // self.tensor_parallel_size ffn = 2 * (max_num_batched_tokens * self.ffn_size) // self.tensor_parallel_size @@ -353,8 +353,8 @@ class GPTNeoXMemoryAnalyzer(CacheFlowMemoryAnalyzer): # estimating # 1) the maximum activation tensor size during inference # 2) the residual tensor size during inference - # Here, we assume that FlashAttention is used and - # thus the attention maps are never materialized in GPU DRAM. + # Here, we assume that we use memory-efficient attention which + # does not materialize the attention maps in GPU DRAM. residual = max_num_batched_tokens * self.hidden_size qkv = 3 * (max_num_batched_tokens * self.hidden_size) // self.tensor_parallel_size ffn = 2 * (max_num_batched_tokens * self.ffn_size) // self.tensor_parallel_size diff --git a/cacheflow/models/opt.py b/cacheflow/models/opt.py index 79b81cd0..4f1e729c 100644 --- a/cacheflow/models/opt.py +++ b/cacheflow/models/opt.py @@ -6,7 +6,7 @@ from torch import nn from transformers import OPTConfig from cacheflow.models import InputMetadata -from cacheflow.models.attention import OPTCacheFlowAttention +from cacheflow.models.attention import GPTCacheFlowAttention from cacheflow.models.sample import Sampler from cacheflow.models.utils import (hf_model_weights_iterator, load_tensor_parallel_weights) @@ -55,7 +55,7 @@ class OPTAttention(nn.Module): self.out_proj = RowParallelLinear(embed_dim, embed_dim, bias=bias, input_is_parallel=True, perform_initialization=False) - self.attn = OPTCacheFlowAttention(scale=self.scaling) + self.attn = GPTCacheFlowAttention(scale=self.scaling) def forward( self, diff --git a/cacheflow/worker/worker.py b/cacheflow/worker/worker.py index 9b76d04e..59001b9d 100644 --- a/cacheflow/worker/worker.py +++ b/cacheflow/worker/worker.py @@ -136,11 +136,6 @@ class Worker: slot = block_number * self.block_size + block_offset slot_mapping.append(slot) - cumulative_prompt_lens: List[int] = [0] - for prompt_len in prompt_lens: - cumulative_prompt_lens.append( - cumulative_prompt_lens[-1] + prompt_len) - # Add generation tokens. max_context_len = 0 max_num_blocks_per_seq = 0 @@ -196,14 +191,11 @@ class Worker: for block_table in generation_block_tables] block_tables_tensor = torch.tensor( padded_block_tables, dtype=torch.int, device='cuda') - cumulative_prompt_lens_tensor = torch.tensor( - cumulative_prompt_lens, dtype=torch.int, device='cuda') input_metadata = InputMetadata( seq_groups=seq_groups, seq_logprobs=seq_logprobs, prompt_lens=prompt_lens, - cumulative_prompt_lens=cumulative_prompt_lens_tensor, slot_mapping=slot_mapping_tensor, context_lens=context_lens_tensor, max_context_len=max_context_len, diff --git a/tests/kernels/activation.py b/tests/kernels/activation.py index 3d9a9a64..b35bea61 100644 --- a/tests/kernels/activation.py +++ b/tests/kernels/activation.py @@ -23,7 +23,7 @@ def test_silu_and_mul( if __name__ == '__main__': - for dtype in [torch.half, torch.float]: + for dtype in [torch.half, torch.bfloat16, torch.float]: for num_tokens in [7, 83, 2048]: for d in [512, 4096, 13824]: print(f'Testing dtype={dtype}, num_tokens={num_tokens}, d={d}') diff --git a/tests/kernels/attention.py b/tests/kernels/attention.py index 4567315d..ae46fd6b 100644 --- a/tests/kernels/attention.py +++ b/tests/kernels/attention.py @@ -1,8 +1,9 @@ import random from typing import List, Optional -from flash_attn.flash_attn_interface import _flash_attn_forward import torch +from xformers import ops as xops +from xformers.ops.fmha.attn_bias import BlockDiagonalCausalMask from cacheflow import attention_ops @@ -81,8 +82,10 @@ def ref_multi_query_kv_attention( end_idx = cu_seq_lens[i + 1] seq_len = end_idx - start_idx - # Create attention mask - attn_mask = torch.triu(torch.ones(seq_len, seq_len), diagonal=1) * -1e5 + # Create attention mask. + attn_mask = torch.triu( + torch.ones(seq_len, seq_len, dtype=dtype), diagonal=1) + attn_mask = attn_mask * torch.finfo(dtype).min attn_mask = attn_mask.to(dtype=dtype, device='cuda') ref_output = ref_masked_attention( @@ -160,21 +163,20 @@ def test_single_query_cached_kv_attention( num_blocks: int, dtype: torch.dtype, ) -> None: - qkv = torch.randn( + qkv = torch.empty( num_tokens, 3, num_heads, head_size, dtype=dtype, device='cuda') + qkv.uniform_(-1e-3, 1e-3) query, _, _ = qkv.unbind(dim=1) + x = 16 // torch.tensor([], dtype=dtype).element_size() key_block_shape = (num_heads, head_size // x, block_size, x) - key_cache = torch.randn( + key_cache = torch.empty( size=(num_blocks, *key_block_shape), dtype=dtype, device='cuda') + key_cache.uniform_(-1e-3, 1e-3) value_block_shape = (num_heads, head_size, block_size) - value_cache = torch.randn( + value_cache = torch.empty( size=(num_blocks, *value_block_shape), dtype=dtype, device='cuda') - - # Adjust the range of the values to reduce precision errors. - query = query / (head_size ** 0.5) - key_cache = key_cache / (head_size ** 0.5) - value_cache = value_cache / (head_size ** 0.5) + value_cache.uniform_(-1e-3, 1e-3) context_lens = [random.randint(1, MAX_SEQ_LEN) for _ in range(num_tokens)] max_context_len = max(context_lens) @@ -228,39 +230,30 @@ def test_multi_query_kv_attention( dtype: torch.dtype, ) -> None: seq_lens = random.sample(range(1, MAX_SEQ_LEN), num_seqs) - max_seq_len = max(seq_lens) num_tokens = sum(seq_lens) + scale = float(1.0 / (head_size ** 0.5)) + qkv = torch.empty( + num_tokens, 3, num_heads, head_size, dtype=dtype, device='cuda') + qkv.uniform_(-1e-3, 1e-3) + query, key, value = qkv.unbind(dim=1) + + attn_op = xops.fmha.cutlass.FwOp() + attn_bias = BlockDiagonalCausalMask.from_seqlens(seq_lens) + output = xops.memory_efficient_attention_forward( + query.unsqueeze(0), + key.unsqueeze(0), + value.unsqueeze(0), + attn_bias=attn_bias, + p=0.0, + scale=scale, + op=attn_op, + ) + output = output.squeeze(0) + cu_seq_lens = [0] for seq_len in seq_lens: cu_seq_lens.append(cu_seq_lens[-1] + seq_len) - cu_seq_lens = torch.tensor(cu_seq_lens, dtype=torch.int, device='cuda') - - scale = float(1.0 / (head_size ** 0.5)) - qkv = torch.randn( - num_tokens, 3, num_heads, head_size, dtype=dtype, device='cuda') - # Adjust the range of the values to reduce precision errors. - qkv = qkv / (head_size ** 0.5) - - query, key, value = qkv.unbind(dim=1) - output = torch.empty( - num_tokens, num_heads, head_size, dtype=dtype, device='cuda') - _flash_attn_forward( - query, - key, - value, - output, - cu_seq_lens, - cu_seq_lens, - max_seq_len, - max_seq_len, - dropout_p=0.0, - softmax_scale=scale, - causal=True, - return_softmax=False, - ) - - cu_seq_lens = cu_seq_lens.cpu().tolist() ref_output = ref_multi_query_kv_attention( cu_seq_lens, query, @@ -277,8 +270,8 @@ def test_attention(seed: int) -> None: # the test fails due to the precision issue. Re-run the test if it fails. torch.random.manual_seed(seed) torch.cuda.manual_seed(seed) - for dtype in [torch.half, torch.float]: - for block_size in [8, 16, 32]: + for dtype in [torch.half, torch.bfloat16]: + for block_size in [8, 16, 32, 64]: for head_size in [32, 64, 80, 96, 128, 160, 192, 256]: print(f'Testing single_query_cached_kv_attention with ' f'dtype={dtype}, block_size={block_size}, ' @@ -292,14 +285,12 @@ def test_attention(seed: int) -> None: dtype=dtype, ) - # NOTE(woosuk): FlashAttention does not support FP32. - for dtype in [torch.half]: - # NOTE(woosuk): FlashAttention does not support head_size > 128. - for head_size in [64, 80, 96, 128]: + for dtype in [torch.half, torch.bfloat16]: + for head_size in [32, 64, 80, 96, 128, 160, 192, 256]: print(f'Testing multi_query_kv_attention with dtype={dtype}, ' f'head_size={head_size}') test_multi_query_kv_attention( - num_seqs=11, + num_seqs=5, num_heads=3, head_size=head_size, dtype=dtype, diff --git a/tests/kernels/cache.py b/tests/kernels/cache.py index f444ac16..b750ca97 100644 --- a/tests/kernels/cache.py +++ b/tests/kernels/cache.py @@ -142,15 +142,16 @@ def test_gather_cached_kv( @torch.inference_mode() def test_cache() -> None: - test_copy_blocks( - num_mappings=23, num_layers=7, num_heads=17, head_size=16, - block_size=8, num_blocks=1024, dtype=torch.half) - test_reshape_and_cache( - num_tokens=3, num_heads=2, head_size=16, block_size=8, num_blocks=2, - dtype=torch.half) - test_gather_cached_kv( - num_tokens=3, num_heads=2, head_size=16, block_size=8, num_blocks=2, - dtype=torch.half) + for dtype in [torch.half, torch.bfloat16, torch.float]: + test_copy_blocks( + num_mappings=23, num_layers=7, num_heads=17, head_size=16, + block_size=8, num_blocks=1024, dtype=dtype) + test_reshape_and_cache( + num_tokens=3, num_heads=2, head_size=16, block_size=8, num_blocks=2, + dtype=dtype) + test_gather_cached_kv( + num_tokens=3, num_heads=2, head_size=16, block_size=8, num_blocks=2, + dtype=dtype) if __name__ == '__main__': diff --git a/tests/kernels/layernorm.py b/tests/kernels/layernorm.py index 0e0072d8..a61fa9b6 100644 --- a/tests/kernels/layernorm.py +++ b/tests/kernels/layernorm.py @@ -8,7 +8,8 @@ class RefRMSNorm(nn.Module): def __init__(self, hidden_size, eps=1e-6): super().__init__() - weight = torch.randn(hidden_size) / (hidden_size ** 0.5) + weight = torch.empty(hidden_size) + weight.uniform_(-1e-3, 1e-3) self.weight = nn.Parameter(weight) self.variance_epsilon = eps @@ -41,7 +42,7 @@ def test_rms_norm( if __name__ == '__main__': - for dtype in [torch.half, torch.float]: + for dtype in [torch.half, torch.bfloat16, torch.float]: for num_tokens in [7, 128, 2048]: for hidden_size in [13, 64, 1024, 5120]: print(f'Testing RMS kernel with dtype={dtype}, num_tokens=' diff --git a/tests/kernels/pos_encoding.py b/tests/kernels/pos_encoding.py index 11fd6695..16b3992a 100644 --- a/tests/kernels/pos_encoding.py +++ b/tests/kernels/pos_encoding.py @@ -129,7 +129,7 @@ def test_rotary_embedding_neox( if __name__ == '__main__': - for dtype in [torch.half, torch.float]: + for dtype in [torch.half, torch.bfloat16, torch.float]: for head_size in [32, 64, 80, 96, 128, 160, 192, 256]: print(f'Running tests for head_size={head_size} and dtype={dtype}') test_rotary_embedding_neox(