150 lines
5.1 KiB
Python
150 lines
5.1 KiB
Python
# SPDX-License-Identifier: Apache-2.0
|
|
|
|
# Adapted from https://github.com/sgl-project/sglang/blob/main/test/srt/test_int8_kernel.py
|
|
import itertools
|
|
|
|
import pytest
|
|
import torch
|
|
|
|
from vllm.model_executor.layers.activation import SiluAndMul
|
|
from vllm.model_executor.layers.fused_moe import fused_moe
|
|
from vllm.model_executor.layers.quantization.utils.int8_utils import (
|
|
per_token_quant_int8)
|
|
from vllm.platforms import current_platform
|
|
|
|
if current_platform.get_device_capability() < (7, 0):
|
|
pytest.skip("INT8 Triton requires CUDA 7.0 or higher",
|
|
allow_module_level=True)
|
|
|
|
|
|
def native_w8a8_per_token_matmul(A, B, As, Bs, output_dtype=torch.float16):
|
|
"""Matrix multiplication function that supports per-token input
|
|
quantization and per-column weight quantization"""
|
|
A = A.to(torch.float32)
|
|
B = B.to(torch.float32)
|
|
|
|
assert A.shape[-1] == B.shape[-1], "Dimension mismatch"
|
|
assert B.ndim == 2 and B.is_contiguous(
|
|
), "B must be a 2D contiguous tensor"
|
|
|
|
# Reshape input
|
|
M = A.numel() // A.shape[-1]
|
|
B = B.t() # Transpose weight matrix
|
|
N, K = B.shape
|
|
origin_C_shape = A.shape[:-1] + (K, )
|
|
A = A.reshape(M, N)
|
|
|
|
# As is per-token [M, 1], Bs is per-column [1, K]
|
|
C = torch.matmul(A, B) # [M, K]
|
|
C = As * C * Bs.view(1, -1) # Broadcast per-column scale
|
|
|
|
return C.reshape(origin_C_shape).to(output_dtype)
|
|
|
|
|
|
def torch_w8a8_per_column_moe(a, w1, w2, w1_s, w2_s, score, topk):
|
|
"""This function performs fused moe with per-column int8 quantization
|
|
using native torch."""
|
|
|
|
B, D = a.shape
|
|
# Perform per-token quantization
|
|
a_q, a_s = per_token_quant_int8(a)
|
|
# Repeat tokens to match topk
|
|
a_q = a_q.view(B, -1, D).repeat(1, topk, 1).reshape(-1, D)
|
|
# Also repeat the scale
|
|
a_s = a_s.view(B, -1, 1).repeat(1, topk, 1).reshape(-1, 1) # [B*topk, 1]
|
|
|
|
out = torch.zeros(B * topk, w2.shape[1], dtype=a.dtype, device=a.device)
|
|
|
|
# Calculate routing
|
|
score = torch.softmax(score, dim=-1, dtype=torch.float32)
|
|
topk_weight, topk_ids = torch.topk(score, topk)
|
|
topk_weight = topk_weight.view(-1)
|
|
topk_ids = topk_ids.view(-1)
|
|
# Process each expert
|
|
for i in range(w1.shape[0]):
|
|
mask = topk_ids == i
|
|
if mask.sum():
|
|
# First MLP layer: note that a_s is now per-token
|
|
inter_out = native_w8a8_per_token_matmul(a_q[mask],
|
|
w1[i],
|
|
a_s[mask],
|
|
w1_s[i],
|
|
output_dtype=a.dtype)
|
|
# Activation function
|
|
act_out = SiluAndMul().forward_native(inter_out)
|
|
# Quantize activation output with per-token
|
|
act_out_q, act_out_s = per_token_quant_int8(act_out)
|
|
|
|
# Second MLP layer
|
|
out[mask] = native_w8a8_per_token_matmul(act_out_q,
|
|
w2[i],
|
|
act_out_s,
|
|
w2_s[i],
|
|
output_dtype=a.dtype)
|
|
# Apply routing weights and sum
|
|
return (out.view(B, -1, w2.shape[1]) *
|
|
topk_weight.view(B, -1, 1).to(out.dtype)).sum(dim=1)
|
|
|
|
|
|
@pytest.fixture(autouse=True, scope="module")
|
|
def setup_cuda():
|
|
"""Sets the default CUDA device for all tests in this module."""
|
|
torch.set_default_device("cuda")
|
|
|
|
|
|
DTYPES = [torch.half, torch.bfloat16]
|
|
M = [1, 33]
|
|
N = [128, 1024]
|
|
K = [256, 4096]
|
|
E = [8]
|
|
TOP_KS = [2, 6]
|
|
SEEDS = [0]
|
|
|
|
|
|
@pytest.mark.parametrize("M, N, K, E, topk, dtype, seed",
|
|
itertools.product(M, N, K, E, TOP_KS, DTYPES, SEEDS))
|
|
@torch.inference_mode()
|
|
def test_w8a8_fp8_fused_moe(M, N, K, E, topk, dtype, seed):
|
|
torch.manual_seed(seed)
|
|
# Initialize int8 quantization parameters
|
|
factor_for_scale = 1e-2
|
|
int8_max = 127
|
|
int8_min = -128
|
|
|
|
# Input tensor
|
|
# M * K
|
|
a = torch.randn((M, K), dtype=dtype) / 10
|
|
|
|
# Generate int8 weights
|
|
w1_fp32 = (torch.rand((E, 2 * N, K), dtype=torch.float32) - 0.5) * 2
|
|
w1 = (w1_fp32 * int8_max).clamp(min=int8_min, max=int8_max).to(torch.int8)
|
|
|
|
w2_fp32 = (torch.rand((E, K, N), dtype=torch.float32) - 0.5) * 2
|
|
w2 = (w2_fp32 * int8_max).clamp(min=int8_min, max=int8_max).to(torch.int8)
|
|
|
|
# Generate scale for each column (per-column quantization)
|
|
w1_s = torch.rand(E, 2 * N, device=w1_fp32.device) * factor_for_scale
|
|
w2_s = torch.rand(E, K, device=w2_fp32.device) * factor_for_scale
|
|
score = torch.randn((M, E), dtype=dtype)
|
|
|
|
ref_out = torch_w8a8_per_column_moe(a, w1, w2, w1_s, w2_s, score, topk)
|
|
out = fused_moe(
|
|
a,
|
|
w1,
|
|
w2,
|
|
score,
|
|
topk,
|
|
renormalize=False,
|
|
use_int8_w8a8=True, # Using int8-w8a8
|
|
per_channel_quant=True,
|
|
w1_scale=w1_s,
|
|
w2_scale=w2_s,
|
|
block_shape=None, # Not using block quantization
|
|
)
|
|
|
|
# Check results
|
|
rel_diff = (torch.mean(
|
|
torch.abs(out.to(torch.float32) - ref_out.to(torch.float32))) /
|
|
torch.mean(torch.abs(ref_out.to(torch.float32))))
|
|
assert rel_diff < 0.05
|