"""1D GPT-NeoX model compatible with HuggingFace weights.""" from typing import Dict, List, Optional, Tuple import torch from torch import nn from cacheflow.models import InputMetadata from cacheflow.models.attention import GPTNeoXCacheFlowAttention from cacheflow.models.sample import Sampler from cacheflow.models.utils import (hf_model_weights_iterator, load_tensor_parallel_weights) from cacheflow.parallel_utils.parallel_state import ( get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size) from cacheflow.parallel_utils.tensor_parallel import (VocabParallelEmbedding, ColumnParallelLinear, RowParallelLinear) from cacheflow.sequence import SequenceOutputs KVCache = Tuple[torch.Tensor, torch.Tensor] class GPTNeoXAttention(nn.Module): def __init__(self, config): super().__init__() self.total_num_heads = config.num_attention_heads self.hidden_size = config.hidden_size self.head_size = self.hidden_size // self.total_num_heads tensor_model_parallel_world_size = get_tensor_model_parallel_world_size() assert self.total_num_heads % tensor_model_parallel_world_size == 0 self.num_heads = self.total_num_heads // tensor_model_parallel_world_size self.query_key_value = ColumnParallelLinear(config.hidden_size, 3 * config.hidden_size, gather_output=False, perform_initialization=False) self.dense = RowParallelLinear(config.hidden_size, config.hidden_size, input_is_parallel=True, perform_initialization=False) scaling = self.head_size ** -0.5 rotary_dim = int(self.head_size * config.rotary_pct) assert rotary_dim % 2 == 0 self.attn = GPTNeoXCacheFlowAttention(scaling, rotary_dim) def forward( self, position_ids: torch.LongTensor, hidden_states: torch.Tensor, kv_cache: KVCache, input_metadata: InputMetadata, cache_event: Optional[torch.cuda.Event], ) -> torch.Tensor: qkv, _ = self.query_key_value(hidden_states) q, k, v = qkv.chunk(chunks=3, dim=-1) k_cache, v_cache = kv_cache attn_output = self.attn( position_ids, q, k, v, k_cache, v_cache, input_metadata, cache_event) output, _ = self.dense(attn_output) return output class GPTNeoXMLP(nn.Module): def __init__(self, config): super().__init__() self.dense_h_to_4h = ColumnParallelLinear(config.hidden_size, config.intermediate_size, gather_output=False, perform_initialization=False) self.dense_4h_to_h = RowParallelLinear(config.intermediate_size, config.hidden_size, input_is_parallel=True, perform_initialization=False) if config.hidden_act != 'gelu': raise ValueError(f'Unsupported activation: {config.hidden_act}. ' 'Only gelu is supported for now.') self.act = torch.nn.GELU() def forward(self, hidden_states): hidden_states, _ = self.dense_h_to_4h(hidden_states) hidden_states = self.act(hidden_states) hidden_states, _ = self.dense_4h_to_h(hidden_states) return hidden_states class GPTNeoXLayer(nn.Module): def __init__(self, config): super().__init__() self.use_parallel_residual = config.use_parallel_residual self.input_layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) self.post_attention_layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) self.attention = GPTNeoXAttention(config) self.mlp = GPTNeoXMLP(config) def forward( self, position_ids: torch.LongTensor, hidden_states: torch.Tensor, kv_cache: KVCache, input_metadata: InputMetadata, cache_event: Optional[torch.cuda.Event], ) -> torch.Tensor: attn_input = self.input_layernorm(hidden_states) attn_output = self.attention( position_ids=position_ids, hidden_states=attn_input, kv_cache=kv_cache, input_metadata=input_metadata, cache_event=cache_event, ) if self.use_parallel_residual: # pseudocode: # x = x + attn(ln1(x)) + mlp(ln2(x)) mlp_input = self.post_attention_layernorm(hidden_states) mlp_output = self.mlp(mlp_input) hidden_states = mlp_output + attn_output + hidden_states else: # pseudocode: # x = x + attn(ln1(x)) # x = x + mlp(ln2(x)) attn_output = attn_output + hidden_states mlp_input = self.post_attention_layernorm(attn_output) mlp_output = self.mlp(mlp_input) hidden_states = mlp_output + attn_output return hidden_states class GPTNeoXModel(nn.Module): def __init__(self, config): super().__init__() self.config = config self.embed_in = VocabParallelEmbedding(config.vocab_size, config.hidden_size, perform_initialization=False) self.layers = nn.ModuleList([GPTNeoXLayer(config) for _ in range(config.num_hidden_layers)]) self.final_layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) def forward( self, input_ids: torch.LongTensor, position_ids: torch.LongTensor, kv_caches: List[KVCache], input_metadata: InputMetadata, cache_events: Optional[List[torch.cuda.Event]], ) -> torch.Tensor: hidden_states = self.embed_in(input_ids) for i in range(len(self.layers)): if cache_events is None: cache_event = None else: cache_event = cache_events[i] layer = self.layers[i] hidden_states = layer( position_ids, hidden_states, kv_caches[i], input_metadata, cache_event, ) hidden_states = self.final_layer_norm(hidden_states) return hidden_states class GPTNeoXForCausalLM(nn.Module): def __init__(self, config): super().__init__() self.config = config self.gpt_neox = GPTNeoXModel(config) self.embed_out = ColumnParallelLinear(config.hidden_size, config.vocab_size, bias=False, gather_output=False, perform_initialization=False) self.sampler = Sampler(config.vocab_size) def forward( self, input_ids: torch.LongTensor, positions: torch.LongTensor, kv_caches: List[KVCache], input_metadata: InputMetadata, cache_events: Optional[List[torch.cuda.Event]], ) -> Dict[int, SequenceOutputs]: hidden_states = self.gpt_neox( input_ids, positions, kv_caches, input_metadata, cache_events) next_tokens = self.sampler( self.embed_out.weight, hidden_states, input_metadata) return next_tokens _column_parallel_weights = ["embed_in.weight", "embed_out.weight", "dense_h_to_4h.weight", "dense_h_to_4h.bias"] _row_parallel_weights = ["dense.weight", "dense_4h_to_h.weight"] def load_weights(self, model_name_or_path: str, cache_dir: Optional[str] = None, use_np_cache: bool = False): tensor_model_parallel_rank = get_tensor_model_parallel_rank() state_dict = self.state_dict() for name, loaded_weight in hf_model_weights_iterator( model_name_or_path, cache_dir, use_np_cache): if ("attention.bias" in name or "attention.masked_bias" in name or "rotary_emb.inv_freq" in name): continue param = state_dict[name] if "query_key_value" in name: # NOTE(woosuk): GPT-NeoX's fused QKV has the shape of # [num_heads * 3 * head_size, hidden_size], while the # required shape is [3 * num_heads * head_size, hidden_size]. # Thus, we need weight conversion. shard_size = param.shape[0] loaded_weight = loaded_weight[shard_size * tensor_model_parallel_rank :shard_size * (tensor_model_parallel_rank + 1)] num_heads = self.config.num_attention_heads hidden_size = self.config.hidden_size head_size = hidden_size // num_heads if 'query_key_value.weight' in name: loaded_weight = loaded_weight.view(-1, 3, head_size, hidden_size) loaded_weight = loaded_weight.transpose(0, 1) loaded_weight = loaded_weight.reshape(-1, hidden_size) elif 'query_key_value.bias' in name: loaded_weight = loaded_weight.view(-1, 3, head_size) loaded_weight = loaded_weight.transpose(0, 1) loaded_weight = loaded_weight.reshape(-1) else: raise ValueError(f"Unexpected weight name: {name}") load_tensor_parallel_weights(param, loaded_weight, name, self._column_parallel_weights, self._row_parallel_weights) def initialize_dummy_weights(self) -> None: for param in self.state_dict().values(): param.data.uniform_(-1e-3, 1e-3)