Add support for GPT-NeoX (Pythia) (#50)

This commit is contained in:
Woosuk Kwon 2023-04-28 00:32:10 -07:00 committed by GitHub
parent aa50b17ca7
commit a96d63c21d
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
9 changed files with 436 additions and 71 deletions

View File

@ -150,20 +150,20 @@ class OPTCacheFlowAttention(GPTCacheFlowAttention):
super().__init__(scale)
class LlamaCacheFlowAttention(GPTCacheFlowAttention):
"""Llama uses GPT-NeoX style rotary embedding."""
class GPTNeoXCacheFlowAttention(GPTCacheFlowAttention):
"""Attention with GPT-NeoX style rotary embedding."""
def __init__(
self,
scale: float,
head_size: int,
rotary_dim: int,
max_position: int = 8192,
base: int = 10000,
) -> None:
super().__init__(scale)
# Create the cos and sin cache.
inv_freq = 1.0 / (base ** (torch.arange(0, head_size, 2) / head_size))
inv_freq = 1.0 / (base ** (torch.arange(0, rotary_dim, 2) / rotary_dim))
t = torch.arange(max_position).float()
freqs = torch.einsum('i,j -> ij', t, inv_freq.float())
cos = freqs.cos()
@ -174,7 +174,7 @@ class LlamaCacheFlowAttention(GPTCacheFlowAttention):
# initializing the model. Make it more robust.
torch_dtype = torch.get_default_dtype()
cache = cache.to(torch_dtype)
# Embedding size: [max_position, head_size]
# Embedding size: [max_position, rotary_dim]
self.register_buffer('cos_sin_cache', cache, persistent=False)
def forward(
@ -190,10 +190,12 @@ class LlamaCacheFlowAttention(GPTCacheFlowAttention):
) -> torch.Tensor: # [num_tokens, num_heads * head_size]
# Apply rotary embedding to the query and key before passing them
# to the attention op.
head_size = value_cache.shape[2]
pos_encoding_ops.rotary_embedding_neox(
positions,
query,
key,
head_size,
self.cos_sin_cache,
)
return super().forward(
@ -205,3 +207,7 @@ class LlamaCacheFlowAttention(GPTCacheFlowAttention):
input_metadata,
cache_event,
)
class LlamaCacheFlowAttention(GPTNeoXCacheFlowAttention):
"""LLaMA uses the GPT-NeoX style rotary embedding."""

View File

@ -0,0 +1,278 @@
"""1D GPT-NeoX model compatible with HuggingFace weights."""
import os
import glob
import filelock
from tqdm import tqdm
from typing import Dict, List, Optional, Tuple
import numpy as np
import torch
from torch import nn
from huggingface_hub import snapshot_download
from cacheflow.models import InputMetadata
from cacheflow.models.attention import GPTNeoXCacheFlowAttention
from cacheflow.models.sample import Sampler
from cacheflow.parallel_utils.parallel_state import (
get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size)
from cacheflow.parallel_utils.tensor_parallel import (VocabParallelEmbedding,
ColumnParallelLinear,
RowParallelLinear)
from cacheflow.sequence import SequenceOutputs
KVCache = Tuple[torch.Tensor, torch.Tensor]
class GPTNeoXAttention(nn.Module):
def __init__(self, config):
super().__init__()
self.total_num_heads = config.num_attention_heads
self.hidden_size = config.hidden_size
self.head_size = self.hidden_size // self.total_num_heads
tensor_model_parallel_world_size = get_tensor_model_parallel_world_size()
assert self.total_num_heads % tensor_model_parallel_world_size == 0
self.num_heads = self.total_num_heads // tensor_model_parallel_world_size
self.query_key_value = ColumnParallelLinear(config.hidden_size,
3 * config.hidden_size,
gather_output=False,
perform_initialization=False)
self.dense = RowParallelLinear(config.hidden_size, config.hidden_size,
input_is_parallel=True,
perform_initialization=False)
scaling = self.head_size ** -0.5
rotary_dim = int(self.head_size * config.rotary_pct)
assert rotary_dim % 2 == 0
self.attn = GPTNeoXCacheFlowAttention(scaling, rotary_dim)
def forward(
self,
position_ids: torch.LongTensor,
hidden_states: torch.Tensor,
kv_cache: KVCache,
input_metadata: InputMetadata,
cache_event: Optional[torch.cuda.Event],
) -> torch.Tensor:
qkv, _ = self.query_key_value(hidden_states)
q, k, v = qkv.chunk(chunks=3, dim=-1)
k_cache, v_cache = kv_cache
attn_output = self.attn(
position_ids, q, k, v, k_cache, v_cache, input_metadata, cache_event)
output, _ = self.dense(attn_output)
return output
class GPTNeoXMLP(nn.Module):
def __init__(self, config):
super().__init__()
self.dense_h_to_4h = ColumnParallelLinear(config.hidden_size,
config.intermediate_size,
gather_output=False,
perform_initialization=False)
self.dense_4h_to_h = RowParallelLinear(config.intermediate_size, config.hidden_size,
input_is_parallel=True,
perform_initialization=False)
if config.hidden_act != 'gelu':
raise ValueError(f'Unsupported activation: {config.hidden_act}. '
'Only gelu is supported for now.')
self.act = torch.nn.GELU()
def forward(self, hidden_states):
hidden_states, _ = self.dense_h_to_4h(hidden_states)
hidden_states = self.act(hidden_states)
hidden_states, _ = self.dense_4h_to_h(hidden_states)
return hidden_states
class GPTNeoXLayer(nn.Module):
def __init__(self, config):
super().__init__()
self.use_parallel_residual = config.use_parallel_residual
self.input_layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
self.post_attention_layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
self.attention = GPTNeoXAttention(config)
self.mlp = GPTNeoXMLP(config)
def forward(
self,
position_ids: torch.LongTensor,
hidden_states: torch.Tensor,
kv_cache: KVCache,
input_metadata: InputMetadata,
cache_event: Optional[torch.cuda.Event],
) -> torch.Tensor:
attn_input = self.input_layernorm(hidden_states)
attn_output = self.attention(
position_ids=position_ids,
hidden_states=attn_input,
kv_cache=kv_cache,
input_metadata=input_metadata,
cache_event=cache_event,
)
if self.use_parallel_residual:
# pseudocode:
# x = x + attn(ln1(x)) + mlp(ln2(x))
mlp_input = self.post_attention_layernorm(hidden_states)
mlp_output = self.mlp(mlp_input)
hidden_states = mlp_output + attn_output + hidden_states
else:
# pseudocode:
# x = x + attn(ln1(x))
# x = x + mlp(ln2(x))
attn_output = attn_output + hidden_states
mlp_input = self.post_attention_layernorm(attn_output)
mlp_output = self.mlp(mlp_input)
hidden_states = mlp_output + attn_output
return hidden_states
class GPTNeoXModel(nn.Module):
def __init__(self, config):
super().__init__()
self.config = config
self.embed_in = VocabParallelEmbedding(config.vocab_size, config.hidden_size,
perform_initialization=False)
self.layers = nn.ModuleList([GPTNeoXLayer(config) for _ in range(config.num_hidden_layers)])
self.final_layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
def forward(
self,
input_ids: torch.LongTensor,
position_ids: torch.LongTensor,
kv_caches: List[KVCache],
input_metadata: InputMetadata,
cache_events: Optional[List[torch.cuda.Event]],
) -> torch.Tensor:
hidden_states = self.embed_in(input_ids)
for i in range(len(self.layers)):
if cache_events is None:
cache_event = None
else:
cache_event = cache_events[i]
layer = self.layers[i]
hidden_states = layer(
position_ids,
hidden_states,
kv_caches[i],
input_metadata,
cache_event,
)
hidden_states = self.final_layer_norm(hidden_states)
return hidden_states
class GPTNeoXForCausalLM(nn.Module):
def __init__(self, config):
super().__init__()
self.config = config
self.gpt_neox = GPTNeoXModel(config)
self.embed_out = ColumnParallelLinear(config.hidden_size, config.vocab_size,
bias=False, gather_output=False,
perform_initialization=False)
self.sampler = Sampler()
def forward(
self,
input_ids: torch.LongTensor,
positions: torch.LongTensor,
kv_caches: List[KVCache],
input_metadata: InputMetadata,
cache_events: Optional[List[torch.cuda.Event]],
) -> Dict[int, SequenceOutputs]:
hidden_states = self.gpt_neox(
input_ids, positions, kv_caches, input_metadata, cache_events)
next_tokens = self.sampler(
self.embed_out.weight, hidden_states, input_metadata)
return next_tokens
_column_parallel_weights = ["embed_in.weight", "embed_out.weight", "dense_h_to_4h.weight", "dense_h_to_4h.bias"]
_row_parallel_weights = ["dense.weight", "dense_4h_to_h.weight"]
def load_weights(self, weights_path: str):
tensor_model_parallel_rank = get_tensor_model_parallel_rank()
state_dict = self.state_dict()
for name, param in state_dict.items():
if "query_key_value" in name:
# NOTE(woosuk): GPT-NeoX's fused QKV has the shape of
# [num_heads * 3 * head_size, num_heads * head_size], while the
# required shape is [3 * num_heads * head_size, num_heads * head_size].
# Thus, we need weight conversion.
loaded_weight = torch.from_numpy(
np.load(os.path.join(weights_path, name)))
shard_size = param.shape[0]
loaded_weight = loaded_weight[shard_size * tensor_model_parallel_rank
:shard_size * (tensor_model_parallel_rank + 1)]
num_heads = self.config.num_attention_heads
hidden_size = self.config.hidden_size
head_size = hidden_size // num_heads
if 'query_key_value.weight' in name:
loaded_weight = loaded_weight.view(-1, 3, head_size, hidden_size)
loaded_weight = loaded_weight.transpose(0, 1)
loaded_weight = loaded_weight.reshape(-1, hidden_size).contiguous()
elif 'query_key_value.bias' in name:
loaded_weight = loaded_weight.view(-1, 3, head_size)
loaded_weight = loaded_weight.transpose(0, 1)
loaded_weight = loaded_weight.reshape(-1).contiguous()
else:
assert False
else:
loaded_weight = torch.from_numpy(
np.load(os.path.join(weights_path, name)))
for p in self._column_parallel_weights:
if p in name:
shard_size = param.shape[0]
loaded_weight = loaded_weight[
shard_size * tensor_model_parallel_rank
:shard_size * (tensor_model_parallel_rank + 1)]
break
for p in self._row_parallel_weights:
if p in name:
shard_size = param.shape[1]
loaded_weight = loaded_weight[
:,
shard_size * tensor_model_parallel_rank
:shard_size * (tensor_model_parallel_rank + 1)]
break
assert param.shape == loaded_weight.shape
param.data.copy_(loaded_weight)
@staticmethod
def get_weights(model_name: str, path: str):
path = os.path.join(path, f"{model_name}-np")
path = os.path.abspath(os.path.expanduser(path))
os.makedirs(path, exist_ok=True)
lock_path = os.path.join(path, "file_lock")
lock = filelock.FileLock(lock_path)
with lock:
test_weight_path = os.path.join(
path, "gpt_neox.embed_in.weight")
if os.path.exists(test_weight_path):
return path
folder = snapshot_download(model_name, allow_patterns="*.bin",
cache_dir=os.path.join(path, "cache"))
bin_files = glob.glob(os.path.join(folder, "*.bin"))
for bin_file in tqdm(bin_files, desc="Convert format"):
state = torch.load(bin_file, map_location="cpu")
for name, param in tqdm(state.items(), leave=False):
param_path = os.path.join(path, name)
with open(param_path, "wb") as f:
np.save(f, param.cpu().detach().numpy())
return path
def initialize_dummy_weights(self) -> None:
for param in self.state_dict().values():
param.data.uniform_(-1e-3, 1e-3)

View File

@ -289,4 +289,4 @@ class LlamaForCausalLM(nn.Module):
def initialize_dummy_weights(self) -> None:
for param in self.state_dict().values():
param.data.uniform_(-0.1, 0.1)
param.data.uniform_(-1e-3, 1e-3)

View File

@ -40,6 +40,37 @@ class CacheFlowMemoryAnalyzer:
max_num_blocks = swap_space // self.get_cache_block_size()
return max_num_blocks
def get_param_size(self) -> int:
raise NotImplementedError()
def get_max_act_size(self, max_num_batched_tokens: int) -> int:
raise NotImplementedError()
def get_cache_block_size(self) -> int:
key_cache_block = self.block_size * self.hidden_size // self.tensor_parallel_size
value_cache_block = key_cache_block
total = self.num_layers * (key_cache_block + value_cache_block)
dtype_size = get_dtype_size(self.dtype)
return dtype_size * total
def get_max_num_gpu_blocks(
self,
max_num_batched_tokens: int,
memory_utilization: float = 0.95,
) -> int:
# NOTE(woosuk): This assumes that the machine has homogeneous GPUs.
usable_memory = int(memory_utilization * self.gpu_memory)
param_size = self.get_param_size()
act_size = self.get_max_act_size(max_num_batched_tokens)
workspace_size = self.get_workspace_size()
max_cache_size = usable_memory - (param_size + act_size + workspace_size)
if max_cache_size <= 0:
raise RuntimeError('Not enough GPU memory.')
max_num_blocks = max_cache_size // self.get_cache_block_size()
return max_num_blocks
class OPTMemoryAnalyzer(CacheFlowMemoryAnalyzer):
@ -69,7 +100,7 @@ class OPTMemoryAnalyzer(CacheFlowMemoryAnalyzer):
self.vocab_size = config.vocab_size
self.max_position = config.max_position_embeddings
def _get_param_size(self) -> int:
def get_param_size(self) -> int:
word_embedding = self.vocab_size * self.embedding_size // self.tensor_parallel_size
if self.embedding_size != self.hidden_size:
# Project in/out.
@ -93,7 +124,7 @@ class OPTMemoryAnalyzer(CacheFlowMemoryAnalyzer):
dtype_size = get_dtype_size(self.dtype)
return dtype_size * total
def _get_max_act_size(
def get_max_act_size(
self,
max_num_batched_tokens: int,
) -> int:
@ -114,31 +145,6 @@ class OPTMemoryAnalyzer(CacheFlowMemoryAnalyzer):
dtype_size = get_dtype_size(self.dtype)
return dtype_size * max_act
def get_cache_block_size(self) -> int:
key_cache_block = self.block_size * self.hidden_size // self.tensor_parallel_size
value_cache_block = key_cache_block
total = self.num_layers * (key_cache_block + value_cache_block)
dtype_size = get_dtype_size(self.dtype)
return dtype_size * total
def get_max_num_gpu_blocks(
self,
max_num_batched_tokens: int,
memory_utilization: float = 0.95,
) -> int:
# NOTE(woosuk): This assumes that the machine has homogeneous GPUs.
usable_memory = int(memory_utilization * self.gpu_memory)
param_size = self._get_param_size()
act_size = self._get_max_act_size(max_num_batched_tokens)
workspace_size = self.get_workspace_size()
max_cache_size = usable_memory - (param_size + act_size + workspace_size)
if max_cache_size <= 0:
raise RuntimeError('Not enough GPU memory.')
max_num_blocks = max_cache_size // self.get_cache_block_size()
return max_num_blocks
class LlamaMemoryAnalyzer(CacheFlowMemoryAnalyzer):
@ -167,9 +173,10 @@ class LlamaMemoryAnalyzer(CacheFlowMemoryAnalyzer):
self.vocab_size = config.vocab_size
self.max_position = 8192
def _get_param_size(self) -> int:
def get_param_size(self) -> int:
# NOTE: LLaMA does not tie the two embeddings.
word_embedding = self.vocab_size * self.hidden_size // self.tensor_parallel_size
position_embedding = self.max_position * self.hidden_size
lm_head = self.vocab_size * self.hidden_size // self.tensor_parallel_size
# NOTE: LLaMA does not have bias terms.
ln1 = self.hidden_size
@ -188,11 +195,11 @@ class LlamaMemoryAnalyzer(CacheFlowMemoryAnalyzer):
up = self.hidden_size * self.ffn_size // self.tensor_parallel_size
ffn = ln2 + gate + down + up
total = (word_embedding + position_embedding + self.num_layers * (mha + ffn))
total = word_embedding + self.num_layers * (mha + ffn) + lm_head
dtype_size = get_dtype_size(self.dtype)
return dtype_size * total
def _get_max_act_size(
def get_max_act_size(
self,
max_num_batched_tokens: int,
) -> int:
@ -213,28 +220,78 @@ class LlamaMemoryAnalyzer(CacheFlowMemoryAnalyzer):
dtype_size = get_dtype_size(self.dtype)
return dtype_size * max_act
def get_cache_block_size(self) -> int:
key_cache_block = self.block_size * self.hidden_size // self.tensor_parallel_size
value_cache_block = key_cache_block
total = self.num_layers * (key_cache_block + value_cache_block)
class GPTNeoXMemoryAnalyzer(CacheFlowMemoryAnalyzer):
def __init__(
self,
model_name: str,
block_size: int,
dtype: torch.dtype,
gpu_memory: int,
cpu_memory: int,
tensor_parallel_size: int,
) -> None:
self.model_name = model_name
self.block_size = block_size
self.dtype = dtype
self.gpu_memory = gpu_memory
self.cpu_memory = cpu_memory
self.tensor_parallel_size = tensor_parallel_size
config = AutoConfig.from_pretrained(model_name)
self.num_layers = config.num_hidden_layers
self.hidden_size = config.hidden_size
self.num_heads = config.num_attention_heads
self.head_size = config.hidden_size // self.num_heads
self.ffn_size = config.intermediate_size
self.vocab_size = config.vocab_size
self.max_position = 8192
self.tie_word_embeddings = config.tie_word_embeddings
def get_param_size(self) -> int:
word_embedding = self.vocab_size * self.hidden_size // self.tensor_parallel_size
if self.tie_word_embeddings:
lm_head = 0
else:
lm_head = self.vocab_size * self.hidden_size // self.tensor_parallel_size
ln1 = 2 * self.hidden_size
q = self.hidden_size * self.hidden_size // self.tensor_parallel_size + self.hidden_size
k = self.hidden_size * self.hidden_size // self.tensor_parallel_size + self.hidden_size
v = self.hidden_size * self.hidden_size // self.tensor_parallel_size + self.hidden_size
out = self.hidden_size * self.hidden_size // self.tensor_parallel_size + self.hidden_size
# Rotary embedding.
# TODO(woosuk): Share the rotary embedding between layers.
rot = self.max_position * self.head_size
mha = ln1 + q + k + v + out + rot
ln2 = 2 * self.hidden_size
ffn1 = self.hidden_size * self.ffn_size // self.tensor_parallel_size + self.ffn_size
ffn2 = self.ffn_size * self.hidden_size // self.tensor_parallel_size + self.hidden_size
ffn = ln2 + ffn1 + ffn2
total = word_embedding + self.num_layers * (mha + ffn) + lm_head
dtype_size = get_dtype_size(self.dtype)
return dtype_size * total
def get_max_num_gpu_blocks(
def get_max_act_size(
self,
max_num_batched_tokens: int,
memory_utilization: float = 0.95,
) -> int:
# NOTE(woosuk): This assumes that the machine has homogeneous GPUs.
gpu_memory = self.gpu_memory
usable_memory = int(memory_utilization * gpu_memory)
param_size = self._get_param_size()
act_size = self._get_max_act_size(max_num_batched_tokens)
workspace_size = self.get_workspace_size()
max_cache_size = usable_memory - (param_size + act_size + workspace_size)
if max_cache_size <= 0:
raise RuntimeError('Not enough GPU memory.')
max_num_blocks = max_cache_size // self.get_cache_block_size()
return max_num_blocks
# NOTE: We approxmiately calculate the maximum activation size by
# estimating
# 1) the maximum activation tensor size during inference
# 2) the residual tensor size during inference
# Here, we assume that FlashAttention is used and
# thus the attention maps are never materialized in GPU DRAM.
residual = max_num_batched_tokens * self.hidden_size
qkv = 3 * (max_num_batched_tokens * self.hidden_size) // self.tensor_parallel_size
ffn = 2 * (max_num_batched_tokens * self.ffn_size) // self.tensor_parallel_size
# Double the activation size for input and output.
max_act = 2 * (max(qkv, ffn) + residual)
# Size of output logits.
output_logits = 2 * (max_num_batched_tokens * self.vocab_size)
max_act = max(max_act, output_logits)
dtype_size = get_dtype_size(self.dtype)
return dtype_size * max_act

View File

@ -1,13 +1,14 @@
from typing import Union
import numpy as np
import torch
import torch.nn as nn
from transformers import AutoConfig
from cacheflow.models.memory_analyzer import CacheFlowMemoryAnalyzer
from cacheflow.models.memory_analyzer import GPTNeoXMemoryAnalyzer
from cacheflow.models.memory_analyzer import LlamaMemoryAnalyzer
from cacheflow.models.memory_analyzer import OPTMemoryAnalyzer
from cacheflow.models.gpt_neox import GPTNeoXForCausalLM
from cacheflow.models.llama import LlamaForCausalLM
from cacheflow.models.opt import OPTForCausalLM
from cacheflow.models.utils import get_torch_dtype
@ -16,11 +17,15 @@ from cacheflow.models.utils import get_torch_dtype
_MODELS = {
'llama': LlamaForCausalLM,
'opt': OPTForCausalLM,
'stablelm': GPTNeoXForCausalLM,
'pythia': GPTNeoXForCausalLM,
}
_MEMORY_ANALYZERS = {
'llama': LlamaMemoryAnalyzer,
'opt': OPTMemoryAnalyzer,
'stablelm': GPTNeoXMemoryAnalyzer,
'pythia': GPTNeoXMemoryAnalyzer,
}

View File

@ -327,4 +327,4 @@ class OPTForCausalLM(nn.Module):
def initialize_dummy_weights(self) -> None:
for param in self.state_dict().values():
param.data.uniform_(-0.1, 0.1)
param.data.uniform_(-1e-3, 1e-3)

View File

@ -4,6 +4,7 @@ void rotary_embedding_neox(
torch::Tensor& positions,
torch::Tensor& query,
torch::Tensor& key,
int head_size,
torch::Tensor& cos_sin_cache);
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {

View File

@ -8,16 +8,17 @@ __global__ void rotary_embedding_neox_kernel(
const int64_t* __restrict__ positions, // [num_tokens]
scalar_t* __restrict__ query, // [num_tokens, num_heads, head_size]
scalar_t* __restrict__ key, // [num_tokens, num_heads, head_size]
const scalar_t* __restrict__ cos_sin_cache, // [max_position, 2, head_size // 2]
const scalar_t* __restrict__ cos_sin_cache, // [max_position, 2, rot_dim // 2]
const int rot_dim,
const int stride,
const int num_heads,
const int head_size) {
// Each thread block is responsible for one token.
const int token_idx = blockIdx.x;
int64_t pos = positions[token_idx];
const scalar_t* cache_ptr = cos_sin_cache + pos * head_size;
const scalar_t* cache_ptr = cos_sin_cache + pos * rot_dim;
const int embed_dim = head_size / 2;
const int embed_dim = rot_dim / 2;
const int n = num_heads * embed_dim;
for (int i = threadIdx.x; i < n; i += blockDim.x) {
const int head_idx = i / embed_dim;
@ -51,16 +52,17 @@ void rotary_embedding_neox(
torch::Tensor& positions, // [num_tokens]
torch::Tensor& query, // [num_tokens, num_heads * head_size]
torch::Tensor& key, // [num_tokens, num_heads * head_size]
torch::Tensor& cos_sin_cache) // [max_position, head_size]
int head_size,
torch::Tensor& cos_sin_cache) // [max_position, rot_dim]
{
int num_tokens = query.size(0);
int head_size = cos_sin_cache.size(1);
int rot_dim = cos_sin_cache.size(1);
int num_heads = query.size(1) / head_size;
int stride = query.stride(0);
TORCH_CHECK(stride == key.stride(0));
dim3 grid(num_tokens);
dim3 block(std::min(num_heads * head_size / 2, 512));
dim3 block(std::min(num_heads * rot_dim / 2, 512));
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
AT_DISPATCH_FLOATING_TYPES_AND_HALF(
query.scalar_type(),
@ -71,6 +73,7 @@ void rotary_embedding_neox(
query.data_ptr<scalar_t>(),
key.data_ptr<scalar_t>(),
cos_sin_cache.data_ptr<scalar_t>(),
rot_dim,
stride,
num_heads,
head_size);

View File

@ -34,6 +34,7 @@ class RefRotaryEmbeddingNeox(nn.Module):
base: int = 10000,
) -> None:
super().__init__()
self.rotary_dim = dim
self.max_position_embeddings = max_position_embeddings
# Create cos and sin embeddings.
@ -52,13 +53,24 @@ class RefRotaryEmbeddingNeox(nn.Module):
query: torch.Tensor, # [num_tokens, num_heads, head_size]
key: torch.Tensor, # [num_tokens, num_heads, head_size]
) -> Tuple[torch.Tensor, torch.Tensor]:
query_rot = query[..., : self.rotary_dim]
query_pass = query[..., self.rotary_dim :]
key_rot = key[..., : self.rotary_dim]
key_pass = key[..., self.rotary_dim :]
query_rot = query_rot.transpose(0, 1)
key_rot = key_rot.transpose(0, 1)
cos = F.embedding(positions, self.cos_cached)
sin = F.embedding(positions, self.sin_cached)
query = query.transpose(0, 1)
key = key.transpose(0, 1)
query, key = apply_rotary_pos_emb(query, key, cos, sin)
query = query.transpose(0, 1).contiguous()
key = key.transpose(0, 1).contiguous()
query_rot, key_rot = apply_rotary_pos_emb(query_rot, key_rot, cos, sin)
query_rot = query_rot.transpose(0, 1).contiguous()
key_rot = key_rot.transpose(0, 1).contiguous()
query = torch.cat((query_rot, query_pass), dim=-1)
key = torch.cat((key_rot, key_pass), dim=-1)
# Output query/key shape: [num_tokens, num_tokens, head_size]
return query, key
@ -69,6 +81,7 @@ def test_rotary_embedding_neox(
num_heads: int,
head_size: int,
max_position: int,
rotary_dim: int,
dtype: torch.dtype,
base: int = 10000,
) -> None:
@ -77,7 +90,7 @@ def test_rotary_embedding_neox(
key = torch.randn(num_tokens, num_heads * head_size, dtype=dtype, device='cuda')
# Create the rotary embedding.
inv_freq = 1.0 / (base ** (torch.arange(0, head_size, 2) / head_size))
inv_freq = 1.0 / (base ** (torch.arange(0, rotary_dim, 2) / rotary_dim))
t = torch.arange(max_position).float()
freqs = torch.einsum('i,j -> ij', t, inv_freq.float())
cos = freqs.cos()
@ -92,12 +105,13 @@ def test_rotary_embedding_neox(
positions,
out_query,
out_key,
head_size,
cos_sin_cache,
)
# Run the reference implementation.
ref_rotary_embedding = RefRotaryEmbeddingNeox(
dim=head_size,
dim=rotary_dim,
max_position_embeddings=max_position,
base=base,
).to(dtype=dtype, device='cuda')
@ -123,5 +137,6 @@ if __name__ == '__main__':
num_heads=5,
head_size=head_size,
max_position=8192,
rotary_dim=int(head_size * 0.25),
dtype=dtype,
)