[Model] Align nemotron config with final HF state and fix lm-eval-small (#7611)
This commit is contained in:
parent
37fd47e780
commit
44f26a9466
@ -1,11 +1,11 @@
|
|||||||
# bash .buildkite/lm-eval-harness/run-lm-eval-gsm-vllm-baseline.sh -m nvidia/Minitron-4B-Base -b auto -l 1000 -f 5 -t 1
|
# bash .buildkite/lm-eval-harness/run-lm-eval-gsm-vllm-baseline.sh -m mgoin/Minitron-4B-Base-FP8 -b auto -l 1000 -f 5 -t 1
|
||||||
model_name: "nvidia/Minitron-4B-Base"
|
model_name: "mgoin/Minitron-4B-Base-FP8"
|
||||||
tasks:
|
tasks:
|
||||||
- name: "gsm8k"
|
- name: "gsm8k"
|
||||||
metrics:
|
metrics:
|
||||||
- name: "exact_match,strict-match"
|
- name: "exact_match,strict-match"
|
||||||
value: 0.252
|
value: 0.233
|
||||||
- name: "exact_match,flexible-extract"
|
- name: "exact_match,flexible-extract"
|
||||||
value: 0.252
|
value: 0.236
|
||||||
limit: 1000
|
limit: 1000
|
||||||
num_fewshot: 5
|
num_fewshot: 5
|
@ -4,7 +4,7 @@ Meta-Llama-3-8B-Instruct-FP8-compressed-tensors.yaml
|
|||||||
Meta-Llama-3-8B-Instruct-INT8-compressed-tensors.yaml
|
Meta-Llama-3-8B-Instruct-INT8-compressed-tensors.yaml
|
||||||
Meta-Llama-3-8B-Instruct-nonuniform-compressed-tensors.yaml
|
Meta-Llama-3-8B-Instruct-nonuniform-compressed-tensors.yaml
|
||||||
Meta-Llama-3-8B-Instruct-Channelwise-compressed-tensors.yaml
|
Meta-Llama-3-8B-Instruct-Channelwise-compressed-tensors.yaml
|
||||||
Minitron-4B-Base.yaml
|
Minitron-4B-Base-FP8.yaml
|
||||||
Qwen2-1.5B-Instruct-INT8-compressed-tensors.yaml
|
Qwen2-1.5B-Instruct-INT8-compressed-tensors.yaml
|
||||||
Qwen2-1.5B-Instruct-FP8W8.yaml
|
Qwen2-1.5B-Instruct-FP8W8.yaml
|
||||||
Meta-Llama-3-8B-QQQ.yaml
|
Meta-Llama-3-8B-QQQ.yaml
|
||||||
|
@ -774,7 +774,7 @@ def get_rope(
|
|||||||
is_neox_style: bool = True,
|
is_neox_style: bool = True,
|
||||||
rope_scaling: Optional[Dict[str, Any]] = None,
|
rope_scaling: Optional[Dict[str, Any]] = None,
|
||||||
dtype: Optional[torch.dtype] = None,
|
dtype: Optional[torch.dtype] = None,
|
||||||
rotary_percent: float = 1.0,
|
partial_rotary_factor: float = 1.0,
|
||||||
) -> RotaryEmbedding:
|
) -> RotaryEmbedding:
|
||||||
if dtype is None:
|
if dtype is None:
|
||||||
dtype = torch.get_default_dtype()
|
dtype = torch.get_default_dtype()
|
||||||
@ -787,8 +787,8 @@ def get_rope(
|
|||||||
rope_scaling_args = tuple(rope_scaling_tuple.items())
|
rope_scaling_args = tuple(rope_scaling_tuple.items())
|
||||||
else:
|
else:
|
||||||
rope_scaling_args = None
|
rope_scaling_args = None
|
||||||
if rotary_percent < 1.0:
|
if partial_rotary_factor < 1.0:
|
||||||
rotary_dim = int(rotary_dim * rotary_percent)
|
rotary_dim = int(rotary_dim * partial_rotary_factor)
|
||||||
key = (head_size, rotary_dim, max_position, base, is_neox_style,
|
key = (head_size, rotary_dim, max_position, base, is_neox_style,
|
||||||
rope_scaling_args, dtype)
|
rope_scaling_args, dtype)
|
||||||
if key in _ROPE_DICT:
|
if key in _ROPE_DICT:
|
||||||
|
@ -53,7 +53,7 @@ from .utils import PPMissingLayer, is_pp_missing_parameter, make_layers
|
|||||||
# - There is no gate_proj, just up_proj
|
# - There is no gate_proj, just up_proj
|
||||||
# - Normal LayerNorm (with a +1 to the weights) instead of RMSNorm
|
# - Normal LayerNorm (with a +1 to the weights) instead of RMSNorm
|
||||||
# - Squared ReLU instead of SwiGLU
|
# - Squared ReLU instead of SwiGLU
|
||||||
# - Adds a rotary_percent to RoPE
|
# - Adds a partial_rotary_factor to RoPE
|
||||||
|
|
||||||
|
|
||||||
def _cast_if_autocast_enabled(*args):
|
def _cast_if_autocast_enabled(*args):
|
||||||
@ -161,7 +161,7 @@ class NemotronAttention(nn.Module):
|
|||||||
self.kv_size = self.num_kv_heads * self.head_dim
|
self.kv_size = self.num_kv_heads * self.head_dim
|
||||||
self.scaling = self.head_dim**-0.5
|
self.scaling = self.head_dim**-0.5
|
||||||
self.rope_theta = rope_theta
|
self.rope_theta = rope_theta
|
||||||
self.rotary_percent = config.rope_percent
|
self.partial_rotary_factor = config.partial_rotary_factor
|
||||||
self.max_position_embeddings = max_position_embeddings
|
self.max_position_embeddings = max_position_embeddings
|
||||||
|
|
||||||
self.qkv_proj = QKVParallelLinear(
|
self.qkv_proj = QKVParallelLinear(
|
||||||
@ -187,7 +187,7 @@ class NemotronAttention(nn.Module):
|
|||||||
max_position=max_position_embeddings,
|
max_position=max_position_embeddings,
|
||||||
base=rope_theta,
|
base=rope_theta,
|
||||||
rope_scaling=rope_scaling,
|
rope_scaling=rope_scaling,
|
||||||
rotary_percent=self.rotary_percent,
|
partial_rotary_factor=self.partial_rotary_factor,
|
||||||
)
|
)
|
||||||
self.attn = Attention(self.num_heads,
|
self.attn = Attention(self.num_heads,
|
||||||
self.head_dim,
|
self.head_dim,
|
||||||
|
@ -35,20 +35,20 @@ class NemotronConfig(PretrainedConfig):
|
|||||||
|
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
vocab_size (`int`, *optional*, defaults to 32000):
|
vocab_size (`int`, *optional*, defaults to 256000):
|
||||||
Vocabulary size of the Nemotron model. Defines the number of
|
Vocabulary size of the Nemotron model. Defines the number of
|
||||||
different tokens that can be represented by the
|
different tokens that can be represented by the
|
||||||
`inputs_ids` passed when calling [`NemotronModel`]
|
`inputs_ids` passed when calling [`NemotronModel`]
|
||||||
hidden_size (`int`, *optional*, defaults to 4096):
|
hidden_size (`int`, *optional*, defaults to 6144):
|
||||||
Dimension of the hidden representations.
|
Dimension of the hidden representations.
|
||||||
intermediate_size (`int`, *optional*, defaults to 11008):
|
intermediate_size (`int`, *optional*, defaults to 24576):
|
||||||
Dimension of the MLP representations.
|
Dimension of the MLP representations.
|
||||||
num_hidden_layers (`int`, *optional*, defaults to 32):
|
num_hidden_layers (`int`, *optional*, defaults to 32):
|
||||||
Number of hidden layers in the Transformer decoder.
|
Number of hidden layers in the Transformer decoder.
|
||||||
num_attention_heads (`int`, *optional*, defaults to 32):
|
num_attention_heads (`int`, *optional*, defaults to 48):
|
||||||
Number of attention heads for each attention layer in the
|
Number of attention heads for each attention layer in the
|
||||||
Transformer decoder.
|
Transformer decoder.
|
||||||
head_dim (`int`, *optional*, defaults to None):
|
head_dim (`int`, *optional*):
|
||||||
Projection weights dimension in multi-head attention. Set to
|
Projection weights dimension in multi-head attention. Set to
|
||||||
hidden_size // num_attention_heads if None
|
hidden_size // num_attention_heads if None
|
||||||
num_key_value_heads (`int`, *optional*):
|
num_key_value_heads (`int`, *optional*):
|
||||||
@ -63,16 +63,16 @@ class NemotronConfig(PretrainedConfig):
|
|||||||
heads within that group. For more details checkout
|
heads within that group. For more details checkout
|
||||||
[this paper](https://arxiv.org/pdf/2305.13245.pdf). If it
|
[this paper](https://arxiv.org/pdf/2305.13245.pdf). If it
|
||||||
is not specified, will default to `num_attention_heads`.
|
is not specified, will default to `num_attention_heads`.
|
||||||
hidden_act (`str` or `function`, *optional*, defaults to `"silu"`):
|
hidden_act (`str` or `function`, *optional*, defaults to `"relu2"`):
|
||||||
The non-linear activation function (function or string) in the
|
The non-linear activation function (function or string) in the
|
||||||
decoder.
|
decoder.
|
||||||
max_position_embeddings (`int`, *optional*, defaults to 2048):
|
max_position_embeddings (`int`, *optional*, defaults to 4096):
|
||||||
The maximum sequence length that this model might ever be used
|
The maximum sequence length that this model might ever be used
|
||||||
with.
|
with.
|
||||||
initializer_range (`float`, *optional*, defaults to 0.02):
|
initializer_range (`float`, *optional*, defaults to 0.0134):
|
||||||
The standard deviation of the truncated_normal_initializer for
|
The standard deviation of the truncated_normal_initializer for
|
||||||
initializing all weight matrices.
|
initializing all weight matrices.
|
||||||
norm_eps (`float`, *optional*, defaults to 1e-06):
|
norm_eps (`float`, *optional*, defaults to 1e-05):
|
||||||
The epsilon used by the normalization layers.
|
The epsilon used by the normalization layers.
|
||||||
use_cache (`bool`, *optional*, defaults to `True`):
|
use_cache (`bool`, *optional*, defaults to `True`):
|
||||||
Whether or not the model should return the last key/values
|
Whether or not the model should return the last key/values
|
||||||
@ -80,21 +80,16 @@ class NemotronConfig(PretrainedConfig):
|
|||||||
`config.is_decoder=True`.
|
`config.is_decoder=True`.
|
||||||
pad_token_id (`int`, *optional*):
|
pad_token_id (`int`, *optional*):
|
||||||
Padding token id.
|
Padding token id.
|
||||||
bos_token_id (`int`, *optional*, defaults to 1):
|
bos_token_id (`int`, *optional*, defaults to 2):
|
||||||
Beginning of stream token id.
|
Beginning of stream token id.
|
||||||
eos_token_id (`int`, *optional*, defaults to 2):
|
eos_token_id (`int`, *optional*, defaults to 3):
|
||||||
End of stream token id.
|
End of stream token id.
|
||||||
tie_word_embeddings (`bool`, *optional*, defaults to `False`):
|
tie_word_embeddings (`bool`, *optional*, defaults to `False`):
|
||||||
Whether to tie weight embeddings
|
Whether to tie weight embeddings
|
||||||
rope_theta (`float`, *optional*, defaults to 10000.0):
|
rope_theta (`float`, *optional*, defaults to 10000.0):
|
||||||
The base period of the RoPE embeddings.
|
The base period of the RoPE embeddings.
|
||||||
rope_scaling (`Dict`, *optional*):
|
partial_rotary_factor (`float`, *optional*, defaults to 0.5):
|
||||||
Dictionary containing the scaling configuration for the RoPE
|
Percentage of the query and keys which will have rotary embedding.
|
||||||
embeddings. Currently supports two scaling strategies: linear
|
|
||||||
and dynamic. Their scaling factor must be a float greater than 1.
|
|
||||||
The expected format is `{"type": strategy name,
|
|
||||||
"factor": scaling factor}`. When using this flag, don't update
|
|
||||||
`max_position_embeddings` to the expected new maximum.
|
|
||||||
attention_bias (`bool`, *optional*, defaults to `False`):
|
attention_bias (`bool`, *optional*, defaults to `False`):
|
||||||
Whether to use a bias in the query, key, value and output
|
Whether to use a bias in the query, key, value and output
|
||||||
projection layers during self-attention.
|
projection layers during self-attention.
|
||||||
@ -106,13 +101,10 @@ class NemotronConfig(PretrainedConfig):
|
|||||||
|
|
||||||
```python
|
```python
|
||||||
>>> from transformers import NemotronModel, NemotronConfig
|
>>> from transformers import NemotronModel, NemotronConfig
|
||||||
|
|
||||||
>>> # Initializing a Nemotron nemotron-15b style configuration
|
>>> # Initializing a Nemotron nemotron-15b style configuration
|
||||||
>>> configuration = NemotronConfig()
|
>>> configuration = NemotronConfig()
|
||||||
|
|
||||||
>>> # Initializing a model from the nemotron-15b style configuration
|
>>> # Initializing a model from the nemotron-15b style configuration
|
||||||
>>> model = NemotronModel(configuration)
|
>>> model = NemotronModel(configuration)
|
||||||
|
|
||||||
>>> # Accessing the model configuration
|
>>> # Accessing the model configuration
|
||||||
>>> configuration = model.config
|
>>> configuration = model.config
|
||||||
```"""
|
```"""
|
||||||
@ -140,7 +132,7 @@ class NemotronConfig(PretrainedConfig):
|
|||||||
tie_word_embeddings=False,
|
tie_word_embeddings=False,
|
||||||
rope_theta=10000.0,
|
rope_theta=10000.0,
|
||||||
rope_scaling=None,
|
rope_scaling=None,
|
||||||
rope_percent=0.5,
|
partial_rotary_factor=0.5,
|
||||||
attention_bias=False,
|
attention_bias=False,
|
||||||
attention_dropout=0.0,
|
attention_dropout=0.0,
|
||||||
mlp_bias=False,
|
mlp_bias=False,
|
||||||
@ -167,8 +159,10 @@ class NemotronConfig(PretrainedConfig):
|
|||||||
self.use_cache = use_cache
|
self.use_cache = use_cache
|
||||||
self.rope_theta = rope_theta
|
self.rope_theta = rope_theta
|
||||||
self.rope_scaling = rope_scaling
|
self.rope_scaling = rope_scaling
|
||||||
rope_percent = rope_percent or kwargs.get("rope_percentage", None)
|
# for backward compatibility
|
||||||
self.rope_percent = rope_percent
|
partial_rotary_factor = kwargs.get("rope_percent", None) or kwargs.get(
|
||||||
|
"rope_percentage", None) or partial_rotary_factor
|
||||||
|
self.partial_rotary_factor = partial_rotary_factor
|
||||||
self._rope_scaling_validation()
|
self._rope_scaling_validation()
|
||||||
self.attention_bias = attention_bias
|
self.attention_bias = attention_bias
|
||||||
self.attention_dropout = attention_dropout
|
self.attention_dropout = attention_dropout
|
||||||
|
Loading…
x
Reference in New Issue
Block a user