Replace FlashAttention with xformers (#70)

This commit is contained in:
Woosuk Kwon 2023-05-05 02:01:08 -07:00 committed by GitHub
parent 189ae23133
commit c9d5b6d4a8
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
13 changed files with 89 additions and 133 deletions

View File

@ -3,11 +3,7 @@
## Installation ## Installation
```bash ```bash
pip install psutil numpy ray torch pip install ninja psutil numpy sentencepiece ray torch transformers xformers
pip install git+https://github.com/huggingface/transformers # Required for LLaMA.
pip install sentencepiece # Required for LlamaTokenizer.
pip install ninja # To parallelize the compilation of flash-attn.
pip install flash-attn # This may take up to 10 mins.
pip install -e . pip install -e .
``` ```

View File

@ -213,7 +213,7 @@ def add_server_arguments(parser: argparse.ArgumentParser):
parser.add_argument('--use-np-cache', action='store_true', parser.add_argument('--use-np-cache', action='store_true',
help='save a numpy copy of model weights for faster loading') help='save a numpy copy of model weights for faster loading')
parser.add_argument('--use-dummy-weights', action='store_true', help='use dummy values for model weights') parser.add_argument('--use-dummy-weights', action='store_true', help='use dummy values for model weights')
# NOTE(woosuk): FlashAttention does not support float32. # TODO(woosuk): Support FP32 for debugging.
parser.add_argument('--dtype', type=str, default='default', choices=['default', 'half', 'bfloat16'], parser.add_argument('--dtype', type=str, default='default', choices=['default', 'half', 'bfloat16'],
help=('data type for model weights and activations. ' help=('data type for model weights and activations. '
'The "default" option will use FP16 precision ' 'The "default" option will use FP16 precision '

View File

@ -1,8 +1,8 @@
from typing import Optional from typing import Optional
from flash_attn.flash_attn_interface import _flash_attn_forward
import torch import torch
import torch.nn as nn import torch.nn as nn
from xformers import ops as xops
from cacheflow import attention_ops from cacheflow import attention_ops
from cacheflow import cache_ops from cacheflow import cache_ops
@ -15,6 +15,7 @@ class GPTCacheFlowAttention(nn.Module):
def __init__(self, scale: float) -> None: def __init__(self, scale: float) -> None:
super().__init__() super().__init__()
self.scale = float(scale) self.scale = float(scale)
self.attn_op = xops.fmha.cutlass.FwOp()
def multi_query_kv_attention( def multi_query_kv_attention(
self, self,
@ -22,32 +23,21 @@ class GPTCacheFlowAttention(nn.Module):
query: torch.Tensor, # [num_prompt_tokens, num_heads, head_size] query: torch.Tensor, # [num_prompt_tokens, num_heads, head_size]
key: torch.Tensor, # [num_prompt_tokens, num_heads, head_size] key: torch.Tensor, # [num_prompt_tokens, num_heads, head_size]
value: torch.Tensor, # [num_prompt_tokens, num_heads, head_size] value: torch.Tensor, # [num_prompt_tokens, num_heads, head_size]
cumulative_prompt_lens: torch.Tensor, # [num_prompts + 1] attn_bias: xops.AttentionBias,
max_prompt_len: int,
) -> None: ) -> None:
if query.dtype == torch.float: # TODO(woosuk): The unsqueeze op may incur some CPU overhead. Optimize.
raise ValueError('The float data type is not supported by ' out = xops.memory_efficient_attention_forward(
'FlashAttention. Use the half data type instead.') query.unsqueeze(0),
head_size = query.shape[-1] key.unsqueeze(0),
if head_size > 128: value.unsqueeze(0),
raise ValueError('FlashAttention does not support head_size > 128.') attn_bias=attn_bias,
p=0.0,
# Directly call FlashAttention's internal function to avoid allocating scale=self.scale,
# a new tensor for the output. op=self.attn_op,
_flash_attn_forward(
query,
key,
value,
output,
cumulative_prompt_lens,
cumulative_prompt_lens,
max_prompt_len,
max_prompt_len,
dropout_p=0.0,
softmax_scale=self.scale,
causal=True,
return_softmax=False,
) )
# TODO(woosuk): Unnecessary copy. Optimize.
output.copy_(out.squeeze(0))
return output
def single_query_cached_kv_attention( def single_query_cached_kv_attention(
self, self,
@ -109,8 +99,7 @@ class GPTCacheFlowAttention(nn.Module):
query[:num_prompt_tokens], query[:num_prompt_tokens],
key[:num_prompt_tokens], key[:num_prompt_tokens],
value[:num_prompt_tokens], value[:num_prompt_tokens],
input_metadata.cumulative_prompt_lens, input_metadata.attn_bias,
input_metadata.max_prompt_len,
) )
# Wait until the cache op is done. # Wait until the cache op is done.
@ -143,13 +132,6 @@ class GPTCacheFlowAttention(nn.Module):
return output.view(-1, num_heads * head_size) return output.view(-1, num_heads * head_size)
class OPTCacheFlowAttention(GPTCacheFlowAttention):
"""OPT uses the same attention mechanism as GPT."""
def __init__(self, scale: float) -> None:
super().__init__(scale)
class GPTNeoXCacheFlowAttention(GPTCacheFlowAttention): class GPTNeoXCacheFlowAttention(GPTCacheFlowAttention):
"""Attention with GPT-NeoX style rotary embedding.""" """Attention with GPT-NeoX style rotary embedding."""
@ -207,7 +189,3 @@ class GPTNeoXCacheFlowAttention(GPTCacheFlowAttention):
input_metadata, input_metadata,
cache_event, cache_event,
) )
class LlamaCacheFlowAttention(GPTNeoXCacheFlowAttention):
"""LLaMA uses the GPT-NeoX style rotary embedding."""

View File

@ -1,6 +1,7 @@
from typing import List, Dict, Tuple from typing import List, Dict, Tuple
import torch import torch
from xformers.ops.fmha.attn_bias import BlockDiagonalCausalMask
from cacheflow.sampling_params import SamplingParams from cacheflow.sampling_params import SamplingParams
@ -12,7 +13,6 @@ class InputMetadata:
seq_groups: List[Tuple[List[int], SamplingParams]], seq_groups: List[Tuple[List[int], SamplingParams]],
seq_logprobs: Dict[int, float], # Seq id -> cumulative logprobs. seq_logprobs: Dict[int, float], # Seq id -> cumulative logprobs.
prompt_lens: List[int], prompt_lens: List[int],
cumulative_prompt_lens: torch.Tensor,
slot_mapping: torch.Tensor, slot_mapping: torch.Tensor,
context_lens: torch.Tensor, context_lens: torch.Tensor,
max_context_len: int, max_context_len: int,
@ -21,15 +21,14 @@ class InputMetadata:
self.seq_groups = seq_groups self.seq_groups = seq_groups
self.seq_logprobs = seq_logprobs self.seq_logprobs = seq_logprobs
self.prompt_lens = prompt_lens self.prompt_lens = prompt_lens
self.cumulative_prompt_lens = cumulative_prompt_lens
self.slot_mapping = slot_mapping self.slot_mapping = slot_mapping
self.context_lens = context_lens self.context_lens = context_lens
self.max_context_len = max_context_len self.max_context_len = max_context_len
self.block_tables = block_tables self.block_tables = block_tables
self.attn_bias = BlockDiagonalCausalMask.from_seqlens(prompt_lens)
self.num_prompts = len(prompt_lens) self.num_prompts = len(prompt_lens)
self.num_prompt_tokens = sum(prompt_lens) self.num_prompt_tokens = sum(prompt_lens)
self.max_prompt_len = max(prompt_lens) if prompt_lens else 0
self.num_generation_tokens = context_lens.shape[0] self.num_generation_tokens = context_lens.shape[0]
self.num_valid_tokens = slot_mapping.shape[0] self.num_valid_tokens = slot_mapping.shape[0]
if block_tables.numel() > 0: if block_tables.numel() > 0:
@ -41,15 +40,13 @@ class InputMetadata:
def __repr__(self) -> str: def __repr__(self) -> str:
return (f'InputMetadata(' return (f'InputMetadata('
f'num_prompts={self.num_prompts}, '
f'num_prompt_tokens={self.num_prompt_tokens}, '
f'max_prompt_len={self.max_prompt_len}, '
f'num_generation_tokens={self.num_generation_tokens}, '
f'num_valid_tokens={self.num_valid_tokens}, ' f'num_valid_tokens={self.num_valid_tokens}, '
f'max_num_blocks_per_seq={self.max_num_blocks_per_seq}, ' f'num_prompt_tokens={self.num_prompt_tokens}, '
f'max_context_len={self.max_context_len}), ' f'num_prompts={self.num_prompts}, '
f'prompt_lens={self.prompt_lens}, ' f'prompt_lens={self.prompt_lens}, '
f'cumulative_prompt_lens={self.cumulative_prompt_lens}, ' f'num_generation_tokens={self.num_generation_tokens}, '
f'slot_mapping={self.slot_mapping}, '
f'context_lens={self.context_lens}, ' f'context_lens={self.context_lens}, '
f'block_tables={self.block_tables})') f'max_context_len={self.max_context_len}), '
f'max_num_blocks_per_seq={self.max_num_blocks_per_seq}, '
f'block_tables={self.block_tables}), '
f'slot_mapping={self.slot_mapping}')

View File

@ -7,7 +7,7 @@ from transformers import LlamaConfig
from cacheflow.models import InputMetadata from cacheflow.models import InputMetadata
from cacheflow.models.activation import SiluAndMul from cacheflow.models.activation import SiluAndMul
from cacheflow.models.attention import LlamaCacheFlowAttention from cacheflow.models.attention import GPTNeoXCacheFlowAttention
from cacheflow.models.layernorm import RMSNorm from cacheflow.models.layernorm import RMSNorm
from cacheflow.models.sample import Sampler from cacheflow.models.sample import Sampler
from cacheflow.models.utils import (hf_model_weights_iterator, from cacheflow.models.utils import (hf_model_weights_iterator,
@ -79,7 +79,7 @@ class LlamaAttention(nn.Module):
input_is_parallel=True, input_is_parallel=True,
perform_initialization=False, perform_initialization=False,
) )
self.attn = LlamaCacheFlowAttention(self.scaling, self.head_dim) self.attn = GPTNeoXCacheFlowAttention(self.scaling, self.head_dim)
def forward( def forward(
self, self,

View File

@ -202,8 +202,8 @@ class OPTMemoryAnalyzer(CacheFlowMemoryAnalyzer):
# estimating # estimating
# 1) the maximum activation tensor size during inference # 1) the maximum activation tensor size during inference
# 2) the residual tensor size during inference # 2) the residual tensor size during inference
# Here, we assume that FlashAttention is used and # Here, we assume that we use memory-efficient attention which
# thus the attention maps are never materialized in GPU DRAM. # does not materialize the attention maps in GPU DRAM.
residual = max_num_batched_tokens * self.hidden_size residual = max_num_batched_tokens * self.hidden_size
qkv = 3 * (max_num_batched_tokens * self.hidden_size) // self.tensor_parallel_size qkv = 3 * (max_num_batched_tokens * self.hidden_size) // self.tensor_parallel_size
ffn = max_num_batched_tokens * self.ffn_size // self.tensor_parallel_size ffn = max_num_batched_tokens * self.ffn_size // self.tensor_parallel_size
@ -277,8 +277,8 @@ class LlamaMemoryAnalyzer(CacheFlowMemoryAnalyzer):
# estimating # estimating
# 1) the maximum activation tensor size during inference # 1) the maximum activation tensor size during inference
# 2) the residual tensor size during inference # 2) the residual tensor size during inference
# Here, we assume that FlashAttention is used and # Here, we assume that we use memory-efficient attention which
# thus the attention maps are never materialized in GPU DRAM. # does not materialize the attention maps in GPU DRAM.
residual = max_num_batched_tokens * self.hidden_size residual = max_num_batched_tokens * self.hidden_size
qkv = 3 * (max_num_batched_tokens * self.hidden_size) // self.tensor_parallel_size qkv = 3 * (max_num_batched_tokens * self.hidden_size) // self.tensor_parallel_size
ffn = 2 * (max_num_batched_tokens * self.ffn_size) // self.tensor_parallel_size ffn = 2 * (max_num_batched_tokens * self.ffn_size) // self.tensor_parallel_size
@ -353,8 +353,8 @@ class GPTNeoXMemoryAnalyzer(CacheFlowMemoryAnalyzer):
# estimating # estimating
# 1) the maximum activation tensor size during inference # 1) the maximum activation tensor size during inference
# 2) the residual tensor size during inference # 2) the residual tensor size during inference
# Here, we assume that FlashAttention is used and # Here, we assume that we use memory-efficient attention which
# thus the attention maps are never materialized in GPU DRAM. # does not materialize the attention maps in GPU DRAM.
residual = max_num_batched_tokens * self.hidden_size residual = max_num_batched_tokens * self.hidden_size
qkv = 3 * (max_num_batched_tokens * self.hidden_size) // self.tensor_parallel_size qkv = 3 * (max_num_batched_tokens * self.hidden_size) // self.tensor_parallel_size
ffn = 2 * (max_num_batched_tokens * self.ffn_size) // self.tensor_parallel_size ffn = 2 * (max_num_batched_tokens * self.ffn_size) // self.tensor_parallel_size

View File

@ -6,7 +6,7 @@ from torch import nn
from transformers import OPTConfig from transformers import OPTConfig
from cacheflow.models import InputMetadata from cacheflow.models import InputMetadata
from cacheflow.models.attention import OPTCacheFlowAttention from cacheflow.models.attention import GPTCacheFlowAttention
from cacheflow.models.sample import Sampler from cacheflow.models.sample import Sampler
from cacheflow.models.utils import (hf_model_weights_iterator, from cacheflow.models.utils import (hf_model_weights_iterator,
load_tensor_parallel_weights) load_tensor_parallel_weights)
@ -55,7 +55,7 @@ class OPTAttention(nn.Module):
self.out_proj = RowParallelLinear(embed_dim, embed_dim, bias=bias, self.out_proj = RowParallelLinear(embed_dim, embed_dim, bias=bias,
input_is_parallel=True, input_is_parallel=True,
perform_initialization=False) perform_initialization=False)
self.attn = OPTCacheFlowAttention(scale=self.scaling) self.attn = GPTCacheFlowAttention(scale=self.scaling)
def forward( def forward(
self, self,

View File

@ -136,11 +136,6 @@ class Worker:
slot = block_number * self.block_size + block_offset slot = block_number * self.block_size + block_offset
slot_mapping.append(slot) slot_mapping.append(slot)
cumulative_prompt_lens: List[int] = [0]
for prompt_len in prompt_lens:
cumulative_prompt_lens.append(
cumulative_prompt_lens[-1] + prompt_len)
# Add generation tokens. # Add generation tokens.
max_context_len = 0 max_context_len = 0
max_num_blocks_per_seq = 0 max_num_blocks_per_seq = 0
@ -196,14 +191,11 @@ class Worker:
for block_table in generation_block_tables] for block_table in generation_block_tables]
block_tables_tensor = torch.tensor( block_tables_tensor = torch.tensor(
padded_block_tables, dtype=torch.int, device='cuda') padded_block_tables, dtype=torch.int, device='cuda')
cumulative_prompt_lens_tensor = torch.tensor(
cumulative_prompt_lens, dtype=torch.int, device='cuda')
input_metadata = InputMetadata( input_metadata = InputMetadata(
seq_groups=seq_groups, seq_groups=seq_groups,
seq_logprobs=seq_logprobs, seq_logprobs=seq_logprobs,
prompt_lens=prompt_lens, prompt_lens=prompt_lens,
cumulative_prompt_lens=cumulative_prompt_lens_tensor,
slot_mapping=slot_mapping_tensor, slot_mapping=slot_mapping_tensor,
context_lens=context_lens_tensor, context_lens=context_lens_tensor,
max_context_len=max_context_len, max_context_len=max_context_len,

View File

@ -23,7 +23,7 @@ def test_silu_and_mul(
if __name__ == '__main__': if __name__ == '__main__':
for dtype in [torch.half, torch.float]: for dtype in [torch.half, torch.bfloat16, torch.float]:
for num_tokens in [7, 83, 2048]: for num_tokens in [7, 83, 2048]:
for d in [512, 4096, 13824]: for d in [512, 4096, 13824]:
print(f'Testing dtype={dtype}, num_tokens={num_tokens}, d={d}') print(f'Testing dtype={dtype}, num_tokens={num_tokens}, d={d}')

View File

@ -1,8 +1,9 @@
import random import random
from typing import List, Optional from typing import List, Optional
from flash_attn.flash_attn_interface import _flash_attn_forward
import torch import torch
from xformers import ops as xops
from xformers.ops.fmha.attn_bias import BlockDiagonalCausalMask
from cacheflow import attention_ops from cacheflow import attention_ops
@ -81,8 +82,10 @@ def ref_multi_query_kv_attention(
end_idx = cu_seq_lens[i + 1] end_idx = cu_seq_lens[i + 1]
seq_len = end_idx - start_idx seq_len = end_idx - start_idx
# Create attention mask # Create attention mask.
attn_mask = torch.triu(torch.ones(seq_len, seq_len), diagonal=1) * -1e5 attn_mask = torch.triu(
torch.ones(seq_len, seq_len, dtype=dtype), diagonal=1)
attn_mask = attn_mask * torch.finfo(dtype).min
attn_mask = attn_mask.to(dtype=dtype, device='cuda') attn_mask = attn_mask.to(dtype=dtype, device='cuda')
ref_output = ref_masked_attention( ref_output = ref_masked_attention(
@ -160,21 +163,20 @@ def test_single_query_cached_kv_attention(
num_blocks: int, num_blocks: int,
dtype: torch.dtype, dtype: torch.dtype,
) -> None: ) -> None:
qkv = torch.randn( qkv = torch.empty(
num_tokens, 3, num_heads, head_size, dtype=dtype, device='cuda') num_tokens, 3, num_heads, head_size, dtype=dtype, device='cuda')
qkv.uniform_(-1e-3, 1e-3)
query, _, _ = qkv.unbind(dim=1) query, _, _ = qkv.unbind(dim=1)
x = 16 // torch.tensor([], dtype=dtype).element_size() x = 16 // torch.tensor([], dtype=dtype).element_size()
key_block_shape = (num_heads, head_size // x, block_size, x) key_block_shape = (num_heads, head_size // x, block_size, x)
key_cache = torch.randn( key_cache = torch.empty(
size=(num_blocks, *key_block_shape), dtype=dtype, device='cuda') size=(num_blocks, *key_block_shape), dtype=dtype, device='cuda')
key_cache.uniform_(-1e-3, 1e-3)
value_block_shape = (num_heads, head_size, block_size) value_block_shape = (num_heads, head_size, block_size)
value_cache = torch.randn( value_cache = torch.empty(
size=(num_blocks, *value_block_shape), dtype=dtype, device='cuda') size=(num_blocks, *value_block_shape), dtype=dtype, device='cuda')
value_cache.uniform_(-1e-3, 1e-3)
# Adjust the range of the values to reduce precision errors.
query = query / (head_size ** 0.5)
key_cache = key_cache / (head_size ** 0.5)
value_cache = value_cache / (head_size ** 0.5)
context_lens = [random.randint(1, MAX_SEQ_LEN) for _ in range(num_tokens)] context_lens = [random.randint(1, MAX_SEQ_LEN) for _ in range(num_tokens)]
max_context_len = max(context_lens) max_context_len = max(context_lens)
@ -228,39 +230,30 @@ def test_multi_query_kv_attention(
dtype: torch.dtype, dtype: torch.dtype,
) -> None: ) -> None:
seq_lens = random.sample(range(1, MAX_SEQ_LEN), num_seqs) seq_lens = random.sample(range(1, MAX_SEQ_LEN), num_seqs)
max_seq_len = max(seq_lens)
num_tokens = sum(seq_lens) num_tokens = sum(seq_lens)
scale = float(1.0 / (head_size ** 0.5))
qkv = torch.empty(
num_tokens, 3, num_heads, head_size, dtype=dtype, device='cuda')
qkv.uniform_(-1e-3, 1e-3)
query, key, value = qkv.unbind(dim=1)
attn_op = xops.fmha.cutlass.FwOp()
attn_bias = BlockDiagonalCausalMask.from_seqlens(seq_lens)
output = xops.memory_efficient_attention_forward(
query.unsqueeze(0),
key.unsqueeze(0),
value.unsqueeze(0),
attn_bias=attn_bias,
p=0.0,
scale=scale,
op=attn_op,
)
output = output.squeeze(0)
cu_seq_lens = [0] cu_seq_lens = [0]
for seq_len in seq_lens: for seq_len in seq_lens:
cu_seq_lens.append(cu_seq_lens[-1] + seq_len) cu_seq_lens.append(cu_seq_lens[-1] + seq_len)
cu_seq_lens = torch.tensor(cu_seq_lens, dtype=torch.int, device='cuda')
scale = float(1.0 / (head_size ** 0.5))
qkv = torch.randn(
num_tokens, 3, num_heads, head_size, dtype=dtype, device='cuda')
# Adjust the range of the values to reduce precision errors.
qkv = qkv / (head_size ** 0.5)
query, key, value = qkv.unbind(dim=1)
output = torch.empty(
num_tokens, num_heads, head_size, dtype=dtype, device='cuda')
_flash_attn_forward(
query,
key,
value,
output,
cu_seq_lens,
cu_seq_lens,
max_seq_len,
max_seq_len,
dropout_p=0.0,
softmax_scale=scale,
causal=True,
return_softmax=False,
)
cu_seq_lens = cu_seq_lens.cpu().tolist()
ref_output = ref_multi_query_kv_attention( ref_output = ref_multi_query_kv_attention(
cu_seq_lens, cu_seq_lens,
query, query,
@ -277,8 +270,8 @@ def test_attention(seed: int) -> None:
# the test fails due to the precision issue. Re-run the test if it fails. # the test fails due to the precision issue. Re-run the test if it fails.
torch.random.manual_seed(seed) torch.random.manual_seed(seed)
torch.cuda.manual_seed(seed) torch.cuda.manual_seed(seed)
for dtype in [torch.half, torch.float]: for dtype in [torch.half, torch.bfloat16]:
for block_size in [8, 16, 32]: for block_size in [8, 16, 32, 64]:
for head_size in [32, 64, 80, 96, 128, 160, 192, 256]: for head_size in [32, 64, 80, 96, 128, 160, 192, 256]:
print(f'Testing single_query_cached_kv_attention with ' print(f'Testing single_query_cached_kv_attention with '
f'dtype={dtype}, block_size={block_size}, ' f'dtype={dtype}, block_size={block_size}, '
@ -292,14 +285,12 @@ def test_attention(seed: int) -> None:
dtype=dtype, dtype=dtype,
) )
# NOTE(woosuk): FlashAttention does not support FP32. for dtype in [torch.half, torch.bfloat16]:
for dtype in [torch.half]: for head_size in [32, 64, 80, 96, 128, 160, 192, 256]:
# NOTE(woosuk): FlashAttention does not support head_size > 128.
for head_size in [64, 80, 96, 128]:
print(f'Testing multi_query_kv_attention with dtype={dtype}, ' print(f'Testing multi_query_kv_attention with dtype={dtype}, '
f'head_size={head_size}') f'head_size={head_size}')
test_multi_query_kv_attention( test_multi_query_kv_attention(
num_seqs=11, num_seqs=5,
num_heads=3, num_heads=3,
head_size=head_size, head_size=head_size,
dtype=dtype, dtype=dtype,

View File

@ -142,15 +142,16 @@ def test_gather_cached_kv(
@torch.inference_mode() @torch.inference_mode()
def test_cache() -> None: def test_cache() -> None:
for dtype in [torch.half, torch.bfloat16, torch.float]:
test_copy_blocks( test_copy_blocks(
num_mappings=23, num_layers=7, num_heads=17, head_size=16, num_mappings=23, num_layers=7, num_heads=17, head_size=16,
block_size=8, num_blocks=1024, dtype=torch.half) block_size=8, num_blocks=1024, dtype=dtype)
test_reshape_and_cache( test_reshape_and_cache(
num_tokens=3, num_heads=2, head_size=16, block_size=8, num_blocks=2, num_tokens=3, num_heads=2, head_size=16, block_size=8, num_blocks=2,
dtype=torch.half) dtype=dtype)
test_gather_cached_kv( test_gather_cached_kv(
num_tokens=3, num_heads=2, head_size=16, block_size=8, num_blocks=2, num_tokens=3, num_heads=2, head_size=16, block_size=8, num_blocks=2,
dtype=torch.half) dtype=dtype)
if __name__ == '__main__': if __name__ == '__main__':

View File

@ -8,7 +8,8 @@ class RefRMSNorm(nn.Module):
def __init__(self, hidden_size, eps=1e-6): def __init__(self, hidden_size, eps=1e-6):
super().__init__() super().__init__()
weight = torch.randn(hidden_size) / (hidden_size ** 0.5) weight = torch.empty(hidden_size)
weight.uniform_(-1e-3, 1e-3)
self.weight = nn.Parameter(weight) self.weight = nn.Parameter(weight)
self.variance_epsilon = eps self.variance_epsilon = eps
@ -41,7 +42,7 @@ def test_rms_norm(
if __name__ == '__main__': if __name__ == '__main__':
for dtype in [torch.half, torch.float]: for dtype in [torch.half, torch.bfloat16, torch.float]:
for num_tokens in [7, 128, 2048]: for num_tokens in [7, 128, 2048]:
for hidden_size in [13, 64, 1024, 5120]: for hidden_size in [13, 64, 1024, 5120]:
print(f'Testing RMS kernel with dtype={dtype}, num_tokens=' print(f'Testing RMS kernel with dtype={dtype}, num_tokens='

View File

@ -129,7 +129,7 @@ def test_rotary_embedding_neox(
if __name__ == '__main__': if __name__ == '__main__':
for dtype in [torch.half, torch.float]: for dtype in [torch.half, torch.bfloat16, torch.float]:
for head_size in [32, 64, 80, 96, 128, 160, 192, 256]: for head_size in [32, 64, 80, 96, 128, 160, 192, 256]:
print(f'Running tests for head_size={head_size} and dtype={dtype}') print(f'Running tests for head_size={head_size} and dtype={dtype}')
test_rotary_embedding_neox( test_rotary_embedding_neox(