[ROCM] enable aiter fused moe kernel for llama4 bf16 checkpoints (#16674)
This commit is contained in:
parent
eb5819b2d9
commit
92edf35826
@ -26,6 +26,7 @@ def rocm_aiter_fused_experts(
|
||||
topk_weights: torch.Tensor,
|
||||
topk_ids: torch.Tensor,
|
||||
use_fp8_w8a8: bool = False,
|
||||
apply_router_weight_on_input: bool = False,
|
||||
w1_scale: Optional[torch.Tensor] = None,
|
||||
w2_scale: Optional[torch.Tensor] = None,
|
||||
block_shape: Optional[List[int]] = None,
|
||||
@ -39,6 +40,18 @@ def rocm_aiter_fused_experts(
|
||||
from vllm.model_executor.layers.quantization.utils.fp8_utils import (
|
||||
per_token_group_quant_fp8)
|
||||
|
||||
if apply_router_weight_on_input:
|
||||
assert (topk_weights.dim() == 2
|
||||
), "`topk_weights` should be in shape (num_tokens, topk)"
|
||||
_, topk = topk_weights.shape
|
||||
assert (
|
||||
topk == 1
|
||||
), "Only support topk=1 when `apply_router_weight_on_input` is True"
|
||||
|
||||
hidden_states = hidden_states * topk_weights.to(hidden_states.dtype)
|
||||
topk_ids = topk_ids.to(torch.int32)
|
||||
topk_weights = torch.ones_like(topk_weights, dtype=torch.float32)
|
||||
|
||||
if envs.VLLM_ROCM_USE_AITER_FP8_BLOCK_SCALED_MOE and use_fp8_w8a8:
|
||||
assert w1_scale is not None
|
||||
assert w2_scale is not None
|
||||
|
Loading…
x
Reference in New Issue
Block a user