This commit is contained in:
parent
cde9183b40
commit
aae74ef95c
@ -286,8 +286,7 @@ define_gpu_extension_target(
|
||||
|
||||
set(VLLM_MOE_EXT_SRC
|
||||
"csrc/moe/torch_bindings.cpp"
|
||||
"csrc/moe/topk_softmax_kernels.cu"
|
||||
"csrc/moe/marlin_moe_ops.cu")
|
||||
"csrc/moe/topk_softmax_kernels.cu")
|
||||
|
||||
define_gpu_extension_target(
|
||||
_moe_C
|
||||
|
File diff suppressed because it is too large
Load Diff
@ -1,12 +0,0 @@
|
||||
#pragma once
|
||||
|
||||
#include <torch/all.h>
|
||||
|
||||
torch::Tensor marlin_gemm_moe(
|
||||
const torch::Tensor& a, const torch::Tensor& b_q_weights,
|
||||
const torch::Tensor& sorted_ids, const torch::Tensor& topk_weights,
|
||||
const torch::Tensor& topk_ids, const torch::Tensor& b_scales,
|
||||
const torch::Tensor& g_idx, const torch::Tensor& perm,
|
||||
torch::Tensor& workspace, int64_t size_m, int64_t size_n, int64_t size_k,
|
||||
bool is_k_full, int64_t num_experts, int64_t topk, int64_t moe_block_size,
|
||||
bool replicate_input, bool apply_weights);
|
@ -1,6 +1,5 @@
|
||||
#include "core/registration.h"
|
||||
#include "moe_ops.h"
|
||||
#include "marlin_moe_ops.h"
|
||||
|
||||
TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, m) {
|
||||
// Apply topk softmax to the gating outputs.
|
||||
@ -8,14 +7,6 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, m) {
|
||||
"topk_softmax(Tensor! topk_weights, Tensor! topk_indices, Tensor! "
|
||||
"token_expert_indices, Tensor gating_output) -> ()");
|
||||
m.impl("topk_softmax", torch::kCUDA, &topk_softmax);
|
||||
m.def(
|
||||
"marlin_gemm_moe(Tensor! a, Tensor! b_q_weights, Tensor! sorted_ids, "
|
||||
"Tensor! topk_weights, Tensor! topk_ids, Tensor! b_scales, Tensor! "
|
||||
"g_idx, Tensor! perm, Tensor! workspace, int size_m, int size_n, int "
|
||||
"size_k, bool is_k_full, int num_experts, int topk, int moe_block_size, "
|
||||
"bool replicate_input, bool apply_weights) -> Tensor");
|
||||
|
||||
m.impl("marlin_gemm_moe", torch::kCUDA, &marlin_gemm_moe);
|
||||
}
|
||||
|
||||
REGISTER_EXTENSION(TORCH_EXTENSION_NAME)
|
||||
|
@ -13,7 +13,5 @@ compressed-tensors, nm-testing/tinyllama-oneshot-w8a16-per-channel, main
|
||||
compressed-tensors, nm-testing/Meta-Llama-3-8B-FP8-compressed-tensors-test, main
|
||||
compressed-tensors, nm-testing/Phi-3-mini-128k-instruct-FP8, main
|
||||
compressed-tensors, neuralmagic/Phi-3-medium-128k-instruct-quantized.w4a16, main
|
||||
compressed-tensors, nm-testing/Mixtral-8x7B-Instruct-v0.1-W4A16-quantized, main
|
||||
compressed-tensors, nm-testing/Mixtral-8x7B-Instruct-v0.1-W4A16-channel-quantized, main
|
||||
awq, casperhansen/mixtral-instruct-awq, main
|
||||
awq_marlin, casperhansen/mixtral-instruct-awq, main
|
||||
|
@ -300,20 +300,6 @@ def awq_marlin_repack(b_q_weight: torch.Tensor, size_k: int, size_n: int,
|
||||
return torch.ops._C.awq_marlin_repack(b_q_weight, size_k, size_n, num_bits)
|
||||
|
||||
|
||||
def gptq_marlin_moe_repack(b_q_weight: torch.Tensor, perm: torch.Tensor,
|
||||
size_k: int, size_n: int,
|
||||
num_bits: int) -> torch.Tensor:
|
||||
num_experts = b_q_weight.shape[0]
|
||||
assert size_k % 16 == 0
|
||||
output = torch.empty((num_experts, size_k // 16, size_n * 2),
|
||||
device=b_q_weight.device,
|
||||
dtype=b_q_weight.dtype)
|
||||
for e in range(num_experts):
|
||||
output[e] = torch.ops._C.gptq_marlin_repack(b_q_weight[e], perm[e],
|
||||
size_k, size_n, num_bits)
|
||||
return output
|
||||
|
||||
|
||||
def gptq_marlin_gemm(a: torch.Tensor,
|
||||
b_q_weight: torch.Tensor,
|
||||
b_scales: torch.Tensor,
|
||||
|
@ -1,17 +1,19 @@
|
||||
from vllm.model_executor.layers.fused_moe.layer import (
|
||||
FusedMoE, FusedMoEMethodBase, FusedMoeWeightScaleSupported)
|
||||
from vllm.model_executor.layers.fused_moe.layer import (FusedMoE,
|
||||
FusedMoEMethodBase)
|
||||
from vllm.triton_utils import HAS_TRITON
|
||||
|
||||
__all__ = ["FusedMoE", "FusedMoEMethodBase", "FusedMoeWeightScaleSupported"]
|
||||
__all__ = [
|
||||
"FusedMoE",
|
||||
"FusedMoEMethodBase",
|
||||
]
|
||||
|
||||
if HAS_TRITON:
|
||||
|
||||
from vllm.model_executor.layers.fused_moe.fused_moe import (
|
||||
fused_experts, fused_marlin_moe, fused_moe, fused_topk,
|
||||
get_config_file_name, grouped_topk)
|
||||
fused_experts, fused_moe, fused_topk, get_config_file_name,
|
||||
grouped_topk)
|
||||
|
||||
__all__ += [
|
||||
"fused_marlin_moe",
|
||||
"fused_moe",
|
||||
"fused_topk",
|
||||
"fused_experts",
|
||||
|
@ -323,16 +323,21 @@ def get_moe_configs(E: int, N: int,
|
||||
return None
|
||||
|
||||
|
||||
def get_default_config(M: int, E: int, N: int, K: int, topk: int,
|
||||
dtype: Optional[str],
|
||||
is_marlin: bool) -> Dict[str, int]:
|
||||
def get_default_config(
|
||||
M: int,
|
||||
E: int,
|
||||
N: int,
|
||||
K: int,
|
||||
topk: int,
|
||||
dtype: Optional[str],
|
||||
) -> Dict[str, int]:
|
||||
config = {
|
||||
'BLOCK_SIZE_M': 64,
|
||||
'BLOCK_SIZE_N': 64,
|
||||
'BLOCK_SIZE_K': 32,
|
||||
'GROUP_SIZE_M': 8
|
||||
}
|
||||
if M <= E or (is_marlin and M <= 32):
|
||||
if M <= E:
|
||||
config = {
|
||||
'BLOCK_SIZE_M': 16,
|
||||
'BLOCK_SIZE_N': 32,
|
||||
@ -342,14 +347,14 @@ def get_default_config(M: int, E: int, N: int, K: int, topk: int,
|
||||
return config
|
||||
|
||||
|
||||
def try_get_optimal_moe_config(w1_shape: Tuple[int, ...],
|
||||
w2_shape: Tuple[int, ...],
|
||||
top_k: int,
|
||||
dtype: Optional[str],
|
||||
M: int,
|
||||
override_config: Optional[Dict[str,
|
||||
Any]] = None,
|
||||
is_marlin: bool = False):
|
||||
def try_get_optimal_moe_config(
|
||||
w1_shape: Tuple[int, ...],
|
||||
w2_shape: Tuple[int, ...],
|
||||
top_k: int,
|
||||
dtype: Optional[str],
|
||||
M: int,
|
||||
override_config: Optional[Dict[str, Any]] = None,
|
||||
):
|
||||
if override_config:
|
||||
config = override_config
|
||||
else:
|
||||
@ -363,8 +368,7 @@ def try_get_optimal_moe_config(w1_shape: Tuple[int, ...],
|
||||
config = configs[min(configs.keys(), key=lambda x: abs(x - M))]
|
||||
else:
|
||||
# Else use the default config
|
||||
config = get_default_config(M, E, N, w1_shape[2], top_k, dtype,
|
||||
is_marlin)
|
||||
config = get_default_config(M, E, N, w1_shape[2], top_k, dtype)
|
||||
return config
|
||||
|
||||
|
||||
@ -437,108 +441,6 @@ def grouped_topk(hidden_states: torch.Tensor,
|
||||
return topk_weights, topk_ids
|
||||
|
||||
|
||||
def fused_marlin_moe(hidden_states: torch.Tensor,
|
||||
w1: torch.Tensor,
|
||||
w2: torch.Tensor,
|
||||
gating_output: torch.Tensor,
|
||||
g_idx1: torch.Tensor,
|
||||
g_idx2: torch.Tensor,
|
||||
rand_perm1: torch.Tensor,
|
||||
rand_perm2: torch.Tensor,
|
||||
topk: int,
|
||||
renormalize: bool,
|
||||
override_config: Optional[Dict[str, Any]] = None,
|
||||
use_fp8: bool = False,
|
||||
w1_scale: Optional[torch.Tensor] = None,
|
||||
w2_scale: Optional[torch.Tensor] = None) -> torch.Tensor:
|
||||
"""
|
||||
This function computes a Mixture of Experts (MoE) layer using two sets of
|
||||
weights, w1 and w2, and top-k gating mechanism.
|
||||
Parameters:
|
||||
- hidden_states (torch.Tensor): The input tensor to the MoE layer.
|
||||
- w1 (torch.Tensor): The first set of expert weights.
|
||||
- w2 (torch.Tensor): The second set of expert weights.
|
||||
- gating_output (torch.Tensor): The output of the gating operation
|
||||
(before softmax).
|
||||
- topk (int): The number of top-k experts to select.
|
||||
- renormalize (bool): If True, renormalize the top-k weights to sum to 1.
|
||||
- inplace (bool): If True, perform the operation in-place.
|
||||
Defaults to False.
|
||||
- override_config (Optional[Dict[str, Any]]): Optional override
|
||||
for the kernel configuration.
|
||||
- use_fp8 (bool): If True, use fp8 arithmetic to compute the inner
|
||||
products for w1 and w2. Defaults to False.
|
||||
- w1_scale (Optional[torch.Tensor]): Optional scale to be used for
|
||||
w1.
|
||||
- w2_scale (Optional[torch.Tensor]): Optional scale to be used for
|
||||
w2.
|
||||
Returns:
|
||||
- torch.Tensor: The output tensor after applying the MoE layer.
|
||||
"""
|
||||
# Check constraints.
|
||||
assert hidden_states.shape[0] == gating_output.shape[0], (
|
||||
"Number of tokens mismatch")
|
||||
assert hidden_states.shape[
|
||||
1] == w1.shape[1] * 16, "Hidden size mismatch w1"
|
||||
assert hidden_states.shape[
|
||||
1] == w2.shape[2] // 2, "Hidden size mismatch w2"
|
||||
assert gating_output.shape[1] == w1.shape[0], "Number of experts mismatch"
|
||||
assert hidden_states.is_contiguous(), "Hidden_states must be contiguous"
|
||||
assert w1.is_contiguous(), "Expert weights1 must be contiguous"
|
||||
assert w2.is_contiguous(), "Expert weights2 must be contiguous"
|
||||
assert hidden_states.dtype in [
|
||||
torch.float32, torch.float16, torch.bfloat16
|
||||
]
|
||||
|
||||
#TODO fp8 is not implemented yet
|
||||
assert not use_fp8
|
||||
|
||||
M, K = hidden_states.shape
|
||||
E = w1.shape[0]
|
||||
N = w2.shape[1] * 16
|
||||
|
||||
topk_weights, topk_ids = fused_topk(hidden_states, gating_output, topk,
|
||||
renormalize)
|
||||
|
||||
get_config_func = functools.partial(try_get_optimal_moe_config,
|
||||
w1.shape,
|
||||
w2.shape,
|
||||
topk_ids.shape[1],
|
||||
"float8" if use_fp8 else None,
|
||||
override_config=override_config,
|
||||
is_marlin=True)
|
||||
config = get_config_func(M)
|
||||
|
||||
block_size_m = config['BLOCK_SIZE_M']
|
||||
|
||||
sorted_token_ids, _, _ = moe_align_block_size(topk_ids, block_size_m, E)
|
||||
|
||||
max_workspace_size = ((M + 255) // 256) * (max(2 * N, K) // 64) * 16
|
||||
workspace = torch.zeros(max_workspace_size,
|
||||
dtype=torch.int,
|
||||
device="cuda",
|
||||
requires_grad=False)
|
||||
|
||||
intermediate_cache2 = torch.empty((M * topk_ids.shape[1], N),
|
||||
device=hidden_states.device,
|
||||
dtype=hidden_states.dtype)
|
||||
|
||||
intermediate_cache1 = torch.ops._moe_C.marlin_gemm_moe(
|
||||
hidden_states, w1, sorted_token_ids, topk_weights, topk_ids, w1_scale,
|
||||
g_idx1, rand_perm1, workspace, M, 2 * N, K, True, E, topk,
|
||||
block_size_m, True, False)
|
||||
|
||||
ops.silu_and_mul(intermediate_cache2, intermediate_cache1.view(-1, 2 * N))
|
||||
|
||||
intermediate_cache3 = torch.ops._moe_C.marlin_gemm_moe(
|
||||
intermediate_cache2, w2, sorted_token_ids, topk_weights, topk_ids,
|
||||
w2_scale, g_idx2, rand_perm2, workspace, M, K, N, True, E, topk,
|
||||
block_size_m, False, True)
|
||||
|
||||
return torch.sum(intermediate_cache3.view(*intermediate_cache3.shape),
|
||||
dim=1)
|
||||
|
||||
|
||||
def get_config_dtype_str(dtype: torch.dtype,
|
||||
use_int8_w8a16: Optional[bool] = False,
|
||||
use_fp8_w8a8: Optional[bool] = False):
|
||||
|
@ -1,5 +1,4 @@
|
||||
from abc import abstractmethod
|
||||
from enum import Enum
|
||||
from typing import List, Optional, Tuple
|
||||
|
||||
import torch
|
||||
@ -16,12 +15,6 @@ from vllm.model_executor.utils import set_weight_attrs
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
class FusedMoeWeightScaleSupported(Enum):
|
||||
TENSOR = "tensor"
|
||||
CHANNEL = "channel"
|
||||
GROUP = "group"
|
||||
|
||||
|
||||
class FusedMoEMethodBase(QuantizeMethodBase):
|
||||
|
||||
@abstractmethod
|
||||
@ -206,182 +199,55 @@ class FusedMoE(torch.nn.Module):
|
||||
params_dtype=params_dtype,
|
||||
weight_loader=self.weight_loader)
|
||||
|
||||
def _load_per_tensor_weight_scale(self, shard_id: str,
|
||||
param: torch.nn.Parameter,
|
||||
loaded_weight: torch.Tensor,
|
||||
expert_id: int):
|
||||
param_data = param.data
|
||||
# for per tensor weight quantization
|
||||
if shard_id in ("w1", "w3"):
|
||||
# We have to keep the weight scales of w1 and w3 because
|
||||
# we need to re-quantize w1/w3 weights after weight loading.
|
||||
idx = 0 if shard_id == "w1" else 1
|
||||
param_data[expert_id][idx] = loaded_weight
|
||||
# If we are in the row parallel case (down_proj)
|
||||
elif shard_id == "w2":
|
||||
param_data[expert_id] = loaded_weight
|
||||
|
||||
def _load_model_weight_or_group_weight_scale(self, shard_dim: int,
|
||||
expert_data: torch.Tensor,
|
||||
shard_id: str,
|
||||
loaded_weight: torch.tensor,
|
||||
tp_rank: int):
|
||||
# Load grouped weight scales for group quantization
|
||||
# or model weights
|
||||
if shard_id == "w2":
|
||||
self._load_w2(shard_id=shard_id,
|
||||
shard_dim=shard_dim,
|
||||
loaded_weight=loaded_weight,
|
||||
expert_data=expert_data,
|
||||
tp_rank=tp_rank)
|
||||
elif shard_id in ("w1", "w3"):
|
||||
self._load_w13(shard_id=shard_id,
|
||||
shard_dim=shard_dim,
|
||||
loaded_weight=loaded_weight,
|
||||
expert_data=expert_data,
|
||||
tp_rank=tp_rank)
|
||||
|
||||
def _load_per_channel_weight_scale(self, expert_data: torch.Tensor,
|
||||
shard_dim: int, shard_id: str,
|
||||
loaded_weight: torch.tensor,
|
||||
tp_rank: int):
|
||||
# for per channel weight quantization
|
||||
if shard_id == "w2":
|
||||
expert_data.copy_(loaded_weight)
|
||||
elif shard_id in ("w1", "w3"):
|
||||
self._load_w13(shard_id=shard_id,
|
||||
shard_dim=shard_dim,
|
||||
loaded_weight=loaded_weight,
|
||||
expert_data=expert_data,
|
||||
tp_rank=tp_rank)
|
||||
|
||||
def _load_w13(self, expert_data: torch.Tensor, shard_dim: int,
|
||||
shard_id: str, loaded_weight: torch.tensor, tp_rank: int):
|
||||
|
||||
# Index the loaded weight for tp sharding.
|
||||
# gate_up_proj: "MergedColumnParallel", so tp sharding on output_dim
|
||||
shard_size = expert_data.shape[shard_dim] // 2
|
||||
loaded_weight = loaded_weight.narrow(shard_dim, shard_size * tp_rank,
|
||||
shard_size)
|
||||
# Narrow parameter and load.
|
||||
# w1, gate_proj: Load into first logical weight of w13.
|
||||
if shard_id == "w1":
|
||||
expert_data = expert_data.narrow(shard_dim, 0, shard_size)
|
||||
# w3, up_proj: Load into second logical weight of w13.
|
||||
else:
|
||||
assert shard_id == "w3"
|
||||
expert_data = expert_data.narrow(shard_dim, shard_size, shard_size)
|
||||
expert_data.copy_(loaded_weight)
|
||||
|
||||
def _load_w2(self, expert_data: torch.Tensor, shard_dim: int,
|
||||
shard_id: str, loaded_weight: torch.tensor, tp_rank: int):
|
||||
|
||||
# Index the loaded weight for tp sharding.
|
||||
# down_proj: "RowParallel" so tp sharding on input_dim
|
||||
# Narrow parameter and load.
|
||||
shard_size = expert_data.shape[shard_dim]
|
||||
loaded_weight = loaded_weight.narrow(shard_dim, shard_size * tp_rank,
|
||||
shard_size)
|
||||
# w2, down_proj: Load into only logical weight of w2.
|
||||
expert_data.copy_(loaded_weight)
|
||||
|
||||
def _load_single_value(self, param: torch.nn.Parameter,
|
||||
loaded_weight: torch.Tensor, expert_id: int):
|
||||
param_data = param.data
|
||||
|
||||
# Input scales can be loaded directly and should be equal.
|
||||
param_data[expert_id] = loaded_weight
|
||||
|
||||
def weight_loader(self, param: torch.nn.Parameter,
|
||||
loaded_weight: torch.Tensor, weight_name: str,
|
||||
shard_id: str, expert_id: int) -> None:
|
||||
|
||||
if shard_id not in ("w1", "w2", "w3"):
|
||||
raise ValueError(f"shard_id must be ['w1','w2','w3'] but "
|
||||
f"got {shard_id}.")
|
||||
|
||||
WEIGHT_SCALE_SUPPORTED = [
|
||||
e.value for e in FusedMoeWeightScaleSupported
|
||||
]
|
||||
# Fetch the dim to shard the parameter/loaded weight
|
||||
# based on the shard id. This will be whatever
|
||||
# dimension intermediate_size is used.
|
||||
SHARD_ID_TO_SHARDED_DIM = {"w1": 0, "w2": 1, "w3": 0}
|
||||
# Special case for fp8 scales.
|
||||
if getattr(param, "is_fp8_scale", False):
|
||||
self._load_fp8_scale(param.data, loaded_weight, weight_name,
|
||||
shard_id, expert_id)
|
||||
return
|
||||
|
||||
expert_data = param.data[expert_id]
|
||||
tp_rank = get_tensor_model_parallel_rank()
|
||||
|
||||
# is_transposed: whether or not the parameter is transposed on disk
|
||||
# If transposed, the loaded weight will be transposed and the dim
|
||||
# to shard the loaded weight will be flipped.
|
||||
is_transposed = getattr(param, "is_transposed", False)
|
||||
shard_dim = SHARD_ID_TO_SHARDED_DIM[shard_id]
|
||||
if is_transposed:
|
||||
loaded_weight = loaded_weight.t().contiguous()
|
||||
shard_dim = ~shard_dim
|
||||
# If transposed, weight is saved as [input_dim, output_dim]
|
||||
# Otherwise, weight is saved as [output_dim, input_dim]
|
||||
# Default is not transposed/input dim is dim 1
|
||||
input_dim = getattr(param, "input_dim", 1)
|
||||
output_dim = getattr(param, "output_dim", 0)
|
||||
|
||||
# Case weight_scales
|
||||
if "weight_scale" in weight_name:
|
||||
# load the weight scaling based on the quantization scheme
|
||||
# supported weight scales can be found in
|
||||
# FusedMoeWeightScaleSupported
|
||||
# TODO @dsikka: once hardened, refactor to use vLLM Parameters
|
||||
# specific to each case
|
||||
quant_method = getattr(param, "quant_method", None)
|
||||
if quant_method == FusedMoeWeightScaleSupported.CHANNEL.value:
|
||||
self._load_per_channel_weight_scale(
|
||||
shard_id=shard_id,
|
||||
shard_dim=shard_dim,
|
||||
loaded_weight=loaded_weight,
|
||||
expert_data=expert_data,
|
||||
tp_rank=tp_rank)
|
||||
elif quant_method == FusedMoeWeightScaleSupported.GROUP.value:
|
||||
self._load_model_weight_or_group_weight_scale(
|
||||
shard_id=shard_id,
|
||||
shard_dim=shard_dim,
|
||||
loaded_weight=loaded_weight,
|
||||
expert_data=expert_data,
|
||||
tp_rank=tp_rank)
|
||||
elif quant_method == FusedMoeWeightScaleSupported.TENSOR.value:
|
||||
self._load_per_tensor_weight_scale(shard_id=shard_id,
|
||||
param=param,
|
||||
loaded_weight=loaded_weight,
|
||||
expert_id=expert_id)
|
||||
else:
|
||||
raise ValueError(
|
||||
f"quant method must be one of {WEIGHT_SCALE_SUPPORTED}")
|
||||
return
|
||||
# Index the loaded weight for tp sharding.
|
||||
# down_proj: "RowParallel" so tp sharding on input_dim
|
||||
if shard_id == "w2":
|
||||
shard_dim = input_dim
|
||||
shard_size = expert_data.shape[shard_dim]
|
||||
# gate_up_proj: "MergedColumnParallel", so tp sharding on output_dim
|
||||
elif shard_id in ("w1", "w3"):
|
||||
shard_dim = output_dim
|
||||
shard_size = expert_data.shape[output_dim] // 2
|
||||
offset = shard_size * tp_rank
|
||||
loaded_weight = loaded_weight.narrow(shard_dim, offset, shard_size)
|
||||
|
||||
if "weight_shape" in weight_name:
|
||||
self._load_single_value(param=param,
|
||||
loaded_weight=loaded_weight,
|
||||
expert_id=expert_id)
|
||||
return
|
||||
|
||||
# Case input scale
|
||||
if "input_scale" in weight_name:
|
||||
# Note: input_scale loading is only supported for fp8
|
||||
if param.data[expert_id] != 1 and (param.data[expert_id] -
|
||||
loaded_weight).abs() > 1e-5:
|
||||
raise ValueError(
|
||||
"input_scales of w1 and w3 of a layer "
|
||||
f"must be equal. But got {param.data[expert_id]} "
|
||||
f"vs. {loaded_weight}")
|
||||
|
||||
self._load_single_value(param=param,
|
||||
loaded_weight=loaded_weight,
|
||||
expert_id=expert_id)
|
||||
return
|
||||
|
||||
# Case model weights
|
||||
if "weight" in weight_name:
|
||||
self._load_model_weight_or_group_weight_scale(
|
||||
shard_id=shard_id,
|
||||
shard_dim=shard_dim,
|
||||
loaded_weight=loaded_weight,
|
||||
expert_data=expert_data,
|
||||
tp_rank=tp_rank)
|
||||
return
|
||||
# Narrow parameter and load.
|
||||
# w1, gate_proj: Load into first logical weight of w13.
|
||||
if shard_id == "w1":
|
||||
expert_data = expert_data.narrow(shard_dim, 0, shard_size)
|
||||
expert_data.copy_(loaded_weight)
|
||||
# w3, up_proj: Load into second logical weight of w13.
|
||||
elif shard_id == "w3":
|
||||
expert_data = expert_data.narrow(shard_dim, shard_size, shard_size)
|
||||
expert_data.copy_(loaded_weight)
|
||||
# w2, down_proj: Load into only logical weight of w2.
|
||||
elif shard_id == "w2":
|
||||
expert_data.copy_(loaded_weight)
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Expected shard_id w1,w2 or w3 but got {shard_id}")
|
||||
|
||||
@staticmethod
|
||||
def select_experts(hidden_states: torch.Tensor,
|
||||
|
@ -3,12 +3,9 @@ from typing import Any, Dict, List, Optional
|
||||
import torch
|
||||
from pydantic import BaseModel
|
||||
|
||||
from vllm.model_executor.layers.fused_moe import FusedMoE
|
||||
from vllm.model_executor.layers.linear import LinearBase, LinearMethodBase
|
||||
from vllm.model_executor.layers.quantization.base_config import ( # noqa: E501
|
||||
QuantizationConfig, QuantizeMethodBase)
|
||||
from vllm.model_executor.layers.quantization.compressed_tensors.compressed_tensors_moe import ( # noqa: E501
|
||||
CompressedTensorsMoEMethod)
|
||||
from vllm.model_executor.layers.quantization.compressed_tensors.schemes import (
|
||||
W4A16SPARSE24_SUPPORTED_BITS, WNA16_SUPPORTED_BITS,
|
||||
CompressedTensorsScheme, CompressedTensorsUnquantized,
|
||||
@ -67,8 +64,6 @@ class CompressedTensorsConfig(QuantizationConfig):
|
||||
return CompressedTensorsLinearMethod(self)
|
||||
if isinstance(layer, Attention):
|
||||
return CompressedTensorsKVCacheMethod(self)
|
||||
if isinstance(layer, FusedMoE):
|
||||
return CompressedTensorsMoEMethod(self)
|
||||
return None
|
||||
|
||||
@classmethod
|
||||
|
@ -1,283 +0,0 @@
|
||||
import enum
|
||||
from enum import Enum
|
||||
from typing import List, Optional
|
||||
|
||||
import torch
|
||||
|
||||
from vllm import _custom_ops as ops
|
||||
from vllm.model_executor.layers.fused_moe import FusedMoEMethodBase
|
||||
from vllm.model_executor.layers.quantization.compressed_tensors.schemes import (
|
||||
WNA16_SUPPORTED_BITS)
|
||||
from vllm.model_executor.layers.quantization.compressed_tensors.utils import (
|
||||
CompressionFormat)
|
||||
from vllm.model_executor.utils import set_weight_attrs
|
||||
|
||||
|
||||
class GPTQMarlinState(Enum):
|
||||
REPACK = enum.auto()
|
||||
READY = enum.auto()
|
||||
|
||||
|
||||
__all__ = ["CompressedTensorsMoEMethod"]
|
||||
|
||||
|
||||
class CompressedTensorsMoEMethod(FusedMoEMethodBase):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
quant_config: "CompressedTensorsConfig" # type: ignore # noqa E501
|
||||
):
|
||||
self.quant_config = quant_config
|
||||
# TODO: @dsikka: refactor this to use schemes as other kernels
|
||||
# are supported + check if the layer is being ignored.
|
||||
config = self.quant_config.target_scheme_map["Linear"].get("weights")
|
||||
self.num_bits = config.num_bits
|
||||
self.packed_factor = 32 // config.num_bits
|
||||
self.strategy = config.strategy.value
|
||||
self.group_size = config.group_size
|
||||
assert config.symmetric, (
|
||||
"Only symmetric quantization is supported for MoE")
|
||||
|
||||
if not (self.quant_config.quant_format
|
||||
== CompressionFormat.pack_quantized.value
|
||||
and self.num_bits in WNA16_SUPPORTED_BITS):
|
||||
raise ValueError("For Fused MoE layers, only ",
|
||||
f"{CompressionFormat.pack_quantized.value} ",
|
||||
"is supported for the following bits: ",
|
||||
f"{WNA16_SUPPORTED_BITS}")
|
||||
|
||||
def create_weights(self, layer: torch.nn.Module, num_experts: int,
|
||||
hidden_size: int, intermediate_size: int,
|
||||
params_dtype: torch.dtype, **extra_weight_attrs):
|
||||
|
||||
# Will transpose the loaded weight along the
|
||||
# intermediate and hidden dim sizes. Will
|
||||
# shard for TP along the transposed dims
|
||||
extra_weight_attrs.update({
|
||||
"is_transposed": True,
|
||||
"quant_method": self.strategy
|
||||
})
|
||||
w13_weight = torch.nn.Parameter(torch.empty(num_experts,
|
||||
hidden_size //
|
||||
self.packed_factor,
|
||||
2 * intermediate_size,
|
||||
dtype=torch.int32),
|
||||
requires_grad=False)
|
||||
layer.register_parameter("w13_weight_packed", w13_weight)
|
||||
set_weight_attrs(w13_weight, extra_weight_attrs)
|
||||
|
||||
w2_weight = torch.nn.Parameter(torch.empty(num_experts,
|
||||
intermediate_size //
|
||||
self.packed_factor,
|
||||
hidden_size,
|
||||
dtype=torch.int32),
|
||||
requires_grad=False)
|
||||
layer.register_parameter("w2_weight_packed", w2_weight)
|
||||
set_weight_attrs(w2_weight, extra_weight_attrs)
|
||||
|
||||
if self.strategy == "channel":
|
||||
num_groups_w2 = num_groups_w13 = 1
|
||||
self.group_size = -1
|
||||
else:
|
||||
num_groups_w2 = intermediate_size // self.group_size
|
||||
num_groups_w13 = hidden_size // self.group_size
|
||||
|
||||
w13_scale = torch.nn.Parameter(torch.ones(num_experts,
|
||||
num_groups_w13,
|
||||
2 * intermediate_size,
|
||||
dtype=params_dtype),
|
||||
requires_grad=False)
|
||||
layer.register_parameter("w13_weight_scale", w13_scale)
|
||||
set_weight_attrs(w13_scale, extra_weight_attrs)
|
||||
|
||||
w2_scale = torch.nn.Parameter(torch.ones(num_experts,
|
||||
num_groups_w2,
|
||||
hidden_size,
|
||||
dtype=params_dtype),
|
||||
requires_grad=False)
|
||||
layer.register_parameter("w2_weight_scale", w2_scale)
|
||||
set_weight_attrs(w2_scale, extra_weight_attrs)
|
||||
|
||||
w2_weight_shape = torch.nn.Parameter(torch.empty(num_experts, 2),
|
||||
requires_grad=False)
|
||||
layer.register_parameter("w2_weight_shape", w2_weight_shape)
|
||||
set_weight_attrs(w2_weight_shape, extra_weight_attrs)
|
||||
w13_weight_shape = torch.nn.Parameter(torch.empty(num_experts, 2),
|
||||
requires_grad=False)
|
||||
|
||||
layer.register_parameter("w13_weight_shape", w13_weight_shape)
|
||||
set_weight_attrs(w13_weight_shape, extra_weight_attrs)
|
||||
|
||||
w13_g_idx = torch.nn.Parameter(
|
||||
torch.empty(
|
||||
num_experts,
|
||||
hidden_size,
|
||||
dtype=torch.int32,
|
||||
),
|
||||
requires_grad=False,
|
||||
)
|
||||
layer.register_parameter("w13_g_idx", w13_g_idx)
|
||||
set_weight_attrs(w13_g_idx, extra_weight_attrs)
|
||||
|
||||
w2_g_idx = torch.nn.Parameter(
|
||||
torch.empty(
|
||||
num_experts,
|
||||
intermediate_size,
|
||||
dtype=torch.int32,
|
||||
),
|
||||
requires_grad=False,
|
||||
)
|
||||
layer.register_parameter("w2_g_idx", w2_g_idx)
|
||||
set_weight_attrs(w2_g_idx, extra_weight_attrs)
|
||||
|
||||
w13_g_idx_sort_indices = torch.nn.Parameter(
|
||||
torch.empty(
|
||||
num_experts,
|
||||
hidden_size,
|
||||
dtype=torch.int32,
|
||||
),
|
||||
requires_grad=False,
|
||||
)
|
||||
layer.register_parameter("w13_g_idx_sort_indices",
|
||||
w13_g_idx_sort_indices)
|
||||
set_weight_attrs(w13_g_idx_sort_indices, extra_weight_attrs)
|
||||
|
||||
w2_g_idx_sort_indices = torch.nn.Parameter(
|
||||
torch.empty(
|
||||
num_experts,
|
||||
intermediate_size,
|
||||
dtype=torch.int32,
|
||||
),
|
||||
requires_grad=False,
|
||||
)
|
||||
layer.register_parameter("w2_g_idx_sort_indices",
|
||||
w2_g_idx_sort_indices)
|
||||
set_weight_attrs(w2_g_idx_sort_indices, extra_weight_attrs)
|
||||
|
||||
layer.a13_scale = None
|
||||
layer.a2_scale = None
|
||||
layer.marlin_state = GPTQMarlinState.REPACK
|
||||
|
||||
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
|
||||
|
||||
def replace_tensor(name, new_t):
|
||||
# It is important to use resize_() here since it ensures
|
||||
# the same buffer is reused
|
||||
getattr(layer, name).resize_(new_t.shape)
|
||||
getattr(layer, name).copy_(new_t)
|
||||
del new_t
|
||||
|
||||
def get_scale_perms(num_bits: int):
|
||||
scale_perm: List[int] = []
|
||||
for i in range(8):
|
||||
scale_perm.extend([i + 8 * j for j in range(8)])
|
||||
scale_perm_single: List[int] = []
|
||||
for i in range(4):
|
||||
scale_perm_single.extend(
|
||||
[2 * i + j for j in [0, 1, 8, 9, 16, 17, 24, 25]])
|
||||
return scale_perm, scale_perm_single
|
||||
|
||||
def marlin_permute_scales(s: torch.Tensor, size_k: int, size_n: int,
|
||||
group_size: int, num_bits: int):
|
||||
scale_perm, scale_perm_single = get_scale_perms(num_bits)
|
||||
if group_size < size_k and group_size != -1:
|
||||
s = s.reshape((-1, len(scale_perm)))[:, scale_perm]
|
||||
else:
|
||||
s = s.reshape((-1, len(scale_perm_single)))[:,
|
||||
scale_perm_single]
|
||||
s = s.reshape((-1, size_n)).contiguous()
|
||||
return s
|
||||
|
||||
def marlin_moe_permute_scales(s: torch.Tensor, size_k: int,
|
||||
size_n: int, group_size: int,
|
||||
num_bits: int):
|
||||
num_experts = s.shape[0]
|
||||
output = torch.empty((num_experts, s.shape[1], s.shape[2]),
|
||||
device=s.device,
|
||||
dtype=s.dtype)
|
||||
for e in range(num_experts):
|
||||
output[e] = marlin_permute_scales(s[e], size_k, size_n,
|
||||
group_size, num_bits)
|
||||
return output
|
||||
|
||||
size_k2 = layer.w2_weight_packed.shape[2]
|
||||
size_k13 = layer.w13_weight_packed.shape[2]
|
||||
|
||||
num_experts = layer.w13_g_idx.shape[0]
|
||||
device = layer.w13_g_idx.device
|
||||
layer.w13_g_idx = torch.nn.Parameter(
|
||||
torch.empty((num_experts, 0), dtype=torch.int32, device=device),
|
||||
requires_grad=False,
|
||||
)
|
||||
layer.w2_g_idx = torch.nn.Parameter(
|
||||
torch.empty((num_experts, 0), dtype=torch.int32, device=device),
|
||||
requires_grad=False,
|
||||
)
|
||||
layer.w13_g_idx_sort_indices = torch.nn.Parameter(
|
||||
torch.empty((num_experts, 0), dtype=torch.int32, device=device),
|
||||
requires_grad=False,
|
||||
)
|
||||
layer.w2_g_idx_sort_indices = torch.nn.Parameter(
|
||||
torch.empty((num_experts, 0), dtype=torch.int32, device=device),
|
||||
requires_grad=False,
|
||||
)
|
||||
|
||||
marlin_w13_qweight = ops.gptq_marlin_moe_repack(
|
||||
layer.w13_weight_packed,
|
||||
layer.w13_g_idx_sort_indices,
|
||||
layer.w13_weight_packed.shape[1] * self.packed_factor,
|
||||
layer.w13_weight_packed.shape[2],
|
||||
self.num_bits,
|
||||
)
|
||||
replace_tensor("w13_weight_packed", marlin_w13_qweight)
|
||||
marlin_w2_qweight = ops.gptq_marlin_moe_repack(
|
||||
layer.w2_weight_packed,
|
||||
layer.w2_g_idx_sort_indices,
|
||||
layer.w2_weight_packed.shape[1] * self.packed_factor,
|
||||
layer.w2_weight_packed.shape[2],
|
||||
self.num_bits,
|
||||
)
|
||||
replace_tensor("w2_weight_packed", marlin_w2_qweight)
|
||||
# Repack scales
|
||||
marlin_w13_scales = marlin_moe_permute_scales(
|
||||
layer.w13_weight_scale,
|
||||
size_k13,
|
||||
layer.w13_weight_scale.shape[2],
|
||||
self.group_size,
|
||||
self.num_bits,
|
||||
)
|
||||
replace_tensor("w13_weight_scale", marlin_w13_scales)
|
||||
marlin_w2_scales = marlin_moe_permute_scales(
|
||||
layer.w2_weight_scale,
|
||||
layer.w2_weight_scale.shape[1] * self.packed_factor,
|
||||
size_k2,
|
||||
self.group_size,
|
||||
self.num_bits,
|
||||
)
|
||||
replace_tensor("w2_weight_scale", marlin_w2_scales)
|
||||
|
||||
def apply(self,
|
||||
layer: torch.nn.Module,
|
||||
x: torch.Tensor,
|
||||
router_logits: torch.Tensor,
|
||||
top_k: int,
|
||||
renormalize: bool = True,
|
||||
use_grouped_topk: bool = False,
|
||||
num_expert_group: Optional[int] = None,
|
||||
topk_group: Optional[int] = None) -> torch.Tensor:
|
||||
|
||||
from vllm.model_executor.layers.fused_moe.fused_moe import (
|
||||
fused_marlin_moe)
|
||||
|
||||
return fused_marlin_moe(x,
|
||||
layer.w13_weight_packed,
|
||||
layer.w2_weight_packed,
|
||||
router_logits,
|
||||
layer.w13_g_idx,
|
||||
layer.w2_g_idx,
|
||||
layer.w13_g_idx_sort_indices,
|
||||
layer.w2_g_idx_sort_indices,
|
||||
top_k,
|
||||
renormalize=renormalize,
|
||||
w1_scale=layer.w13_weight_scale,
|
||||
w2_scale=layer.w2_weight_scale)
|
@ -7,8 +7,7 @@ from torch.nn.parameter import Parameter
|
||||
import vllm.envs as envs
|
||||
from vllm import _custom_ops as ops
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.fused_moe import (FusedMoE, FusedMoEMethodBase,
|
||||
FusedMoeWeightScaleSupported)
|
||||
from vllm.model_executor.layers.fused_moe import FusedMoE, FusedMoEMethodBase
|
||||
from vllm.model_executor.layers.linear import (LinearBase, LinearMethodBase,
|
||||
UnquantizedLinearMethod)
|
||||
from vllm.model_executor.layers.quantization.base_config import (
|
||||
@ -319,16 +318,19 @@ class Fp8MoEMethod(FusedMoEMethodBase):
|
||||
dtype=torch.float32),
|
||||
requires_grad=False)
|
||||
layer.register_parameter("w2_weight_scale", w2_weight_scale)
|
||||
# Add the quantization method used (per tensor/grouped/channel)
|
||||
# to ensure the weight scales are loaded in properly
|
||||
extra_weight_attrs.update(
|
||||
{"quant_method": FusedMoeWeightScaleSupported.TENSOR.value})
|
||||
|
||||
# If loading fp8 checkpoint, pass the weight loaders.
|
||||
# If loading an fp16 checkpoint, do not (we will quantize in
|
||||
# process_weights_after_loading()
|
||||
if self.quant_config.is_checkpoint_fp8_serialized:
|
||||
set_weight_attrs(w13_weight_scale, extra_weight_attrs)
|
||||
set_weight_attrs(w2_weight_scale, extra_weight_attrs)
|
||||
set_weight_attrs(w13_weight_scale, {
|
||||
"is_fp8_scale": True,
|
||||
**extra_weight_attrs
|
||||
})
|
||||
set_weight_attrs(w2_weight_scale, {
|
||||
"is_fp8_scale": True,
|
||||
**extra_weight_attrs
|
||||
})
|
||||
|
||||
# INPUT_SCALES
|
||||
if self.quant_config.activation_scheme == "static":
|
||||
@ -341,14 +343,19 @@ class Fp8MoEMethod(FusedMoEMethodBase):
|
||||
num_experts, dtype=torch.float32),
|
||||
requires_grad=False)
|
||||
layer.register_parameter("w13_input_scale", w13_input_scale)
|
||||
set_weight_attrs(w13_input_scale, extra_weight_attrs)
|
||||
set_weight_attrs(w13_input_scale, {
|
||||
"is_fp8_scale": True,
|
||||
**extra_weight_attrs
|
||||
})
|
||||
|
||||
w2_input_scale = torch.nn.Parameter(torch.ones(
|
||||
num_experts, dtype=torch.float32),
|
||||
requires_grad=False)
|
||||
layer.register_parameter("w2_input_scale", w2_input_scale)
|
||||
set_weight_attrs(w2_input_scale, extra_weight_attrs)
|
||||
|
||||
set_weight_attrs(w2_input_scale, {
|
||||
"is_fp8_scale": True,
|
||||
**extra_weight_attrs
|
||||
})
|
||||
else:
|
||||
layer.w13_input_scale = None
|
||||
layer.w2_input_scale = None
|
||||
|
@ -23,11 +23,11 @@ def get_model_architecture(
|
||||
architectures = getattr(model_config.hf_config, "architectures", [])
|
||||
# Special handling for quantized Mixtral.
|
||||
# FIXME(woosuk): This is a temporary hack.
|
||||
mixtral_supported = ["fp8", "compressed-tensors"]
|
||||
if (model_config.quantization is not None
|
||||
and model_config.quantization not in mixtral_supported
|
||||
and model_config.quantization != "fp8"
|
||||
and "MixtralForCausalLM" in architectures):
|
||||
architectures = ["QuantMixtralForCausalLM"]
|
||||
|
||||
return ModelRegistry.resolve_model_cls(architectures)
|
||||
|
||||
|
||||
|
@ -920,7 +920,7 @@ class JambaForCausalLM(nn.Module, HasInnerState):
|
||||
weight_loader = param.weight_loader
|
||||
weight_loader(param,
|
||||
loaded_weight,
|
||||
name,
|
||||
weight_name,
|
||||
shard_id=shard_id,
|
||||
expert_id=expert_id)
|
||||
break
|
||||
|
@ -73,7 +73,6 @@ class MixtralMoE(nn.Module):
|
||||
self.hidden_size = hidden_size
|
||||
|
||||
# Gate always runs at half / full precision for now.
|
||||
|
||||
self.gate = ReplicatedLinear(hidden_size,
|
||||
num_experts,
|
||||
bias=False,
|
||||
|
Loading…
x
Reference in New Issue
Block a user