[FEAT][ROCm] Integrate Fused MoE Kernels from AITER (#14967)
Signed-off-by: vllmellm <vllm.ellm@embeddedllm.com> Signed-off-by: tjtanaa <tunjian.tan@embeddedllm.com> Co-authored-by: tjtanaa <tunjian.tan@embeddedllm.com>
This commit is contained in:
parent
781d056280
commit
5ebf66748b
@ -3,7 +3,6 @@
|
||||
|
||||
Run `pytest tests/kernels/test_moe.py`.
|
||||
"""
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
from torch.nn import Parameter
|
||||
@ -216,11 +215,17 @@ def test_fused_moe_wn16(m: int, n: int, k: int, e: int, topk: int,
|
||||
@pytest.mark.parametrize("dtype",
|
||||
[torch.float32, torch.float16, torch.bfloat16])
|
||||
@pytest.mark.parametrize("padding", [True, False])
|
||||
@pytest.mark.parametrize(
|
||||
"use_rocm_aiter", [True, False] if current_platform.is_rocm() else [False])
|
||||
@torch.inference_mode()
|
||||
def test_mixtral_moe(dtype: torch.dtype, padding: bool):
|
||||
def test_mixtral_moe(dtype: torch.dtype, padding: bool, use_rocm_aiter: bool,
|
||||
monkeypatch):
|
||||
"""Make sure our Mixtral MoE implementation agrees with the one from
|
||||
huggingface."""
|
||||
|
||||
if use_rocm_aiter:
|
||||
monkeypatch.setenv("VLLM_ROCM_USE_AITER", "1")
|
||||
|
||||
# Instantiate our and huggingface's MoE blocks
|
||||
config = MixtralConfig()
|
||||
hf_moe = MixtralSparseMoeBlock(config).to(dtype).to("cuda")
|
||||
@ -268,10 +273,18 @@ def test_mixtral_moe(dtype: torch.dtype, padding: bool):
|
||||
torch.bfloat16: 1e-2,
|
||||
}
|
||||
|
||||
torch.testing.assert_close(hf_states.flatten(0, 1),
|
||||
vllm_states,
|
||||
rtol=mixtral_moe_tol[dtype],
|
||||
atol=mixtral_moe_tol[dtype])
|
||||
if use_rocm_aiter:
|
||||
# The values of rtol and atol are set based on the tests in ROCM AITER package. # noqa: E501
|
||||
# https://github.com/ROCm/aiter/blob/dfed377f4be7da96ca2d75ac0761f569676f7240/op_tests/test_moe.py#L174 # noqa: E501
|
||||
torch.testing.assert_close(hf_states.flatten(0, 1),
|
||||
vllm_states,
|
||||
rtol=0.01,
|
||||
atol=100)
|
||||
else:
|
||||
torch.testing.assert_close(hf_states.flatten(0, 1),
|
||||
vllm_states,
|
||||
rtol=mixtral_moe_tol[dtype],
|
||||
atol=mixtral_moe_tol[dtype])
|
||||
|
||||
|
||||
@pytest.mark.parametrize("m", [1, 33, 64, 222])
|
||||
|
@ -7,6 +7,10 @@ from vllm.model_executor.custom_op import CustomOp
|
||||
from vllm.model_executor.layers.activation import (GeluAndMul,
|
||||
ReLUSquaredActivation,
|
||||
SiluAndMul)
|
||||
from vllm.model_executor.layers.fused_moe.fused_moe import (
|
||||
dispatch_fused_experts_func, dispatch_topk_func,
|
||||
torch_vllm_inplace_fused_experts, torch_vllm_outplace_fused_experts,
|
||||
vllm_topk_softmax)
|
||||
from vllm.model_executor.layers.layernorm import (
|
||||
RMSNorm, dispatch_cuda_rmsnorm_func, fused_add_rms_norm, rms_norm,
|
||||
rocm_aiter_fused_add_rms_norm, rocm_aiter_rms_norm)
|
||||
@ -92,6 +96,38 @@ def test_enabled_ops_invalid(env: str):
|
||||
RMSNorm(1024).enabled()
|
||||
|
||||
|
||||
@pytest.mark.parametrize("use_rocm_aiter", ["0", "1"])
|
||||
def test_topk_dispatch(use_rocm_aiter: str, monkeypatch):
|
||||
monkeypatch.setenv("VLLM_ROCM_USE_AITER", use_rocm_aiter)
|
||||
topk_func = dispatch_topk_func()
|
||||
|
||||
if current_platform.is_rocm() and int(use_rocm_aiter):
|
||||
from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import (
|
||||
rocm_aiter_topk_softmax)
|
||||
|
||||
assert topk_func == rocm_aiter_topk_softmax
|
||||
else:
|
||||
assert topk_func == vllm_topk_softmax
|
||||
|
||||
|
||||
@pytest.mark.parametrize("use_rocm_aiter", ["0", "1"])
|
||||
@pytest.mark.parametrize("inplace", [True, False])
|
||||
def test_fused_experts_dispatch(use_rocm_aiter: str, inplace: bool,
|
||||
monkeypatch):
|
||||
|
||||
monkeypatch.setenv("VLLM_ROCM_USE_AITER", use_rocm_aiter)
|
||||
fused_experts_func = dispatch_fused_experts_func(inplace)
|
||||
if current_platform.is_rocm() and int(use_rocm_aiter):
|
||||
from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import (
|
||||
rocm_aiter_fused_experts)
|
||||
|
||||
assert fused_experts_func == rocm_aiter_fused_experts
|
||||
elif inplace:
|
||||
assert fused_experts_func == torch_vllm_inplace_fused_experts
|
||||
else:
|
||||
assert fused_experts_func == torch_vllm_outplace_fused_experts
|
||||
|
||||
|
||||
@pytest.mark.parametrize("add_residual", [True, False])
|
||||
@pytest.mark.parametrize("use_rocm_aiter", ["0", "1"])
|
||||
@pytest.mark.parametrize("use_rocm_aiter_norm", ["0", "1"])
|
||||
|
@ -174,15 +174,8 @@ SAMPLE_JSON_SCHEMA = {
|
||||
@pytest.mark.parametrize("dtype", ["bfloat16"])
|
||||
@pytest.mark.parametrize("max_tokens", [64])
|
||||
@pytest.mark.parametrize("num_logprobs", [5])
|
||||
def test_models(
|
||||
hf_runner,
|
||||
vllm_runner,
|
||||
example_prompts,
|
||||
model: str,
|
||||
dtype: str,
|
||||
max_tokens: int,
|
||||
num_logprobs: int,
|
||||
) -> None:
|
||||
def test_models(hf_runner, vllm_runner, example_prompts, model: str,
|
||||
dtype: str, max_tokens: int, num_logprobs: int) -> None:
|
||||
# TODO(sang): Sliding window should be tested separately.
|
||||
with hf_runner(model, dtype=dtype) as hf_model:
|
||||
hf_outputs = hf_model.generate_greedy_logprobs_limit(
|
||||
@ -206,14 +199,8 @@ def test_models(
|
||||
@pytest.mark.parametrize("dtype", ["bfloat16"])
|
||||
@pytest.mark.parametrize("max_tokens", [64])
|
||||
@pytest.mark.parametrize("num_logprobs", [5])
|
||||
def test_mistral_format(
|
||||
vllm_runner,
|
||||
example_prompts,
|
||||
model: str,
|
||||
dtype: str,
|
||||
max_tokens: int,
|
||||
num_logprobs: int,
|
||||
) -> None:
|
||||
def test_mistral_format(vllm_runner, example_prompts, model: str, dtype: str,
|
||||
max_tokens: int, num_logprobs: int) -> None:
|
||||
with vllm_runner(
|
||||
model,
|
||||
dtype=dtype,
|
||||
@ -244,11 +231,8 @@ def test_mistral_format(
|
||||
|
||||
@pytest.mark.parametrize("model", MISTRAL_FORMAT_MODELS)
|
||||
@pytest.mark.parametrize("dtype", ["bfloat16"])
|
||||
def test_mistral_symbolic_languages(
|
||||
vllm_runner,
|
||||
model: str,
|
||||
dtype: str,
|
||||
) -> None:
|
||||
def test_mistral_symbolic_languages(vllm_runner, model: str,
|
||||
dtype: str) -> None:
|
||||
with vllm_runner(model,
|
||||
dtype=dtype,
|
||||
max_model_len=8192,
|
||||
@ -266,11 +250,7 @@ def test_mistral_symbolic_languages(
|
||||
@pytest.mark.parametrize("dtype", ["bfloat16"])
|
||||
@pytest.mark.parametrize("model",
|
||||
MISTRAL_FORMAT_MODELS) # v1 can't do func calling
|
||||
def test_mistral_function_calling(
|
||||
vllm_runner,
|
||||
model: str,
|
||||
dtype: str,
|
||||
) -> None:
|
||||
def test_mistral_function_calling(vllm_runner, model: str, dtype: str) -> None:
|
||||
with vllm_runner(model,
|
||||
dtype=dtype,
|
||||
tokenizer_mode="mistral",
|
||||
@ -301,11 +281,8 @@ def test_mistral_function_calling(
|
||||
@pytest.mark.parametrize("model", MODELS)
|
||||
@pytest.mark.parametrize("guided_backend",
|
||||
["outlines", "lm-format-enforcer", "xgrammar"])
|
||||
def test_mistral_guided_decoding(
|
||||
vllm_runner,
|
||||
model: str,
|
||||
guided_backend: str,
|
||||
) -> None:
|
||||
def test_mistral_guided_decoding(vllm_runner, model: str,
|
||||
guided_backend: str) -> None:
|
||||
with vllm_runner(model, dtype='bfloat16',
|
||||
tokenizer_mode="mistral") as vllm_model:
|
||||
|
||||
|
@ -23,8 +23,14 @@ MODELS = [
|
||||
reason="FP8 is not supported on this GPU type.")
|
||||
@pytest.mark.parametrize("model_id", MODELS)
|
||||
@pytest.mark.parametrize("force_marlin", [False, True])
|
||||
@pytest.mark.parametrize(
|
||||
"use_rocm_aiter", [True, False] if current_platform.is_rocm() else [False])
|
||||
def test_model_load_and_run(vllm_runner, model_id: str, force_marlin: bool,
|
||||
monkeypatch) -> None:
|
||||
use_rocm_aiter: bool, monkeypatch) -> None:
|
||||
|
||||
if use_rocm_aiter:
|
||||
monkeypatch.setenv("VLLM_ROCM_USE_AITER", "1")
|
||||
|
||||
if force_marlin:
|
||||
monkeypatch.setenv("VLLM_TEST_FORCE_FP8_MARLIN", "1")
|
||||
|
||||
@ -47,7 +53,13 @@ KV_CACHE_MODELS = [
|
||||
@pytest.mark.skipif(not is_quant_method_supported("fp8"),
|
||||
reason="FP8 is not supported on this GPU type.")
|
||||
@pytest.mark.parametrize("model_id", KV_CACHE_MODELS)
|
||||
def test_kv_cache_model_load_and_run(vllm_runner, model_id: str, monkeypatch):
|
||||
@pytest.mark.parametrize(
|
||||
"use_rocm_aiter", [True, False] if current_platform.is_rocm() else [False])
|
||||
def test_kv_cache_model_load_and_run(vllm_runner, model_id: str,
|
||||
use_rocm_aiter: bool, monkeypatch):
|
||||
if use_rocm_aiter:
|
||||
monkeypatch.setenv("VLLM_ROCM_USE_AITER", "1")
|
||||
|
||||
# vllm_runner.apply_model() relies on V0 internals.
|
||||
monkeypatch.setenv("VLLM_USE_V1", "0")
|
||||
with vllm_runner(model_id, kv_cache_dtype="fp8") as llm:
|
||||
@ -86,8 +98,13 @@ def test_kv_cache_model_load_and_run(vllm_runner, model_id: str, monkeypatch):
|
||||
reason="FP8 is not supported on this GPU type.")
|
||||
@pytest.mark.parametrize("kv_cache_dtype", ["auto", "fp8"])
|
||||
@pytest.mark.parametrize("force_marlin", [False, True])
|
||||
@pytest.mark.parametrize(
|
||||
"use_rocm_aiter", [True, False] if current_platform.is_rocm() else [False])
|
||||
def test_load_fp16_model(vllm_runner, kv_cache_dtype: str, force_marlin: bool,
|
||||
monkeypatch) -> None:
|
||||
use_rocm_aiter: bool, monkeypatch) -> None:
|
||||
if use_rocm_aiter:
|
||||
monkeypatch.setenv("VLLM_ROCM_USE_AITER", "1")
|
||||
|
||||
# vllm_runner.apply_model() relies on V0 internals.
|
||||
monkeypatch.setenv("VLLM_USE_V1", "0")
|
||||
|
||||
|
15
vllm/envs.py
15
vllm/envs.py
@ -73,6 +73,8 @@ if TYPE_CHECKING:
|
||||
VLLM_DISABLED_KERNELS: list[str] = []
|
||||
VLLM_USE_V1: bool = True
|
||||
VLLM_ROCM_USE_AITER: bool = False
|
||||
VLLM_ROCM_USE_AITER_MOE: bool = True
|
||||
VLLM_ROCM_USE_AITER_FP8_BLOCK_SCALED_MOE: bool = False
|
||||
VLLM_ROCM_USE_AITER_RMSNORM: bool = True
|
||||
VLLM_ROCM_FP8_PADDING: bool = True
|
||||
VLLM_ROCM_MOE_PADDING: bool = True
|
||||
@ -513,6 +515,19 @@ environment_variables: dict[str, Callable[[], Any]] = {
|
||||
lambda: (os.getenv("VLLM_ROCM_USE_AITER", "False").lower() in
|
||||
("true", "1")),
|
||||
|
||||
# Whether to use aiter moe ops.
|
||||
# By default is enabled.
|
||||
"VLLM_ROCM_USE_AITER_MOE":
|
||||
lambda: (os.getenv("VLLM_ROCM_USE_AITER_MOE", "True").lower() in
|
||||
("true", "1")),
|
||||
|
||||
# Whether to use aiter block scaled moe kernel.
|
||||
# By default this is disabled.
|
||||
"VLLM_ROCM_USE_AITER_FP8_BLOCK_SCALED_MOE":
|
||||
lambda:
|
||||
(os.getenv("VLLM_ROCM_USE_AITER_FP8_BLOCK_SCALED_MOE", "false").lower() in
|
||||
("true", "1")),
|
||||
|
||||
# use aiter rms norm op if aiter ops are enabled.
|
||||
"VLLM_ROCM_USE_AITER_RMSNORM":
|
||||
lambda: (os.getenv("VLLM_ROCM_USE_AITER_RMSNORM", "True").lower() in
|
||||
|
@ -17,6 +17,10 @@ from vllm.model_executor.layers.quantization.utils.fp8_utils import (
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.utils import direct_register_custom_op
|
||||
|
||||
from .rocm_aiter_fused_moe import (is_rocm_aiter_moe_enabled,
|
||||
rocm_aiter_fused_experts,
|
||||
rocm_aiter_topk_softmax)
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
@ -1035,6 +1039,28 @@ def try_get_optimal_moe_config(
|
||||
return config
|
||||
|
||||
|
||||
def vllm_topk_softmax(topk_weights: torch.Tensor, topk_indices: torch.Tensor,
|
||||
token_expert_indices: torch.Tensor,
|
||||
gating_output: torch.Tensor,
|
||||
renormalize: bool) -> tuple[torch.Tensor, ...]:
|
||||
ops.topk_softmax(
|
||||
topk_weights,
|
||||
topk_indices,
|
||||
token_expert_indices,
|
||||
gating_output,
|
||||
)
|
||||
if renormalize:
|
||||
topk_weights = topk_weights / topk_weights.sum(dim=-1, keepdim=True)
|
||||
|
||||
return topk_weights, topk_indices
|
||||
|
||||
|
||||
def dispatch_topk_func() -> Callable[..., tuple[torch.Tensor, ...]]:
|
||||
if is_rocm_aiter_moe_enabled():
|
||||
return rocm_aiter_topk_softmax
|
||||
return vllm_topk_softmax
|
||||
|
||||
|
||||
def fused_topk(
|
||||
hidden_states: torch.Tensor,
|
||||
gating_output: torch.Tensor,
|
||||
@ -1059,17 +1085,14 @@ def fused_topk(
|
||||
dtype=torch.int32,
|
||||
device=hidden_states.device)
|
||||
|
||||
ops.topk_softmax(
|
||||
topk_weights,
|
||||
topk_ids,
|
||||
token_expert_indicies,
|
||||
gating_output.float(), # TODO(woosuk): Optimize this.
|
||||
)
|
||||
gating_output_float = gating_output.float() # TODO(woosuk): Optimize this.
|
||||
|
||||
topk_func = dispatch_topk_func()
|
||||
topk_weights, topk_ids = topk_func(topk_weights, topk_ids,
|
||||
token_expert_indicies,
|
||||
gating_output_float, renormalize)
|
||||
|
||||
del token_expert_indicies # Not used. Will be used in the future.
|
||||
|
||||
if renormalize:
|
||||
topk_weights = topk_weights / topk_weights.sum(dim=-1, keepdim=True)
|
||||
|
||||
return topk_weights, topk_ids
|
||||
|
||||
|
||||
@ -1259,6 +1282,24 @@ direct_register_custom_op(
|
||||
)
|
||||
|
||||
|
||||
def torch_vllm_inplace_fused_experts(**kwargs) -> torch.Tensor:
|
||||
torch.ops.vllm.inplace_fused_experts(**kwargs)
|
||||
hidden_states = kwargs['hidden_states']
|
||||
return hidden_states
|
||||
|
||||
|
||||
def torch_vllm_outplace_fused_experts(**kwargs) -> torch.Tensor:
|
||||
return torch.ops.vllm.outplace_fused_experts(**kwargs)
|
||||
|
||||
|
||||
def dispatch_fused_experts_func(inplace: bool) -> Callable[..., torch.Tensor]:
|
||||
if is_rocm_aiter_moe_enabled():
|
||||
return rocm_aiter_fused_experts
|
||||
if inplace:
|
||||
return torch_vllm_inplace_fused_experts
|
||||
return torch_vllm_outplace_fused_experts
|
||||
|
||||
|
||||
def fused_experts(hidden_states: torch.Tensor,
|
||||
w1: torch.Tensor,
|
||||
w2: torch.Tensor,
|
||||
@ -1278,20 +1319,25 @@ def fused_experts(hidden_states: torch.Tensor,
|
||||
a1_scale: Optional[torch.Tensor] = None,
|
||||
a2_scale: Optional[torch.Tensor] = None,
|
||||
block_shape: Optional[List[int]] = None) -> torch.Tensor:
|
||||
|
||||
if inplace:
|
||||
torch.ops.vllm.inplace_fused_experts(
|
||||
hidden_states, w1, w2, topk_weights, topk_ids, activation,
|
||||
use_fp8_w8a8, use_int8_w8a16, use_int4_w4a16, global_num_experts,
|
||||
expert_map, w1_scale, w2_scale, w1_zp, w2_zp, a1_scale, a2_scale,
|
||||
block_shape)
|
||||
return hidden_states
|
||||
else:
|
||||
return torch.ops.vllm.outplace_fused_experts(
|
||||
hidden_states, w1, w2, topk_weights, topk_ids, activation,
|
||||
use_fp8_w8a8, use_int8_w8a16, use_int4_w4a16, global_num_experts,
|
||||
expert_map, w1_scale, w2_scale, w1_zp, w2_zp, a1_scale, a2_scale,
|
||||
block_shape)
|
||||
return dispatch_fused_experts_func(inplace)(
|
||||
hidden_states=hidden_states,
|
||||
w1=w1,
|
||||
w2=w2,
|
||||
topk_weights=topk_weights,
|
||||
topk_ids=topk_ids,
|
||||
activation=activation,
|
||||
use_fp8_w8a8=use_fp8_w8a8,
|
||||
use_int8_w8a16=use_int8_w8a16,
|
||||
use_int4_w4a16=use_int4_w4a16,
|
||||
global_num_experts=global_num_experts,
|
||||
expert_map=expert_map,
|
||||
w1_scale=w1_scale,
|
||||
w2_scale=w2_scale,
|
||||
w1_zp=w1_zp,
|
||||
w2_zp=w2_zp,
|
||||
a1_scale=a1_scale,
|
||||
a2_scale=a2_scale,
|
||||
block_shape=block_shape)
|
||||
|
||||
|
||||
def fused_experts_impl(hidden_states: torch.Tensor,
|
||||
|
@ -8,7 +8,7 @@ import torch
|
||||
import torch.nn.functional as F
|
||||
from torch.nn.parameter import UninitializedParameter
|
||||
|
||||
from vllm import envs
|
||||
import vllm.envs as envs
|
||||
from vllm.config import get_current_vllm_config
|
||||
from vllm.distributed import (get_dp_group, get_tensor_model_parallel_rank,
|
||||
get_tensor_model_parallel_world_size,
|
||||
@ -16,6 +16,8 @@ from vllm.distributed import (get_dp_group, get_tensor_model_parallel_rank,
|
||||
from vllm.forward_context import ForwardContext, get_forward_context
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.custom_op import CustomOp
|
||||
from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import (
|
||||
is_rocm_aiter_moe_enabled, shuffle_weights)
|
||||
from vllm.model_executor.layers.quantization.base_config import (
|
||||
QuantizationConfig, QuantizeMethodBase)
|
||||
from vllm.model_executor.utils import set_weight_attrs
|
||||
@ -118,6 +120,16 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
|
||||
layer.w2_weight.data),
|
||||
requires_grad=False)
|
||||
|
||||
if is_rocm_aiter_moe_enabled():
|
||||
# reshaping weights is required for aiter moe kernel.
|
||||
shuffled_w13, shuffled_w2 = shuffle_weights(
|
||||
layer.w13_weight.data, layer.w2_weight.data)
|
||||
|
||||
layer.w13_weight = torch.nn.Parameter(shuffled_w13,
|
||||
requires_grad=False)
|
||||
layer.w2_weight = torch.nn.Parameter(shuffled_w2,
|
||||
requires_grad=False)
|
||||
|
||||
if current_platform.is_cpu():
|
||||
if current_platform.get_cpu_architecture() == CpuArchEnum.X86:
|
||||
import intel_extension_for_pytorch as ipex
|
||||
|
157
vllm/model_executor/layers/fused_moe/rocm_aiter_fused_moe.py
Normal file
157
vllm/model_executor/layers/fused_moe/rocm_aiter_fused_moe.py
Normal file
@ -0,0 +1,157 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
from typing import List, Optional
|
||||
|
||||
import torch
|
||||
|
||||
import vllm.envs as envs
|
||||
from vllm.model_executor.layers.quantization.utils.fp8_utils import (
|
||||
per_token_group_quant_fp8)
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
|
||||
def is_rocm_aiter_moe_enabled() -> bool:
|
||||
return current_platform.is_rocm() \
|
||||
and envs.VLLM_ROCM_USE_AITER_MOE \
|
||||
and envs.VLLM_ROCM_USE_AITER \
|
||||
|
||||
|
||||
def is_rocm_aiter_block_scaled_moe_enabled() -> bool:
|
||||
return is_rocm_aiter_moe_enabled() and \
|
||||
envs.VLLM_ROCM_USE_AITER_FP8_BLOCK_SCALED_MOE
|
||||
|
||||
|
||||
def rocm_aiter_fused_experts(
|
||||
*,
|
||||
hidden_states: torch.Tensor,
|
||||
w1: torch.Tensor,
|
||||
w2: torch.Tensor,
|
||||
topk_weights: torch.Tensor,
|
||||
topk_ids: torch.Tensor,
|
||||
use_fp8_w8a8: bool = False,
|
||||
w1_scale: Optional[torch.Tensor] = None,
|
||||
w2_scale: Optional[torch.Tensor] = None,
|
||||
block_shape: Optional[List[int]] = None,
|
||||
expert_mask: Optional[torch.Tensor] = None,
|
||||
**kwagrs # Ignore additional keyword arguments
|
||||
) -> torch.Tensor:
|
||||
|
||||
import aiter as rocm_aiter
|
||||
import aiter.fused_moe_bf16_asm as rocm_aiter_asm_fmoe
|
||||
|
||||
if envs.VLLM_ROCM_USE_AITER_FP8_BLOCK_SCALED_MOE and use_fp8_w8a8:
|
||||
assert w1_scale is not None
|
||||
assert w2_scale is not None
|
||||
|
||||
local_E = E = w1.shape[0]
|
||||
if expert_mask is not None:
|
||||
E = expert_mask.numel()
|
||||
|
||||
topk = topk_ids.shape[1]
|
||||
model_dim = w1.shape[-1]
|
||||
dtype = hidden_states.dtype
|
||||
# The default block sizes are 128 in AITER.
|
||||
if block_shape is None:
|
||||
block_shape = [128, 128]
|
||||
|
||||
scale_blk_k = block_shape[1]
|
||||
|
||||
(
|
||||
sorted_token_ids,
|
||||
sorted_weight_buf,
|
||||
sorted_expert_ids,
|
||||
num_valid_ids,
|
||||
out_asm,
|
||||
) = rocm_aiter_asm_fmoe.moe_sorting_ck(topk_ids,
|
||||
topk_weights,
|
||||
E,
|
||||
model_dim,
|
||||
dtype,
|
||||
expert_mask=expert_mask)
|
||||
|
||||
a1, a1_scale = per_token_group_quant_fp8(hidden_states, scale_blk_k)
|
||||
rocm_aiter.fmoe_fp8_blockscale_g1u1(
|
||||
out_asm,
|
||||
a1,
|
||||
w1,
|
||||
w2,
|
||||
sorted_token_ids,
|
||||
sorted_weight_buf,
|
||||
sorted_expert_ids,
|
||||
num_valid_ids,
|
||||
topk,
|
||||
w1_scale.view(local_E, -1),
|
||||
w2_scale.view(local_E, -1),
|
||||
a1_scale.t().contiguous(),
|
||||
block_shape[0],
|
||||
block_shape[1],
|
||||
None,
|
||||
)
|
||||
return out_asm
|
||||
|
||||
elif use_fp8_w8a8:
|
||||
return rocm_aiter_asm_fmoe.asm_moe(hidden_states=hidden_states,
|
||||
w1=w1,
|
||||
w2=w2,
|
||||
topk_weight=topk_weights,
|
||||
topk_ids=topk_ids,
|
||||
fc1_scale=w1_scale,
|
||||
fc2_scale=w2_scale,
|
||||
fc1_smooth_scale=None,
|
||||
fc2_smooth_scale=None,
|
||||
a16=False)
|
||||
|
||||
return rocm_aiter.ck_moe(hidden_states=hidden_states,
|
||||
w1=w1,
|
||||
w2=w2,
|
||||
topk_weights=topk_weights,
|
||||
topk_ids=topk_ids)
|
||||
|
||||
|
||||
def rocm_aiter_topk_softmax(topk_weights: torch.Tensor,
|
||||
topk_indices: torch.Tensor,
|
||||
token_expert_indices: torch.Tensor,
|
||||
gating_output: torch.Tensor,
|
||||
renormalize: bool) -> tuple[torch.Tensor, ...]:
|
||||
import aiter as rocm_aiter
|
||||
rocm_aiter.topk_softmax(topk_weights, topk_indices, token_expert_indices,
|
||||
gating_output, renormalize)
|
||||
|
||||
return topk_weights, topk_indices
|
||||
|
||||
|
||||
def shuffle_weights(*tensors: torch.Tensor) -> tuple[torch.Tensor, ...]:
|
||||
"""
|
||||
Applies shuffle_weight function from AITER to each
|
||||
input tensor and returns them.
|
||||
|
||||
Args:
|
||||
*tensors: Variable number of torch.Tensor objects.
|
||||
|
||||
Returns:
|
||||
A tuple of shuffled tensors.
|
||||
"""
|
||||
from aiter.ops.shuffle import shuffle_weight
|
||||
|
||||
return tuple(shuffle_weight(tensor) for tensor in tensors)
|
||||
|
||||
|
||||
def expand_weights(*tensors: torch.Tensor,
|
||||
expansion_dims: list[int]) -> tuple[torch.Tensor, ...]:
|
||||
"""
|
||||
Expands the dimensions of input tensors.
|
||||
|
||||
Args:
|
||||
*tensors: A variable number of torch.Tensor objects.
|
||||
expansion_dims: A list of expansion dimensions
|
||||
corresponding to each tensor.
|
||||
|
||||
Returns:
|
||||
A tuple of tensors with expanded dimensions.
|
||||
"""
|
||||
|
||||
assert len(tensors) == len(expansion_dims), \
|
||||
"Number of tensors must match the number of expansion dimensions."
|
||||
|
||||
return tuple(
|
||||
tensor.unsqueeze(-1).unsqueeze(-1).expand((-1, dim, -1))
|
||||
for tensor, dim in zip(tensors, expansion_dims))
|
@ -13,6 +13,9 @@ from vllm.distributed import get_tensor_model_parallel_world_size
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.fused_moe import (FusedMoE, FusedMoEMethodBase,
|
||||
FusedMoeWeightScaleSupported)
|
||||
from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import (
|
||||
expand_weights, is_rocm_aiter_block_scaled_moe_enabled,
|
||||
is_rocm_aiter_moe_enabled, shuffle_weights)
|
||||
from vllm.model_executor.layers.linear import (LinearBase, LinearMethodBase,
|
||||
UnquantizedLinearMethod)
|
||||
from vllm.model_executor.layers.quantization.base_config import (
|
||||
@ -554,6 +557,15 @@ class Fp8MoEMethod(FusedMoEMethodBase):
|
||||
layer.w2_weight = Parameter(w2_weight, requires_grad=False)
|
||||
layer.w2_weight_scale_inv = Parameter(w2_weight_scale_inv,
|
||||
requires_grad=False)
|
||||
if is_rocm_aiter_block_scaled_moe_enabled():
|
||||
# reshaping weights is required for aiter moe kernel.
|
||||
shuffled_w13, shuffled_w2 = shuffle_weights(
|
||||
layer.w13_weight.data, layer.w2_weight.data)
|
||||
|
||||
layer.w13_weight = torch.nn.Parameter(shuffled_w13,
|
||||
requires_grad=False)
|
||||
layer.w2_weight = torch.nn.Parameter(shuffled_w2,
|
||||
requires_grad=False)
|
||||
return
|
||||
|
||||
# If checkpoint is fp16, quantize in place.
|
||||
@ -581,6 +593,26 @@ class Fp8MoEMethod(FusedMoEMethodBase):
|
||||
requires_grad=False)
|
||||
layer.w2_weight = torch.nn.Parameter(w2_weight,
|
||||
requires_grad=False)
|
||||
if is_rocm_aiter_moe_enabled():
|
||||
# reshaping weights is required for aiter moe kernel.
|
||||
w13_scales, w2_scales = expand_weights(
|
||||
layer.w13_weight_scale.data,
|
||||
layer.w2_weight_scale.data,
|
||||
expansion_dims=[
|
||||
layer.w13_weight.shape[1], layer.w2_weight.shape[1]
|
||||
])
|
||||
layer.w13_weight_scale = torch.nn.Parameter(
|
||||
w13_scales.contiguous(), requires_grad=False)
|
||||
layer.w2_weight_scale = torch.nn.Parameter(
|
||||
w2_scales.contiguous(), requires_grad=False)
|
||||
|
||||
shuffled_w13, shuffled_w2 = shuffle_weights(
|
||||
layer.w13_weight, layer.w2_weight)
|
||||
|
||||
layer.w13_weight = torch.nn.Parameter(shuffled_w13,
|
||||
requires_grad=False)
|
||||
layer.w2_weight = torch.nn.Parameter(shuffled_w2,
|
||||
requires_grad=False)
|
||||
return
|
||||
|
||||
# If checkpoint is fp8, we need to handle that the
|
||||
@ -648,6 +680,26 @@ class Fp8MoEMethod(FusedMoEMethodBase):
|
||||
dq_weight, max_w13_scales[expert_id])
|
||||
start += shard_size
|
||||
|
||||
if is_rocm_aiter_moe_enabled():
|
||||
# reshaping weights is required for aiter moe kernel.
|
||||
expansion_dims = [
|
||||
layer.w13_weight.shape[1], layer.w2_weight.shape[1]
|
||||
]
|
||||
max_w13_scales, w2_scales = expand_weights(
|
||||
max_w13_scales,
|
||||
layer.w2_weight_scale.data,
|
||||
expansion_dims=expansion_dims)
|
||||
layer.w2_weight_scale = torch.nn.Parameter(
|
||||
w2_scales.contiguous(), requires_grad=False)
|
||||
|
||||
shuffled_w13, shuffled_w2 = shuffle_weights(
|
||||
layer.w13_weight, layer.w2_weight)
|
||||
|
||||
layer.w13_weight = torch.nn.Parameter(shuffled_w13,
|
||||
requires_grad=False)
|
||||
layer.w2_weight = torch.nn.Parameter(shuffled_w2,
|
||||
requires_grad=False)
|
||||
|
||||
layer.w13_weight_scale = torch.nn.Parameter(max_w13_scales,
|
||||
requires_grad=False)
|
||||
return
|
||||
|
Loading…
x
Reference in New Issue
Block a user