271 lines
9.8 KiB
Python
Raw Normal View History

2023-02-09 11:25:37 +00:00
"""1D OPT model compatible with HuggingFace weights."""
2023-02-23 09:31:55 +00:00
from typing import Dict, List, Optional, Tuple
2023-02-09 11:25:37 +00:00
import torch
from torch import nn
from transformers import OPTConfig
from transformers import PreTrainedModel
2023-02-23 09:31:55 +00:00
from cacheflow.models import InputMetadata
from cacheflow.models.attention import OPTCacheFlowAttention
from cacheflow.models.sample import Sampler
from cacheflow.sequence import SequenceOutputs
2023-02-23 09:31:55 +00:00
KVCache = Tuple[torch.Tensor, torch.Tensor]
2023-02-09 11:25:37 +00:00
class OPTLearnedPositionalEmbedding(nn.Embedding):
def __init__(self, num_embeddings: int, embedding_dim: int):
# OPT is set up so that if padding_idx is specified then offset the embedding ids by 2
# and adjust num_embeddings appropriately. Other models don't have this hack
self.offset = 2
super().__init__(num_embeddings + self.offset, embedding_dim)
def forward(self, positions: torch.LongTensor):
return super().forward(positions + self.offset)
class OPTAttention(nn.Module):
def __init__(
self,
embed_dim: int,
num_heads: int,
bias: bool = True,
) -> None:
super().__init__()
self.embed_dim = embed_dim
self.num_heads = num_heads
self.head_dim = embed_dim // num_heads
self.scaling = self.head_dim**-0.5
2023-02-23 09:31:55 +00:00
# TODO(woosuk): Fuse the three linear layers into one QKV linear layer.
2023-02-09 11:25:37 +00:00
self.k_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
self.v_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
self.q_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
2023-02-23 09:31:55 +00:00
self.attn = OPTCacheFlowAttention(scale=self.scaling)
def forward(
self,
hidden_states: torch.Tensor,
kv_cache: KVCache,
input_metadata: InputMetadata,
cache_event: Optional[torch.cuda.Event],
) -> torch.Tensor:
q = self.q_proj(hidden_states)
2023-02-09 11:25:37 +00:00
k = self.k_proj(hidden_states)
v = self.v_proj(hidden_states)
2023-02-23 09:31:55 +00:00
key_cache, value_cache = kv_cache
attn_output = self.attn(
q, k, v, key_cache, value_cache, input_metadata, cache_event)
2023-02-09 11:25:37 +00:00
output = self.out_proj(attn_output)
return output
class OPTDecoderLayer(nn.Module):
def __init__(self, config: OPTConfig):
super().__init__()
self.embed_dim = config.hidden_size
self.self_attn = OPTAttention(
embed_dim=self.embed_dim,
num_heads=config.num_attention_heads,
bias=config.enable_bias,
)
self.do_layer_norm_before = config.do_layer_norm_before
assert config.activation_function == 'relu'
self.activation_fn = nn.ReLU()
self.self_attn_layer_norm = nn.LayerNorm(
self.embed_dim, elementwise_affine=config.layer_norm_elementwise_affine)
self.fc1 = nn.Linear(self.embed_dim, config.ffn_dim, bias=config.enable_bias)
self.fc2 = nn.Linear(config.ffn_dim, self.embed_dim, bias=config.enable_bias)
self.final_layer_norm = nn.LayerNorm(self.embed_dim, elementwise_affine=config.layer_norm_elementwise_affine)
2023-02-23 09:31:55 +00:00
def forward(
self,
hidden_states: torch.Tensor,
kv_cache: KVCache,
input_metadata: InputMetadata,
cache_event: Optional[torch.cuda.Event],
) -> torch.Tensor:
2023-02-09 11:25:37 +00:00
# Self Attention
residual = hidden_states
# 125m, 1.7B, ..., 175B applies layer norm BEFORE attention
if self.do_layer_norm_before:
hidden_states = self.self_attn_layer_norm(hidden_states)
2023-02-23 09:31:55 +00:00
hidden_states = self.self_attn(
hidden_states=hidden_states,
kv_cache=kv_cache,
input_metadata=input_metadata,
cache_event=cache_event)
2023-02-09 11:25:37 +00:00
hidden_states = residual + hidden_states
# 350m applies layer norm AFTER attention
if not self.do_layer_norm_before:
hidden_states = self.self_attn_layer_norm(hidden_states)
# Fully Connected
residual = hidden_states
# 125m, 1.7B, ..., 175B applies layer norm BEFORE attention
if self.do_layer_norm_before:
hidden_states = self.final_layer_norm(hidden_states)
hidden_states = self.fc1(hidden_states)
hidden_states = self.activation_fn(hidden_states)
hidden_states = self.fc2(hidden_states)
hidden_states = residual + hidden_states
# 350m applies layer norm AFTER attention
if not self.do_layer_norm_before:
hidden_states = self.final_layer_norm(hidden_states)
return hidden_states
class OPTPreTrainedModel(PreTrainedModel):
config_class = OPTConfig
base_model_prefix = "model"
supports_gradient_checkpointing = True
_no_split_modules = ["OPTDecoderLayer"]
_keys_to_ignore_on_load_unexpected = [r"decoder\.version"]
def _init_weights(self, module) -> None:
del module # unused
return
class OPTDecoder(OPTPreTrainedModel):
def __init__(self, config: OPTConfig):
super().__init__(config)
self.padding_idx = config.pad_token_id
self.max_target_positions = config.max_position_embeddings
self.vocab_size = config.vocab_size
self.embed_tokens = nn.Embedding(config.vocab_size, config.word_embed_proj_dim, self.padding_idx)
self.embed_positions = OPTLearnedPositionalEmbedding(config.max_position_embeddings, config.hidden_size)
if config.word_embed_proj_dim != config.hidden_size:
self.project_out = nn.Linear(config.hidden_size, config.word_embed_proj_dim, bias=False)
else:
self.project_out = None
if config.word_embed_proj_dim != config.hidden_size:
self.project_in = nn.Linear(config.word_embed_proj_dim, config.hidden_size, bias=False)
else:
self.project_in = None
# Note that the only purpose of `config._remove_final_layer_norm` is to keep backward compatibility
# with checkpoints that have been fine-tuned before transformers v4.20.1
# see https://github.com/facebookresearch/metaseq/pull/164
if config.do_layer_norm_before and not config._remove_final_layer_norm:
self.final_layer_norm = nn.LayerNorm(
config.hidden_size, elementwise_affine=config.layer_norm_elementwise_affine
)
else:
self.final_layer_norm = None
self.layers = nn.ModuleList([OPTDecoderLayer(config) for _ in range(config.num_hidden_layers)])
# Initialize weights and apply final processing
self.post_init()
def forward(
self,
input_ids: torch.LongTensor,
positions: torch.LongTensor,
2023-02-23 09:31:55 +00:00
kv_caches: List[KVCache],
input_metadata: InputMetadata,
cache_events: Optional[List[torch.cuda.Event]],
2023-02-09 11:25:37 +00:00
) -> torch.Tensor:
inputs_embeds = self.embed_tokens(input_ids)
pos_embeds = self.embed_positions(positions)
if self.project_in is not None:
inputs_embeds = self.project_in(inputs_embeds)
hidden_states = inputs_embeds + pos_embeds
2023-02-23 09:31:55 +00:00
for i in range(len(self.layers)):
if cache_events is None:
cache_event = None
else:
cache_event = cache_events[i]
layer = self.layers[i]
hidden_states = layer(
hidden_states, kv_caches[i], input_metadata, cache_event)
2023-02-09 11:25:37 +00:00
if self.final_layer_norm is not None:
hidden_states = self.final_layer_norm(hidden_states)
if self.project_out is not None:
hidden_states = self.project_out(hidden_states)
return hidden_states
class OPTModel(OPTPreTrainedModel):
def __init__(self, config: OPTConfig):
super().__init__(config)
self.decoder = OPTDecoder(config)
# Initialize weights and apply final processing
self.post_init()
def forward(
self,
input_ids: torch.LongTensor,
positions: torch.LongTensor,
2023-02-23 09:31:55 +00:00
kv_caches: List[KVCache],
input_metadata: InputMetadata,
cache_events: Optional[List[torch.cuda.Event]],
2023-02-09 11:25:37 +00:00
) -> torch.Tensor:
2023-02-23 09:31:55 +00:00
return self.decoder(
input_ids, positions, kv_caches, input_metadata, cache_events)
2023-02-09 11:25:37 +00:00
class OPTForCausalLM(OPTPreTrainedModel):
_keys_to_ignore_on_load_missing = [r"lm_head.weight"]
def __init__(self, config):
super().__init__(config)
self.model = OPTModel(config)
# the lm_head weight is automatically tied to the embed tokens weight
self.lm_head = nn.Linear(config.word_embed_proj_dim, config.vocab_size, bias=False)
2023-02-23 20:30:12 +00:00
self.sampler = Sampler()
2023-02-09 11:25:37 +00:00
# Initialize weights and apply final processing
self.post_init()
2023-02-24 16:29:36 -08:00
# NOTE(woosuk): While the following methods are not called in the model code,
# they may be internally used by the transformers library.
# For example, tie_weights() does not work without these methods.
# Thus, do not delete these methods.
def get_input_embeddings(self):
return self.model.decoder.embed_tokens
def set_input_embeddings(self, value):
self.model.decoder.embed_tokens = value
def get_output_embeddings(self):
return self.lm_head
def set_output_embeddings(self, new_embeddings):
self.lm_head = new_embeddings
def set_decoder(self, decoder):
self.model.decoder = decoder
def get_decoder(self):
return self.model.decoder
2023-02-09 11:25:37 +00:00
def forward(
self,
input_ids: torch.LongTensor,
positions: torch.LongTensor,
2023-02-23 09:31:55 +00:00
kv_caches: List[KVCache],
input_metadata: InputMetadata,
cache_events: Optional[List[torch.cuda.Event]],
) -> Dict[int, SequenceOutputs]:
2023-02-23 09:31:55 +00:00
hidden_states = self.model(
input_ids, positions, kv_caches, input_metadata, cache_events)
2023-02-23 20:30:12 +00:00
next_tokens = self.sampler(
self.lm_head.weight, hidden_states, input_metadata)
2023-02-23 09:31:55 +00:00
return next_tokens