[Misc] Fused MoE Marlin support for GPTQ (#8217)
This commit is contained in:
parent
c7cb5c3335
commit
6cd5e5b07e
@ -386,7 +386,18 @@ steps:
|
|||||||
- vllm/
|
- vllm/
|
||||||
- tests/weight_loading
|
- tests/weight_loading
|
||||||
commands:
|
commands:
|
||||||
- bash weight_loading/run_model_weight_loading_test.sh
|
- bash weight_loading/run_model_weight_loading_test.sh -c weight_loading/models.txt
|
||||||
|
|
||||||
|
- label: Weight Loading Multiple GPU Test - Large Models # optional
|
||||||
|
working_dir: "/vllm-workspace/tests"
|
||||||
|
num_gpus: 2
|
||||||
|
gpu: a100
|
||||||
|
optional: true
|
||||||
|
source_file_dependencies:
|
||||||
|
- vllm/
|
||||||
|
- tests/weight_loading
|
||||||
|
commands:
|
||||||
|
- bash weight_loading/run_model_weight_loading_test.sh -c weight_loading/models-large.txt
|
||||||
|
|
||||||
|
|
||||||
##### multi gpus test #####
|
##### multi gpus test #####
|
||||||
|
@ -16,7 +16,6 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, m) {
|
|||||||
"g_idx, Tensor! perm, Tensor! workspace, int size_m, int size_n, int "
|
"g_idx, Tensor! perm, Tensor! workspace, int size_m, int size_n, int "
|
||||||
"size_k, bool is_k_full, int num_experts, int topk, int moe_block_size, "
|
"size_k, bool is_k_full, int num_experts, int topk, int moe_block_size, "
|
||||||
"bool replicate_input, bool apply_weights) -> Tensor");
|
"bool replicate_input, bool apply_weights) -> Tensor");
|
||||||
|
|
||||||
m.impl("marlin_gemm_moe", torch::kCUDA, &marlin_gemm_moe);
|
m.impl("marlin_gemm_moe", torch::kCUDA, &marlin_gemm_moe);
|
||||||
#endif
|
#endif
|
||||||
}
|
}
|
||||||
|
@ -2,6 +2,8 @@
|
|||||||
|
|
||||||
Run `pytest tests/kernels/test_moe.py`.
|
Run `pytest tests/kernels/test_moe.py`.
|
||||||
"""
|
"""
|
||||||
|
from typing import List
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
import torch
|
import torch
|
||||||
from transformers import MixtralConfig
|
from transformers import MixtralConfig
|
||||||
@ -9,7 +11,13 @@ from transformers.models.mixtral.modeling_mixtral import MixtralSparseMoeBlock
|
|||||||
|
|
||||||
from vllm.model_executor.layers.activation import SiluAndMul
|
from vllm.model_executor.layers.activation import SiluAndMul
|
||||||
from vllm.model_executor.layers.fused_moe import fused_moe
|
from vllm.model_executor.layers.fused_moe import fused_moe
|
||||||
|
from vllm.model_executor.layers.fused_moe.fused_marlin_moe import (
|
||||||
|
fused_marlin_moe, single_marlin_moe)
|
||||||
|
from vllm.model_executor.layers.fused_moe.fused_moe import fused_topk
|
||||||
|
from vllm.model_executor.layers.quantization.utils.marlin_utils_test import (
|
||||||
|
marlin_quantize)
|
||||||
from vllm.model_executor.models.mixtral import MixtralMoE
|
from vllm.model_executor.models.mixtral import MixtralMoE
|
||||||
|
from vllm.scalar_type import scalar_types
|
||||||
|
|
||||||
|
|
||||||
def torch_moe(a, w1, w2, score, topk):
|
def torch_moe(a, w1, w2, score, topk):
|
||||||
@ -29,6 +37,20 @@ def torch_moe(a, w1, w2, score, topk):
|
|||||||
topk_weight.view(B, -1, 1).to(out.dtype)).sum(dim=1)
|
topk_weight.view(B, -1, 1).to(out.dtype)).sum(dim=1)
|
||||||
|
|
||||||
|
|
||||||
|
def torch_moe_single(a, w, score, topk):
|
||||||
|
B, D = a.shape
|
||||||
|
a = a.view(B, -1, D).repeat(1, topk, 1).reshape(-1, D)
|
||||||
|
out = torch.zeros(B * topk, w.shape[1], dtype=a.dtype, device=a.device)
|
||||||
|
score = torch.softmax(score, dim=-1, dtype=torch.float32)
|
||||||
|
_, topk_ids = torch.topk(score, topk)
|
||||||
|
topk_ids = topk_ids.view(-1)
|
||||||
|
for i in range(w.shape[0]):
|
||||||
|
mask = topk_ids == i
|
||||||
|
if mask.sum():
|
||||||
|
out[mask] = a[mask] @ w[i].transpose(0, 1)
|
||||||
|
return (out.view(B, -1, w.shape[1])).sum(dim=1)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("m", [1024 * 128, 512, 222, 33, 1])
|
@pytest.mark.parametrize("m", [1024 * 128, 512, 222, 33, 1])
|
||||||
@pytest.mark.parametrize("n", [2048, 256, 1024])
|
@pytest.mark.parametrize("n", [2048, 256, 1024])
|
||||||
@pytest.mark.parametrize("k", [128, 511, 1024])
|
@pytest.mark.parametrize("k", [128, 511, 1024])
|
||||||
@ -43,11 +65,11 @@ def test_fused_moe(
|
|||||||
topk: int,
|
topk: int,
|
||||||
dtype: torch.dtype,
|
dtype: torch.dtype,
|
||||||
):
|
):
|
||||||
a = torch.randn((m, k), device='cuda', dtype=dtype) / 10
|
a = torch.randn((m, k), device="cuda", dtype=dtype) / 10
|
||||||
w1 = torch.randn((e, 2 * n, k), device='cuda', dtype=dtype) / 10
|
w1 = torch.randn((e, 2 * n, k), device="cuda", dtype=dtype) / 10
|
||||||
w2 = torch.randn((e, k, n), device='cuda', dtype=dtype) / 10
|
w2 = torch.randn((e, k, n), device="cuda", dtype=dtype) / 10
|
||||||
|
|
||||||
score = torch.randn((m, e), device='cuda', dtype=dtype)
|
score = torch.randn((m, e), device="cuda", dtype=dtype)
|
||||||
triton_output = fused_moe(a, w1, w2, score, topk, renormalize=False)
|
triton_output = fused_moe(a, w1, w2, score, topk, renormalize=False)
|
||||||
torch_output = torch_moe(a, w1, w2, score, topk)
|
torch_output = torch_moe(a, w1, w2, score, topk)
|
||||||
torch.testing.assert_close(triton_output, torch_output, atol=1e-2, rtol=0)
|
torch.testing.assert_close(triton_output, torch_output, atol=1e-2, rtol=0)
|
||||||
@ -99,3 +121,194 @@ def test_mixtral_moe(dtype: torch.dtype):
|
|||||||
vllm_states,
|
vllm_states,
|
||||||
rtol=mixtral_moe_tol[dtype],
|
rtol=mixtral_moe_tol[dtype],
|
||||||
atol=mixtral_moe_tol[dtype])
|
atol=mixtral_moe_tol[dtype])
|
||||||
|
|
||||||
|
|
||||||
|
def stack_and_dev(tensors: List[torch.Tensor]):
|
||||||
|
dev = tensors[0].device
|
||||||
|
return torch.stack(tensors, dim=0).to(dev)
|
||||||
|
|
||||||
|
|
||||||
|
def compute_max_diff(output, output_ref):
|
||||||
|
return torch.mean(torch.abs(output - output_ref)) / torch.mean(
|
||||||
|
torch.abs(output_ref))
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("m", [64, 512, 222, 33, 1])
|
||||||
|
@pytest.mark.parametrize("n", [128, 2048, 256, 1024])
|
||||||
|
@pytest.mark.parametrize("k", [128, 1024, 512])
|
||||||
|
@pytest.mark.parametrize("e", [4, 8, 64])
|
||||||
|
@pytest.mark.parametrize("topk", [2, 6])
|
||||||
|
@pytest.mark.parametrize("group_size", [-1, 32, 64, 128])
|
||||||
|
@pytest.mark.parametrize("act_order", [True, False])
|
||||||
|
def test_fused_marlin_moe(
|
||||||
|
m: int,
|
||||||
|
n: int,
|
||||||
|
k: int,
|
||||||
|
e: int,
|
||||||
|
topk: int,
|
||||||
|
group_size: int,
|
||||||
|
act_order: bool,
|
||||||
|
):
|
||||||
|
torch.manual_seed(7)
|
||||||
|
|
||||||
|
if topk > e:
|
||||||
|
return
|
||||||
|
|
||||||
|
# Filter act_order
|
||||||
|
if act_order:
|
||||||
|
if group_size == -1:
|
||||||
|
return
|
||||||
|
if group_size in (k, n):
|
||||||
|
return
|
||||||
|
|
||||||
|
quant_type = scalar_types.uint4b8
|
||||||
|
dtype = torch.float16
|
||||||
|
a = torch.randn((m, k), device="cuda", dtype=dtype) / 10
|
||||||
|
w1 = torch.randn((e, 2 * n, k), device="cuda", dtype=dtype) / 10
|
||||||
|
w2 = torch.randn((e, k, n), device="cuda", dtype=dtype) / 10
|
||||||
|
for i in range(w2.shape[0]):
|
||||||
|
w2[0] = torch.eye(k, n, device="cuda", dtype=dtype)
|
||||||
|
|
||||||
|
w_ref1_l = []
|
||||||
|
qweight1_l = []
|
||||||
|
scales1_l = []
|
||||||
|
g_idx1_l = []
|
||||||
|
sort_indices1_l = []
|
||||||
|
|
||||||
|
for i in range(w1.shape[0]):
|
||||||
|
test_perm = torch.randperm(k)
|
||||||
|
w_ref1, qweight1, scales1, g_idx1, sort_indices1, _ = marlin_quantize(
|
||||||
|
w1[i].transpose(1, 0), quant_type, group_size, act_order,
|
||||||
|
test_perm)
|
||||||
|
w_ref1_l.append(w_ref1)
|
||||||
|
qweight1_l.append(qweight1)
|
||||||
|
scales1_l.append(scales1)
|
||||||
|
g_idx1_l.append(g_idx1)
|
||||||
|
sort_indices1_l.append(sort_indices1)
|
||||||
|
|
||||||
|
w_ref1 = stack_and_dev(w_ref1_l)
|
||||||
|
qweight1 = stack_and_dev(qweight1_l).contiguous()
|
||||||
|
scales1 = stack_and_dev(scales1_l)
|
||||||
|
g_idx1 = stack_and_dev(g_idx1_l)
|
||||||
|
sort_indices1 = stack_and_dev(sort_indices1_l)
|
||||||
|
|
||||||
|
w_ref2_l = []
|
||||||
|
qweight2_l = []
|
||||||
|
scales2_l = []
|
||||||
|
g_idx2_l = []
|
||||||
|
sort_indices2_l = []
|
||||||
|
|
||||||
|
for i in range(w2.shape[0]):
|
||||||
|
test_perm = torch.randperm(n)
|
||||||
|
w_ref2, qweight2, scales2, g_idx2, sort_indices2, _ = marlin_quantize(
|
||||||
|
w2[i].transpose(1, 0), quant_type, group_size, act_order,
|
||||||
|
test_perm)
|
||||||
|
w_ref2_l.append(w_ref2)
|
||||||
|
qweight2_l.append(qweight2)
|
||||||
|
scales2_l.append(scales2)
|
||||||
|
g_idx2_l.append(g_idx2)
|
||||||
|
sort_indices2_l.append(sort_indices2)
|
||||||
|
|
||||||
|
w_ref2 = stack_and_dev(w_ref2_l)
|
||||||
|
qweight2 = stack_and_dev(qweight2_l).contiguous()
|
||||||
|
scales2 = stack_and_dev(scales2_l)
|
||||||
|
g_idx2 = stack_and_dev(g_idx2_l)
|
||||||
|
sort_indices2 = stack_and_dev(sort_indices2_l)
|
||||||
|
|
||||||
|
score = torch.randn((m, e), device="cuda", dtype=dtype)
|
||||||
|
|
||||||
|
topk_weights, topk_ids = fused_topk(a, score, topk, False)
|
||||||
|
|
||||||
|
triton_output = fused_moe(
|
||||||
|
a,
|
||||||
|
w_ref1.transpose(1, 2).contiguous(),
|
||||||
|
w_ref2.transpose(1, 2).contiguous(),
|
||||||
|
score,
|
||||||
|
topk,
|
||||||
|
renormalize=False,
|
||||||
|
)
|
||||||
|
marlin_output = fused_marlin_moe(
|
||||||
|
a,
|
||||||
|
qweight1,
|
||||||
|
qweight2,
|
||||||
|
score,
|
||||||
|
g_idx1,
|
||||||
|
g_idx2,
|
||||||
|
sort_indices1,
|
||||||
|
sort_indices2,
|
||||||
|
topk_weights,
|
||||||
|
topk_ids,
|
||||||
|
w1_scale=scales1,
|
||||||
|
w2_scale=scales2,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert compute_max_diff(marlin_output, triton_output) < 4e-2
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.skip("This test is here for the sake of debugging, "
|
||||||
|
"don't run it in automated tests.")
|
||||||
|
@pytest.mark.parametrize("m", [64, 512, 222, 33, 1])
|
||||||
|
@pytest.mark.parametrize("n", [128, 2048, 256, 1024])
|
||||||
|
@pytest.mark.parametrize("k", [128, 1024, 512])
|
||||||
|
@pytest.mark.parametrize("e", [4, 8, 64])
|
||||||
|
@pytest.mark.parametrize("topk", [2, 6])
|
||||||
|
@pytest.mark.parametrize("group_size", [-1, 32, 64, 128])
|
||||||
|
@pytest.mark.parametrize("act_order", [True, False])
|
||||||
|
def test_marlin_moe_mmm(
|
||||||
|
m: int,
|
||||||
|
n: int,
|
||||||
|
k: int,
|
||||||
|
e: int,
|
||||||
|
topk: int,
|
||||||
|
group_size: int,
|
||||||
|
act_order: bool,
|
||||||
|
):
|
||||||
|
if topk > e:
|
||||||
|
return
|
||||||
|
|
||||||
|
# Filter act_order
|
||||||
|
if act_order:
|
||||||
|
if group_size == -1:
|
||||||
|
return
|
||||||
|
if group_size == k:
|
||||||
|
return
|
||||||
|
|
||||||
|
quant_type = scalar_types.uint4b8
|
||||||
|
dtype = torch.float16
|
||||||
|
a = torch.randn((m, k), device="cuda", dtype=dtype) / 10
|
||||||
|
w = torch.randn((e, n, k), device="cuda", dtype=dtype) / 10
|
||||||
|
|
||||||
|
w_ref_l = []
|
||||||
|
qweights_l = []
|
||||||
|
scales_l = []
|
||||||
|
g_idx_l = []
|
||||||
|
sort_indices_l = []
|
||||||
|
|
||||||
|
for i in range(w.shape[0]):
|
||||||
|
test_perm = torch.randperm(k)
|
||||||
|
w_ref, qweight, scales, g_idx, sort_indices, _ = marlin_quantize(
|
||||||
|
w[i].transpose(1, 0), quant_type, group_size, act_order, test_perm)
|
||||||
|
w_ref_l.append(w_ref)
|
||||||
|
qweights_l.append(qweight)
|
||||||
|
scales_l.append(scales)
|
||||||
|
g_idx_l.append(g_idx)
|
||||||
|
sort_indices_l.append(sort_indices)
|
||||||
|
|
||||||
|
w_ref = stack_and_dev(w_ref_l)
|
||||||
|
qweight = stack_and_dev(qweights_l).contiguous()
|
||||||
|
scales = stack_and_dev(scales_l)
|
||||||
|
g_idx = stack_and_dev(g_idx_l)
|
||||||
|
sort_indices = stack_and_dev(sort_indices_l)
|
||||||
|
|
||||||
|
score = torch.randn((m, e), device="cuda", dtype=dtype)
|
||||||
|
marlin_output = single_marlin_moe(a,
|
||||||
|
qweight,
|
||||||
|
scales,
|
||||||
|
score,
|
||||||
|
g_idx,
|
||||||
|
sort_indices,
|
||||||
|
topk,
|
||||||
|
renormalize=False)
|
||||||
|
torch_output = torch_moe_single(a, w_ref.transpose(1, 2), score, topk)
|
||||||
|
|
||||||
|
assert compute_max_diff(marlin_output, torch_output) < 1e-2
|
||||||
|
3
tests/weight_loading/models-large.txt
Normal file
3
tests/weight_loading/models-large.txt
Normal file
@ -0,0 +1,3 @@
|
|||||||
|
compressed-tensors, nm-testing/Mixtral-8x7B-Instruct-v0.1-W4A16-quantized, main
|
||||||
|
compressed-tensors, nm-testing/Mixtral-8x7B-Instruct-v0.1-W4A16-channel-quantized, main
|
||||||
|
gptq_marlin, TheBloke/Mixtral-8x7B-v0.1-GPTQ, main
|
@ -19,8 +19,6 @@ compressed-tensors, nm-testing/tinyllama-oneshot-w8a16-per-channel, main
|
|||||||
compressed-tensors, nm-testing/Meta-Llama-3-8B-FP8-compressed-tensors-test, main
|
compressed-tensors, nm-testing/Meta-Llama-3-8B-FP8-compressed-tensors-test, main
|
||||||
compressed-tensors, nm-testing/Phi-3-mini-128k-instruct-FP8, main
|
compressed-tensors, nm-testing/Phi-3-mini-128k-instruct-FP8, main
|
||||||
compressed-tensors, neuralmagic/Phi-3-medium-128k-instruct-quantized.w4a16, main
|
compressed-tensors, neuralmagic/Phi-3-medium-128k-instruct-quantized.w4a16, main
|
||||||
compressed-tensors, nm-testing/Mixtral-8x7B-Instruct-v0.1-W4A16-quantized, main
|
|
||||||
compressed-tensors, nm-testing/Mixtral-8x7B-Instruct-v0.1-W4A16-channel-quantized, main
|
|
||||||
compressed-tensors, nm-testing/TinyLlama-1.1B-Chat-v1.0-actorder-group, main
|
compressed-tensors, nm-testing/TinyLlama-1.1B-Chat-v1.0-actorder-group, main
|
||||||
awq, casperhansen/mixtral-instruct-awq, main
|
awq, casperhansen/mixtral-instruct-awq, main
|
||||||
awq_marlin, casperhansen/mixtral-instruct-awq, main
|
awq_marlin, casperhansen/mixtral-instruct-awq, main
|
||||||
|
@ -2,16 +2,22 @@ from vllm.model_executor.layers.fused_moe.layer import (
|
|||||||
FusedMoE, FusedMoEMethodBase, FusedMoeWeightScaleSupported)
|
FusedMoE, FusedMoEMethodBase, FusedMoeWeightScaleSupported)
|
||||||
from vllm.triton_utils import HAS_TRITON
|
from vllm.triton_utils import HAS_TRITON
|
||||||
|
|
||||||
__all__ = ["FusedMoE", "FusedMoEMethodBase", "FusedMoeWeightScaleSupported"]
|
__all__ = [
|
||||||
|
"FusedMoE",
|
||||||
|
"FusedMoEMethodBase",
|
||||||
|
"FusedMoeWeightScaleSupported",
|
||||||
|
]
|
||||||
|
|
||||||
if HAS_TRITON:
|
if HAS_TRITON:
|
||||||
|
from vllm.model_executor.layers.fused_moe.fused_marlin_moe import (
|
||||||
|
fused_marlin_moe, single_marlin_moe)
|
||||||
from vllm.model_executor.layers.fused_moe.fused_moe import (
|
from vllm.model_executor.layers.fused_moe.fused_moe import (
|
||||||
fused_experts, fused_marlin_moe, fused_moe, fused_topk,
|
fused_experts, fused_moe, fused_topk, get_config_file_name,
|
||||||
get_config_file_name, grouped_topk)
|
grouped_topk)
|
||||||
|
|
||||||
__all__ += [
|
__all__ += [
|
||||||
"fused_marlin_moe",
|
"fused_marlin_moe",
|
||||||
|
"single_marlin_moe",
|
||||||
"fused_moe",
|
"fused_moe",
|
||||||
"fused_topk",
|
"fused_topk",
|
||||||
"fused_experts",
|
"fused_experts",
|
||||||
|
219
vllm/model_executor/layers/fused_moe/fused_marlin_moe.py
Normal file
219
vllm/model_executor/layers/fused_moe/fused_marlin_moe.py
Normal file
@ -0,0 +1,219 @@
|
|||||||
|
"""Fused MoE utilities for GPTQ."""
|
||||||
|
import functools
|
||||||
|
from typing import Any, Dict, Optional
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from vllm import _custom_ops as ops
|
||||||
|
from vllm.model_executor.layers.fused_moe.fused_moe import (
|
||||||
|
fused_topk, moe_align_block_size, try_get_optimal_moe_config)
|
||||||
|
|
||||||
|
|
||||||
|
def single_marlin_moe(
|
||||||
|
hidden_states: torch.Tensor,
|
||||||
|
w: torch.Tensor,
|
||||||
|
scales: torch.Tensor,
|
||||||
|
gating_output: torch.Tensor,
|
||||||
|
g_idx: torch.Tensor,
|
||||||
|
perm: torch.Tensor,
|
||||||
|
topk: int,
|
||||||
|
renormalize: bool,
|
||||||
|
override_config: Optional[Dict[str, Any]] = None) -> torch.Tensor:
|
||||||
|
"""
|
||||||
|
This function computes the multiplication of hidden_states with expert
|
||||||
|
weights used in Marlin MoE, using weights w and top-k gating mechanism.
|
||||||
|
Its purpose is testing and debugging the fused MoE kernel.
|
||||||
|
|
||||||
|
Parameters:
|
||||||
|
- hidden_states (torch.Tensor): The input tensor to the Marlin Mul.
|
||||||
|
- w (torch.Tensor): The set of expert weights.
|
||||||
|
- scales (torch.Tensor): The quantization scales.
|
||||||
|
- gating_output (torch.Tensor): The output of the gating operation
|
||||||
|
(before softmax).
|
||||||
|
- g_idx (torch.Tensor): The act_order indices.
|
||||||
|
- perm (torch.Tensor): The act_order input permutation.
|
||||||
|
- topk (int): The number of top-k experts to select.
|
||||||
|
- renormalize (bool): If True, renormalize the top-k weights to sum to 1.
|
||||||
|
- override_config (Optional[Dict[str, Any]]): Optional override
|
||||||
|
for the kernel configuration.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
- torch.Tensor: The output tensor after applying the MoE layer.
|
||||||
|
"""
|
||||||
|
# Check constraints.
|
||||||
|
assert hidden_states.shape[0] == gating_output.shape[0], (
|
||||||
|
"Number of tokens mismatch")
|
||||||
|
assert hidden_states.shape[1] == w.shape[1] * 16, "Hidden size mismatch"
|
||||||
|
assert gating_output.shape[1] == w.shape[0], "Number of experts mismatch"
|
||||||
|
assert hidden_states.is_contiguous(), "Hidden_states must be contiguous"
|
||||||
|
assert w.is_contiguous(), "Expert weights must be contiguous"
|
||||||
|
assert hidden_states.dtype == torch.float16
|
||||||
|
|
||||||
|
M, K = hidden_states.shape
|
||||||
|
E = w.shape[0]
|
||||||
|
N = w.shape[2] // 2
|
||||||
|
|
||||||
|
topk_weights, topk_ids = fused_topk(hidden_states, gating_output, topk,
|
||||||
|
renormalize)
|
||||||
|
|
||||||
|
# This might not be an optimal config for a single MMM
|
||||||
|
get_config_func = functools.partial(try_get_optimal_moe_config,
|
||||||
|
w.shape,
|
||||||
|
w.shape,
|
||||||
|
topk_ids.shape[1],
|
||||||
|
None,
|
||||||
|
override_config=override_config,
|
||||||
|
is_marlin=True)
|
||||||
|
config = get_config_func(M)
|
||||||
|
|
||||||
|
block_size_m = config['BLOCK_SIZE_M']
|
||||||
|
|
||||||
|
sorted_token_ids, _, _ = moe_align_block_size(topk_ids, block_size_m, E)
|
||||||
|
|
||||||
|
max_workspace_size = (N // 64) * 16
|
||||||
|
workspace = torch.zeros(max_workspace_size,
|
||||||
|
dtype=torch.int,
|
||||||
|
device="cuda",
|
||||||
|
requires_grad=False)
|
||||||
|
|
||||||
|
intermediate_cache = torch.ops._moe_C.marlin_gemm_moe(
|
||||||
|
hidden_states, w, sorted_token_ids, topk_weights, topk_ids, scales,
|
||||||
|
g_idx, perm, workspace, M, N, K, True, E, topk, block_size_m, True,
|
||||||
|
False)
|
||||||
|
|
||||||
|
return torch.sum(intermediate_cache.view(*intermediate_cache.shape), dim=1)
|
||||||
|
|
||||||
|
|
||||||
|
def fused_marlin_moe(
|
||||||
|
hidden_states: torch.Tensor,
|
||||||
|
w1: torch.Tensor,
|
||||||
|
w2: torch.Tensor,
|
||||||
|
gating_output: torch.Tensor,
|
||||||
|
g_idx1: torch.Tensor,
|
||||||
|
g_idx2: torch.Tensor,
|
||||||
|
perm1: torch.Tensor,
|
||||||
|
perm2: torch.Tensor,
|
||||||
|
topk_weights: torch.Tensor,
|
||||||
|
topk_ids: torch.Tensor,
|
||||||
|
override_config: Optional[Dict[str, Any]] = None,
|
||||||
|
w1_scale: Optional[torch.Tensor] = None,
|
||||||
|
w2_scale: Optional[torch.Tensor] = None,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
"""
|
||||||
|
This function computes a Mixture of Experts (MoE) layer using two sets of
|
||||||
|
weights, w1 and w2, and top-k gating mechanism.
|
||||||
|
|
||||||
|
Parameters:
|
||||||
|
- hidden_states (torch.Tensor): The input tensor to the MoE layer.
|
||||||
|
- w1 (torch.Tensor): The first set of expert weights.
|
||||||
|
- w2 (torch.Tensor): The second set of expert weights.
|
||||||
|
- gating_output (torch.Tensor): The output of the gating operation
|
||||||
|
(before softmax).
|
||||||
|
- g_idx1 (torch.Tensor): The first set of act_order indices.
|
||||||
|
- g_idx2 (torch.Tensor): The second set of act_order indices.
|
||||||
|
- perm1 (torch.Tensor): The first act_order input permutation.
|
||||||
|
- perm2 (torch.Tensor): The second act_order input permutation.
|
||||||
|
- topk_weights (torch.Tensor): Top-k weights.
|
||||||
|
- topk_ids (torch.Tensor): Indices of topk-k elements.
|
||||||
|
- renormalize (bool): If True, renormalize the top-k weights to sum to 1.
|
||||||
|
- override_config (Optional[Dict[str, Any]]): Optional override
|
||||||
|
for the kernel configuration.
|
||||||
|
- w1_scale (Optional[torch.Tensor]): Optional scale to be used for
|
||||||
|
w1.
|
||||||
|
- w2_scale (Optional[torch.Tensor]): Optional scale to be used for
|
||||||
|
w2.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
- torch.Tensor: The output tensor after applying the MoE layer.
|
||||||
|
"""
|
||||||
|
# Check constraints.
|
||||||
|
assert hidden_states.shape[0] == gating_output.shape[
|
||||||
|
0], "Number of tokens mismatch"
|
||||||
|
assert hidden_states.shape[
|
||||||
|
1] == w1.shape[1] * 16, "Hidden size mismatch w1"
|
||||||
|
assert hidden_states.shape[
|
||||||
|
1] == w2.shape[2] // 2, "Hidden size mismatch w2"
|
||||||
|
assert gating_output.shape[1] == w1.shape[0], "Number of experts mismatch"
|
||||||
|
assert hidden_states.is_contiguous(), "Hidden_states must be contiguous"
|
||||||
|
assert w1.is_contiguous(), "Expert weights1 must be contiguous"
|
||||||
|
assert w2.is_contiguous(), "Expert weights2 must be contiguous"
|
||||||
|
assert hidden_states.dtype == torch.float16
|
||||||
|
|
||||||
|
M, K = hidden_states.shape
|
||||||
|
E = w1.shape[0]
|
||||||
|
N = w2.shape[1] * 16
|
||||||
|
topk = topk_ids.shape[1]
|
||||||
|
|
||||||
|
get_config_func = functools.partial(
|
||||||
|
try_get_optimal_moe_config,
|
||||||
|
w1.shape,
|
||||||
|
w2.shape,
|
||||||
|
topk_ids.shape[1],
|
||||||
|
None,
|
||||||
|
override_config=override_config,
|
||||||
|
is_marlin=True,
|
||||||
|
)
|
||||||
|
config = get_config_func(M)
|
||||||
|
|
||||||
|
block_size_m = config["BLOCK_SIZE_M"]
|
||||||
|
|
||||||
|
sorted_token_ids, _, _ = moe_align_block_size(topk_ids, block_size_m, E)
|
||||||
|
|
||||||
|
max_workspace_size = ((M + 255) // 256) * (max(2 * N, K) // 64) * 16
|
||||||
|
workspace = torch.zeros(max_workspace_size,
|
||||||
|
dtype=torch.int,
|
||||||
|
device="cuda",
|
||||||
|
requires_grad=False)
|
||||||
|
|
||||||
|
intermediate_cache2 = torch.empty(
|
||||||
|
(M * topk_ids.shape[1], N),
|
||||||
|
device=hidden_states.device,
|
||||||
|
dtype=hidden_states.dtype,
|
||||||
|
)
|
||||||
|
|
||||||
|
intermediate_cache1 = torch.ops._moe_C.marlin_gemm_moe(
|
||||||
|
hidden_states,
|
||||||
|
w1,
|
||||||
|
sorted_token_ids,
|
||||||
|
topk_weights,
|
||||||
|
topk_ids,
|
||||||
|
w1_scale,
|
||||||
|
g_idx1,
|
||||||
|
perm1,
|
||||||
|
workspace,
|
||||||
|
M,
|
||||||
|
2 * N,
|
||||||
|
K,
|
||||||
|
True,
|
||||||
|
E,
|
||||||
|
topk,
|
||||||
|
block_size_m,
|
||||||
|
True,
|
||||||
|
False,
|
||||||
|
)
|
||||||
|
|
||||||
|
ops.silu_and_mul(intermediate_cache2, intermediate_cache1.view(-1, 2 * N))
|
||||||
|
|
||||||
|
intermediate_cache3 = torch.ops._moe_C.marlin_gemm_moe(
|
||||||
|
intermediate_cache2,
|
||||||
|
w2,
|
||||||
|
sorted_token_ids,
|
||||||
|
topk_weights,
|
||||||
|
topk_ids,
|
||||||
|
w2_scale,
|
||||||
|
g_idx2,
|
||||||
|
perm2,
|
||||||
|
workspace,
|
||||||
|
M,
|
||||||
|
K,
|
||||||
|
N,
|
||||||
|
True,
|
||||||
|
E,
|
||||||
|
topk,
|
||||||
|
block_size_m,
|
||||||
|
False,
|
||||||
|
True,
|
||||||
|
)
|
||||||
|
|
||||||
|
return torch.sum(intermediate_cache3.view(*intermediate_cache3.shape),
|
||||||
|
dim=1)
|
@ -323,15 +323,22 @@ def get_moe_configs(E: int, N: int,
|
|||||||
return None
|
return None
|
||||||
|
|
||||||
|
|
||||||
def get_default_config(M: int, E: int, N: int, K: int, topk: int,
|
def get_default_config(
|
||||||
|
M: int,
|
||||||
|
E: int,
|
||||||
|
N: int,
|
||||||
|
K: int,
|
||||||
|
topk: int,
|
||||||
dtype: Optional[str],
|
dtype: Optional[str],
|
||||||
is_marlin: bool) -> Dict[str, int]:
|
is_marlin: bool,
|
||||||
|
) -> Dict[str, int]:
|
||||||
config = {
|
config = {
|
||||||
'BLOCK_SIZE_M': 64,
|
'BLOCK_SIZE_M': 64,
|
||||||
'BLOCK_SIZE_N': 64,
|
'BLOCK_SIZE_N': 64,
|
||||||
'BLOCK_SIZE_K': 32,
|
'BLOCK_SIZE_K': 32,
|
||||||
'GROUP_SIZE_M': 8
|
'GROUP_SIZE_M': 8
|
||||||
}
|
}
|
||||||
|
# A heuristic: fused marlin works faster with this config for small M
|
||||||
if M <= E or (is_marlin and M <= 32):
|
if M <= E or (is_marlin and M <= 32):
|
||||||
config = {
|
config = {
|
||||||
'BLOCK_SIZE_M': 16,
|
'BLOCK_SIZE_M': 16,
|
||||||
@ -342,14 +349,15 @@ def get_default_config(M: int, E: int, N: int, K: int, topk: int,
|
|||||||
return config
|
return config
|
||||||
|
|
||||||
|
|
||||||
def try_get_optimal_moe_config(w1_shape: Tuple[int, ...],
|
def try_get_optimal_moe_config(
|
||||||
|
w1_shape: Tuple[int, ...],
|
||||||
w2_shape: Tuple[int, ...],
|
w2_shape: Tuple[int, ...],
|
||||||
top_k: int,
|
top_k: int,
|
||||||
dtype: Optional[str],
|
dtype: Optional[str],
|
||||||
M: int,
|
M: int,
|
||||||
override_config: Optional[Dict[str,
|
override_config: Optional[Dict[str, Any]] = None,
|
||||||
Any]] = None,
|
is_marlin: bool = False,
|
||||||
is_marlin: bool = False):
|
):
|
||||||
if override_config:
|
if override_config:
|
||||||
config = override_config
|
config = override_config
|
||||||
else:
|
else:
|
||||||
@ -391,6 +399,7 @@ def fused_topk(
|
|||||||
topk,
|
topk,
|
||||||
dtype=torch.int32,
|
dtype=torch.int32,
|
||||||
device=hidden_states.device)
|
device=hidden_states.device)
|
||||||
|
|
||||||
ops.topk_softmax(
|
ops.topk_softmax(
|
||||||
topk_weights,
|
topk_weights,
|
||||||
topk_ids,
|
topk_ids,
|
||||||
@ -437,113 +446,6 @@ def grouped_topk(hidden_states: torch.Tensor,
|
|||||||
return topk_weights, topk_ids
|
return topk_weights, topk_ids
|
||||||
|
|
||||||
|
|
||||||
def fused_marlin_moe(hidden_states: torch.Tensor,
|
|
||||||
w1: torch.Tensor,
|
|
||||||
w2: torch.Tensor,
|
|
||||||
gating_output: torch.Tensor,
|
|
||||||
g_idx1: torch.Tensor,
|
|
||||||
g_idx2: torch.Tensor,
|
|
||||||
rand_perm1: torch.Tensor,
|
|
||||||
rand_perm2: torch.Tensor,
|
|
||||||
topk: int,
|
|
||||||
custom_routing_function: Optional[Callable] = None,
|
|
||||||
renormalize: bool = True,
|
|
||||||
override_config: Optional[Dict[str, Any]] = None,
|
|
||||||
use_fp8: bool = False,
|
|
||||||
w1_scale: Optional[torch.Tensor] = None,
|
|
||||||
w2_scale: Optional[torch.Tensor] = None) -> torch.Tensor:
|
|
||||||
"""
|
|
||||||
This function computes a Mixture of Experts (MoE) layer using two sets of
|
|
||||||
weights, w1 and w2, and top-k gating mechanism.
|
|
||||||
Parameters:
|
|
||||||
- hidden_states (torch.Tensor): The input tensor to the MoE layer.
|
|
||||||
- w1 (torch.Tensor): The first set of expert weights.
|
|
||||||
- w2 (torch.Tensor): The second set of expert weights.
|
|
||||||
- gating_output (torch.Tensor): The output of the gating operation
|
|
||||||
(before softmax).
|
|
||||||
- topk (int): The number of top-k experts to select.
|
|
||||||
- renormalize (bool): If True, renormalize the top-k weights to sum to 1.
|
|
||||||
- inplace (bool): If True, perform the operation in-place.
|
|
||||||
Defaults to False.
|
|
||||||
- override_config (Optional[Dict[str, Any]]): Optional override
|
|
||||||
for the kernel configuration.
|
|
||||||
- use_fp8 (bool): If True, use fp8 arithmetic to compute the inner
|
|
||||||
products for w1 and w2. Defaults to False.
|
|
||||||
- w1_scale (Optional[torch.Tensor]): Optional scale to be used for
|
|
||||||
w1.
|
|
||||||
- w2_scale (Optional[torch.Tensor]): Optional scale to be used for
|
|
||||||
w2.
|
|
||||||
Returns:
|
|
||||||
- torch.Tensor: The output tensor after applying the MoE layer.
|
|
||||||
"""
|
|
||||||
# Check constraints.
|
|
||||||
assert hidden_states.shape[0] == gating_output.shape[0], (
|
|
||||||
"Number of tokens mismatch")
|
|
||||||
assert hidden_states.shape[
|
|
||||||
1] == w1.shape[1] * 16, "Hidden size mismatch w1"
|
|
||||||
assert hidden_states.shape[
|
|
||||||
1] == w2.shape[2] // 2, "Hidden size mismatch w2"
|
|
||||||
assert gating_output.shape[1] == w1.shape[0], "Number of experts mismatch"
|
|
||||||
assert hidden_states.is_contiguous(), "Hidden_states must be contiguous"
|
|
||||||
assert w1.is_contiguous(), "Expert weights1 must be contiguous"
|
|
||||||
assert w2.is_contiguous(), "Expert weights2 must be contiguous"
|
|
||||||
assert hidden_states.dtype in [
|
|
||||||
torch.float32, torch.float16, torch.bfloat16
|
|
||||||
]
|
|
||||||
|
|
||||||
#TODO fp8 is not implemented yet
|
|
||||||
assert not use_fp8
|
|
||||||
|
|
||||||
M, K = hidden_states.shape
|
|
||||||
E = w1.shape[0]
|
|
||||||
N = w2.shape[1] * 16
|
|
||||||
|
|
||||||
if custom_routing_function is None:
|
|
||||||
topk_weights, topk_ids = fused_topk(hidden_states, gating_output, topk,
|
|
||||||
renormalize)
|
|
||||||
else:
|
|
||||||
topk_weights, topk_ids = custom_routing_function(
|
|
||||||
hidden_states, gating_output, topk, renormalize)
|
|
||||||
|
|
||||||
get_config_func = functools.partial(try_get_optimal_moe_config,
|
|
||||||
w1.shape,
|
|
||||||
w2.shape,
|
|
||||||
topk_ids.shape[1],
|
|
||||||
"float8" if use_fp8 else None,
|
|
||||||
override_config=override_config,
|
|
||||||
is_marlin=True)
|
|
||||||
config = get_config_func(M)
|
|
||||||
|
|
||||||
block_size_m = config['BLOCK_SIZE_M']
|
|
||||||
|
|
||||||
sorted_token_ids, _, _ = moe_align_block_size(topk_ids, block_size_m, E)
|
|
||||||
|
|
||||||
max_workspace_size = ((M + 255) // 256) * (max(2 * N, K) // 64) * 16
|
|
||||||
workspace = torch.zeros(max_workspace_size,
|
|
||||||
dtype=torch.int,
|
|
||||||
device="cuda",
|
|
||||||
requires_grad=False)
|
|
||||||
|
|
||||||
intermediate_cache2 = torch.empty((M * topk_ids.shape[1], N),
|
|
||||||
device=hidden_states.device,
|
|
||||||
dtype=hidden_states.dtype)
|
|
||||||
|
|
||||||
intermediate_cache1 = torch.ops._moe_C.marlin_gemm_moe(
|
|
||||||
hidden_states, w1, sorted_token_ids, topk_weights, topk_ids, w1_scale,
|
|
||||||
g_idx1, rand_perm1, workspace, M, 2 * N, K, True, E, topk,
|
|
||||||
block_size_m, True, False)
|
|
||||||
|
|
||||||
ops.silu_and_mul(intermediate_cache2, intermediate_cache1.view(-1, 2 * N))
|
|
||||||
|
|
||||||
intermediate_cache3 = torch.ops._moe_C.marlin_gemm_moe(
|
|
||||||
intermediate_cache2, w2, sorted_token_ids, topk_weights, topk_ids,
|
|
||||||
w2_scale, g_idx2, rand_perm2, workspace, M, K, N, True, E, topk,
|
|
||||||
block_size_m, False, True)
|
|
||||||
|
|
||||||
return torch.sum(intermediate_cache3.view(*intermediate_cache3.shape),
|
|
||||||
dim=1)
|
|
||||||
|
|
||||||
|
|
||||||
def get_config_dtype_str(dtype: torch.dtype,
|
def get_config_dtype_str(dtype: torch.dtype,
|
||||||
use_int8_w8a16: Optional[bool] = False,
|
use_int8_w8a16: Optional[bool] = False,
|
||||||
use_fp8_w8a8: Optional[bool] = False):
|
use_fp8_w8a8: Optional[bool] = False):
|
||||||
|
@ -306,10 +306,28 @@ class FusedMoE(torch.nn.Module):
|
|||||||
# Input scales can be loaded directly and should be equal.
|
# Input scales can be loaded directly and should be equal.
|
||||||
param_data[expert_id] = loaded_weight
|
param_data[expert_id] = loaded_weight
|
||||||
|
|
||||||
|
def _load_g_idx(self, shard_id: str, expert_data: torch.Tensor,
|
||||||
|
shard_dim: int, loaded_weight: torch.tensor, tp_rank: int):
|
||||||
|
|
||||||
|
if shard_id == "w2":
|
||||||
|
self._load_w2(shard_id=shard_id,
|
||||||
|
shard_dim=shard_dim,
|
||||||
|
loaded_weight=loaded_weight,
|
||||||
|
expert_data=expert_data,
|
||||||
|
tp_rank=tp_rank)
|
||||||
|
else:
|
||||||
|
assert shard_id in ("w1", "w3")
|
||||||
|
expert_data.copy_(loaded_weight)
|
||||||
|
|
||||||
def weight_loader(self, param: torch.nn.Parameter,
|
def weight_loader(self, param: torch.nn.Parameter,
|
||||||
loaded_weight: torch.Tensor, weight_name: str,
|
loaded_weight: torch.Tensor, weight_name: str,
|
||||||
shard_id: str, expert_id: int) -> None:
|
shard_id: str, expert_id: int) -> None:
|
||||||
|
|
||||||
|
# compressed-tensors represents weights on disk which are flipped
|
||||||
|
loaded_weight = loaded_weight.t().contiguous() if (
|
||||||
|
self.quant_method.__class__.__name__
|
||||||
|
== "CompressedTensorsMoEMethod") else loaded_weight
|
||||||
|
|
||||||
if shard_id not in ("w1", "w2", "w3"):
|
if shard_id not in ("w1", "w2", "w3"):
|
||||||
raise ValueError(f"shard_id must be ['w1','w2','w3'] but "
|
raise ValueError(f"shard_id must be ['w1','w2','w3'] but "
|
||||||
f"got {shard_id}.")
|
f"got {shard_id}.")
|
||||||
@ -325,19 +343,41 @@ class FusedMoE(torch.nn.Module):
|
|||||||
expert_data = param.data[expert_id]
|
expert_data = param.data[expert_id]
|
||||||
tp_rank = get_tensor_model_parallel_rank()
|
tp_rank = get_tensor_model_parallel_rank()
|
||||||
|
|
||||||
# is_transposed: whether or not the parameter is transposed on disk
|
# is_transposed: if the dim to shard the weight
|
||||||
# If transposed, the loaded weight will be transposed and the dim
|
# should be flipped. Required by GPTQ, compressed-tensors
|
||||||
# to shard the loaded weight will be flipped.
|
# should be whatever dimension intermediate_size is
|
||||||
is_transposed = getattr(param, "is_transposed", False)
|
is_transposed = getattr(param, "is_transposed", False)
|
||||||
shard_dim = SHARD_ID_TO_SHARDED_DIM[shard_id]
|
shard_dim = SHARD_ID_TO_SHARDED_DIM[shard_id]
|
||||||
if is_transposed:
|
if is_transposed:
|
||||||
loaded_weight = loaded_weight.t().contiguous()
|
|
||||||
shard_dim = ~shard_dim
|
shard_dim = ~shard_dim
|
||||||
|
|
||||||
# Case weight_scales
|
# Case input scale: input_scale loading is only supported for fp8
|
||||||
if "weight_scale" in weight_name:
|
if "input_scale" in weight_name:
|
||||||
# load the weight scaling based on the quantization scheme
|
if param.data[expert_id] != 1 and (param.data[expert_id] -
|
||||||
# supported weight scales can be found in
|
loaded_weight).abs() > 1e-5:
|
||||||
|
raise ValueError(
|
||||||
|
"input_scales of w1 and w3 of a layer "
|
||||||
|
f"must be equal. But got {param.data[expert_id]} "
|
||||||
|
f"vs. {loaded_weight}")
|
||||||
|
|
||||||
|
self._load_single_value(param=param,
|
||||||
|
loaded_weight=loaded_weight,
|
||||||
|
expert_id=expert_id)
|
||||||
|
return
|
||||||
|
|
||||||
|
# Case g_idx
|
||||||
|
if "g_idx" in weight_name:
|
||||||
|
self._load_g_idx(shard_dim=0,
|
||||||
|
shard_id=shard_id,
|
||||||
|
loaded_weight=loaded_weight,
|
||||||
|
expert_data=expert_data,
|
||||||
|
tp_rank=tp_rank)
|
||||||
|
return
|
||||||
|
|
||||||
|
# Case weight scales and zero_points
|
||||||
|
if ("scale" in weight_name or "zero" in weight_name):
|
||||||
|
# load the weight scales and zp based on the quantization scheme
|
||||||
|
# supported weight scales/zp can be found in
|
||||||
# FusedMoeWeightScaleSupported
|
# FusedMoeWeightScaleSupported
|
||||||
# TODO @dsikka: once hardened, refactor to use vLLM Parameters
|
# TODO @dsikka: once hardened, refactor to use vLLM Parameters
|
||||||
# specific to each case
|
# specific to each case
|
||||||
@ -366,22 +406,9 @@ class FusedMoE(torch.nn.Module):
|
|||||||
f"quant method must be one of {WEIGHT_SCALE_SUPPORTED}")
|
f"quant method must be one of {WEIGHT_SCALE_SUPPORTED}")
|
||||||
return
|
return
|
||||||
|
|
||||||
|
# Case weight_shape
|
||||||
if "weight_shape" in weight_name:
|
if "weight_shape" in weight_name:
|
||||||
self._load_single_value(param=param,
|
# only required by compressed-tensors
|
||||||
loaded_weight=loaded_weight,
|
|
||||||
expert_id=expert_id)
|
|
||||||
return
|
|
||||||
|
|
||||||
# Case input scale
|
|
||||||
if "input_scale" in weight_name:
|
|
||||||
# Note: input_scale loading is only supported for fp8
|
|
||||||
if param.data[expert_id] != 1 and (param.data[expert_id] -
|
|
||||||
loaded_weight).abs() > 1e-5:
|
|
||||||
raise ValueError(
|
|
||||||
"input_scales of w1 and w3 of a layer "
|
|
||||||
f"must be equal. But got {param.data[expert_id]} "
|
|
||||||
f"vs. {loaded_weight}")
|
|
||||||
|
|
||||||
self._load_single_value(param=param,
|
self._load_single_value(param=param,
|
||||||
loaded_weight=loaded_weight,
|
loaded_weight=loaded_weight,
|
||||||
expert_id=expert_id)
|
expert_id=expert_id)
|
||||||
|
@ -5,9 +5,7 @@ from typing import Callable, List, Optional
|
|||||||
import torch
|
import torch
|
||||||
|
|
||||||
from vllm import _custom_ops as ops
|
from vllm import _custom_ops as ops
|
||||||
from vllm.model_executor.layers.fused_moe import FusedMoEMethodBase
|
from vllm.model_executor.layers.fused_moe import FusedMoE, FusedMoEMethodBase
|
||||||
from vllm.model_executor.layers.quantization.compressed_tensors.schemes import (
|
|
||||||
WNA16_SUPPORTED_BITS)
|
|
||||||
from vllm.model_executor.layers.quantization.compressed_tensors.utils import (
|
from vllm.model_executor.layers.quantization.compressed_tensors.utils import (
|
||||||
CompressionFormat)
|
CompressionFormat)
|
||||||
from vllm.model_executor.utils import set_weight_attrs
|
from vllm.model_executor.utils import set_weight_attrs
|
||||||
@ -40,11 +38,10 @@ class CompressedTensorsMoEMethod(FusedMoEMethodBase):
|
|||||||
|
|
||||||
if not (self.quant_config.quant_format
|
if not (self.quant_config.quant_format
|
||||||
== CompressionFormat.pack_quantized.value
|
== CompressionFormat.pack_quantized.value
|
||||||
and self.num_bits in WNA16_SUPPORTED_BITS):
|
and self.num_bits == 4):
|
||||||
raise ValueError("For Fused MoE layers, only ",
|
raise ValueError("For Fused MoE layers, only ",
|
||||||
f"{CompressionFormat.pack_quantized.value} ",
|
f"{CompressionFormat.pack_quantized.value} ",
|
||||||
"is supported for the following bits: ",
|
"is supported for 4 bits")
|
||||||
f"{WNA16_SUPPORTED_BITS}")
|
|
||||||
|
|
||||||
def create_weights(self, layer: torch.nn.Module, num_experts: int,
|
def create_weights(self, layer: torch.nn.Module, num_experts: int,
|
||||||
hidden_size: int, intermediate_size: int,
|
hidden_size: int, intermediate_size: int,
|
||||||
@ -269,10 +266,21 @@ class CompressedTensorsMoEMethod(FusedMoEMethodBase):
|
|||||||
custom_routing_function: Optional[Callable] = None,
|
custom_routing_function: Optional[Callable] = None,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
|
|
||||||
from vllm.model_executor.layers.fused_moe.fused_moe import (
|
from vllm.model_executor.layers.fused_moe.fused_marlin_moe import (
|
||||||
fused_marlin_moe)
|
fused_marlin_moe)
|
||||||
|
|
||||||
return fused_marlin_moe(x,
|
topk_weights, topk_ids = FusedMoE.select_experts(
|
||||||
|
hidden_states=x,
|
||||||
|
router_logits=router_logits,
|
||||||
|
use_grouped_topk=use_grouped_topk,
|
||||||
|
top_k=top_k,
|
||||||
|
renormalize=renormalize,
|
||||||
|
topk_group=topk_group,
|
||||||
|
num_expert_group=num_expert_group,
|
||||||
|
custom_routing_function=custom_routing_function)
|
||||||
|
|
||||||
|
return fused_marlin_moe(
|
||||||
|
x,
|
||||||
layer.w13_weight_packed,
|
layer.w13_weight_packed,
|
||||||
layer.w2_weight_packed,
|
layer.w2_weight_packed,
|
||||||
router_logits,
|
router_logits,
|
||||||
@ -280,8 +288,8 @@ class CompressedTensorsMoEMethod(FusedMoEMethodBase):
|
|||||||
layer.w2_g_idx,
|
layer.w2_g_idx,
|
||||||
layer.w13_g_idx_sort_indices,
|
layer.w13_g_idx_sort_indices,
|
||||||
layer.w2_g_idx_sort_indices,
|
layer.w2_g_idx_sort_indices,
|
||||||
top_k,
|
topk_weights,
|
||||||
custom_routing_function,
|
topk_ids,
|
||||||
renormalize=renormalize,
|
|
||||||
w1_scale=layer.w13_weight_scale,
|
w1_scale=layer.w13_weight_scale,
|
||||||
w2_scale=layer.w2_weight_scale)
|
w2_scale=layer.w2_weight_scale,
|
||||||
|
)
|
||||||
|
@ -22,7 +22,7 @@ from vllm.scalar_type import scalar_types
|
|||||||
__all__ = ["CompressedTensorsWNA16"]
|
__all__ = ["CompressedTensorsWNA16"]
|
||||||
WNA16_SUPPORTED_TYPES_MAP = {
|
WNA16_SUPPORTED_TYPES_MAP = {
|
||||||
4: scalar_types.uint4b8,
|
4: scalar_types.uint4b8,
|
||||||
8: scalar_types.uint8b128,
|
8: scalar_types.uint8b128
|
||||||
}
|
}
|
||||||
WNA16_SUPPORTED_BITS = list(WNA16_SUPPORTED_TYPES_MAP.keys())
|
WNA16_SUPPORTED_BITS = list(WNA16_SUPPORTED_TYPES_MAP.keys())
|
||||||
|
|
||||||
|
@ -1,18 +1,22 @@
|
|||||||
from typing import Any, Dict, List, Optional
|
from typing import Any, Callable, Dict, List, Optional, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from torch.nn import Parameter
|
from torch.nn import Parameter
|
||||||
|
|
||||||
from vllm import _custom_ops as ops
|
from vllm import _custom_ops as ops
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
from vllm.model_executor.layers.linear import LinearBase, LinearMethodBase
|
from vllm.model_executor.layers.fused_moe.layer import (
|
||||||
|
FusedMoE, FusedMoEMethodBase, FusedMoeWeightScaleSupported)
|
||||||
|
from vllm.model_executor.layers.linear import (LinearBase, LinearMethodBase,
|
||||||
|
set_weight_attrs)
|
||||||
from vllm.model_executor.layers.quantization.base_config import (
|
from vllm.model_executor.layers.quantization.base_config import (
|
||||||
QuantizationConfig)
|
QuantizationConfig)
|
||||||
from vllm.model_executor.layers.quantization.utils.marlin_utils import (
|
from vllm.model_executor.layers.quantization.utils.marlin_utils import (
|
||||||
apply_gptq_marlin_linear, check_marlin_supported, marlin_is_k_full,
|
apply_gptq_marlin_linear, check_marlin_supported, marlin_is_k_full,
|
||||||
marlin_make_empty_g_idx, marlin_make_workspace, marlin_permute_scales,
|
marlin_make_empty_g_idx, marlin_make_workspace, marlin_moe_permute_scales,
|
||||||
marlin_repeat_scales_on_all_ranks, marlin_sort_g_idx, replace_tensor,
|
marlin_permute_scales, marlin_repeat_scales_on_all_ranks,
|
||||||
verify_marlin_supported, verify_marlin_supports_shape)
|
marlin_sort_g_idx, replace_tensor, verify_marlin_supported,
|
||||||
|
verify_marlin_supports_shape)
|
||||||
from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead
|
from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead
|
||||||
from vllm.model_executor.parameter import (ChannelQuantScaleParameter,
|
from vllm.model_executor.parameter import (ChannelQuantScaleParameter,
|
||||||
GroupQuantScaleParameter,
|
GroupQuantScaleParameter,
|
||||||
@ -33,8 +37,14 @@ class GPTQMarlinConfig(QuantizationConfig):
|
|||||||
(8, True): scalar_types.uint8b128,
|
(8, True): scalar_types.uint8b128,
|
||||||
}
|
}
|
||||||
|
|
||||||
def __init__(self, weight_bits: int, group_size: int, desc_act: bool,
|
def __init__(
|
||||||
is_sym: bool, lm_head_quantized: bool) -> None:
|
self,
|
||||||
|
weight_bits: int,
|
||||||
|
group_size: int,
|
||||||
|
desc_act: bool,
|
||||||
|
is_sym: bool,
|
||||||
|
lm_head_quantized: bool,
|
||||||
|
) -> None:
|
||||||
if desc_act and group_size == -1:
|
if desc_act and group_size == -1:
|
||||||
# In this case, act_order == True is the same as act_order == False
|
# In this case, act_order == True is the same as act_order == False
|
||||||
# (since we have only one group per output channel)
|
# (since we have only one group per output channel)
|
||||||
@ -105,11 +115,14 @@ class GPTQMarlinConfig(QuantizationConfig):
|
|||||||
" faster inference")
|
" faster inference")
|
||||||
return None
|
return None
|
||||||
|
|
||||||
def get_quant_method(self, layer: torch.nn.Module,
|
def get_quant_method(
|
||||||
prefix: str) -> Optional["GPTQMarlinLinearMethod"]:
|
self, layer: torch.nn.Module, prefix: str
|
||||||
if (isinstance(layer, LinearBase) or
|
) -> Optional[Union["GPTQMarlinLinearMethod", "GPTQMarlinMoEMethod"]]:
|
||||||
(isinstance(layer, ParallelLMHead) and self.lm_head_quantized)):
|
if isinstance(layer, LinearBase) or (isinstance(layer, ParallelLMHead)
|
||||||
|
and self.lm_head_quantized):
|
||||||
return GPTQMarlinLinearMethod(self)
|
return GPTQMarlinLinearMethod(self)
|
||||||
|
elif isinstance(layer, FusedMoE):
|
||||||
|
return GPTQMarlinMoEMethod(self)
|
||||||
return None
|
return None
|
||||||
|
|
||||||
def get_scaled_act_names(self) -> List[str]:
|
def get_scaled_act_names(self) -> List[str]:
|
||||||
@ -179,7 +192,8 @@ class GPTQMarlinLinearMethod(LinearMethodBase):
|
|||||||
output_size_per_partition=output_size_per_partition,
|
output_size_per_partition=output_size_per_partition,
|
||||||
input_size_per_partition=input_size_per_partition,
|
input_size_per_partition=input_size_per_partition,
|
||||||
input_size=input_size,
|
input_size=input_size,
|
||||||
group_size=group_size)
|
group_size=group_size,
|
||||||
|
)
|
||||||
|
|
||||||
# Determine sharding
|
# Determine sharding
|
||||||
if marlin_repeat_scales_on_all_ranks(self.quant_config.desc_act,
|
if marlin_repeat_scales_on_all_ranks(self.quant_config.desc_act,
|
||||||
@ -299,7 +313,8 @@ class GPTQMarlinLinearMethod(LinearMethodBase):
|
|||||||
perm=layer.g_idx_sort_indices,
|
perm=layer.g_idx_sort_indices,
|
||||||
size_k=layer.input_size_per_partition,
|
size_k=layer.input_size_per_partition,
|
||||||
size_n=layer.output_size_per_partition,
|
size_n=layer.output_size_per_partition,
|
||||||
num_bits=self.quant_config.quant_type.size_bits)
|
num_bits=self.quant_config.quant_type.size_bits,
|
||||||
|
)
|
||||||
replace_tensor(layer, "qweight", marlin_qweight)
|
replace_tensor(layer, "qweight", marlin_qweight)
|
||||||
|
|
||||||
# Permute scales from autogptq format to marlin format.
|
# Permute scales from autogptq format to marlin format.
|
||||||
@ -308,7 +323,8 @@ class GPTQMarlinLinearMethod(LinearMethodBase):
|
|||||||
size_k=(layer.input_size if self.quant_config.desc_act else
|
size_k=(layer.input_size if self.quant_config.desc_act else
|
||||||
layer.input_size_per_partition),
|
layer.input_size_per_partition),
|
||||||
size_n=layer.output_size_per_partition,
|
size_n=layer.output_size_per_partition,
|
||||||
group_size=self.quant_config.group_size)
|
group_size=self.quant_config.group_size,
|
||||||
|
)
|
||||||
replace_tensor(layer, "scales", marlin_scales)
|
replace_tensor(layer, "scales", marlin_scales)
|
||||||
|
|
||||||
def apply(
|
def apply(
|
||||||
@ -329,4 +345,270 @@ class GPTQMarlinLinearMethod(LinearMethodBase):
|
|||||||
output_size_per_partition=layer.output_size_per_partition,
|
output_size_per_partition=layer.output_size_per_partition,
|
||||||
input_size_per_partition=layer.input_size_per_partition,
|
input_size_per_partition=layer.input_size_per_partition,
|
||||||
is_k_full=layer.is_k_full,
|
is_k_full=layer.is_k_full,
|
||||||
bias=bias)
|
bias=bias,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class GPTQMarlinMoEMethod(FusedMoEMethodBase):
|
||||||
|
"""MoE Marlin method with quantization."""
|
||||||
|
|
||||||
|
def __init__(self, quant_config: GPTQMarlinConfig) -> None:
|
||||||
|
self.quant_config = quant_config
|
||||||
|
|
||||||
|
def create_weights(
|
||||||
|
self,
|
||||||
|
layer: torch.nn.Module,
|
||||||
|
num_experts: int,
|
||||||
|
hidden_size: int,
|
||||||
|
intermediate_size: int,
|
||||||
|
params_dtype: torch.dtype,
|
||||||
|
**extra_weight_attrs,
|
||||||
|
):
|
||||||
|
# Currently assuming is_k_full is always True
|
||||||
|
# (input size per partition is the same as full input size)
|
||||||
|
# Supports only sym for now (no zp)
|
||||||
|
if self.quant_config.group_size != -1:
|
||||||
|
scales_size13 = hidden_size // self.quant_config.group_size
|
||||||
|
scales_size2 = intermediate_size // self.quant_config.group_size
|
||||||
|
strategy = FusedMoeWeightScaleSupported.GROUP.value
|
||||||
|
else:
|
||||||
|
scales_size13 = 1
|
||||||
|
scales_size2 = 1
|
||||||
|
strategy = FusedMoeWeightScaleSupported.CHANNEL.value
|
||||||
|
|
||||||
|
extra_weight_attrs.update({
|
||||||
|
"quant_method": strategy,
|
||||||
|
"is_transposed": True
|
||||||
|
})
|
||||||
|
# Fused gate_up_proj (column parallel)
|
||||||
|
w13_qweight = torch.nn.Parameter(
|
||||||
|
torch.empty(
|
||||||
|
num_experts,
|
||||||
|
hidden_size // self.quant_config.pack_factor,
|
||||||
|
2 * intermediate_size,
|
||||||
|
dtype=torch.int32,
|
||||||
|
),
|
||||||
|
requires_grad=False,
|
||||||
|
)
|
||||||
|
layer.register_parameter("w13_qweight", w13_qweight)
|
||||||
|
set_weight_attrs(w13_qweight, extra_weight_attrs)
|
||||||
|
# down_proj (row parallel)
|
||||||
|
w2_qweight = torch.nn.Parameter(
|
||||||
|
torch.empty(
|
||||||
|
num_experts,
|
||||||
|
intermediate_size // self.quant_config.pack_factor,
|
||||||
|
hidden_size,
|
||||||
|
dtype=torch.int32,
|
||||||
|
),
|
||||||
|
requires_grad=False,
|
||||||
|
)
|
||||||
|
layer.register_parameter("w2_qweight", w2_qweight)
|
||||||
|
set_weight_attrs(w2_qweight, extra_weight_attrs)
|
||||||
|
# up_proj scales
|
||||||
|
w13_scales = torch.nn.Parameter(
|
||||||
|
torch.empty(num_experts,
|
||||||
|
scales_size13,
|
||||||
|
2 * intermediate_size,
|
||||||
|
dtype=torch.half),
|
||||||
|
requires_grad=False,
|
||||||
|
)
|
||||||
|
layer.register_parameter("w13_scales", w13_scales)
|
||||||
|
set_weight_attrs(w13_scales, extra_weight_attrs)
|
||||||
|
# down_proj scales
|
||||||
|
w2_scales = torch.nn.Parameter(
|
||||||
|
torch.empty(num_experts,
|
||||||
|
scales_size2,
|
||||||
|
hidden_size,
|
||||||
|
dtype=torch.half),
|
||||||
|
requires_grad=False,
|
||||||
|
)
|
||||||
|
layer.register_parameter("w2_scales", w2_scales)
|
||||||
|
set_weight_attrs(w2_scales, extra_weight_attrs)
|
||||||
|
# up_proj scales
|
||||||
|
w13_qzeros = torch.nn.Parameter(
|
||||||
|
torch.empty(num_experts,
|
||||||
|
scales_size13,
|
||||||
|
2 * intermediate_size // self.quant_config.pack_factor,
|
||||||
|
dtype=params_dtype),
|
||||||
|
requires_grad=False,
|
||||||
|
)
|
||||||
|
layer.register_parameter("w13_qzeros", w13_qzeros)
|
||||||
|
set_weight_attrs(w13_qzeros, extra_weight_attrs)
|
||||||
|
# down_proj scales
|
||||||
|
w2_qzeros = torch.nn.Parameter(
|
||||||
|
torch.empty(num_experts,
|
||||||
|
scales_size2,
|
||||||
|
hidden_size // self.quant_config.pack_factor,
|
||||||
|
dtype=params_dtype),
|
||||||
|
requires_grad=False,
|
||||||
|
)
|
||||||
|
layer.register_parameter("w2_qzeros", w2_qzeros)
|
||||||
|
set_weight_attrs(w2_qzeros, extra_weight_attrs)
|
||||||
|
w13_g_idx = torch.nn.Parameter(
|
||||||
|
torch.empty(
|
||||||
|
num_experts,
|
||||||
|
hidden_size,
|
||||||
|
dtype=torch.int32,
|
||||||
|
),
|
||||||
|
requires_grad=False,
|
||||||
|
)
|
||||||
|
layer.register_parameter("w13_g_idx", w13_g_idx)
|
||||||
|
set_weight_attrs(w13_g_idx, extra_weight_attrs)
|
||||||
|
w2_g_idx = torch.nn.Parameter(
|
||||||
|
torch.empty(
|
||||||
|
num_experts,
|
||||||
|
intermediate_size,
|
||||||
|
dtype=torch.int32,
|
||||||
|
),
|
||||||
|
requires_grad=False,
|
||||||
|
)
|
||||||
|
layer.register_parameter("w2_g_idx", w2_g_idx)
|
||||||
|
set_weight_attrs(w2_g_idx, extra_weight_attrs)
|
||||||
|
w13_g_idx_sort_indices = torch.nn.Parameter(
|
||||||
|
torch.empty(
|
||||||
|
num_experts,
|
||||||
|
hidden_size,
|
||||||
|
dtype=torch.int32,
|
||||||
|
),
|
||||||
|
requires_grad=False,
|
||||||
|
)
|
||||||
|
layer.register_parameter("w13_g_idx_sort_indices",
|
||||||
|
w13_g_idx_sort_indices)
|
||||||
|
set_weight_attrs(w13_g_idx_sort_indices, extra_weight_attrs)
|
||||||
|
w2_g_idx_sort_indices = torch.nn.Parameter(
|
||||||
|
torch.empty(
|
||||||
|
num_experts,
|
||||||
|
intermediate_size,
|
||||||
|
dtype=torch.int32,
|
||||||
|
),
|
||||||
|
requires_grad=False,
|
||||||
|
)
|
||||||
|
layer.register_parameter("w2_g_idx_sort_indices",
|
||||||
|
w2_g_idx_sort_indices)
|
||||||
|
set_weight_attrs(w2_g_idx_sort_indices, extra_weight_attrs)
|
||||||
|
|
||||||
|
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
|
||||||
|
|
||||||
|
# Process act_order
|
||||||
|
if self.quant_config.desc_act:
|
||||||
|
# Get sorting based on g_idx
|
||||||
|
num_experts = layer.w13_g_idx.shape[0]
|
||||||
|
w13_g_idx_sort_indices = torch.empty_like(layer.w13_g_idx)
|
||||||
|
w2_g_idx_sort_indices = torch.empty_like(layer.w2_g_idx)
|
||||||
|
w13_sorted_g_idx = torch.empty_like(layer.w13_g_idx)
|
||||||
|
w2_sorted_g_idx = torch.empty_like(layer.w2_g_idx)
|
||||||
|
for e in range(num_experts):
|
||||||
|
w13_g_idx_sort_indices[e] = torch.argsort(
|
||||||
|
layer.w13_g_idx[e]).to(torch.int32)
|
||||||
|
w2_g_idx_sort_indices[e] = torch.argsort(layer.w2_g_idx[e]).to(
|
||||||
|
torch.int32)
|
||||||
|
w13_sorted_g_idx[e] = layer.w13_g_idx[e][
|
||||||
|
w13_g_idx_sort_indices[e]]
|
||||||
|
w2_sorted_g_idx[e] = layer.w2_g_idx[e][
|
||||||
|
w2_g_idx_sort_indices[e]]
|
||||||
|
replace_tensor(layer, "w13_g_idx", w13_sorted_g_idx)
|
||||||
|
replace_tensor(layer, "w2_g_idx", w2_sorted_g_idx)
|
||||||
|
replace_tensor(layer, "w13_g_idx_sort_indices",
|
||||||
|
w13_g_idx_sort_indices)
|
||||||
|
replace_tensor(layer, "w2_g_idx_sort_indices",
|
||||||
|
w2_g_idx_sort_indices)
|
||||||
|
else:
|
||||||
|
# Reset g_idx related tensors
|
||||||
|
num_experts = layer.w13_g_idx.shape[0]
|
||||||
|
device = layer.w13_g_idx.device
|
||||||
|
layer.w13_g_idx = torch.nn.Parameter(
|
||||||
|
torch.empty((num_experts, 0), dtype=torch.int32,
|
||||||
|
device=device),
|
||||||
|
requires_grad=False,
|
||||||
|
)
|
||||||
|
layer.w2_g_idx = torch.nn.Parameter(
|
||||||
|
torch.empty((num_experts, 0), dtype=torch.int32,
|
||||||
|
device=device),
|
||||||
|
requires_grad=False,
|
||||||
|
)
|
||||||
|
layer.w13_g_idx_sort_indices = torch.nn.Parameter(
|
||||||
|
torch.empty((num_experts, 0), dtype=torch.int32,
|
||||||
|
device=device),
|
||||||
|
requires_grad=False,
|
||||||
|
)
|
||||||
|
layer.w2_g_idx_sort_indices = torch.nn.Parameter(
|
||||||
|
torch.empty((num_experts, 0), dtype=torch.int32,
|
||||||
|
device=device),
|
||||||
|
requires_grad=False,
|
||||||
|
)
|
||||||
|
# Repack weights
|
||||||
|
marlin_w13_qweight = ops.gptq_marlin_moe_repack(
|
||||||
|
layer.w13_qweight,
|
||||||
|
layer.w13_g_idx_sort_indices,
|
||||||
|
layer.w13_qweight.shape[1] * self.quant_config.pack_factor,
|
||||||
|
layer.w13_qweight.shape[2],
|
||||||
|
self.quant_config.quant_type.size_bits,
|
||||||
|
)
|
||||||
|
replace_tensor(layer, "w13_qweight", marlin_w13_qweight)
|
||||||
|
marlin_w2_qweight = ops.gptq_marlin_moe_repack(
|
||||||
|
layer.w2_qweight,
|
||||||
|
layer.w2_g_idx_sort_indices,
|
||||||
|
layer.w2_qweight.shape[1] * self.quant_config.pack_factor,
|
||||||
|
layer.w2_qweight.shape[2],
|
||||||
|
self.quant_config.quant_type.size_bits,
|
||||||
|
)
|
||||||
|
replace_tensor(layer, "w2_qweight", marlin_w2_qweight)
|
||||||
|
# Repack scales
|
||||||
|
marlin_w13_scales = marlin_moe_permute_scales(
|
||||||
|
s=layer.w13_scales,
|
||||||
|
size_k=layer.intermediate_size_per_partition,
|
||||||
|
size_n=layer.w13_scales.shape[2],
|
||||||
|
group_size=self.quant_config.group_size,
|
||||||
|
)
|
||||||
|
replace_tensor(layer, "w13_scales", marlin_w13_scales)
|
||||||
|
marlin_w2_scales = marlin_moe_permute_scales(
|
||||||
|
s=layer.w2_scales,
|
||||||
|
size_k=layer.w2_scales.shape[1] * self.quant_config.pack_factor,
|
||||||
|
size_n=layer.w2_scales.shape[2],
|
||||||
|
group_size=self.quant_config.group_size,
|
||||||
|
)
|
||||||
|
replace_tensor(layer, "w2_scales", marlin_w2_scales)
|
||||||
|
|
||||||
|
def apply(
|
||||||
|
self,
|
||||||
|
layer: torch.nn.Module,
|
||||||
|
x: torch.Tensor,
|
||||||
|
router_logits: torch.Tensor,
|
||||||
|
top_k: int,
|
||||||
|
renormalize: bool = True,
|
||||||
|
use_grouped_topk: bool = False,
|
||||||
|
num_expert_group: Optional[int] = None,
|
||||||
|
topk_group: Optional[int] = None,
|
||||||
|
custom_routing_function: Optional[Callable] = None,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
from vllm.model_executor.layers.fused_moe.fused_marlin_moe import (
|
||||||
|
fused_marlin_moe)
|
||||||
|
|
||||||
|
# The input must currently be float16
|
||||||
|
orig_dtype = x.dtype
|
||||||
|
x = x.half()
|
||||||
|
|
||||||
|
topk_weights, topk_ids = FusedMoE.select_experts(
|
||||||
|
hidden_states=x,
|
||||||
|
router_logits=router_logits,
|
||||||
|
use_grouped_topk=use_grouped_topk,
|
||||||
|
top_k=top_k,
|
||||||
|
renormalize=renormalize,
|
||||||
|
topk_group=topk_group,
|
||||||
|
num_expert_group=num_expert_group,
|
||||||
|
custom_routing_function=None)
|
||||||
|
|
||||||
|
return fused_marlin_moe(
|
||||||
|
x,
|
||||||
|
layer.w13_qweight,
|
||||||
|
layer.w2_qweight,
|
||||||
|
router_logits,
|
||||||
|
layer.w13_g_idx,
|
||||||
|
layer.w2_g_idx,
|
||||||
|
layer.w13_g_idx_sort_indices,
|
||||||
|
layer.w2_g_idx_sort_indices,
|
||||||
|
topk_weights,
|
||||||
|
topk_ids,
|
||||||
|
w1_scale=layer.w13_scales,
|
||||||
|
w2_scale=layer.w2_scales,
|
||||||
|
).to(orig_dtype)
|
||||||
|
@ -176,6 +176,23 @@ def marlin_permute_scales(s: torch.Tensor, size_k: int, size_n: int,
|
|||||||
return s
|
return s
|
||||||
|
|
||||||
|
|
||||||
|
def marlin_moe_permute_scales(
|
||||||
|
s: torch.Tensor,
|
||||||
|
size_k: int,
|
||||||
|
size_n: int,
|
||||||
|
group_size: int,
|
||||||
|
):
|
||||||
|
num_experts = s.shape[0]
|
||||||
|
output = torch.empty(
|
||||||
|
(num_experts, s.shape[1], s.shape[2]),
|
||||||
|
device=s.device,
|
||||||
|
dtype=s.dtype,
|
||||||
|
)
|
||||||
|
for e in range(num_experts):
|
||||||
|
output[e] = marlin_permute_scales(s[e], size_k, size_n, group_size)
|
||||||
|
return output
|
||||||
|
|
||||||
|
|
||||||
def marlin_zero_points(zp: torch.Tensor, size_k: int, size_n: int,
|
def marlin_zero_points(zp: torch.Tensor, size_k: int, size_n: int,
|
||||||
num_bits: int) -> torch.Tensor:
|
num_bits: int) -> torch.Tensor:
|
||||||
# Permute zero-points in a similar way to scales, but do not use the
|
# Permute zero-points in a similar way to scales, but do not use the
|
||||||
|
@ -1,6 +1,6 @@
|
|||||||
"""Utility functions used for tests and benchmarks"""
|
"""Utility functions used for tests and benchmarks"""
|
||||||
|
|
||||||
from typing import List
|
from typing import List, Optional
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
@ -92,8 +92,11 @@ def get_weight_perm(num_bits: int):
|
|||||||
return perm
|
return perm
|
||||||
|
|
||||||
|
|
||||||
def marlin_quantize(w: torch.Tensor, quant_type: ScalarType, group_size: int,
|
def marlin_quantize(w: torch.Tensor,
|
||||||
act_order: bool):
|
quant_type: ScalarType,
|
||||||
|
group_size: int,
|
||||||
|
act_order: bool,
|
||||||
|
test_perm: Optional[torch.Tensor] = None):
|
||||||
size_k, size_n = w.shape
|
size_k, size_n = w.shape
|
||||||
num_bits = quant_type.size_bits
|
num_bits = quant_type.size_bits
|
||||||
|
|
||||||
@ -104,7 +107,7 @@ def marlin_quantize(w: torch.Tensor, quant_type: ScalarType, group_size: int,
|
|||||||
|
|
||||||
# Quantize (and apply act_order if provided)
|
# Quantize (and apply act_order if provided)
|
||||||
w_ref, q_w, s, g_idx, rand_perm = gptq_quantize_weights(
|
w_ref, q_w, s, g_idx, rand_perm = gptq_quantize_weights(
|
||||||
w, quant_type, group_size, act_order)
|
w, quant_type, group_size, act_order, test_perm)
|
||||||
|
|
||||||
# For act_order, sort the "weights" and "g_idx" so that group ids are
|
# For act_order, sort the "weights" and "g_idx" so that group ids are
|
||||||
# increasing
|
# increasing
|
||||||
|
@ -1,5 +1,5 @@
|
|||||||
"""This file is used for /tests and /benchmarks"""
|
"""This file is used for /tests and /benchmarks"""
|
||||||
from typing import List
|
from typing import List, Optional
|
||||||
|
|
||||||
import numpy
|
import numpy
|
||||||
import torch
|
import torch
|
||||||
@ -53,7 +53,10 @@ def get_pack_factor(num_bits):
|
|||||||
return 32 // num_bits
|
return 32 // num_bits
|
||||||
|
|
||||||
|
|
||||||
def permute_rows(q_w: torch.Tensor, w_ref: torch.Tensor, group_size: int):
|
def permute_rows(q_w: torch.Tensor,
|
||||||
|
w_ref: torch.Tensor,
|
||||||
|
group_size: int,
|
||||||
|
test_perm: Optional[torch.Tensor] = None):
|
||||||
assert q_w.shape == w_ref.shape
|
assert q_w.shape == w_ref.shape
|
||||||
|
|
||||||
orig_device = q_w.device
|
orig_device = q_w.device
|
||||||
@ -64,7 +67,7 @@ def permute_rows(q_w: torch.Tensor, w_ref: torch.Tensor, group_size: int):
|
|||||||
g_idx[i] = i // group_size
|
g_idx[i] = i // group_size
|
||||||
|
|
||||||
# Simulate act_order by doing a random permutation on K
|
# Simulate act_order by doing a random permutation on K
|
||||||
rand_perm = torch.randperm(k_size)
|
rand_perm = test_perm if test_perm is not None else torch.randperm(k_size)
|
||||||
|
|
||||||
g_idx = g_idx[rand_perm].contiguous()
|
g_idx = g_idx[rand_perm].contiguous()
|
||||||
q_w = q_w[rand_perm, :].contiguous()
|
q_w = q_w[rand_perm, :].contiguous()
|
||||||
@ -164,8 +167,11 @@ def quantize_weights(w: torch.Tensor,
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def gptq_quantize_weights(w: torch.Tensor, quant_type: ScalarType,
|
def gptq_quantize_weights(w: torch.Tensor,
|
||||||
group_size: int, act_order: bool):
|
quant_type: ScalarType,
|
||||||
|
group_size: int,
|
||||||
|
act_order: bool,
|
||||||
|
test_perm: Optional[torch.Tensor] = None):
|
||||||
size_k, _ = w.shape
|
size_k, _ = w.shape
|
||||||
|
|
||||||
assert w.is_floating_point(), "w must be float"
|
assert w.is_floating_point(), "w must be float"
|
||||||
@ -186,7 +192,8 @@ def gptq_quantize_weights(w: torch.Tensor, quant_type: ScalarType,
|
|||||||
), "For act_order, groupsize = {} must be less than size_k = {}".format(
|
), "For act_order, groupsize = {} must be less than size_k = {}".format(
|
||||||
group_size, size_k)
|
group_size, size_k)
|
||||||
|
|
||||||
w_ref, w_q, g_idx, rand_perm = permute_rows(w_q, w_ref, group_size)
|
w_ref, w_q, g_idx, rand_perm = permute_rows(w_q, w_ref, group_size,
|
||||||
|
test_perm)
|
||||||
|
|
||||||
return w_ref, w_q, w_s, g_idx, rand_perm
|
return w_ref, w_q, w_s, g_idx, rand_perm
|
||||||
|
|
||||||
|
@ -24,10 +24,18 @@ def get_model_architecture(
|
|||||||
# Special handling for quantized Mixtral.
|
# Special handling for quantized Mixtral.
|
||||||
# FIXME(woosuk): This is a temporary hack.
|
# FIXME(woosuk): This is a temporary hack.
|
||||||
mixtral_supported = ["fp8", "compressed-tensors"]
|
mixtral_supported = ["fp8", "compressed-tensors"]
|
||||||
|
# for gptq_marlin, only run fused MoE for int4
|
||||||
|
if model_config.quantization == "gptq_marlin":
|
||||||
|
hf_quant_config = getattr(model_config.hf_config,
|
||||||
|
"quantization_config", None)
|
||||||
|
if hf_quant_config and hf_quant_config.get("bits") == 4:
|
||||||
|
mixtral_supported.append("gptq_marlin")
|
||||||
|
|
||||||
if (model_config.quantization is not None
|
if (model_config.quantization is not None
|
||||||
and model_config.quantization not in mixtral_supported
|
and model_config.quantization not in mixtral_supported
|
||||||
and "MixtralForCausalLM" in architectures):
|
and "MixtralForCausalLM" in architectures):
|
||||||
architectures = ["QuantMixtralForCausalLM"]
|
architectures = ["QuantMixtralForCausalLM"]
|
||||||
|
|
||||||
return ModelRegistry.resolve_model_cls(architectures)
|
return ModelRegistry.resolve_model_cls(architectures)
|
||||||
|
|
||||||
|
|
||||||
|
@ -435,7 +435,8 @@ class MixtralForCausalLM(nn.Module, SupportsLoRA):
|
|||||||
continue
|
continue
|
||||||
name = name.replace(weight_name, param_name)
|
name = name.replace(weight_name, param_name)
|
||||||
# Skip loading extra bias for GPTQ models.
|
# Skip loading extra bias for GPTQ models.
|
||||||
if name.endswith(".bias") and name not in params_dict:
|
if ((name.endswith(".bias") or name.endswith("_bias"))
|
||||||
|
and name not in params_dict):
|
||||||
continue
|
continue
|
||||||
# Skip layers on other devices.
|
# Skip layers on other devices.
|
||||||
if is_pp_missing_parameter(name, self):
|
if is_pp_missing_parameter(name, self):
|
||||||
@ -454,6 +455,9 @@ class MixtralForCausalLM(nn.Module, SupportsLoRA):
|
|||||||
# Skip layers on other devices.
|
# Skip layers on other devices.
|
||||||
if is_pp_missing_parameter(name, self):
|
if is_pp_missing_parameter(name, self):
|
||||||
continue
|
continue
|
||||||
|
if ((name.endswith(".bias") or name.endswith("_bias"))
|
||||||
|
and name not in params_dict):
|
||||||
|
continue
|
||||||
param = params_dict[name]
|
param = params_dict[name]
|
||||||
weight_loader = param.weight_loader
|
weight_loader = param.weight_loader
|
||||||
weight_loader(param,
|
weight_loader(param,
|
||||||
@ -464,7 +468,8 @@ class MixtralForCausalLM(nn.Module, SupportsLoRA):
|
|||||||
break
|
break
|
||||||
else:
|
else:
|
||||||
# Skip loading extra bias for GPTQ models.
|
# Skip loading extra bias for GPTQ models.
|
||||||
if name.endswith(".bias") and name not in params_dict:
|
if ((name.endswith(".bias") or name.endswith("_bias"))
|
||||||
|
and name not in params_dict):
|
||||||
continue
|
continue
|
||||||
# Skip layers on other devices.
|
# Skip layers on other devices.
|
||||||
if is_pp_missing_parameter(name, self):
|
if is_pp_missing_parameter(name, self):
|
||||||
|
Loading…
x
Reference in New Issue
Block a user