"""1D OPT model compatible with HuggingFace weights.""" from typing import Dict, List, Optional, Tuple import torch from torch import nn from transformers import OPTConfig from transformers import PreTrainedModel from cacheflow.models import InputMetadata from cacheflow.models.attention import OPTCacheFlowAttention from cacheflow.models.sample import Sampler from cacheflow.sequence import SequenceOutputs KVCache = Tuple[torch.Tensor, torch.Tensor] class OPTLearnedPositionalEmbedding(nn.Embedding): def __init__(self, num_embeddings: int, embedding_dim: int): # OPT is set up so that if padding_idx is specified then offset the embedding ids by 2 # and adjust num_embeddings appropriately. Other models don't have this hack self.offset = 2 super().__init__(num_embeddings + self.offset, embedding_dim) def forward(self, positions: torch.LongTensor): return super().forward(positions + self.offset) class OPTAttention(nn.Module): def __init__( self, embed_dim: int, num_heads: int, bias: bool = True, ) -> None: super().__init__() self.embed_dim = embed_dim self.num_heads = num_heads self.head_dim = embed_dim // num_heads self.scaling = self.head_dim**-0.5 # TODO(woosuk): Fuse the three linear layers into one QKV linear layer. self.k_proj = nn.Linear(embed_dim, embed_dim, bias=bias) self.v_proj = nn.Linear(embed_dim, embed_dim, bias=bias) self.q_proj = nn.Linear(embed_dim, embed_dim, bias=bias) self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias) self.attn = OPTCacheFlowAttention(scale=self.scaling) def forward( self, hidden_states: torch.Tensor, kv_cache: KVCache, input_metadata: InputMetadata, cache_event: Optional[torch.cuda.Event], ) -> torch.Tensor: q = self.q_proj(hidden_states) k = self.k_proj(hidden_states) v = self.v_proj(hidden_states) key_cache, value_cache = kv_cache attn_output = self.attn( q, k, v, key_cache, value_cache, input_metadata, cache_event) output = self.out_proj(attn_output) return output class OPTDecoderLayer(nn.Module): def __init__(self, config: OPTConfig): super().__init__() self.embed_dim = config.hidden_size self.self_attn = OPTAttention( embed_dim=self.embed_dim, num_heads=config.num_attention_heads, bias=config.enable_bias, ) self.do_layer_norm_before = config.do_layer_norm_before assert config.activation_function == 'relu' self.activation_fn = nn.ReLU() self.self_attn_layer_norm = nn.LayerNorm( self.embed_dim, elementwise_affine=config.layer_norm_elementwise_affine) self.fc1 = nn.Linear(self.embed_dim, config.ffn_dim, bias=config.enable_bias) self.fc2 = nn.Linear(config.ffn_dim, self.embed_dim, bias=config.enable_bias) self.final_layer_norm = nn.LayerNorm(self.embed_dim, elementwise_affine=config.layer_norm_elementwise_affine) def forward( self, hidden_states: torch.Tensor, kv_cache: KVCache, input_metadata: InputMetadata, cache_event: Optional[torch.cuda.Event], ) -> torch.Tensor: # Self Attention residual = hidden_states # 125m, 1.7B, ..., 175B applies layer norm BEFORE attention if self.do_layer_norm_before: hidden_states = self.self_attn_layer_norm(hidden_states) hidden_states = self.self_attn( hidden_states=hidden_states, kv_cache=kv_cache, input_metadata=input_metadata, cache_event=cache_event) hidden_states = residual + hidden_states # 350m applies layer norm AFTER attention if not self.do_layer_norm_before: hidden_states = self.self_attn_layer_norm(hidden_states) # Fully Connected residual = hidden_states # 125m, 1.7B, ..., 175B applies layer norm BEFORE attention if self.do_layer_norm_before: hidden_states = self.final_layer_norm(hidden_states) hidden_states = self.fc1(hidden_states) hidden_states = self.activation_fn(hidden_states) hidden_states = self.fc2(hidden_states) hidden_states = residual + hidden_states # 350m applies layer norm AFTER attention if not self.do_layer_norm_before: hidden_states = self.final_layer_norm(hidden_states) return hidden_states class OPTPreTrainedModel(PreTrainedModel): config_class = OPTConfig base_model_prefix = "model" supports_gradient_checkpointing = True _no_split_modules = ["OPTDecoderLayer"] _keys_to_ignore_on_load_unexpected = [r"decoder\.version"] def _init_weights(self, module) -> None: del module # unused return class OPTDecoder(OPTPreTrainedModel): def __init__(self, config: OPTConfig): super().__init__(config) self.padding_idx = config.pad_token_id self.max_target_positions = config.max_position_embeddings self.vocab_size = config.vocab_size self.embed_tokens = nn.Embedding(config.vocab_size, config.word_embed_proj_dim, self.padding_idx) self.embed_positions = OPTLearnedPositionalEmbedding(config.max_position_embeddings, config.hidden_size) if config.word_embed_proj_dim != config.hidden_size: self.project_out = nn.Linear(config.hidden_size, config.word_embed_proj_dim, bias=False) else: self.project_out = None if config.word_embed_proj_dim != config.hidden_size: self.project_in = nn.Linear(config.word_embed_proj_dim, config.hidden_size, bias=False) else: self.project_in = None # Note that the only purpose of `config._remove_final_layer_norm` is to keep backward compatibility # with checkpoints that have been fine-tuned before transformers v4.20.1 # see https://github.com/facebookresearch/metaseq/pull/164 if config.do_layer_norm_before and not config._remove_final_layer_norm: self.final_layer_norm = nn.LayerNorm( config.hidden_size, elementwise_affine=config.layer_norm_elementwise_affine ) else: self.final_layer_norm = None self.layers = nn.ModuleList([OPTDecoderLayer(config) for _ in range(config.num_hidden_layers)]) # Initialize weights and apply final processing self.post_init() def forward( self, input_ids: torch.LongTensor, positions: torch.LongTensor, kv_caches: List[KVCache], input_metadata: InputMetadata, cache_events: Optional[List[torch.cuda.Event]], ) -> torch.Tensor: inputs_embeds = self.embed_tokens(input_ids) pos_embeds = self.embed_positions(positions) if self.project_in is not None: inputs_embeds = self.project_in(inputs_embeds) hidden_states = inputs_embeds + pos_embeds for i in range(len(self.layers)): if cache_events is None: cache_event = None else: cache_event = cache_events[i] layer = self.layers[i] hidden_states = layer( hidden_states, kv_caches[i], input_metadata, cache_event) if self.final_layer_norm is not None: hidden_states = self.final_layer_norm(hidden_states) if self.project_out is not None: hidden_states = self.project_out(hidden_states) return hidden_states class OPTModel(OPTPreTrainedModel): def __init__(self, config: OPTConfig): super().__init__(config) self.decoder = OPTDecoder(config) # Initialize weights and apply final processing self.post_init() def forward( self, input_ids: torch.LongTensor, positions: torch.LongTensor, kv_caches: List[KVCache], input_metadata: InputMetadata, cache_events: Optional[List[torch.cuda.Event]], ) -> torch.Tensor: return self.decoder( input_ids, positions, kv_caches, input_metadata, cache_events) class OPTForCausalLM(OPTPreTrainedModel): _keys_to_ignore_on_load_missing = [r"lm_head.weight"] def __init__(self, config): super().__init__(config) self.model = OPTModel(config) # the lm_head weight is automatically tied to the embed tokens weight self.lm_head = nn.Linear(config.word_embed_proj_dim, config.vocab_size, bias=False) self.sampler = Sampler() # Initialize weights and apply final processing self.post_init() # NOTE(woosuk): While the following methods are not called in the model code, # they may be internally used by the transformers library. # For example, tie_weights() does not work without these methods. # Thus, do not delete these methods. def get_input_embeddings(self): return self.model.decoder.embed_tokens def set_input_embeddings(self, value): self.model.decoder.embed_tokens = value def get_output_embeddings(self): return self.lm_head def set_output_embeddings(self, new_embeddings): self.lm_head = new_embeddings def set_decoder(self, decoder): self.model.decoder = decoder def get_decoder(self): return self.model.decoder def forward( self, input_ids: torch.LongTensor, positions: torch.LongTensor, kv_caches: List[KVCache], input_metadata: InputMetadata, cache_events: Optional[List[torch.cuda.Event]], ) -> Dict[int, SequenceOutputs]: hidden_states = self.model( input_ids, positions, kv_caches, input_metadata, cache_events) next_tokens = self.sampler( self.lm_head.weight, hidden_states, input_metadata) return next_tokens