From 114f0067434def302b3b278a99c82b50b982708c Mon Sep 17 00:00:00 2001 From: Yida Wu Date: Mon, 6 Jan 2025 20:18:05 +0000 Subject: [PATCH] deepseek overflow fix --- vllm/model_executor/models/deepseek_v2.py | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/vllm/model_executor/models/deepseek_v2.py b/vllm/model_executor/models/deepseek_v2.py index 4cf4e6c358bf2..939892c5d66c8 100644 --- a/vllm/model_executor/models/deepseek_v2.py +++ b/vllm/model_executor/models/deepseek_v2.py @@ -149,9 +149,9 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: router_logits, _ = self.gate(hidden_states) final_hidden_states = self.experts( hidden_states=hidden_states, - router_logits=router_logits) * self.routed_scaling_factor + router_logits=router_logits) if shared_output is not None: - final_hidden_states = final_hidden_states + shared_output + final_hidden_states = final_hidden_states + shared_output * (1. / self.routed_scaling_factor) if self.tp_size > 1: final_hidden_states = tensor_model_parallel_all_reduce( final_hidden_states) @@ -375,6 +375,7 @@ def __init__( eps=config.rms_norm_eps) self.post_attention_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.routed_scaling_factor = config.routed_scaling_factor def forward( self, @@ -399,9 +400,14 @@ def forward( ) # Fully Connected + if isinstance(self.mlp, DeepseekV2MoE): + hidden_states *= 1. / self.mlp.routed_scaling_factor hidden_states, residual = self.post_attention_layernorm( hidden_states, residual) hidden_states = self.mlp(hidden_states) + if isinstance(self.mlp, DeepseekV2MLP): + hidden_states *= 1. / self.routed_scaling_factor + residual *= 1. / self.routed_scaling_factor return hidden_states, residual