[ROCM] enable aiter fused moe kernel for llama4 bf16 checkpoints (#16674)
This commit is contained in:
parent
eb5819b2d9
commit
92edf35826
@ -26,6 +26,7 @@ def rocm_aiter_fused_experts(
|
|||||||
topk_weights: torch.Tensor,
|
topk_weights: torch.Tensor,
|
||||||
topk_ids: torch.Tensor,
|
topk_ids: torch.Tensor,
|
||||||
use_fp8_w8a8: bool = False,
|
use_fp8_w8a8: bool = False,
|
||||||
|
apply_router_weight_on_input: bool = False,
|
||||||
w1_scale: Optional[torch.Tensor] = None,
|
w1_scale: Optional[torch.Tensor] = None,
|
||||||
w2_scale: Optional[torch.Tensor] = None,
|
w2_scale: Optional[torch.Tensor] = None,
|
||||||
block_shape: Optional[List[int]] = None,
|
block_shape: Optional[List[int]] = None,
|
||||||
@ -39,6 +40,18 @@ def rocm_aiter_fused_experts(
|
|||||||
from vllm.model_executor.layers.quantization.utils.fp8_utils import (
|
from vllm.model_executor.layers.quantization.utils.fp8_utils import (
|
||||||
per_token_group_quant_fp8)
|
per_token_group_quant_fp8)
|
||||||
|
|
||||||
|
if apply_router_weight_on_input:
|
||||||
|
assert (topk_weights.dim() == 2
|
||||||
|
), "`topk_weights` should be in shape (num_tokens, topk)"
|
||||||
|
_, topk = topk_weights.shape
|
||||||
|
assert (
|
||||||
|
topk == 1
|
||||||
|
), "Only support topk=1 when `apply_router_weight_on_input` is True"
|
||||||
|
|
||||||
|
hidden_states = hidden_states * topk_weights.to(hidden_states.dtype)
|
||||||
|
topk_ids = topk_ids.to(torch.int32)
|
||||||
|
topk_weights = torch.ones_like(topk_weights, dtype=torch.float32)
|
||||||
|
|
||||||
if envs.VLLM_ROCM_USE_AITER_FP8_BLOCK_SCALED_MOE and use_fp8_w8a8:
|
if envs.VLLM_ROCM_USE_AITER_FP8_BLOCK_SCALED_MOE and use_fp8_w8a8:
|
||||||
assert w1_scale is not None
|
assert w1_scale is not None
|
||||||
assert w2_scale is not None
|
assert w2_scale is not None
|
||||||
|
Loading…
x
Reference in New Issue
Block a user