diff --git a/vllm/model_executor/models/deepseek_v2.py b/vllm/model_executor/models/deepseek_v2.py index 548f913c..d66f61a8 100644 --- a/vllm/model_executor/models/deepseek_v2.py +++ b/vllm/model_executor/models/deepseek_v2.py @@ -155,11 +155,21 @@ class DeepseekV2MoE(nn.Module): shared_output = self.shared_experts(hidden_states) # router_logits: (num_tokens, n_experts) router_logits, _ = self.gate(hidden_states) - final_hidden_states = self.experts( - hidden_states=hidden_states, - router_logits=router_logits) * self.routed_scaling_factor + if hidden_states.dtype != torch.float16: + final_hidden_states = self.experts( + hidden_states=hidden_states, + router_logits=router_logits) * self.routed_scaling_factor + else: + # This is a special case to avoid FP16 overflow + final_hidden_states = self.experts(hidden_states=hidden_states, + router_logits=router_logits) if shared_output is not None: - final_hidden_states = final_hidden_states + shared_output + if hidden_states.dtype != torch.float16: + final_hidden_states = final_hidden_states + shared_output + else: + # This is a special case to avoid FP16 overflow + final_hidden_states = final_hidden_states + shared_output \ + * (1. / self.routed_scaling_factor) if self.tp_size > 1: final_hidden_states = tensor_model_parallel_all_reduce( final_hidden_states) @@ -531,6 +541,7 @@ class DeepseekV2DecoderLayer(nn.Module): eps=config.rms_norm_eps) self.post_attention_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.routed_scaling_factor = config.routed_scaling_factor def forward( self, @@ -551,9 +562,18 @@ class DeepseekV2DecoderLayer(nn.Module): ) # Fully Connected + if isinstance(self.mlp, DeepseekV2MoE) and \ + hidden_states.dtype == torch.float16: + # This is a special case to avoid FP16 overflow + hidden_states *= 1. / self.routed_scaling_factor hidden_states, residual = self.post_attention_layernorm( hidden_states, residual) hidden_states = self.mlp(hidden_states) + if isinstance(self.mlp, DeepseekV2MLP) and \ + hidden_states.dtype == torch.float16: + # This is a special case to avoid FP16 overflow + hidden_states *= 1. / self.routed_scaling_factor + residual *= 1. / self.routed_scaling_factor return hidden_states, residual