Add support for GPT-NeoX (Pythia) (#50)
This commit is contained in:
parent
aa50b17ca7
commit
a96d63c21d
@ -150,20 +150,20 @@ class OPTCacheFlowAttention(GPTCacheFlowAttention):
|
|||||||
super().__init__(scale)
|
super().__init__(scale)
|
||||||
|
|
||||||
|
|
||||||
class LlamaCacheFlowAttention(GPTCacheFlowAttention):
|
class GPTNeoXCacheFlowAttention(GPTCacheFlowAttention):
|
||||||
"""Llama uses GPT-NeoX style rotary embedding."""
|
"""Attention with GPT-NeoX style rotary embedding."""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
scale: float,
|
scale: float,
|
||||||
head_size: int,
|
rotary_dim: int,
|
||||||
max_position: int = 8192,
|
max_position: int = 8192,
|
||||||
base: int = 10000,
|
base: int = 10000,
|
||||||
) -> None:
|
) -> None:
|
||||||
super().__init__(scale)
|
super().__init__(scale)
|
||||||
|
|
||||||
# Create the cos and sin cache.
|
# Create the cos and sin cache.
|
||||||
inv_freq = 1.0 / (base ** (torch.arange(0, head_size, 2) / head_size))
|
inv_freq = 1.0 / (base ** (torch.arange(0, rotary_dim, 2) / rotary_dim))
|
||||||
t = torch.arange(max_position).float()
|
t = torch.arange(max_position).float()
|
||||||
freqs = torch.einsum('i,j -> ij', t, inv_freq.float())
|
freqs = torch.einsum('i,j -> ij', t, inv_freq.float())
|
||||||
cos = freqs.cos()
|
cos = freqs.cos()
|
||||||
@ -174,7 +174,7 @@ class LlamaCacheFlowAttention(GPTCacheFlowAttention):
|
|||||||
# initializing the model. Make it more robust.
|
# initializing the model. Make it more robust.
|
||||||
torch_dtype = torch.get_default_dtype()
|
torch_dtype = torch.get_default_dtype()
|
||||||
cache = cache.to(torch_dtype)
|
cache = cache.to(torch_dtype)
|
||||||
# Embedding size: [max_position, head_size]
|
# Embedding size: [max_position, rotary_dim]
|
||||||
self.register_buffer('cos_sin_cache', cache, persistent=False)
|
self.register_buffer('cos_sin_cache', cache, persistent=False)
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
@ -190,10 +190,12 @@ class LlamaCacheFlowAttention(GPTCacheFlowAttention):
|
|||||||
) -> torch.Tensor: # [num_tokens, num_heads * head_size]
|
) -> torch.Tensor: # [num_tokens, num_heads * head_size]
|
||||||
# Apply rotary embedding to the query and key before passing them
|
# Apply rotary embedding to the query and key before passing them
|
||||||
# to the attention op.
|
# to the attention op.
|
||||||
|
head_size = value_cache.shape[2]
|
||||||
pos_encoding_ops.rotary_embedding_neox(
|
pos_encoding_ops.rotary_embedding_neox(
|
||||||
positions,
|
positions,
|
||||||
query,
|
query,
|
||||||
key,
|
key,
|
||||||
|
head_size,
|
||||||
self.cos_sin_cache,
|
self.cos_sin_cache,
|
||||||
)
|
)
|
||||||
return super().forward(
|
return super().forward(
|
||||||
@ -205,3 +207,7 @@ class LlamaCacheFlowAttention(GPTCacheFlowAttention):
|
|||||||
input_metadata,
|
input_metadata,
|
||||||
cache_event,
|
cache_event,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class LlamaCacheFlowAttention(GPTNeoXCacheFlowAttention):
|
||||||
|
"""LLaMA uses the GPT-NeoX style rotary embedding."""
|
||||||
|
278
cacheflow/models/gpt_neox.py
Normal file
278
cacheflow/models/gpt_neox.py
Normal file
@ -0,0 +1,278 @@
|
|||||||
|
"""1D GPT-NeoX model compatible with HuggingFace weights."""
|
||||||
|
import os
|
||||||
|
import glob
|
||||||
|
import filelock
|
||||||
|
from tqdm import tqdm
|
||||||
|
from typing import Dict, List, Optional, Tuple
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
from torch import nn
|
||||||
|
from huggingface_hub import snapshot_download
|
||||||
|
|
||||||
|
from cacheflow.models import InputMetadata
|
||||||
|
from cacheflow.models.attention import GPTNeoXCacheFlowAttention
|
||||||
|
from cacheflow.models.sample import Sampler
|
||||||
|
from cacheflow.parallel_utils.parallel_state import (
|
||||||
|
get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size)
|
||||||
|
from cacheflow.parallel_utils.tensor_parallel import (VocabParallelEmbedding,
|
||||||
|
ColumnParallelLinear,
|
||||||
|
RowParallelLinear)
|
||||||
|
from cacheflow.sequence import SequenceOutputs
|
||||||
|
|
||||||
|
KVCache = Tuple[torch.Tensor, torch.Tensor]
|
||||||
|
|
||||||
|
|
||||||
|
class GPTNeoXAttention(nn.Module):
|
||||||
|
|
||||||
|
def __init__(self, config):
|
||||||
|
super().__init__()
|
||||||
|
self.total_num_heads = config.num_attention_heads
|
||||||
|
self.hidden_size = config.hidden_size
|
||||||
|
self.head_size = self.hidden_size // self.total_num_heads
|
||||||
|
|
||||||
|
tensor_model_parallel_world_size = get_tensor_model_parallel_world_size()
|
||||||
|
assert self.total_num_heads % tensor_model_parallel_world_size == 0
|
||||||
|
self.num_heads = self.total_num_heads // tensor_model_parallel_world_size
|
||||||
|
|
||||||
|
self.query_key_value = ColumnParallelLinear(config.hidden_size,
|
||||||
|
3 * config.hidden_size,
|
||||||
|
gather_output=False,
|
||||||
|
perform_initialization=False)
|
||||||
|
self.dense = RowParallelLinear(config.hidden_size, config.hidden_size,
|
||||||
|
input_is_parallel=True,
|
||||||
|
perform_initialization=False)
|
||||||
|
|
||||||
|
scaling = self.head_size ** -0.5
|
||||||
|
rotary_dim = int(self.head_size * config.rotary_pct)
|
||||||
|
assert rotary_dim % 2 == 0
|
||||||
|
self.attn = GPTNeoXCacheFlowAttention(scaling, rotary_dim)
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
position_ids: torch.LongTensor,
|
||||||
|
hidden_states: torch.Tensor,
|
||||||
|
kv_cache: KVCache,
|
||||||
|
input_metadata: InputMetadata,
|
||||||
|
cache_event: Optional[torch.cuda.Event],
|
||||||
|
) -> torch.Tensor:
|
||||||
|
qkv, _ = self.query_key_value(hidden_states)
|
||||||
|
|
||||||
|
q, k, v = qkv.chunk(chunks=3, dim=-1)
|
||||||
|
k_cache, v_cache = kv_cache
|
||||||
|
attn_output = self.attn(
|
||||||
|
position_ids, q, k, v, k_cache, v_cache, input_metadata, cache_event)
|
||||||
|
output, _ = self.dense(attn_output)
|
||||||
|
return output
|
||||||
|
|
||||||
|
|
||||||
|
class GPTNeoXMLP(nn.Module):
|
||||||
|
def __init__(self, config):
|
||||||
|
super().__init__()
|
||||||
|
self.dense_h_to_4h = ColumnParallelLinear(config.hidden_size,
|
||||||
|
config.intermediate_size,
|
||||||
|
gather_output=False,
|
||||||
|
perform_initialization=False)
|
||||||
|
self.dense_4h_to_h = RowParallelLinear(config.intermediate_size, config.hidden_size,
|
||||||
|
input_is_parallel=True,
|
||||||
|
perform_initialization=False)
|
||||||
|
if config.hidden_act != 'gelu':
|
||||||
|
raise ValueError(f'Unsupported activation: {config.hidden_act}. '
|
||||||
|
'Only gelu is supported for now.')
|
||||||
|
self.act = torch.nn.GELU()
|
||||||
|
|
||||||
|
def forward(self, hidden_states):
|
||||||
|
hidden_states, _ = self.dense_h_to_4h(hidden_states)
|
||||||
|
hidden_states = self.act(hidden_states)
|
||||||
|
hidden_states, _ = self.dense_4h_to_h(hidden_states)
|
||||||
|
return hidden_states
|
||||||
|
|
||||||
|
|
||||||
|
class GPTNeoXLayer(nn.Module):
|
||||||
|
|
||||||
|
def __init__(self, config):
|
||||||
|
super().__init__()
|
||||||
|
self.use_parallel_residual = config.use_parallel_residual
|
||||||
|
self.input_layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
||||||
|
self.post_attention_layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
||||||
|
self.attention = GPTNeoXAttention(config)
|
||||||
|
self.mlp = GPTNeoXMLP(config)
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
position_ids: torch.LongTensor,
|
||||||
|
hidden_states: torch.Tensor,
|
||||||
|
kv_cache: KVCache,
|
||||||
|
input_metadata: InputMetadata,
|
||||||
|
cache_event: Optional[torch.cuda.Event],
|
||||||
|
) -> torch.Tensor:
|
||||||
|
attn_input = self.input_layernorm(hidden_states)
|
||||||
|
attn_output = self.attention(
|
||||||
|
position_ids=position_ids,
|
||||||
|
hidden_states=attn_input,
|
||||||
|
kv_cache=kv_cache,
|
||||||
|
input_metadata=input_metadata,
|
||||||
|
cache_event=cache_event,
|
||||||
|
)
|
||||||
|
|
||||||
|
if self.use_parallel_residual:
|
||||||
|
# pseudocode:
|
||||||
|
# x = x + attn(ln1(x)) + mlp(ln2(x))
|
||||||
|
mlp_input = self.post_attention_layernorm(hidden_states)
|
||||||
|
mlp_output = self.mlp(mlp_input)
|
||||||
|
hidden_states = mlp_output + attn_output + hidden_states
|
||||||
|
else:
|
||||||
|
# pseudocode:
|
||||||
|
# x = x + attn(ln1(x))
|
||||||
|
# x = x + mlp(ln2(x))
|
||||||
|
attn_output = attn_output + hidden_states
|
||||||
|
mlp_input = self.post_attention_layernorm(attn_output)
|
||||||
|
mlp_output = self.mlp(mlp_input)
|
||||||
|
hidden_states = mlp_output + attn_output
|
||||||
|
return hidden_states
|
||||||
|
|
||||||
|
|
||||||
|
class GPTNeoXModel(nn.Module):
|
||||||
|
def __init__(self, config):
|
||||||
|
super().__init__()
|
||||||
|
self.config = config
|
||||||
|
|
||||||
|
self.embed_in = VocabParallelEmbedding(config.vocab_size, config.hidden_size,
|
||||||
|
perform_initialization=False)
|
||||||
|
self.layers = nn.ModuleList([GPTNeoXLayer(config) for _ in range(config.num_hidden_layers)])
|
||||||
|
self.final_layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
input_ids: torch.LongTensor,
|
||||||
|
position_ids: torch.LongTensor,
|
||||||
|
kv_caches: List[KVCache],
|
||||||
|
input_metadata: InputMetadata,
|
||||||
|
cache_events: Optional[List[torch.cuda.Event]],
|
||||||
|
) -> torch.Tensor:
|
||||||
|
hidden_states = self.embed_in(input_ids)
|
||||||
|
for i in range(len(self.layers)):
|
||||||
|
if cache_events is None:
|
||||||
|
cache_event = None
|
||||||
|
else:
|
||||||
|
cache_event = cache_events[i]
|
||||||
|
layer = self.layers[i]
|
||||||
|
hidden_states = layer(
|
||||||
|
position_ids,
|
||||||
|
hidden_states,
|
||||||
|
kv_caches[i],
|
||||||
|
input_metadata,
|
||||||
|
cache_event,
|
||||||
|
)
|
||||||
|
hidden_states = self.final_layer_norm(hidden_states)
|
||||||
|
return hidden_states
|
||||||
|
|
||||||
|
|
||||||
|
class GPTNeoXForCausalLM(nn.Module):
|
||||||
|
|
||||||
|
def __init__(self, config):
|
||||||
|
super().__init__()
|
||||||
|
self.config = config
|
||||||
|
self.gpt_neox = GPTNeoXModel(config)
|
||||||
|
self.embed_out = ColumnParallelLinear(config.hidden_size, config.vocab_size,
|
||||||
|
bias=False, gather_output=False,
|
||||||
|
perform_initialization=False)
|
||||||
|
self.sampler = Sampler()
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
input_ids: torch.LongTensor,
|
||||||
|
positions: torch.LongTensor,
|
||||||
|
kv_caches: List[KVCache],
|
||||||
|
input_metadata: InputMetadata,
|
||||||
|
cache_events: Optional[List[torch.cuda.Event]],
|
||||||
|
) -> Dict[int, SequenceOutputs]:
|
||||||
|
hidden_states = self.gpt_neox(
|
||||||
|
input_ids, positions, kv_caches, input_metadata, cache_events)
|
||||||
|
next_tokens = self.sampler(
|
||||||
|
self.embed_out.weight, hidden_states, input_metadata)
|
||||||
|
return next_tokens
|
||||||
|
|
||||||
|
_column_parallel_weights = ["embed_in.weight", "embed_out.weight", "dense_h_to_4h.weight", "dense_h_to_4h.bias"]
|
||||||
|
_row_parallel_weights = ["dense.weight", "dense_4h_to_h.weight"]
|
||||||
|
|
||||||
|
def load_weights(self, weights_path: str):
|
||||||
|
tensor_model_parallel_rank = get_tensor_model_parallel_rank()
|
||||||
|
state_dict = self.state_dict()
|
||||||
|
for name, param in state_dict.items():
|
||||||
|
if "query_key_value" in name:
|
||||||
|
# NOTE(woosuk): GPT-NeoX's fused QKV has the shape of
|
||||||
|
# [num_heads * 3 * head_size, num_heads * head_size], while the
|
||||||
|
# required shape is [3 * num_heads * head_size, num_heads * head_size].
|
||||||
|
# Thus, we need weight conversion.
|
||||||
|
loaded_weight = torch.from_numpy(
|
||||||
|
np.load(os.path.join(weights_path, name)))
|
||||||
|
shard_size = param.shape[0]
|
||||||
|
loaded_weight = loaded_weight[shard_size * tensor_model_parallel_rank
|
||||||
|
:shard_size * (tensor_model_parallel_rank + 1)]
|
||||||
|
|
||||||
|
num_heads = self.config.num_attention_heads
|
||||||
|
hidden_size = self.config.hidden_size
|
||||||
|
head_size = hidden_size // num_heads
|
||||||
|
if 'query_key_value.weight' in name:
|
||||||
|
loaded_weight = loaded_weight.view(-1, 3, head_size, hidden_size)
|
||||||
|
loaded_weight = loaded_weight.transpose(0, 1)
|
||||||
|
loaded_weight = loaded_weight.reshape(-1, hidden_size).contiguous()
|
||||||
|
elif 'query_key_value.bias' in name:
|
||||||
|
loaded_weight = loaded_weight.view(-1, 3, head_size)
|
||||||
|
loaded_weight = loaded_weight.transpose(0, 1)
|
||||||
|
loaded_weight = loaded_weight.reshape(-1).contiguous()
|
||||||
|
else:
|
||||||
|
assert False
|
||||||
|
else:
|
||||||
|
loaded_weight = torch.from_numpy(
|
||||||
|
np.load(os.path.join(weights_path, name)))
|
||||||
|
for p in self._column_parallel_weights:
|
||||||
|
if p in name:
|
||||||
|
shard_size = param.shape[0]
|
||||||
|
loaded_weight = loaded_weight[
|
||||||
|
shard_size * tensor_model_parallel_rank
|
||||||
|
:shard_size * (tensor_model_parallel_rank + 1)]
|
||||||
|
break
|
||||||
|
for p in self._row_parallel_weights:
|
||||||
|
if p in name:
|
||||||
|
shard_size = param.shape[1]
|
||||||
|
loaded_weight = loaded_weight[
|
||||||
|
:,
|
||||||
|
shard_size * tensor_model_parallel_rank
|
||||||
|
:shard_size * (tensor_model_parallel_rank + 1)]
|
||||||
|
break
|
||||||
|
|
||||||
|
assert param.shape == loaded_weight.shape
|
||||||
|
param.data.copy_(loaded_weight)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def get_weights(model_name: str, path: str):
|
||||||
|
path = os.path.join(path, f"{model_name}-np")
|
||||||
|
path = os.path.abspath(os.path.expanduser(path))
|
||||||
|
os.makedirs(path, exist_ok=True)
|
||||||
|
lock_path = os.path.join(path, "file_lock")
|
||||||
|
lock = filelock.FileLock(lock_path)
|
||||||
|
|
||||||
|
with lock:
|
||||||
|
test_weight_path = os.path.join(
|
||||||
|
path, "gpt_neox.embed_in.weight")
|
||||||
|
if os.path.exists(test_weight_path):
|
||||||
|
return path
|
||||||
|
|
||||||
|
folder = snapshot_download(model_name, allow_patterns="*.bin",
|
||||||
|
cache_dir=os.path.join(path, "cache"))
|
||||||
|
bin_files = glob.glob(os.path.join(folder, "*.bin"))
|
||||||
|
|
||||||
|
for bin_file in tqdm(bin_files, desc="Convert format"):
|
||||||
|
state = torch.load(bin_file, map_location="cpu")
|
||||||
|
for name, param in tqdm(state.items(), leave=False):
|
||||||
|
param_path = os.path.join(path, name)
|
||||||
|
with open(param_path, "wb") as f:
|
||||||
|
np.save(f, param.cpu().detach().numpy())
|
||||||
|
|
||||||
|
return path
|
||||||
|
|
||||||
|
def initialize_dummy_weights(self) -> None:
|
||||||
|
for param in self.state_dict().values():
|
||||||
|
param.data.uniform_(-1e-3, 1e-3)
|
@ -289,4 +289,4 @@ class LlamaForCausalLM(nn.Module):
|
|||||||
|
|
||||||
def initialize_dummy_weights(self) -> None:
|
def initialize_dummy_weights(self) -> None:
|
||||||
for param in self.state_dict().values():
|
for param in self.state_dict().values():
|
||||||
param.data.uniform_(-0.1, 0.1)
|
param.data.uniform_(-1e-3, 1e-3)
|
||||||
|
@ -40,6 +40,37 @@ class CacheFlowMemoryAnalyzer:
|
|||||||
max_num_blocks = swap_space // self.get_cache_block_size()
|
max_num_blocks = swap_space // self.get_cache_block_size()
|
||||||
return max_num_blocks
|
return max_num_blocks
|
||||||
|
|
||||||
|
def get_param_size(self) -> int:
|
||||||
|
raise NotImplementedError()
|
||||||
|
|
||||||
|
def get_max_act_size(self, max_num_batched_tokens: int) -> int:
|
||||||
|
raise NotImplementedError()
|
||||||
|
|
||||||
|
def get_cache_block_size(self) -> int:
|
||||||
|
key_cache_block = self.block_size * self.hidden_size // self.tensor_parallel_size
|
||||||
|
value_cache_block = key_cache_block
|
||||||
|
total = self.num_layers * (key_cache_block + value_cache_block)
|
||||||
|
dtype_size = get_dtype_size(self.dtype)
|
||||||
|
return dtype_size * total
|
||||||
|
|
||||||
|
def get_max_num_gpu_blocks(
|
||||||
|
self,
|
||||||
|
max_num_batched_tokens: int,
|
||||||
|
memory_utilization: float = 0.95,
|
||||||
|
) -> int:
|
||||||
|
# NOTE(woosuk): This assumes that the machine has homogeneous GPUs.
|
||||||
|
usable_memory = int(memory_utilization * self.gpu_memory)
|
||||||
|
|
||||||
|
param_size = self.get_param_size()
|
||||||
|
act_size = self.get_max_act_size(max_num_batched_tokens)
|
||||||
|
workspace_size = self.get_workspace_size()
|
||||||
|
|
||||||
|
max_cache_size = usable_memory - (param_size + act_size + workspace_size)
|
||||||
|
if max_cache_size <= 0:
|
||||||
|
raise RuntimeError('Not enough GPU memory.')
|
||||||
|
max_num_blocks = max_cache_size // self.get_cache_block_size()
|
||||||
|
return max_num_blocks
|
||||||
|
|
||||||
|
|
||||||
class OPTMemoryAnalyzer(CacheFlowMemoryAnalyzer):
|
class OPTMemoryAnalyzer(CacheFlowMemoryAnalyzer):
|
||||||
|
|
||||||
@ -69,7 +100,7 @@ class OPTMemoryAnalyzer(CacheFlowMemoryAnalyzer):
|
|||||||
self.vocab_size = config.vocab_size
|
self.vocab_size = config.vocab_size
|
||||||
self.max_position = config.max_position_embeddings
|
self.max_position = config.max_position_embeddings
|
||||||
|
|
||||||
def _get_param_size(self) -> int:
|
def get_param_size(self) -> int:
|
||||||
word_embedding = self.vocab_size * self.embedding_size // self.tensor_parallel_size
|
word_embedding = self.vocab_size * self.embedding_size // self.tensor_parallel_size
|
||||||
if self.embedding_size != self.hidden_size:
|
if self.embedding_size != self.hidden_size:
|
||||||
# Project in/out.
|
# Project in/out.
|
||||||
@ -93,7 +124,7 @@ class OPTMemoryAnalyzer(CacheFlowMemoryAnalyzer):
|
|||||||
dtype_size = get_dtype_size(self.dtype)
|
dtype_size = get_dtype_size(self.dtype)
|
||||||
return dtype_size * total
|
return dtype_size * total
|
||||||
|
|
||||||
def _get_max_act_size(
|
def get_max_act_size(
|
||||||
self,
|
self,
|
||||||
max_num_batched_tokens: int,
|
max_num_batched_tokens: int,
|
||||||
) -> int:
|
) -> int:
|
||||||
@ -114,31 +145,6 @@ class OPTMemoryAnalyzer(CacheFlowMemoryAnalyzer):
|
|||||||
dtype_size = get_dtype_size(self.dtype)
|
dtype_size = get_dtype_size(self.dtype)
|
||||||
return dtype_size * max_act
|
return dtype_size * max_act
|
||||||
|
|
||||||
def get_cache_block_size(self) -> int:
|
|
||||||
key_cache_block = self.block_size * self.hidden_size // self.tensor_parallel_size
|
|
||||||
value_cache_block = key_cache_block
|
|
||||||
total = self.num_layers * (key_cache_block + value_cache_block)
|
|
||||||
dtype_size = get_dtype_size(self.dtype)
|
|
||||||
return dtype_size * total
|
|
||||||
|
|
||||||
def get_max_num_gpu_blocks(
|
|
||||||
self,
|
|
||||||
max_num_batched_tokens: int,
|
|
||||||
memory_utilization: float = 0.95,
|
|
||||||
) -> int:
|
|
||||||
# NOTE(woosuk): This assumes that the machine has homogeneous GPUs.
|
|
||||||
usable_memory = int(memory_utilization * self.gpu_memory)
|
|
||||||
|
|
||||||
param_size = self._get_param_size()
|
|
||||||
act_size = self._get_max_act_size(max_num_batched_tokens)
|
|
||||||
workspace_size = self.get_workspace_size()
|
|
||||||
|
|
||||||
max_cache_size = usable_memory - (param_size + act_size + workspace_size)
|
|
||||||
if max_cache_size <= 0:
|
|
||||||
raise RuntimeError('Not enough GPU memory.')
|
|
||||||
max_num_blocks = max_cache_size // self.get_cache_block_size()
|
|
||||||
return max_num_blocks
|
|
||||||
|
|
||||||
|
|
||||||
class LlamaMemoryAnalyzer(CacheFlowMemoryAnalyzer):
|
class LlamaMemoryAnalyzer(CacheFlowMemoryAnalyzer):
|
||||||
|
|
||||||
@ -167,9 +173,10 @@ class LlamaMemoryAnalyzer(CacheFlowMemoryAnalyzer):
|
|||||||
self.vocab_size = config.vocab_size
|
self.vocab_size = config.vocab_size
|
||||||
self.max_position = 8192
|
self.max_position = 8192
|
||||||
|
|
||||||
def _get_param_size(self) -> int:
|
def get_param_size(self) -> int:
|
||||||
|
# NOTE: LLaMA does not tie the two embeddings.
|
||||||
word_embedding = self.vocab_size * self.hidden_size // self.tensor_parallel_size
|
word_embedding = self.vocab_size * self.hidden_size // self.tensor_parallel_size
|
||||||
position_embedding = self.max_position * self.hidden_size
|
lm_head = self.vocab_size * self.hidden_size // self.tensor_parallel_size
|
||||||
|
|
||||||
# NOTE: LLaMA does not have bias terms.
|
# NOTE: LLaMA does not have bias terms.
|
||||||
ln1 = self.hidden_size
|
ln1 = self.hidden_size
|
||||||
@ -188,11 +195,11 @@ class LlamaMemoryAnalyzer(CacheFlowMemoryAnalyzer):
|
|||||||
up = self.hidden_size * self.ffn_size // self.tensor_parallel_size
|
up = self.hidden_size * self.ffn_size // self.tensor_parallel_size
|
||||||
ffn = ln2 + gate + down + up
|
ffn = ln2 + gate + down + up
|
||||||
|
|
||||||
total = (word_embedding + position_embedding + self.num_layers * (mha + ffn))
|
total = word_embedding + self.num_layers * (mha + ffn) + lm_head
|
||||||
dtype_size = get_dtype_size(self.dtype)
|
dtype_size = get_dtype_size(self.dtype)
|
||||||
return dtype_size * total
|
return dtype_size * total
|
||||||
|
|
||||||
def _get_max_act_size(
|
def get_max_act_size(
|
||||||
self,
|
self,
|
||||||
max_num_batched_tokens: int,
|
max_num_batched_tokens: int,
|
||||||
) -> int:
|
) -> int:
|
||||||
@ -213,28 +220,78 @@ class LlamaMemoryAnalyzer(CacheFlowMemoryAnalyzer):
|
|||||||
dtype_size = get_dtype_size(self.dtype)
|
dtype_size = get_dtype_size(self.dtype)
|
||||||
return dtype_size * max_act
|
return dtype_size * max_act
|
||||||
|
|
||||||
def get_cache_block_size(self) -> int:
|
|
||||||
key_cache_block = self.block_size * self.hidden_size // self.tensor_parallel_size
|
class GPTNeoXMemoryAnalyzer(CacheFlowMemoryAnalyzer):
|
||||||
value_cache_block = key_cache_block
|
|
||||||
total = self.num_layers * (key_cache_block + value_cache_block)
|
def __init__(
|
||||||
|
self,
|
||||||
|
model_name: str,
|
||||||
|
block_size: int,
|
||||||
|
dtype: torch.dtype,
|
||||||
|
gpu_memory: int,
|
||||||
|
cpu_memory: int,
|
||||||
|
tensor_parallel_size: int,
|
||||||
|
) -> None:
|
||||||
|
self.model_name = model_name
|
||||||
|
self.block_size = block_size
|
||||||
|
self.dtype = dtype
|
||||||
|
self.gpu_memory = gpu_memory
|
||||||
|
self.cpu_memory = cpu_memory
|
||||||
|
self.tensor_parallel_size = tensor_parallel_size
|
||||||
|
|
||||||
|
config = AutoConfig.from_pretrained(model_name)
|
||||||
|
self.num_layers = config.num_hidden_layers
|
||||||
|
self.hidden_size = config.hidden_size
|
||||||
|
self.num_heads = config.num_attention_heads
|
||||||
|
self.head_size = config.hidden_size // self.num_heads
|
||||||
|
self.ffn_size = config.intermediate_size
|
||||||
|
self.vocab_size = config.vocab_size
|
||||||
|
self.max_position = 8192
|
||||||
|
self.tie_word_embeddings = config.tie_word_embeddings
|
||||||
|
|
||||||
|
def get_param_size(self) -> int:
|
||||||
|
word_embedding = self.vocab_size * self.hidden_size // self.tensor_parallel_size
|
||||||
|
if self.tie_word_embeddings:
|
||||||
|
lm_head = 0
|
||||||
|
else:
|
||||||
|
lm_head = self.vocab_size * self.hidden_size // self.tensor_parallel_size
|
||||||
|
|
||||||
|
ln1 = 2 * self.hidden_size
|
||||||
|
q = self.hidden_size * self.hidden_size // self.tensor_parallel_size + self.hidden_size
|
||||||
|
k = self.hidden_size * self.hidden_size // self.tensor_parallel_size + self.hidden_size
|
||||||
|
v = self.hidden_size * self.hidden_size // self.tensor_parallel_size + self.hidden_size
|
||||||
|
out = self.hidden_size * self.hidden_size // self.tensor_parallel_size + self.hidden_size
|
||||||
|
# Rotary embedding.
|
||||||
|
# TODO(woosuk): Share the rotary embedding between layers.
|
||||||
|
rot = self.max_position * self.head_size
|
||||||
|
mha = ln1 + q + k + v + out + rot
|
||||||
|
|
||||||
|
ln2 = 2 * self.hidden_size
|
||||||
|
ffn1 = self.hidden_size * self.ffn_size // self.tensor_parallel_size + self.ffn_size
|
||||||
|
ffn2 = self.ffn_size * self.hidden_size // self.tensor_parallel_size + self.hidden_size
|
||||||
|
ffn = ln2 + ffn1 + ffn2
|
||||||
|
|
||||||
|
total = word_embedding + self.num_layers * (mha + ffn) + lm_head
|
||||||
dtype_size = get_dtype_size(self.dtype)
|
dtype_size = get_dtype_size(self.dtype)
|
||||||
return dtype_size * total
|
return dtype_size * total
|
||||||
|
|
||||||
def get_max_num_gpu_blocks(
|
def get_max_act_size(
|
||||||
self,
|
self,
|
||||||
max_num_batched_tokens: int,
|
max_num_batched_tokens: int,
|
||||||
memory_utilization: float = 0.95,
|
|
||||||
) -> int:
|
) -> int:
|
||||||
# NOTE(woosuk): This assumes that the machine has homogeneous GPUs.
|
# NOTE: We approxmiately calculate the maximum activation size by
|
||||||
gpu_memory = self.gpu_memory
|
# estimating
|
||||||
usable_memory = int(memory_utilization * gpu_memory)
|
# 1) the maximum activation tensor size during inference
|
||||||
|
# 2) the residual tensor size during inference
|
||||||
param_size = self._get_param_size()
|
# Here, we assume that FlashAttention is used and
|
||||||
act_size = self._get_max_act_size(max_num_batched_tokens)
|
# thus the attention maps are never materialized in GPU DRAM.
|
||||||
workspace_size = self.get_workspace_size()
|
residual = max_num_batched_tokens * self.hidden_size
|
||||||
|
qkv = 3 * (max_num_batched_tokens * self.hidden_size) // self.tensor_parallel_size
|
||||||
max_cache_size = usable_memory - (param_size + act_size + workspace_size)
|
ffn = 2 * (max_num_batched_tokens * self.ffn_size) // self.tensor_parallel_size
|
||||||
if max_cache_size <= 0:
|
# Double the activation size for input and output.
|
||||||
raise RuntimeError('Not enough GPU memory.')
|
max_act = 2 * (max(qkv, ffn) + residual)
|
||||||
max_num_blocks = max_cache_size // self.get_cache_block_size()
|
# Size of output logits.
|
||||||
return max_num_blocks
|
output_logits = 2 * (max_num_batched_tokens * self.vocab_size)
|
||||||
|
max_act = max(max_act, output_logits)
|
||||||
|
dtype_size = get_dtype_size(self.dtype)
|
||||||
|
return dtype_size * max_act
|
||||||
|
@ -1,13 +1,14 @@
|
|||||||
from typing import Union
|
from typing import Union
|
||||||
|
|
||||||
import numpy as np
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
from transformers import AutoConfig
|
from transformers import AutoConfig
|
||||||
|
|
||||||
from cacheflow.models.memory_analyzer import CacheFlowMemoryAnalyzer
|
from cacheflow.models.memory_analyzer import CacheFlowMemoryAnalyzer
|
||||||
|
from cacheflow.models.memory_analyzer import GPTNeoXMemoryAnalyzer
|
||||||
from cacheflow.models.memory_analyzer import LlamaMemoryAnalyzer
|
from cacheflow.models.memory_analyzer import LlamaMemoryAnalyzer
|
||||||
from cacheflow.models.memory_analyzer import OPTMemoryAnalyzer
|
from cacheflow.models.memory_analyzer import OPTMemoryAnalyzer
|
||||||
|
from cacheflow.models.gpt_neox import GPTNeoXForCausalLM
|
||||||
from cacheflow.models.llama import LlamaForCausalLM
|
from cacheflow.models.llama import LlamaForCausalLM
|
||||||
from cacheflow.models.opt import OPTForCausalLM
|
from cacheflow.models.opt import OPTForCausalLM
|
||||||
from cacheflow.models.utils import get_torch_dtype
|
from cacheflow.models.utils import get_torch_dtype
|
||||||
@ -16,11 +17,15 @@ from cacheflow.models.utils import get_torch_dtype
|
|||||||
_MODELS = {
|
_MODELS = {
|
||||||
'llama': LlamaForCausalLM,
|
'llama': LlamaForCausalLM,
|
||||||
'opt': OPTForCausalLM,
|
'opt': OPTForCausalLM,
|
||||||
|
'stablelm': GPTNeoXForCausalLM,
|
||||||
|
'pythia': GPTNeoXForCausalLM,
|
||||||
}
|
}
|
||||||
|
|
||||||
_MEMORY_ANALYZERS = {
|
_MEMORY_ANALYZERS = {
|
||||||
'llama': LlamaMemoryAnalyzer,
|
'llama': LlamaMemoryAnalyzer,
|
||||||
'opt': OPTMemoryAnalyzer,
|
'opt': OPTMemoryAnalyzer,
|
||||||
|
'stablelm': GPTNeoXMemoryAnalyzer,
|
||||||
|
'pythia': GPTNeoXMemoryAnalyzer,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@ -327,4 +327,4 @@ class OPTForCausalLM(nn.Module):
|
|||||||
|
|
||||||
def initialize_dummy_weights(self) -> None:
|
def initialize_dummy_weights(self) -> None:
|
||||||
for param in self.state_dict().values():
|
for param in self.state_dict().values():
|
||||||
param.data.uniform_(-0.1, 0.1)
|
param.data.uniform_(-1e-3, 1e-3)
|
||||||
|
@ -4,6 +4,7 @@ void rotary_embedding_neox(
|
|||||||
torch::Tensor& positions,
|
torch::Tensor& positions,
|
||||||
torch::Tensor& query,
|
torch::Tensor& query,
|
||||||
torch::Tensor& key,
|
torch::Tensor& key,
|
||||||
|
int head_size,
|
||||||
torch::Tensor& cos_sin_cache);
|
torch::Tensor& cos_sin_cache);
|
||||||
|
|
||||||
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
|
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
|
||||||
|
@ -8,16 +8,17 @@ __global__ void rotary_embedding_neox_kernel(
|
|||||||
const int64_t* __restrict__ positions, // [num_tokens]
|
const int64_t* __restrict__ positions, // [num_tokens]
|
||||||
scalar_t* __restrict__ query, // [num_tokens, num_heads, head_size]
|
scalar_t* __restrict__ query, // [num_tokens, num_heads, head_size]
|
||||||
scalar_t* __restrict__ key, // [num_tokens, num_heads, head_size]
|
scalar_t* __restrict__ key, // [num_tokens, num_heads, head_size]
|
||||||
const scalar_t* __restrict__ cos_sin_cache, // [max_position, 2, head_size // 2]
|
const scalar_t* __restrict__ cos_sin_cache, // [max_position, 2, rot_dim // 2]
|
||||||
|
const int rot_dim,
|
||||||
const int stride,
|
const int stride,
|
||||||
const int num_heads,
|
const int num_heads,
|
||||||
const int head_size) {
|
const int head_size) {
|
||||||
// Each thread block is responsible for one token.
|
// Each thread block is responsible for one token.
|
||||||
const int token_idx = blockIdx.x;
|
const int token_idx = blockIdx.x;
|
||||||
int64_t pos = positions[token_idx];
|
int64_t pos = positions[token_idx];
|
||||||
const scalar_t* cache_ptr = cos_sin_cache + pos * head_size;
|
const scalar_t* cache_ptr = cos_sin_cache + pos * rot_dim;
|
||||||
|
|
||||||
const int embed_dim = head_size / 2;
|
const int embed_dim = rot_dim / 2;
|
||||||
const int n = num_heads * embed_dim;
|
const int n = num_heads * embed_dim;
|
||||||
for (int i = threadIdx.x; i < n; i += blockDim.x) {
|
for (int i = threadIdx.x; i < n; i += blockDim.x) {
|
||||||
const int head_idx = i / embed_dim;
|
const int head_idx = i / embed_dim;
|
||||||
@ -51,16 +52,17 @@ void rotary_embedding_neox(
|
|||||||
torch::Tensor& positions, // [num_tokens]
|
torch::Tensor& positions, // [num_tokens]
|
||||||
torch::Tensor& query, // [num_tokens, num_heads * head_size]
|
torch::Tensor& query, // [num_tokens, num_heads * head_size]
|
||||||
torch::Tensor& key, // [num_tokens, num_heads * head_size]
|
torch::Tensor& key, // [num_tokens, num_heads * head_size]
|
||||||
torch::Tensor& cos_sin_cache) // [max_position, head_size]
|
int head_size,
|
||||||
|
torch::Tensor& cos_sin_cache) // [max_position, rot_dim]
|
||||||
{
|
{
|
||||||
int num_tokens = query.size(0);
|
int num_tokens = query.size(0);
|
||||||
int head_size = cos_sin_cache.size(1);
|
int rot_dim = cos_sin_cache.size(1);
|
||||||
int num_heads = query.size(1) / head_size;
|
int num_heads = query.size(1) / head_size;
|
||||||
int stride = query.stride(0);
|
int stride = query.stride(0);
|
||||||
TORCH_CHECK(stride == key.stride(0));
|
TORCH_CHECK(stride == key.stride(0));
|
||||||
|
|
||||||
dim3 grid(num_tokens);
|
dim3 grid(num_tokens);
|
||||||
dim3 block(std::min(num_heads * head_size / 2, 512));
|
dim3 block(std::min(num_heads * rot_dim / 2, 512));
|
||||||
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
||||||
AT_DISPATCH_FLOATING_TYPES_AND_HALF(
|
AT_DISPATCH_FLOATING_TYPES_AND_HALF(
|
||||||
query.scalar_type(),
|
query.scalar_type(),
|
||||||
@ -71,6 +73,7 @@ void rotary_embedding_neox(
|
|||||||
query.data_ptr<scalar_t>(),
|
query.data_ptr<scalar_t>(),
|
||||||
key.data_ptr<scalar_t>(),
|
key.data_ptr<scalar_t>(),
|
||||||
cos_sin_cache.data_ptr<scalar_t>(),
|
cos_sin_cache.data_ptr<scalar_t>(),
|
||||||
|
rot_dim,
|
||||||
stride,
|
stride,
|
||||||
num_heads,
|
num_heads,
|
||||||
head_size);
|
head_size);
|
||||||
|
@ -34,6 +34,7 @@ class RefRotaryEmbeddingNeox(nn.Module):
|
|||||||
base: int = 10000,
|
base: int = 10000,
|
||||||
) -> None:
|
) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
self.rotary_dim = dim
|
||||||
self.max_position_embeddings = max_position_embeddings
|
self.max_position_embeddings = max_position_embeddings
|
||||||
|
|
||||||
# Create cos and sin embeddings.
|
# Create cos and sin embeddings.
|
||||||
@ -52,13 +53,24 @@ class RefRotaryEmbeddingNeox(nn.Module):
|
|||||||
query: torch.Tensor, # [num_tokens, num_heads, head_size]
|
query: torch.Tensor, # [num_tokens, num_heads, head_size]
|
||||||
key: torch.Tensor, # [num_tokens, num_heads, head_size]
|
key: torch.Tensor, # [num_tokens, num_heads, head_size]
|
||||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||||
|
|
||||||
|
query_rot = query[..., : self.rotary_dim]
|
||||||
|
query_pass = query[..., self.rotary_dim :]
|
||||||
|
key_rot = key[..., : self.rotary_dim]
|
||||||
|
key_pass = key[..., self.rotary_dim :]
|
||||||
|
|
||||||
|
|
||||||
|
query_rot = query_rot.transpose(0, 1)
|
||||||
|
key_rot = key_rot.transpose(0, 1)
|
||||||
cos = F.embedding(positions, self.cos_cached)
|
cos = F.embedding(positions, self.cos_cached)
|
||||||
sin = F.embedding(positions, self.sin_cached)
|
sin = F.embedding(positions, self.sin_cached)
|
||||||
query = query.transpose(0, 1)
|
query_rot, key_rot = apply_rotary_pos_emb(query_rot, key_rot, cos, sin)
|
||||||
key = key.transpose(0, 1)
|
query_rot = query_rot.transpose(0, 1).contiguous()
|
||||||
query, key = apply_rotary_pos_emb(query, key, cos, sin)
|
key_rot = key_rot.transpose(0, 1).contiguous()
|
||||||
query = query.transpose(0, 1).contiguous()
|
|
||||||
key = key.transpose(0, 1).contiguous()
|
query = torch.cat((query_rot, query_pass), dim=-1)
|
||||||
|
key = torch.cat((key_rot, key_pass), dim=-1)
|
||||||
|
|
||||||
# Output query/key shape: [num_tokens, num_tokens, head_size]
|
# Output query/key shape: [num_tokens, num_tokens, head_size]
|
||||||
return query, key
|
return query, key
|
||||||
|
|
||||||
@ -69,6 +81,7 @@ def test_rotary_embedding_neox(
|
|||||||
num_heads: int,
|
num_heads: int,
|
||||||
head_size: int,
|
head_size: int,
|
||||||
max_position: int,
|
max_position: int,
|
||||||
|
rotary_dim: int,
|
||||||
dtype: torch.dtype,
|
dtype: torch.dtype,
|
||||||
base: int = 10000,
|
base: int = 10000,
|
||||||
) -> None:
|
) -> None:
|
||||||
@ -77,7 +90,7 @@ def test_rotary_embedding_neox(
|
|||||||
key = torch.randn(num_tokens, num_heads * head_size, dtype=dtype, device='cuda')
|
key = torch.randn(num_tokens, num_heads * head_size, dtype=dtype, device='cuda')
|
||||||
|
|
||||||
# Create the rotary embedding.
|
# Create the rotary embedding.
|
||||||
inv_freq = 1.0 / (base ** (torch.arange(0, head_size, 2) / head_size))
|
inv_freq = 1.0 / (base ** (torch.arange(0, rotary_dim, 2) / rotary_dim))
|
||||||
t = torch.arange(max_position).float()
|
t = torch.arange(max_position).float()
|
||||||
freqs = torch.einsum('i,j -> ij', t, inv_freq.float())
|
freqs = torch.einsum('i,j -> ij', t, inv_freq.float())
|
||||||
cos = freqs.cos()
|
cos = freqs.cos()
|
||||||
@ -92,12 +105,13 @@ def test_rotary_embedding_neox(
|
|||||||
positions,
|
positions,
|
||||||
out_query,
|
out_query,
|
||||||
out_key,
|
out_key,
|
||||||
|
head_size,
|
||||||
cos_sin_cache,
|
cos_sin_cache,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Run the reference implementation.
|
# Run the reference implementation.
|
||||||
ref_rotary_embedding = RefRotaryEmbeddingNeox(
|
ref_rotary_embedding = RefRotaryEmbeddingNeox(
|
||||||
dim=head_size,
|
dim=rotary_dim,
|
||||||
max_position_embeddings=max_position,
|
max_position_embeddings=max_position,
|
||||||
base=base,
|
base=base,
|
||||||
).to(dtype=dtype, device='cuda')
|
).to(dtype=dtype, device='cuda')
|
||||||
@ -123,5 +137,6 @@ if __name__ == '__main__':
|
|||||||
num_heads=5,
|
num_heads=5,
|
||||||
head_size=head_size,
|
head_size=head_size,
|
||||||
max_position=8192,
|
max_position=8192,
|
||||||
|
rotary_dim=int(head_size * 0.25),
|
||||||
dtype=dtype,
|
dtype=dtype,
|
||||||
)
|
)
|
||||||
|
Loading…
x
Reference in New Issue
Block a user