[FEAT][ROCm] Integrate Fused MoE Kernels from AITER (#14967)

Signed-off-by: vllmellm <vllm.ellm@embeddedllm.com>
Signed-off-by: tjtanaa <tunjian.tan@embeddedllm.com>
Co-authored-by: tjtanaa <tunjian.tan@embeddedllm.com>
This commit is contained in:
vllmellm 2025-03-26 16:30:30 +08:00 committed by GitHub
parent 781d056280
commit 5ebf66748b
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
9 changed files with 391 additions and 66 deletions

View File

@ -3,7 +3,6 @@
Run `pytest tests/kernels/test_moe.py`. Run `pytest tests/kernels/test_moe.py`.
""" """
import pytest import pytest
import torch import torch
from torch.nn import Parameter from torch.nn import Parameter
@ -216,11 +215,17 @@ def test_fused_moe_wn16(m: int, n: int, k: int, e: int, topk: int,
@pytest.mark.parametrize("dtype", @pytest.mark.parametrize("dtype",
[torch.float32, torch.float16, torch.bfloat16]) [torch.float32, torch.float16, torch.bfloat16])
@pytest.mark.parametrize("padding", [True, False]) @pytest.mark.parametrize("padding", [True, False])
@pytest.mark.parametrize(
"use_rocm_aiter", [True, False] if current_platform.is_rocm() else [False])
@torch.inference_mode() @torch.inference_mode()
def test_mixtral_moe(dtype: torch.dtype, padding: bool): def test_mixtral_moe(dtype: torch.dtype, padding: bool, use_rocm_aiter: bool,
monkeypatch):
"""Make sure our Mixtral MoE implementation agrees with the one from """Make sure our Mixtral MoE implementation agrees with the one from
huggingface.""" huggingface."""
if use_rocm_aiter:
monkeypatch.setenv("VLLM_ROCM_USE_AITER", "1")
# Instantiate our and huggingface's MoE blocks # Instantiate our and huggingface's MoE blocks
config = MixtralConfig() config = MixtralConfig()
hf_moe = MixtralSparseMoeBlock(config).to(dtype).to("cuda") hf_moe = MixtralSparseMoeBlock(config).to(dtype).to("cuda")
@ -268,10 +273,18 @@ def test_mixtral_moe(dtype: torch.dtype, padding: bool):
torch.bfloat16: 1e-2, torch.bfloat16: 1e-2,
} }
torch.testing.assert_close(hf_states.flatten(0, 1), if use_rocm_aiter:
vllm_states, # The values of rtol and atol are set based on the tests in ROCM AITER package. # noqa: E501
rtol=mixtral_moe_tol[dtype], # https://github.com/ROCm/aiter/blob/dfed377f4be7da96ca2d75ac0761f569676f7240/op_tests/test_moe.py#L174 # noqa: E501
atol=mixtral_moe_tol[dtype]) torch.testing.assert_close(hf_states.flatten(0, 1),
vllm_states,
rtol=0.01,
atol=100)
else:
torch.testing.assert_close(hf_states.flatten(0, 1),
vllm_states,
rtol=mixtral_moe_tol[dtype],
atol=mixtral_moe_tol[dtype])
@pytest.mark.parametrize("m", [1, 33, 64, 222]) @pytest.mark.parametrize("m", [1, 33, 64, 222])

View File

@ -7,6 +7,10 @@ from vllm.model_executor.custom_op import CustomOp
from vllm.model_executor.layers.activation import (GeluAndMul, from vllm.model_executor.layers.activation import (GeluAndMul,
ReLUSquaredActivation, ReLUSquaredActivation,
SiluAndMul) SiluAndMul)
from vllm.model_executor.layers.fused_moe.fused_moe import (
dispatch_fused_experts_func, dispatch_topk_func,
torch_vllm_inplace_fused_experts, torch_vllm_outplace_fused_experts,
vllm_topk_softmax)
from vllm.model_executor.layers.layernorm import ( from vllm.model_executor.layers.layernorm import (
RMSNorm, dispatch_cuda_rmsnorm_func, fused_add_rms_norm, rms_norm, RMSNorm, dispatch_cuda_rmsnorm_func, fused_add_rms_norm, rms_norm,
rocm_aiter_fused_add_rms_norm, rocm_aiter_rms_norm) rocm_aiter_fused_add_rms_norm, rocm_aiter_rms_norm)
@ -92,6 +96,38 @@ def test_enabled_ops_invalid(env: str):
RMSNorm(1024).enabled() RMSNorm(1024).enabled()
@pytest.mark.parametrize("use_rocm_aiter", ["0", "1"])
def test_topk_dispatch(use_rocm_aiter: str, monkeypatch):
monkeypatch.setenv("VLLM_ROCM_USE_AITER", use_rocm_aiter)
topk_func = dispatch_topk_func()
if current_platform.is_rocm() and int(use_rocm_aiter):
from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import (
rocm_aiter_topk_softmax)
assert topk_func == rocm_aiter_topk_softmax
else:
assert topk_func == vllm_topk_softmax
@pytest.mark.parametrize("use_rocm_aiter", ["0", "1"])
@pytest.mark.parametrize("inplace", [True, False])
def test_fused_experts_dispatch(use_rocm_aiter: str, inplace: bool,
monkeypatch):
monkeypatch.setenv("VLLM_ROCM_USE_AITER", use_rocm_aiter)
fused_experts_func = dispatch_fused_experts_func(inplace)
if current_platform.is_rocm() and int(use_rocm_aiter):
from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import (
rocm_aiter_fused_experts)
assert fused_experts_func == rocm_aiter_fused_experts
elif inplace:
assert fused_experts_func == torch_vllm_inplace_fused_experts
else:
assert fused_experts_func == torch_vllm_outplace_fused_experts
@pytest.mark.parametrize("add_residual", [True, False]) @pytest.mark.parametrize("add_residual", [True, False])
@pytest.mark.parametrize("use_rocm_aiter", ["0", "1"]) @pytest.mark.parametrize("use_rocm_aiter", ["0", "1"])
@pytest.mark.parametrize("use_rocm_aiter_norm", ["0", "1"]) @pytest.mark.parametrize("use_rocm_aiter_norm", ["0", "1"])

View File

@ -174,15 +174,8 @@ SAMPLE_JSON_SCHEMA = {
@pytest.mark.parametrize("dtype", ["bfloat16"]) @pytest.mark.parametrize("dtype", ["bfloat16"])
@pytest.mark.parametrize("max_tokens", [64]) @pytest.mark.parametrize("max_tokens", [64])
@pytest.mark.parametrize("num_logprobs", [5]) @pytest.mark.parametrize("num_logprobs", [5])
def test_models( def test_models(hf_runner, vllm_runner, example_prompts, model: str,
hf_runner, dtype: str, max_tokens: int, num_logprobs: int) -> None:
vllm_runner,
example_prompts,
model: str,
dtype: str,
max_tokens: int,
num_logprobs: int,
) -> None:
# TODO(sang): Sliding window should be tested separately. # TODO(sang): Sliding window should be tested separately.
with hf_runner(model, dtype=dtype) as hf_model: with hf_runner(model, dtype=dtype) as hf_model:
hf_outputs = hf_model.generate_greedy_logprobs_limit( hf_outputs = hf_model.generate_greedy_logprobs_limit(
@ -206,14 +199,8 @@ def test_models(
@pytest.mark.parametrize("dtype", ["bfloat16"]) @pytest.mark.parametrize("dtype", ["bfloat16"])
@pytest.mark.parametrize("max_tokens", [64]) @pytest.mark.parametrize("max_tokens", [64])
@pytest.mark.parametrize("num_logprobs", [5]) @pytest.mark.parametrize("num_logprobs", [5])
def test_mistral_format( def test_mistral_format(vllm_runner, example_prompts, model: str, dtype: str,
vllm_runner, max_tokens: int, num_logprobs: int) -> None:
example_prompts,
model: str,
dtype: str,
max_tokens: int,
num_logprobs: int,
) -> None:
with vllm_runner( with vllm_runner(
model, model,
dtype=dtype, dtype=dtype,
@ -244,11 +231,8 @@ def test_mistral_format(
@pytest.mark.parametrize("model", MISTRAL_FORMAT_MODELS) @pytest.mark.parametrize("model", MISTRAL_FORMAT_MODELS)
@pytest.mark.parametrize("dtype", ["bfloat16"]) @pytest.mark.parametrize("dtype", ["bfloat16"])
def test_mistral_symbolic_languages( def test_mistral_symbolic_languages(vllm_runner, model: str,
vllm_runner, dtype: str) -> None:
model: str,
dtype: str,
) -> None:
with vllm_runner(model, with vllm_runner(model,
dtype=dtype, dtype=dtype,
max_model_len=8192, max_model_len=8192,
@ -266,11 +250,7 @@ def test_mistral_symbolic_languages(
@pytest.mark.parametrize("dtype", ["bfloat16"]) @pytest.mark.parametrize("dtype", ["bfloat16"])
@pytest.mark.parametrize("model", @pytest.mark.parametrize("model",
MISTRAL_FORMAT_MODELS) # v1 can't do func calling MISTRAL_FORMAT_MODELS) # v1 can't do func calling
def test_mistral_function_calling( def test_mistral_function_calling(vllm_runner, model: str, dtype: str) -> None:
vllm_runner,
model: str,
dtype: str,
) -> None:
with vllm_runner(model, with vllm_runner(model,
dtype=dtype, dtype=dtype,
tokenizer_mode="mistral", tokenizer_mode="mistral",
@ -301,11 +281,8 @@ def test_mistral_function_calling(
@pytest.mark.parametrize("model", MODELS) @pytest.mark.parametrize("model", MODELS)
@pytest.mark.parametrize("guided_backend", @pytest.mark.parametrize("guided_backend",
["outlines", "lm-format-enforcer", "xgrammar"]) ["outlines", "lm-format-enforcer", "xgrammar"])
def test_mistral_guided_decoding( def test_mistral_guided_decoding(vllm_runner, model: str,
vllm_runner, guided_backend: str) -> None:
model: str,
guided_backend: str,
) -> None:
with vllm_runner(model, dtype='bfloat16', with vllm_runner(model, dtype='bfloat16',
tokenizer_mode="mistral") as vllm_model: tokenizer_mode="mistral") as vllm_model:

View File

@ -23,8 +23,14 @@ MODELS = [
reason="FP8 is not supported on this GPU type.") reason="FP8 is not supported on this GPU type.")
@pytest.mark.parametrize("model_id", MODELS) @pytest.mark.parametrize("model_id", MODELS)
@pytest.mark.parametrize("force_marlin", [False, True]) @pytest.mark.parametrize("force_marlin", [False, True])
@pytest.mark.parametrize(
"use_rocm_aiter", [True, False] if current_platform.is_rocm() else [False])
def test_model_load_and_run(vllm_runner, model_id: str, force_marlin: bool, def test_model_load_and_run(vllm_runner, model_id: str, force_marlin: bool,
monkeypatch) -> None: use_rocm_aiter: bool, monkeypatch) -> None:
if use_rocm_aiter:
monkeypatch.setenv("VLLM_ROCM_USE_AITER", "1")
if force_marlin: if force_marlin:
monkeypatch.setenv("VLLM_TEST_FORCE_FP8_MARLIN", "1") monkeypatch.setenv("VLLM_TEST_FORCE_FP8_MARLIN", "1")
@ -47,7 +53,13 @@ KV_CACHE_MODELS = [
@pytest.mark.skipif(not is_quant_method_supported("fp8"), @pytest.mark.skipif(not is_quant_method_supported("fp8"),
reason="FP8 is not supported on this GPU type.") reason="FP8 is not supported on this GPU type.")
@pytest.mark.parametrize("model_id", KV_CACHE_MODELS) @pytest.mark.parametrize("model_id", KV_CACHE_MODELS)
def test_kv_cache_model_load_and_run(vllm_runner, model_id: str, monkeypatch): @pytest.mark.parametrize(
"use_rocm_aiter", [True, False] if current_platform.is_rocm() else [False])
def test_kv_cache_model_load_and_run(vllm_runner, model_id: str,
use_rocm_aiter: bool, monkeypatch):
if use_rocm_aiter:
monkeypatch.setenv("VLLM_ROCM_USE_AITER", "1")
# vllm_runner.apply_model() relies on V0 internals. # vllm_runner.apply_model() relies on V0 internals.
monkeypatch.setenv("VLLM_USE_V1", "0") monkeypatch.setenv("VLLM_USE_V1", "0")
with vllm_runner(model_id, kv_cache_dtype="fp8") as llm: with vllm_runner(model_id, kv_cache_dtype="fp8") as llm:
@ -86,8 +98,13 @@ def test_kv_cache_model_load_and_run(vllm_runner, model_id: str, monkeypatch):
reason="FP8 is not supported on this GPU type.") reason="FP8 is not supported on this GPU type.")
@pytest.mark.parametrize("kv_cache_dtype", ["auto", "fp8"]) @pytest.mark.parametrize("kv_cache_dtype", ["auto", "fp8"])
@pytest.mark.parametrize("force_marlin", [False, True]) @pytest.mark.parametrize("force_marlin", [False, True])
@pytest.mark.parametrize(
"use_rocm_aiter", [True, False] if current_platform.is_rocm() else [False])
def test_load_fp16_model(vllm_runner, kv_cache_dtype: str, force_marlin: bool, def test_load_fp16_model(vllm_runner, kv_cache_dtype: str, force_marlin: bool,
monkeypatch) -> None: use_rocm_aiter: bool, monkeypatch) -> None:
if use_rocm_aiter:
monkeypatch.setenv("VLLM_ROCM_USE_AITER", "1")
# vllm_runner.apply_model() relies on V0 internals. # vllm_runner.apply_model() relies on V0 internals.
monkeypatch.setenv("VLLM_USE_V1", "0") monkeypatch.setenv("VLLM_USE_V1", "0")

View File

@ -73,6 +73,8 @@ if TYPE_CHECKING:
VLLM_DISABLED_KERNELS: list[str] = [] VLLM_DISABLED_KERNELS: list[str] = []
VLLM_USE_V1: bool = True VLLM_USE_V1: bool = True
VLLM_ROCM_USE_AITER: bool = False VLLM_ROCM_USE_AITER: bool = False
VLLM_ROCM_USE_AITER_MOE: bool = True
VLLM_ROCM_USE_AITER_FP8_BLOCK_SCALED_MOE: bool = False
VLLM_ROCM_USE_AITER_RMSNORM: bool = True VLLM_ROCM_USE_AITER_RMSNORM: bool = True
VLLM_ROCM_FP8_PADDING: bool = True VLLM_ROCM_FP8_PADDING: bool = True
VLLM_ROCM_MOE_PADDING: bool = True VLLM_ROCM_MOE_PADDING: bool = True
@ -513,6 +515,19 @@ environment_variables: dict[str, Callable[[], Any]] = {
lambda: (os.getenv("VLLM_ROCM_USE_AITER", "False").lower() in lambda: (os.getenv("VLLM_ROCM_USE_AITER", "False").lower() in
("true", "1")), ("true", "1")),
# Whether to use aiter moe ops.
# By default is enabled.
"VLLM_ROCM_USE_AITER_MOE":
lambda: (os.getenv("VLLM_ROCM_USE_AITER_MOE", "True").lower() in
("true", "1")),
# Whether to use aiter block scaled moe kernel.
# By default this is disabled.
"VLLM_ROCM_USE_AITER_FP8_BLOCK_SCALED_MOE":
lambda:
(os.getenv("VLLM_ROCM_USE_AITER_FP8_BLOCK_SCALED_MOE", "false").lower() in
("true", "1")),
# use aiter rms norm op if aiter ops are enabled. # use aiter rms norm op if aiter ops are enabled.
"VLLM_ROCM_USE_AITER_RMSNORM": "VLLM_ROCM_USE_AITER_RMSNORM":
lambda: (os.getenv("VLLM_ROCM_USE_AITER_RMSNORM", "True").lower() in lambda: (os.getenv("VLLM_ROCM_USE_AITER_RMSNORM", "True").lower() in

View File

@ -17,6 +17,10 @@ from vllm.model_executor.layers.quantization.utils.fp8_utils import (
from vllm.platforms import current_platform from vllm.platforms import current_platform
from vllm.utils import direct_register_custom_op from vllm.utils import direct_register_custom_op
from .rocm_aiter_fused_moe import (is_rocm_aiter_moe_enabled,
rocm_aiter_fused_experts,
rocm_aiter_topk_softmax)
logger = init_logger(__name__) logger = init_logger(__name__)
@ -1035,6 +1039,28 @@ def try_get_optimal_moe_config(
return config return config
def vllm_topk_softmax(topk_weights: torch.Tensor, topk_indices: torch.Tensor,
token_expert_indices: torch.Tensor,
gating_output: torch.Tensor,
renormalize: bool) -> tuple[torch.Tensor, ...]:
ops.topk_softmax(
topk_weights,
topk_indices,
token_expert_indices,
gating_output,
)
if renormalize:
topk_weights = topk_weights / topk_weights.sum(dim=-1, keepdim=True)
return topk_weights, topk_indices
def dispatch_topk_func() -> Callable[..., tuple[torch.Tensor, ...]]:
if is_rocm_aiter_moe_enabled():
return rocm_aiter_topk_softmax
return vllm_topk_softmax
def fused_topk( def fused_topk(
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
gating_output: torch.Tensor, gating_output: torch.Tensor,
@ -1059,17 +1085,14 @@ def fused_topk(
dtype=torch.int32, dtype=torch.int32,
device=hidden_states.device) device=hidden_states.device)
ops.topk_softmax( gating_output_float = gating_output.float() # TODO(woosuk): Optimize this.
topk_weights,
topk_ids, topk_func = dispatch_topk_func()
token_expert_indicies, topk_weights, topk_ids = topk_func(topk_weights, topk_ids,
gating_output.float(), # TODO(woosuk): Optimize this. token_expert_indicies,
) gating_output_float, renormalize)
del token_expert_indicies # Not used. Will be used in the future. del token_expert_indicies # Not used. Will be used in the future.
if renormalize:
topk_weights = topk_weights / topk_weights.sum(dim=-1, keepdim=True)
return topk_weights, topk_ids return topk_weights, topk_ids
@ -1259,6 +1282,24 @@ direct_register_custom_op(
) )
def torch_vllm_inplace_fused_experts(**kwargs) -> torch.Tensor:
torch.ops.vllm.inplace_fused_experts(**kwargs)
hidden_states = kwargs['hidden_states']
return hidden_states
def torch_vllm_outplace_fused_experts(**kwargs) -> torch.Tensor:
return torch.ops.vllm.outplace_fused_experts(**kwargs)
def dispatch_fused_experts_func(inplace: bool) -> Callable[..., torch.Tensor]:
if is_rocm_aiter_moe_enabled():
return rocm_aiter_fused_experts
if inplace:
return torch_vllm_inplace_fused_experts
return torch_vllm_outplace_fused_experts
def fused_experts(hidden_states: torch.Tensor, def fused_experts(hidden_states: torch.Tensor,
w1: torch.Tensor, w1: torch.Tensor,
w2: torch.Tensor, w2: torch.Tensor,
@ -1278,20 +1319,25 @@ def fused_experts(hidden_states: torch.Tensor,
a1_scale: Optional[torch.Tensor] = None, a1_scale: Optional[torch.Tensor] = None,
a2_scale: Optional[torch.Tensor] = None, a2_scale: Optional[torch.Tensor] = None,
block_shape: Optional[List[int]] = None) -> torch.Tensor: block_shape: Optional[List[int]] = None) -> torch.Tensor:
return dispatch_fused_experts_func(inplace)(
if inplace: hidden_states=hidden_states,
torch.ops.vllm.inplace_fused_experts( w1=w1,
hidden_states, w1, w2, topk_weights, topk_ids, activation, w2=w2,
use_fp8_w8a8, use_int8_w8a16, use_int4_w4a16, global_num_experts, topk_weights=topk_weights,
expert_map, w1_scale, w2_scale, w1_zp, w2_zp, a1_scale, a2_scale, topk_ids=topk_ids,
block_shape) activation=activation,
return hidden_states use_fp8_w8a8=use_fp8_w8a8,
else: use_int8_w8a16=use_int8_w8a16,
return torch.ops.vllm.outplace_fused_experts( use_int4_w4a16=use_int4_w4a16,
hidden_states, w1, w2, topk_weights, topk_ids, activation, global_num_experts=global_num_experts,
use_fp8_w8a8, use_int8_w8a16, use_int4_w4a16, global_num_experts, expert_map=expert_map,
expert_map, w1_scale, w2_scale, w1_zp, w2_zp, a1_scale, a2_scale, w1_scale=w1_scale,
block_shape) w2_scale=w2_scale,
w1_zp=w1_zp,
w2_zp=w2_zp,
a1_scale=a1_scale,
a2_scale=a2_scale,
block_shape=block_shape)
def fused_experts_impl(hidden_states: torch.Tensor, def fused_experts_impl(hidden_states: torch.Tensor,

View File

@ -8,7 +8,7 @@ import torch
import torch.nn.functional as F import torch.nn.functional as F
from torch.nn.parameter import UninitializedParameter from torch.nn.parameter import UninitializedParameter
from vllm import envs import vllm.envs as envs
from vllm.config import get_current_vllm_config from vllm.config import get_current_vllm_config
from vllm.distributed import (get_dp_group, get_tensor_model_parallel_rank, from vllm.distributed import (get_dp_group, get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size, get_tensor_model_parallel_world_size,
@ -16,6 +16,8 @@ from vllm.distributed import (get_dp_group, get_tensor_model_parallel_rank,
from vllm.forward_context import ForwardContext, get_forward_context from vllm.forward_context import ForwardContext, get_forward_context
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.model_executor.custom_op import CustomOp from vllm.model_executor.custom_op import CustomOp
from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import (
is_rocm_aiter_moe_enabled, shuffle_weights)
from vllm.model_executor.layers.quantization.base_config import ( from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig, QuantizeMethodBase) QuantizationConfig, QuantizeMethodBase)
from vllm.model_executor.utils import set_weight_attrs from vllm.model_executor.utils import set_weight_attrs
@ -118,6 +120,16 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
layer.w2_weight.data), layer.w2_weight.data),
requires_grad=False) requires_grad=False)
if is_rocm_aiter_moe_enabled():
# reshaping weights is required for aiter moe kernel.
shuffled_w13, shuffled_w2 = shuffle_weights(
layer.w13_weight.data, layer.w2_weight.data)
layer.w13_weight = torch.nn.Parameter(shuffled_w13,
requires_grad=False)
layer.w2_weight = torch.nn.Parameter(shuffled_w2,
requires_grad=False)
if current_platform.is_cpu(): if current_platform.is_cpu():
if current_platform.get_cpu_architecture() == CpuArchEnum.X86: if current_platform.get_cpu_architecture() == CpuArchEnum.X86:
import intel_extension_for_pytorch as ipex import intel_extension_for_pytorch as ipex

View File

@ -0,0 +1,157 @@
# SPDX-License-Identifier: Apache-2.0
from typing import List, Optional
import torch
import vllm.envs as envs
from vllm.model_executor.layers.quantization.utils.fp8_utils import (
per_token_group_quant_fp8)
from vllm.platforms import current_platform
def is_rocm_aiter_moe_enabled() -> bool:
return current_platform.is_rocm() \
and envs.VLLM_ROCM_USE_AITER_MOE \
and envs.VLLM_ROCM_USE_AITER \
def is_rocm_aiter_block_scaled_moe_enabled() -> bool:
return is_rocm_aiter_moe_enabled() and \
envs.VLLM_ROCM_USE_AITER_FP8_BLOCK_SCALED_MOE
def rocm_aiter_fused_experts(
*,
hidden_states: torch.Tensor,
w1: torch.Tensor,
w2: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
use_fp8_w8a8: bool = False,
w1_scale: Optional[torch.Tensor] = None,
w2_scale: Optional[torch.Tensor] = None,
block_shape: Optional[List[int]] = None,
expert_mask: Optional[torch.Tensor] = None,
**kwagrs # Ignore additional keyword arguments
) -> torch.Tensor:
import aiter as rocm_aiter
import aiter.fused_moe_bf16_asm as rocm_aiter_asm_fmoe
if envs.VLLM_ROCM_USE_AITER_FP8_BLOCK_SCALED_MOE and use_fp8_w8a8:
assert w1_scale is not None
assert w2_scale is not None
local_E = E = w1.shape[0]
if expert_mask is not None:
E = expert_mask.numel()
topk = topk_ids.shape[1]
model_dim = w1.shape[-1]
dtype = hidden_states.dtype
# The default block sizes are 128 in AITER.
if block_shape is None:
block_shape = [128, 128]
scale_blk_k = block_shape[1]
(
sorted_token_ids,
sorted_weight_buf,
sorted_expert_ids,
num_valid_ids,
out_asm,
) = rocm_aiter_asm_fmoe.moe_sorting_ck(topk_ids,
topk_weights,
E,
model_dim,
dtype,
expert_mask=expert_mask)
a1, a1_scale = per_token_group_quant_fp8(hidden_states, scale_blk_k)
rocm_aiter.fmoe_fp8_blockscale_g1u1(
out_asm,
a1,
w1,
w2,
sorted_token_ids,
sorted_weight_buf,
sorted_expert_ids,
num_valid_ids,
topk,
w1_scale.view(local_E, -1),
w2_scale.view(local_E, -1),
a1_scale.t().contiguous(),
block_shape[0],
block_shape[1],
None,
)
return out_asm
elif use_fp8_w8a8:
return rocm_aiter_asm_fmoe.asm_moe(hidden_states=hidden_states,
w1=w1,
w2=w2,
topk_weight=topk_weights,
topk_ids=topk_ids,
fc1_scale=w1_scale,
fc2_scale=w2_scale,
fc1_smooth_scale=None,
fc2_smooth_scale=None,
a16=False)
return rocm_aiter.ck_moe(hidden_states=hidden_states,
w1=w1,
w2=w2,
topk_weights=topk_weights,
topk_ids=topk_ids)
def rocm_aiter_topk_softmax(topk_weights: torch.Tensor,
topk_indices: torch.Tensor,
token_expert_indices: torch.Tensor,
gating_output: torch.Tensor,
renormalize: bool) -> tuple[torch.Tensor, ...]:
import aiter as rocm_aiter
rocm_aiter.topk_softmax(topk_weights, topk_indices, token_expert_indices,
gating_output, renormalize)
return topk_weights, topk_indices
def shuffle_weights(*tensors: torch.Tensor) -> tuple[torch.Tensor, ...]:
"""
Applies shuffle_weight function from AITER to each
input tensor and returns them.
Args:
*tensors: Variable number of torch.Tensor objects.
Returns:
A tuple of shuffled tensors.
"""
from aiter.ops.shuffle import shuffle_weight
return tuple(shuffle_weight(tensor) for tensor in tensors)
def expand_weights(*tensors: torch.Tensor,
expansion_dims: list[int]) -> tuple[torch.Tensor, ...]:
"""
Expands the dimensions of input tensors.
Args:
*tensors: A variable number of torch.Tensor objects.
expansion_dims: A list of expansion dimensions
corresponding to each tensor.
Returns:
A tuple of tensors with expanded dimensions.
"""
assert len(tensors) == len(expansion_dims), \
"Number of tensors must match the number of expansion dimensions."
return tuple(
tensor.unsqueeze(-1).unsqueeze(-1).expand((-1, dim, -1))
for tensor, dim in zip(tensors, expansion_dims))

View File

@ -13,6 +13,9 @@ from vllm.distributed import get_tensor_model_parallel_world_size
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.model_executor.layers.fused_moe import (FusedMoE, FusedMoEMethodBase, from vllm.model_executor.layers.fused_moe import (FusedMoE, FusedMoEMethodBase,
FusedMoeWeightScaleSupported) FusedMoeWeightScaleSupported)
from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import (
expand_weights, is_rocm_aiter_block_scaled_moe_enabled,
is_rocm_aiter_moe_enabled, shuffle_weights)
from vllm.model_executor.layers.linear import (LinearBase, LinearMethodBase, from vllm.model_executor.layers.linear import (LinearBase, LinearMethodBase,
UnquantizedLinearMethod) UnquantizedLinearMethod)
from vllm.model_executor.layers.quantization.base_config import ( from vllm.model_executor.layers.quantization.base_config import (
@ -554,6 +557,15 @@ class Fp8MoEMethod(FusedMoEMethodBase):
layer.w2_weight = Parameter(w2_weight, requires_grad=False) layer.w2_weight = Parameter(w2_weight, requires_grad=False)
layer.w2_weight_scale_inv = Parameter(w2_weight_scale_inv, layer.w2_weight_scale_inv = Parameter(w2_weight_scale_inv,
requires_grad=False) requires_grad=False)
if is_rocm_aiter_block_scaled_moe_enabled():
# reshaping weights is required for aiter moe kernel.
shuffled_w13, shuffled_w2 = shuffle_weights(
layer.w13_weight.data, layer.w2_weight.data)
layer.w13_weight = torch.nn.Parameter(shuffled_w13,
requires_grad=False)
layer.w2_weight = torch.nn.Parameter(shuffled_w2,
requires_grad=False)
return return
# If checkpoint is fp16, quantize in place. # If checkpoint is fp16, quantize in place.
@ -581,6 +593,26 @@ class Fp8MoEMethod(FusedMoEMethodBase):
requires_grad=False) requires_grad=False)
layer.w2_weight = torch.nn.Parameter(w2_weight, layer.w2_weight = torch.nn.Parameter(w2_weight,
requires_grad=False) requires_grad=False)
if is_rocm_aiter_moe_enabled():
# reshaping weights is required for aiter moe kernel.
w13_scales, w2_scales = expand_weights(
layer.w13_weight_scale.data,
layer.w2_weight_scale.data,
expansion_dims=[
layer.w13_weight.shape[1], layer.w2_weight.shape[1]
])
layer.w13_weight_scale = torch.nn.Parameter(
w13_scales.contiguous(), requires_grad=False)
layer.w2_weight_scale = torch.nn.Parameter(
w2_scales.contiguous(), requires_grad=False)
shuffled_w13, shuffled_w2 = shuffle_weights(
layer.w13_weight, layer.w2_weight)
layer.w13_weight = torch.nn.Parameter(shuffled_w13,
requires_grad=False)
layer.w2_weight = torch.nn.Parameter(shuffled_w2,
requires_grad=False)
return return
# If checkpoint is fp8, we need to handle that the # If checkpoint is fp8, we need to handle that the
@ -648,6 +680,26 @@ class Fp8MoEMethod(FusedMoEMethodBase):
dq_weight, max_w13_scales[expert_id]) dq_weight, max_w13_scales[expert_id])
start += shard_size start += shard_size
if is_rocm_aiter_moe_enabled():
# reshaping weights is required for aiter moe kernel.
expansion_dims = [
layer.w13_weight.shape[1], layer.w2_weight.shape[1]
]
max_w13_scales, w2_scales = expand_weights(
max_w13_scales,
layer.w2_weight_scale.data,
expansion_dims=expansion_dims)
layer.w2_weight_scale = torch.nn.Parameter(
w2_scales.contiguous(), requires_grad=False)
shuffled_w13, shuffled_w2 = shuffle_weights(
layer.w13_weight, layer.w2_weight)
layer.w13_weight = torch.nn.Parameter(shuffled_w13,
requires_grad=False)
layer.w2_weight = torch.nn.Parameter(shuffled_w2,
requires_grad=False)
layer.w13_weight_scale = torch.nn.Parameter(max_w13_scales, layer.w13_weight_scale = torch.nn.Parameter(max_w13_scales,
requires_grad=False) requires_grad=False)
return return