Replace FlashAttention with xformers (#70)
This commit is contained in:
parent
189ae23133
commit
c9d5b6d4a8
@ -3,11 +3,7 @@
|
||||
## Installation
|
||||
|
||||
```bash
|
||||
pip install psutil numpy ray torch
|
||||
pip install git+https://github.com/huggingface/transformers # Required for LLaMA.
|
||||
pip install sentencepiece # Required for LlamaTokenizer.
|
||||
pip install ninja # To parallelize the compilation of flash-attn.
|
||||
pip install flash-attn # This may take up to 10 mins.
|
||||
pip install ninja psutil numpy sentencepiece ray torch transformers xformers
|
||||
pip install -e .
|
||||
```
|
||||
|
||||
|
@ -213,7 +213,7 @@ def add_server_arguments(parser: argparse.ArgumentParser):
|
||||
parser.add_argument('--use-np-cache', action='store_true',
|
||||
help='save a numpy copy of model weights for faster loading')
|
||||
parser.add_argument('--use-dummy-weights', action='store_true', help='use dummy values for model weights')
|
||||
# NOTE(woosuk): FlashAttention does not support float32.
|
||||
# TODO(woosuk): Support FP32 for debugging.
|
||||
parser.add_argument('--dtype', type=str, default='default', choices=['default', 'half', 'bfloat16'],
|
||||
help=('data type for model weights and activations. '
|
||||
'The "default" option will use FP16 precision '
|
||||
|
@ -1,8 +1,8 @@
|
||||
from typing import Optional
|
||||
|
||||
from flash_attn.flash_attn_interface import _flash_attn_forward
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from xformers import ops as xops
|
||||
|
||||
from cacheflow import attention_ops
|
||||
from cacheflow import cache_ops
|
||||
@ -15,6 +15,7 @@ class GPTCacheFlowAttention(nn.Module):
|
||||
def __init__(self, scale: float) -> None:
|
||||
super().__init__()
|
||||
self.scale = float(scale)
|
||||
self.attn_op = xops.fmha.cutlass.FwOp()
|
||||
|
||||
def multi_query_kv_attention(
|
||||
self,
|
||||
@ -22,32 +23,21 @@ class GPTCacheFlowAttention(nn.Module):
|
||||
query: torch.Tensor, # [num_prompt_tokens, num_heads, head_size]
|
||||
key: torch.Tensor, # [num_prompt_tokens, num_heads, head_size]
|
||||
value: torch.Tensor, # [num_prompt_tokens, num_heads, head_size]
|
||||
cumulative_prompt_lens: torch.Tensor, # [num_prompts + 1]
|
||||
max_prompt_len: int,
|
||||
attn_bias: xops.AttentionBias,
|
||||
) -> None:
|
||||
if query.dtype == torch.float:
|
||||
raise ValueError('The float data type is not supported by '
|
||||
'FlashAttention. Use the half data type instead.')
|
||||
head_size = query.shape[-1]
|
||||
if head_size > 128:
|
||||
raise ValueError('FlashAttention does not support head_size > 128.')
|
||||
|
||||
# Directly call FlashAttention's internal function to avoid allocating
|
||||
# a new tensor for the output.
|
||||
_flash_attn_forward(
|
||||
query,
|
||||
key,
|
||||
value,
|
||||
output,
|
||||
cumulative_prompt_lens,
|
||||
cumulative_prompt_lens,
|
||||
max_prompt_len,
|
||||
max_prompt_len,
|
||||
dropout_p=0.0,
|
||||
softmax_scale=self.scale,
|
||||
causal=True,
|
||||
return_softmax=False,
|
||||
# TODO(woosuk): The unsqueeze op may incur some CPU overhead. Optimize.
|
||||
out = xops.memory_efficient_attention_forward(
|
||||
query.unsqueeze(0),
|
||||
key.unsqueeze(0),
|
||||
value.unsqueeze(0),
|
||||
attn_bias=attn_bias,
|
||||
p=0.0,
|
||||
scale=self.scale,
|
||||
op=self.attn_op,
|
||||
)
|
||||
# TODO(woosuk): Unnecessary copy. Optimize.
|
||||
output.copy_(out.squeeze(0))
|
||||
return output
|
||||
|
||||
def single_query_cached_kv_attention(
|
||||
self,
|
||||
@ -109,8 +99,7 @@ class GPTCacheFlowAttention(nn.Module):
|
||||
query[:num_prompt_tokens],
|
||||
key[:num_prompt_tokens],
|
||||
value[:num_prompt_tokens],
|
||||
input_metadata.cumulative_prompt_lens,
|
||||
input_metadata.max_prompt_len,
|
||||
input_metadata.attn_bias,
|
||||
)
|
||||
|
||||
# Wait until the cache op is done.
|
||||
@ -143,13 +132,6 @@ class GPTCacheFlowAttention(nn.Module):
|
||||
return output.view(-1, num_heads * head_size)
|
||||
|
||||
|
||||
class OPTCacheFlowAttention(GPTCacheFlowAttention):
|
||||
"""OPT uses the same attention mechanism as GPT."""
|
||||
|
||||
def __init__(self, scale: float) -> None:
|
||||
super().__init__(scale)
|
||||
|
||||
|
||||
class GPTNeoXCacheFlowAttention(GPTCacheFlowAttention):
|
||||
"""Attention with GPT-NeoX style rotary embedding."""
|
||||
|
||||
@ -207,7 +189,3 @@ class GPTNeoXCacheFlowAttention(GPTCacheFlowAttention):
|
||||
input_metadata,
|
||||
cache_event,
|
||||
)
|
||||
|
||||
|
||||
class LlamaCacheFlowAttention(GPTNeoXCacheFlowAttention):
|
||||
"""LLaMA uses the GPT-NeoX style rotary embedding."""
|
||||
|
@ -1,6 +1,7 @@
|
||||
from typing import List, Dict, Tuple
|
||||
|
||||
import torch
|
||||
from xformers.ops.fmha.attn_bias import BlockDiagonalCausalMask
|
||||
|
||||
from cacheflow.sampling_params import SamplingParams
|
||||
|
||||
@ -12,7 +13,6 @@ class InputMetadata:
|
||||
seq_groups: List[Tuple[List[int], SamplingParams]],
|
||||
seq_logprobs: Dict[int, float], # Seq id -> cumulative logprobs.
|
||||
prompt_lens: List[int],
|
||||
cumulative_prompt_lens: torch.Tensor,
|
||||
slot_mapping: torch.Tensor,
|
||||
context_lens: torch.Tensor,
|
||||
max_context_len: int,
|
||||
@ -21,15 +21,14 @@ class InputMetadata:
|
||||
self.seq_groups = seq_groups
|
||||
self.seq_logprobs = seq_logprobs
|
||||
self.prompt_lens = prompt_lens
|
||||
self.cumulative_prompt_lens = cumulative_prompt_lens
|
||||
self.slot_mapping = slot_mapping
|
||||
self.context_lens = context_lens
|
||||
self.max_context_len = max_context_len
|
||||
self.block_tables = block_tables
|
||||
|
||||
self.attn_bias = BlockDiagonalCausalMask.from_seqlens(prompt_lens)
|
||||
self.num_prompts = len(prompt_lens)
|
||||
self.num_prompt_tokens = sum(prompt_lens)
|
||||
self.max_prompt_len = max(prompt_lens) if prompt_lens else 0
|
||||
self.num_generation_tokens = context_lens.shape[0]
|
||||
self.num_valid_tokens = slot_mapping.shape[0]
|
||||
if block_tables.numel() > 0:
|
||||
@ -41,15 +40,13 @@ class InputMetadata:
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return (f'InputMetadata('
|
||||
f'num_prompts={self.num_prompts}, '
|
||||
f'num_prompt_tokens={self.num_prompt_tokens}, '
|
||||
f'max_prompt_len={self.max_prompt_len}, '
|
||||
f'num_generation_tokens={self.num_generation_tokens}, '
|
||||
f'num_valid_tokens={self.num_valid_tokens}, '
|
||||
f'max_num_blocks_per_seq={self.max_num_blocks_per_seq}, '
|
||||
f'max_context_len={self.max_context_len}), '
|
||||
f'num_prompt_tokens={self.num_prompt_tokens}, '
|
||||
f'num_prompts={self.num_prompts}, '
|
||||
f'prompt_lens={self.prompt_lens}, '
|
||||
f'cumulative_prompt_lens={self.cumulative_prompt_lens}, '
|
||||
f'slot_mapping={self.slot_mapping}, '
|
||||
f'num_generation_tokens={self.num_generation_tokens}, '
|
||||
f'context_lens={self.context_lens}, '
|
||||
f'block_tables={self.block_tables})')
|
||||
f'max_context_len={self.max_context_len}), '
|
||||
f'max_num_blocks_per_seq={self.max_num_blocks_per_seq}, '
|
||||
f'block_tables={self.block_tables}), '
|
||||
f'slot_mapping={self.slot_mapping}')
|
||||
|
@ -7,7 +7,7 @@ from transformers import LlamaConfig
|
||||
|
||||
from cacheflow.models import InputMetadata
|
||||
from cacheflow.models.activation import SiluAndMul
|
||||
from cacheflow.models.attention import LlamaCacheFlowAttention
|
||||
from cacheflow.models.attention import GPTNeoXCacheFlowAttention
|
||||
from cacheflow.models.layernorm import RMSNorm
|
||||
from cacheflow.models.sample import Sampler
|
||||
from cacheflow.models.utils import (hf_model_weights_iterator,
|
||||
@ -79,7 +79,7 @@ class LlamaAttention(nn.Module):
|
||||
input_is_parallel=True,
|
||||
perform_initialization=False,
|
||||
)
|
||||
self.attn = LlamaCacheFlowAttention(self.scaling, self.head_dim)
|
||||
self.attn = GPTNeoXCacheFlowAttention(self.scaling, self.head_dim)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
|
@ -202,8 +202,8 @@ class OPTMemoryAnalyzer(CacheFlowMemoryAnalyzer):
|
||||
# estimating
|
||||
# 1) the maximum activation tensor size during inference
|
||||
# 2) the residual tensor size during inference
|
||||
# Here, we assume that FlashAttention is used and
|
||||
# thus the attention maps are never materialized in GPU DRAM.
|
||||
# Here, we assume that we use memory-efficient attention which
|
||||
# does not materialize the attention maps in GPU DRAM.
|
||||
residual = max_num_batched_tokens * self.hidden_size
|
||||
qkv = 3 * (max_num_batched_tokens * self.hidden_size) // self.tensor_parallel_size
|
||||
ffn = max_num_batched_tokens * self.ffn_size // self.tensor_parallel_size
|
||||
@ -277,8 +277,8 @@ class LlamaMemoryAnalyzer(CacheFlowMemoryAnalyzer):
|
||||
# estimating
|
||||
# 1) the maximum activation tensor size during inference
|
||||
# 2) the residual tensor size during inference
|
||||
# Here, we assume that FlashAttention is used and
|
||||
# thus the attention maps are never materialized in GPU DRAM.
|
||||
# Here, we assume that we use memory-efficient attention which
|
||||
# does not materialize the attention maps in GPU DRAM.
|
||||
residual = max_num_batched_tokens * self.hidden_size
|
||||
qkv = 3 * (max_num_batched_tokens * self.hidden_size) // self.tensor_parallel_size
|
||||
ffn = 2 * (max_num_batched_tokens * self.ffn_size) // self.tensor_parallel_size
|
||||
@ -353,8 +353,8 @@ class GPTNeoXMemoryAnalyzer(CacheFlowMemoryAnalyzer):
|
||||
# estimating
|
||||
# 1) the maximum activation tensor size during inference
|
||||
# 2) the residual tensor size during inference
|
||||
# Here, we assume that FlashAttention is used and
|
||||
# thus the attention maps are never materialized in GPU DRAM.
|
||||
# Here, we assume that we use memory-efficient attention which
|
||||
# does not materialize the attention maps in GPU DRAM.
|
||||
residual = max_num_batched_tokens * self.hidden_size
|
||||
qkv = 3 * (max_num_batched_tokens * self.hidden_size) // self.tensor_parallel_size
|
||||
ffn = 2 * (max_num_batched_tokens * self.ffn_size) // self.tensor_parallel_size
|
||||
|
@ -6,7 +6,7 @@ from torch import nn
|
||||
from transformers import OPTConfig
|
||||
|
||||
from cacheflow.models import InputMetadata
|
||||
from cacheflow.models.attention import OPTCacheFlowAttention
|
||||
from cacheflow.models.attention import GPTCacheFlowAttention
|
||||
from cacheflow.models.sample import Sampler
|
||||
from cacheflow.models.utils import (hf_model_weights_iterator,
|
||||
load_tensor_parallel_weights)
|
||||
@ -55,7 +55,7 @@ class OPTAttention(nn.Module):
|
||||
self.out_proj = RowParallelLinear(embed_dim, embed_dim, bias=bias,
|
||||
input_is_parallel=True,
|
||||
perform_initialization=False)
|
||||
self.attn = OPTCacheFlowAttention(scale=self.scaling)
|
||||
self.attn = GPTCacheFlowAttention(scale=self.scaling)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
|
@ -136,11 +136,6 @@ class Worker:
|
||||
slot = block_number * self.block_size + block_offset
|
||||
slot_mapping.append(slot)
|
||||
|
||||
cumulative_prompt_lens: List[int] = [0]
|
||||
for prompt_len in prompt_lens:
|
||||
cumulative_prompt_lens.append(
|
||||
cumulative_prompt_lens[-1] + prompt_len)
|
||||
|
||||
# Add generation tokens.
|
||||
max_context_len = 0
|
||||
max_num_blocks_per_seq = 0
|
||||
@ -196,14 +191,11 @@ class Worker:
|
||||
for block_table in generation_block_tables]
|
||||
block_tables_tensor = torch.tensor(
|
||||
padded_block_tables, dtype=torch.int, device='cuda')
|
||||
cumulative_prompt_lens_tensor = torch.tensor(
|
||||
cumulative_prompt_lens, dtype=torch.int, device='cuda')
|
||||
|
||||
input_metadata = InputMetadata(
|
||||
seq_groups=seq_groups,
|
||||
seq_logprobs=seq_logprobs,
|
||||
prompt_lens=prompt_lens,
|
||||
cumulative_prompt_lens=cumulative_prompt_lens_tensor,
|
||||
slot_mapping=slot_mapping_tensor,
|
||||
context_lens=context_lens_tensor,
|
||||
max_context_len=max_context_len,
|
||||
|
@ -23,7 +23,7 @@ def test_silu_and_mul(
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
for dtype in [torch.half, torch.float]:
|
||||
for dtype in [torch.half, torch.bfloat16, torch.float]:
|
||||
for num_tokens in [7, 83, 2048]:
|
||||
for d in [512, 4096, 13824]:
|
||||
print(f'Testing dtype={dtype}, num_tokens={num_tokens}, d={d}')
|
||||
|
@ -1,8 +1,9 @@
|
||||
import random
|
||||
from typing import List, Optional
|
||||
|
||||
from flash_attn.flash_attn_interface import _flash_attn_forward
|
||||
import torch
|
||||
from xformers import ops as xops
|
||||
from xformers.ops.fmha.attn_bias import BlockDiagonalCausalMask
|
||||
|
||||
from cacheflow import attention_ops
|
||||
|
||||
@ -81,8 +82,10 @@ def ref_multi_query_kv_attention(
|
||||
end_idx = cu_seq_lens[i + 1]
|
||||
seq_len = end_idx - start_idx
|
||||
|
||||
# Create attention mask
|
||||
attn_mask = torch.triu(torch.ones(seq_len, seq_len), diagonal=1) * -1e5
|
||||
# Create attention mask.
|
||||
attn_mask = torch.triu(
|
||||
torch.ones(seq_len, seq_len, dtype=dtype), diagonal=1)
|
||||
attn_mask = attn_mask * torch.finfo(dtype).min
|
||||
attn_mask = attn_mask.to(dtype=dtype, device='cuda')
|
||||
|
||||
ref_output = ref_masked_attention(
|
||||
@ -160,21 +163,20 @@ def test_single_query_cached_kv_attention(
|
||||
num_blocks: int,
|
||||
dtype: torch.dtype,
|
||||
) -> None:
|
||||
qkv = torch.randn(
|
||||
qkv = torch.empty(
|
||||
num_tokens, 3, num_heads, head_size, dtype=dtype, device='cuda')
|
||||
qkv.uniform_(-1e-3, 1e-3)
|
||||
query, _, _ = qkv.unbind(dim=1)
|
||||
|
||||
x = 16 // torch.tensor([], dtype=dtype).element_size()
|
||||
key_block_shape = (num_heads, head_size // x, block_size, x)
|
||||
key_cache = torch.randn(
|
||||
key_cache = torch.empty(
|
||||
size=(num_blocks, *key_block_shape), dtype=dtype, device='cuda')
|
||||
key_cache.uniform_(-1e-3, 1e-3)
|
||||
value_block_shape = (num_heads, head_size, block_size)
|
||||
value_cache = torch.randn(
|
||||
value_cache = torch.empty(
|
||||
size=(num_blocks, *value_block_shape), dtype=dtype, device='cuda')
|
||||
|
||||
# Adjust the range of the values to reduce precision errors.
|
||||
query = query / (head_size ** 0.5)
|
||||
key_cache = key_cache / (head_size ** 0.5)
|
||||
value_cache = value_cache / (head_size ** 0.5)
|
||||
value_cache.uniform_(-1e-3, 1e-3)
|
||||
|
||||
context_lens = [random.randint(1, MAX_SEQ_LEN) for _ in range(num_tokens)]
|
||||
max_context_len = max(context_lens)
|
||||
@ -228,39 +230,30 @@ def test_multi_query_kv_attention(
|
||||
dtype: torch.dtype,
|
||||
) -> None:
|
||||
seq_lens = random.sample(range(1, MAX_SEQ_LEN), num_seqs)
|
||||
max_seq_len = max(seq_lens)
|
||||
num_tokens = sum(seq_lens)
|
||||
|
||||
scale = float(1.0 / (head_size ** 0.5))
|
||||
qkv = torch.empty(
|
||||
num_tokens, 3, num_heads, head_size, dtype=dtype, device='cuda')
|
||||
qkv.uniform_(-1e-3, 1e-3)
|
||||
query, key, value = qkv.unbind(dim=1)
|
||||
|
||||
attn_op = xops.fmha.cutlass.FwOp()
|
||||
attn_bias = BlockDiagonalCausalMask.from_seqlens(seq_lens)
|
||||
output = xops.memory_efficient_attention_forward(
|
||||
query.unsqueeze(0),
|
||||
key.unsqueeze(0),
|
||||
value.unsqueeze(0),
|
||||
attn_bias=attn_bias,
|
||||
p=0.0,
|
||||
scale=scale,
|
||||
op=attn_op,
|
||||
)
|
||||
output = output.squeeze(0)
|
||||
|
||||
cu_seq_lens = [0]
|
||||
for seq_len in seq_lens:
|
||||
cu_seq_lens.append(cu_seq_lens[-1] + seq_len)
|
||||
cu_seq_lens = torch.tensor(cu_seq_lens, dtype=torch.int, device='cuda')
|
||||
|
||||
scale = float(1.0 / (head_size ** 0.5))
|
||||
qkv = torch.randn(
|
||||
num_tokens, 3, num_heads, head_size, dtype=dtype, device='cuda')
|
||||
# Adjust the range of the values to reduce precision errors.
|
||||
qkv = qkv / (head_size ** 0.5)
|
||||
|
||||
query, key, value = qkv.unbind(dim=1)
|
||||
output = torch.empty(
|
||||
num_tokens, num_heads, head_size, dtype=dtype, device='cuda')
|
||||
_flash_attn_forward(
|
||||
query,
|
||||
key,
|
||||
value,
|
||||
output,
|
||||
cu_seq_lens,
|
||||
cu_seq_lens,
|
||||
max_seq_len,
|
||||
max_seq_len,
|
||||
dropout_p=0.0,
|
||||
softmax_scale=scale,
|
||||
causal=True,
|
||||
return_softmax=False,
|
||||
)
|
||||
|
||||
cu_seq_lens = cu_seq_lens.cpu().tolist()
|
||||
ref_output = ref_multi_query_kv_attention(
|
||||
cu_seq_lens,
|
||||
query,
|
||||
@ -277,8 +270,8 @@ def test_attention(seed: int) -> None:
|
||||
# the test fails due to the precision issue. Re-run the test if it fails.
|
||||
torch.random.manual_seed(seed)
|
||||
torch.cuda.manual_seed(seed)
|
||||
for dtype in [torch.half, torch.float]:
|
||||
for block_size in [8, 16, 32]:
|
||||
for dtype in [torch.half, torch.bfloat16]:
|
||||
for block_size in [8, 16, 32, 64]:
|
||||
for head_size in [32, 64, 80, 96, 128, 160, 192, 256]:
|
||||
print(f'Testing single_query_cached_kv_attention with '
|
||||
f'dtype={dtype}, block_size={block_size}, '
|
||||
@ -292,14 +285,12 @@ def test_attention(seed: int) -> None:
|
||||
dtype=dtype,
|
||||
)
|
||||
|
||||
# NOTE(woosuk): FlashAttention does not support FP32.
|
||||
for dtype in [torch.half]:
|
||||
# NOTE(woosuk): FlashAttention does not support head_size > 128.
|
||||
for head_size in [64, 80, 96, 128]:
|
||||
for dtype in [torch.half, torch.bfloat16]:
|
||||
for head_size in [32, 64, 80, 96, 128, 160, 192, 256]:
|
||||
print(f'Testing multi_query_kv_attention with dtype={dtype}, '
|
||||
f'head_size={head_size}')
|
||||
test_multi_query_kv_attention(
|
||||
num_seqs=11,
|
||||
num_seqs=5,
|
||||
num_heads=3,
|
||||
head_size=head_size,
|
||||
dtype=dtype,
|
||||
|
@ -142,15 +142,16 @@ def test_gather_cached_kv(
|
||||
|
||||
@torch.inference_mode()
|
||||
def test_cache() -> None:
|
||||
test_copy_blocks(
|
||||
num_mappings=23, num_layers=7, num_heads=17, head_size=16,
|
||||
block_size=8, num_blocks=1024, dtype=torch.half)
|
||||
test_reshape_and_cache(
|
||||
num_tokens=3, num_heads=2, head_size=16, block_size=8, num_blocks=2,
|
||||
dtype=torch.half)
|
||||
test_gather_cached_kv(
|
||||
num_tokens=3, num_heads=2, head_size=16, block_size=8, num_blocks=2,
|
||||
dtype=torch.half)
|
||||
for dtype in [torch.half, torch.bfloat16, torch.float]:
|
||||
test_copy_blocks(
|
||||
num_mappings=23, num_layers=7, num_heads=17, head_size=16,
|
||||
block_size=8, num_blocks=1024, dtype=dtype)
|
||||
test_reshape_and_cache(
|
||||
num_tokens=3, num_heads=2, head_size=16, block_size=8, num_blocks=2,
|
||||
dtype=dtype)
|
||||
test_gather_cached_kv(
|
||||
num_tokens=3, num_heads=2, head_size=16, block_size=8, num_blocks=2,
|
||||
dtype=dtype)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
@ -8,7 +8,8 @@ class RefRMSNorm(nn.Module):
|
||||
|
||||
def __init__(self, hidden_size, eps=1e-6):
|
||||
super().__init__()
|
||||
weight = torch.randn(hidden_size) / (hidden_size ** 0.5)
|
||||
weight = torch.empty(hidden_size)
|
||||
weight.uniform_(-1e-3, 1e-3)
|
||||
self.weight = nn.Parameter(weight)
|
||||
self.variance_epsilon = eps
|
||||
|
||||
@ -41,7 +42,7 @@ def test_rms_norm(
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
for dtype in [torch.half, torch.float]:
|
||||
for dtype in [torch.half, torch.bfloat16, torch.float]:
|
||||
for num_tokens in [7, 128, 2048]:
|
||||
for hidden_size in [13, 64, 1024, 5120]:
|
||||
print(f'Testing RMS kernel with dtype={dtype}, num_tokens='
|
||||
|
@ -129,7 +129,7 @@ def test_rotary_embedding_neox(
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
for dtype in [torch.half, torch.float]:
|
||||
for dtype in [torch.half, torch.bfloat16, torch.float]:
|
||||
for head_size in [32, 64, 80, 96, 128, 160, 192, 256]:
|
||||
print(f'Running tests for head_size={head_size} and dtype={dtype}')
|
||||
test_rotary_embedding_neox(
|
||||
|
Loading…
x
Reference in New Issue
Block a user