[Hardware][TPU] workaround fix for MoE on TPU (#11764)
This commit is contained in:
parent
8bddb73512
commit
263a870ee1
@ -14,6 +14,8 @@ from vllm import _custom_ops as ops
|
||||
from vllm.model_executor.layers.fused_moe import fused_moe
|
||||
from vllm.model_executor.layers.fused_moe.fused_moe import (
|
||||
fused_topk, moe_align_block_size)
|
||||
from vllm.model_executor.layers.fused_moe.moe_torch_iterative import (
|
||||
fused_moe as iterative_moe)
|
||||
from vllm.model_executor.layers.quantization.utils.marlin_utils_test import (
|
||||
marlin_quantize)
|
||||
from vllm.model_executor.models.mixtral import MixtralMoE
|
||||
@ -46,6 +48,11 @@ def test_fused_moe(
|
||||
triton_output = fused_moe(a, w1, w2, score, topk, renormalize=False)
|
||||
torch_output = torch_moe(a, w1, w2, score, topk)
|
||||
torch.testing.assert_close(triton_output, torch_output, atol=2e-2, rtol=0)
|
||||
iterative_output = iterative_moe(a, w1, w2, score, topk, renormalize=False)
|
||||
torch.testing.assert_close(iterative_output,
|
||||
torch_output,
|
||||
atol=2e-2,
|
||||
rtol=0)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("dtype",
|
||||
|
@ -20,7 +20,8 @@ if current_platform.is_cuda_alike():
|
||||
else:
|
||||
fused_experts = None # type: ignore
|
||||
if current_platform.is_tpu():
|
||||
from .moe_pallas import fused_moe as fused_moe_pallas
|
||||
# the iterative moe implementation is used until the moe_pallas is fixed
|
||||
from .moe_torch_iterative import fused_moe as fused_moe_pallas
|
||||
else:
|
||||
fused_moe_pallas = None # type: ignore
|
||||
logger = init_logger(__name__)
|
||||
|
51
vllm/model_executor/layers/fused_moe/moe_torch_iterative.py
Normal file
51
vllm/model_executor/layers/fused_moe/moe_torch_iterative.py
Normal file
@ -0,0 +1,51 @@
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
|
||||
|
||||
def fused_moe(
|
||||
hidden_states: torch.Tensor,
|
||||
w1: torch.Tensor,
|
||||
w2: torch.Tensor,
|
||||
gating_output: torch.Tensor,
|
||||
topk: int,
|
||||
renormalize: bool,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Args:
|
||||
hidden_states: [*, hidden_size]
|
||||
w1: [num_experts, intermediate_size * 2, hidden_size]
|
||||
w2: [num_experts, hidden_size, intermediate_size]
|
||||
gating_output: [*, num_experts]
|
||||
"""
|
||||
orig_shape = hidden_states.shape
|
||||
hidden_size = hidden_states.shape[-1]
|
||||
num_tokens = hidden_states.shape[:-1].numel()
|
||||
num_experts = w1.shape[0]
|
||||
intermediate_size = w2.shape[-1]
|
||||
dtype = hidden_states.dtype
|
||||
|
||||
hidden_states = hidden_states.view(num_tokens, hidden_size)
|
||||
gating_output = gating_output.view(num_tokens, num_experts)
|
||||
topk_weights = gating_output.softmax(dim=-1, dtype=torch.float)
|
||||
topk_weights, selected_experts = topk_weights.topk(topk, dim=-1)
|
||||
if renormalize:
|
||||
topk_weights = topk_weights / topk_weights.sum(dim=-1, keepdim=True)
|
||||
topk_weights = topk_weights.to(dtype)
|
||||
|
||||
final_hidden_states = None
|
||||
for expert_idx in range(num_experts):
|
||||
expert_w1 = w1[expert_idx]
|
||||
expert_w2 = w2[expert_idx]
|
||||
expert_mask = (selected_experts == expert_idx)
|
||||
expert_weights = (topk_weights * expert_mask).sum(dim=-1, keepdim=True)
|
||||
x = F.linear(hidden_states, expert_w1)
|
||||
gate = F.silu(x[:, :intermediate_size])
|
||||
x = x[:, intermediate_size:] * gate
|
||||
x = F.linear(x, expert_w2)
|
||||
current_hidden_states = x * expert_weights
|
||||
if final_hidden_states is None:
|
||||
final_hidden_states = current_hidden_states
|
||||
else:
|
||||
final_hidden_states = final_hidden_states + current_hidden_states
|
||||
|
||||
return final_hidden_states.view(orig_shape) # type: ignore
|
Loading…
x
Reference in New Issue
Block a user