From ec7da6fcf32fc05efe5d7ba30d01d3d940f12a3c Mon Sep 17 00:00:00 2001 From: Lucia Fang <116399278+luccafong@users.noreply.github.com> Date: Wed, 9 Apr 2025 00:59:14 -0700 Subject: [PATCH] [BugFix] llama4 qknorm should be not shared across head (#16311) Signed-off-by: Lu Fang --- vllm/model_executor/models/llama4.py | 19 +++++++------------ 1 file changed, 7 insertions(+), 12 deletions(-) diff --git a/vllm/model_executor/models/llama4.py b/vllm/model_executor/models/llama4.py index 029f6044..3dbf352a 100644 --- a/vllm/model_executor/models/llama4.py +++ b/vllm/model_executor/models/llama4.py @@ -155,14 +155,8 @@ class Llama4Attention(nn.Module): self.rope_theta = rope_theta self.max_position_embeddings = max_position_embeddings self.n_rep = self.num_heads // self.num_kv_heads - self.q_norm = RMSNorm( - hidden_size=self.q_size, - eps=config.rms_norm_eps, - has_weight=False, - dtype=torch.float32, - ) if self.use_qk_norm else None - self.k_norm = RMSNorm( - hidden_size=self.kv_size, + self.qk_norm = RMSNorm( + hidden_size=self.head_dim, eps=config.rms_norm_eps, has_weight=False, dtype=torch.float32, @@ -226,10 +220,11 @@ class Llama4Attention(nn.Module): if self.rotary_emb is not None: q, k = self.rotary_emb(positions, q, k) - if self.q_norm is not None: - q = self.q_norm(q.float()).to(q.dtype) - if self.k_norm is not None: - k = self.k_norm(k.float()).to(k.dtype) + if self.qk_norm is not None: + q = q.reshape(-1, self.num_heads, self.head_dim) + q = self.qk_norm(q.float()).reshape(-1, self.q_size).to(q.dtype) + k = k.reshape(-1, self.num_kv_heads, self.head_dim) + k = self.qk_norm(k.float()).reshape(-1, self.kv_size).to(k.dtype) # We are applying temperature tuning (https://arxiv.org/abs/2501.19399) # to NoPE layers, where the inference-time temperature tuning function