"""1D OPT model compatible with HuggingFace weights.""" from typing import Dict, List, Optional, Tuple import torch from torch import nn from transformers import OPTConfig from cacheflow.model_executor.input_metadata import InputMetadata from cacheflow.model_executor.layers.attention import GPTCacheFlowAttention from cacheflow.model_executor.layers.sampler import Sampler from cacheflow.model_executor.weight_utils import (hf_model_weights_iterator, load_tensor_parallel_weights) from cacheflow.model_executor.parallel_utils.parallel_state import ( get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size) from cacheflow.model_executor.parallel_utils.tensor_parallel import ( VocabParallelEmbedding, ColumnParallelLinear, RowParallelLinear) from cacheflow.sequence import SequenceOutputs KVCache = Tuple[torch.Tensor, torch.Tensor] class OPTLearnedPositionalEmbedding(nn.Embedding): def __init__(self, num_embeddings: int, embedding_dim: int): # OPT is set up so that if padding_idx is specified then offset the embedding ids by 2 # and adjust num_embeddings appropriately. Other models don't have this hack self.offset = 2 super().__init__(num_embeddings + self.offset, embedding_dim) def forward(self, positions: torch.LongTensor): return super().forward(positions + self.offset) class OPTAttention(nn.Module): def __init__( self, embed_dim: int, num_heads: int, bias: bool = True, ) -> None: super().__init__() self.embed_dim = embed_dim tensor_model_parallel_world_size = get_tensor_model_parallel_world_size() total_num_heads = num_heads assert num_heads % tensor_model_parallel_world_size == 0 self.num_heads = total_num_heads // tensor_model_parallel_world_size self.head_dim = embed_dim // total_num_heads self.scaling = self.head_dim ** -0.5 self.qkv_proj = ColumnParallelLinear(embed_dim, 3 * embed_dim, bias=bias, gather_output=False, perform_initialization=False) self.out_proj = RowParallelLinear(embed_dim, embed_dim, bias=bias, input_is_parallel=True, perform_initialization=False) self.attn = GPTCacheFlowAttention(scale=self.scaling) def forward( self, hidden_states: torch.Tensor, kv_cache: KVCache, input_metadata: InputMetadata, cache_event: Optional[torch.cuda.Event], ) -> torch.Tensor: qkv, _ = self.qkv_proj(hidden_states) q, k, v = qkv.chunk(chunks=3, dim=-1) key_cache, value_cache = kv_cache attn_output = self.attn( q, k, v, key_cache, value_cache, input_metadata, cache_event) output, _ = self.out_proj(attn_output) return output class OPTDecoderLayer(nn.Module): def __init__(self, config: OPTConfig): super().__init__() self.config = config self.embed_dim = config.hidden_size self.self_attn = OPTAttention( embed_dim=self.embed_dim, num_heads=config.num_attention_heads, bias=config.enable_bias, ) self.do_layer_norm_before = config.do_layer_norm_before assert config.activation_function == 'relu' self.activation_fn = nn.ReLU() self.self_attn_layer_norm = nn.LayerNorm( self.embed_dim, elementwise_affine=config.layer_norm_elementwise_affine) self.fc1 = ColumnParallelLinear(self.embed_dim, config.ffn_dim, bias=config.enable_bias, gather_output=False, perform_initialization=False) self.fc2 = RowParallelLinear(config.ffn_dim, self.embed_dim, bias=config.enable_bias, input_is_parallel=True, perform_initialization=False) self.final_layer_norm = nn.LayerNorm( self.embed_dim, elementwise_affine=config.layer_norm_elementwise_affine) def forward( self, hidden_states: torch.Tensor, kv_cache: KVCache, input_metadata: InputMetadata, cache_event: Optional[torch.cuda.Event], ) -> torch.Tensor: # Self Attention residual = hidden_states # 125m, 1.7B, ..., 175B applies layer norm BEFORE attention if self.do_layer_norm_before: hidden_states = self.self_attn_layer_norm(hidden_states) hidden_states = self.self_attn( hidden_states=hidden_states, kv_cache=kv_cache, input_metadata=input_metadata, cache_event=cache_event) hidden_states = residual + hidden_states # 350m applies layer norm AFTER attention if not self.do_layer_norm_before: hidden_states = self.self_attn_layer_norm(hidden_states) # Fully Connected residual = hidden_states # 125m, 1.7B, ..., 175B applies layer norm BEFORE attention if self.do_layer_norm_before: hidden_states = self.final_layer_norm(hidden_states) hidden_states, _ = self.fc1(hidden_states) hidden_states = self.activation_fn(hidden_states) hidden_states, _ = self.fc2(hidden_states) hidden_states = residual + hidden_states # 350m applies layer norm AFTER attention if not self.do_layer_norm_before: hidden_states = self.final_layer_norm(hidden_states) return hidden_states class OPTDecoder(nn.Module): def __init__(self, config: OPTConfig): super().__init__() self.config = config self.padding_idx = config.pad_token_id self.max_target_positions = config.max_position_embeddings self.vocab_size = config.vocab_size self.embed_tokens = VocabParallelEmbedding(config.vocab_size, config.word_embed_proj_dim, perform_initialization=False) # Positional embeddings are replicated (not sharded). self.embed_positions = OPTLearnedPositionalEmbedding( config.max_position_embeddings, config.hidden_size) # Project out & in will be replicated if they exist. if config.word_embed_proj_dim != config.hidden_size: self.project_out = nn.Linear(config.hidden_size, config.word_embed_proj_dim, bias=False) else: self.project_out = None if config.word_embed_proj_dim != config.hidden_size: self.project_in = nn.Linear(config.word_embed_proj_dim, config.hidden_size, bias=False) else: self.project_in = None # Note that the only purpose of `config._remove_final_layer_norm` is to keep backward compatibility # with checkpoints that have been fine-tuned before transformers v4.20.1 # see https://github.com/facebookresearch/metaseq/pull/164 if config.do_layer_norm_before and not config._remove_final_layer_norm: self.final_layer_norm = nn.LayerNorm( config.hidden_size, elementwise_affine=config.layer_norm_elementwise_affine ) else: self.final_layer_norm = None self.layers = nn.ModuleList([OPTDecoderLayer(config) for _ in range(config.num_hidden_layers)]) def forward( self, input_ids: torch.LongTensor, positions: torch.LongTensor, kv_caches: List[KVCache], input_metadata: InputMetadata, cache_events: Optional[List[torch.cuda.Event]], ) -> torch.Tensor: inputs_embeds = self.embed_tokens(input_ids) pos_embeds = self.embed_positions(positions) if self.project_in is not None: inputs_embeds = self.project_in(inputs_embeds) hidden_states = inputs_embeds + pos_embeds for i in range(len(self.layers)): if cache_events is None: cache_event = None else: cache_event = cache_events[i] layer = self.layers[i] hidden_states = layer( hidden_states, kv_caches[i], input_metadata, cache_event) if self.final_layer_norm is not None: hidden_states = self.final_layer_norm(hidden_states) if self.project_out is not None: hidden_states = self.project_out(hidden_states) return hidden_states class OPTModel(nn.Module): def __init__(self, config: OPTConfig): super().__init__() self.decoder = OPTDecoder(config) def forward( self, input_ids: torch.LongTensor, positions: torch.LongTensor, kv_caches: List[KVCache], input_metadata: InputMetadata, cache_events: Optional[List[torch.cuda.Event]], ) -> torch.Tensor: return self.decoder( input_ids, positions, kv_caches, input_metadata, cache_events) class OPTForCausalLM(nn.Module): def __init__(self, config): super().__init__() self.config = config self.model = OPTModel(config) # TODO(zhuohan): create a new weight after implementing pipeline # parallelism self.lm_head_weight = self.model.decoder.embed_tokens.weight self.sampler = Sampler(config.vocab_size) def forward( self, input_ids: torch.LongTensor, positions: torch.LongTensor, kv_caches: List[KVCache], input_metadata: InputMetadata, cache_events: Optional[List[torch.cuda.Event]], ) -> Dict[int, SequenceOutputs]: hidden_states = self.model( input_ids, positions, kv_caches, input_metadata, cache_events) next_tokens = self.sampler( self.lm_head_weight, hidden_states, input_metadata) return next_tokens _column_parallel_weights = ["embed_tokens.weight", "fc1.weight", "fc1.bias"] _row_parallel_weights = ["out_proj.weight", "fc2.weight"] def load_weights(self, model_name_or_path: str, cache_dir: Optional[str] = None, use_np_cache: bool = False): tensor_model_parallel_rank = get_tensor_model_parallel_rank() state_dict = self.state_dict() for name, loaded_weight in hf_model_weights_iterator( model_name_or_path, cache_dir, use_np_cache): if "lm_head.weight" in name: continue if name.startswith("decoder."): name = "model." + name is_attention_weight = False for stride_id, att_weight_name in enumerate(["q_proj", "k_proj", "v_proj"]): if att_weight_name not in name: continue param = state_dict[name.replace(att_weight_name, "qkv_proj")] shard_size = param.shape[0] // 3 loaded_weight = loaded_weight[ shard_size * tensor_model_parallel_rank :shard_size * (tensor_model_parallel_rank + 1)] param_slice = param.data[shard_size * stride_id :shard_size * (stride_id + 1)] assert param_slice.shape == loaded_weight.shape param_slice.copy_(loaded_weight) is_attention_weight = True break if is_attention_weight: continue param = state_dict[name] load_tensor_parallel_weights(param, loaded_weight, name, self._column_parallel_weights, self._row_parallel_weights, tensor_model_parallel_rank)