[BugFix] 1D query fix for MoE models (#3597)

This commit is contained in:
Nick Hill 2024-03-24 16:00:16 -07:00 committed by GitHub
parent af9e53496f
commit 41deac4a3d
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 15 additions and 15 deletions

View File

@ -81,11 +81,13 @@ def test_mixtral_moe(dtype: torch.dtype):
vllm_moe.w2s[i][:] = hf_moe.experts[i].w2.weight.data
# Generate input batch of dimensions [batch_size, seq_len, hidden_dim]
inputs = torch.randn((1, 64, config.hidden_size)).to(dtype).to("cuda")
hf_inputs = torch.randn((1, 64, config.hidden_size)).to(dtype).to("cuda")
# vLLM uses 1D query [num_tokens, hidden_dim]
vllm_inputs = hf_inputs.flatten(0, 1)
# Run forward passes for both MoE blocks
hf_states, _ = hf_moe.forward(inputs)
vllm_states = vllm_moe.forward(inputs)
hf_states, _ = hf_moe.forward(hf_inputs)
vllm_states = vllm_moe.forward(vllm_inputs)
mixtral_moe_tol = {
torch.float32: 1e-3,
@ -93,7 +95,7 @@ def test_mixtral_moe(dtype: torch.dtype):
torch.bfloat16: 1e-2,
}
assert torch.allclose(hf_states,
assert torch.allclose(hf_states.flatten(0, 1),
vllm_states,
rtol=mixtral_moe_tol[dtype],
atol=mixtral_moe_tol[dtype])

View File

@ -150,11 +150,11 @@ class DeepseekMoE(nn.Module):
self.w2 = self.w2.view(len(w2), *w2s[0].shape)
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
batch_size, sequence_length, hidden_dim = hidden_states.shape
num_tokens, hidden_dim = hidden_states.shape
hidden_states = hidden_states.view(-1, hidden_dim)
if self.config.n_shared_experts is not None:
shared_output = self.shared_experts(hidden_states)
# router_logits: (batch * sequence_length, n_experts)
# router_logits: (num_tokens, n_experts)
router_logits, _ = self.gate(hidden_states)
final_hidden_states = fused_moe(hidden_states,
self.w1,
@ -169,8 +169,7 @@ class DeepseekMoE(nn.Module):
final_hidden_states = tensor_model_parallel_all_reduce(
final_hidden_states)
return final_hidden_states.view(batch_size, sequence_length,
hidden_dim)
return final_hidden_states.view(num_tokens, hidden_dim)
class DeepseekAttention(nn.Module):

View File

@ -124,9 +124,9 @@ class MixtralMoE(nn.Module):
param_data[expert_id, :, :] = loaded_weight[:, shard]
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
batch_size, sequence_length, hidden_size = hidden_states.shape
num_tokens, hidden_size = hidden_states.shape
hidden_states = hidden_states.view(-1, self.hidden_size)
# router_logits: (batch * sequence_length, n_experts)
# router_logits: (num_tokens, n_experts)
router_logits, _ = self.gate(hidden_states)
final_hidden_states = fused_moe(hidden_states,
self.ws,
@ -140,8 +140,7 @@ class MixtralMoE(nn.Module):
final_hidden_states = tensor_model_parallel_all_reduce(
final_hidden_states)
return final_hidden_states.view(batch_size, sequence_length,
hidden_size)
return final_hidden_states.view(num_tokens, hidden_size)
class MixtralAttention(nn.Module):

View File

@ -132,9 +132,9 @@ class MixtralMoE(nn.Module):
linear_method=None)
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
batch_size, sequence_length, hidden_dim = hidden_states.shape
num_tokens, hidden_dim = hidden_states.shape
hidden_states = hidden_states.view(-1, hidden_dim)
# router_logits: (batch * sequence_length, n_experts)
# router_logits: (num_tokens, n_experts)
router_logits, _ = self.gate(hidden_states)
routing_weights = F.softmax(router_logits, dim=1, dtype=torch.float)
@ -158,7 +158,7 @@ class MixtralMoE(nn.Module):
final_hidden_states.add_(current_hidden_states)
return tensor_model_parallel_all_reduce(final_hidden_states).view(
batch_size, sequence_length, hidden_dim)
num_tokens, hidden_dim)
class MixtralAttention(nn.Module):