[🚀 Ready to be merged] Added support for Jais models (#3183)
This commit is contained in:
parent
3bbff9e5ab
commit
4c07dd28c0
@ -76,6 +76,7 @@ vLLM seamlessly supports many Hugging Face models, including the following archi
|
||||
- GPT-NeoX (`EleutherAI/gpt-neox-20b`, `databricks/dolly-v2-12b`, `stabilityai/stablelm-tuned-alpha-7b`, etc.)
|
||||
- InternLM (`internlm/internlm-7b`, `internlm/internlm-chat-7b`, etc.)
|
||||
- InternLM2 (`internlm/internlm2-7b`, `internlm/internlm2-chat-7b`, etc.)
|
||||
- Jais (`core42/jais-13b`, `core42/jais-13b-chat`, `core42/jais-30b-v3`, `core42/jais-30b-chat-v3`, etc.)
|
||||
- LLaMA & LLaMA-2 (`meta-llama/Llama-2-70b-hf`, `lmsys/vicuna-13b-v1.3`, `young-geng/koala`, `openlm-research/open_llama_13b`, etc.)
|
||||
- Mistral (`mistralai/Mistral-7B-v0.1`, `mistralai/Mistral-7B-Instruct-v0.1`, etc.)
|
||||
- Mixtral (`mistralai/Mixtral-8x7B-v0.1`, `mistralai/Mixtral-8x7B-Instruct-v0.1`, etc.)
|
||||
|
@ -66,7 +66,11 @@ Alongside each architecture, we include some popular models that use it.
|
||||
* - :code:`InternLM2ForCausalLM`
|
||||
- InternLM2
|
||||
- :code:`internlm/internlm2-7b`, :code:`internlm/internlm2-chat-7b`, etc.
|
||||
-
|
||||
-
|
||||
* - :code:`JAISLMHeadModel`
|
||||
- Jais
|
||||
- :code:`core42/jais-13b`, :code:`core42/jais-13b-chat`, :code:`core42/jais-30b-v3`, :code:`core42/jais-30b-chat-v3`, etc.
|
||||
-
|
||||
* - :code:`LlamaForCausalLM`
|
||||
- LLaMA, LLaMA-2, Vicuna, Alpaca, Yi
|
||||
- :code:`meta-llama/Llama-2-13b-hf`, :code:`meta-llama/Llama-2-70b-hf`, :code:`openlm-research/open_llama_13b`, :code:`lmsys/vicuna-13b-v1.3`, :code:`01-ai/Yi-6B`, :code:`01-ai/Yi-34B`, etc.
|
||||
|
@ -27,6 +27,7 @@ _MODELS = {
|
||||
"GPTNeoXForCausalLM": ("gpt_neox", "GPTNeoXForCausalLM"),
|
||||
"InternLMForCausalLM": ("llama", "LlamaForCausalLM"),
|
||||
"InternLM2ForCausalLM": ("internlm2", "InternLM2ForCausalLM"),
|
||||
"JAISLMHeadModel": ("jais", "JAISLMHeadModel"),
|
||||
"LlamaForCausalLM": ("llama", "LlamaForCausalLM"),
|
||||
# For decapoda-research/llama-*
|
||||
"LLaMAForCausalLM": ("llama", "LlamaForCausalLM"),
|
||||
|
@ -242,8 +242,7 @@ class GPT2LMHeadModel(nn.Module):
|
||||
logits: torch.Tensor,
|
||||
sampling_metadata: SamplingMetadata,
|
||||
) -> Optional[SamplerOutput]:
|
||||
next_tokens = self.sampler(self.lm_head_weight, logits,
|
||||
sampling_metadata)
|
||||
next_tokens = self.sampler(logits, sampling_metadata)
|
||||
return next_tokens
|
||||
|
||||
def load_weights(self,
|
||||
|
351
vllm/model_executor/models/jais.py
Normal file
351
vllm/model_executor/models/jais.py
Normal file
@ -0,0 +1,351 @@
|
||||
# coding=utf-8
|
||||
# Adapted from
|
||||
# https://huggingface.co/core42/jais-30b-chat-v3/blob/main/modeling_jais.py
|
||||
# Copyright 2023 The vLLM team.
|
||||
# Copyright 2023 the Jais authors and HuggingFace Inc. team. All rights
|
||||
# reserved.
|
||||
# Copyright 2023 Cerebras Systems.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Inference-only Jais model compatible with HuggingFace weights."""
|
||||
|
||||
import math
|
||||
from typing import List, Optional, Tuple
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
from vllm.transformers_utils.configs import JAISConfig
|
||||
|
||||
from vllm.model_executor.input_metadata import InputMetadata
|
||||
from vllm.model_executor.layers.attention import Attention
|
||||
from vllm.model_executor.layers.linear import (
|
||||
ColumnParallelLinear,
|
||||
LinearMethodBase,
|
||||
QKVParallelLinear,
|
||||
RowParallelLinear,
|
||||
)
|
||||
from vllm.model_executor.layers.logits_processor import LogitsProcessor
|
||||
from vllm.model_executor.layers.sampler import Sampler
|
||||
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
||||
VocabParallelEmbedding, )
|
||||
from vllm.model_executor.parallel_utils.parallel_state import (
|
||||
get_tensor_model_parallel_world_size,
|
||||
get_tensor_model_parallel_rank,
|
||||
)
|
||||
from vllm.model_executor.weight_utils import (
|
||||
default_weight_loader,
|
||||
hf_model_weights_iterator,
|
||||
)
|
||||
from vllm.sequence import SamplerOutput
|
||||
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
||||
|
||||
KVCache = Tuple[torch.Tensor, torch.Tensor]
|
||||
|
||||
|
||||
class SwiGLUActivation(nn.Module):
|
||||
|
||||
def forward(self, x1: torch.Tensor, x2: torch.Tensor) -> torch.Tensor:
|
||||
return x1 * nn.functional.silu(x2)
|
||||
|
||||
|
||||
def _get_alibi_slopes(n):
|
||||
|
||||
def get_slopes_power_of_2(n):
|
||||
start = 2**(-(2**-(math.log2(n) - 3)))
|
||||
ratio = start
|
||||
return [start * ratio**i for i in range(n)]
|
||||
|
||||
if math.log2(n).is_integer():
|
||||
return get_slopes_power_of_2(n)
|
||||
else:
|
||||
closest_power_of_2 = 2**math.floor(math.log2(n))
|
||||
return (get_slopes_power_of_2(closest_power_of_2) + _get_alibi_slopes(
|
||||
2 * closest_power_of_2)[0::2][:n - closest_power_of_2])
|
||||
|
||||
|
||||
class JAISAttention(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: JAISConfig,
|
||||
linear_method: Optional[LinearMethodBase] = None,
|
||||
):
|
||||
super().__init__()
|
||||
self.hidden_size = config.hidden_size
|
||||
total_num_heads = config.num_attention_heads
|
||||
tensor_model_parallel_world_size = (
|
||||
get_tensor_model_parallel_world_size())
|
||||
assert total_num_heads % tensor_model_parallel_world_size == 0
|
||||
self.num_heads = total_num_heads // tensor_model_parallel_world_size
|
||||
self.head_dim = self.hidden_size // total_num_heads
|
||||
if hasattr(config, "scale_qk_dot_by_d"):
|
||||
config.mup_scale_qk_dot_by_d = config.scale_qk_dot_by_d
|
||||
self.attn_scale_power = 1.0 if config.mup_scale_qk_dot_by_d else 0.5
|
||||
self.scale = self.head_dim**-self.attn_scale_power
|
||||
|
||||
self.c_attn = QKVParallelLinear(
|
||||
self.hidden_size,
|
||||
self.head_dim,
|
||||
total_num_heads,
|
||||
bias=True,
|
||||
linear_method=linear_method,
|
||||
)
|
||||
self.c_proj = RowParallelLinear(
|
||||
self.hidden_size,
|
||||
self.hidden_size,
|
||||
bias=True,
|
||||
linear_method=linear_method,
|
||||
)
|
||||
|
||||
tp_rank = get_tensor_model_parallel_rank()
|
||||
head_start = tp_rank * self.num_heads
|
||||
head_end = (tp_rank + 1) * self.num_heads
|
||||
alibi_slopes = _get_alibi_slopes(total_num_heads)
|
||||
alibi_slopes = alibi_slopes[head_start:head_end]
|
||||
self.attn = Attention(
|
||||
self.num_heads,
|
||||
self.head_dim,
|
||||
scale=self.scale,
|
||||
alibi_slopes=alibi_slopes,
|
||||
)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
kv_cache: KVCache,
|
||||
input_metadata: InputMetadata,
|
||||
) -> torch.Tensor:
|
||||
qkv, _ = self.c_attn(hidden_states)
|
||||
q, k, v = qkv.chunk(chunks=3, dim=-1)
|
||||
key_cache, value_cache = kv_cache
|
||||
attn_output = self.attn(q, k, v, key_cache, value_cache,
|
||||
input_metadata)
|
||||
attn_output, _ = self.c_proj(attn_output)
|
||||
return attn_output
|
||||
|
||||
|
||||
class JAISMLP(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
intermediate_size: int,
|
||||
config: JAISConfig,
|
||||
linear_method: Optional[LinearMethodBase] = None,
|
||||
):
|
||||
super().__init__()
|
||||
hidden_size = config.hidden_size
|
||||
self.swiglu = config.activation_function == "swiglu"
|
||||
self.c_fc = ColumnParallelLinear(
|
||||
hidden_size,
|
||||
intermediate_size,
|
||||
bias=True,
|
||||
linear_method=linear_method,
|
||||
)
|
||||
self.c_fc2 = (ColumnParallelLinear(
|
||||
hidden_size,
|
||||
intermediate_size,
|
||||
bias=True,
|
||||
linear_method=linear_method,
|
||||
) if self.swiglu else None)
|
||||
self.c_proj = RowParallelLinear(
|
||||
intermediate_size,
|
||||
hidden_size,
|
||||
bias=True,
|
||||
linear_method=linear_method,
|
||||
)
|
||||
|
||||
self.act = SwiGLUActivation()
|
||||
|
||||
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||
if self.swiglu:
|
||||
hidden_states2, _ = self.c_fc2(hidden_states)
|
||||
hidden_states, _ = self.c_fc(hidden_states)
|
||||
hidden_states = (self.act(hidden_states, hidden_states2)
|
||||
if self.swiglu else self.act(hidden_states))
|
||||
hidden_states, _ = self.c_proj(hidden_states)
|
||||
return hidden_states
|
||||
|
||||
|
||||
class JAISBlock(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: JAISConfig,
|
||||
linear_method: Optional[LinearMethodBase] = None,
|
||||
):
|
||||
super().__init__()
|
||||
hidden_size = config.hidden_size
|
||||
inner_dim = (config.n_inner if config.n_inner is not None else 4 *
|
||||
hidden_size)
|
||||
|
||||
self.ln_1 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
|
||||
self.attn = JAISAttention(config, linear_method)
|
||||
self.ln_2 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
|
||||
self.mlp = JAISMLP(inner_dim, config, linear_method)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
kv_cache: KVCache,
|
||||
input_metadata: InputMetadata,
|
||||
) -> torch.Tensor:
|
||||
residual = hidden_states
|
||||
hidden_states = self.ln_1(hidden_states)
|
||||
attn_output = self.attn(
|
||||
hidden_states=hidden_states,
|
||||
kv_cache=kv_cache,
|
||||
input_metadata=input_metadata,
|
||||
)
|
||||
# residual connection
|
||||
hidden_states = attn_output + residual
|
||||
|
||||
residual = hidden_states
|
||||
hidden_states = self.ln_2(hidden_states)
|
||||
feed_forward_hidden_states = self.mlp(hidden_states)
|
||||
# residual connection
|
||||
hidden_states = residual + feed_forward_hidden_states
|
||||
return hidden_states
|
||||
|
||||
|
||||
class JAISModel(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: JAISConfig,
|
||||
linear_method: Optional[LinearMethodBase] = None,
|
||||
):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
assert not config.add_cross_attention
|
||||
assert not config.scale_attn_by_inverse_layer_idx
|
||||
assert not config.reorder_and_upcast_attn
|
||||
self.embed_dim = config.hidden_size
|
||||
self.wte = VocabParallelEmbedding(config.vocab_size, self.embed_dim)
|
||||
self.wpe = (nn.Embedding(config.max_position_embeddings,
|
||||
self.embed_dim)
|
||||
if config.position_embedding_type != "alibi" else None)
|
||||
if hasattr(config, "embeddings_scale"):
|
||||
self.embeddings_scale = config.embeddings_scale
|
||||
else:
|
||||
self.embeddings_scale = config.mup_embeddings_scale
|
||||
self.h = nn.ModuleList([
|
||||
JAISBlock(config, linear_method)
|
||||
for _ in range(config.num_hidden_layers)
|
||||
])
|
||||
self.ln_f = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_epsilon)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
position_ids: torch.Tensor,
|
||||
kv_caches: List[KVCache],
|
||||
input_metadata: InputMetadata,
|
||||
) -> torch.Tensor:
|
||||
inputs_embeds = self.wte(input_ids)
|
||||
if self.wpe is not None:
|
||||
position_embeds = self.wpe(position_ids)
|
||||
hidden_states = inputs_embeds + position_embeds
|
||||
else:
|
||||
hidden_states = inputs_embeds
|
||||
hidden_states *= torch.tensor(float(self.embeddings_scale),
|
||||
dtype=hidden_states.dtype)
|
||||
|
||||
for i in range(len(self.h)):
|
||||
layer = self.h[i]
|
||||
hidden_states = layer(hidden_states, kv_caches[i], input_metadata)
|
||||
|
||||
hidden_states = self.ln_f(hidden_states)
|
||||
return hidden_states
|
||||
|
||||
|
||||
class JAISLMHeadModel(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: JAISConfig,
|
||||
linear_method: Optional[LinearMethodBase] = None,
|
||||
):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.linear_method = linear_method
|
||||
self.transformer = JAISModel(config, linear_method)
|
||||
self.lm_head_weight = self.transformer.wte.weight
|
||||
if hasattr(config, "width_scale"):
|
||||
self.output_logits_scale = config.width_scale
|
||||
else:
|
||||
self.output_logits_scale = (config.mup_output_alpha *
|
||||
config.mup_width_scale)
|
||||
self.logits_processor = LogitsProcessor(vocab_size=config.vocab_size,
|
||||
scale=self.output_logits_scale)
|
||||
self.sampler = Sampler()
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
positions: torch.Tensor,
|
||||
kv_caches: List[KVCache],
|
||||
input_metadata: InputMetadata,
|
||||
) -> torch.Tensor:
|
||||
hidden_states = self.transformer(input_ids, positions, kv_caches,
|
||||
input_metadata)
|
||||
return hidden_states
|
||||
|
||||
def compute_logits(self, hidden_states: torch.Tensor,
|
||||
sampling_metadata: SamplingMetadata) -> torch.Tensor:
|
||||
logits = self.logits_processor(self.lm_head_weight, hidden_states,
|
||||
sampling_metadata)
|
||||
return logits
|
||||
|
||||
def sample(
|
||||
self,
|
||||
logits: torch.Tensor,
|
||||
sampling_metadata: SamplingMetadata,
|
||||
) -> Optional[SamplerOutput]:
|
||||
next_tokens = self.sampler(logits, sampling_metadata)
|
||||
return next_tokens
|
||||
|
||||
def load_weights(
|
||||
self,
|
||||
model_name_or_path: str,
|
||||
cache_dir: Optional[str] = None,
|
||||
load_format: str = "auto",
|
||||
revision: Optional[str] = None,
|
||||
):
|
||||
params_dict = dict(self.named_parameters(remove_duplicate=False))
|
||||
for name, loaded_weight in hf_model_weights_iterator(
|
||||
model_name_or_path, cache_dir, load_format, revision):
|
||||
if "lm_head.weight" in name:
|
||||
# GPT-2 ties the weights of the embedding layer and the final
|
||||
# linear layer.
|
||||
continue
|
||||
if ".attn.bias" in name or ".attn.masked_bias" in name:
|
||||
# Skip attention mask.
|
||||
# NOTE: "c_attn.bias" should not be skipped.
|
||||
continue
|
||||
if "relative_pe" in name:
|
||||
continue
|
||||
if not name.startswith("transformer."):
|
||||
name = "transformer." + name
|
||||
param = params_dict[name]
|
||||
# The HF's GPT-2 implementation uses Conv1D instead of Linear.
|
||||
# Because of this, we need to transpose the weights.
|
||||
# Note(zhuohan): the logic below might break quantized models.
|
||||
for conv1d_weight_name in ["c_attn", "c_proj", "c_fc"]:
|
||||
if conv1d_weight_name not in name:
|
||||
continue
|
||||
if not name.endswith(".weight"):
|
||||
continue
|
||||
loaded_weight = loaded_weight.t()
|
||||
weight_loader = getattr(param, "weight_loader",
|
||||
default_weight_loader)
|
||||
weight_loader(param, loaded_weight)
|
@ -10,6 +10,7 @@ _CONFIG_REGISTRY = {
|
||||
"RefinedWeb": RWConfig, # For tiiuae/falcon-40b(-instruct)
|
||||
"RefinedWebModel": RWConfig, # For tiiuae/falcon-7b(-instruct)
|
||||
"starcoder2": Starcoder2Config,
|
||||
"jais": JAISConfig,
|
||||
}
|
||||
|
||||
|
||||
|
@ -5,10 +5,12 @@ from vllm.transformers_utils.configs.mpt import MPTConfig
|
||||
# `FalconConfig` class from the official HuggingFace transformers library.
|
||||
from vllm.transformers_utils.configs.falcon import RWConfig
|
||||
from vllm.transformers_utils.configs.starcoder2 import Starcoder2Config
|
||||
from vllm.transformers_utils.configs.jais import JAISConfig
|
||||
|
||||
__all__ = [
|
||||
"ChatGLMConfig",
|
||||
"MPTConfig",
|
||||
"RWConfig",
|
||||
"Starcoder2Config",
|
||||
"JAISConfig",
|
||||
]
|
||||
|
234
vllm/transformers_utils/configs/jais.py
Normal file
234
vllm/transformers_utils/configs/jais.py
Normal file
@ -0,0 +1,234 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2023 The OpenAI Team Authors and HuggingFace Inc. team.
|
||||
# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
|
||||
# Copyright 2023 Cerebras Systems.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""JAIS configuration"""
|
||||
|
||||
from transformers.configuration_utils import PretrainedConfig
|
||||
from transformers.utils import logging
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
class JAISConfig(PretrainedConfig):
|
||||
"""
|
||||
This is the configuration class to store the configuration of a
|
||||
[`JAISModel`]. It is used to instantiate a JAIS model according to the
|
||||
specified arguments, defining the model architecture.
|
||||
|
||||
Configuration objects inherit from [`PretrainedConfig`] and can be used
|
||||
to control the model outputs. Read the documentation from
|
||||
[`PretrainedConfig`] for more information.
|
||||
|
||||
|
||||
Args:
|
||||
vocab_size (`int`, *optional*, defaults to 50257):
|
||||
Vocabulary size of the JAIS model. Defines the number of different
|
||||
tokens that can be represented by the
|
||||
`inputs_ids` passed when calling [`JAISModel`].
|
||||
n_positions (`int`, *optional*, defaults to 1024):
|
||||
The maximum sequence length that this model might ever be used
|
||||
with. Typically set this to something large just in case
|
||||
(e.g., 512 or 1024 or 2048).
|
||||
n_embd (`int`, *optional*, defaults to 768):
|
||||
Dimensionality of the embeddings and hidden states.
|
||||
n_layer (`int`, *optional*, defaults to 12):
|
||||
Number of hidden layers in the Transformer encoder.
|
||||
n_head (`int`, *optional*, defaults to 12):
|
||||
Number of attention heads for each attention layer in the
|
||||
Transformer encoder.
|
||||
n_inner (`int`, *optional*, defaults to None):
|
||||
Dimensionality of the inner feed-forward layers. `None` will set
|
||||
it to 4 times n_embd
|
||||
activation_function (`str`, *optional*, defaults to `"gelu"`):
|
||||
Activation function, to be selected in the list
|
||||
`["relu", "silu", "gelu", "tanh", "gelu_new", "swiglu"]`.
|
||||
resid_pdrop (`float`, *optional*, defaults to 0.1):
|
||||
The dropout probability for all fully connected layers in
|
||||
the embeddings, encoder, and pooler.
|
||||
embd_pdrop (`float`, *optional*, defaults to 0.1):
|
||||
The dropout ratio for the embeddings.
|
||||
attn_pdrop (`float`, *optional*, defaults to 0.1):
|
||||
The dropout ratio for the attention.
|
||||
layer_norm_epsilon (`float`, *optional*, defaults to 1e-5):
|
||||
The epsilon to use in the layer normalization layers.
|
||||
initializer_range (`float`, *optional*, defaults to 0.02):
|
||||
The standard deviation of the truncated_normal_initializer for
|
||||
initializing all weight matrices.
|
||||
scale_attn_weights (`bool`, *optional*, defaults to `True`):
|
||||
Scale attention weights by dividing by sqrt(hidden_size)..
|
||||
use_cache (`bool`, *optional*, defaults to `True`):
|
||||
Whether or not the model should return the last key/values
|
||||
attentions (not used by all models).
|
||||
scale_attn_by_inverse_layer_idx (`bool`, *optional*,
|
||||
defaults to `False`):
|
||||
Whether to additionally scale attention weights by
|
||||
`1 / layer_idx + 1`.
|
||||
reorder_and_upcast_attn (`bool`, *optional*, defaults to `False`):
|
||||
Whether to scale keys (K) prior to computing attention
|
||||
(dot-product)
|
||||
and upcast attention dot-product/softmax to float() when training
|
||||
with mixed precision.
|
||||
position_embedding_type (`str`, *optional*, defaults to `"learned"`):
|
||||
Positional embedding can be either `"alibi"` or `"learned"`.
|
||||
mup_width_scale (`float`, *optional*, defaults to 1.0):
|
||||
muP parameter to scale learning rate and initializers. Calculated
|
||||
as (`d_model,0 / d_model`), where
|
||||
`d_model` is the model's width and `d_model,0` is the proxy
|
||||
model's width.
|
||||
mup_embeddings_scale (`float`, *optional*, defaults to 1.0):
|
||||
muP parameter to scale token and position embeddings.
|
||||
mup_output_alpha (`float`, *optional*, defaults to 1.0):
|
||||
muP parameter to scale output logits
|
||||
(`output_logits_scale = mup_output_alpha * mup_width_scale`).
|
||||
mup_scale_qk_dot_by_d (`bool`, *optional*, defaults to `False`):
|
||||
Scale attention weights by dividing by hidden_size instead of
|
||||
sqrt(hidden_size). Need to set scale_attn_weights to `True` as
|
||||
well.
|
||||
alibi_scaling (`Dict`, *optional*):
|
||||
Dictionary containing the scaling configuration for ALiBi
|
||||
embeddings. Currently only supports linear
|
||||
scaling strategy. Can specify either the scaling `factor` (must be
|
||||
a float greater than 1) for fixed scaling
|
||||
or `train_seq_len` for dynamic scaling on input samples with
|
||||
sequence length > `train_seq_len`. The expected
|
||||
formats are `{"type": strategy name, "factor": scaling factor}` or
|
||||
`{"type": strategy name,
|
||||
"train_seq_len": training sequence length}`.
|
||||
architectures (`List`, *optional*, defaults to ['JAISLMHeadModel']):
|
||||
architecture names for Jais.
|
||||
|
||||
Example:
|
||||
|
||||
```python
|
||||
>>> from transformers import JAISConfig, JAISModel
|
||||
|
||||
>>> # Initializing a JAIS configuration
|
||||
>>> configuration = JAISConfig()
|
||||
|
||||
>>> # Initializing a model (with random weights) from the configuration
|
||||
>>> model = JAISModel(configuration)
|
||||
|
||||
>>> # Accessing the model configuration
|
||||
>>> configuration = model.config
|
||||
```"""
|
||||
|
||||
model_type = "jais"
|
||||
keys_to_ignore_at_inference = ["past_key_values"]
|
||||
attribute_map = {
|
||||
"hidden_size": "n_embd",
|
||||
"max_position_embeddings": "n_positions",
|
||||
"num_attention_heads": "n_head",
|
||||
"num_hidden_layers": "n_layer",
|
||||
}
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
vocab_size=50257,
|
||||
n_positions=1024,
|
||||
n_embd=768,
|
||||
n_layer=12,
|
||||
n_head=12,
|
||||
n_inner=None,
|
||||
activation_function="gelu_new",
|
||||
resid_pdrop=0.1,
|
||||
embd_pdrop=0.1,
|
||||
attn_pdrop=0.1,
|
||||
layer_norm_epsilon=1e-5,
|
||||
initializer_range=0.02,
|
||||
scale_attn_weights=True,
|
||||
use_cache=True,
|
||||
bos_token_id=50256,
|
||||
eos_token_id=50256,
|
||||
scale_attn_by_inverse_layer_idx=False,
|
||||
reorder_and_upcast_attn=False,
|
||||
position_embedding_type="learned",
|
||||
mup_width_scale=1.0,
|
||||
mup_embeddings_scale=1.0,
|
||||
mup_output_alpha=1.0,
|
||||
mup_scale_qk_dot_by_d=False,
|
||||
alibi_scaling=None,
|
||||
architectures=None,
|
||||
**kwargs,
|
||||
):
|
||||
self.vocab_size = vocab_size
|
||||
self.n_positions = n_positions
|
||||
self.n_embd = n_embd
|
||||
self.n_layer = n_layer
|
||||
self.n_head = n_head
|
||||
self.n_inner = n_inner
|
||||
self.activation_function = activation_function
|
||||
self.resid_pdrop = resid_pdrop
|
||||
self.embd_pdrop = embd_pdrop
|
||||
self.attn_pdrop = attn_pdrop
|
||||
self.layer_norm_epsilon = layer_norm_epsilon
|
||||
self.initializer_range = initializer_range
|
||||
self.scale_attn_weights = scale_attn_weights
|
||||
self.use_cache = use_cache
|
||||
self.scale_attn_by_inverse_layer_idx = scale_attn_by_inverse_layer_idx
|
||||
self.reorder_and_upcast_attn = reorder_and_upcast_attn
|
||||
|
||||
self.bos_token_id = bos_token_id
|
||||
self.eos_token_id = eos_token_id
|
||||
|
||||
self.position_embedding_type = position_embedding_type
|
||||
self.mup_width_scale = mup_width_scale
|
||||
self.mup_embeddings_scale = mup_embeddings_scale
|
||||
self.mup_output_alpha = mup_output_alpha
|
||||
self.mup_scale_qk_dot_by_d = mup_scale_qk_dot_by_d
|
||||
|
||||
self.alibi_scaling = alibi_scaling
|
||||
self._alibi_scaling_validation()
|
||||
if architectures is None:
|
||||
architectures = ["JAISLMHeadModel"]
|
||||
|
||||
super().__init__(
|
||||
bos_token_id=bos_token_id,
|
||||
eos_token_id=eos_token_id,
|
||||
architectures=architectures,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
def _alibi_scaling_validation(self):
|
||||
"""
|
||||
Validate the `alibi_scaling` configuration.
|
||||
"""
|
||||
if self.alibi_scaling is None:
|
||||
return
|
||||
|
||||
if (not isinstance(self.alibi_scaling, dict)
|
||||
or len(self.alibi_scaling) != 2):
|
||||
raise ValueError(
|
||||
"`alibi_scaling` must be a dictionary with two fields,"
|
||||
"`type` and `factor` or `type` and `train_seq_len`, "
|
||||
f"got {self.alibi_scaling}")
|
||||
alibi_scaling_type = self.alibi_scaling.get("type", None)
|
||||
alibi_scaling_factor = self.alibi_scaling.get("factor", None)
|
||||
alibi_dynamic_scaling = self.alibi_scaling.get("train_seq_len", None)
|
||||
if alibi_scaling_type is None or alibi_scaling_type != "linear":
|
||||
raise ValueError(f"`alibi_scaling`'s type field must be 'linear',"
|
||||
f"got {alibi_scaling_type}")
|
||||
if (alibi_scaling_factor is not None
|
||||
and not isinstance(alibi_scaling_factor, float)
|
||||
or alibi_scaling_factor <= 1.0):
|
||||
raise ValueError(
|
||||
f"`alibi_scaling`'s factor field must be a float > 1.0,"
|
||||
f"got {alibi_scaling_factor}")
|
||||
if (alibi_dynamic_scaling is not None
|
||||
and not isinstance(alibi_dynamic_scaling, int)
|
||||
or alibi_dynamic_scaling <= 1):
|
||||
raise ValueError(
|
||||
f"`alibi_scaling`'s `train_seq_len` field must be an"
|
||||
f"integer > 1, got {alibi_dynamic_scaling}")
|
Loading…
x
Reference in New Issue
Block a user