[Model] Align nemotron config with final HF state and fix lm-eval-small (#7611)

This commit is contained in:
Michael Goin 2024-08-16 18:56:34 -04:00 committed by GitHub
parent 37fd47e780
commit 44f26a9466
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 29 additions and 35 deletions

View File

@ -1,11 +1,11 @@
# bash .buildkite/lm-eval-harness/run-lm-eval-gsm-vllm-baseline.sh -m nvidia/Minitron-4B-Base -b auto -l 1000 -f 5 -t 1
model_name: "nvidia/Minitron-4B-Base"
# bash .buildkite/lm-eval-harness/run-lm-eval-gsm-vllm-baseline.sh -m mgoin/Minitron-4B-Base-FP8 -b auto -l 1000 -f 5 -t 1
model_name: "mgoin/Minitron-4B-Base-FP8"
tasks:
- name: "gsm8k"
metrics:
- name: "exact_match,strict-match"
value: 0.252
value: 0.233
- name: "exact_match,flexible-extract"
value: 0.252
value: 0.236
limit: 1000
num_fewshot: 5

View File

@ -4,7 +4,7 @@ Meta-Llama-3-8B-Instruct-FP8-compressed-tensors.yaml
Meta-Llama-3-8B-Instruct-INT8-compressed-tensors.yaml
Meta-Llama-3-8B-Instruct-nonuniform-compressed-tensors.yaml
Meta-Llama-3-8B-Instruct-Channelwise-compressed-tensors.yaml
Minitron-4B-Base.yaml
Minitron-4B-Base-FP8.yaml
Qwen2-1.5B-Instruct-INT8-compressed-tensors.yaml
Qwen2-1.5B-Instruct-FP8W8.yaml
Meta-Llama-3-8B-QQQ.yaml

View File

@ -774,7 +774,7 @@ def get_rope(
is_neox_style: bool = True,
rope_scaling: Optional[Dict[str, Any]] = None,
dtype: Optional[torch.dtype] = None,
rotary_percent: float = 1.0,
partial_rotary_factor: float = 1.0,
) -> RotaryEmbedding:
if dtype is None:
dtype = torch.get_default_dtype()
@ -787,8 +787,8 @@ def get_rope(
rope_scaling_args = tuple(rope_scaling_tuple.items())
else:
rope_scaling_args = None
if rotary_percent < 1.0:
rotary_dim = int(rotary_dim * rotary_percent)
if partial_rotary_factor < 1.0:
rotary_dim = int(rotary_dim * partial_rotary_factor)
key = (head_size, rotary_dim, max_position, base, is_neox_style,
rope_scaling_args, dtype)
if key in _ROPE_DICT:

View File

@ -53,7 +53,7 @@ from .utils import PPMissingLayer, is_pp_missing_parameter, make_layers
# - There is no gate_proj, just up_proj
# - Normal LayerNorm (with a +1 to the weights) instead of RMSNorm
# - Squared ReLU instead of SwiGLU
# - Adds a rotary_percent to RoPE
# - Adds a partial_rotary_factor to RoPE
def _cast_if_autocast_enabled(*args):
@ -161,7 +161,7 @@ class NemotronAttention(nn.Module):
self.kv_size = self.num_kv_heads * self.head_dim
self.scaling = self.head_dim**-0.5
self.rope_theta = rope_theta
self.rotary_percent = config.rope_percent
self.partial_rotary_factor = config.partial_rotary_factor
self.max_position_embeddings = max_position_embeddings
self.qkv_proj = QKVParallelLinear(
@ -187,7 +187,7 @@ class NemotronAttention(nn.Module):
max_position=max_position_embeddings,
base=rope_theta,
rope_scaling=rope_scaling,
rotary_percent=self.rotary_percent,
partial_rotary_factor=self.partial_rotary_factor,
)
self.attn = Attention(self.num_heads,
self.head_dim,

View File

@ -35,20 +35,20 @@ class NemotronConfig(PretrainedConfig):
Args:
vocab_size (`int`, *optional*, defaults to 32000):
vocab_size (`int`, *optional*, defaults to 256000):
Vocabulary size of the Nemotron model. Defines the number of
different tokens that can be represented by the
`inputs_ids` passed when calling [`NemotronModel`]
hidden_size (`int`, *optional*, defaults to 4096):
hidden_size (`int`, *optional*, defaults to 6144):
Dimension of the hidden representations.
intermediate_size (`int`, *optional*, defaults to 11008):
intermediate_size (`int`, *optional*, defaults to 24576):
Dimension of the MLP representations.
num_hidden_layers (`int`, *optional*, defaults to 32):
Number of hidden layers in the Transformer decoder.
num_attention_heads (`int`, *optional*, defaults to 32):
num_attention_heads (`int`, *optional*, defaults to 48):
Number of attention heads for each attention layer in the
Transformer decoder.
head_dim (`int`, *optional*, defaults to None):
head_dim (`int`, *optional*):
Projection weights dimension in multi-head attention. Set to
hidden_size // num_attention_heads if None
num_key_value_heads (`int`, *optional*):
@ -63,16 +63,16 @@ class NemotronConfig(PretrainedConfig):
heads within that group. For more details checkout
[this paper](https://arxiv.org/pdf/2305.13245.pdf). If it
is not specified, will default to `num_attention_heads`.
hidden_act (`str` or `function`, *optional*, defaults to `"silu"`):
hidden_act (`str` or `function`, *optional*, defaults to `"relu2"`):
The non-linear activation function (function or string) in the
decoder.
max_position_embeddings (`int`, *optional*, defaults to 2048):
max_position_embeddings (`int`, *optional*, defaults to 4096):
The maximum sequence length that this model might ever be used
with.
initializer_range (`float`, *optional*, defaults to 0.02):
initializer_range (`float`, *optional*, defaults to 0.0134):
The standard deviation of the truncated_normal_initializer for
initializing all weight matrices.
norm_eps (`float`, *optional*, defaults to 1e-06):
norm_eps (`float`, *optional*, defaults to 1e-05):
The epsilon used by the normalization layers.
use_cache (`bool`, *optional*, defaults to `True`):
Whether or not the model should return the last key/values
@ -80,21 +80,16 @@ class NemotronConfig(PretrainedConfig):
`config.is_decoder=True`.
pad_token_id (`int`, *optional*):
Padding token id.
bos_token_id (`int`, *optional*, defaults to 1):
bos_token_id (`int`, *optional*, defaults to 2):
Beginning of stream token id.
eos_token_id (`int`, *optional*, defaults to 2):
eos_token_id (`int`, *optional*, defaults to 3):
End of stream token id.
tie_word_embeddings (`bool`, *optional*, defaults to `False`):
Whether to tie weight embeddings
rope_theta (`float`, *optional*, defaults to 10000.0):
The base period of the RoPE embeddings.
rope_scaling (`Dict`, *optional*):
Dictionary containing the scaling configuration for the RoPE
embeddings. Currently supports two scaling strategies: linear
and dynamic. Their scaling factor must be a float greater than 1.
The expected format is `{"type": strategy name,
"factor": scaling factor}`. When using this flag, don't update
`max_position_embeddings` to the expected new maximum.
partial_rotary_factor (`float`, *optional*, defaults to 0.5):
Percentage of the query and keys which will have rotary embedding.
attention_bias (`bool`, *optional*, defaults to `False`):
Whether to use a bias in the query, key, value and output
projection layers during self-attention.
@ -106,13 +101,10 @@ class NemotronConfig(PretrainedConfig):
```python
>>> from transformers import NemotronModel, NemotronConfig
>>> # Initializing a Nemotron nemotron-15b style configuration
>>> configuration = NemotronConfig()
>>> # Initializing a model from the nemotron-15b style configuration
>>> model = NemotronModel(configuration)
>>> # Accessing the model configuration
>>> configuration = model.config
```"""
@ -140,7 +132,7 @@ class NemotronConfig(PretrainedConfig):
tie_word_embeddings=False,
rope_theta=10000.0,
rope_scaling=None,
rope_percent=0.5,
partial_rotary_factor=0.5,
attention_bias=False,
attention_dropout=0.0,
mlp_bias=False,
@ -167,8 +159,10 @@ class NemotronConfig(PretrainedConfig):
self.use_cache = use_cache
self.rope_theta = rope_theta
self.rope_scaling = rope_scaling
rope_percent = rope_percent or kwargs.get("rope_percentage", None)
self.rope_percent = rope_percent
# for backward compatibility
partial_rotary_factor = kwargs.get("rope_percent", None) or kwargs.get(
"rope_percentage", None) or partial_rotary_factor
self.partial_rotary_factor = partial_rotary_factor
self._rope_scaling_validation()
self.attention_bias = attention_bias
self.attention_dropout = attention_dropout