[Bugfix] Fix deepseekv3 gate bias error (#12002)

Signed-off-by: mgoin <michael@neuralmagic.com>
Co-authored-by: mgoin <michael@neuralmagic.com>
This commit is contained in:
Steve Luo 2025-01-14 04:43:51 +08:00 committed by GitHub
parent 289b5191d5
commit f35ec461fc
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -497,7 +497,10 @@ def grouped_topk(hidden_states: torch.Tensor,
raise ValueError(f"Unsupported scoring function: {scoring_func}")
if e_score_correction_bias is not None:
scores.add_(e_score_correction_bias.unsqueeze(0))
# Store original scores before applying correction bias. We use biased
# scores for expert selection but original scores for routing weights
original_scores = scores
scores = scores + e_score_correction_bias.unsqueeze(0)
num_token = scores.shape[0]
group_scores = scores.view(num_token, num_expert_group,
@ -510,10 +513,16 @@ def grouped_topk(hidden_states: torch.Tensor,
num_token, num_expert_group,
scores.shape[-1] // num_expert_group).reshape(num_token, -1) # [n, e]
tmp_scores = scores.masked_fill(~score_mask.bool(), 0.0) # [n, e]
topk_weights, topk_ids = torch.topk(tmp_scores,
k=topk,
dim=-1,
sorted=False)
if e_score_correction_bias is not None:
topk_ids = torch.topk(tmp_scores, k=topk, dim=-1, sorted=False)[1]
# Use original unbiased scores for the routing weights
topk_weights = original_scores.gather(1, topk_ids)
else:
topk_weights, topk_ids = torch.topk(tmp_scores,
k=topk,
dim=-1,
sorted=False)
if renormalize:
topk_weights = topk_weights / topk_weights.sum(dim=-1, keepdim=True)