[Misc] Fused MoE Marlin support for GPTQ (#8217)

This commit is contained in:
Dipika Sikka 2024-09-09 23:02:52 -04:00 committed by GitHub
parent c7cb5c3335
commit 6cd5e5b07e
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
19 changed files with 912 additions and 204 deletions

View File

@ -386,7 +386,18 @@ steps:
- vllm/ - vllm/
- tests/weight_loading - tests/weight_loading
commands: commands:
- bash weight_loading/run_model_weight_loading_test.sh - bash weight_loading/run_model_weight_loading_test.sh -c weight_loading/models.txt
- label: Weight Loading Multiple GPU Test - Large Models # optional
working_dir: "/vllm-workspace/tests"
num_gpus: 2
gpu: a100
optional: true
source_file_dependencies:
- vllm/
- tests/weight_loading
commands:
- bash weight_loading/run_model_weight_loading_test.sh -c weight_loading/models-large.txt
##### multi gpus test ##### ##### multi gpus test #####

View File

@ -16,7 +16,6 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, m) {
"g_idx, Tensor! perm, Tensor! workspace, int size_m, int size_n, int " "g_idx, Tensor! perm, Tensor! workspace, int size_m, int size_n, int "
"size_k, bool is_k_full, int num_experts, int topk, int moe_block_size, " "size_k, bool is_k_full, int num_experts, int topk, int moe_block_size, "
"bool replicate_input, bool apply_weights) -> Tensor"); "bool replicate_input, bool apply_weights) -> Tensor");
m.impl("marlin_gemm_moe", torch::kCUDA, &marlin_gemm_moe); m.impl("marlin_gemm_moe", torch::kCUDA, &marlin_gemm_moe);
#endif #endif
} }

View File

@ -2,6 +2,8 @@
Run `pytest tests/kernels/test_moe.py`. Run `pytest tests/kernels/test_moe.py`.
""" """
from typing import List
import pytest import pytest
import torch import torch
from transformers import MixtralConfig from transformers import MixtralConfig
@ -9,7 +11,13 @@ from transformers.models.mixtral.modeling_mixtral import MixtralSparseMoeBlock
from vllm.model_executor.layers.activation import SiluAndMul from vllm.model_executor.layers.activation import SiluAndMul
from vllm.model_executor.layers.fused_moe import fused_moe from vllm.model_executor.layers.fused_moe import fused_moe
from vllm.model_executor.layers.fused_moe.fused_marlin_moe import (
fused_marlin_moe, single_marlin_moe)
from vllm.model_executor.layers.fused_moe.fused_moe import fused_topk
from vllm.model_executor.layers.quantization.utils.marlin_utils_test import (
marlin_quantize)
from vllm.model_executor.models.mixtral import MixtralMoE from vllm.model_executor.models.mixtral import MixtralMoE
from vllm.scalar_type import scalar_types
def torch_moe(a, w1, w2, score, topk): def torch_moe(a, w1, w2, score, topk):
@ -29,6 +37,20 @@ def torch_moe(a, w1, w2, score, topk):
topk_weight.view(B, -1, 1).to(out.dtype)).sum(dim=1) topk_weight.view(B, -1, 1).to(out.dtype)).sum(dim=1)
def torch_moe_single(a, w, score, topk):
B, D = a.shape
a = a.view(B, -1, D).repeat(1, topk, 1).reshape(-1, D)
out = torch.zeros(B * topk, w.shape[1], dtype=a.dtype, device=a.device)
score = torch.softmax(score, dim=-1, dtype=torch.float32)
_, topk_ids = torch.topk(score, topk)
topk_ids = topk_ids.view(-1)
for i in range(w.shape[0]):
mask = topk_ids == i
if mask.sum():
out[mask] = a[mask] @ w[i].transpose(0, 1)
return (out.view(B, -1, w.shape[1])).sum(dim=1)
@pytest.mark.parametrize("m", [1024 * 128, 512, 222, 33, 1]) @pytest.mark.parametrize("m", [1024 * 128, 512, 222, 33, 1])
@pytest.mark.parametrize("n", [2048, 256, 1024]) @pytest.mark.parametrize("n", [2048, 256, 1024])
@pytest.mark.parametrize("k", [128, 511, 1024]) @pytest.mark.parametrize("k", [128, 511, 1024])
@ -43,11 +65,11 @@ def test_fused_moe(
topk: int, topk: int,
dtype: torch.dtype, dtype: torch.dtype,
): ):
a = torch.randn((m, k), device='cuda', dtype=dtype) / 10 a = torch.randn((m, k), device="cuda", dtype=dtype) / 10
w1 = torch.randn((e, 2 * n, k), device='cuda', dtype=dtype) / 10 w1 = torch.randn((e, 2 * n, k), device="cuda", dtype=dtype) / 10
w2 = torch.randn((e, k, n), device='cuda', dtype=dtype) / 10 w2 = torch.randn((e, k, n), device="cuda", dtype=dtype) / 10
score = torch.randn((m, e), device='cuda', dtype=dtype) score = torch.randn((m, e), device="cuda", dtype=dtype)
triton_output = fused_moe(a, w1, w2, score, topk, renormalize=False) triton_output = fused_moe(a, w1, w2, score, topk, renormalize=False)
torch_output = torch_moe(a, w1, w2, score, topk) torch_output = torch_moe(a, w1, w2, score, topk)
torch.testing.assert_close(triton_output, torch_output, atol=1e-2, rtol=0) torch.testing.assert_close(triton_output, torch_output, atol=1e-2, rtol=0)
@ -99,3 +121,194 @@ def test_mixtral_moe(dtype: torch.dtype):
vllm_states, vllm_states,
rtol=mixtral_moe_tol[dtype], rtol=mixtral_moe_tol[dtype],
atol=mixtral_moe_tol[dtype]) atol=mixtral_moe_tol[dtype])
def stack_and_dev(tensors: List[torch.Tensor]):
dev = tensors[0].device
return torch.stack(tensors, dim=0).to(dev)
def compute_max_diff(output, output_ref):
return torch.mean(torch.abs(output - output_ref)) / torch.mean(
torch.abs(output_ref))
@pytest.mark.parametrize("m", [64, 512, 222, 33, 1])
@pytest.mark.parametrize("n", [128, 2048, 256, 1024])
@pytest.mark.parametrize("k", [128, 1024, 512])
@pytest.mark.parametrize("e", [4, 8, 64])
@pytest.mark.parametrize("topk", [2, 6])
@pytest.mark.parametrize("group_size", [-1, 32, 64, 128])
@pytest.mark.parametrize("act_order", [True, False])
def test_fused_marlin_moe(
m: int,
n: int,
k: int,
e: int,
topk: int,
group_size: int,
act_order: bool,
):
torch.manual_seed(7)
if topk > e:
return
# Filter act_order
if act_order:
if group_size == -1:
return
if group_size in (k, n):
return
quant_type = scalar_types.uint4b8
dtype = torch.float16
a = torch.randn((m, k), device="cuda", dtype=dtype) / 10
w1 = torch.randn((e, 2 * n, k), device="cuda", dtype=dtype) / 10
w2 = torch.randn((e, k, n), device="cuda", dtype=dtype) / 10
for i in range(w2.shape[0]):
w2[0] = torch.eye(k, n, device="cuda", dtype=dtype)
w_ref1_l = []
qweight1_l = []
scales1_l = []
g_idx1_l = []
sort_indices1_l = []
for i in range(w1.shape[0]):
test_perm = torch.randperm(k)
w_ref1, qweight1, scales1, g_idx1, sort_indices1, _ = marlin_quantize(
w1[i].transpose(1, 0), quant_type, group_size, act_order,
test_perm)
w_ref1_l.append(w_ref1)
qweight1_l.append(qweight1)
scales1_l.append(scales1)
g_idx1_l.append(g_idx1)
sort_indices1_l.append(sort_indices1)
w_ref1 = stack_and_dev(w_ref1_l)
qweight1 = stack_and_dev(qweight1_l).contiguous()
scales1 = stack_and_dev(scales1_l)
g_idx1 = stack_and_dev(g_idx1_l)
sort_indices1 = stack_and_dev(sort_indices1_l)
w_ref2_l = []
qweight2_l = []
scales2_l = []
g_idx2_l = []
sort_indices2_l = []
for i in range(w2.shape[0]):
test_perm = torch.randperm(n)
w_ref2, qweight2, scales2, g_idx2, sort_indices2, _ = marlin_quantize(
w2[i].transpose(1, 0), quant_type, group_size, act_order,
test_perm)
w_ref2_l.append(w_ref2)
qweight2_l.append(qweight2)
scales2_l.append(scales2)
g_idx2_l.append(g_idx2)
sort_indices2_l.append(sort_indices2)
w_ref2 = stack_and_dev(w_ref2_l)
qweight2 = stack_and_dev(qweight2_l).contiguous()
scales2 = stack_and_dev(scales2_l)
g_idx2 = stack_and_dev(g_idx2_l)
sort_indices2 = stack_and_dev(sort_indices2_l)
score = torch.randn((m, e), device="cuda", dtype=dtype)
topk_weights, topk_ids = fused_topk(a, score, topk, False)
triton_output = fused_moe(
a,
w_ref1.transpose(1, 2).contiguous(),
w_ref2.transpose(1, 2).contiguous(),
score,
topk,
renormalize=False,
)
marlin_output = fused_marlin_moe(
a,
qweight1,
qweight2,
score,
g_idx1,
g_idx2,
sort_indices1,
sort_indices2,
topk_weights,
topk_ids,
w1_scale=scales1,
w2_scale=scales2,
)
assert compute_max_diff(marlin_output, triton_output) < 4e-2
@pytest.mark.skip("This test is here for the sake of debugging, "
"don't run it in automated tests.")
@pytest.mark.parametrize("m", [64, 512, 222, 33, 1])
@pytest.mark.parametrize("n", [128, 2048, 256, 1024])
@pytest.mark.parametrize("k", [128, 1024, 512])
@pytest.mark.parametrize("e", [4, 8, 64])
@pytest.mark.parametrize("topk", [2, 6])
@pytest.mark.parametrize("group_size", [-1, 32, 64, 128])
@pytest.mark.parametrize("act_order", [True, False])
def test_marlin_moe_mmm(
m: int,
n: int,
k: int,
e: int,
topk: int,
group_size: int,
act_order: bool,
):
if topk > e:
return
# Filter act_order
if act_order:
if group_size == -1:
return
if group_size == k:
return
quant_type = scalar_types.uint4b8
dtype = torch.float16
a = torch.randn((m, k), device="cuda", dtype=dtype) / 10
w = torch.randn((e, n, k), device="cuda", dtype=dtype) / 10
w_ref_l = []
qweights_l = []
scales_l = []
g_idx_l = []
sort_indices_l = []
for i in range(w.shape[0]):
test_perm = torch.randperm(k)
w_ref, qweight, scales, g_idx, sort_indices, _ = marlin_quantize(
w[i].transpose(1, 0), quant_type, group_size, act_order, test_perm)
w_ref_l.append(w_ref)
qweights_l.append(qweight)
scales_l.append(scales)
g_idx_l.append(g_idx)
sort_indices_l.append(sort_indices)
w_ref = stack_and_dev(w_ref_l)
qweight = stack_and_dev(qweights_l).contiguous()
scales = stack_and_dev(scales_l)
g_idx = stack_and_dev(g_idx_l)
sort_indices = stack_and_dev(sort_indices_l)
score = torch.randn((m, e), device="cuda", dtype=dtype)
marlin_output = single_marlin_moe(a,
qweight,
scales,
score,
g_idx,
sort_indices,
topk,
renormalize=False)
torch_output = torch_moe_single(a, w_ref.transpose(1, 2), score, topk)
assert compute_max_diff(marlin_output, torch_output) < 1e-2

View File

@ -0,0 +1,3 @@
compressed-tensors, nm-testing/Mixtral-8x7B-Instruct-v0.1-W4A16-quantized, main
compressed-tensors, nm-testing/Mixtral-8x7B-Instruct-v0.1-W4A16-channel-quantized, main
gptq_marlin, TheBloke/Mixtral-8x7B-v0.1-GPTQ, main

View File

@ -19,8 +19,6 @@ compressed-tensors, nm-testing/tinyllama-oneshot-w8a16-per-channel, main
compressed-tensors, nm-testing/Meta-Llama-3-8B-FP8-compressed-tensors-test, main compressed-tensors, nm-testing/Meta-Llama-3-8B-FP8-compressed-tensors-test, main
compressed-tensors, nm-testing/Phi-3-mini-128k-instruct-FP8, main compressed-tensors, nm-testing/Phi-3-mini-128k-instruct-FP8, main
compressed-tensors, neuralmagic/Phi-3-medium-128k-instruct-quantized.w4a16, main compressed-tensors, neuralmagic/Phi-3-medium-128k-instruct-quantized.w4a16, main
compressed-tensors, nm-testing/Mixtral-8x7B-Instruct-v0.1-W4A16-quantized, main
compressed-tensors, nm-testing/Mixtral-8x7B-Instruct-v0.1-W4A16-channel-quantized, main
compressed-tensors, nm-testing/TinyLlama-1.1B-Chat-v1.0-actorder-group, main compressed-tensors, nm-testing/TinyLlama-1.1B-Chat-v1.0-actorder-group, main
awq, casperhansen/mixtral-instruct-awq, main awq, casperhansen/mixtral-instruct-awq, main
awq_marlin, casperhansen/mixtral-instruct-awq, main awq_marlin, casperhansen/mixtral-instruct-awq, main

View File

@ -2,16 +2,22 @@ from vllm.model_executor.layers.fused_moe.layer import (
FusedMoE, FusedMoEMethodBase, FusedMoeWeightScaleSupported) FusedMoE, FusedMoEMethodBase, FusedMoeWeightScaleSupported)
from vllm.triton_utils import HAS_TRITON from vllm.triton_utils import HAS_TRITON
__all__ = ["FusedMoE", "FusedMoEMethodBase", "FusedMoeWeightScaleSupported"] __all__ = [
"FusedMoE",
"FusedMoEMethodBase",
"FusedMoeWeightScaleSupported",
]
if HAS_TRITON: if HAS_TRITON:
from vllm.model_executor.layers.fused_moe.fused_marlin_moe import (
fused_marlin_moe, single_marlin_moe)
from vllm.model_executor.layers.fused_moe.fused_moe import ( from vllm.model_executor.layers.fused_moe.fused_moe import (
fused_experts, fused_marlin_moe, fused_moe, fused_topk, fused_experts, fused_moe, fused_topk, get_config_file_name,
get_config_file_name, grouped_topk) grouped_topk)
__all__ += [ __all__ += [
"fused_marlin_moe", "fused_marlin_moe",
"single_marlin_moe",
"fused_moe", "fused_moe",
"fused_topk", "fused_topk",
"fused_experts", "fused_experts",

View File

@ -0,0 +1,219 @@
"""Fused MoE utilities for GPTQ."""
import functools
from typing import Any, Dict, Optional
import torch
from vllm import _custom_ops as ops
from vllm.model_executor.layers.fused_moe.fused_moe import (
fused_topk, moe_align_block_size, try_get_optimal_moe_config)
def single_marlin_moe(
hidden_states: torch.Tensor,
w: torch.Tensor,
scales: torch.Tensor,
gating_output: torch.Tensor,
g_idx: torch.Tensor,
perm: torch.Tensor,
topk: int,
renormalize: bool,
override_config: Optional[Dict[str, Any]] = None) -> torch.Tensor:
"""
This function computes the multiplication of hidden_states with expert
weights used in Marlin MoE, using weights w and top-k gating mechanism.
Its purpose is testing and debugging the fused MoE kernel.
Parameters:
- hidden_states (torch.Tensor): The input tensor to the Marlin Mul.
- w (torch.Tensor): The set of expert weights.
- scales (torch.Tensor): The quantization scales.
- gating_output (torch.Tensor): The output of the gating operation
(before softmax).
- g_idx (torch.Tensor): The act_order indices.
- perm (torch.Tensor): The act_order input permutation.
- topk (int): The number of top-k experts to select.
- renormalize (bool): If True, renormalize the top-k weights to sum to 1.
- override_config (Optional[Dict[str, Any]]): Optional override
for the kernel configuration.
Returns:
- torch.Tensor: The output tensor after applying the MoE layer.
"""
# Check constraints.
assert hidden_states.shape[0] == gating_output.shape[0], (
"Number of tokens mismatch")
assert hidden_states.shape[1] == w.shape[1] * 16, "Hidden size mismatch"
assert gating_output.shape[1] == w.shape[0], "Number of experts mismatch"
assert hidden_states.is_contiguous(), "Hidden_states must be contiguous"
assert w.is_contiguous(), "Expert weights must be contiguous"
assert hidden_states.dtype == torch.float16
M, K = hidden_states.shape
E = w.shape[0]
N = w.shape[2] // 2
topk_weights, topk_ids = fused_topk(hidden_states, gating_output, topk,
renormalize)
# This might not be an optimal config for a single MMM
get_config_func = functools.partial(try_get_optimal_moe_config,
w.shape,
w.shape,
topk_ids.shape[1],
None,
override_config=override_config,
is_marlin=True)
config = get_config_func(M)
block_size_m = config['BLOCK_SIZE_M']
sorted_token_ids, _, _ = moe_align_block_size(topk_ids, block_size_m, E)
max_workspace_size = (N // 64) * 16
workspace = torch.zeros(max_workspace_size,
dtype=torch.int,
device="cuda",
requires_grad=False)
intermediate_cache = torch.ops._moe_C.marlin_gemm_moe(
hidden_states, w, sorted_token_ids, topk_weights, topk_ids, scales,
g_idx, perm, workspace, M, N, K, True, E, topk, block_size_m, True,
False)
return torch.sum(intermediate_cache.view(*intermediate_cache.shape), dim=1)
def fused_marlin_moe(
hidden_states: torch.Tensor,
w1: torch.Tensor,
w2: torch.Tensor,
gating_output: torch.Tensor,
g_idx1: torch.Tensor,
g_idx2: torch.Tensor,
perm1: torch.Tensor,
perm2: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
override_config: Optional[Dict[str, Any]] = None,
w1_scale: Optional[torch.Tensor] = None,
w2_scale: Optional[torch.Tensor] = None,
) -> torch.Tensor:
"""
This function computes a Mixture of Experts (MoE) layer using two sets of
weights, w1 and w2, and top-k gating mechanism.
Parameters:
- hidden_states (torch.Tensor): The input tensor to the MoE layer.
- w1 (torch.Tensor): The first set of expert weights.
- w2 (torch.Tensor): The second set of expert weights.
- gating_output (torch.Tensor): The output of the gating operation
(before softmax).
- g_idx1 (torch.Tensor): The first set of act_order indices.
- g_idx2 (torch.Tensor): The second set of act_order indices.
- perm1 (torch.Tensor): The first act_order input permutation.
- perm2 (torch.Tensor): The second act_order input permutation.
- topk_weights (torch.Tensor): Top-k weights.
- topk_ids (torch.Tensor): Indices of topk-k elements.
- renormalize (bool): If True, renormalize the top-k weights to sum to 1.
- override_config (Optional[Dict[str, Any]]): Optional override
for the kernel configuration.
- w1_scale (Optional[torch.Tensor]): Optional scale to be used for
w1.
- w2_scale (Optional[torch.Tensor]): Optional scale to be used for
w2.
Returns:
- torch.Tensor: The output tensor after applying the MoE layer.
"""
# Check constraints.
assert hidden_states.shape[0] == gating_output.shape[
0], "Number of tokens mismatch"
assert hidden_states.shape[
1] == w1.shape[1] * 16, "Hidden size mismatch w1"
assert hidden_states.shape[
1] == w2.shape[2] // 2, "Hidden size mismatch w2"
assert gating_output.shape[1] == w1.shape[0], "Number of experts mismatch"
assert hidden_states.is_contiguous(), "Hidden_states must be contiguous"
assert w1.is_contiguous(), "Expert weights1 must be contiguous"
assert w2.is_contiguous(), "Expert weights2 must be contiguous"
assert hidden_states.dtype == torch.float16
M, K = hidden_states.shape
E = w1.shape[0]
N = w2.shape[1] * 16
topk = topk_ids.shape[1]
get_config_func = functools.partial(
try_get_optimal_moe_config,
w1.shape,
w2.shape,
topk_ids.shape[1],
None,
override_config=override_config,
is_marlin=True,
)
config = get_config_func(M)
block_size_m = config["BLOCK_SIZE_M"]
sorted_token_ids, _, _ = moe_align_block_size(topk_ids, block_size_m, E)
max_workspace_size = ((M + 255) // 256) * (max(2 * N, K) // 64) * 16
workspace = torch.zeros(max_workspace_size,
dtype=torch.int,
device="cuda",
requires_grad=False)
intermediate_cache2 = torch.empty(
(M * topk_ids.shape[1], N),
device=hidden_states.device,
dtype=hidden_states.dtype,
)
intermediate_cache1 = torch.ops._moe_C.marlin_gemm_moe(
hidden_states,
w1,
sorted_token_ids,
topk_weights,
topk_ids,
w1_scale,
g_idx1,
perm1,
workspace,
M,
2 * N,
K,
True,
E,
topk,
block_size_m,
True,
False,
)
ops.silu_and_mul(intermediate_cache2, intermediate_cache1.view(-1, 2 * N))
intermediate_cache3 = torch.ops._moe_C.marlin_gemm_moe(
intermediate_cache2,
w2,
sorted_token_ids,
topk_weights,
topk_ids,
w2_scale,
g_idx2,
perm2,
workspace,
M,
K,
N,
True,
E,
topk,
block_size_m,
False,
True,
)
return torch.sum(intermediate_cache3.view(*intermediate_cache3.shape),
dim=1)

View File

@ -323,15 +323,22 @@ def get_moe_configs(E: int, N: int,
return None return None
def get_default_config(M: int, E: int, N: int, K: int, topk: int, def get_default_config(
M: int,
E: int,
N: int,
K: int,
topk: int,
dtype: Optional[str], dtype: Optional[str],
is_marlin: bool) -> Dict[str, int]: is_marlin: bool,
) -> Dict[str, int]:
config = { config = {
'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_M': 64,
'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_N': 64,
'BLOCK_SIZE_K': 32, 'BLOCK_SIZE_K': 32,
'GROUP_SIZE_M': 8 'GROUP_SIZE_M': 8
} }
# A heuristic: fused marlin works faster with this config for small M
if M <= E or (is_marlin and M <= 32): if M <= E or (is_marlin and M <= 32):
config = { config = {
'BLOCK_SIZE_M': 16, 'BLOCK_SIZE_M': 16,
@ -342,14 +349,15 @@ def get_default_config(M: int, E: int, N: int, K: int, topk: int,
return config return config
def try_get_optimal_moe_config(w1_shape: Tuple[int, ...], def try_get_optimal_moe_config(
w1_shape: Tuple[int, ...],
w2_shape: Tuple[int, ...], w2_shape: Tuple[int, ...],
top_k: int, top_k: int,
dtype: Optional[str], dtype: Optional[str],
M: int, M: int,
override_config: Optional[Dict[str, override_config: Optional[Dict[str, Any]] = None,
Any]] = None, is_marlin: bool = False,
is_marlin: bool = False): ):
if override_config: if override_config:
config = override_config config = override_config
else: else:
@ -391,6 +399,7 @@ def fused_topk(
topk, topk,
dtype=torch.int32, dtype=torch.int32,
device=hidden_states.device) device=hidden_states.device)
ops.topk_softmax( ops.topk_softmax(
topk_weights, topk_weights,
topk_ids, topk_ids,
@ -437,113 +446,6 @@ def grouped_topk(hidden_states: torch.Tensor,
return topk_weights, topk_ids return topk_weights, topk_ids
def fused_marlin_moe(hidden_states: torch.Tensor,
w1: torch.Tensor,
w2: torch.Tensor,
gating_output: torch.Tensor,
g_idx1: torch.Tensor,
g_idx2: torch.Tensor,
rand_perm1: torch.Tensor,
rand_perm2: torch.Tensor,
topk: int,
custom_routing_function: Optional[Callable] = None,
renormalize: bool = True,
override_config: Optional[Dict[str, Any]] = None,
use_fp8: bool = False,
w1_scale: Optional[torch.Tensor] = None,
w2_scale: Optional[torch.Tensor] = None) -> torch.Tensor:
"""
This function computes a Mixture of Experts (MoE) layer using two sets of
weights, w1 and w2, and top-k gating mechanism.
Parameters:
- hidden_states (torch.Tensor): The input tensor to the MoE layer.
- w1 (torch.Tensor): The first set of expert weights.
- w2 (torch.Tensor): The second set of expert weights.
- gating_output (torch.Tensor): The output of the gating operation
(before softmax).
- topk (int): The number of top-k experts to select.
- renormalize (bool): If True, renormalize the top-k weights to sum to 1.
- inplace (bool): If True, perform the operation in-place.
Defaults to False.
- override_config (Optional[Dict[str, Any]]): Optional override
for the kernel configuration.
- use_fp8 (bool): If True, use fp8 arithmetic to compute the inner
products for w1 and w2. Defaults to False.
- w1_scale (Optional[torch.Tensor]): Optional scale to be used for
w1.
- w2_scale (Optional[torch.Tensor]): Optional scale to be used for
w2.
Returns:
- torch.Tensor: The output tensor after applying the MoE layer.
"""
# Check constraints.
assert hidden_states.shape[0] == gating_output.shape[0], (
"Number of tokens mismatch")
assert hidden_states.shape[
1] == w1.shape[1] * 16, "Hidden size mismatch w1"
assert hidden_states.shape[
1] == w2.shape[2] // 2, "Hidden size mismatch w2"
assert gating_output.shape[1] == w1.shape[0], "Number of experts mismatch"
assert hidden_states.is_contiguous(), "Hidden_states must be contiguous"
assert w1.is_contiguous(), "Expert weights1 must be contiguous"
assert w2.is_contiguous(), "Expert weights2 must be contiguous"
assert hidden_states.dtype in [
torch.float32, torch.float16, torch.bfloat16
]
#TODO fp8 is not implemented yet
assert not use_fp8
M, K = hidden_states.shape
E = w1.shape[0]
N = w2.shape[1] * 16
if custom_routing_function is None:
topk_weights, topk_ids = fused_topk(hidden_states, gating_output, topk,
renormalize)
else:
topk_weights, topk_ids = custom_routing_function(
hidden_states, gating_output, topk, renormalize)
get_config_func = functools.partial(try_get_optimal_moe_config,
w1.shape,
w2.shape,
topk_ids.shape[1],
"float8" if use_fp8 else None,
override_config=override_config,
is_marlin=True)
config = get_config_func(M)
block_size_m = config['BLOCK_SIZE_M']
sorted_token_ids, _, _ = moe_align_block_size(topk_ids, block_size_m, E)
max_workspace_size = ((M + 255) // 256) * (max(2 * N, K) // 64) * 16
workspace = torch.zeros(max_workspace_size,
dtype=torch.int,
device="cuda",
requires_grad=False)
intermediate_cache2 = torch.empty((M * topk_ids.shape[1], N),
device=hidden_states.device,
dtype=hidden_states.dtype)
intermediate_cache1 = torch.ops._moe_C.marlin_gemm_moe(
hidden_states, w1, sorted_token_ids, topk_weights, topk_ids, w1_scale,
g_idx1, rand_perm1, workspace, M, 2 * N, K, True, E, topk,
block_size_m, True, False)
ops.silu_and_mul(intermediate_cache2, intermediate_cache1.view(-1, 2 * N))
intermediate_cache3 = torch.ops._moe_C.marlin_gemm_moe(
intermediate_cache2, w2, sorted_token_ids, topk_weights, topk_ids,
w2_scale, g_idx2, rand_perm2, workspace, M, K, N, True, E, topk,
block_size_m, False, True)
return torch.sum(intermediate_cache3.view(*intermediate_cache3.shape),
dim=1)
def get_config_dtype_str(dtype: torch.dtype, def get_config_dtype_str(dtype: torch.dtype,
use_int8_w8a16: Optional[bool] = False, use_int8_w8a16: Optional[bool] = False,
use_fp8_w8a8: Optional[bool] = False): use_fp8_w8a8: Optional[bool] = False):

View File

@ -306,10 +306,28 @@ class FusedMoE(torch.nn.Module):
# Input scales can be loaded directly and should be equal. # Input scales can be loaded directly and should be equal.
param_data[expert_id] = loaded_weight param_data[expert_id] = loaded_weight
def _load_g_idx(self, shard_id: str, expert_data: torch.Tensor,
shard_dim: int, loaded_weight: torch.tensor, tp_rank: int):
if shard_id == "w2":
self._load_w2(shard_id=shard_id,
shard_dim=shard_dim,
loaded_weight=loaded_weight,
expert_data=expert_data,
tp_rank=tp_rank)
else:
assert shard_id in ("w1", "w3")
expert_data.copy_(loaded_weight)
def weight_loader(self, param: torch.nn.Parameter, def weight_loader(self, param: torch.nn.Parameter,
loaded_weight: torch.Tensor, weight_name: str, loaded_weight: torch.Tensor, weight_name: str,
shard_id: str, expert_id: int) -> None: shard_id: str, expert_id: int) -> None:
# compressed-tensors represents weights on disk which are flipped
loaded_weight = loaded_weight.t().contiguous() if (
self.quant_method.__class__.__name__
== "CompressedTensorsMoEMethod") else loaded_weight
if shard_id not in ("w1", "w2", "w3"): if shard_id not in ("w1", "w2", "w3"):
raise ValueError(f"shard_id must be ['w1','w2','w3'] but " raise ValueError(f"shard_id must be ['w1','w2','w3'] but "
f"got {shard_id}.") f"got {shard_id}.")
@ -325,19 +343,41 @@ class FusedMoE(torch.nn.Module):
expert_data = param.data[expert_id] expert_data = param.data[expert_id]
tp_rank = get_tensor_model_parallel_rank() tp_rank = get_tensor_model_parallel_rank()
# is_transposed: whether or not the parameter is transposed on disk # is_transposed: if the dim to shard the weight
# If transposed, the loaded weight will be transposed and the dim # should be flipped. Required by GPTQ, compressed-tensors
# to shard the loaded weight will be flipped. # should be whatever dimension intermediate_size is
is_transposed = getattr(param, "is_transposed", False) is_transposed = getattr(param, "is_transposed", False)
shard_dim = SHARD_ID_TO_SHARDED_DIM[shard_id] shard_dim = SHARD_ID_TO_SHARDED_DIM[shard_id]
if is_transposed: if is_transposed:
loaded_weight = loaded_weight.t().contiguous()
shard_dim = ~shard_dim shard_dim = ~shard_dim
# Case weight_scales # Case input scale: input_scale loading is only supported for fp8
if "weight_scale" in weight_name: if "input_scale" in weight_name:
# load the weight scaling based on the quantization scheme if param.data[expert_id] != 1 and (param.data[expert_id] -
# supported weight scales can be found in loaded_weight).abs() > 1e-5:
raise ValueError(
"input_scales of w1 and w3 of a layer "
f"must be equal. But got {param.data[expert_id]} "
f"vs. {loaded_weight}")
self._load_single_value(param=param,
loaded_weight=loaded_weight,
expert_id=expert_id)
return
# Case g_idx
if "g_idx" in weight_name:
self._load_g_idx(shard_dim=0,
shard_id=shard_id,
loaded_weight=loaded_weight,
expert_data=expert_data,
tp_rank=tp_rank)
return
# Case weight scales and zero_points
if ("scale" in weight_name or "zero" in weight_name):
# load the weight scales and zp based on the quantization scheme
# supported weight scales/zp can be found in
# FusedMoeWeightScaleSupported # FusedMoeWeightScaleSupported
# TODO @dsikka: once hardened, refactor to use vLLM Parameters # TODO @dsikka: once hardened, refactor to use vLLM Parameters
# specific to each case # specific to each case
@ -366,22 +406,9 @@ class FusedMoE(torch.nn.Module):
f"quant method must be one of {WEIGHT_SCALE_SUPPORTED}") f"quant method must be one of {WEIGHT_SCALE_SUPPORTED}")
return return
# Case weight_shape
if "weight_shape" in weight_name: if "weight_shape" in weight_name:
self._load_single_value(param=param, # only required by compressed-tensors
loaded_weight=loaded_weight,
expert_id=expert_id)
return
# Case input scale
if "input_scale" in weight_name:
# Note: input_scale loading is only supported for fp8
if param.data[expert_id] != 1 and (param.data[expert_id] -
loaded_weight).abs() > 1e-5:
raise ValueError(
"input_scales of w1 and w3 of a layer "
f"must be equal. But got {param.data[expert_id]} "
f"vs. {loaded_weight}")
self._load_single_value(param=param, self._load_single_value(param=param,
loaded_weight=loaded_weight, loaded_weight=loaded_weight,
expert_id=expert_id) expert_id=expert_id)

View File

@ -5,9 +5,7 @@ from typing import Callable, List, Optional
import torch import torch
from vllm import _custom_ops as ops from vllm import _custom_ops as ops
from vllm.model_executor.layers.fused_moe import FusedMoEMethodBase from vllm.model_executor.layers.fused_moe import FusedMoE, FusedMoEMethodBase
from vllm.model_executor.layers.quantization.compressed_tensors.schemes import (
WNA16_SUPPORTED_BITS)
from vllm.model_executor.layers.quantization.compressed_tensors.utils import ( from vllm.model_executor.layers.quantization.compressed_tensors.utils import (
CompressionFormat) CompressionFormat)
from vllm.model_executor.utils import set_weight_attrs from vllm.model_executor.utils import set_weight_attrs
@ -40,11 +38,10 @@ class CompressedTensorsMoEMethod(FusedMoEMethodBase):
if not (self.quant_config.quant_format if not (self.quant_config.quant_format
== CompressionFormat.pack_quantized.value == CompressionFormat.pack_quantized.value
and self.num_bits in WNA16_SUPPORTED_BITS): and self.num_bits == 4):
raise ValueError("For Fused MoE layers, only ", raise ValueError("For Fused MoE layers, only ",
f"{CompressionFormat.pack_quantized.value} ", f"{CompressionFormat.pack_quantized.value} ",
"is supported for the following bits: ", "is supported for 4 bits")
f"{WNA16_SUPPORTED_BITS}")
def create_weights(self, layer: torch.nn.Module, num_experts: int, def create_weights(self, layer: torch.nn.Module, num_experts: int,
hidden_size: int, intermediate_size: int, hidden_size: int, intermediate_size: int,
@ -269,10 +266,21 @@ class CompressedTensorsMoEMethod(FusedMoEMethodBase):
custom_routing_function: Optional[Callable] = None, custom_routing_function: Optional[Callable] = None,
) -> torch.Tensor: ) -> torch.Tensor:
from vllm.model_executor.layers.fused_moe.fused_moe import ( from vllm.model_executor.layers.fused_moe.fused_marlin_moe import (
fused_marlin_moe) fused_marlin_moe)
return fused_marlin_moe(x, topk_weights, topk_ids = FusedMoE.select_experts(
hidden_states=x,
router_logits=router_logits,
use_grouped_topk=use_grouped_topk,
top_k=top_k,
renormalize=renormalize,
topk_group=topk_group,
num_expert_group=num_expert_group,
custom_routing_function=custom_routing_function)
return fused_marlin_moe(
x,
layer.w13_weight_packed, layer.w13_weight_packed,
layer.w2_weight_packed, layer.w2_weight_packed,
router_logits, router_logits,
@ -280,8 +288,8 @@ class CompressedTensorsMoEMethod(FusedMoEMethodBase):
layer.w2_g_idx, layer.w2_g_idx,
layer.w13_g_idx_sort_indices, layer.w13_g_idx_sort_indices,
layer.w2_g_idx_sort_indices, layer.w2_g_idx_sort_indices,
top_k, topk_weights,
custom_routing_function, topk_ids,
renormalize=renormalize,
w1_scale=layer.w13_weight_scale, w1_scale=layer.w13_weight_scale,
w2_scale=layer.w2_weight_scale) w2_scale=layer.w2_weight_scale,
)

View File

@ -22,7 +22,7 @@ from vllm.scalar_type import scalar_types
__all__ = ["CompressedTensorsWNA16"] __all__ = ["CompressedTensorsWNA16"]
WNA16_SUPPORTED_TYPES_MAP = { WNA16_SUPPORTED_TYPES_MAP = {
4: scalar_types.uint4b8, 4: scalar_types.uint4b8,
8: scalar_types.uint8b128, 8: scalar_types.uint8b128
} }
WNA16_SUPPORTED_BITS = list(WNA16_SUPPORTED_TYPES_MAP.keys()) WNA16_SUPPORTED_BITS = list(WNA16_SUPPORTED_TYPES_MAP.keys())

View File

@ -1,18 +1,22 @@
from typing import Any, Dict, List, Optional from typing import Any, Callable, Dict, List, Optional, Union
import torch import torch
from torch.nn import Parameter from torch.nn import Parameter
from vllm import _custom_ops as ops from vllm import _custom_ops as ops
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.model_executor.layers.linear import LinearBase, LinearMethodBase from vllm.model_executor.layers.fused_moe.layer import (
FusedMoE, FusedMoEMethodBase, FusedMoeWeightScaleSupported)
from vllm.model_executor.layers.linear import (LinearBase, LinearMethodBase,
set_weight_attrs)
from vllm.model_executor.layers.quantization.base_config import ( from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig) QuantizationConfig)
from vllm.model_executor.layers.quantization.utils.marlin_utils import ( from vllm.model_executor.layers.quantization.utils.marlin_utils import (
apply_gptq_marlin_linear, check_marlin_supported, marlin_is_k_full, apply_gptq_marlin_linear, check_marlin_supported, marlin_is_k_full,
marlin_make_empty_g_idx, marlin_make_workspace, marlin_permute_scales, marlin_make_empty_g_idx, marlin_make_workspace, marlin_moe_permute_scales,
marlin_repeat_scales_on_all_ranks, marlin_sort_g_idx, replace_tensor, marlin_permute_scales, marlin_repeat_scales_on_all_ranks,
verify_marlin_supported, verify_marlin_supports_shape) marlin_sort_g_idx, replace_tensor, verify_marlin_supported,
verify_marlin_supports_shape)
from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead
from vllm.model_executor.parameter import (ChannelQuantScaleParameter, from vllm.model_executor.parameter import (ChannelQuantScaleParameter,
GroupQuantScaleParameter, GroupQuantScaleParameter,
@ -33,8 +37,14 @@ class GPTQMarlinConfig(QuantizationConfig):
(8, True): scalar_types.uint8b128, (8, True): scalar_types.uint8b128,
} }
def __init__(self, weight_bits: int, group_size: int, desc_act: bool, def __init__(
is_sym: bool, lm_head_quantized: bool) -> None: self,
weight_bits: int,
group_size: int,
desc_act: bool,
is_sym: bool,
lm_head_quantized: bool,
) -> None:
if desc_act and group_size == -1: if desc_act and group_size == -1:
# In this case, act_order == True is the same as act_order == False # In this case, act_order == True is the same as act_order == False
# (since we have only one group per output channel) # (since we have only one group per output channel)
@ -105,11 +115,14 @@ class GPTQMarlinConfig(QuantizationConfig):
" faster inference") " faster inference")
return None return None
def get_quant_method(self, layer: torch.nn.Module, def get_quant_method(
prefix: str) -> Optional["GPTQMarlinLinearMethod"]: self, layer: torch.nn.Module, prefix: str
if (isinstance(layer, LinearBase) or ) -> Optional[Union["GPTQMarlinLinearMethod", "GPTQMarlinMoEMethod"]]:
(isinstance(layer, ParallelLMHead) and self.lm_head_quantized)): if isinstance(layer, LinearBase) or (isinstance(layer, ParallelLMHead)
and self.lm_head_quantized):
return GPTQMarlinLinearMethod(self) return GPTQMarlinLinearMethod(self)
elif isinstance(layer, FusedMoE):
return GPTQMarlinMoEMethod(self)
return None return None
def get_scaled_act_names(self) -> List[str]: def get_scaled_act_names(self) -> List[str]:
@ -179,7 +192,8 @@ class GPTQMarlinLinearMethod(LinearMethodBase):
output_size_per_partition=output_size_per_partition, output_size_per_partition=output_size_per_partition,
input_size_per_partition=input_size_per_partition, input_size_per_partition=input_size_per_partition,
input_size=input_size, input_size=input_size,
group_size=group_size) group_size=group_size,
)
# Determine sharding # Determine sharding
if marlin_repeat_scales_on_all_ranks(self.quant_config.desc_act, if marlin_repeat_scales_on_all_ranks(self.quant_config.desc_act,
@ -299,7 +313,8 @@ class GPTQMarlinLinearMethod(LinearMethodBase):
perm=layer.g_idx_sort_indices, perm=layer.g_idx_sort_indices,
size_k=layer.input_size_per_partition, size_k=layer.input_size_per_partition,
size_n=layer.output_size_per_partition, size_n=layer.output_size_per_partition,
num_bits=self.quant_config.quant_type.size_bits) num_bits=self.quant_config.quant_type.size_bits,
)
replace_tensor(layer, "qweight", marlin_qweight) replace_tensor(layer, "qweight", marlin_qweight)
# Permute scales from autogptq format to marlin format. # Permute scales from autogptq format to marlin format.
@ -308,7 +323,8 @@ class GPTQMarlinLinearMethod(LinearMethodBase):
size_k=(layer.input_size if self.quant_config.desc_act else size_k=(layer.input_size if self.quant_config.desc_act else
layer.input_size_per_partition), layer.input_size_per_partition),
size_n=layer.output_size_per_partition, size_n=layer.output_size_per_partition,
group_size=self.quant_config.group_size) group_size=self.quant_config.group_size,
)
replace_tensor(layer, "scales", marlin_scales) replace_tensor(layer, "scales", marlin_scales)
def apply( def apply(
@ -329,4 +345,270 @@ class GPTQMarlinLinearMethod(LinearMethodBase):
output_size_per_partition=layer.output_size_per_partition, output_size_per_partition=layer.output_size_per_partition,
input_size_per_partition=layer.input_size_per_partition, input_size_per_partition=layer.input_size_per_partition,
is_k_full=layer.is_k_full, is_k_full=layer.is_k_full,
bias=bias) bias=bias,
)
class GPTQMarlinMoEMethod(FusedMoEMethodBase):
"""MoE Marlin method with quantization."""
def __init__(self, quant_config: GPTQMarlinConfig) -> None:
self.quant_config = quant_config
def create_weights(
self,
layer: torch.nn.Module,
num_experts: int,
hidden_size: int,
intermediate_size: int,
params_dtype: torch.dtype,
**extra_weight_attrs,
):
# Currently assuming is_k_full is always True
# (input size per partition is the same as full input size)
# Supports only sym for now (no zp)
if self.quant_config.group_size != -1:
scales_size13 = hidden_size // self.quant_config.group_size
scales_size2 = intermediate_size // self.quant_config.group_size
strategy = FusedMoeWeightScaleSupported.GROUP.value
else:
scales_size13 = 1
scales_size2 = 1
strategy = FusedMoeWeightScaleSupported.CHANNEL.value
extra_weight_attrs.update({
"quant_method": strategy,
"is_transposed": True
})
# Fused gate_up_proj (column parallel)
w13_qweight = torch.nn.Parameter(
torch.empty(
num_experts,
hidden_size // self.quant_config.pack_factor,
2 * intermediate_size,
dtype=torch.int32,
),
requires_grad=False,
)
layer.register_parameter("w13_qweight", w13_qweight)
set_weight_attrs(w13_qweight, extra_weight_attrs)
# down_proj (row parallel)
w2_qweight = torch.nn.Parameter(
torch.empty(
num_experts,
intermediate_size // self.quant_config.pack_factor,
hidden_size,
dtype=torch.int32,
),
requires_grad=False,
)
layer.register_parameter("w2_qweight", w2_qweight)
set_weight_attrs(w2_qweight, extra_weight_attrs)
# up_proj scales
w13_scales = torch.nn.Parameter(
torch.empty(num_experts,
scales_size13,
2 * intermediate_size,
dtype=torch.half),
requires_grad=False,
)
layer.register_parameter("w13_scales", w13_scales)
set_weight_attrs(w13_scales, extra_weight_attrs)
# down_proj scales
w2_scales = torch.nn.Parameter(
torch.empty(num_experts,
scales_size2,
hidden_size,
dtype=torch.half),
requires_grad=False,
)
layer.register_parameter("w2_scales", w2_scales)
set_weight_attrs(w2_scales, extra_weight_attrs)
# up_proj scales
w13_qzeros = torch.nn.Parameter(
torch.empty(num_experts,
scales_size13,
2 * intermediate_size // self.quant_config.pack_factor,
dtype=params_dtype),
requires_grad=False,
)
layer.register_parameter("w13_qzeros", w13_qzeros)
set_weight_attrs(w13_qzeros, extra_weight_attrs)
# down_proj scales
w2_qzeros = torch.nn.Parameter(
torch.empty(num_experts,
scales_size2,
hidden_size // self.quant_config.pack_factor,
dtype=params_dtype),
requires_grad=False,
)
layer.register_parameter("w2_qzeros", w2_qzeros)
set_weight_attrs(w2_qzeros, extra_weight_attrs)
w13_g_idx = torch.nn.Parameter(
torch.empty(
num_experts,
hidden_size,
dtype=torch.int32,
),
requires_grad=False,
)
layer.register_parameter("w13_g_idx", w13_g_idx)
set_weight_attrs(w13_g_idx, extra_weight_attrs)
w2_g_idx = torch.nn.Parameter(
torch.empty(
num_experts,
intermediate_size,
dtype=torch.int32,
),
requires_grad=False,
)
layer.register_parameter("w2_g_idx", w2_g_idx)
set_weight_attrs(w2_g_idx, extra_weight_attrs)
w13_g_idx_sort_indices = torch.nn.Parameter(
torch.empty(
num_experts,
hidden_size,
dtype=torch.int32,
),
requires_grad=False,
)
layer.register_parameter("w13_g_idx_sort_indices",
w13_g_idx_sort_indices)
set_weight_attrs(w13_g_idx_sort_indices, extra_weight_attrs)
w2_g_idx_sort_indices = torch.nn.Parameter(
torch.empty(
num_experts,
intermediate_size,
dtype=torch.int32,
),
requires_grad=False,
)
layer.register_parameter("w2_g_idx_sort_indices",
w2_g_idx_sort_indices)
set_weight_attrs(w2_g_idx_sort_indices, extra_weight_attrs)
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
# Process act_order
if self.quant_config.desc_act:
# Get sorting based on g_idx
num_experts = layer.w13_g_idx.shape[0]
w13_g_idx_sort_indices = torch.empty_like(layer.w13_g_idx)
w2_g_idx_sort_indices = torch.empty_like(layer.w2_g_idx)
w13_sorted_g_idx = torch.empty_like(layer.w13_g_idx)
w2_sorted_g_idx = torch.empty_like(layer.w2_g_idx)
for e in range(num_experts):
w13_g_idx_sort_indices[e] = torch.argsort(
layer.w13_g_idx[e]).to(torch.int32)
w2_g_idx_sort_indices[e] = torch.argsort(layer.w2_g_idx[e]).to(
torch.int32)
w13_sorted_g_idx[e] = layer.w13_g_idx[e][
w13_g_idx_sort_indices[e]]
w2_sorted_g_idx[e] = layer.w2_g_idx[e][
w2_g_idx_sort_indices[e]]
replace_tensor(layer, "w13_g_idx", w13_sorted_g_idx)
replace_tensor(layer, "w2_g_idx", w2_sorted_g_idx)
replace_tensor(layer, "w13_g_idx_sort_indices",
w13_g_idx_sort_indices)
replace_tensor(layer, "w2_g_idx_sort_indices",
w2_g_idx_sort_indices)
else:
# Reset g_idx related tensors
num_experts = layer.w13_g_idx.shape[0]
device = layer.w13_g_idx.device
layer.w13_g_idx = torch.nn.Parameter(
torch.empty((num_experts, 0), dtype=torch.int32,
device=device),
requires_grad=False,
)
layer.w2_g_idx = torch.nn.Parameter(
torch.empty((num_experts, 0), dtype=torch.int32,
device=device),
requires_grad=False,
)
layer.w13_g_idx_sort_indices = torch.nn.Parameter(
torch.empty((num_experts, 0), dtype=torch.int32,
device=device),
requires_grad=False,
)
layer.w2_g_idx_sort_indices = torch.nn.Parameter(
torch.empty((num_experts, 0), dtype=torch.int32,
device=device),
requires_grad=False,
)
# Repack weights
marlin_w13_qweight = ops.gptq_marlin_moe_repack(
layer.w13_qweight,
layer.w13_g_idx_sort_indices,
layer.w13_qweight.shape[1] * self.quant_config.pack_factor,
layer.w13_qweight.shape[2],
self.quant_config.quant_type.size_bits,
)
replace_tensor(layer, "w13_qweight", marlin_w13_qweight)
marlin_w2_qweight = ops.gptq_marlin_moe_repack(
layer.w2_qweight,
layer.w2_g_idx_sort_indices,
layer.w2_qweight.shape[1] * self.quant_config.pack_factor,
layer.w2_qweight.shape[2],
self.quant_config.quant_type.size_bits,
)
replace_tensor(layer, "w2_qweight", marlin_w2_qweight)
# Repack scales
marlin_w13_scales = marlin_moe_permute_scales(
s=layer.w13_scales,
size_k=layer.intermediate_size_per_partition,
size_n=layer.w13_scales.shape[2],
group_size=self.quant_config.group_size,
)
replace_tensor(layer, "w13_scales", marlin_w13_scales)
marlin_w2_scales = marlin_moe_permute_scales(
s=layer.w2_scales,
size_k=layer.w2_scales.shape[1] * self.quant_config.pack_factor,
size_n=layer.w2_scales.shape[2],
group_size=self.quant_config.group_size,
)
replace_tensor(layer, "w2_scales", marlin_w2_scales)
def apply(
self,
layer: torch.nn.Module,
x: torch.Tensor,
router_logits: torch.Tensor,
top_k: int,
renormalize: bool = True,
use_grouped_topk: bool = False,
num_expert_group: Optional[int] = None,
topk_group: Optional[int] = None,
custom_routing_function: Optional[Callable] = None,
) -> torch.Tensor:
from vllm.model_executor.layers.fused_moe.fused_marlin_moe import (
fused_marlin_moe)
# The input must currently be float16
orig_dtype = x.dtype
x = x.half()
topk_weights, topk_ids = FusedMoE.select_experts(
hidden_states=x,
router_logits=router_logits,
use_grouped_topk=use_grouped_topk,
top_k=top_k,
renormalize=renormalize,
topk_group=topk_group,
num_expert_group=num_expert_group,
custom_routing_function=None)
return fused_marlin_moe(
x,
layer.w13_qweight,
layer.w2_qweight,
router_logits,
layer.w13_g_idx,
layer.w2_g_idx,
layer.w13_g_idx_sort_indices,
layer.w2_g_idx_sort_indices,
topk_weights,
topk_ids,
w1_scale=layer.w13_scales,
w2_scale=layer.w2_scales,
).to(orig_dtype)

View File

@ -176,6 +176,23 @@ def marlin_permute_scales(s: torch.Tensor, size_k: int, size_n: int,
return s return s
def marlin_moe_permute_scales(
s: torch.Tensor,
size_k: int,
size_n: int,
group_size: int,
):
num_experts = s.shape[0]
output = torch.empty(
(num_experts, s.shape[1], s.shape[2]),
device=s.device,
dtype=s.dtype,
)
for e in range(num_experts):
output[e] = marlin_permute_scales(s[e], size_k, size_n, group_size)
return output
def marlin_zero_points(zp: torch.Tensor, size_k: int, size_n: int, def marlin_zero_points(zp: torch.Tensor, size_k: int, size_n: int,
num_bits: int) -> torch.Tensor: num_bits: int) -> torch.Tensor:
# Permute zero-points in a similar way to scales, but do not use the # Permute zero-points in a similar way to scales, but do not use the

View File

@ -1,6 +1,6 @@
"""Utility functions used for tests and benchmarks""" """Utility functions used for tests and benchmarks"""
from typing import List from typing import List, Optional
import numpy as np import numpy as np
import torch import torch
@ -92,8 +92,11 @@ def get_weight_perm(num_bits: int):
return perm return perm
def marlin_quantize(w: torch.Tensor, quant_type: ScalarType, group_size: int, def marlin_quantize(w: torch.Tensor,
act_order: bool): quant_type: ScalarType,
group_size: int,
act_order: bool,
test_perm: Optional[torch.Tensor] = None):
size_k, size_n = w.shape size_k, size_n = w.shape
num_bits = quant_type.size_bits num_bits = quant_type.size_bits
@ -104,7 +107,7 @@ def marlin_quantize(w: torch.Tensor, quant_type: ScalarType, group_size: int,
# Quantize (and apply act_order if provided) # Quantize (and apply act_order if provided)
w_ref, q_w, s, g_idx, rand_perm = gptq_quantize_weights( w_ref, q_w, s, g_idx, rand_perm = gptq_quantize_weights(
w, quant_type, group_size, act_order) w, quant_type, group_size, act_order, test_perm)
# For act_order, sort the "weights" and "g_idx" so that group ids are # For act_order, sort the "weights" and "g_idx" so that group ids are
# increasing # increasing

View File

@ -1,5 +1,5 @@
"""This file is used for /tests and /benchmarks""" """This file is used for /tests and /benchmarks"""
from typing import List from typing import List, Optional
import numpy import numpy
import torch import torch
@ -53,7 +53,10 @@ def get_pack_factor(num_bits):
return 32 // num_bits return 32 // num_bits
def permute_rows(q_w: torch.Tensor, w_ref: torch.Tensor, group_size: int): def permute_rows(q_w: torch.Tensor,
w_ref: torch.Tensor,
group_size: int,
test_perm: Optional[torch.Tensor] = None):
assert q_w.shape == w_ref.shape assert q_w.shape == w_ref.shape
orig_device = q_w.device orig_device = q_w.device
@ -64,7 +67,7 @@ def permute_rows(q_w: torch.Tensor, w_ref: torch.Tensor, group_size: int):
g_idx[i] = i // group_size g_idx[i] = i // group_size
# Simulate act_order by doing a random permutation on K # Simulate act_order by doing a random permutation on K
rand_perm = torch.randperm(k_size) rand_perm = test_perm if test_perm is not None else torch.randperm(k_size)
g_idx = g_idx[rand_perm].contiguous() g_idx = g_idx[rand_perm].contiguous()
q_w = q_w[rand_perm, :].contiguous() q_w = q_w[rand_perm, :].contiguous()
@ -164,8 +167,11 @@ def quantize_weights(w: torch.Tensor,
) )
def gptq_quantize_weights(w: torch.Tensor, quant_type: ScalarType, def gptq_quantize_weights(w: torch.Tensor,
group_size: int, act_order: bool): quant_type: ScalarType,
group_size: int,
act_order: bool,
test_perm: Optional[torch.Tensor] = None):
size_k, _ = w.shape size_k, _ = w.shape
assert w.is_floating_point(), "w must be float" assert w.is_floating_point(), "w must be float"
@ -186,7 +192,8 @@ def gptq_quantize_weights(w: torch.Tensor, quant_type: ScalarType,
), "For act_order, groupsize = {} must be less than size_k = {}".format( ), "For act_order, groupsize = {} must be less than size_k = {}".format(
group_size, size_k) group_size, size_k)
w_ref, w_q, g_idx, rand_perm = permute_rows(w_q, w_ref, group_size) w_ref, w_q, g_idx, rand_perm = permute_rows(w_q, w_ref, group_size,
test_perm)
return w_ref, w_q, w_s, g_idx, rand_perm return w_ref, w_q, w_s, g_idx, rand_perm

View File

@ -24,10 +24,18 @@ def get_model_architecture(
# Special handling for quantized Mixtral. # Special handling for quantized Mixtral.
# FIXME(woosuk): This is a temporary hack. # FIXME(woosuk): This is a temporary hack.
mixtral_supported = ["fp8", "compressed-tensors"] mixtral_supported = ["fp8", "compressed-tensors"]
# for gptq_marlin, only run fused MoE for int4
if model_config.quantization == "gptq_marlin":
hf_quant_config = getattr(model_config.hf_config,
"quantization_config", None)
if hf_quant_config and hf_quant_config.get("bits") == 4:
mixtral_supported.append("gptq_marlin")
if (model_config.quantization is not None if (model_config.quantization is not None
and model_config.quantization not in mixtral_supported and model_config.quantization not in mixtral_supported
and "MixtralForCausalLM" in architectures): and "MixtralForCausalLM" in architectures):
architectures = ["QuantMixtralForCausalLM"] architectures = ["QuantMixtralForCausalLM"]
return ModelRegistry.resolve_model_cls(architectures) return ModelRegistry.resolve_model_cls(architectures)

View File

@ -435,7 +435,8 @@ class MixtralForCausalLM(nn.Module, SupportsLoRA):
continue continue
name = name.replace(weight_name, param_name) name = name.replace(weight_name, param_name)
# Skip loading extra bias for GPTQ models. # Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict: if ((name.endswith(".bias") or name.endswith("_bias"))
and name not in params_dict):
continue continue
# Skip layers on other devices. # Skip layers on other devices.
if is_pp_missing_parameter(name, self): if is_pp_missing_parameter(name, self):
@ -454,6 +455,9 @@ class MixtralForCausalLM(nn.Module, SupportsLoRA):
# Skip layers on other devices. # Skip layers on other devices.
if is_pp_missing_parameter(name, self): if is_pp_missing_parameter(name, self):
continue continue
if ((name.endswith(".bias") or name.endswith("_bias"))
and name not in params_dict):
continue
param = params_dict[name] param = params_dict[name]
weight_loader = param.weight_loader weight_loader = param.weight_loader
weight_loader(param, weight_loader(param,
@ -464,7 +468,8 @@ class MixtralForCausalLM(nn.Module, SupportsLoRA):
break break
else: else:
# Skip loading extra bias for GPTQ models. # Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict: if ((name.endswith(".bias") or name.endswith("_bias"))
and name not in params_dict):
continue continue
# Skip layers on other devices. # Skip layers on other devices.
if is_pp_missing_parameter(name, self): if is_pp_missing_parameter(name, self):