[Bugfix] Fix GLM4 model (#16618)

Signed-off-by: intervitens <intervitens@tutanota.com>
This commit is contained in:
intervitens 2025-04-17 13:35:07 +03:00 committed by GitHub
parent d8e557b5e5
commit 5b1aca2ae3
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 6 additions and 6 deletions

View File

@ -338,7 +338,7 @@ See [this page](#generative-models) for more information on how to use generativ
* ✅︎
- * `Glm4ForCausalLM`
* GLM-4-0414
* `THUDM/GLM-4-32B-Chat-0414`, etc.
* `THUDM/GLM-4-32B-0414`, etc.
* ✅︎
* ✅︎
- * `GPT2LMHeadModel`

View File

@ -147,7 +147,7 @@ _TEXT_GENERATION_EXAMPLE_MODELS = {
min_transformers_version="4.50"),
"GlmForCausalLM": _HfExamplesInfo("THUDM/glm-4-9b-chat-hf"),
"Glm4ForCausalLM": _HfExamplesInfo(
"THUDM/GLM-4-32B-Chat-0414",
"THUDM/GLM-4-32B-0414",
is_available_online=False,
min_transformers_version="4.52.dev0"
),

View File

@ -82,7 +82,7 @@ class Glm4Attention(nn.Module):
partial_rotary_factor = getattr(config, "partial_rotary_factor", 0.5)
self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size)
self.head_dim = head_dim or hidden_size // self.total_num_heads
self.rotary_dim = int(partial_rotary_factor * self.head_dim)
self.rotary_dim = self.head_dim
self.q_size = self.num_heads * self.head_dim
self.kv_size = self.num_kv_heads * self.head_dim
self.scaling = self.head_dim**-0.5
@ -110,6 +110,7 @@ class Glm4Attention(nn.Module):
base=self.rope_theta,
rope_scaling=rope_scaling,
partial_rotary_factor=partial_rotary_factor,
is_neox_style=False,
)
self.attn = Attention(self.num_heads,
self.head_dim,
@ -197,13 +198,12 @@ class Glm4DecoderLayer(nn.Module):
)
hidden_states = self.post_self_attn_layernorm(hidden_states)
hidden_states = residual + hidden_states
# Fully Connected
hidden_states = self.post_attention_layernorm(hidden_states, residual)
hidden_states, residual = self.post_attention_layernorm(
hidden_states, residual)
hidden_states = self.mlp(hidden_states)
hidden_states = self.post_mlp_layernorm(hidden_states)
hidden_states = residual + hidden_states
return hidden_states, residual