[BugFix] 1D query fix for MoE models (#3597)
This commit is contained in:
parent
af9e53496f
commit
41deac4a3d
@ -81,11 +81,13 @@ def test_mixtral_moe(dtype: torch.dtype):
|
|||||||
vllm_moe.w2s[i][:] = hf_moe.experts[i].w2.weight.data
|
vllm_moe.w2s[i][:] = hf_moe.experts[i].w2.weight.data
|
||||||
|
|
||||||
# Generate input batch of dimensions [batch_size, seq_len, hidden_dim]
|
# Generate input batch of dimensions [batch_size, seq_len, hidden_dim]
|
||||||
inputs = torch.randn((1, 64, config.hidden_size)).to(dtype).to("cuda")
|
hf_inputs = torch.randn((1, 64, config.hidden_size)).to(dtype).to("cuda")
|
||||||
|
# vLLM uses 1D query [num_tokens, hidden_dim]
|
||||||
|
vllm_inputs = hf_inputs.flatten(0, 1)
|
||||||
|
|
||||||
# Run forward passes for both MoE blocks
|
# Run forward passes for both MoE blocks
|
||||||
hf_states, _ = hf_moe.forward(inputs)
|
hf_states, _ = hf_moe.forward(hf_inputs)
|
||||||
vllm_states = vllm_moe.forward(inputs)
|
vllm_states = vllm_moe.forward(vllm_inputs)
|
||||||
|
|
||||||
mixtral_moe_tol = {
|
mixtral_moe_tol = {
|
||||||
torch.float32: 1e-3,
|
torch.float32: 1e-3,
|
||||||
@ -93,7 +95,7 @@ def test_mixtral_moe(dtype: torch.dtype):
|
|||||||
torch.bfloat16: 1e-2,
|
torch.bfloat16: 1e-2,
|
||||||
}
|
}
|
||||||
|
|
||||||
assert torch.allclose(hf_states,
|
assert torch.allclose(hf_states.flatten(0, 1),
|
||||||
vllm_states,
|
vllm_states,
|
||||||
rtol=mixtral_moe_tol[dtype],
|
rtol=mixtral_moe_tol[dtype],
|
||||||
atol=mixtral_moe_tol[dtype])
|
atol=mixtral_moe_tol[dtype])
|
||||||
|
@ -150,11 +150,11 @@ class DeepseekMoE(nn.Module):
|
|||||||
self.w2 = self.w2.view(len(w2), *w2s[0].shape)
|
self.w2 = self.w2.view(len(w2), *w2s[0].shape)
|
||||||
|
|
||||||
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||||
batch_size, sequence_length, hidden_dim = hidden_states.shape
|
num_tokens, hidden_dim = hidden_states.shape
|
||||||
hidden_states = hidden_states.view(-1, hidden_dim)
|
hidden_states = hidden_states.view(-1, hidden_dim)
|
||||||
if self.config.n_shared_experts is not None:
|
if self.config.n_shared_experts is not None:
|
||||||
shared_output = self.shared_experts(hidden_states)
|
shared_output = self.shared_experts(hidden_states)
|
||||||
# router_logits: (batch * sequence_length, n_experts)
|
# router_logits: (num_tokens, n_experts)
|
||||||
router_logits, _ = self.gate(hidden_states)
|
router_logits, _ = self.gate(hidden_states)
|
||||||
final_hidden_states = fused_moe(hidden_states,
|
final_hidden_states = fused_moe(hidden_states,
|
||||||
self.w1,
|
self.w1,
|
||||||
@ -169,8 +169,7 @@ class DeepseekMoE(nn.Module):
|
|||||||
final_hidden_states = tensor_model_parallel_all_reduce(
|
final_hidden_states = tensor_model_parallel_all_reduce(
|
||||||
final_hidden_states)
|
final_hidden_states)
|
||||||
|
|
||||||
return final_hidden_states.view(batch_size, sequence_length,
|
return final_hidden_states.view(num_tokens, hidden_dim)
|
||||||
hidden_dim)
|
|
||||||
|
|
||||||
|
|
||||||
class DeepseekAttention(nn.Module):
|
class DeepseekAttention(nn.Module):
|
||||||
|
@ -124,9 +124,9 @@ class MixtralMoE(nn.Module):
|
|||||||
param_data[expert_id, :, :] = loaded_weight[:, shard]
|
param_data[expert_id, :, :] = loaded_weight[:, shard]
|
||||||
|
|
||||||
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||||
batch_size, sequence_length, hidden_size = hidden_states.shape
|
num_tokens, hidden_size = hidden_states.shape
|
||||||
hidden_states = hidden_states.view(-1, self.hidden_size)
|
hidden_states = hidden_states.view(-1, self.hidden_size)
|
||||||
# router_logits: (batch * sequence_length, n_experts)
|
# router_logits: (num_tokens, n_experts)
|
||||||
router_logits, _ = self.gate(hidden_states)
|
router_logits, _ = self.gate(hidden_states)
|
||||||
final_hidden_states = fused_moe(hidden_states,
|
final_hidden_states = fused_moe(hidden_states,
|
||||||
self.ws,
|
self.ws,
|
||||||
@ -140,8 +140,7 @@ class MixtralMoE(nn.Module):
|
|||||||
final_hidden_states = tensor_model_parallel_all_reduce(
|
final_hidden_states = tensor_model_parallel_all_reduce(
|
||||||
final_hidden_states)
|
final_hidden_states)
|
||||||
|
|
||||||
return final_hidden_states.view(batch_size, sequence_length,
|
return final_hidden_states.view(num_tokens, hidden_size)
|
||||||
hidden_size)
|
|
||||||
|
|
||||||
|
|
||||||
class MixtralAttention(nn.Module):
|
class MixtralAttention(nn.Module):
|
||||||
|
@ -132,9 +132,9 @@ class MixtralMoE(nn.Module):
|
|||||||
linear_method=None)
|
linear_method=None)
|
||||||
|
|
||||||
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||||
batch_size, sequence_length, hidden_dim = hidden_states.shape
|
num_tokens, hidden_dim = hidden_states.shape
|
||||||
hidden_states = hidden_states.view(-1, hidden_dim)
|
hidden_states = hidden_states.view(-1, hidden_dim)
|
||||||
# router_logits: (batch * sequence_length, n_experts)
|
# router_logits: (num_tokens, n_experts)
|
||||||
router_logits, _ = self.gate(hidden_states)
|
router_logits, _ = self.gate(hidden_states)
|
||||||
|
|
||||||
routing_weights = F.softmax(router_logits, dim=1, dtype=torch.float)
|
routing_weights = F.softmax(router_logits, dim=1, dtype=torch.float)
|
||||||
@ -158,7 +158,7 @@ class MixtralMoE(nn.Module):
|
|||||||
final_hidden_states.add_(current_hidden_states)
|
final_hidden_states.add_(current_hidden_states)
|
||||||
|
|
||||||
return tensor_model_parallel_all_reduce(final_hidden_states).view(
|
return tensor_model_parallel_all_reduce(final_hidden_states).view(
|
||||||
batch_size, sequence_length, hidden_dim)
|
num_tokens, hidden_dim)
|
||||||
|
|
||||||
|
|
||||||
class MixtralAttention(nn.Module):
|
class MixtralAttention(nn.Module):
|
||||||
|
Loading…
x
Reference in New Issue
Block a user