diff --git a/vllm/model_executor/layers/fused_moe/rocm_aiter_fused_moe.py b/vllm/model_executor/layers/fused_moe/rocm_aiter_fused_moe.py index ac158a7e..4214e894 100644 --- a/vllm/model_executor/layers/fused_moe/rocm_aiter_fused_moe.py +++ b/vllm/model_executor/layers/fused_moe/rocm_aiter_fused_moe.py @@ -26,6 +26,7 @@ def rocm_aiter_fused_experts( topk_weights: torch.Tensor, topk_ids: torch.Tensor, use_fp8_w8a8: bool = False, + apply_router_weight_on_input: bool = False, w1_scale: Optional[torch.Tensor] = None, w2_scale: Optional[torch.Tensor] = None, block_shape: Optional[List[int]] = None, @@ -39,6 +40,18 @@ def rocm_aiter_fused_experts( from vllm.model_executor.layers.quantization.utils.fp8_utils import ( per_token_group_quant_fp8) + if apply_router_weight_on_input: + assert (topk_weights.dim() == 2 + ), "`topk_weights` should be in shape (num_tokens, topk)" + _, topk = topk_weights.shape + assert ( + topk == 1 + ), "Only support topk=1 when `apply_router_weight_on_input` is True" + + hidden_states = hidden_states * topk_weights.to(hidden_states.dtype) + topk_ids = topk_ids.to(torch.int32) + topk_weights = torch.ones_like(topk_weights, dtype=torch.float32) + if envs.VLLM_ROCM_USE_AITER_FP8_BLOCK_SCALED_MOE and use_fp8_w8a8: assert w1_scale is not None assert w2_scale is not None