[ROCm][Kernel] MoE weights padding (#14454)

Signed-off-by: Gregory Shtrasberg <Gregory.Shtrasberg@amd.com>
Signed-off-by: charlifu <charlifu@amd.com>
Co-authored-by: charlifu <charlifu@amd.com>
This commit is contained in:
Gregory Shtrasberg 2025-03-24 19:45:30 -04:00 committed by GitHub
parent 8279201ce6
commit f533b5837f
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 65 additions and 16 deletions

View File

@ -3,8 +3,11 @@
Run `pytest tests/kernels/test_moe.py`. Run `pytest tests/kernels/test_moe.py`.
""" """
import pytest import pytest
import torch import torch
from torch.nn import Parameter
from torch.nn import functional as F
from transformers import MixtralConfig from transformers import MixtralConfig
from transformers.models.mixtral.modeling_mixtral import MixtralSparseMoeBlock from transformers.models.mixtral.modeling_mixtral import MixtralSparseMoeBlock
@ -37,6 +40,7 @@ TOP_KS = [2, 6]
@pytest.mark.parametrize("topk", TOP_KS) @pytest.mark.parametrize("topk", TOP_KS)
@pytest.mark.parametrize("ep_size", EP_SIZE) @pytest.mark.parametrize("ep_size", EP_SIZE)
@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16]) @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16])
@pytest.mark.parametrize("padding", [True, False])
def test_fused_moe( def test_fused_moe(
m: int, m: int,
n: int, n: int,
@ -45,6 +49,7 @@ def test_fused_moe(
topk: int, topk: int,
ep_size: int, ep_size: int,
dtype: torch.dtype, dtype: torch.dtype,
padding: bool,
): ):
a = torch.randn((m, k), device="cuda", dtype=dtype) / 10 a = torch.randn((m, k), device="cuda", dtype=dtype) / 10
w1 = torch.randn((e, 2 * n, k), device="cuda", dtype=dtype) / 10 w1 = torch.randn((e, 2 * n, k), device="cuda", dtype=dtype) / 10
@ -65,16 +70,7 @@ def test_fused_moe(
else: else:
e_map = None e_map = None
triton_output = fused_moe(a,
w1,
w2,
score,
topk,
global_num_experts=e,
expert_map=e_map,
renormalize=False)
torch_output = torch_moe(a, w1, w2, score, topk, e_map) torch_output = torch_moe(a, w1, w2, score, topk, e_map)
torch.testing.assert_close(triton_output, torch_output, atol=2e-2, rtol=0)
iterative_output = iterative_moe(a, iterative_output = iterative_moe(a,
w1, w1,
w2, w2,
@ -83,6 +79,23 @@ def test_fused_moe(
global_num_experts=e, global_num_experts=e,
expert_map=e_map, expert_map=e_map,
renormalize=False) renormalize=False)
# Pad the weight if moe padding is enabled
if padding:
w1 = F.pad(w1, (0, 128), "constant", 0)[..., 0:-128]
torch.cuda.empty_cache()
w2 = F.pad(w2, (0, 128), "constant", 0)[..., 0:-128]
torch.cuda.empty_cache()
triton_output = fused_moe(a,
w1,
w2,
score,
topk,
global_num_experts=e,
expert_map=e_map,
renormalize=False)
torch.testing.assert_close(triton_output, torch_output, atol=2e-2, rtol=0)
torch.testing.assert_close(iterative_output, torch.testing.assert_close(iterative_output,
torch_output, torch_output,
atol=2e-2, atol=2e-2,
@ -202,8 +215,9 @@ def test_fused_moe_wn16(m: int, n: int, k: int, e: int, topk: int,
@pytest.mark.parametrize("dtype", @pytest.mark.parametrize("dtype",
[torch.float32, torch.float16, torch.bfloat16]) [torch.float32, torch.float16, torch.bfloat16])
@pytest.mark.parametrize("padding", [True, False])
@torch.inference_mode() @torch.inference_mode()
def test_mixtral_moe(dtype: torch.dtype): def test_mixtral_moe(dtype: torch.dtype, padding: bool):
"""Make sure our Mixtral MoE implementation agrees with the one from """Make sure our Mixtral MoE implementation agrees with the one from
huggingface.""" huggingface."""
@ -233,6 +247,17 @@ def test_mixtral_moe(dtype: torch.dtype):
# vLLM uses 1D query [num_tokens, hidden_dim] # vLLM uses 1D query [num_tokens, hidden_dim]
vllm_inputs = hf_inputs.flatten(0, 1) vllm_inputs = hf_inputs.flatten(0, 1)
# Pad the weight if moe padding is enabled
if padding:
vllm_moe.experts.w13_weight = Parameter(F.pad(
vllm_moe.experts.w13_weight, (0, 128), "constant", 0)[..., 0:-128],
requires_grad=False)
torch.cuda.empty_cache()
vllm_moe.experts.w2_weight = Parameter(F.pad(
vllm_moe.experts.w2_weight, (0, 128), "constant", 0)[..., 0:-128],
requires_grad=False)
torch.cuda.empty_cache()
# Run forward passes for both MoE blocks # Run forward passes for both MoE blocks
hf_states, _ = hf_moe.forward(hf_inputs) hf_states, _ = hf_moe.forward(hf_inputs)
vllm_states = vllm_moe.forward(vllm_inputs) vllm_states = vllm_moe.forward(vllm_inputs)

View File

@ -75,6 +75,7 @@ if TYPE_CHECKING:
VLLM_ROCM_USE_AITER: bool = False VLLM_ROCM_USE_AITER: bool = False
VLLM_ROCM_USE_AITER_RMSNORM: bool = True VLLM_ROCM_USE_AITER_RMSNORM: bool = True
VLLM_ROCM_FP8_PADDING: bool = True VLLM_ROCM_FP8_PADDING: bool = True
VLLM_ROCM_MOE_PADDING: bool = True
VLLM_ENABLE_V1_MULTIPROCESSING: bool = True VLLM_ENABLE_V1_MULTIPROCESSING: bool = True
VLLM_LOG_BATCHSIZE_INTERVAL: float = -1 VLLM_LOG_BATCHSIZE_INTERVAL: float = -1
VLLM_DISABLE_COMPILE_CACHE: bool = False VLLM_DISABLE_COMPILE_CACHE: bool = False
@ -520,6 +521,10 @@ environment_variables: dict[str, Callable[[], Any]] = {
"VLLM_ROCM_FP8_PADDING": "VLLM_ROCM_FP8_PADDING":
lambda: bool(int(os.getenv("VLLM_ROCM_FP8_PADDING", "1"))), lambda: bool(int(os.getenv("VLLM_ROCM_FP8_PADDING", "1"))),
# Pad the weights for the moe kernel
"VLLM_ROCM_MOE_PADDING":
lambda: bool(int(os.getenv("VLLM_ROCM_MOE_PADDING", "1"))),
# Divisor for dynamic query scale factor calculation for FP8 KV Cache # Divisor for dynamic query scale factor calculation for FP8 KV Cache
"Q_SCALE_CONSTANT": "Q_SCALE_CONSTANT":
lambda: int(os.getenv("Q_SCALE_CONSTANT", "200")), lambda: int(os.getenv("Q_SCALE_CONSTANT", "200")),

View File

@ -800,7 +800,7 @@ def invoke_fused_moe_kernel(A: torch.Tensor,
expert_ids, expert_ids,
num_tokens_post_padded, num_tokens_post_padded,
B.shape[1], B.shape[1],
A.shape[1], B.shape[2],
EM, EM,
topk_ids.numel(), topk_ids.numel(),
A.stride(0), A.stride(0),
@ -1322,8 +1322,8 @@ def fused_experts_impl(hidden_states: torch.Tensor,
assert topk_weights.shape == topk_ids.shape, "topk shape mismatch" assert topk_weights.shape == topk_ids.shape, "topk shape mismatch"
assert hidden_states.is_contiguous(), "Hidden_states must be contiguous" assert hidden_states.is_contiguous(), "Hidden_states must be contiguous"
assert w1.is_contiguous(), "Expert weights1 must be contiguous" assert w1.stride(-1) == 1, "Stride of last dimension must be 1"
assert w2.is_contiguous(), "Expert weights2 must be contiguous" assert w2.stride(-1) == 1, "Stride of last dimension must be 1"
assert hidden_states.dtype in [ assert hidden_states.dtype in [
torch.float32, torch.float16, torch.bfloat16 torch.float32, torch.float16, torch.bfloat16
] ]

View File

@ -5,6 +5,7 @@ from enum import Enum
from typing import Callable, List, Optional, Tuple from typing import Callable, List, Optional, Tuple
import torch import torch
import torch.nn.functional as F
from torch.nn.parameter import UninitializedParameter from torch.nn.parameter import UninitializedParameter
from vllm import envs from vllm import envs
@ -96,9 +97,27 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
layer.register_parameter("w2_weight", w2_weight) layer.register_parameter("w2_weight", w2_weight)
set_weight_attrs(w2_weight, extra_weight_attrs) set_weight_attrs(w2_weight, extra_weight_attrs)
def _maybe_pad_weight(self, weight: torch.Tensor) -> torch.Tensor:
# Pad the weight tensor. This is an optimization on ROCm platform, which
# can benefit from tensors located far enough from one another in memory
if (envs.VLLM_ROCM_MOE_PADDING and current_platform.is_rocm()
and weight.stride(-1) == 1
and (weight.stride(-2) * weight.element_size()) % 512 == 0):
num_pad = 256 // weight.element_size()
weight = F.pad(weight, (0, num_pad), "constant", 0)[..., :-num_pad]
torch.cuda.empty_cache()
return weight
def process_weights_after_loading(self, layer: torch.nn.Module) -> None: def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
super().process_weights_after_loading(layer) super().process_weights_after_loading(layer)
layer.w13_weight = torch.nn.Parameter(self._maybe_pad_weight(
layer.w13_weight.data),
requires_grad=False)
layer.w2_weight = torch.nn.Parameter(self._maybe_pad_weight(
layer.w2_weight.data),
requires_grad=False)
if current_platform.is_cpu(): if current_platform.is_cpu():
if current_platform.get_cpu_architecture() == CpuArchEnum.X86: if current_platform.get_cpu_architecture() == CpuArchEnum.X86:
import intel_extension_for_pytorch as ipex import intel_extension_for_pytorch as ipex

View File

@ -255,7 +255,7 @@ class Fp8LinearMethod(LinearMethodBase):
else: else:
layer.register_parameter("input_scale", None) layer.register_parameter("input_scale", None)
def add_padding_to_weight(self, weight: torch.Tensor) -> torch.Tensor: def _maybe_pad_weight(self, weight: torch.Tensor) -> torch.Tensor:
# Pad the weight tensor. This is an optimization on ROCm platform, which # Pad the weight tensor. This is an optimization on ROCm platform, which
# can benefit from tensors located far enough from one another in memory # can benefit from tensors located far enough from one another in memory
if (envs.VLLM_ROCM_FP8_PADDING and current_platform.is_rocm() if (envs.VLLM_ROCM_FP8_PADDING and current_platform.is_rocm()
@ -279,7 +279,7 @@ class Fp8LinearMethod(LinearMethodBase):
weight = layer.weight.data weight = layer.weight.data
weight_scale_inv = layer.weight_scale_inv.data weight_scale_inv = layer.weight_scale_inv.data
weight = self.add_padding_to_weight(weight) weight = self._maybe_pad_weight(weight)
# Torch.compile cannot use Parameter subclasses. # Torch.compile cannot use Parameter subclasses.
layer.weight = Parameter(weight, requires_grad=False) layer.weight = Parameter(weight, requires_grad=False)
@ -343,7 +343,7 @@ class Fp8LinearMethod(LinearMethodBase):
logical_widths=layer.logical_widths, logical_widths=layer.logical_widths,
) )
weight = self.add_padding_to_weight(weight) weight = self._maybe_pad_weight(weight)
# Update layer with new values. # Update layer with new values.
layer.weight = Parameter(weight.t(), requires_grad=False) layer.weight = Parameter(weight.t(), requires_grad=False)
layer.weight_scale = Parameter(weight_scale, requires_grad=False) layer.weight_scale = Parameter(weight_scale, requires_grad=False)