vllm/tests/kernels/test_awq_marlin.py

165 lines
4.9 KiB
Python

"""Test AWQ with fused MoE Marlin kernels.
Run `pytest tests/kernels/test_awq_marlin.py`.
"""
import pytest
import torch
from tests.kernels.utils import (compute_max_diff, stack_and_dev, torch_moe,
torch_moe_single)
from vllm import _custom_ops as ops
from vllm.model_executor.layers.fused_moe.fused_marlin_moe import (
fused_marlin_moe, single_marlin_moe)
from vllm.model_executor.layers.fused_moe.fused_moe import fused_topk
from vllm.model_executor.layers.quantization.utils.marlin_utils_test import (
awq_marlin_quantize)
from vllm.scalar_type import scalar_types
@pytest.mark.parametrize("m", [64, 512, 222, 33, 1])
@pytest.mark.parametrize("n", [128, 2048, 256, 1024])
@pytest.mark.parametrize("k", [128, 1024, 512])
@pytest.mark.parametrize("e", [8, 64])
@pytest.mark.parametrize("topk", [2, 6])
@pytest.mark.parametrize("group_size", [-1, 32, 64, 128])
@pytest.mark.skipif(not (ops.supports_moe_ops
and hasattr(torch.ops._moe_C, "marlin_gemm_moe")),
reason="Marlin is not supported on this GPU type.")
def test_fused_marlin_moe_awq(
m: int,
n: int,
k: int,
e: int,
topk: int,
group_size: int,
):
torch.manual_seed(7)
num_bits = 4
quant_type = scalar_types.uint4
dtype = torch.float16
a = torch.randn((m, k), device="cuda", dtype=dtype) / 10
w1 = torch.randn((e, 2 * n, k), device="cuda", dtype=dtype) / 10
w2 = torch.randn((e, k, n), device="cuda", dtype=dtype) / 10
w_ref1_l = []
qweights1_l = []
scales1_l = []
zp1_l = []
for i in range(w1.shape[0]):
w_ref1, qweight1, scales1, zp1 = awq_marlin_quantize(
w1[i].transpose(1, 0), quant_type, group_size)
w_ref1_l.append(w_ref1)
qweights1_l.append(qweight1)
scales1_l.append(scales1)
zp1_l.append(zp1)
w_ref1 = stack_and_dev(w_ref1_l)
qweight1 = stack_and_dev(qweights1_l).contiguous()
scales1 = stack_and_dev(scales1_l)
zp1 = stack_and_dev(zp1_l)
w_ref2_l = []
qweights2_l = []
scales2_l = []
zp2_l = []
for i in range(w2.shape[0]):
w_ref2, qweight2, scales2, zp2 = awq_marlin_quantize(
w2[i].transpose(1, 0), quant_type, group_size)
w_ref2_l.append(w_ref2)
qweights2_l.append(qweight2)
scales2_l.append(scales2)
zp2_l.append(zp2)
w_ref2 = stack_and_dev(w_ref2_l)
qweight2 = stack_and_dev(qweights2_l).contiguous()
scales2 = stack_and_dev(scales2_l)
zp2 = stack_and_dev(zp2_l)
score = torch.randn((m, e), device="cuda", dtype=dtype)
topk_weights, topk_ids = fused_topk(a, score, topk, False)
marlin_output = fused_marlin_moe(
a,
qweight1,
qweight2,
scales1,
scales2,
score,
topk_weights,
topk_ids,
w1_zeros=zp1,
w2_zeros=zp2,
num_bits=num_bits,
)
torch_output = torch_moe(
a,
w_ref1.transpose(1, 2),
w_ref2.transpose(1, 2),
score,
topk,
)
assert compute_max_diff(marlin_output, torch_output) < 4e-2
@pytest.mark.skip("This test is here for the sake of debugging, "
"don't run it in automated tests.")
@pytest.mark.parametrize("m", [64, 512, 222, 33, 1])
@pytest.mark.parametrize("n", [128, 2048, 256, 1024])
@pytest.mark.parametrize("k", [128, 1024, 512])
@pytest.mark.parametrize("e", [8, 64])
@pytest.mark.parametrize("topk", [2, 6])
@pytest.mark.parametrize("group_size", [-1, 32, 64, 128])
def test_single_marlin_moe_multiply_awq(
m: int,
n: int,
k: int,
e: int,
topk: int,
group_size: int,
):
torch.manual_seed(7)
num_bits = 4
quant_type = scalar_types.uint4
dtype = torch.float16
a = torch.randn((m, k), device="cuda", dtype=dtype) / 10
w = torch.randn((e, n, k), device="cuda", dtype=dtype) / 10
w_ref_l = []
qweights_l = []
scales_l = []
zp_l = []
for i in range(w.shape[0]):
w_ref, qweight, scales, zp = awq_marlin_quantize(
w[i].transpose(1, 0), quant_type, group_size)
w_ref_l.append(w_ref)
qweights_l.append(qweight)
scales_l.append(scales)
zp_l.append(zp)
w_ref = stack_and_dev(w_ref_l)
qweight = stack_and_dev(qweights_l).contiguous()
scales = stack_and_dev(scales_l).contiguous()
zp = stack_and_dev(zp_l).contiguous()
score = torch.randn((m, e), device="cuda", dtype=dtype)
marlin_output = single_marlin_moe(a,
qweight,
scales,
score,
topk,
renormalize=False,
w_zeros=zp,
num_bits=num_bits)
torch_output = torch_moe_single(a, w_ref.transpose(1, 2), score, topk)
assert compute_max_diff(marlin_output, torch_output) < 1e-2