[ROCM] enable aiter fused moe kernel for llama4 bf16 checkpoints (#16674)

This commit is contained in:
Sijia(Jackson) Chen 2025-04-17 11:44:34 -07:00 committed by GitHub
parent eb5819b2d9
commit 92edf35826
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -26,6 +26,7 @@ def rocm_aiter_fused_experts(
topk_weights: torch.Tensor, topk_weights: torch.Tensor,
topk_ids: torch.Tensor, topk_ids: torch.Tensor,
use_fp8_w8a8: bool = False, use_fp8_w8a8: bool = False,
apply_router_weight_on_input: bool = False,
w1_scale: Optional[torch.Tensor] = None, w1_scale: Optional[torch.Tensor] = None,
w2_scale: Optional[torch.Tensor] = None, w2_scale: Optional[torch.Tensor] = None,
block_shape: Optional[List[int]] = None, block_shape: Optional[List[int]] = None,
@ -39,6 +40,18 @@ def rocm_aiter_fused_experts(
from vllm.model_executor.layers.quantization.utils.fp8_utils import ( from vllm.model_executor.layers.quantization.utils.fp8_utils import (
per_token_group_quant_fp8) per_token_group_quant_fp8)
if apply_router_weight_on_input:
assert (topk_weights.dim() == 2
), "`topk_weights` should be in shape (num_tokens, topk)"
_, topk = topk_weights.shape
assert (
topk == 1
), "Only support topk=1 when `apply_router_weight_on_input` is True"
hidden_states = hidden_states * topk_weights.to(hidden_states.dtype)
topk_ids = topk_ids.to(torch.int32)
topk_weights = torch.ones_like(topk_weights, dtype=torch.float32)
if envs.VLLM_ROCM_USE_AITER_FP8_BLOCK_SCALED_MOE and use_fp8_w8a8: if envs.VLLM_ROCM_USE_AITER_FP8_BLOCK_SCALED_MOE and use_fp8_w8a8:
assert w1_scale is not None assert w1_scale is not None
assert w2_scale is not None assert w2_scale is not None