[FEAT][ROCm] Integrate Fused MoE Kernels from AITER (#14967)
Signed-off-by: vllmellm <vllm.ellm@embeddedllm.com> Signed-off-by: tjtanaa <tunjian.tan@embeddedllm.com> Co-authored-by: tjtanaa <tunjian.tan@embeddedllm.com>
This commit is contained in:
parent
781d056280
commit
5ebf66748b
@ -3,7 +3,6 @@
|
|||||||
|
|
||||||
Run `pytest tests/kernels/test_moe.py`.
|
Run `pytest tests/kernels/test_moe.py`.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
import torch
|
import torch
|
||||||
from torch.nn import Parameter
|
from torch.nn import Parameter
|
||||||
@ -216,11 +215,17 @@ def test_fused_moe_wn16(m: int, n: int, k: int, e: int, topk: int,
|
|||||||
@pytest.mark.parametrize("dtype",
|
@pytest.mark.parametrize("dtype",
|
||||||
[torch.float32, torch.float16, torch.bfloat16])
|
[torch.float32, torch.float16, torch.bfloat16])
|
||||||
@pytest.mark.parametrize("padding", [True, False])
|
@pytest.mark.parametrize("padding", [True, False])
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"use_rocm_aiter", [True, False] if current_platform.is_rocm() else [False])
|
||||||
@torch.inference_mode()
|
@torch.inference_mode()
|
||||||
def test_mixtral_moe(dtype: torch.dtype, padding: bool):
|
def test_mixtral_moe(dtype: torch.dtype, padding: bool, use_rocm_aiter: bool,
|
||||||
|
monkeypatch):
|
||||||
"""Make sure our Mixtral MoE implementation agrees with the one from
|
"""Make sure our Mixtral MoE implementation agrees with the one from
|
||||||
huggingface."""
|
huggingface."""
|
||||||
|
|
||||||
|
if use_rocm_aiter:
|
||||||
|
monkeypatch.setenv("VLLM_ROCM_USE_AITER", "1")
|
||||||
|
|
||||||
# Instantiate our and huggingface's MoE blocks
|
# Instantiate our and huggingface's MoE blocks
|
||||||
config = MixtralConfig()
|
config = MixtralConfig()
|
||||||
hf_moe = MixtralSparseMoeBlock(config).to(dtype).to("cuda")
|
hf_moe = MixtralSparseMoeBlock(config).to(dtype).to("cuda")
|
||||||
@ -268,10 +273,18 @@ def test_mixtral_moe(dtype: torch.dtype, padding: bool):
|
|||||||
torch.bfloat16: 1e-2,
|
torch.bfloat16: 1e-2,
|
||||||
}
|
}
|
||||||
|
|
||||||
torch.testing.assert_close(hf_states.flatten(0, 1),
|
if use_rocm_aiter:
|
||||||
vllm_states,
|
# The values of rtol and atol are set based on the tests in ROCM AITER package. # noqa: E501
|
||||||
rtol=mixtral_moe_tol[dtype],
|
# https://github.com/ROCm/aiter/blob/dfed377f4be7da96ca2d75ac0761f569676f7240/op_tests/test_moe.py#L174 # noqa: E501
|
||||||
atol=mixtral_moe_tol[dtype])
|
torch.testing.assert_close(hf_states.flatten(0, 1),
|
||||||
|
vllm_states,
|
||||||
|
rtol=0.01,
|
||||||
|
atol=100)
|
||||||
|
else:
|
||||||
|
torch.testing.assert_close(hf_states.flatten(0, 1),
|
||||||
|
vllm_states,
|
||||||
|
rtol=mixtral_moe_tol[dtype],
|
||||||
|
atol=mixtral_moe_tol[dtype])
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("m", [1, 33, 64, 222])
|
@pytest.mark.parametrize("m", [1, 33, 64, 222])
|
||||||
|
@ -7,6 +7,10 @@ from vllm.model_executor.custom_op import CustomOp
|
|||||||
from vllm.model_executor.layers.activation import (GeluAndMul,
|
from vllm.model_executor.layers.activation import (GeluAndMul,
|
||||||
ReLUSquaredActivation,
|
ReLUSquaredActivation,
|
||||||
SiluAndMul)
|
SiluAndMul)
|
||||||
|
from vllm.model_executor.layers.fused_moe.fused_moe import (
|
||||||
|
dispatch_fused_experts_func, dispatch_topk_func,
|
||||||
|
torch_vllm_inplace_fused_experts, torch_vllm_outplace_fused_experts,
|
||||||
|
vllm_topk_softmax)
|
||||||
from vllm.model_executor.layers.layernorm import (
|
from vllm.model_executor.layers.layernorm import (
|
||||||
RMSNorm, dispatch_cuda_rmsnorm_func, fused_add_rms_norm, rms_norm,
|
RMSNorm, dispatch_cuda_rmsnorm_func, fused_add_rms_norm, rms_norm,
|
||||||
rocm_aiter_fused_add_rms_norm, rocm_aiter_rms_norm)
|
rocm_aiter_fused_add_rms_norm, rocm_aiter_rms_norm)
|
||||||
@ -92,6 +96,38 @@ def test_enabled_ops_invalid(env: str):
|
|||||||
RMSNorm(1024).enabled()
|
RMSNorm(1024).enabled()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("use_rocm_aiter", ["0", "1"])
|
||||||
|
def test_topk_dispatch(use_rocm_aiter: str, monkeypatch):
|
||||||
|
monkeypatch.setenv("VLLM_ROCM_USE_AITER", use_rocm_aiter)
|
||||||
|
topk_func = dispatch_topk_func()
|
||||||
|
|
||||||
|
if current_platform.is_rocm() and int(use_rocm_aiter):
|
||||||
|
from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import (
|
||||||
|
rocm_aiter_topk_softmax)
|
||||||
|
|
||||||
|
assert topk_func == rocm_aiter_topk_softmax
|
||||||
|
else:
|
||||||
|
assert topk_func == vllm_topk_softmax
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("use_rocm_aiter", ["0", "1"])
|
||||||
|
@pytest.mark.parametrize("inplace", [True, False])
|
||||||
|
def test_fused_experts_dispatch(use_rocm_aiter: str, inplace: bool,
|
||||||
|
monkeypatch):
|
||||||
|
|
||||||
|
monkeypatch.setenv("VLLM_ROCM_USE_AITER", use_rocm_aiter)
|
||||||
|
fused_experts_func = dispatch_fused_experts_func(inplace)
|
||||||
|
if current_platform.is_rocm() and int(use_rocm_aiter):
|
||||||
|
from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import (
|
||||||
|
rocm_aiter_fused_experts)
|
||||||
|
|
||||||
|
assert fused_experts_func == rocm_aiter_fused_experts
|
||||||
|
elif inplace:
|
||||||
|
assert fused_experts_func == torch_vllm_inplace_fused_experts
|
||||||
|
else:
|
||||||
|
assert fused_experts_func == torch_vllm_outplace_fused_experts
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("add_residual", [True, False])
|
@pytest.mark.parametrize("add_residual", [True, False])
|
||||||
@pytest.mark.parametrize("use_rocm_aiter", ["0", "1"])
|
@pytest.mark.parametrize("use_rocm_aiter", ["0", "1"])
|
||||||
@pytest.mark.parametrize("use_rocm_aiter_norm", ["0", "1"])
|
@pytest.mark.parametrize("use_rocm_aiter_norm", ["0", "1"])
|
||||||
|
@ -174,15 +174,8 @@ SAMPLE_JSON_SCHEMA = {
|
|||||||
@pytest.mark.parametrize("dtype", ["bfloat16"])
|
@pytest.mark.parametrize("dtype", ["bfloat16"])
|
||||||
@pytest.mark.parametrize("max_tokens", [64])
|
@pytest.mark.parametrize("max_tokens", [64])
|
||||||
@pytest.mark.parametrize("num_logprobs", [5])
|
@pytest.mark.parametrize("num_logprobs", [5])
|
||||||
def test_models(
|
def test_models(hf_runner, vllm_runner, example_prompts, model: str,
|
||||||
hf_runner,
|
dtype: str, max_tokens: int, num_logprobs: int) -> None:
|
||||||
vllm_runner,
|
|
||||||
example_prompts,
|
|
||||||
model: str,
|
|
||||||
dtype: str,
|
|
||||||
max_tokens: int,
|
|
||||||
num_logprobs: int,
|
|
||||||
) -> None:
|
|
||||||
# TODO(sang): Sliding window should be tested separately.
|
# TODO(sang): Sliding window should be tested separately.
|
||||||
with hf_runner(model, dtype=dtype) as hf_model:
|
with hf_runner(model, dtype=dtype) as hf_model:
|
||||||
hf_outputs = hf_model.generate_greedy_logprobs_limit(
|
hf_outputs = hf_model.generate_greedy_logprobs_limit(
|
||||||
@ -206,14 +199,8 @@ def test_models(
|
|||||||
@pytest.mark.parametrize("dtype", ["bfloat16"])
|
@pytest.mark.parametrize("dtype", ["bfloat16"])
|
||||||
@pytest.mark.parametrize("max_tokens", [64])
|
@pytest.mark.parametrize("max_tokens", [64])
|
||||||
@pytest.mark.parametrize("num_logprobs", [5])
|
@pytest.mark.parametrize("num_logprobs", [5])
|
||||||
def test_mistral_format(
|
def test_mistral_format(vllm_runner, example_prompts, model: str, dtype: str,
|
||||||
vllm_runner,
|
max_tokens: int, num_logprobs: int) -> None:
|
||||||
example_prompts,
|
|
||||||
model: str,
|
|
||||||
dtype: str,
|
|
||||||
max_tokens: int,
|
|
||||||
num_logprobs: int,
|
|
||||||
) -> None:
|
|
||||||
with vllm_runner(
|
with vllm_runner(
|
||||||
model,
|
model,
|
||||||
dtype=dtype,
|
dtype=dtype,
|
||||||
@ -244,11 +231,8 @@ def test_mistral_format(
|
|||||||
|
|
||||||
@pytest.mark.parametrize("model", MISTRAL_FORMAT_MODELS)
|
@pytest.mark.parametrize("model", MISTRAL_FORMAT_MODELS)
|
||||||
@pytest.mark.parametrize("dtype", ["bfloat16"])
|
@pytest.mark.parametrize("dtype", ["bfloat16"])
|
||||||
def test_mistral_symbolic_languages(
|
def test_mistral_symbolic_languages(vllm_runner, model: str,
|
||||||
vllm_runner,
|
dtype: str) -> None:
|
||||||
model: str,
|
|
||||||
dtype: str,
|
|
||||||
) -> None:
|
|
||||||
with vllm_runner(model,
|
with vllm_runner(model,
|
||||||
dtype=dtype,
|
dtype=dtype,
|
||||||
max_model_len=8192,
|
max_model_len=8192,
|
||||||
@ -266,11 +250,7 @@ def test_mistral_symbolic_languages(
|
|||||||
@pytest.mark.parametrize("dtype", ["bfloat16"])
|
@pytest.mark.parametrize("dtype", ["bfloat16"])
|
||||||
@pytest.mark.parametrize("model",
|
@pytest.mark.parametrize("model",
|
||||||
MISTRAL_FORMAT_MODELS) # v1 can't do func calling
|
MISTRAL_FORMAT_MODELS) # v1 can't do func calling
|
||||||
def test_mistral_function_calling(
|
def test_mistral_function_calling(vllm_runner, model: str, dtype: str) -> None:
|
||||||
vllm_runner,
|
|
||||||
model: str,
|
|
||||||
dtype: str,
|
|
||||||
) -> None:
|
|
||||||
with vllm_runner(model,
|
with vllm_runner(model,
|
||||||
dtype=dtype,
|
dtype=dtype,
|
||||||
tokenizer_mode="mistral",
|
tokenizer_mode="mistral",
|
||||||
@ -301,11 +281,8 @@ def test_mistral_function_calling(
|
|||||||
@pytest.mark.parametrize("model", MODELS)
|
@pytest.mark.parametrize("model", MODELS)
|
||||||
@pytest.mark.parametrize("guided_backend",
|
@pytest.mark.parametrize("guided_backend",
|
||||||
["outlines", "lm-format-enforcer", "xgrammar"])
|
["outlines", "lm-format-enforcer", "xgrammar"])
|
||||||
def test_mistral_guided_decoding(
|
def test_mistral_guided_decoding(vllm_runner, model: str,
|
||||||
vllm_runner,
|
guided_backend: str) -> None:
|
||||||
model: str,
|
|
||||||
guided_backend: str,
|
|
||||||
) -> None:
|
|
||||||
with vllm_runner(model, dtype='bfloat16',
|
with vllm_runner(model, dtype='bfloat16',
|
||||||
tokenizer_mode="mistral") as vllm_model:
|
tokenizer_mode="mistral") as vllm_model:
|
||||||
|
|
||||||
|
@ -23,8 +23,14 @@ MODELS = [
|
|||||||
reason="FP8 is not supported on this GPU type.")
|
reason="FP8 is not supported on this GPU type.")
|
||||||
@pytest.mark.parametrize("model_id", MODELS)
|
@pytest.mark.parametrize("model_id", MODELS)
|
||||||
@pytest.mark.parametrize("force_marlin", [False, True])
|
@pytest.mark.parametrize("force_marlin", [False, True])
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"use_rocm_aiter", [True, False] if current_platform.is_rocm() else [False])
|
||||||
def test_model_load_and_run(vllm_runner, model_id: str, force_marlin: bool,
|
def test_model_load_and_run(vllm_runner, model_id: str, force_marlin: bool,
|
||||||
monkeypatch) -> None:
|
use_rocm_aiter: bool, monkeypatch) -> None:
|
||||||
|
|
||||||
|
if use_rocm_aiter:
|
||||||
|
monkeypatch.setenv("VLLM_ROCM_USE_AITER", "1")
|
||||||
|
|
||||||
if force_marlin:
|
if force_marlin:
|
||||||
monkeypatch.setenv("VLLM_TEST_FORCE_FP8_MARLIN", "1")
|
monkeypatch.setenv("VLLM_TEST_FORCE_FP8_MARLIN", "1")
|
||||||
|
|
||||||
@ -47,7 +53,13 @@ KV_CACHE_MODELS = [
|
|||||||
@pytest.mark.skipif(not is_quant_method_supported("fp8"),
|
@pytest.mark.skipif(not is_quant_method_supported("fp8"),
|
||||||
reason="FP8 is not supported on this GPU type.")
|
reason="FP8 is not supported on this GPU type.")
|
||||||
@pytest.mark.parametrize("model_id", KV_CACHE_MODELS)
|
@pytest.mark.parametrize("model_id", KV_CACHE_MODELS)
|
||||||
def test_kv_cache_model_load_and_run(vllm_runner, model_id: str, monkeypatch):
|
@pytest.mark.parametrize(
|
||||||
|
"use_rocm_aiter", [True, False] if current_platform.is_rocm() else [False])
|
||||||
|
def test_kv_cache_model_load_and_run(vllm_runner, model_id: str,
|
||||||
|
use_rocm_aiter: bool, monkeypatch):
|
||||||
|
if use_rocm_aiter:
|
||||||
|
monkeypatch.setenv("VLLM_ROCM_USE_AITER", "1")
|
||||||
|
|
||||||
# vllm_runner.apply_model() relies on V0 internals.
|
# vllm_runner.apply_model() relies on V0 internals.
|
||||||
monkeypatch.setenv("VLLM_USE_V1", "0")
|
monkeypatch.setenv("VLLM_USE_V1", "0")
|
||||||
with vllm_runner(model_id, kv_cache_dtype="fp8") as llm:
|
with vllm_runner(model_id, kv_cache_dtype="fp8") as llm:
|
||||||
@ -86,8 +98,13 @@ def test_kv_cache_model_load_and_run(vllm_runner, model_id: str, monkeypatch):
|
|||||||
reason="FP8 is not supported on this GPU type.")
|
reason="FP8 is not supported on this GPU type.")
|
||||||
@pytest.mark.parametrize("kv_cache_dtype", ["auto", "fp8"])
|
@pytest.mark.parametrize("kv_cache_dtype", ["auto", "fp8"])
|
||||||
@pytest.mark.parametrize("force_marlin", [False, True])
|
@pytest.mark.parametrize("force_marlin", [False, True])
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"use_rocm_aiter", [True, False] if current_platform.is_rocm() else [False])
|
||||||
def test_load_fp16_model(vllm_runner, kv_cache_dtype: str, force_marlin: bool,
|
def test_load_fp16_model(vllm_runner, kv_cache_dtype: str, force_marlin: bool,
|
||||||
monkeypatch) -> None:
|
use_rocm_aiter: bool, monkeypatch) -> None:
|
||||||
|
if use_rocm_aiter:
|
||||||
|
monkeypatch.setenv("VLLM_ROCM_USE_AITER", "1")
|
||||||
|
|
||||||
# vllm_runner.apply_model() relies on V0 internals.
|
# vllm_runner.apply_model() relies on V0 internals.
|
||||||
monkeypatch.setenv("VLLM_USE_V1", "0")
|
monkeypatch.setenv("VLLM_USE_V1", "0")
|
||||||
|
|
||||||
|
15
vllm/envs.py
15
vllm/envs.py
@ -73,6 +73,8 @@ if TYPE_CHECKING:
|
|||||||
VLLM_DISABLED_KERNELS: list[str] = []
|
VLLM_DISABLED_KERNELS: list[str] = []
|
||||||
VLLM_USE_V1: bool = True
|
VLLM_USE_V1: bool = True
|
||||||
VLLM_ROCM_USE_AITER: bool = False
|
VLLM_ROCM_USE_AITER: bool = False
|
||||||
|
VLLM_ROCM_USE_AITER_MOE: bool = True
|
||||||
|
VLLM_ROCM_USE_AITER_FP8_BLOCK_SCALED_MOE: bool = False
|
||||||
VLLM_ROCM_USE_AITER_RMSNORM: bool = True
|
VLLM_ROCM_USE_AITER_RMSNORM: bool = True
|
||||||
VLLM_ROCM_FP8_PADDING: bool = True
|
VLLM_ROCM_FP8_PADDING: bool = True
|
||||||
VLLM_ROCM_MOE_PADDING: bool = True
|
VLLM_ROCM_MOE_PADDING: bool = True
|
||||||
@ -513,6 +515,19 @@ environment_variables: dict[str, Callable[[], Any]] = {
|
|||||||
lambda: (os.getenv("VLLM_ROCM_USE_AITER", "False").lower() in
|
lambda: (os.getenv("VLLM_ROCM_USE_AITER", "False").lower() in
|
||||||
("true", "1")),
|
("true", "1")),
|
||||||
|
|
||||||
|
# Whether to use aiter moe ops.
|
||||||
|
# By default is enabled.
|
||||||
|
"VLLM_ROCM_USE_AITER_MOE":
|
||||||
|
lambda: (os.getenv("VLLM_ROCM_USE_AITER_MOE", "True").lower() in
|
||||||
|
("true", "1")),
|
||||||
|
|
||||||
|
# Whether to use aiter block scaled moe kernel.
|
||||||
|
# By default this is disabled.
|
||||||
|
"VLLM_ROCM_USE_AITER_FP8_BLOCK_SCALED_MOE":
|
||||||
|
lambda:
|
||||||
|
(os.getenv("VLLM_ROCM_USE_AITER_FP8_BLOCK_SCALED_MOE", "false").lower() in
|
||||||
|
("true", "1")),
|
||||||
|
|
||||||
# use aiter rms norm op if aiter ops are enabled.
|
# use aiter rms norm op if aiter ops are enabled.
|
||||||
"VLLM_ROCM_USE_AITER_RMSNORM":
|
"VLLM_ROCM_USE_AITER_RMSNORM":
|
||||||
lambda: (os.getenv("VLLM_ROCM_USE_AITER_RMSNORM", "True").lower() in
|
lambda: (os.getenv("VLLM_ROCM_USE_AITER_RMSNORM", "True").lower() in
|
||||||
|
@ -17,6 +17,10 @@ from vllm.model_executor.layers.quantization.utils.fp8_utils import (
|
|||||||
from vllm.platforms import current_platform
|
from vllm.platforms import current_platform
|
||||||
from vllm.utils import direct_register_custom_op
|
from vllm.utils import direct_register_custom_op
|
||||||
|
|
||||||
|
from .rocm_aiter_fused_moe import (is_rocm_aiter_moe_enabled,
|
||||||
|
rocm_aiter_fused_experts,
|
||||||
|
rocm_aiter_topk_softmax)
|
||||||
|
|
||||||
logger = init_logger(__name__)
|
logger = init_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
@ -1035,6 +1039,28 @@ def try_get_optimal_moe_config(
|
|||||||
return config
|
return config
|
||||||
|
|
||||||
|
|
||||||
|
def vllm_topk_softmax(topk_weights: torch.Tensor, topk_indices: torch.Tensor,
|
||||||
|
token_expert_indices: torch.Tensor,
|
||||||
|
gating_output: torch.Tensor,
|
||||||
|
renormalize: bool) -> tuple[torch.Tensor, ...]:
|
||||||
|
ops.topk_softmax(
|
||||||
|
topk_weights,
|
||||||
|
topk_indices,
|
||||||
|
token_expert_indices,
|
||||||
|
gating_output,
|
||||||
|
)
|
||||||
|
if renormalize:
|
||||||
|
topk_weights = topk_weights / topk_weights.sum(dim=-1, keepdim=True)
|
||||||
|
|
||||||
|
return topk_weights, topk_indices
|
||||||
|
|
||||||
|
|
||||||
|
def dispatch_topk_func() -> Callable[..., tuple[torch.Tensor, ...]]:
|
||||||
|
if is_rocm_aiter_moe_enabled():
|
||||||
|
return rocm_aiter_topk_softmax
|
||||||
|
return vllm_topk_softmax
|
||||||
|
|
||||||
|
|
||||||
def fused_topk(
|
def fused_topk(
|
||||||
hidden_states: torch.Tensor,
|
hidden_states: torch.Tensor,
|
||||||
gating_output: torch.Tensor,
|
gating_output: torch.Tensor,
|
||||||
@ -1059,17 +1085,14 @@ def fused_topk(
|
|||||||
dtype=torch.int32,
|
dtype=torch.int32,
|
||||||
device=hidden_states.device)
|
device=hidden_states.device)
|
||||||
|
|
||||||
ops.topk_softmax(
|
gating_output_float = gating_output.float() # TODO(woosuk): Optimize this.
|
||||||
topk_weights,
|
|
||||||
topk_ids,
|
topk_func = dispatch_topk_func()
|
||||||
token_expert_indicies,
|
topk_weights, topk_ids = topk_func(topk_weights, topk_ids,
|
||||||
gating_output.float(), # TODO(woosuk): Optimize this.
|
token_expert_indicies,
|
||||||
)
|
gating_output_float, renormalize)
|
||||||
|
|
||||||
del token_expert_indicies # Not used. Will be used in the future.
|
del token_expert_indicies # Not used. Will be used in the future.
|
||||||
|
|
||||||
if renormalize:
|
|
||||||
topk_weights = topk_weights / topk_weights.sum(dim=-1, keepdim=True)
|
|
||||||
|
|
||||||
return topk_weights, topk_ids
|
return topk_weights, topk_ids
|
||||||
|
|
||||||
|
|
||||||
@ -1259,6 +1282,24 @@ direct_register_custom_op(
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def torch_vllm_inplace_fused_experts(**kwargs) -> torch.Tensor:
|
||||||
|
torch.ops.vllm.inplace_fused_experts(**kwargs)
|
||||||
|
hidden_states = kwargs['hidden_states']
|
||||||
|
return hidden_states
|
||||||
|
|
||||||
|
|
||||||
|
def torch_vllm_outplace_fused_experts(**kwargs) -> torch.Tensor:
|
||||||
|
return torch.ops.vllm.outplace_fused_experts(**kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
def dispatch_fused_experts_func(inplace: bool) -> Callable[..., torch.Tensor]:
|
||||||
|
if is_rocm_aiter_moe_enabled():
|
||||||
|
return rocm_aiter_fused_experts
|
||||||
|
if inplace:
|
||||||
|
return torch_vllm_inplace_fused_experts
|
||||||
|
return torch_vllm_outplace_fused_experts
|
||||||
|
|
||||||
|
|
||||||
def fused_experts(hidden_states: torch.Tensor,
|
def fused_experts(hidden_states: torch.Tensor,
|
||||||
w1: torch.Tensor,
|
w1: torch.Tensor,
|
||||||
w2: torch.Tensor,
|
w2: torch.Tensor,
|
||||||
@ -1278,20 +1319,25 @@ def fused_experts(hidden_states: torch.Tensor,
|
|||||||
a1_scale: Optional[torch.Tensor] = None,
|
a1_scale: Optional[torch.Tensor] = None,
|
||||||
a2_scale: Optional[torch.Tensor] = None,
|
a2_scale: Optional[torch.Tensor] = None,
|
||||||
block_shape: Optional[List[int]] = None) -> torch.Tensor:
|
block_shape: Optional[List[int]] = None) -> torch.Tensor:
|
||||||
|
return dispatch_fused_experts_func(inplace)(
|
||||||
if inplace:
|
hidden_states=hidden_states,
|
||||||
torch.ops.vllm.inplace_fused_experts(
|
w1=w1,
|
||||||
hidden_states, w1, w2, topk_weights, topk_ids, activation,
|
w2=w2,
|
||||||
use_fp8_w8a8, use_int8_w8a16, use_int4_w4a16, global_num_experts,
|
topk_weights=topk_weights,
|
||||||
expert_map, w1_scale, w2_scale, w1_zp, w2_zp, a1_scale, a2_scale,
|
topk_ids=topk_ids,
|
||||||
block_shape)
|
activation=activation,
|
||||||
return hidden_states
|
use_fp8_w8a8=use_fp8_w8a8,
|
||||||
else:
|
use_int8_w8a16=use_int8_w8a16,
|
||||||
return torch.ops.vllm.outplace_fused_experts(
|
use_int4_w4a16=use_int4_w4a16,
|
||||||
hidden_states, w1, w2, topk_weights, topk_ids, activation,
|
global_num_experts=global_num_experts,
|
||||||
use_fp8_w8a8, use_int8_w8a16, use_int4_w4a16, global_num_experts,
|
expert_map=expert_map,
|
||||||
expert_map, w1_scale, w2_scale, w1_zp, w2_zp, a1_scale, a2_scale,
|
w1_scale=w1_scale,
|
||||||
block_shape)
|
w2_scale=w2_scale,
|
||||||
|
w1_zp=w1_zp,
|
||||||
|
w2_zp=w2_zp,
|
||||||
|
a1_scale=a1_scale,
|
||||||
|
a2_scale=a2_scale,
|
||||||
|
block_shape=block_shape)
|
||||||
|
|
||||||
|
|
||||||
def fused_experts_impl(hidden_states: torch.Tensor,
|
def fused_experts_impl(hidden_states: torch.Tensor,
|
||||||
|
@ -8,7 +8,7 @@ import torch
|
|||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
from torch.nn.parameter import UninitializedParameter
|
from torch.nn.parameter import UninitializedParameter
|
||||||
|
|
||||||
from vllm import envs
|
import vllm.envs as envs
|
||||||
from vllm.config import get_current_vllm_config
|
from vllm.config import get_current_vllm_config
|
||||||
from vllm.distributed import (get_dp_group, get_tensor_model_parallel_rank,
|
from vllm.distributed import (get_dp_group, get_tensor_model_parallel_rank,
|
||||||
get_tensor_model_parallel_world_size,
|
get_tensor_model_parallel_world_size,
|
||||||
@ -16,6 +16,8 @@ from vllm.distributed import (get_dp_group, get_tensor_model_parallel_rank,
|
|||||||
from vllm.forward_context import ForwardContext, get_forward_context
|
from vllm.forward_context import ForwardContext, get_forward_context
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
from vllm.model_executor.custom_op import CustomOp
|
from vllm.model_executor.custom_op import CustomOp
|
||||||
|
from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import (
|
||||||
|
is_rocm_aiter_moe_enabled, shuffle_weights)
|
||||||
from vllm.model_executor.layers.quantization.base_config import (
|
from vllm.model_executor.layers.quantization.base_config import (
|
||||||
QuantizationConfig, QuantizeMethodBase)
|
QuantizationConfig, QuantizeMethodBase)
|
||||||
from vllm.model_executor.utils import set_weight_attrs
|
from vllm.model_executor.utils import set_weight_attrs
|
||||||
@ -118,6 +120,16 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
|
|||||||
layer.w2_weight.data),
|
layer.w2_weight.data),
|
||||||
requires_grad=False)
|
requires_grad=False)
|
||||||
|
|
||||||
|
if is_rocm_aiter_moe_enabled():
|
||||||
|
# reshaping weights is required for aiter moe kernel.
|
||||||
|
shuffled_w13, shuffled_w2 = shuffle_weights(
|
||||||
|
layer.w13_weight.data, layer.w2_weight.data)
|
||||||
|
|
||||||
|
layer.w13_weight = torch.nn.Parameter(shuffled_w13,
|
||||||
|
requires_grad=False)
|
||||||
|
layer.w2_weight = torch.nn.Parameter(shuffled_w2,
|
||||||
|
requires_grad=False)
|
||||||
|
|
||||||
if current_platform.is_cpu():
|
if current_platform.is_cpu():
|
||||||
if current_platform.get_cpu_architecture() == CpuArchEnum.X86:
|
if current_platform.get_cpu_architecture() == CpuArchEnum.X86:
|
||||||
import intel_extension_for_pytorch as ipex
|
import intel_extension_for_pytorch as ipex
|
||||||
|
157
vllm/model_executor/layers/fused_moe/rocm_aiter_fused_moe.py
Normal file
157
vllm/model_executor/layers/fused_moe/rocm_aiter_fused_moe.py
Normal file
@ -0,0 +1,157 @@
|
|||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
from typing import List, Optional
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
import vllm.envs as envs
|
||||||
|
from vllm.model_executor.layers.quantization.utils.fp8_utils import (
|
||||||
|
per_token_group_quant_fp8)
|
||||||
|
from vllm.platforms import current_platform
|
||||||
|
|
||||||
|
|
||||||
|
def is_rocm_aiter_moe_enabled() -> bool:
|
||||||
|
return current_platform.is_rocm() \
|
||||||
|
and envs.VLLM_ROCM_USE_AITER_MOE \
|
||||||
|
and envs.VLLM_ROCM_USE_AITER \
|
||||||
|
|
||||||
|
|
||||||
|
def is_rocm_aiter_block_scaled_moe_enabled() -> bool:
|
||||||
|
return is_rocm_aiter_moe_enabled() and \
|
||||||
|
envs.VLLM_ROCM_USE_AITER_FP8_BLOCK_SCALED_MOE
|
||||||
|
|
||||||
|
|
||||||
|
def rocm_aiter_fused_experts(
|
||||||
|
*,
|
||||||
|
hidden_states: torch.Tensor,
|
||||||
|
w1: torch.Tensor,
|
||||||
|
w2: torch.Tensor,
|
||||||
|
topk_weights: torch.Tensor,
|
||||||
|
topk_ids: torch.Tensor,
|
||||||
|
use_fp8_w8a8: bool = False,
|
||||||
|
w1_scale: Optional[torch.Tensor] = None,
|
||||||
|
w2_scale: Optional[torch.Tensor] = None,
|
||||||
|
block_shape: Optional[List[int]] = None,
|
||||||
|
expert_mask: Optional[torch.Tensor] = None,
|
||||||
|
**kwagrs # Ignore additional keyword arguments
|
||||||
|
) -> torch.Tensor:
|
||||||
|
|
||||||
|
import aiter as rocm_aiter
|
||||||
|
import aiter.fused_moe_bf16_asm as rocm_aiter_asm_fmoe
|
||||||
|
|
||||||
|
if envs.VLLM_ROCM_USE_AITER_FP8_BLOCK_SCALED_MOE and use_fp8_w8a8:
|
||||||
|
assert w1_scale is not None
|
||||||
|
assert w2_scale is not None
|
||||||
|
|
||||||
|
local_E = E = w1.shape[0]
|
||||||
|
if expert_mask is not None:
|
||||||
|
E = expert_mask.numel()
|
||||||
|
|
||||||
|
topk = topk_ids.shape[1]
|
||||||
|
model_dim = w1.shape[-1]
|
||||||
|
dtype = hidden_states.dtype
|
||||||
|
# The default block sizes are 128 in AITER.
|
||||||
|
if block_shape is None:
|
||||||
|
block_shape = [128, 128]
|
||||||
|
|
||||||
|
scale_blk_k = block_shape[1]
|
||||||
|
|
||||||
|
(
|
||||||
|
sorted_token_ids,
|
||||||
|
sorted_weight_buf,
|
||||||
|
sorted_expert_ids,
|
||||||
|
num_valid_ids,
|
||||||
|
out_asm,
|
||||||
|
) = rocm_aiter_asm_fmoe.moe_sorting_ck(topk_ids,
|
||||||
|
topk_weights,
|
||||||
|
E,
|
||||||
|
model_dim,
|
||||||
|
dtype,
|
||||||
|
expert_mask=expert_mask)
|
||||||
|
|
||||||
|
a1, a1_scale = per_token_group_quant_fp8(hidden_states, scale_blk_k)
|
||||||
|
rocm_aiter.fmoe_fp8_blockscale_g1u1(
|
||||||
|
out_asm,
|
||||||
|
a1,
|
||||||
|
w1,
|
||||||
|
w2,
|
||||||
|
sorted_token_ids,
|
||||||
|
sorted_weight_buf,
|
||||||
|
sorted_expert_ids,
|
||||||
|
num_valid_ids,
|
||||||
|
topk,
|
||||||
|
w1_scale.view(local_E, -1),
|
||||||
|
w2_scale.view(local_E, -1),
|
||||||
|
a1_scale.t().contiguous(),
|
||||||
|
block_shape[0],
|
||||||
|
block_shape[1],
|
||||||
|
None,
|
||||||
|
)
|
||||||
|
return out_asm
|
||||||
|
|
||||||
|
elif use_fp8_w8a8:
|
||||||
|
return rocm_aiter_asm_fmoe.asm_moe(hidden_states=hidden_states,
|
||||||
|
w1=w1,
|
||||||
|
w2=w2,
|
||||||
|
topk_weight=topk_weights,
|
||||||
|
topk_ids=topk_ids,
|
||||||
|
fc1_scale=w1_scale,
|
||||||
|
fc2_scale=w2_scale,
|
||||||
|
fc1_smooth_scale=None,
|
||||||
|
fc2_smooth_scale=None,
|
||||||
|
a16=False)
|
||||||
|
|
||||||
|
return rocm_aiter.ck_moe(hidden_states=hidden_states,
|
||||||
|
w1=w1,
|
||||||
|
w2=w2,
|
||||||
|
topk_weights=topk_weights,
|
||||||
|
topk_ids=topk_ids)
|
||||||
|
|
||||||
|
|
||||||
|
def rocm_aiter_topk_softmax(topk_weights: torch.Tensor,
|
||||||
|
topk_indices: torch.Tensor,
|
||||||
|
token_expert_indices: torch.Tensor,
|
||||||
|
gating_output: torch.Tensor,
|
||||||
|
renormalize: bool) -> tuple[torch.Tensor, ...]:
|
||||||
|
import aiter as rocm_aiter
|
||||||
|
rocm_aiter.topk_softmax(topk_weights, topk_indices, token_expert_indices,
|
||||||
|
gating_output, renormalize)
|
||||||
|
|
||||||
|
return topk_weights, topk_indices
|
||||||
|
|
||||||
|
|
||||||
|
def shuffle_weights(*tensors: torch.Tensor) -> tuple[torch.Tensor, ...]:
|
||||||
|
"""
|
||||||
|
Applies shuffle_weight function from AITER to each
|
||||||
|
input tensor and returns them.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
*tensors: Variable number of torch.Tensor objects.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A tuple of shuffled tensors.
|
||||||
|
"""
|
||||||
|
from aiter.ops.shuffle import shuffle_weight
|
||||||
|
|
||||||
|
return tuple(shuffle_weight(tensor) for tensor in tensors)
|
||||||
|
|
||||||
|
|
||||||
|
def expand_weights(*tensors: torch.Tensor,
|
||||||
|
expansion_dims: list[int]) -> tuple[torch.Tensor, ...]:
|
||||||
|
"""
|
||||||
|
Expands the dimensions of input tensors.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
*tensors: A variable number of torch.Tensor objects.
|
||||||
|
expansion_dims: A list of expansion dimensions
|
||||||
|
corresponding to each tensor.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A tuple of tensors with expanded dimensions.
|
||||||
|
"""
|
||||||
|
|
||||||
|
assert len(tensors) == len(expansion_dims), \
|
||||||
|
"Number of tensors must match the number of expansion dimensions."
|
||||||
|
|
||||||
|
return tuple(
|
||||||
|
tensor.unsqueeze(-1).unsqueeze(-1).expand((-1, dim, -1))
|
||||||
|
for tensor, dim in zip(tensors, expansion_dims))
|
@ -13,6 +13,9 @@ from vllm.distributed import get_tensor_model_parallel_world_size
|
|||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
from vllm.model_executor.layers.fused_moe import (FusedMoE, FusedMoEMethodBase,
|
from vllm.model_executor.layers.fused_moe import (FusedMoE, FusedMoEMethodBase,
|
||||||
FusedMoeWeightScaleSupported)
|
FusedMoeWeightScaleSupported)
|
||||||
|
from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import (
|
||||||
|
expand_weights, is_rocm_aiter_block_scaled_moe_enabled,
|
||||||
|
is_rocm_aiter_moe_enabled, shuffle_weights)
|
||||||
from vllm.model_executor.layers.linear import (LinearBase, LinearMethodBase,
|
from vllm.model_executor.layers.linear import (LinearBase, LinearMethodBase,
|
||||||
UnquantizedLinearMethod)
|
UnquantizedLinearMethod)
|
||||||
from vllm.model_executor.layers.quantization.base_config import (
|
from vllm.model_executor.layers.quantization.base_config import (
|
||||||
@ -554,6 +557,15 @@ class Fp8MoEMethod(FusedMoEMethodBase):
|
|||||||
layer.w2_weight = Parameter(w2_weight, requires_grad=False)
|
layer.w2_weight = Parameter(w2_weight, requires_grad=False)
|
||||||
layer.w2_weight_scale_inv = Parameter(w2_weight_scale_inv,
|
layer.w2_weight_scale_inv = Parameter(w2_weight_scale_inv,
|
||||||
requires_grad=False)
|
requires_grad=False)
|
||||||
|
if is_rocm_aiter_block_scaled_moe_enabled():
|
||||||
|
# reshaping weights is required for aiter moe kernel.
|
||||||
|
shuffled_w13, shuffled_w2 = shuffle_weights(
|
||||||
|
layer.w13_weight.data, layer.w2_weight.data)
|
||||||
|
|
||||||
|
layer.w13_weight = torch.nn.Parameter(shuffled_w13,
|
||||||
|
requires_grad=False)
|
||||||
|
layer.w2_weight = torch.nn.Parameter(shuffled_w2,
|
||||||
|
requires_grad=False)
|
||||||
return
|
return
|
||||||
|
|
||||||
# If checkpoint is fp16, quantize in place.
|
# If checkpoint is fp16, quantize in place.
|
||||||
@ -581,6 +593,26 @@ class Fp8MoEMethod(FusedMoEMethodBase):
|
|||||||
requires_grad=False)
|
requires_grad=False)
|
||||||
layer.w2_weight = torch.nn.Parameter(w2_weight,
|
layer.w2_weight = torch.nn.Parameter(w2_weight,
|
||||||
requires_grad=False)
|
requires_grad=False)
|
||||||
|
if is_rocm_aiter_moe_enabled():
|
||||||
|
# reshaping weights is required for aiter moe kernel.
|
||||||
|
w13_scales, w2_scales = expand_weights(
|
||||||
|
layer.w13_weight_scale.data,
|
||||||
|
layer.w2_weight_scale.data,
|
||||||
|
expansion_dims=[
|
||||||
|
layer.w13_weight.shape[1], layer.w2_weight.shape[1]
|
||||||
|
])
|
||||||
|
layer.w13_weight_scale = torch.nn.Parameter(
|
||||||
|
w13_scales.contiguous(), requires_grad=False)
|
||||||
|
layer.w2_weight_scale = torch.nn.Parameter(
|
||||||
|
w2_scales.contiguous(), requires_grad=False)
|
||||||
|
|
||||||
|
shuffled_w13, shuffled_w2 = shuffle_weights(
|
||||||
|
layer.w13_weight, layer.w2_weight)
|
||||||
|
|
||||||
|
layer.w13_weight = torch.nn.Parameter(shuffled_w13,
|
||||||
|
requires_grad=False)
|
||||||
|
layer.w2_weight = torch.nn.Parameter(shuffled_w2,
|
||||||
|
requires_grad=False)
|
||||||
return
|
return
|
||||||
|
|
||||||
# If checkpoint is fp8, we need to handle that the
|
# If checkpoint is fp8, we need to handle that the
|
||||||
@ -648,6 +680,26 @@ class Fp8MoEMethod(FusedMoEMethodBase):
|
|||||||
dq_weight, max_w13_scales[expert_id])
|
dq_weight, max_w13_scales[expert_id])
|
||||||
start += shard_size
|
start += shard_size
|
||||||
|
|
||||||
|
if is_rocm_aiter_moe_enabled():
|
||||||
|
# reshaping weights is required for aiter moe kernel.
|
||||||
|
expansion_dims = [
|
||||||
|
layer.w13_weight.shape[1], layer.w2_weight.shape[1]
|
||||||
|
]
|
||||||
|
max_w13_scales, w2_scales = expand_weights(
|
||||||
|
max_w13_scales,
|
||||||
|
layer.w2_weight_scale.data,
|
||||||
|
expansion_dims=expansion_dims)
|
||||||
|
layer.w2_weight_scale = torch.nn.Parameter(
|
||||||
|
w2_scales.contiguous(), requires_grad=False)
|
||||||
|
|
||||||
|
shuffled_w13, shuffled_w2 = shuffle_weights(
|
||||||
|
layer.w13_weight, layer.w2_weight)
|
||||||
|
|
||||||
|
layer.w13_weight = torch.nn.Parameter(shuffled_w13,
|
||||||
|
requires_grad=False)
|
||||||
|
layer.w2_weight = torch.nn.Parameter(shuffled_w2,
|
||||||
|
requires_grad=False)
|
||||||
|
|
||||||
layer.w13_weight_scale = torch.nn.Parameter(max_w13_scales,
|
layer.w13_weight_scale = torch.nn.Parameter(max_w13_scales,
|
||||||
requires_grad=False)
|
requires_grad=False)
|
||||||
return
|
return
|
||||||
|
Loading…
x
Reference in New Issue
Block a user