[ROCm][Kernel] MoE weights padding (#14454)
Signed-off-by: Gregory Shtrasberg <Gregory.Shtrasberg@amd.com> Signed-off-by: charlifu <charlifu@amd.com> Co-authored-by: charlifu <charlifu@amd.com>
This commit is contained in:
parent
8279201ce6
commit
f533b5837f
@ -3,8 +3,11 @@
|
||||
|
||||
Run `pytest tests/kernels/test_moe.py`.
|
||||
"""
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
from torch.nn import Parameter
|
||||
from torch.nn import functional as F
|
||||
from transformers import MixtralConfig
|
||||
from transformers.models.mixtral.modeling_mixtral import MixtralSparseMoeBlock
|
||||
|
||||
@ -37,6 +40,7 @@ TOP_KS = [2, 6]
|
||||
@pytest.mark.parametrize("topk", TOP_KS)
|
||||
@pytest.mark.parametrize("ep_size", EP_SIZE)
|
||||
@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16])
|
||||
@pytest.mark.parametrize("padding", [True, False])
|
||||
def test_fused_moe(
|
||||
m: int,
|
||||
n: int,
|
||||
@ -45,6 +49,7 @@ def test_fused_moe(
|
||||
topk: int,
|
||||
ep_size: int,
|
||||
dtype: torch.dtype,
|
||||
padding: bool,
|
||||
):
|
||||
a = torch.randn((m, k), device="cuda", dtype=dtype) / 10
|
||||
w1 = torch.randn((e, 2 * n, k), device="cuda", dtype=dtype) / 10
|
||||
@ -65,16 +70,7 @@ def test_fused_moe(
|
||||
else:
|
||||
e_map = None
|
||||
|
||||
triton_output = fused_moe(a,
|
||||
w1,
|
||||
w2,
|
||||
score,
|
||||
topk,
|
||||
global_num_experts=e,
|
||||
expert_map=e_map,
|
||||
renormalize=False)
|
||||
torch_output = torch_moe(a, w1, w2, score, topk, e_map)
|
||||
torch.testing.assert_close(triton_output, torch_output, atol=2e-2, rtol=0)
|
||||
iterative_output = iterative_moe(a,
|
||||
w1,
|
||||
w2,
|
||||
@ -83,6 +79,23 @@ def test_fused_moe(
|
||||
global_num_experts=e,
|
||||
expert_map=e_map,
|
||||
renormalize=False)
|
||||
|
||||
# Pad the weight if moe padding is enabled
|
||||
if padding:
|
||||
w1 = F.pad(w1, (0, 128), "constant", 0)[..., 0:-128]
|
||||
torch.cuda.empty_cache()
|
||||
w2 = F.pad(w2, (0, 128), "constant", 0)[..., 0:-128]
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
triton_output = fused_moe(a,
|
||||
w1,
|
||||
w2,
|
||||
score,
|
||||
topk,
|
||||
global_num_experts=e,
|
||||
expert_map=e_map,
|
||||
renormalize=False)
|
||||
torch.testing.assert_close(triton_output, torch_output, atol=2e-2, rtol=0)
|
||||
torch.testing.assert_close(iterative_output,
|
||||
torch_output,
|
||||
atol=2e-2,
|
||||
@ -202,8 +215,9 @@ def test_fused_moe_wn16(m: int, n: int, k: int, e: int, topk: int,
|
||||
|
||||
@pytest.mark.parametrize("dtype",
|
||||
[torch.float32, torch.float16, torch.bfloat16])
|
||||
@pytest.mark.parametrize("padding", [True, False])
|
||||
@torch.inference_mode()
|
||||
def test_mixtral_moe(dtype: torch.dtype):
|
||||
def test_mixtral_moe(dtype: torch.dtype, padding: bool):
|
||||
"""Make sure our Mixtral MoE implementation agrees with the one from
|
||||
huggingface."""
|
||||
|
||||
@ -233,6 +247,17 @@ def test_mixtral_moe(dtype: torch.dtype):
|
||||
# vLLM uses 1D query [num_tokens, hidden_dim]
|
||||
vllm_inputs = hf_inputs.flatten(0, 1)
|
||||
|
||||
# Pad the weight if moe padding is enabled
|
||||
if padding:
|
||||
vllm_moe.experts.w13_weight = Parameter(F.pad(
|
||||
vllm_moe.experts.w13_weight, (0, 128), "constant", 0)[..., 0:-128],
|
||||
requires_grad=False)
|
||||
torch.cuda.empty_cache()
|
||||
vllm_moe.experts.w2_weight = Parameter(F.pad(
|
||||
vllm_moe.experts.w2_weight, (0, 128), "constant", 0)[..., 0:-128],
|
||||
requires_grad=False)
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
# Run forward passes for both MoE blocks
|
||||
hf_states, _ = hf_moe.forward(hf_inputs)
|
||||
vllm_states = vllm_moe.forward(vllm_inputs)
|
||||
|
@ -75,6 +75,7 @@ if TYPE_CHECKING:
|
||||
VLLM_ROCM_USE_AITER: bool = False
|
||||
VLLM_ROCM_USE_AITER_RMSNORM: bool = True
|
||||
VLLM_ROCM_FP8_PADDING: bool = True
|
||||
VLLM_ROCM_MOE_PADDING: bool = True
|
||||
VLLM_ENABLE_V1_MULTIPROCESSING: bool = True
|
||||
VLLM_LOG_BATCHSIZE_INTERVAL: float = -1
|
||||
VLLM_DISABLE_COMPILE_CACHE: bool = False
|
||||
@ -520,6 +521,10 @@ environment_variables: dict[str, Callable[[], Any]] = {
|
||||
"VLLM_ROCM_FP8_PADDING":
|
||||
lambda: bool(int(os.getenv("VLLM_ROCM_FP8_PADDING", "1"))),
|
||||
|
||||
# Pad the weights for the moe kernel
|
||||
"VLLM_ROCM_MOE_PADDING":
|
||||
lambda: bool(int(os.getenv("VLLM_ROCM_MOE_PADDING", "1"))),
|
||||
|
||||
# Divisor for dynamic query scale factor calculation for FP8 KV Cache
|
||||
"Q_SCALE_CONSTANT":
|
||||
lambda: int(os.getenv("Q_SCALE_CONSTANT", "200")),
|
||||
|
@ -800,7 +800,7 @@ def invoke_fused_moe_kernel(A: torch.Tensor,
|
||||
expert_ids,
|
||||
num_tokens_post_padded,
|
||||
B.shape[1],
|
||||
A.shape[1],
|
||||
B.shape[2],
|
||||
EM,
|
||||
topk_ids.numel(),
|
||||
A.stride(0),
|
||||
@ -1322,8 +1322,8 @@ def fused_experts_impl(hidden_states: torch.Tensor,
|
||||
|
||||
assert topk_weights.shape == topk_ids.shape, "topk shape mismatch"
|
||||
assert hidden_states.is_contiguous(), "Hidden_states must be contiguous"
|
||||
assert w1.is_contiguous(), "Expert weights1 must be contiguous"
|
||||
assert w2.is_contiguous(), "Expert weights2 must be contiguous"
|
||||
assert w1.stride(-1) == 1, "Stride of last dimension must be 1"
|
||||
assert w2.stride(-1) == 1, "Stride of last dimension must be 1"
|
||||
assert hidden_states.dtype in [
|
||||
torch.float32, torch.float16, torch.bfloat16
|
||||
]
|
||||
|
@ -5,6 +5,7 @@ from enum import Enum
|
||||
from typing import Callable, List, Optional, Tuple
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from torch.nn.parameter import UninitializedParameter
|
||||
|
||||
from vllm import envs
|
||||
@ -96,9 +97,27 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
|
||||
layer.register_parameter("w2_weight", w2_weight)
|
||||
set_weight_attrs(w2_weight, extra_weight_attrs)
|
||||
|
||||
def _maybe_pad_weight(self, weight: torch.Tensor) -> torch.Tensor:
|
||||
# Pad the weight tensor. This is an optimization on ROCm platform, which
|
||||
# can benefit from tensors located far enough from one another in memory
|
||||
if (envs.VLLM_ROCM_MOE_PADDING and current_platform.is_rocm()
|
||||
and weight.stride(-1) == 1
|
||||
and (weight.stride(-2) * weight.element_size()) % 512 == 0):
|
||||
num_pad = 256 // weight.element_size()
|
||||
weight = F.pad(weight, (0, num_pad), "constant", 0)[..., :-num_pad]
|
||||
torch.cuda.empty_cache()
|
||||
return weight
|
||||
|
||||
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
|
||||
super().process_weights_after_loading(layer)
|
||||
|
||||
layer.w13_weight = torch.nn.Parameter(self._maybe_pad_weight(
|
||||
layer.w13_weight.data),
|
||||
requires_grad=False)
|
||||
layer.w2_weight = torch.nn.Parameter(self._maybe_pad_weight(
|
||||
layer.w2_weight.data),
|
||||
requires_grad=False)
|
||||
|
||||
if current_platform.is_cpu():
|
||||
if current_platform.get_cpu_architecture() == CpuArchEnum.X86:
|
||||
import intel_extension_for_pytorch as ipex
|
||||
|
@ -255,7 +255,7 @@ class Fp8LinearMethod(LinearMethodBase):
|
||||
else:
|
||||
layer.register_parameter("input_scale", None)
|
||||
|
||||
def add_padding_to_weight(self, weight: torch.Tensor) -> torch.Tensor:
|
||||
def _maybe_pad_weight(self, weight: torch.Tensor) -> torch.Tensor:
|
||||
# Pad the weight tensor. This is an optimization on ROCm platform, which
|
||||
# can benefit from tensors located far enough from one another in memory
|
||||
if (envs.VLLM_ROCM_FP8_PADDING and current_platform.is_rocm()
|
||||
@ -279,7 +279,7 @@ class Fp8LinearMethod(LinearMethodBase):
|
||||
weight = layer.weight.data
|
||||
weight_scale_inv = layer.weight_scale_inv.data
|
||||
|
||||
weight = self.add_padding_to_weight(weight)
|
||||
weight = self._maybe_pad_weight(weight)
|
||||
|
||||
# Torch.compile cannot use Parameter subclasses.
|
||||
layer.weight = Parameter(weight, requires_grad=False)
|
||||
@ -343,7 +343,7 @@ class Fp8LinearMethod(LinearMethodBase):
|
||||
logical_widths=layer.logical_widths,
|
||||
)
|
||||
|
||||
weight = self.add_padding_to_weight(weight)
|
||||
weight = self._maybe_pad_weight(weight)
|
||||
# Update layer with new values.
|
||||
layer.weight = Parameter(weight.t(), requires_grad=False)
|
||||
layer.weight_scale = Parameter(weight_scale, requires_grad=False)
|
||||
|
Loading…
x
Reference in New Issue
Block a user