# SPDX-License-Identifier: Apache-2.0 """Test AWQ with fused MoE Marlin kernels. Run `pytest tests/kernels/test_awq_marlin.py`. """ import pytest import torch import vllm.model_executor.layers.fused_moe # noqa from tests.kernels.utils import (compute_max_diff, stack_and_dev, torch_moe, torch_moe_single) from vllm import _custom_ops as ops from vllm.model_executor.layers.fused_moe.fused_moe import fused_topk from vllm.model_executor.layers.quantization.utils.marlin_utils_test import ( awq_marlin_quantize) from vllm.scalar_type import scalar_types NUM_EXPERTS = [8, 64] TOP_KS = [2, 6] GROUP_SIZES = [-1, 32, 128] @pytest.mark.parametrize("m", [1, 33, 64, 222]) @pytest.mark.parametrize("n", [128, 2048]) @pytest.mark.parametrize("k", [128, 1024]) @pytest.mark.parametrize("e", NUM_EXPERTS) @pytest.mark.parametrize("topk", TOP_KS) @pytest.mark.parametrize("group_size", GROUP_SIZES) @pytest.mark.skipif(not (ops.supports_moe_ops and hasattr(torch.ops._moe_C, "marlin_gemm_moe")), reason="Marlin is not supported on this GPU type.") def test_fused_marlin_moe_awq( m: int, n: int, k: int, e: int, topk: int, group_size: int, ): torch.manual_seed(7) num_bits = 4 quant_type = scalar_types.uint4 dtype = torch.float16 a = torch.randn((m, k), device="cuda", dtype=dtype) / 10 w1 = torch.randn((e, 2 * n, k), device="cuda", dtype=dtype) / 10 w2 = torch.randn((e, k, n), device="cuda", dtype=dtype) / 10 w_ref1_l = [] qweights1_l = [] scales1_l = [] zp1_l = [] for i in range(w1.shape[0]): w_ref1, qweight1, scales1, zp1 = awq_marlin_quantize( w1[i].transpose(1, 0), quant_type, group_size) w_ref1_l.append(w_ref1) qweights1_l.append(qweight1) scales1_l.append(scales1) zp1_l.append(zp1) w_ref1 = stack_and_dev(w_ref1_l) qweight1 = stack_and_dev(qweights1_l).contiguous() scales1 = stack_and_dev(scales1_l) zp1 = stack_and_dev(zp1_l) w_ref2_l = [] qweights2_l = [] scales2_l = [] zp2_l = [] for i in range(w2.shape[0]): w_ref2, qweight2, scales2, zp2 = awq_marlin_quantize( w2[i].transpose(1, 0), quant_type, group_size) w_ref2_l.append(w_ref2) qweights2_l.append(qweight2) scales2_l.append(scales2) zp2_l.append(zp2) w_ref2 = stack_and_dev(w_ref2_l) qweight2 = stack_and_dev(qweights2_l).contiguous() scales2 = stack_and_dev(scales2_l) zp2 = stack_and_dev(zp2_l) score = torch.randn((m, e), device="cuda", dtype=dtype) topk_weights, topk_ids = fused_topk(a, score, topk, False) marlin_output = torch.ops.vllm.fused_marlin_moe( a, qweight1, qweight2, scales1, scales2, score, topk_weights, topk_ids, w1_zeros=zp1, w2_zeros=zp2, num_bits=num_bits, ) torch_output = torch_moe(a, w_ref1.transpose(1, 2), w_ref2.transpose(1, 2), score, topk, None) assert compute_max_diff(marlin_output, torch_output) < 4e-2 @pytest.mark.skip("This test is here for the sake of debugging, " "don't run it in automated tests.") @pytest.mark.parametrize("m", [64, 512, 222, 33, 1]) @pytest.mark.parametrize("n", [128, 2048, 256, 1024]) @pytest.mark.parametrize("k", [128, 1024, 512]) @pytest.mark.parametrize("e", [8, 64]) @pytest.mark.parametrize("topk", [2, 6]) @pytest.mark.parametrize("group_size", [-1, 32, 64, 128]) def test_single_marlin_moe_multiply_awq( m: int, n: int, k: int, e: int, topk: int, group_size: int, ): torch.manual_seed(7) num_bits = 4 quant_type = scalar_types.uint4 dtype = torch.float16 a = torch.randn((m, k), device="cuda", dtype=dtype) / 10 w = torch.randn((e, n, k), device="cuda", dtype=dtype) / 10 w_ref_l = [] qweights_l = [] scales_l = [] zp_l = [] for i in range(w.shape[0]): w_ref, qweight, scales, zp = awq_marlin_quantize( w[i].transpose(1, 0), quant_type, group_size) w_ref_l.append(w_ref) qweights_l.append(qweight) scales_l.append(scales) zp_l.append(zp) w_ref = stack_and_dev(w_ref_l) qweight = stack_and_dev(qweights_l).contiguous() scales = stack_and_dev(scales_l).contiguous() zp = stack_and_dev(zp_l).contiguous() score = torch.randn((m, e), device="cuda", dtype=dtype) marlin_output = torch.ops.vllm.single_marlin_moe(a, qweight, scales, score, topk, renormalize=False, w_zeros=zp, num_bits=num_bits) torch_output = torch_moe_single(a, w_ref.transpose(1, 2), score, topk) assert compute_max_diff(marlin_output, torch_output) < 1e-2