[Bugfix] Fix FP16 overflow for DeepSeek V2 (#13232)
Signed-off-by: Yida Wu <yida.wu@amd.com>
This commit is contained in:
parent
4290b704ff
commit
c982ac5722
@ -155,11 +155,21 @@ class DeepseekV2MoE(nn.Module):
|
|||||||
shared_output = self.shared_experts(hidden_states)
|
shared_output = self.shared_experts(hidden_states)
|
||||||
# router_logits: (num_tokens, n_experts)
|
# router_logits: (num_tokens, n_experts)
|
||||||
router_logits, _ = self.gate(hidden_states)
|
router_logits, _ = self.gate(hidden_states)
|
||||||
final_hidden_states = self.experts(
|
if hidden_states.dtype != torch.float16:
|
||||||
hidden_states=hidden_states,
|
final_hidden_states = self.experts(
|
||||||
router_logits=router_logits) * self.routed_scaling_factor
|
hidden_states=hidden_states,
|
||||||
|
router_logits=router_logits) * self.routed_scaling_factor
|
||||||
|
else:
|
||||||
|
# This is a special case to avoid FP16 overflow
|
||||||
|
final_hidden_states = self.experts(hidden_states=hidden_states,
|
||||||
|
router_logits=router_logits)
|
||||||
if shared_output is not None:
|
if shared_output is not None:
|
||||||
final_hidden_states = final_hidden_states + shared_output
|
if hidden_states.dtype != torch.float16:
|
||||||
|
final_hidden_states = final_hidden_states + shared_output
|
||||||
|
else:
|
||||||
|
# This is a special case to avoid FP16 overflow
|
||||||
|
final_hidden_states = final_hidden_states + shared_output \
|
||||||
|
* (1. / self.routed_scaling_factor)
|
||||||
if self.tp_size > 1:
|
if self.tp_size > 1:
|
||||||
final_hidden_states = tensor_model_parallel_all_reduce(
|
final_hidden_states = tensor_model_parallel_all_reduce(
|
||||||
final_hidden_states)
|
final_hidden_states)
|
||||||
@ -531,6 +541,7 @@ class DeepseekV2DecoderLayer(nn.Module):
|
|||||||
eps=config.rms_norm_eps)
|
eps=config.rms_norm_eps)
|
||||||
self.post_attention_layernorm = RMSNorm(config.hidden_size,
|
self.post_attention_layernorm = RMSNorm(config.hidden_size,
|
||||||
eps=config.rms_norm_eps)
|
eps=config.rms_norm_eps)
|
||||||
|
self.routed_scaling_factor = config.routed_scaling_factor
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
@ -551,9 +562,18 @@ class DeepseekV2DecoderLayer(nn.Module):
|
|||||||
)
|
)
|
||||||
|
|
||||||
# Fully Connected
|
# Fully Connected
|
||||||
|
if isinstance(self.mlp, DeepseekV2MoE) and \
|
||||||
|
hidden_states.dtype == torch.float16:
|
||||||
|
# This is a special case to avoid FP16 overflow
|
||||||
|
hidden_states *= 1. / self.routed_scaling_factor
|
||||||
hidden_states, residual = self.post_attention_layernorm(
|
hidden_states, residual = self.post_attention_layernorm(
|
||||||
hidden_states, residual)
|
hidden_states, residual)
|
||||||
hidden_states = self.mlp(hidden_states)
|
hidden_states = self.mlp(hidden_states)
|
||||||
|
if isinstance(self.mlp, DeepseekV2MLP) and \
|
||||||
|
hidden_states.dtype == torch.float16:
|
||||||
|
# This is a special case to avoid FP16 overflow
|
||||||
|
hidden_states *= 1. / self.routed_scaling_factor
|
||||||
|
residual *= 1. / self.routed_scaling_factor
|
||||||
return hidden_states, residual
|
return hidden_states, residual
|
||||||
|
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user