diff --git a/tests/kernels/test_moe.py b/tests/kernels/test_moe.py index 653d2734..3f4dd3cf 100644 --- a/tests/kernels/test_moe.py +++ b/tests/kernels/test_moe.py @@ -3,7 +3,6 @@ Run `pytest tests/kernels/test_moe.py`. """ - import pytest import torch from torch.nn import Parameter @@ -216,11 +215,17 @@ def test_fused_moe_wn16(m: int, n: int, k: int, e: int, topk: int, @pytest.mark.parametrize("dtype", [torch.float32, torch.float16, torch.bfloat16]) @pytest.mark.parametrize("padding", [True, False]) +@pytest.mark.parametrize( + "use_rocm_aiter", [True, False] if current_platform.is_rocm() else [False]) @torch.inference_mode() -def test_mixtral_moe(dtype: torch.dtype, padding: bool): +def test_mixtral_moe(dtype: torch.dtype, padding: bool, use_rocm_aiter: bool, + monkeypatch): """Make sure our Mixtral MoE implementation agrees with the one from huggingface.""" + if use_rocm_aiter: + monkeypatch.setenv("VLLM_ROCM_USE_AITER", "1") + # Instantiate our and huggingface's MoE blocks config = MixtralConfig() hf_moe = MixtralSparseMoeBlock(config).to(dtype).to("cuda") @@ -268,10 +273,18 @@ def test_mixtral_moe(dtype: torch.dtype, padding: bool): torch.bfloat16: 1e-2, } - torch.testing.assert_close(hf_states.flatten(0, 1), - vllm_states, - rtol=mixtral_moe_tol[dtype], - atol=mixtral_moe_tol[dtype]) + if use_rocm_aiter: + # The values of rtol and atol are set based on the tests in ROCM AITER package. # noqa: E501 + # https://github.com/ROCm/aiter/blob/dfed377f4be7da96ca2d75ac0761f569676f7240/op_tests/test_moe.py#L174 # noqa: E501 + torch.testing.assert_close(hf_states.flatten(0, 1), + vllm_states, + rtol=0.01, + atol=100) + else: + torch.testing.assert_close(hf_states.flatten(0, 1), + vllm_states, + rtol=mixtral_moe_tol[dtype], + atol=mixtral_moe_tol[dtype]) @pytest.mark.parametrize("m", [1, 33, 64, 222]) diff --git a/tests/model_executor/test_enabled_custom_ops.py b/tests/model_executor/test_enabled_custom_ops.py index 24147b74..ac2e0f35 100644 --- a/tests/model_executor/test_enabled_custom_ops.py +++ b/tests/model_executor/test_enabled_custom_ops.py @@ -7,6 +7,10 @@ from vllm.model_executor.custom_op import CustomOp from vllm.model_executor.layers.activation import (GeluAndMul, ReLUSquaredActivation, SiluAndMul) +from vllm.model_executor.layers.fused_moe.fused_moe import ( + dispatch_fused_experts_func, dispatch_topk_func, + torch_vllm_inplace_fused_experts, torch_vllm_outplace_fused_experts, + vllm_topk_softmax) from vllm.model_executor.layers.layernorm import ( RMSNorm, dispatch_cuda_rmsnorm_func, fused_add_rms_norm, rms_norm, rocm_aiter_fused_add_rms_norm, rocm_aiter_rms_norm) @@ -92,6 +96,38 @@ def test_enabled_ops_invalid(env: str): RMSNorm(1024).enabled() +@pytest.mark.parametrize("use_rocm_aiter", ["0", "1"]) +def test_topk_dispatch(use_rocm_aiter: str, monkeypatch): + monkeypatch.setenv("VLLM_ROCM_USE_AITER", use_rocm_aiter) + topk_func = dispatch_topk_func() + + if current_platform.is_rocm() and int(use_rocm_aiter): + from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import ( + rocm_aiter_topk_softmax) + + assert topk_func == rocm_aiter_topk_softmax + else: + assert topk_func == vllm_topk_softmax + + +@pytest.mark.parametrize("use_rocm_aiter", ["0", "1"]) +@pytest.mark.parametrize("inplace", [True, False]) +def test_fused_experts_dispatch(use_rocm_aiter: str, inplace: bool, + monkeypatch): + + monkeypatch.setenv("VLLM_ROCM_USE_AITER", use_rocm_aiter) + fused_experts_func = dispatch_fused_experts_func(inplace) + if current_platform.is_rocm() and int(use_rocm_aiter): + from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import ( + rocm_aiter_fused_experts) + + assert fused_experts_func == rocm_aiter_fused_experts + elif inplace: + assert fused_experts_func == torch_vllm_inplace_fused_experts + else: + assert fused_experts_func == torch_vllm_outplace_fused_experts + + @pytest.mark.parametrize("add_residual", [True, False]) @pytest.mark.parametrize("use_rocm_aiter", ["0", "1"]) @pytest.mark.parametrize("use_rocm_aiter_norm", ["0", "1"]) diff --git a/tests/models/decoder_only/language/test_mistral.py b/tests/models/decoder_only/language/test_mistral.py index 4c205536..ec885386 100644 --- a/tests/models/decoder_only/language/test_mistral.py +++ b/tests/models/decoder_only/language/test_mistral.py @@ -174,15 +174,8 @@ SAMPLE_JSON_SCHEMA = { @pytest.mark.parametrize("dtype", ["bfloat16"]) @pytest.mark.parametrize("max_tokens", [64]) @pytest.mark.parametrize("num_logprobs", [5]) -def test_models( - hf_runner, - vllm_runner, - example_prompts, - model: str, - dtype: str, - max_tokens: int, - num_logprobs: int, -) -> None: +def test_models(hf_runner, vllm_runner, example_prompts, model: str, + dtype: str, max_tokens: int, num_logprobs: int) -> None: # TODO(sang): Sliding window should be tested separately. with hf_runner(model, dtype=dtype) as hf_model: hf_outputs = hf_model.generate_greedy_logprobs_limit( @@ -206,14 +199,8 @@ def test_models( @pytest.mark.parametrize("dtype", ["bfloat16"]) @pytest.mark.parametrize("max_tokens", [64]) @pytest.mark.parametrize("num_logprobs", [5]) -def test_mistral_format( - vllm_runner, - example_prompts, - model: str, - dtype: str, - max_tokens: int, - num_logprobs: int, -) -> None: +def test_mistral_format(vllm_runner, example_prompts, model: str, dtype: str, + max_tokens: int, num_logprobs: int) -> None: with vllm_runner( model, dtype=dtype, @@ -244,11 +231,8 @@ def test_mistral_format( @pytest.mark.parametrize("model", MISTRAL_FORMAT_MODELS) @pytest.mark.parametrize("dtype", ["bfloat16"]) -def test_mistral_symbolic_languages( - vllm_runner, - model: str, - dtype: str, -) -> None: +def test_mistral_symbolic_languages(vllm_runner, model: str, + dtype: str) -> None: with vllm_runner(model, dtype=dtype, max_model_len=8192, @@ -266,11 +250,7 @@ def test_mistral_symbolic_languages( @pytest.mark.parametrize("dtype", ["bfloat16"]) @pytest.mark.parametrize("model", MISTRAL_FORMAT_MODELS) # v1 can't do func calling -def test_mistral_function_calling( - vllm_runner, - model: str, - dtype: str, -) -> None: +def test_mistral_function_calling(vllm_runner, model: str, dtype: str) -> None: with vllm_runner(model, dtype=dtype, tokenizer_mode="mistral", @@ -301,11 +281,8 @@ def test_mistral_function_calling( @pytest.mark.parametrize("model", MODELS) @pytest.mark.parametrize("guided_backend", ["outlines", "lm-format-enforcer", "xgrammar"]) -def test_mistral_guided_decoding( - vllm_runner, - model: str, - guided_backend: str, -) -> None: +def test_mistral_guided_decoding(vllm_runner, model: str, + guided_backend: str) -> None: with vllm_runner(model, dtype='bfloat16', tokenizer_mode="mistral") as vllm_model: diff --git a/tests/quantization/test_fp8.py b/tests/quantization/test_fp8.py index 19cf29d3..e74e14a0 100644 --- a/tests/quantization/test_fp8.py +++ b/tests/quantization/test_fp8.py @@ -23,8 +23,14 @@ MODELS = [ reason="FP8 is not supported on this GPU type.") @pytest.mark.parametrize("model_id", MODELS) @pytest.mark.parametrize("force_marlin", [False, True]) +@pytest.mark.parametrize( + "use_rocm_aiter", [True, False] if current_platform.is_rocm() else [False]) def test_model_load_and_run(vllm_runner, model_id: str, force_marlin: bool, - monkeypatch) -> None: + use_rocm_aiter: bool, monkeypatch) -> None: + + if use_rocm_aiter: + monkeypatch.setenv("VLLM_ROCM_USE_AITER", "1") + if force_marlin: monkeypatch.setenv("VLLM_TEST_FORCE_FP8_MARLIN", "1") @@ -47,7 +53,13 @@ KV_CACHE_MODELS = [ @pytest.mark.skipif(not is_quant_method_supported("fp8"), reason="FP8 is not supported on this GPU type.") @pytest.mark.parametrize("model_id", KV_CACHE_MODELS) -def test_kv_cache_model_load_and_run(vllm_runner, model_id: str, monkeypatch): +@pytest.mark.parametrize( + "use_rocm_aiter", [True, False] if current_platform.is_rocm() else [False]) +def test_kv_cache_model_load_and_run(vllm_runner, model_id: str, + use_rocm_aiter: bool, monkeypatch): + if use_rocm_aiter: + monkeypatch.setenv("VLLM_ROCM_USE_AITER", "1") + # vllm_runner.apply_model() relies on V0 internals. monkeypatch.setenv("VLLM_USE_V1", "0") with vllm_runner(model_id, kv_cache_dtype="fp8") as llm: @@ -86,8 +98,13 @@ def test_kv_cache_model_load_and_run(vllm_runner, model_id: str, monkeypatch): reason="FP8 is not supported on this GPU type.") @pytest.mark.parametrize("kv_cache_dtype", ["auto", "fp8"]) @pytest.mark.parametrize("force_marlin", [False, True]) +@pytest.mark.parametrize( + "use_rocm_aiter", [True, False] if current_platform.is_rocm() else [False]) def test_load_fp16_model(vllm_runner, kv_cache_dtype: str, force_marlin: bool, - monkeypatch) -> None: + use_rocm_aiter: bool, monkeypatch) -> None: + if use_rocm_aiter: + monkeypatch.setenv("VLLM_ROCM_USE_AITER", "1") + # vllm_runner.apply_model() relies on V0 internals. monkeypatch.setenv("VLLM_USE_V1", "0") diff --git a/vllm/envs.py b/vllm/envs.py index b4305d9c..4c413006 100644 --- a/vllm/envs.py +++ b/vllm/envs.py @@ -73,6 +73,8 @@ if TYPE_CHECKING: VLLM_DISABLED_KERNELS: list[str] = [] VLLM_USE_V1: bool = True VLLM_ROCM_USE_AITER: bool = False + VLLM_ROCM_USE_AITER_MOE: bool = True + VLLM_ROCM_USE_AITER_FP8_BLOCK_SCALED_MOE: bool = False VLLM_ROCM_USE_AITER_RMSNORM: bool = True VLLM_ROCM_FP8_PADDING: bool = True VLLM_ROCM_MOE_PADDING: bool = True @@ -513,6 +515,19 @@ environment_variables: dict[str, Callable[[], Any]] = { lambda: (os.getenv("VLLM_ROCM_USE_AITER", "False").lower() in ("true", "1")), + # Whether to use aiter moe ops. + # By default is enabled. + "VLLM_ROCM_USE_AITER_MOE": + lambda: (os.getenv("VLLM_ROCM_USE_AITER_MOE", "True").lower() in + ("true", "1")), + + # Whether to use aiter block scaled moe kernel. + # By default this is disabled. + "VLLM_ROCM_USE_AITER_FP8_BLOCK_SCALED_MOE": + lambda: + (os.getenv("VLLM_ROCM_USE_AITER_FP8_BLOCK_SCALED_MOE", "false").lower() in + ("true", "1")), + # use aiter rms norm op if aiter ops are enabled. "VLLM_ROCM_USE_AITER_RMSNORM": lambda: (os.getenv("VLLM_ROCM_USE_AITER_RMSNORM", "True").lower() in diff --git a/vllm/model_executor/layers/fused_moe/fused_moe.py b/vllm/model_executor/layers/fused_moe/fused_moe.py index 4de020ff..97e915c6 100644 --- a/vllm/model_executor/layers/fused_moe/fused_moe.py +++ b/vllm/model_executor/layers/fused_moe/fused_moe.py @@ -17,6 +17,10 @@ from vllm.model_executor.layers.quantization.utils.fp8_utils import ( from vllm.platforms import current_platform from vllm.utils import direct_register_custom_op +from .rocm_aiter_fused_moe import (is_rocm_aiter_moe_enabled, + rocm_aiter_fused_experts, + rocm_aiter_topk_softmax) + logger = init_logger(__name__) @@ -1035,6 +1039,28 @@ def try_get_optimal_moe_config( return config +def vllm_topk_softmax(topk_weights: torch.Tensor, topk_indices: torch.Tensor, + token_expert_indices: torch.Tensor, + gating_output: torch.Tensor, + renormalize: bool) -> tuple[torch.Tensor, ...]: + ops.topk_softmax( + topk_weights, + topk_indices, + token_expert_indices, + gating_output, + ) + if renormalize: + topk_weights = topk_weights / topk_weights.sum(dim=-1, keepdim=True) + + return topk_weights, topk_indices + + +def dispatch_topk_func() -> Callable[..., tuple[torch.Tensor, ...]]: + if is_rocm_aiter_moe_enabled(): + return rocm_aiter_topk_softmax + return vllm_topk_softmax + + def fused_topk( hidden_states: torch.Tensor, gating_output: torch.Tensor, @@ -1059,17 +1085,14 @@ def fused_topk( dtype=torch.int32, device=hidden_states.device) - ops.topk_softmax( - topk_weights, - topk_ids, - token_expert_indicies, - gating_output.float(), # TODO(woosuk): Optimize this. - ) + gating_output_float = gating_output.float() # TODO(woosuk): Optimize this. + + topk_func = dispatch_topk_func() + topk_weights, topk_ids = topk_func(topk_weights, topk_ids, + token_expert_indicies, + gating_output_float, renormalize) + del token_expert_indicies # Not used. Will be used in the future. - - if renormalize: - topk_weights = topk_weights / topk_weights.sum(dim=-1, keepdim=True) - return topk_weights, topk_ids @@ -1259,6 +1282,24 @@ direct_register_custom_op( ) +def torch_vllm_inplace_fused_experts(**kwargs) -> torch.Tensor: + torch.ops.vllm.inplace_fused_experts(**kwargs) + hidden_states = kwargs['hidden_states'] + return hidden_states + + +def torch_vllm_outplace_fused_experts(**kwargs) -> torch.Tensor: + return torch.ops.vllm.outplace_fused_experts(**kwargs) + + +def dispatch_fused_experts_func(inplace: bool) -> Callable[..., torch.Tensor]: + if is_rocm_aiter_moe_enabled(): + return rocm_aiter_fused_experts + if inplace: + return torch_vllm_inplace_fused_experts + return torch_vllm_outplace_fused_experts + + def fused_experts(hidden_states: torch.Tensor, w1: torch.Tensor, w2: torch.Tensor, @@ -1278,20 +1319,25 @@ def fused_experts(hidden_states: torch.Tensor, a1_scale: Optional[torch.Tensor] = None, a2_scale: Optional[torch.Tensor] = None, block_shape: Optional[List[int]] = None) -> torch.Tensor: - - if inplace: - torch.ops.vllm.inplace_fused_experts( - hidden_states, w1, w2, topk_weights, topk_ids, activation, - use_fp8_w8a8, use_int8_w8a16, use_int4_w4a16, global_num_experts, - expert_map, w1_scale, w2_scale, w1_zp, w2_zp, a1_scale, a2_scale, - block_shape) - return hidden_states - else: - return torch.ops.vllm.outplace_fused_experts( - hidden_states, w1, w2, topk_weights, topk_ids, activation, - use_fp8_w8a8, use_int8_w8a16, use_int4_w4a16, global_num_experts, - expert_map, w1_scale, w2_scale, w1_zp, w2_zp, a1_scale, a2_scale, - block_shape) + return dispatch_fused_experts_func(inplace)( + hidden_states=hidden_states, + w1=w1, + w2=w2, + topk_weights=topk_weights, + topk_ids=topk_ids, + activation=activation, + use_fp8_w8a8=use_fp8_w8a8, + use_int8_w8a16=use_int8_w8a16, + use_int4_w4a16=use_int4_w4a16, + global_num_experts=global_num_experts, + expert_map=expert_map, + w1_scale=w1_scale, + w2_scale=w2_scale, + w1_zp=w1_zp, + w2_zp=w2_zp, + a1_scale=a1_scale, + a2_scale=a2_scale, + block_shape=block_shape) def fused_experts_impl(hidden_states: torch.Tensor, diff --git a/vllm/model_executor/layers/fused_moe/layer.py b/vllm/model_executor/layers/fused_moe/layer.py index bc134f67..b72f51aa 100644 --- a/vllm/model_executor/layers/fused_moe/layer.py +++ b/vllm/model_executor/layers/fused_moe/layer.py @@ -8,7 +8,7 @@ import torch import torch.nn.functional as F from torch.nn.parameter import UninitializedParameter -from vllm import envs +import vllm.envs as envs from vllm.config import get_current_vllm_config from vllm.distributed import (get_dp_group, get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size, @@ -16,6 +16,8 @@ from vllm.distributed import (get_dp_group, get_tensor_model_parallel_rank, from vllm.forward_context import ForwardContext, get_forward_context from vllm.logger import init_logger from vllm.model_executor.custom_op import CustomOp +from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import ( + is_rocm_aiter_moe_enabled, shuffle_weights) from vllm.model_executor.layers.quantization.base_config import ( QuantizationConfig, QuantizeMethodBase) from vllm.model_executor.utils import set_weight_attrs @@ -118,6 +120,16 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp): layer.w2_weight.data), requires_grad=False) + if is_rocm_aiter_moe_enabled(): + # reshaping weights is required for aiter moe kernel. + shuffled_w13, shuffled_w2 = shuffle_weights( + layer.w13_weight.data, layer.w2_weight.data) + + layer.w13_weight = torch.nn.Parameter(shuffled_w13, + requires_grad=False) + layer.w2_weight = torch.nn.Parameter(shuffled_w2, + requires_grad=False) + if current_platform.is_cpu(): if current_platform.get_cpu_architecture() == CpuArchEnum.X86: import intel_extension_for_pytorch as ipex diff --git a/vllm/model_executor/layers/fused_moe/rocm_aiter_fused_moe.py b/vllm/model_executor/layers/fused_moe/rocm_aiter_fused_moe.py new file mode 100644 index 00000000..c9bb6767 --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/rocm_aiter_fused_moe.py @@ -0,0 +1,157 @@ +# SPDX-License-Identifier: Apache-2.0 +from typing import List, Optional + +import torch + +import vllm.envs as envs +from vllm.model_executor.layers.quantization.utils.fp8_utils import ( + per_token_group_quant_fp8) +from vllm.platforms import current_platform + + +def is_rocm_aiter_moe_enabled() -> bool: + return current_platform.is_rocm() \ + and envs.VLLM_ROCM_USE_AITER_MOE \ + and envs.VLLM_ROCM_USE_AITER \ + + +def is_rocm_aiter_block_scaled_moe_enabled() -> bool: + return is_rocm_aiter_moe_enabled() and \ + envs.VLLM_ROCM_USE_AITER_FP8_BLOCK_SCALED_MOE + + +def rocm_aiter_fused_experts( + *, + hidden_states: torch.Tensor, + w1: torch.Tensor, + w2: torch.Tensor, + topk_weights: torch.Tensor, + topk_ids: torch.Tensor, + use_fp8_w8a8: bool = False, + w1_scale: Optional[torch.Tensor] = None, + w2_scale: Optional[torch.Tensor] = None, + block_shape: Optional[List[int]] = None, + expert_mask: Optional[torch.Tensor] = None, + **kwagrs # Ignore additional keyword arguments +) -> torch.Tensor: + + import aiter as rocm_aiter + import aiter.fused_moe_bf16_asm as rocm_aiter_asm_fmoe + + if envs.VLLM_ROCM_USE_AITER_FP8_BLOCK_SCALED_MOE and use_fp8_w8a8: + assert w1_scale is not None + assert w2_scale is not None + + local_E = E = w1.shape[0] + if expert_mask is not None: + E = expert_mask.numel() + + topk = topk_ids.shape[1] + model_dim = w1.shape[-1] + dtype = hidden_states.dtype + # The default block sizes are 128 in AITER. + if block_shape is None: + block_shape = [128, 128] + + scale_blk_k = block_shape[1] + + ( + sorted_token_ids, + sorted_weight_buf, + sorted_expert_ids, + num_valid_ids, + out_asm, + ) = rocm_aiter_asm_fmoe.moe_sorting_ck(topk_ids, + topk_weights, + E, + model_dim, + dtype, + expert_mask=expert_mask) + + a1, a1_scale = per_token_group_quant_fp8(hidden_states, scale_blk_k) + rocm_aiter.fmoe_fp8_blockscale_g1u1( + out_asm, + a1, + w1, + w2, + sorted_token_ids, + sorted_weight_buf, + sorted_expert_ids, + num_valid_ids, + topk, + w1_scale.view(local_E, -1), + w2_scale.view(local_E, -1), + a1_scale.t().contiguous(), + block_shape[0], + block_shape[1], + None, + ) + return out_asm + + elif use_fp8_w8a8: + return rocm_aiter_asm_fmoe.asm_moe(hidden_states=hidden_states, + w1=w1, + w2=w2, + topk_weight=topk_weights, + topk_ids=topk_ids, + fc1_scale=w1_scale, + fc2_scale=w2_scale, + fc1_smooth_scale=None, + fc2_smooth_scale=None, + a16=False) + + return rocm_aiter.ck_moe(hidden_states=hidden_states, + w1=w1, + w2=w2, + topk_weights=topk_weights, + topk_ids=topk_ids) + + +def rocm_aiter_topk_softmax(topk_weights: torch.Tensor, + topk_indices: torch.Tensor, + token_expert_indices: torch.Tensor, + gating_output: torch.Tensor, + renormalize: bool) -> tuple[torch.Tensor, ...]: + import aiter as rocm_aiter + rocm_aiter.topk_softmax(topk_weights, topk_indices, token_expert_indices, + gating_output, renormalize) + + return topk_weights, topk_indices + + +def shuffle_weights(*tensors: torch.Tensor) -> tuple[torch.Tensor, ...]: + """ + Applies shuffle_weight function from AITER to each + input tensor and returns them. + + Args: + *tensors: Variable number of torch.Tensor objects. + + Returns: + A tuple of shuffled tensors. + """ + from aiter.ops.shuffle import shuffle_weight + + return tuple(shuffle_weight(tensor) for tensor in tensors) + + +def expand_weights(*tensors: torch.Tensor, + expansion_dims: list[int]) -> tuple[torch.Tensor, ...]: + """ + Expands the dimensions of input tensors. + + Args: + *tensors: A variable number of torch.Tensor objects. + expansion_dims: A list of expansion dimensions + corresponding to each tensor. + + Returns: + A tuple of tensors with expanded dimensions. + """ + + assert len(tensors) == len(expansion_dims), \ + "Number of tensors must match the number of expansion dimensions." + + return tuple( + tensor.unsqueeze(-1).unsqueeze(-1).expand((-1, dim, -1)) + for tensor, dim in zip(tensors, expansion_dims)) diff --git a/vllm/model_executor/layers/quantization/fp8.py b/vllm/model_executor/layers/quantization/fp8.py index d92b0931..bc17a569 100644 --- a/vllm/model_executor/layers/quantization/fp8.py +++ b/vllm/model_executor/layers/quantization/fp8.py @@ -13,6 +13,9 @@ from vllm.distributed import get_tensor_model_parallel_world_size from vllm.logger import init_logger from vllm.model_executor.layers.fused_moe import (FusedMoE, FusedMoEMethodBase, FusedMoeWeightScaleSupported) +from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import ( + expand_weights, is_rocm_aiter_block_scaled_moe_enabled, + is_rocm_aiter_moe_enabled, shuffle_weights) from vllm.model_executor.layers.linear import (LinearBase, LinearMethodBase, UnquantizedLinearMethod) from vllm.model_executor.layers.quantization.base_config import ( @@ -554,6 +557,15 @@ class Fp8MoEMethod(FusedMoEMethodBase): layer.w2_weight = Parameter(w2_weight, requires_grad=False) layer.w2_weight_scale_inv = Parameter(w2_weight_scale_inv, requires_grad=False) + if is_rocm_aiter_block_scaled_moe_enabled(): + # reshaping weights is required for aiter moe kernel. + shuffled_w13, shuffled_w2 = shuffle_weights( + layer.w13_weight.data, layer.w2_weight.data) + + layer.w13_weight = torch.nn.Parameter(shuffled_w13, + requires_grad=False) + layer.w2_weight = torch.nn.Parameter(shuffled_w2, + requires_grad=False) return # If checkpoint is fp16, quantize in place. @@ -581,6 +593,26 @@ class Fp8MoEMethod(FusedMoEMethodBase): requires_grad=False) layer.w2_weight = torch.nn.Parameter(w2_weight, requires_grad=False) + if is_rocm_aiter_moe_enabled(): + # reshaping weights is required for aiter moe kernel. + w13_scales, w2_scales = expand_weights( + layer.w13_weight_scale.data, + layer.w2_weight_scale.data, + expansion_dims=[ + layer.w13_weight.shape[1], layer.w2_weight.shape[1] + ]) + layer.w13_weight_scale = torch.nn.Parameter( + w13_scales.contiguous(), requires_grad=False) + layer.w2_weight_scale = torch.nn.Parameter( + w2_scales.contiguous(), requires_grad=False) + + shuffled_w13, shuffled_w2 = shuffle_weights( + layer.w13_weight, layer.w2_weight) + + layer.w13_weight = torch.nn.Parameter(shuffled_w13, + requires_grad=False) + layer.w2_weight = torch.nn.Parameter(shuffled_w2, + requires_grad=False) return # If checkpoint is fp8, we need to handle that the @@ -648,6 +680,26 @@ class Fp8MoEMethod(FusedMoEMethodBase): dq_weight, max_w13_scales[expert_id]) start += shard_size + if is_rocm_aiter_moe_enabled(): + # reshaping weights is required for aiter moe kernel. + expansion_dims = [ + layer.w13_weight.shape[1], layer.w2_weight.shape[1] + ] + max_w13_scales, w2_scales = expand_weights( + max_w13_scales, + layer.w2_weight_scale.data, + expansion_dims=expansion_dims) + layer.w2_weight_scale = torch.nn.Parameter( + w2_scales.contiguous(), requires_grad=False) + + shuffled_w13, shuffled_w2 = shuffle_weights( + layer.w13_weight, layer.w2_weight) + + layer.w13_weight = torch.nn.Parameter(shuffled_w13, + requires_grad=False) + layer.w2_weight = torch.nn.Parameter(shuffled_w2, + requires_grad=False) + layer.w13_weight_scale = torch.nn.Parameter(max_w13_scales, requires_grad=False) return