[Misc] Fused MoE Marlin support for GPTQ (#8217)
This commit is contained in:
parent
c7cb5c3335
commit
6cd5e5b07e
@ -386,7 +386,18 @@ steps:
|
||||
- vllm/
|
||||
- tests/weight_loading
|
||||
commands:
|
||||
- bash weight_loading/run_model_weight_loading_test.sh
|
||||
- bash weight_loading/run_model_weight_loading_test.sh -c weight_loading/models.txt
|
||||
|
||||
- label: Weight Loading Multiple GPU Test - Large Models # optional
|
||||
working_dir: "/vllm-workspace/tests"
|
||||
num_gpus: 2
|
||||
gpu: a100
|
||||
optional: true
|
||||
source_file_dependencies:
|
||||
- vllm/
|
||||
- tests/weight_loading
|
||||
commands:
|
||||
- bash weight_loading/run_model_weight_loading_test.sh -c weight_loading/models-large.txt
|
||||
|
||||
|
||||
##### multi gpus test #####
|
||||
|
@ -1737,4 +1737,4 @@ torch::Tensor marlin_gemm_moe(
|
||||
moe_block_size, dev, at::cuda::getCurrentCUDAStream(dev), thread_k,
|
||||
thread_n, sms, max_par, replicate_input, apply_weights);
|
||||
return c;
|
||||
}
|
||||
}
|
||||
|
@ -9,4 +9,4 @@ torch::Tensor marlin_gemm_moe(
|
||||
const torch::Tensor& g_idx, const torch::Tensor& perm,
|
||||
torch::Tensor& workspace, int64_t size_m, int64_t size_n, int64_t size_k,
|
||||
bool is_k_full, int64_t num_experts, int64_t topk, int64_t moe_block_size,
|
||||
bool replicate_input, bool apply_weights);
|
||||
bool replicate_input, bool apply_weights);
|
||||
|
@ -16,7 +16,6 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, m) {
|
||||
"g_idx, Tensor! perm, Tensor! workspace, int size_m, int size_n, int "
|
||||
"size_k, bool is_k_full, int num_experts, int topk, int moe_block_size, "
|
||||
"bool replicate_input, bool apply_weights) -> Tensor");
|
||||
|
||||
m.impl("marlin_gemm_moe", torch::kCUDA, &marlin_gemm_moe);
|
||||
#endif
|
||||
}
|
||||
|
@ -2,6 +2,8 @@
|
||||
|
||||
Run `pytest tests/kernels/test_moe.py`.
|
||||
"""
|
||||
from typing import List
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
from transformers import MixtralConfig
|
||||
@ -9,7 +11,13 @@ from transformers.models.mixtral.modeling_mixtral import MixtralSparseMoeBlock
|
||||
|
||||
from vllm.model_executor.layers.activation import SiluAndMul
|
||||
from vllm.model_executor.layers.fused_moe import fused_moe
|
||||
from vllm.model_executor.layers.fused_moe.fused_marlin_moe import (
|
||||
fused_marlin_moe, single_marlin_moe)
|
||||
from vllm.model_executor.layers.fused_moe.fused_moe import fused_topk
|
||||
from vllm.model_executor.layers.quantization.utils.marlin_utils_test import (
|
||||
marlin_quantize)
|
||||
from vllm.model_executor.models.mixtral import MixtralMoE
|
||||
from vllm.scalar_type import scalar_types
|
||||
|
||||
|
||||
def torch_moe(a, w1, w2, score, topk):
|
||||
@ -29,6 +37,20 @@ def torch_moe(a, w1, w2, score, topk):
|
||||
topk_weight.view(B, -1, 1).to(out.dtype)).sum(dim=1)
|
||||
|
||||
|
||||
def torch_moe_single(a, w, score, topk):
|
||||
B, D = a.shape
|
||||
a = a.view(B, -1, D).repeat(1, topk, 1).reshape(-1, D)
|
||||
out = torch.zeros(B * topk, w.shape[1], dtype=a.dtype, device=a.device)
|
||||
score = torch.softmax(score, dim=-1, dtype=torch.float32)
|
||||
_, topk_ids = torch.topk(score, topk)
|
||||
topk_ids = topk_ids.view(-1)
|
||||
for i in range(w.shape[0]):
|
||||
mask = topk_ids == i
|
||||
if mask.sum():
|
||||
out[mask] = a[mask] @ w[i].transpose(0, 1)
|
||||
return (out.view(B, -1, w.shape[1])).sum(dim=1)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("m", [1024 * 128, 512, 222, 33, 1])
|
||||
@pytest.mark.parametrize("n", [2048, 256, 1024])
|
||||
@pytest.mark.parametrize("k", [128, 511, 1024])
|
||||
@ -43,11 +65,11 @@ def test_fused_moe(
|
||||
topk: int,
|
||||
dtype: torch.dtype,
|
||||
):
|
||||
a = torch.randn((m, k), device='cuda', dtype=dtype) / 10
|
||||
w1 = torch.randn((e, 2 * n, k), device='cuda', dtype=dtype) / 10
|
||||
w2 = torch.randn((e, k, n), device='cuda', dtype=dtype) / 10
|
||||
a = torch.randn((m, k), device="cuda", dtype=dtype) / 10
|
||||
w1 = torch.randn((e, 2 * n, k), device="cuda", dtype=dtype) / 10
|
||||
w2 = torch.randn((e, k, n), device="cuda", dtype=dtype) / 10
|
||||
|
||||
score = torch.randn((m, e), device='cuda', dtype=dtype)
|
||||
score = torch.randn((m, e), device="cuda", dtype=dtype)
|
||||
triton_output = fused_moe(a, w1, w2, score, topk, renormalize=False)
|
||||
torch_output = torch_moe(a, w1, w2, score, topk)
|
||||
torch.testing.assert_close(triton_output, torch_output, atol=1e-2, rtol=0)
|
||||
@ -99,3 +121,194 @@ def test_mixtral_moe(dtype: torch.dtype):
|
||||
vllm_states,
|
||||
rtol=mixtral_moe_tol[dtype],
|
||||
atol=mixtral_moe_tol[dtype])
|
||||
|
||||
|
||||
def stack_and_dev(tensors: List[torch.Tensor]):
|
||||
dev = tensors[0].device
|
||||
return torch.stack(tensors, dim=0).to(dev)
|
||||
|
||||
|
||||
def compute_max_diff(output, output_ref):
|
||||
return torch.mean(torch.abs(output - output_ref)) / torch.mean(
|
||||
torch.abs(output_ref))
|
||||
|
||||
|
||||
@pytest.mark.parametrize("m", [64, 512, 222, 33, 1])
|
||||
@pytest.mark.parametrize("n", [128, 2048, 256, 1024])
|
||||
@pytest.mark.parametrize("k", [128, 1024, 512])
|
||||
@pytest.mark.parametrize("e", [4, 8, 64])
|
||||
@pytest.mark.parametrize("topk", [2, 6])
|
||||
@pytest.mark.parametrize("group_size", [-1, 32, 64, 128])
|
||||
@pytest.mark.parametrize("act_order", [True, False])
|
||||
def test_fused_marlin_moe(
|
||||
m: int,
|
||||
n: int,
|
||||
k: int,
|
||||
e: int,
|
||||
topk: int,
|
||||
group_size: int,
|
||||
act_order: bool,
|
||||
):
|
||||
torch.manual_seed(7)
|
||||
|
||||
if topk > e:
|
||||
return
|
||||
|
||||
# Filter act_order
|
||||
if act_order:
|
||||
if group_size == -1:
|
||||
return
|
||||
if group_size in (k, n):
|
||||
return
|
||||
|
||||
quant_type = scalar_types.uint4b8
|
||||
dtype = torch.float16
|
||||
a = torch.randn((m, k), device="cuda", dtype=dtype) / 10
|
||||
w1 = torch.randn((e, 2 * n, k), device="cuda", dtype=dtype) / 10
|
||||
w2 = torch.randn((e, k, n), device="cuda", dtype=dtype) / 10
|
||||
for i in range(w2.shape[0]):
|
||||
w2[0] = torch.eye(k, n, device="cuda", dtype=dtype)
|
||||
|
||||
w_ref1_l = []
|
||||
qweight1_l = []
|
||||
scales1_l = []
|
||||
g_idx1_l = []
|
||||
sort_indices1_l = []
|
||||
|
||||
for i in range(w1.shape[0]):
|
||||
test_perm = torch.randperm(k)
|
||||
w_ref1, qweight1, scales1, g_idx1, sort_indices1, _ = marlin_quantize(
|
||||
w1[i].transpose(1, 0), quant_type, group_size, act_order,
|
||||
test_perm)
|
||||
w_ref1_l.append(w_ref1)
|
||||
qweight1_l.append(qweight1)
|
||||
scales1_l.append(scales1)
|
||||
g_idx1_l.append(g_idx1)
|
||||
sort_indices1_l.append(sort_indices1)
|
||||
|
||||
w_ref1 = stack_and_dev(w_ref1_l)
|
||||
qweight1 = stack_and_dev(qweight1_l).contiguous()
|
||||
scales1 = stack_and_dev(scales1_l)
|
||||
g_idx1 = stack_and_dev(g_idx1_l)
|
||||
sort_indices1 = stack_and_dev(sort_indices1_l)
|
||||
|
||||
w_ref2_l = []
|
||||
qweight2_l = []
|
||||
scales2_l = []
|
||||
g_idx2_l = []
|
||||
sort_indices2_l = []
|
||||
|
||||
for i in range(w2.shape[0]):
|
||||
test_perm = torch.randperm(n)
|
||||
w_ref2, qweight2, scales2, g_idx2, sort_indices2, _ = marlin_quantize(
|
||||
w2[i].transpose(1, 0), quant_type, group_size, act_order,
|
||||
test_perm)
|
||||
w_ref2_l.append(w_ref2)
|
||||
qweight2_l.append(qweight2)
|
||||
scales2_l.append(scales2)
|
||||
g_idx2_l.append(g_idx2)
|
||||
sort_indices2_l.append(sort_indices2)
|
||||
|
||||
w_ref2 = stack_and_dev(w_ref2_l)
|
||||
qweight2 = stack_and_dev(qweight2_l).contiguous()
|
||||
scales2 = stack_and_dev(scales2_l)
|
||||
g_idx2 = stack_and_dev(g_idx2_l)
|
||||
sort_indices2 = stack_and_dev(sort_indices2_l)
|
||||
|
||||
score = torch.randn((m, e), device="cuda", dtype=dtype)
|
||||
|
||||
topk_weights, topk_ids = fused_topk(a, score, topk, False)
|
||||
|
||||
triton_output = fused_moe(
|
||||
a,
|
||||
w_ref1.transpose(1, 2).contiguous(),
|
||||
w_ref2.transpose(1, 2).contiguous(),
|
||||
score,
|
||||
topk,
|
||||
renormalize=False,
|
||||
)
|
||||
marlin_output = fused_marlin_moe(
|
||||
a,
|
||||
qweight1,
|
||||
qweight2,
|
||||
score,
|
||||
g_idx1,
|
||||
g_idx2,
|
||||
sort_indices1,
|
||||
sort_indices2,
|
||||
topk_weights,
|
||||
topk_ids,
|
||||
w1_scale=scales1,
|
||||
w2_scale=scales2,
|
||||
)
|
||||
|
||||
assert compute_max_diff(marlin_output, triton_output) < 4e-2
|
||||
|
||||
|
||||
@pytest.mark.skip("This test is here for the sake of debugging, "
|
||||
"don't run it in automated tests.")
|
||||
@pytest.mark.parametrize("m", [64, 512, 222, 33, 1])
|
||||
@pytest.mark.parametrize("n", [128, 2048, 256, 1024])
|
||||
@pytest.mark.parametrize("k", [128, 1024, 512])
|
||||
@pytest.mark.parametrize("e", [4, 8, 64])
|
||||
@pytest.mark.parametrize("topk", [2, 6])
|
||||
@pytest.mark.parametrize("group_size", [-1, 32, 64, 128])
|
||||
@pytest.mark.parametrize("act_order", [True, False])
|
||||
def test_marlin_moe_mmm(
|
||||
m: int,
|
||||
n: int,
|
||||
k: int,
|
||||
e: int,
|
||||
topk: int,
|
||||
group_size: int,
|
||||
act_order: bool,
|
||||
):
|
||||
if topk > e:
|
||||
return
|
||||
|
||||
# Filter act_order
|
||||
if act_order:
|
||||
if group_size == -1:
|
||||
return
|
||||
if group_size == k:
|
||||
return
|
||||
|
||||
quant_type = scalar_types.uint4b8
|
||||
dtype = torch.float16
|
||||
a = torch.randn((m, k), device="cuda", dtype=dtype) / 10
|
||||
w = torch.randn((e, n, k), device="cuda", dtype=dtype) / 10
|
||||
|
||||
w_ref_l = []
|
||||
qweights_l = []
|
||||
scales_l = []
|
||||
g_idx_l = []
|
||||
sort_indices_l = []
|
||||
|
||||
for i in range(w.shape[0]):
|
||||
test_perm = torch.randperm(k)
|
||||
w_ref, qweight, scales, g_idx, sort_indices, _ = marlin_quantize(
|
||||
w[i].transpose(1, 0), quant_type, group_size, act_order, test_perm)
|
||||
w_ref_l.append(w_ref)
|
||||
qweights_l.append(qweight)
|
||||
scales_l.append(scales)
|
||||
g_idx_l.append(g_idx)
|
||||
sort_indices_l.append(sort_indices)
|
||||
|
||||
w_ref = stack_and_dev(w_ref_l)
|
||||
qweight = stack_and_dev(qweights_l).contiguous()
|
||||
scales = stack_and_dev(scales_l)
|
||||
g_idx = stack_and_dev(g_idx_l)
|
||||
sort_indices = stack_and_dev(sort_indices_l)
|
||||
|
||||
score = torch.randn((m, e), device="cuda", dtype=dtype)
|
||||
marlin_output = single_marlin_moe(a,
|
||||
qweight,
|
||||
scales,
|
||||
score,
|
||||
g_idx,
|
||||
sort_indices,
|
||||
topk,
|
||||
renormalize=False)
|
||||
torch_output = torch_moe_single(a, w_ref.transpose(1, 2), score, topk)
|
||||
|
||||
assert compute_max_diff(marlin_output, torch_output) < 1e-2
|
||||
|
3
tests/weight_loading/models-large.txt
Normal file
3
tests/weight_loading/models-large.txt
Normal file
@ -0,0 +1,3 @@
|
||||
compressed-tensors, nm-testing/Mixtral-8x7B-Instruct-v0.1-W4A16-quantized, main
|
||||
compressed-tensors, nm-testing/Mixtral-8x7B-Instruct-v0.1-W4A16-channel-quantized, main
|
||||
gptq_marlin, TheBloke/Mixtral-8x7B-v0.1-GPTQ, main
|
@ -19,8 +19,6 @@ compressed-tensors, nm-testing/tinyllama-oneshot-w8a16-per-channel, main
|
||||
compressed-tensors, nm-testing/Meta-Llama-3-8B-FP8-compressed-tensors-test, main
|
||||
compressed-tensors, nm-testing/Phi-3-mini-128k-instruct-FP8, main
|
||||
compressed-tensors, neuralmagic/Phi-3-medium-128k-instruct-quantized.w4a16, main
|
||||
compressed-tensors, nm-testing/Mixtral-8x7B-Instruct-v0.1-W4A16-quantized, main
|
||||
compressed-tensors, nm-testing/Mixtral-8x7B-Instruct-v0.1-W4A16-channel-quantized, main
|
||||
compressed-tensors, nm-testing/TinyLlama-1.1B-Chat-v1.0-actorder-group, main
|
||||
awq, casperhansen/mixtral-instruct-awq, main
|
||||
awq_marlin, casperhansen/mixtral-instruct-awq, main
|
||||
|
@ -2,16 +2,22 @@ from vllm.model_executor.layers.fused_moe.layer import (
|
||||
FusedMoE, FusedMoEMethodBase, FusedMoeWeightScaleSupported)
|
||||
from vllm.triton_utils import HAS_TRITON
|
||||
|
||||
__all__ = ["FusedMoE", "FusedMoEMethodBase", "FusedMoeWeightScaleSupported"]
|
||||
__all__ = [
|
||||
"FusedMoE",
|
||||
"FusedMoEMethodBase",
|
||||
"FusedMoeWeightScaleSupported",
|
||||
]
|
||||
|
||||
if HAS_TRITON:
|
||||
|
||||
from vllm.model_executor.layers.fused_moe.fused_marlin_moe import (
|
||||
fused_marlin_moe, single_marlin_moe)
|
||||
from vllm.model_executor.layers.fused_moe.fused_moe import (
|
||||
fused_experts, fused_marlin_moe, fused_moe, fused_topk,
|
||||
get_config_file_name, grouped_topk)
|
||||
fused_experts, fused_moe, fused_topk, get_config_file_name,
|
||||
grouped_topk)
|
||||
|
||||
__all__ += [
|
||||
"fused_marlin_moe",
|
||||
"single_marlin_moe",
|
||||
"fused_moe",
|
||||
"fused_topk",
|
||||
"fused_experts",
|
||||
|
219
vllm/model_executor/layers/fused_moe/fused_marlin_moe.py
Normal file
219
vllm/model_executor/layers/fused_moe/fused_marlin_moe.py
Normal file
@ -0,0 +1,219 @@
|
||||
"""Fused MoE utilities for GPTQ."""
|
||||
import functools
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
import torch
|
||||
|
||||
from vllm import _custom_ops as ops
|
||||
from vllm.model_executor.layers.fused_moe.fused_moe import (
|
||||
fused_topk, moe_align_block_size, try_get_optimal_moe_config)
|
||||
|
||||
|
||||
def single_marlin_moe(
|
||||
hidden_states: torch.Tensor,
|
||||
w: torch.Tensor,
|
||||
scales: torch.Tensor,
|
||||
gating_output: torch.Tensor,
|
||||
g_idx: torch.Tensor,
|
||||
perm: torch.Tensor,
|
||||
topk: int,
|
||||
renormalize: bool,
|
||||
override_config: Optional[Dict[str, Any]] = None) -> torch.Tensor:
|
||||
"""
|
||||
This function computes the multiplication of hidden_states with expert
|
||||
weights used in Marlin MoE, using weights w and top-k gating mechanism.
|
||||
Its purpose is testing and debugging the fused MoE kernel.
|
||||
|
||||
Parameters:
|
||||
- hidden_states (torch.Tensor): The input tensor to the Marlin Mul.
|
||||
- w (torch.Tensor): The set of expert weights.
|
||||
- scales (torch.Tensor): The quantization scales.
|
||||
- gating_output (torch.Tensor): The output of the gating operation
|
||||
(before softmax).
|
||||
- g_idx (torch.Tensor): The act_order indices.
|
||||
- perm (torch.Tensor): The act_order input permutation.
|
||||
- topk (int): The number of top-k experts to select.
|
||||
- renormalize (bool): If True, renormalize the top-k weights to sum to 1.
|
||||
- override_config (Optional[Dict[str, Any]]): Optional override
|
||||
for the kernel configuration.
|
||||
|
||||
Returns:
|
||||
- torch.Tensor: The output tensor after applying the MoE layer.
|
||||
"""
|
||||
# Check constraints.
|
||||
assert hidden_states.shape[0] == gating_output.shape[0], (
|
||||
"Number of tokens mismatch")
|
||||
assert hidden_states.shape[1] == w.shape[1] * 16, "Hidden size mismatch"
|
||||
assert gating_output.shape[1] == w.shape[0], "Number of experts mismatch"
|
||||
assert hidden_states.is_contiguous(), "Hidden_states must be contiguous"
|
||||
assert w.is_contiguous(), "Expert weights must be contiguous"
|
||||
assert hidden_states.dtype == torch.float16
|
||||
|
||||
M, K = hidden_states.shape
|
||||
E = w.shape[0]
|
||||
N = w.shape[2] // 2
|
||||
|
||||
topk_weights, topk_ids = fused_topk(hidden_states, gating_output, topk,
|
||||
renormalize)
|
||||
|
||||
# This might not be an optimal config for a single MMM
|
||||
get_config_func = functools.partial(try_get_optimal_moe_config,
|
||||
w.shape,
|
||||
w.shape,
|
||||
topk_ids.shape[1],
|
||||
None,
|
||||
override_config=override_config,
|
||||
is_marlin=True)
|
||||
config = get_config_func(M)
|
||||
|
||||
block_size_m = config['BLOCK_SIZE_M']
|
||||
|
||||
sorted_token_ids, _, _ = moe_align_block_size(topk_ids, block_size_m, E)
|
||||
|
||||
max_workspace_size = (N // 64) * 16
|
||||
workspace = torch.zeros(max_workspace_size,
|
||||
dtype=torch.int,
|
||||
device="cuda",
|
||||
requires_grad=False)
|
||||
|
||||
intermediate_cache = torch.ops._moe_C.marlin_gemm_moe(
|
||||
hidden_states, w, sorted_token_ids, topk_weights, topk_ids, scales,
|
||||
g_idx, perm, workspace, M, N, K, True, E, topk, block_size_m, True,
|
||||
False)
|
||||
|
||||
return torch.sum(intermediate_cache.view(*intermediate_cache.shape), dim=1)
|
||||
|
||||
|
||||
def fused_marlin_moe(
|
||||
hidden_states: torch.Tensor,
|
||||
w1: torch.Tensor,
|
||||
w2: torch.Tensor,
|
||||
gating_output: torch.Tensor,
|
||||
g_idx1: torch.Tensor,
|
||||
g_idx2: torch.Tensor,
|
||||
perm1: torch.Tensor,
|
||||
perm2: torch.Tensor,
|
||||
topk_weights: torch.Tensor,
|
||||
topk_ids: torch.Tensor,
|
||||
override_config: Optional[Dict[str, Any]] = None,
|
||||
w1_scale: Optional[torch.Tensor] = None,
|
||||
w2_scale: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
This function computes a Mixture of Experts (MoE) layer using two sets of
|
||||
weights, w1 and w2, and top-k gating mechanism.
|
||||
|
||||
Parameters:
|
||||
- hidden_states (torch.Tensor): The input tensor to the MoE layer.
|
||||
- w1 (torch.Tensor): The first set of expert weights.
|
||||
- w2 (torch.Tensor): The second set of expert weights.
|
||||
- gating_output (torch.Tensor): The output of the gating operation
|
||||
(before softmax).
|
||||
- g_idx1 (torch.Tensor): The first set of act_order indices.
|
||||
- g_idx2 (torch.Tensor): The second set of act_order indices.
|
||||
- perm1 (torch.Tensor): The first act_order input permutation.
|
||||
- perm2 (torch.Tensor): The second act_order input permutation.
|
||||
- topk_weights (torch.Tensor): Top-k weights.
|
||||
- topk_ids (torch.Tensor): Indices of topk-k elements.
|
||||
- renormalize (bool): If True, renormalize the top-k weights to sum to 1.
|
||||
- override_config (Optional[Dict[str, Any]]): Optional override
|
||||
for the kernel configuration.
|
||||
- w1_scale (Optional[torch.Tensor]): Optional scale to be used for
|
||||
w1.
|
||||
- w2_scale (Optional[torch.Tensor]): Optional scale to be used for
|
||||
w2.
|
||||
|
||||
Returns:
|
||||
- torch.Tensor: The output tensor after applying the MoE layer.
|
||||
"""
|
||||
# Check constraints.
|
||||
assert hidden_states.shape[0] == gating_output.shape[
|
||||
0], "Number of tokens mismatch"
|
||||
assert hidden_states.shape[
|
||||
1] == w1.shape[1] * 16, "Hidden size mismatch w1"
|
||||
assert hidden_states.shape[
|
||||
1] == w2.shape[2] // 2, "Hidden size mismatch w2"
|
||||
assert gating_output.shape[1] == w1.shape[0], "Number of experts mismatch"
|
||||
assert hidden_states.is_contiguous(), "Hidden_states must be contiguous"
|
||||
assert w1.is_contiguous(), "Expert weights1 must be contiguous"
|
||||
assert w2.is_contiguous(), "Expert weights2 must be contiguous"
|
||||
assert hidden_states.dtype == torch.float16
|
||||
|
||||
M, K = hidden_states.shape
|
||||
E = w1.shape[0]
|
||||
N = w2.shape[1] * 16
|
||||
topk = topk_ids.shape[1]
|
||||
|
||||
get_config_func = functools.partial(
|
||||
try_get_optimal_moe_config,
|
||||
w1.shape,
|
||||
w2.shape,
|
||||
topk_ids.shape[1],
|
||||
None,
|
||||
override_config=override_config,
|
||||
is_marlin=True,
|
||||
)
|
||||
config = get_config_func(M)
|
||||
|
||||
block_size_m = config["BLOCK_SIZE_M"]
|
||||
|
||||
sorted_token_ids, _, _ = moe_align_block_size(topk_ids, block_size_m, E)
|
||||
|
||||
max_workspace_size = ((M + 255) // 256) * (max(2 * N, K) // 64) * 16
|
||||
workspace = torch.zeros(max_workspace_size,
|
||||
dtype=torch.int,
|
||||
device="cuda",
|
||||
requires_grad=False)
|
||||
|
||||
intermediate_cache2 = torch.empty(
|
||||
(M * topk_ids.shape[1], N),
|
||||
device=hidden_states.device,
|
||||
dtype=hidden_states.dtype,
|
||||
)
|
||||
|
||||
intermediate_cache1 = torch.ops._moe_C.marlin_gemm_moe(
|
||||
hidden_states,
|
||||
w1,
|
||||
sorted_token_ids,
|
||||
topk_weights,
|
||||
topk_ids,
|
||||
w1_scale,
|
||||
g_idx1,
|
||||
perm1,
|
||||
workspace,
|
||||
M,
|
||||
2 * N,
|
||||
K,
|
||||
True,
|
||||
E,
|
||||
topk,
|
||||
block_size_m,
|
||||
True,
|
||||
False,
|
||||
)
|
||||
|
||||
ops.silu_and_mul(intermediate_cache2, intermediate_cache1.view(-1, 2 * N))
|
||||
|
||||
intermediate_cache3 = torch.ops._moe_C.marlin_gemm_moe(
|
||||
intermediate_cache2,
|
||||
w2,
|
||||
sorted_token_ids,
|
||||
topk_weights,
|
||||
topk_ids,
|
||||
w2_scale,
|
||||
g_idx2,
|
||||
perm2,
|
||||
workspace,
|
||||
M,
|
||||
K,
|
||||
N,
|
||||
True,
|
||||
E,
|
||||
topk,
|
||||
block_size_m,
|
||||
False,
|
||||
True,
|
||||
)
|
||||
|
||||
return torch.sum(intermediate_cache3.view(*intermediate_cache3.shape),
|
||||
dim=1)
|
@ -323,15 +323,22 @@ def get_moe_configs(E: int, N: int,
|
||||
return None
|
||||
|
||||
|
||||
def get_default_config(M: int, E: int, N: int, K: int, topk: int,
|
||||
dtype: Optional[str],
|
||||
is_marlin: bool) -> Dict[str, int]:
|
||||
def get_default_config(
|
||||
M: int,
|
||||
E: int,
|
||||
N: int,
|
||||
K: int,
|
||||
topk: int,
|
||||
dtype: Optional[str],
|
||||
is_marlin: bool,
|
||||
) -> Dict[str, int]:
|
||||
config = {
|
||||
'BLOCK_SIZE_M': 64,
|
||||
'BLOCK_SIZE_N': 64,
|
||||
'BLOCK_SIZE_K': 32,
|
||||
'GROUP_SIZE_M': 8
|
||||
}
|
||||
# A heuristic: fused marlin works faster with this config for small M
|
||||
if M <= E or (is_marlin and M <= 32):
|
||||
config = {
|
||||
'BLOCK_SIZE_M': 16,
|
||||
@ -342,14 +349,15 @@ def get_default_config(M: int, E: int, N: int, K: int, topk: int,
|
||||
return config
|
||||
|
||||
|
||||
def try_get_optimal_moe_config(w1_shape: Tuple[int, ...],
|
||||
w2_shape: Tuple[int, ...],
|
||||
top_k: int,
|
||||
dtype: Optional[str],
|
||||
M: int,
|
||||
override_config: Optional[Dict[str,
|
||||
Any]] = None,
|
||||
is_marlin: bool = False):
|
||||
def try_get_optimal_moe_config(
|
||||
w1_shape: Tuple[int, ...],
|
||||
w2_shape: Tuple[int, ...],
|
||||
top_k: int,
|
||||
dtype: Optional[str],
|
||||
M: int,
|
||||
override_config: Optional[Dict[str, Any]] = None,
|
||||
is_marlin: bool = False,
|
||||
):
|
||||
if override_config:
|
||||
config = override_config
|
||||
else:
|
||||
@ -391,6 +399,7 @@ def fused_topk(
|
||||
topk,
|
||||
dtype=torch.int32,
|
||||
device=hidden_states.device)
|
||||
|
||||
ops.topk_softmax(
|
||||
topk_weights,
|
||||
topk_ids,
|
||||
@ -437,113 +446,6 @@ def grouped_topk(hidden_states: torch.Tensor,
|
||||
return topk_weights, topk_ids
|
||||
|
||||
|
||||
def fused_marlin_moe(hidden_states: torch.Tensor,
|
||||
w1: torch.Tensor,
|
||||
w2: torch.Tensor,
|
||||
gating_output: torch.Tensor,
|
||||
g_idx1: torch.Tensor,
|
||||
g_idx2: torch.Tensor,
|
||||
rand_perm1: torch.Tensor,
|
||||
rand_perm2: torch.Tensor,
|
||||
topk: int,
|
||||
custom_routing_function: Optional[Callable] = None,
|
||||
renormalize: bool = True,
|
||||
override_config: Optional[Dict[str, Any]] = None,
|
||||
use_fp8: bool = False,
|
||||
w1_scale: Optional[torch.Tensor] = None,
|
||||
w2_scale: Optional[torch.Tensor] = None) -> torch.Tensor:
|
||||
"""
|
||||
This function computes a Mixture of Experts (MoE) layer using two sets of
|
||||
weights, w1 and w2, and top-k gating mechanism.
|
||||
Parameters:
|
||||
- hidden_states (torch.Tensor): The input tensor to the MoE layer.
|
||||
- w1 (torch.Tensor): The first set of expert weights.
|
||||
- w2 (torch.Tensor): The second set of expert weights.
|
||||
- gating_output (torch.Tensor): The output of the gating operation
|
||||
(before softmax).
|
||||
- topk (int): The number of top-k experts to select.
|
||||
- renormalize (bool): If True, renormalize the top-k weights to sum to 1.
|
||||
- inplace (bool): If True, perform the operation in-place.
|
||||
Defaults to False.
|
||||
- override_config (Optional[Dict[str, Any]]): Optional override
|
||||
for the kernel configuration.
|
||||
- use_fp8 (bool): If True, use fp8 arithmetic to compute the inner
|
||||
products for w1 and w2. Defaults to False.
|
||||
- w1_scale (Optional[torch.Tensor]): Optional scale to be used for
|
||||
w1.
|
||||
- w2_scale (Optional[torch.Tensor]): Optional scale to be used for
|
||||
w2.
|
||||
Returns:
|
||||
- torch.Tensor: The output tensor after applying the MoE layer.
|
||||
"""
|
||||
# Check constraints.
|
||||
assert hidden_states.shape[0] == gating_output.shape[0], (
|
||||
"Number of tokens mismatch")
|
||||
assert hidden_states.shape[
|
||||
1] == w1.shape[1] * 16, "Hidden size mismatch w1"
|
||||
assert hidden_states.shape[
|
||||
1] == w2.shape[2] // 2, "Hidden size mismatch w2"
|
||||
assert gating_output.shape[1] == w1.shape[0], "Number of experts mismatch"
|
||||
assert hidden_states.is_contiguous(), "Hidden_states must be contiguous"
|
||||
assert w1.is_contiguous(), "Expert weights1 must be contiguous"
|
||||
assert w2.is_contiguous(), "Expert weights2 must be contiguous"
|
||||
assert hidden_states.dtype in [
|
||||
torch.float32, torch.float16, torch.bfloat16
|
||||
]
|
||||
|
||||
#TODO fp8 is not implemented yet
|
||||
assert not use_fp8
|
||||
|
||||
M, K = hidden_states.shape
|
||||
E = w1.shape[0]
|
||||
N = w2.shape[1] * 16
|
||||
|
||||
if custom_routing_function is None:
|
||||
topk_weights, topk_ids = fused_topk(hidden_states, gating_output, topk,
|
||||
renormalize)
|
||||
else:
|
||||
topk_weights, topk_ids = custom_routing_function(
|
||||
hidden_states, gating_output, topk, renormalize)
|
||||
|
||||
get_config_func = functools.partial(try_get_optimal_moe_config,
|
||||
w1.shape,
|
||||
w2.shape,
|
||||
topk_ids.shape[1],
|
||||
"float8" if use_fp8 else None,
|
||||
override_config=override_config,
|
||||
is_marlin=True)
|
||||
config = get_config_func(M)
|
||||
|
||||
block_size_m = config['BLOCK_SIZE_M']
|
||||
|
||||
sorted_token_ids, _, _ = moe_align_block_size(topk_ids, block_size_m, E)
|
||||
|
||||
max_workspace_size = ((M + 255) // 256) * (max(2 * N, K) // 64) * 16
|
||||
workspace = torch.zeros(max_workspace_size,
|
||||
dtype=torch.int,
|
||||
device="cuda",
|
||||
requires_grad=False)
|
||||
|
||||
intermediate_cache2 = torch.empty((M * topk_ids.shape[1], N),
|
||||
device=hidden_states.device,
|
||||
dtype=hidden_states.dtype)
|
||||
|
||||
intermediate_cache1 = torch.ops._moe_C.marlin_gemm_moe(
|
||||
hidden_states, w1, sorted_token_ids, topk_weights, topk_ids, w1_scale,
|
||||
g_idx1, rand_perm1, workspace, M, 2 * N, K, True, E, topk,
|
||||
block_size_m, True, False)
|
||||
|
||||
ops.silu_and_mul(intermediate_cache2, intermediate_cache1.view(-1, 2 * N))
|
||||
|
||||
intermediate_cache3 = torch.ops._moe_C.marlin_gemm_moe(
|
||||
intermediate_cache2, w2, sorted_token_ids, topk_weights, topk_ids,
|
||||
w2_scale, g_idx2, rand_perm2, workspace, M, K, N, True, E, topk,
|
||||
block_size_m, False, True)
|
||||
|
||||
return torch.sum(intermediate_cache3.view(*intermediate_cache3.shape),
|
||||
dim=1)
|
||||
|
||||
|
||||
def get_config_dtype_str(dtype: torch.dtype,
|
||||
use_int8_w8a16: Optional[bool] = False,
|
||||
use_fp8_w8a8: Optional[bool] = False):
|
||||
|
@ -306,10 +306,28 @@ class FusedMoE(torch.nn.Module):
|
||||
# Input scales can be loaded directly and should be equal.
|
||||
param_data[expert_id] = loaded_weight
|
||||
|
||||
def _load_g_idx(self, shard_id: str, expert_data: torch.Tensor,
|
||||
shard_dim: int, loaded_weight: torch.tensor, tp_rank: int):
|
||||
|
||||
if shard_id == "w2":
|
||||
self._load_w2(shard_id=shard_id,
|
||||
shard_dim=shard_dim,
|
||||
loaded_weight=loaded_weight,
|
||||
expert_data=expert_data,
|
||||
tp_rank=tp_rank)
|
||||
else:
|
||||
assert shard_id in ("w1", "w3")
|
||||
expert_data.copy_(loaded_weight)
|
||||
|
||||
def weight_loader(self, param: torch.nn.Parameter,
|
||||
loaded_weight: torch.Tensor, weight_name: str,
|
||||
shard_id: str, expert_id: int) -> None:
|
||||
|
||||
# compressed-tensors represents weights on disk which are flipped
|
||||
loaded_weight = loaded_weight.t().contiguous() if (
|
||||
self.quant_method.__class__.__name__
|
||||
== "CompressedTensorsMoEMethod") else loaded_weight
|
||||
|
||||
if shard_id not in ("w1", "w2", "w3"):
|
||||
raise ValueError(f"shard_id must be ['w1','w2','w3'] but "
|
||||
f"got {shard_id}.")
|
||||
@ -325,19 +343,41 @@ class FusedMoE(torch.nn.Module):
|
||||
expert_data = param.data[expert_id]
|
||||
tp_rank = get_tensor_model_parallel_rank()
|
||||
|
||||
# is_transposed: whether or not the parameter is transposed on disk
|
||||
# If transposed, the loaded weight will be transposed and the dim
|
||||
# to shard the loaded weight will be flipped.
|
||||
# is_transposed: if the dim to shard the weight
|
||||
# should be flipped. Required by GPTQ, compressed-tensors
|
||||
# should be whatever dimension intermediate_size is
|
||||
is_transposed = getattr(param, "is_transposed", False)
|
||||
shard_dim = SHARD_ID_TO_SHARDED_DIM[shard_id]
|
||||
if is_transposed:
|
||||
loaded_weight = loaded_weight.t().contiguous()
|
||||
shard_dim = ~shard_dim
|
||||
|
||||
# Case weight_scales
|
||||
if "weight_scale" in weight_name:
|
||||
# load the weight scaling based on the quantization scheme
|
||||
# supported weight scales can be found in
|
||||
# Case input scale: input_scale loading is only supported for fp8
|
||||
if "input_scale" in weight_name:
|
||||
if param.data[expert_id] != 1 and (param.data[expert_id] -
|
||||
loaded_weight).abs() > 1e-5:
|
||||
raise ValueError(
|
||||
"input_scales of w1 and w3 of a layer "
|
||||
f"must be equal. But got {param.data[expert_id]} "
|
||||
f"vs. {loaded_weight}")
|
||||
|
||||
self._load_single_value(param=param,
|
||||
loaded_weight=loaded_weight,
|
||||
expert_id=expert_id)
|
||||
return
|
||||
|
||||
# Case g_idx
|
||||
if "g_idx" in weight_name:
|
||||
self._load_g_idx(shard_dim=0,
|
||||
shard_id=shard_id,
|
||||
loaded_weight=loaded_weight,
|
||||
expert_data=expert_data,
|
||||
tp_rank=tp_rank)
|
||||
return
|
||||
|
||||
# Case weight scales and zero_points
|
||||
if ("scale" in weight_name or "zero" in weight_name):
|
||||
# load the weight scales and zp based on the quantization scheme
|
||||
# supported weight scales/zp can be found in
|
||||
# FusedMoeWeightScaleSupported
|
||||
# TODO @dsikka: once hardened, refactor to use vLLM Parameters
|
||||
# specific to each case
|
||||
@ -366,22 +406,9 @@ class FusedMoE(torch.nn.Module):
|
||||
f"quant method must be one of {WEIGHT_SCALE_SUPPORTED}")
|
||||
return
|
||||
|
||||
# Case weight_shape
|
||||
if "weight_shape" in weight_name:
|
||||
self._load_single_value(param=param,
|
||||
loaded_weight=loaded_weight,
|
||||
expert_id=expert_id)
|
||||
return
|
||||
|
||||
# Case input scale
|
||||
if "input_scale" in weight_name:
|
||||
# Note: input_scale loading is only supported for fp8
|
||||
if param.data[expert_id] != 1 and (param.data[expert_id] -
|
||||
loaded_weight).abs() > 1e-5:
|
||||
raise ValueError(
|
||||
"input_scales of w1 and w3 of a layer "
|
||||
f"must be equal. But got {param.data[expert_id]} "
|
||||
f"vs. {loaded_weight}")
|
||||
|
||||
# only required by compressed-tensors
|
||||
self._load_single_value(param=param,
|
||||
loaded_weight=loaded_weight,
|
||||
expert_id=expert_id)
|
||||
@ -498,4 +525,4 @@ class FusedMoE(torch.nn.Module):
|
||||
param_data[expert_id][idx] = loaded_weight
|
||||
# If we are in the row parallel case (down_proj)
|
||||
else:
|
||||
param_data[expert_id] = loaded_weight
|
||||
param_data[expert_id] = loaded_weight
|
||||
|
@ -5,9 +5,7 @@ from typing import Callable, List, Optional
|
||||
import torch
|
||||
|
||||
from vllm import _custom_ops as ops
|
||||
from vllm.model_executor.layers.fused_moe import FusedMoEMethodBase
|
||||
from vllm.model_executor.layers.quantization.compressed_tensors.schemes import (
|
||||
WNA16_SUPPORTED_BITS)
|
||||
from vllm.model_executor.layers.fused_moe import FusedMoE, FusedMoEMethodBase
|
||||
from vllm.model_executor.layers.quantization.compressed_tensors.utils import (
|
||||
CompressionFormat)
|
||||
from vllm.model_executor.utils import set_weight_attrs
|
||||
@ -40,11 +38,10 @@ class CompressedTensorsMoEMethod(FusedMoEMethodBase):
|
||||
|
||||
if not (self.quant_config.quant_format
|
||||
== CompressionFormat.pack_quantized.value
|
||||
and self.num_bits in WNA16_SUPPORTED_BITS):
|
||||
and self.num_bits == 4):
|
||||
raise ValueError("For Fused MoE layers, only ",
|
||||
f"{CompressionFormat.pack_quantized.value} ",
|
||||
"is supported for the following bits: ",
|
||||
f"{WNA16_SUPPORTED_BITS}")
|
||||
"is supported for 4 bits")
|
||||
|
||||
def create_weights(self, layer: torch.nn.Module, num_experts: int,
|
||||
hidden_size: int, intermediate_size: int,
|
||||
@ -269,19 +266,30 @@ class CompressedTensorsMoEMethod(FusedMoEMethodBase):
|
||||
custom_routing_function: Optional[Callable] = None,
|
||||
) -> torch.Tensor:
|
||||
|
||||
from vllm.model_executor.layers.fused_moe.fused_moe import (
|
||||
from vllm.model_executor.layers.fused_moe.fused_marlin_moe import (
|
||||
fused_marlin_moe)
|
||||
|
||||
return fused_marlin_moe(x,
|
||||
layer.w13_weight_packed,
|
||||
layer.w2_weight_packed,
|
||||
router_logits,
|
||||
layer.w13_g_idx,
|
||||
layer.w2_g_idx,
|
||||
layer.w13_g_idx_sort_indices,
|
||||
layer.w2_g_idx_sort_indices,
|
||||
top_k,
|
||||
custom_routing_function,
|
||||
renormalize=renormalize,
|
||||
w1_scale=layer.w13_weight_scale,
|
||||
w2_scale=layer.w2_weight_scale)
|
||||
topk_weights, topk_ids = FusedMoE.select_experts(
|
||||
hidden_states=x,
|
||||
router_logits=router_logits,
|
||||
use_grouped_topk=use_grouped_topk,
|
||||
top_k=top_k,
|
||||
renormalize=renormalize,
|
||||
topk_group=topk_group,
|
||||
num_expert_group=num_expert_group,
|
||||
custom_routing_function=custom_routing_function)
|
||||
|
||||
return fused_marlin_moe(
|
||||
x,
|
||||
layer.w13_weight_packed,
|
||||
layer.w2_weight_packed,
|
||||
router_logits,
|
||||
layer.w13_g_idx,
|
||||
layer.w2_g_idx,
|
||||
layer.w13_g_idx_sort_indices,
|
||||
layer.w2_g_idx_sort_indices,
|
||||
topk_weights,
|
||||
topk_ids,
|
||||
w1_scale=layer.w13_weight_scale,
|
||||
w2_scale=layer.w2_weight_scale,
|
||||
)
|
||||
|
@ -22,7 +22,7 @@ from vllm.scalar_type import scalar_types
|
||||
__all__ = ["CompressedTensorsWNA16"]
|
||||
WNA16_SUPPORTED_TYPES_MAP = {
|
||||
4: scalar_types.uint4b8,
|
||||
8: scalar_types.uint8b128,
|
||||
8: scalar_types.uint8b128
|
||||
}
|
||||
WNA16_SUPPORTED_BITS = list(WNA16_SUPPORTED_TYPES_MAP.keys())
|
||||
|
||||
|
@ -1,18 +1,22 @@
|
||||
from typing import Any, Dict, List, Optional
|
||||
from typing import Any, Callable, Dict, List, Optional, Union
|
||||
|
||||
import torch
|
||||
from torch.nn import Parameter
|
||||
|
||||
from vllm import _custom_ops as ops
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.linear import LinearBase, LinearMethodBase
|
||||
from vllm.model_executor.layers.fused_moe.layer import (
|
||||
FusedMoE, FusedMoEMethodBase, FusedMoeWeightScaleSupported)
|
||||
from vllm.model_executor.layers.linear import (LinearBase, LinearMethodBase,
|
||||
set_weight_attrs)
|
||||
from vllm.model_executor.layers.quantization.base_config import (
|
||||
QuantizationConfig)
|
||||
from vllm.model_executor.layers.quantization.utils.marlin_utils import (
|
||||
apply_gptq_marlin_linear, check_marlin_supported, marlin_is_k_full,
|
||||
marlin_make_empty_g_idx, marlin_make_workspace, marlin_permute_scales,
|
||||
marlin_repeat_scales_on_all_ranks, marlin_sort_g_idx, replace_tensor,
|
||||
verify_marlin_supported, verify_marlin_supports_shape)
|
||||
marlin_make_empty_g_idx, marlin_make_workspace, marlin_moe_permute_scales,
|
||||
marlin_permute_scales, marlin_repeat_scales_on_all_ranks,
|
||||
marlin_sort_g_idx, replace_tensor, verify_marlin_supported,
|
||||
verify_marlin_supports_shape)
|
||||
from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead
|
||||
from vllm.model_executor.parameter import (ChannelQuantScaleParameter,
|
||||
GroupQuantScaleParameter,
|
||||
@ -33,8 +37,14 @@ class GPTQMarlinConfig(QuantizationConfig):
|
||||
(8, True): scalar_types.uint8b128,
|
||||
}
|
||||
|
||||
def __init__(self, weight_bits: int, group_size: int, desc_act: bool,
|
||||
is_sym: bool, lm_head_quantized: bool) -> None:
|
||||
def __init__(
|
||||
self,
|
||||
weight_bits: int,
|
||||
group_size: int,
|
||||
desc_act: bool,
|
||||
is_sym: bool,
|
||||
lm_head_quantized: bool,
|
||||
) -> None:
|
||||
if desc_act and group_size == -1:
|
||||
# In this case, act_order == True is the same as act_order == False
|
||||
# (since we have only one group per output channel)
|
||||
@ -105,11 +115,14 @@ class GPTQMarlinConfig(QuantizationConfig):
|
||||
" faster inference")
|
||||
return None
|
||||
|
||||
def get_quant_method(self, layer: torch.nn.Module,
|
||||
prefix: str) -> Optional["GPTQMarlinLinearMethod"]:
|
||||
if (isinstance(layer, LinearBase) or
|
||||
(isinstance(layer, ParallelLMHead) and self.lm_head_quantized)):
|
||||
def get_quant_method(
|
||||
self, layer: torch.nn.Module, prefix: str
|
||||
) -> Optional[Union["GPTQMarlinLinearMethod", "GPTQMarlinMoEMethod"]]:
|
||||
if isinstance(layer, LinearBase) or (isinstance(layer, ParallelLMHead)
|
||||
and self.lm_head_quantized):
|
||||
return GPTQMarlinLinearMethod(self)
|
||||
elif isinstance(layer, FusedMoE):
|
||||
return GPTQMarlinMoEMethod(self)
|
||||
return None
|
||||
|
||||
def get_scaled_act_names(self) -> List[str]:
|
||||
@ -179,7 +192,8 @@ class GPTQMarlinLinearMethod(LinearMethodBase):
|
||||
output_size_per_partition=output_size_per_partition,
|
||||
input_size_per_partition=input_size_per_partition,
|
||||
input_size=input_size,
|
||||
group_size=group_size)
|
||||
group_size=group_size,
|
||||
)
|
||||
|
||||
# Determine sharding
|
||||
if marlin_repeat_scales_on_all_ranks(self.quant_config.desc_act,
|
||||
@ -299,7 +313,8 @@ class GPTQMarlinLinearMethod(LinearMethodBase):
|
||||
perm=layer.g_idx_sort_indices,
|
||||
size_k=layer.input_size_per_partition,
|
||||
size_n=layer.output_size_per_partition,
|
||||
num_bits=self.quant_config.quant_type.size_bits)
|
||||
num_bits=self.quant_config.quant_type.size_bits,
|
||||
)
|
||||
replace_tensor(layer, "qweight", marlin_qweight)
|
||||
|
||||
# Permute scales from autogptq format to marlin format.
|
||||
@ -308,7 +323,8 @@ class GPTQMarlinLinearMethod(LinearMethodBase):
|
||||
size_k=(layer.input_size if self.quant_config.desc_act else
|
||||
layer.input_size_per_partition),
|
||||
size_n=layer.output_size_per_partition,
|
||||
group_size=self.quant_config.group_size)
|
||||
group_size=self.quant_config.group_size,
|
||||
)
|
||||
replace_tensor(layer, "scales", marlin_scales)
|
||||
|
||||
def apply(
|
||||
@ -329,4 +345,270 @@ class GPTQMarlinLinearMethod(LinearMethodBase):
|
||||
output_size_per_partition=layer.output_size_per_partition,
|
||||
input_size_per_partition=layer.input_size_per_partition,
|
||||
is_k_full=layer.is_k_full,
|
||||
bias=bias)
|
||||
bias=bias,
|
||||
)
|
||||
|
||||
|
||||
class GPTQMarlinMoEMethod(FusedMoEMethodBase):
|
||||
"""MoE Marlin method with quantization."""
|
||||
|
||||
def __init__(self, quant_config: GPTQMarlinConfig) -> None:
|
||||
self.quant_config = quant_config
|
||||
|
||||
def create_weights(
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
num_experts: int,
|
||||
hidden_size: int,
|
||||
intermediate_size: int,
|
||||
params_dtype: torch.dtype,
|
||||
**extra_weight_attrs,
|
||||
):
|
||||
# Currently assuming is_k_full is always True
|
||||
# (input size per partition is the same as full input size)
|
||||
# Supports only sym for now (no zp)
|
||||
if self.quant_config.group_size != -1:
|
||||
scales_size13 = hidden_size // self.quant_config.group_size
|
||||
scales_size2 = intermediate_size // self.quant_config.group_size
|
||||
strategy = FusedMoeWeightScaleSupported.GROUP.value
|
||||
else:
|
||||
scales_size13 = 1
|
||||
scales_size2 = 1
|
||||
strategy = FusedMoeWeightScaleSupported.CHANNEL.value
|
||||
|
||||
extra_weight_attrs.update({
|
||||
"quant_method": strategy,
|
||||
"is_transposed": True
|
||||
})
|
||||
# Fused gate_up_proj (column parallel)
|
||||
w13_qweight = torch.nn.Parameter(
|
||||
torch.empty(
|
||||
num_experts,
|
||||
hidden_size // self.quant_config.pack_factor,
|
||||
2 * intermediate_size,
|
||||
dtype=torch.int32,
|
||||
),
|
||||
requires_grad=False,
|
||||
)
|
||||
layer.register_parameter("w13_qweight", w13_qweight)
|
||||
set_weight_attrs(w13_qweight, extra_weight_attrs)
|
||||
# down_proj (row parallel)
|
||||
w2_qweight = torch.nn.Parameter(
|
||||
torch.empty(
|
||||
num_experts,
|
||||
intermediate_size // self.quant_config.pack_factor,
|
||||
hidden_size,
|
||||
dtype=torch.int32,
|
||||
),
|
||||
requires_grad=False,
|
||||
)
|
||||
layer.register_parameter("w2_qweight", w2_qweight)
|
||||
set_weight_attrs(w2_qweight, extra_weight_attrs)
|
||||
# up_proj scales
|
||||
w13_scales = torch.nn.Parameter(
|
||||
torch.empty(num_experts,
|
||||
scales_size13,
|
||||
2 * intermediate_size,
|
||||
dtype=torch.half),
|
||||
requires_grad=False,
|
||||
)
|
||||
layer.register_parameter("w13_scales", w13_scales)
|
||||
set_weight_attrs(w13_scales, extra_weight_attrs)
|
||||
# down_proj scales
|
||||
w2_scales = torch.nn.Parameter(
|
||||
torch.empty(num_experts,
|
||||
scales_size2,
|
||||
hidden_size,
|
||||
dtype=torch.half),
|
||||
requires_grad=False,
|
||||
)
|
||||
layer.register_parameter("w2_scales", w2_scales)
|
||||
set_weight_attrs(w2_scales, extra_weight_attrs)
|
||||
# up_proj scales
|
||||
w13_qzeros = torch.nn.Parameter(
|
||||
torch.empty(num_experts,
|
||||
scales_size13,
|
||||
2 * intermediate_size // self.quant_config.pack_factor,
|
||||
dtype=params_dtype),
|
||||
requires_grad=False,
|
||||
)
|
||||
layer.register_parameter("w13_qzeros", w13_qzeros)
|
||||
set_weight_attrs(w13_qzeros, extra_weight_attrs)
|
||||
# down_proj scales
|
||||
w2_qzeros = torch.nn.Parameter(
|
||||
torch.empty(num_experts,
|
||||
scales_size2,
|
||||
hidden_size // self.quant_config.pack_factor,
|
||||
dtype=params_dtype),
|
||||
requires_grad=False,
|
||||
)
|
||||
layer.register_parameter("w2_qzeros", w2_qzeros)
|
||||
set_weight_attrs(w2_qzeros, extra_weight_attrs)
|
||||
w13_g_idx = torch.nn.Parameter(
|
||||
torch.empty(
|
||||
num_experts,
|
||||
hidden_size,
|
||||
dtype=torch.int32,
|
||||
),
|
||||
requires_grad=False,
|
||||
)
|
||||
layer.register_parameter("w13_g_idx", w13_g_idx)
|
||||
set_weight_attrs(w13_g_idx, extra_weight_attrs)
|
||||
w2_g_idx = torch.nn.Parameter(
|
||||
torch.empty(
|
||||
num_experts,
|
||||
intermediate_size,
|
||||
dtype=torch.int32,
|
||||
),
|
||||
requires_grad=False,
|
||||
)
|
||||
layer.register_parameter("w2_g_idx", w2_g_idx)
|
||||
set_weight_attrs(w2_g_idx, extra_weight_attrs)
|
||||
w13_g_idx_sort_indices = torch.nn.Parameter(
|
||||
torch.empty(
|
||||
num_experts,
|
||||
hidden_size,
|
||||
dtype=torch.int32,
|
||||
),
|
||||
requires_grad=False,
|
||||
)
|
||||
layer.register_parameter("w13_g_idx_sort_indices",
|
||||
w13_g_idx_sort_indices)
|
||||
set_weight_attrs(w13_g_idx_sort_indices, extra_weight_attrs)
|
||||
w2_g_idx_sort_indices = torch.nn.Parameter(
|
||||
torch.empty(
|
||||
num_experts,
|
||||
intermediate_size,
|
||||
dtype=torch.int32,
|
||||
),
|
||||
requires_grad=False,
|
||||
)
|
||||
layer.register_parameter("w2_g_idx_sort_indices",
|
||||
w2_g_idx_sort_indices)
|
||||
set_weight_attrs(w2_g_idx_sort_indices, extra_weight_attrs)
|
||||
|
||||
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
|
||||
|
||||
# Process act_order
|
||||
if self.quant_config.desc_act:
|
||||
# Get sorting based on g_idx
|
||||
num_experts = layer.w13_g_idx.shape[0]
|
||||
w13_g_idx_sort_indices = torch.empty_like(layer.w13_g_idx)
|
||||
w2_g_idx_sort_indices = torch.empty_like(layer.w2_g_idx)
|
||||
w13_sorted_g_idx = torch.empty_like(layer.w13_g_idx)
|
||||
w2_sorted_g_idx = torch.empty_like(layer.w2_g_idx)
|
||||
for e in range(num_experts):
|
||||
w13_g_idx_sort_indices[e] = torch.argsort(
|
||||
layer.w13_g_idx[e]).to(torch.int32)
|
||||
w2_g_idx_sort_indices[e] = torch.argsort(layer.w2_g_idx[e]).to(
|
||||
torch.int32)
|
||||
w13_sorted_g_idx[e] = layer.w13_g_idx[e][
|
||||
w13_g_idx_sort_indices[e]]
|
||||
w2_sorted_g_idx[e] = layer.w2_g_idx[e][
|
||||
w2_g_idx_sort_indices[e]]
|
||||
replace_tensor(layer, "w13_g_idx", w13_sorted_g_idx)
|
||||
replace_tensor(layer, "w2_g_idx", w2_sorted_g_idx)
|
||||
replace_tensor(layer, "w13_g_idx_sort_indices",
|
||||
w13_g_idx_sort_indices)
|
||||
replace_tensor(layer, "w2_g_idx_sort_indices",
|
||||
w2_g_idx_sort_indices)
|
||||
else:
|
||||
# Reset g_idx related tensors
|
||||
num_experts = layer.w13_g_idx.shape[0]
|
||||
device = layer.w13_g_idx.device
|
||||
layer.w13_g_idx = torch.nn.Parameter(
|
||||
torch.empty((num_experts, 0), dtype=torch.int32,
|
||||
device=device),
|
||||
requires_grad=False,
|
||||
)
|
||||
layer.w2_g_idx = torch.nn.Parameter(
|
||||
torch.empty((num_experts, 0), dtype=torch.int32,
|
||||
device=device),
|
||||
requires_grad=False,
|
||||
)
|
||||
layer.w13_g_idx_sort_indices = torch.nn.Parameter(
|
||||
torch.empty((num_experts, 0), dtype=torch.int32,
|
||||
device=device),
|
||||
requires_grad=False,
|
||||
)
|
||||
layer.w2_g_idx_sort_indices = torch.nn.Parameter(
|
||||
torch.empty((num_experts, 0), dtype=torch.int32,
|
||||
device=device),
|
||||
requires_grad=False,
|
||||
)
|
||||
# Repack weights
|
||||
marlin_w13_qweight = ops.gptq_marlin_moe_repack(
|
||||
layer.w13_qweight,
|
||||
layer.w13_g_idx_sort_indices,
|
||||
layer.w13_qweight.shape[1] * self.quant_config.pack_factor,
|
||||
layer.w13_qweight.shape[2],
|
||||
self.quant_config.quant_type.size_bits,
|
||||
)
|
||||
replace_tensor(layer, "w13_qweight", marlin_w13_qweight)
|
||||
marlin_w2_qweight = ops.gptq_marlin_moe_repack(
|
||||
layer.w2_qweight,
|
||||
layer.w2_g_idx_sort_indices,
|
||||
layer.w2_qweight.shape[1] * self.quant_config.pack_factor,
|
||||
layer.w2_qweight.shape[2],
|
||||
self.quant_config.quant_type.size_bits,
|
||||
)
|
||||
replace_tensor(layer, "w2_qweight", marlin_w2_qweight)
|
||||
# Repack scales
|
||||
marlin_w13_scales = marlin_moe_permute_scales(
|
||||
s=layer.w13_scales,
|
||||
size_k=layer.intermediate_size_per_partition,
|
||||
size_n=layer.w13_scales.shape[2],
|
||||
group_size=self.quant_config.group_size,
|
||||
)
|
||||
replace_tensor(layer, "w13_scales", marlin_w13_scales)
|
||||
marlin_w2_scales = marlin_moe_permute_scales(
|
||||
s=layer.w2_scales,
|
||||
size_k=layer.w2_scales.shape[1] * self.quant_config.pack_factor,
|
||||
size_n=layer.w2_scales.shape[2],
|
||||
group_size=self.quant_config.group_size,
|
||||
)
|
||||
replace_tensor(layer, "w2_scales", marlin_w2_scales)
|
||||
|
||||
def apply(
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
x: torch.Tensor,
|
||||
router_logits: torch.Tensor,
|
||||
top_k: int,
|
||||
renormalize: bool = True,
|
||||
use_grouped_topk: bool = False,
|
||||
num_expert_group: Optional[int] = None,
|
||||
topk_group: Optional[int] = None,
|
||||
custom_routing_function: Optional[Callable] = None,
|
||||
) -> torch.Tensor:
|
||||
from vllm.model_executor.layers.fused_moe.fused_marlin_moe import (
|
||||
fused_marlin_moe)
|
||||
|
||||
# The input must currently be float16
|
||||
orig_dtype = x.dtype
|
||||
x = x.half()
|
||||
|
||||
topk_weights, topk_ids = FusedMoE.select_experts(
|
||||
hidden_states=x,
|
||||
router_logits=router_logits,
|
||||
use_grouped_topk=use_grouped_topk,
|
||||
top_k=top_k,
|
||||
renormalize=renormalize,
|
||||
topk_group=topk_group,
|
||||
num_expert_group=num_expert_group,
|
||||
custom_routing_function=None)
|
||||
|
||||
return fused_marlin_moe(
|
||||
x,
|
||||
layer.w13_qweight,
|
||||
layer.w2_qweight,
|
||||
router_logits,
|
||||
layer.w13_g_idx,
|
||||
layer.w2_g_idx,
|
||||
layer.w13_g_idx_sort_indices,
|
||||
layer.w2_g_idx_sort_indices,
|
||||
topk_weights,
|
||||
topk_ids,
|
||||
w1_scale=layer.w13_scales,
|
||||
w2_scale=layer.w2_scales,
|
||||
).to(orig_dtype)
|
||||
|
@ -176,6 +176,23 @@ def marlin_permute_scales(s: torch.Tensor, size_k: int, size_n: int,
|
||||
return s
|
||||
|
||||
|
||||
def marlin_moe_permute_scales(
|
||||
s: torch.Tensor,
|
||||
size_k: int,
|
||||
size_n: int,
|
||||
group_size: int,
|
||||
):
|
||||
num_experts = s.shape[0]
|
||||
output = torch.empty(
|
||||
(num_experts, s.shape[1], s.shape[2]),
|
||||
device=s.device,
|
||||
dtype=s.dtype,
|
||||
)
|
||||
for e in range(num_experts):
|
||||
output[e] = marlin_permute_scales(s[e], size_k, size_n, group_size)
|
||||
return output
|
||||
|
||||
|
||||
def marlin_zero_points(zp: torch.Tensor, size_k: int, size_n: int,
|
||||
num_bits: int) -> torch.Tensor:
|
||||
# Permute zero-points in a similar way to scales, but do not use the
|
||||
|
@ -1,6 +1,6 @@
|
||||
"""Utility functions used for tests and benchmarks"""
|
||||
|
||||
from typing import List
|
||||
from typing import List, Optional
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
@ -92,8 +92,11 @@ def get_weight_perm(num_bits: int):
|
||||
return perm
|
||||
|
||||
|
||||
def marlin_quantize(w: torch.Tensor, quant_type: ScalarType, group_size: int,
|
||||
act_order: bool):
|
||||
def marlin_quantize(w: torch.Tensor,
|
||||
quant_type: ScalarType,
|
||||
group_size: int,
|
||||
act_order: bool,
|
||||
test_perm: Optional[torch.Tensor] = None):
|
||||
size_k, size_n = w.shape
|
||||
num_bits = quant_type.size_bits
|
||||
|
||||
@ -104,7 +107,7 @@ def marlin_quantize(w: torch.Tensor, quant_type: ScalarType, group_size: int,
|
||||
|
||||
# Quantize (and apply act_order if provided)
|
||||
w_ref, q_w, s, g_idx, rand_perm = gptq_quantize_weights(
|
||||
w, quant_type, group_size, act_order)
|
||||
w, quant_type, group_size, act_order, test_perm)
|
||||
|
||||
# For act_order, sort the "weights" and "g_idx" so that group ids are
|
||||
# increasing
|
||||
|
@ -1,5 +1,5 @@
|
||||
"""This file is used for /tests and /benchmarks"""
|
||||
from typing import List
|
||||
from typing import List, Optional
|
||||
|
||||
import numpy
|
||||
import torch
|
||||
@ -53,7 +53,10 @@ def get_pack_factor(num_bits):
|
||||
return 32 // num_bits
|
||||
|
||||
|
||||
def permute_rows(q_w: torch.Tensor, w_ref: torch.Tensor, group_size: int):
|
||||
def permute_rows(q_w: torch.Tensor,
|
||||
w_ref: torch.Tensor,
|
||||
group_size: int,
|
||||
test_perm: Optional[torch.Tensor] = None):
|
||||
assert q_w.shape == w_ref.shape
|
||||
|
||||
orig_device = q_w.device
|
||||
@ -64,7 +67,7 @@ def permute_rows(q_w: torch.Tensor, w_ref: torch.Tensor, group_size: int):
|
||||
g_idx[i] = i // group_size
|
||||
|
||||
# Simulate act_order by doing a random permutation on K
|
||||
rand_perm = torch.randperm(k_size)
|
||||
rand_perm = test_perm if test_perm is not None else torch.randperm(k_size)
|
||||
|
||||
g_idx = g_idx[rand_perm].contiguous()
|
||||
q_w = q_w[rand_perm, :].contiguous()
|
||||
@ -164,8 +167,11 @@ def quantize_weights(w: torch.Tensor,
|
||||
)
|
||||
|
||||
|
||||
def gptq_quantize_weights(w: torch.Tensor, quant_type: ScalarType,
|
||||
group_size: int, act_order: bool):
|
||||
def gptq_quantize_weights(w: torch.Tensor,
|
||||
quant_type: ScalarType,
|
||||
group_size: int,
|
||||
act_order: bool,
|
||||
test_perm: Optional[torch.Tensor] = None):
|
||||
size_k, _ = w.shape
|
||||
|
||||
assert w.is_floating_point(), "w must be float"
|
||||
@ -186,7 +192,8 @@ def gptq_quantize_weights(w: torch.Tensor, quant_type: ScalarType,
|
||||
), "For act_order, groupsize = {} must be less than size_k = {}".format(
|
||||
group_size, size_k)
|
||||
|
||||
w_ref, w_q, g_idx, rand_perm = permute_rows(w_q, w_ref, group_size)
|
||||
w_ref, w_q, g_idx, rand_perm = permute_rows(w_q, w_ref, group_size,
|
||||
test_perm)
|
||||
|
||||
return w_ref, w_q, w_s, g_idx, rand_perm
|
||||
|
||||
|
@ -24,10 +24,18 @@ def get_model_architecture(
|
||||
# Special handling for quantized Mixtral.
|
||||
# FIXME(woosuk): This is a temporary hack.
|
||||
mixtral_supported = ["fp8", "compressed-tensors"]
|
||||
# for gptq_marlin, only run fused MoE for int4
|
||||
if model_config.quantization == "gptq_marlin":
|
||||
hf_quant_config = getattr(model_config.hf_config,
|
||||
"quantization_config", None)
|
||||
if hf_quant_config and hf_quant_config.get("bits") == 4:
|
||||
mixtral_supported.append("gptq_marlin")
|
||||
|
||||
if (model_config.quantization is not None
|
||||
and model_config.quantization not in mixtral_supported
|
||||
and "MixtralForCausalLM" in architectures):
|
||||
architectures = ["QuantMixtralForCausalLM"]
|
||||
|
||||
return ModelRegistry.resolve_model_cls(architectures)
|
||||
|
||||
|
||||
|
@ -435,7 +435,8 @@ class MixtralForCausalLM(nn.Module, SupportsLoRA):
|
||||
continue
|
||||
name = name.replace(weight_name, param_name)
|
||||
# Skip loading extra bias for GPTQ models.
|
||||
if name.endswith(".bias") and name not in params_dict:
|
||||
if ((name.endswith(".bias") or name.endswith("_bias"))
|
||||
and name not in params_dict):
|
||||
continue
|
||||
# Skip layers on other devices.
|
||||
if is_pp_missing_parameter(name, self):
|
||||
@ -454,6 +455,9 @@ class MixtralForCausalLM(nn.Module, SupportsLoRA):
|
||||
# Skip layers on other devices.
|
||||
if is_pp_missing_parameter(name, self):
|
||||
continue
|
||||
if ((name.endswith(".bias") or name.endswith("_bias"))
|
||||
and name not in params_dict):
|
||||
continue
|
||||
param = params_dict[name]
|
||||
weight_loader = param.weight_loader
|
||||
weight_loader(param,
|
||||
@ -464,7 +468,8 @@ class MixtralForCausalLM(nn.Module, SupportsLoRA):
|
||||
break
|
||||
else:
|
||||
# Skip loading extra bias for GPTQ models.
|
||||
if name.endswith(".bias") and name not in params_dict:
|
||||
if ((name.endswith(".bias") or name.endswith("_bias"))
|
||||
and name not in params_dict):
|
||||
continue
|
||||
# Skip layers on other devices.
|
||||
if is_pp_missing_parameter(name, self):
|
||||
|
Loading…
x
Reference in New Issue
Block a user