[ Misc ] Refactor MoE to isolate Fp8 From Mixtral (#5970)
Co-authored-by: Robert Shaw <rshaw@neuralmagic> Co-authored-by: Michael Goin <michael@neuralmagic.com>
This commit is contained in:
parent
4d26d806e1
commit
7c008c51a9
@ -0,0 +1,11 @@
|
|||||||
|
# bash ./run-lm-eval-gsm-vllm-baseline.sh -m neuralmagic/Mixtral-8x22B-Instruct-v0.1-FP8-dynamic -b "auto" -l 250 -f 5 -t 8
|
||||||
|
model_name: "neuralmagic/Mixtral-8x22B-Instruct-v0.1-FP8-dynamic"
|
||||||
|
tasks:
|
||||||
|
- name: "gsm8k"
|
||||||
|
metrics:
|
||||||
|
- name: "exact_match,strict-match"
|
||||||
|
value: 0.86
|
||||||
|
- name: "exact_match,flexible-extract"
|
||||||
|
value: 0.86
|
||||||
|
limit: 250
|
||||||
|
num_fewshot: 5
|
@ -0,0 +1,11 @@
|
|||||||
|
# bash ./run-lm-eval-gsm-vllm-baseline.sh -m neuralmagic/Mixtral-8x7B-Instruct-v0.1-FP8 -b "auto" -l 250 -f 5 -t 4
|
||||||
|
model_name: "neuralmagic/Mixtral-8x7B-Instruct-v0.1-FP8"
|
||||||
|
tasks:
|
||||||
|
- name: "gsm8k"
|
||||||
|
metrics:
|
||||||
|
- name: "exact_match,strict-match"
|
||||||
|
value: 0.624
|
||||||
|
- name: "exact_match,flexible-extract"
|
||||||
|
value: 0.624
|
||||||
|
limit: 250
|
||||||
|
num_fewshot: 5
|
@ -0,0 +1,11 @@
|
|||||||
|
# bash ./run-lm-eval-gsm-vllm-baseline.sh -m Qwen/Qwen2-57B-A14B-Instruct -b "auto" -l 250 -f 5 -t 4
|
||||||
|
model_name: "Qwen/Qwen2-57B-A14B-Instruct"
|
||||||
|
tasks:
|
||||||
|
- name: "gsm8k"
|
||||||
|
metrics:
|
||||||
|
- name: "exact_match,strict-match"
|
||||||
|
value: 0.792
|
||||||
|
- name: "exact_match,flexible-extract"
|
||||||
|
value: 0.824
|
||||||
|
limit: 250
|
||||||
|
num_fewshot: 5
|
@ -1,2 +1,3 @@
|
|||||||
Meta-Llama-3-70B-Instruct.yaml
|
Meta-Llama-3-70B-Instruct.yaml
|
||||||
Mixtral-8x7B-Instruct-v0.1.yaml
|
Mixtral-8x7B-Instruct-v0.1.yaml
|
||||||
|
Qwen2-57B-A14-Instruct.yaml
|
||||||
|
@ -77,8 +77,8 @@ def test_mixtral_moe(dtype: torch.dtype):
|
|||||||
for i in range(config.num_local_experts):
|
for i in range(config.num_local_experts):
|
||||||
weights = (hf_moe.experts[i].w1.weight.data,
|
weights = (hf_moe.experts[i].w1.weight.data,
|
||||||
hf_moe.experts[i].w3.weight.data)
|
hf_moe.experts[i].w3.weight.data)
|
||||||
vllm_moe.w13_weight[i][:] = torch.cat(weights, dim=0)
|
vllm_moe.experts.w13_weight[i][:] = torch.cat(weights, dim=0)
|
||||||
vllm_moe.w2_weight[i][:] = hf_moe.experts[i].w2.weight.data
|
vllm_moe.experts.w2_weight[i][:] = hf_moe.experts[i].w2.weight.data
|
||||||
|
|
||||||
# Generate input batch of dimensions [batch_size, seq_len, hidden_dim]
|
# Generate input batch of dimensions [batch_size, seq_len, hidden_dim]
|
||||||
hf_inputs = torch.randn((1, 64, config.hidden_size)).to(dtype).to("cuda")
|
hf_inputs = torch.randn((1, 64, config.hidden_size)).to(dtype).to("cuda")
|
||||||
|
@ -1,5 +1,7 @@
|
|||||||
from vllm.model_executor.layers.fused_moe.fused_moe import (
|
from vllm.model_executor.layers.fused_moe.fused_moe import (
|
||||||
fused_experts, fused_moe, fused_topk, get_config_file_name, grouped_topk)
|
fused_experts, fused_moe, fused_topk, get_config_file_name, grouped_topk)
|
||||||
|
from vllm.model_executor.layers.fused_moe.layer import (FusedMoE,
|
||||||
|
FusedMoEMethodBase)
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"fused_moe",
|
"fused_moe",
|
||||||
@ -7,4 +9,6 @@ __all__ = [
|
|||||||
"fused_experts",
|
"fused_experts",
|
||||||
"get_config_file_name",
|
"get_config_file_name",
|
||||||
"grouped_topk",
|
"grouped_topk",
|
||||||
|
"FusedMoE",
|
||||||
|
"FusedMoEMethodBase",
|
||||||
]
|
]
|
||||||
|
197
vllm/model_executor/layers/fused_moe/layer.py
Normal file
197
vllm/model_executor/layers/fused_moe/layer.py
Normal file
@ -0,0 +1,197 @@
|
|||||||
|
from abc import abstractmethod
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from vllm.distributed import (get_tensor_model_parallel_rank,
|
||||||
|
get_tensor_model_parallel_world_size,
|
||||||
|
tensor_model_parallel_all_reduce)
|
||||||
|
from vllm.logger import init_logger
|
||||||
|
from vllm.model_executor.layers.fused_moe.fused_moe import fused_moe
|
||||||
|
from vllm.model_executor.layers.quantization.base_config import (
|
||||||
|
QuantizationConfig, QuantizeMethodBase)
|
||||||
|
from vllm.model_executor.utils import set_weight_attrs
|
||||||
|
|
||||||
|
logger = init_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class FusedMoEMethodBase(QuantizeMethodBase):
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def create_weights(self, layer: torch.nn.Module, num_experts: int,
|
||||||
|
hidden_size: int, intermediate_size: int,
|
||||||
|
params_dtype: torch.dtype, **extra_weight_attrs):
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def apply(self,
|
||||||
|
layer: torch.nn.Module,
|
||||||
|
x: torch.Tensor,
|
||||||
|
router_logits: torch.Tensor,
|
||||||
|
top_k: int,
|
||||||
|
renormalize: bool = True) -> torch.Tensor:
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
|
||||||
|
class UnquantizedFusedMoEMethod(FusedMoEMethodBase):
|
||||||
|
"""MoE method without quantization."""
|
||||||
|
|
||||||
|
def create_weights(self, layer: torch.nn.Module, num_experts: int,
|
||||||
|
hidden_size: int, intermediate_size: int,
|
||||||
|
params_dtype: torch.dtype, **extra_weight_attrs):
|
||||||
|
|
||||||
|
# Fused gate_up_proj (column parallel)
|
||||||
|
w13_weight = torch.nn.Parameter(torch.empty(num_experts,
|
||||||
|
2 * intermediate_size,
|
||||||
|
hidden_size,
|
||||||
|
dtype=params_dtype),
|
||||||
|
requires_grad=False)
|
||||||
|
layer.register_parameter("w13_weight", w13_weight)
|
||||||
|
set_weight_attrs(w13_weight, extra_weight_attrs)
|
||||||
|
|
||||||
|
# down_proj (row parallel)
|
||||||
|
w2_weight = torch.nn.Parameter(torch.empty(num_experts,
|
||||||
|
hidden_size,
|
||||||
|
intermediate_size,
|
||||||
|
dtype=params_dtype),
|
||||||
|
requires_grad=False)
|
||||||
|
layer.register_parameter("w2_weight", w2_weight)
|
||||||
|
set_weight_attrs(w2_weight, extra_weight_attrs)
|
||||||
|
|
||||||
|
def apply(self,
|
||||||
|
layer: torch.nn.Module,
|
||||||
|
x: torch.Tensor,
|
||||||
|
router_logits: torch.Tensor,
|
||||||
|
top_k: int,
|
||||||
|
renormalize: bool = True) -> torch.Tensor:
|
||||||
|
|
||||||
|
return fused_moe(x,
|
||||||
|
layer.w13_weight,
|
||||||
|
layer.w2_weight,
|
||||||
|
router_logits,
|
||||||
|
top_k,
|
||||||
|
renormalize=renormalize,
|
||||||
|
inplace=True)
|
||||||
|
|
||||||
|
|
||||||
|
class FusedMoE(torch.nn.Module):
|
||||||
|
"""FusedMoE layer for MoE models.
|
||||||
|
|
||||||
|
This layer contains both MergedColumnParallel weights (gate_up_proj /
|
||||||
|
w13) and RowParallelLinear weights (down_proj/ w2).
|
||||||
|
|
||||||
|
Note: Mixtral uses w1, w2, and w3 for gate, up, and down_proj. We
|
||||||
|
copy that naming convention here and handle any remapping in the
|
||||||
|
load_weights function in each model implementation.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
num_experts: Number of experts in the model
|
||||||
|
top_k: Number of experts selected for each token
|
||||||
|
hidden_size: Input hidden state size of the transformer
|
||||||
|
intermediate_size: Intermediate size of the experts
|
||||||
|
params_dtype: Data type for the parameters.
|
||||||
|
reduce_results: Whether to all all_reduce on the output of the layer
|
||||||
|
renomalize: Whether to renormalize the logits in the fused_moe kernel
|
||||||
|
quant_config: Quantization configure.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
num_experts: int,
|
||||||
|
top_k: int,
|
||||||
|
hidden_size: int,
|
||||||
|
intermediate_size: int,
|
||||||
|
params_dtype: Optional[torch.dtype] = None,
|
||||||
|
reduce_results: bool = False,
|
||||||
|
renormalize: bool = True,
|
||||||
|
quant_config: Optional[QuantizationConfig] = None,
|
||||||
|
tp_size: Optional[int] = None,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
if params_dtype is None:
|
||||||
|
params_dtype = torch.get_default_dtype()
|
||||||
|
|
||||||
|
self.tp_size = (tp_size if tp_size is not None else
|
||||||
|
get_tensor_model_parallel_world_size())
|
||||||
|
self.top_k = top_k
|
||||||
|
self.num_experts = num_experts
|
||||||
|
self.intermediate_size_per_partition = intermediate_size // self.tp_size
|
||||||
|
self.reduce_results = reduce_results
|
||||||
|
self.renormalize = renormalize
|
||||||
|
|
||||||
|
if quant_config is None:
|
||||||
|
self.quant_method: Optional[QuantizeMethodBase] = (
|
||||||
|
UnquantizedFusedMoEMethod())
|
||||||
|
else:
|
||||||
|
self.quant_method = quant_config.get_quant_method(self)
|
||||||
|
assert self.quant_method is not None
|
||||||
|
|
||||||
|
self.quant_method.create_weights(
|
||||||
|
layer=self,
|
||||||
|
num_experts=num_experts,
|
||||||
|
hidden_size=hidden_size,
|
||||||
|
intermediate_size=self.intermediate_size_per_partition,
|
||||||
|
params_dtype=params_dtype,
|
||||||
|
weight_loader=self.weight_loader)
|
||||||
|
|
||||||
|
def weight_loader(self, param: torch.nn.Parameter,
|
||||||
|
loaded_weight: torch.Tensor, weight_name: str,
|
||||||
|
shard_id: int, expert_id: int):
|
||||||
|
param_data = param.data
|
||||||
|
|
||||||
|
# FIXME(robertgshaw2-neuralmagic): Overfit to Mixtral.
|
||||||
|
# Follow up PR to enable fp8 for other MoE models.
|
||||||
|
if "input_scale" in weight_name or "w2.weight_scale" in weight_name:
|
||||||
|
if param_data[expert_id] != 1 and (param_data[expert_id] -
|
||||||
|
loaded_weight).abs() > 1e-5:
|
||||||
|
raise ValueError(
|
||||||
|
"input_scales of w1 and w3 of a layer "
|
||||||
|
f"must be equal. But got {param_data[expert_id]} "
|
||||||
|
f"vs. {loaded_weight}")
|
||||||
|
param_data[expert_id] = loaded_weight
|
||||||
|
# FIXME(robertgshaw2-neuralmagic): Overfit to Mixtral.
|
||||||
|
# Follow up PR to enable fp8 for other MoE models.
|
||||||
|
elif "weight_scale" in weight_name:
|
||||||
|
# We have to keep the weight scales of w1 and w3 because
|
||||||
|
# we need to re-quantize w1/w3 weights after weight loading.
|
||||||
|
assert "w1" in weight_name or "w3" in weight_name
|
||||||
|
shard_id = 0 if "w1" in weight_name else 1
|
||||||
|
param_data[expert_id][shard_id] = loaded_weight
|
||||||
|
else:
|
||||||
|
tp_rank = get_tensor_model_parallel_rank()
|
||||||
|
shard_size = self.intermediate_size_per_partition
|
||||||
|
shard = slice(tp_rank * shard_size, (tp_rank + 1) * shard_size)
|
||||||
|
|
||||||
|
# w1, gate_proj case: Load into first shard of w13.
|
||||||
|
if shard_id == 0:
|
||||||
|
param_data[expert_id,
|
||||||
|
0:shard_size, :] = loaded_weight[shard, :]
|
||||||
|
# w3, up_proj case: Load into second shard of w13.
|
||||||
|
elif shard_id == 2:
|
||||||
|
param_data[expert_id, shard_size:2 *
|
||||||
|
shard_size, :] = loaded_weight[shard, :]
|
||||||
|
# w2, down_proj case: Load into only shard of w2.
|
||||||
|
elif shard_id == 1:
|
||||||
|
param_data[expert_id, :, :] = loaded_weight[:, shard]
|
||||||
|
else:
|
||||||
|
raise ValueError(
|
||||||
|
f"Shard id must be in [0,1,2] but got {shard_id}")
|
||||||
|
|
||||||
|
def forward(self, hidden_states: torch.Tensor,
|
||||||
|
router_logits: torch.Tensor):
|
||||||
|
assert self.quant_method is not None
|
||||||
|
|
||||||
|
# Matrix multiply.
|
||||||
|
final_hidden_states = self.quant_method.apply(
|
||||||
|
self,
|
||||||
|
x=hidden_states,
|
||||||
|
router_logits=router_logits,
|
||||||
|
top_k=self.top_k,
|
||||||
|
renormalize=self.renormalize)
|
||||||
|
|
||||||
|
if self.reduce_results and self.tp_size > 1:
|
||||||
|
final_hidden_states = tensor_model_parallel_all_reduce(
|
||||||
|
final_hidden_states)
|
||||||
|
|
||||||
|
return final_hidden_states
|
@ -6,6 +6,8 @@ from torch.nn.parameter import Parameter
|
|||||||
|
|
||||||
from vllm import _custom_ops as ops
|
from vllm import _custom_ops as ops
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
|
from vllm.model_executor.layers.fused_moe import (FusedMoE, FusedMoEMethodBase,
|
||||||
|
fused_moe)
|
||||||
from vllm.model_executor.layers.linear import LinearBase, LinearMethodBase
|
from vllm.model_executor.layers.linear import LinearBase, LinearMethodBase
|
||||||
from vllm.model_executor.layers.quantization.base_config import (
|
from vllm.model_executor.layers.quantization.base_config import (
|
||||||
QuantizationConfig, QuantizeMethodBase)
|
QuantizationConfig, QuantizeMethodBase)
|
||||||
@ -71,7 +73,9 @@ class Fp8Config(QuantizationConfig):
|
|||||||
|
|
||||||
if isinstance(layer, LinearBase):
|
if isinstance(layer, LinearBase):
|
||||||
return Fp8LinearMethod(self)
|
return Fp8LinearMethod(self)
|
||||||
if isinstance(layer, Attention):
|
elif isinstance(layer, FusedMoE):
|
||||||
|
return Fp8MoEMethod(self)
|
||||||
|
elif isinstance(layer, Attention):
|
||||||
return Fp8KVCacheMethod(self)
|
return Fp8KVCacheMethod(self)
|
||||||
return None
|
return None
|
||||||
|
|
||||||
@ -270,6 +274,187 @@ class Fp8LinearMethod(LinearMethodBase):
|
|||||||
return torch.narrow(output, 0, 0, x.shape[0])
|
return torch.narrow(output, 0, 0, x.shape[0])
|
||||||
|
|
||||||
|
|
||||||
|
class Fp8MoEMethod(FusedMoEMethodBase):
|
||||||
|
"""MoE method for FP8.
|
||||||
|
Supports loading FP8 checkpoints with static weight scale and
|
||||||
|
dynamic/static activation scale.
|
||||||
|
|
||||||
|
Also supports loading quantized FP16/BF16 model checkpoints with dynamic
|
||||||
|
activation scaling. The weight scaling factor will be initialized after
|
||||||
|
the model weights are loaded.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
quant_config: The quantization config.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, quant_config: Fp8Config):
|
||||||
|
self.quant_config = quant_config
|
||||||
|
|
||||||
|
def create_weights(self, layer: Module, num_experts: int, hidden_size: int,
|
||||||
|
intermediate_size: int, params_dtype: torch.dtype,
|
||||||
|
**extra_weight_attrs):
|
||||||
|
|
||||||
|
layer.process_after_load = True
|
||||||
|
|
||||||
|
if self.quant_config.is_checkpoint_fp8_serialized:
|
||||||
|
params_dtype = torch.float8_e4m3fn
|
||||||
|
|
||||||
|
# WEIGHTS
|
||||||
|
w13_weight = torch.nn.Parameter(torch.empty(num_experts,
|
||||||
|
2 * intermediate_size,
|
||||||
|
hidden_size,
|
||||||
|
dtype=params_dtype),
|
||||||
|
requires_grad=False)
|
||||||
|
layer.register_parameter("w13_weight", w13_weight)
|
||||||
|
set_weight_attrs(w13_weight, extra_weight_attrs)
|
||||||
|
|
||||||
|
w2_weight = torch.nn.Parameter(torch.empty(num_experts,
|
||||||
|
hidden_size,
|
||||||
|
intermediate_size,
|
||||||
|
dtype=params_dtype),
|
||||||
|
requires_grad=False)
|
||||||
|
layer.register_parameter("w2_weight", w2_weight)
|
||||||
|
set_weight_attrs(w2_weight, extra_weight_attrs)
|
||||||
|
|
||||||
|
# WEIGHT_SCALES
|
||||||
|
# Allocate 2 scales for w1 and w3 respectively.
|
||||||
|
# They will be combined to a single scale after weight loading.
|
||||||
|
w13_scale = torch.nn.Parameter(torch.ones(num_experts,
|
||||||
|
2,
|
||||||
|
dtype=torch.float32),
|
||||||
|
requires_grad=False)
|
||||||
|
layer.register_parameter("w13_scale", w13_scale)
|
||||||
|
|
||||||
|
w2_scale = torch.nn.Parameter(torch.ones(num_experts,
|
||||||
|
dtype=torch.float32),
|
||||||
|
requires_grad=False)
|
||||||
|
layer.register_parameter("w2_scale", w2_scale)
|
||||||
|
|
||||||
|
# If loading fp8 checkpoint, pass the weight loaders.
|
||||||
|
# If loading an fp16 checkpoint, do not (we will quantize in
|
||||||
|
# process_weights_after_loading()
|
||||||
|
if self.quant_config.is_checkpoint_fp8_serialized:
|
||||||
|
set_weight_attrs(w13_scale, extra_weight_attrs)
|
||||||
|
set_weight_attrs(w2_scale, extra_weight_attrs)
|
||||||
|
|
||||||
|
# INPUT_SCALES
|
||||||
|
if self.quant_config.activation_scheme == "static":
|
||||||
|
if not self.quant_config.is_checkpoint_fp8_serialized:
|
||||||
|
raise ValueError(
|
||||||
|
"Found static activation scheme for checkpoint that "
|
||||||
|
"was not serialized fp8.")
|
||||||
|
|
||||||
|
a13_scale = torch.nn.Parameter(torch.ones(num_experts,
|
||||||
|
dtype=torch.float32),
|
||||||
|
requires_grad=False)
|
||||||
|
layer.register_parameter("a13_scale", a13_scale)
|
||||||
|
set_weight_attrs(a13_scale, extra_weight_attrs)
|
||||||
|
|
||||||
|
a2_scale = torch.nn.Parameter(torch.ones(num_experts,
|
||||||
|
dtype=torch.float32),
|
||||||
|
requires_grad=False)
|
||||||
|
layer.register_parameter("a2_scale", a2_scale)
|
||||||
|
set_weight_attrs(a2_scale, extra_weight_attrs)
|
||||||
|
else:
|
||||||
|
layer.a13_scale = None
|
||||||
|
layer.a2_scale = None
|
||||||
|
|
||||||
|
def process_weights_after_loading(self, layer: Module) -> None:
|
||||||
|
if (not hasattr(layer, "process_after_load")
|
||||||
|
or not layer.process_after_load):
|
||||||
|
return
|
||||||
|
|
||||||
|
# If checkpoint is fp16, quantize in place.
|
||||||
|
if not self.quant_config.is_checkpoint_fp8_serialized:
|
||||||
|
w13_weight = torch.empty_like(layer.w13_weight.data,
|
||||||
|
dtype=torch.float8_e4m3fn)
|
||||||
|
w2_weight = torch.empty_like(layer.w2_weight.data,
|
||||||
|
dtype=torch.float8_e4m3fn)
|
||||||
|
|
||||||
|
# Re-initialize w13_scale because we directly quantize
|
||||||
|
# merged w13 weights and generate a single scaling factor.
|
||||||
|
layer.w13_scale = torch.nn.Parameter(torch.ones(
|
||||||
|
layer.num_experts,
|
||||||
|
dtype=torch.float32,
|
||||||
|
device=w13_weight.device),
|
||||||
|
requires_grad=False)
|
||||||
|
for expert in range(layer.num_experts):
|
||||||
|
w13_weight[expert, :, :], layer.w13_scale[
|
||||||
|
expert] = ops.scaled_fp8_quant(
|
||||||
|
layer.w13_weight.data[expert, :, :])
|
||||||
|
w2_weight[expert, :, :], layer.w2_scale[
|
||||||
|
expert] = ops.scaled_fp8_quant(
|
||||||
|
layer.w2_weight.data[expert, :, :])
|
||||||
|
layer.w13_weight = torch.nn.Parameter(w13_weight,
|
||||||
|
requires_grad=False)
|
||||||
|
layer.w2_weight = torch.nn.Parameter(w2_weight,
|
||||||
|
requires_grad=False)
|
||||||
|
return
|
||||||
|
|
||||||
|
# If checkpoint is fp8, we need to handle that the
|
||||||
|
# MoE kernels require single activation scale and single weight
|
||||||
|
# scale for w13 per expert.
|
||||||
|
else:
|
||||||
|
# Fp8 moe kernels require a single activation scale.
|
||||||
|
# We take the max of all the scales in case they differ.
|
||||||
|
if self.quant_config.activation_scheme == "static":
|
||||||
|
if layer.a13_scale is None or layer.a2_scale is None:
|
||||||
|
raise ValueError(
|
||||||
|
"QuantConfig has static quantization, but found "
|
||||||
|
"activation scales are None.")
|
||||||
|
if (not all_close_1d(layer.a13_scale)
|
||||||
|
or not all_close_1d(layer.a2_scale)):
|
||||||
|
print_warning_once(
|
||||||
|
"Found input_scales that are not equal for "
|
||||||
|
"fp8 MoE layer. Using the maximum across experts "
|
||||||
|
"for each layer. ")
|
||||||
|
layer.a13_scale = torch.nn.Parameter(layer.a13_scale.max(),
|
||||||
|
requires_grad=False)
|
||||||
|
layer.a2_scale = torch.nn.Parameter(layer.a2_scale.max(),
|
||||||
|
requires_grad=False)
|
||||||
|
|
||||||
|
# Fp8 moe kernel needs single weight scale for w13 per expert.
|
||||||
|
# We take the max then dequant and requant each expert.
|
||||||
|
assert layer.w13_scale is not None
|
||||||
|
shard_size = layer.intermediate_size_per_partition
|
||||||
|
max_w13_scales = layer.w13_scale.max(dim=1).values
|
||||||
|
for expert_id in range(layer.num_experts):
|
||||||
|
start = 0
|
||||||
|
for shard_id in range(2):
|
||||||
|
dq_weight = per_tensor_dequantize(
|
||||||
|
layer.w13_weight[expert_id][start:start +
|
||||||
|
shard_size, :],
|
||||||
|
layer.w13_scale[expert_id][shard_id])
|
||||||
|
layer.w13_weight[expert_id][
|
||||||
|
start:start + shard_size, :] = per_tensor_quantize(
|
||||||
|
dq_weight, max_w13_scales[expert_id])
|
||||||
|
start += shard_size
|
||||||
|
|
||||||
|
layer.w13_scale = torch.nn.Parameter(max_w13_scales,
|
||||||
|
requires_grad=False)
|
||||||
|
return
|
||||||
|
|
||||||
|
def apply(self,
|
||||||
|
layer: torch.nn.Module,
|
||||||
|
x: torch.Tensor,
|
||||||
|
router_logits: torch.Tensor,
|
||||||
|
top_k: int,
|
||||||
|
renormalize: bool = True) -> torch.Tensor:
|
||||||
|
|
||||||
|
return fused_moe(x,
|
||||||
|
layer.w13_weight,
|
||||||
|
layer.w2_weight,
|
||||||
|
router_logits,
|
||||||
|
top_k,
|
||||||
|
renormalize=renormalize,
|
||||||
|
inplace=True,
|
||||||
|
use_fp8=True,
|
||||||
|
w1_scale=layer.w13_scale,
|
||||||
|
w2_scale=layer.w2_scale,
|
||||||
|
a1_scale=layer.a13_scale,
|
||||||
|
a2_scale=layer.a2_scale)
|
||||||
|
|
||||||
|
|
||||||
class Fp8KVCacheMethod(QuantizeMethodBase):
|
class Fp8KVCacheMethod(QuantizeMethodBase):
|
||||||
"""Supports loading kv-cache scaling factors from FP8 checkpoints.
|
"""Supports loading kv-cache scaling factors from FP8 checkpoints.
|
||||||
"""
|
"""
|
||||||
@ -321,3 +506,8 @@ def per_tensor_dequantize(
|
|||||||
fake_qweight = tensor.to(torch.float16)
|
fake_qweight = tensor.to(torch.float16)
|
||||||
dq_weight = fake_qweight * inv_scale
|
dq_weight = fake_qweight * inv_scale
|
||||||
return dq_weight
|
return dq_weight
|
||||||
|
|
||||||
|
|
||||||
|
def all_close_1d(x: torch.Tensor) -> bool:
|
||||||
|
assert len(x.shape) == 1
|
||||||
|
return all(torch.allclose(x[0], x[i]) for i in range(x.shape[0]))
|
||||||
|
@ -27,13 +27,10 @@ import torch
|
|||||||
from torch import nn
|
from torch import nn
|
||||||
from transformers import MixtralConfig
|
from transformers import MixtralConfig
|
||||||
|
|
||||||
from vllm import _custom_ops as ops
|
|
||||||
from vllm.attention import Attention, AttentionMetadata
|
from vllm.attention import Attention, AttentionMetadata
|
||||||
from vllm.config import CacheConfig, LoRAConfig
|
from vllm.config import CacheConfig, LoRAConfig
|
||||||
from vllm.distributed import (get_tensor_model_parallel_rank,
|
from vllm.distributed import get_tensor_model_parallel_world_size
|
||||||
get_tensor_model_parallel_world_size,
|
from vllm.model_executor.layers.fused_moe import FusedMoE
|
||||||
tensor_model_parallel_all_reduce)
|
|
||||||
from vllm.model_executor.layers.fused_moe import fused_moe
|
|
||||||
from vllm.model_executor.layers.layernorm import RMSNorm
|
from vllm.model_executor.layers.layernorm import RMSNorm
|
||||||
from vllm.model_executor.layers.linear import (QKVParallelLinear,
|
from vllm.model_executor.layers.linear import (QKVParallelLinear,
|
||||||
ReplicatedLinear,
|
ReplicatedLinear,
|
||||||
@ -41,16 +38,12 @@ from vllm.model_executor.layers.linear import (QKVParallelLinear,
|
|||||||
from vllm.model_executor.layers.logits_processor import LogitsProcessor
|
from vllm.model_executor.layers.logits_processor import LogitsProcessor
|
||||||
from vllm.model_executor.layers.quantization.base_config import (
|
from vllm.model_executor.layers.quantization.base_config import (
|
||||||
QuantizationConfig)
|
QuantizationConfig)
|
||||||
from vllm.model_executor.layers.quantization.fp8 import (Fp8Config,
|
|
||||||
per_tensor_dequantize,
|
|
||||||
per_tensor_quantize)
|
|
||||||
from vllm.model_executor.layers.rotary_embedding import get_rope
|
from vllm.model_executor.layers.rotary_embedding import get_rope
|
||||||
from vllm.model_executor.layers.sampler import Sampler
|
from vllm.model_executor.layers.sampler import Sampler
|
||||||
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
||||||
DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding)
|
DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding)
|
||||||
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
||||||
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
||||||
from vllm.model_executor.utils import set_weight_attrs
|
|
||||||
from vllm.sequence import IntermediateTensors, SamplerOutput
|
from vllm.sequence import IntermediateTensors, SamplerOutput
|
||||||
from vllm.utils import print_warning_once
|
from vllm.utils import print_warning_once
|
||||||
|
|
||||||
@ -66,227 +59,40 @@ class MixtralMoE(nn.Module):
|
|||||||
across ranks.
|
across ranks.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(
|
def __init__(self,
|
||||||
self,
|
|
||||||
num_experts: int,
|
num_experts: int,
|
||||||
top_k: int,
|
top_k: int,
|
||||||
hidden_size: int,
|
hidden_size: int,
|
||||||
intermediate_size: int,
|
intermediate_size: int,
|
||||||
params_dtype: Optional[torch.dtype] = None,
|
params_dtype: Optional[torch.dtype] = None,
|
||||||
tp_size: Optional[int] = None,
|
|
||||||
quant_config: Optional[QuantizationConfig] = None,
|
quant_config: Optional[QuantizationConfig] = None,
|
||||||
):
|
tp_size: Optional[int] = None):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.tp_size = tp_size or get_tensor_model_parallel_world_size()
|
|
||||||
self.num_total_experts = num_experts
|
|
||||||
self.top_k = top_k
|
|
||||||
self.hidden_size = hidden_size
|
self.hidden_size = hidden_size
|
||||||
self.intermediate_size = intermediate_size // self.tp_size
|
|
||||||
self.quant_config = quant_config
|
|
||||||
|
|
||||||
# FIXME(pcmoritz): Make this more general to support different
|
|
||||||
# quantization schemes
|
|
||||||
self.use_fp8 = isinstance(quant_config, Fp8Config)
|
|
||||||
|
|
||||||
if params_dtype is None:
|
|
||||||
params_dtype = torch.get_default_dtype()
|
|
||||||
self.params_dtype = params_dtype
|
|
||||||
|
|
||||||
# Gate always runs at half / full precision for now.
|
# Gate always runs at half / full precision for now.
|
||||||
self.gate = ReplicatedLinear(self.hidden_size,
|
self.gate = ReplicatedLinear(hidden_size,
|
||||||
self.num_total_experts,
|
num_experts,
|
||||||
bias=False,
|
bias=False,
|
||||||
params_dtype=self.params_dtype,
|
params_dtype=params_dtype,
|
||||||
quant_config=None)
|
quant_config=None)
|
||||||
|
|
||||||
if self.use_fp8 and self.quant_config.is_checkpoint_fp8_serialized:
|
self.experts = FusedMoE(num_experts=num_experts,
|
||||||
params_dtype = torch.float8_e4m3fn
|
top_k=top_k,
|
||||||
|
hidden_size=hidden_size,
|
||||||
self.w13_weight = nn.Parameter(torch.empty(self.num_total_experts,
|
intermediate_size=intermediate_size,
|
||||||
2 * self.intermediate_size,
|
params_dtype=params_dtype,
|
||||||
self.hidden_size,
|
reduce_results=True,
|
||||||
dtype=params_dtype),
|
renormalize=True,
|
||||||
requires_grad=False)
|
quant_config=quant_config,
|
||||||
self.w2_weight = nn.Parameter(torch.empty(self.num_total_experts,
|
tp_size=tp_size)
|
||||||
self.hidden_size,
|
|
||||||
self.intermediate_size,
|
|
||||||
dtype=params_dtype),
|
|
||||||
requires_grad=False)
|
|
||||||
|
|
||||||
set_weight_attrs(self.w13_weight, {
|
|
||||||
"weight_loader": self.weight_loader,
|
|
||||||
})
|
|
||||||
set_weight_attrs(self.w2_weight, {
|
|
||||||
"weight_loader": self.weight_loader,
|
|
||||||
})
|
|
||||||
|
|
||||||
# Used for fp8.
|
|
||||||
self.w13_scale = None
|
|
||||||
self.w2_scale = None
|
|
||||||
self.a13_scale = None
|
|
||||||
self.a2_scale = None
|
|
||||||
|
|
||||||
if self.use_fp8:
|
|
||||||
# WEIGHT_SCALE (for fp8)
|
|
||||||
# Allocate 2 scales for w1 and w3 respectively.
|
|
||||||
# They will be combined to a single scale after weight loading.
|
|
||||||
self.w13_scale = nn.Parameter(torch.ones(self.num_total_experts,
|
|
||||||
2,
|
|
||||||
dtype=torch.float32),
|
|
||||||
requires_grad=False)
|
|
||||||
self.w2_scale = nn.Parameter(torch.ones(self.num_total_experts,
|
|
||||||
dtype=torch.float32),
|
|
||||||
requires_grad=False)
|
|
||||||
|
|
||||||
# If loading fp8 checkpoint, pass the weight loaders.
|
|
||||||
# If loading an fp16 checkpoint, do not (we will quantize in
|
|
||||||
# process_weights_after_loading()
|
|
||||||
if quant_config.is_checkpoint_fp8_serialized:
|
|
||||||
set_weight_attrs(self.w13_scale, {
|
|
||||||
"weight_loader": self.weight_loader,
|
|
||||||
})
|
|
||||||
set_weight_attrs(self.w2_scale, {
|
|
||||||
"weight_loader": self.weight_loader,
|
|
||||||
})
|
|
||||||
|
|
||||||
# INPUT_SCALE (for fp8)
|
|
||||||
if quant_config.activation_scheme == "static":
|
|
||||||
if not quant_config.is_checkpoint_fp8_serialized:
|
|
||||||
raise ValueError(
|
|
||||||
"Found static activation scheme for checkpoint that "
|
|
||||||
"was not serialized fp8.")
|
|
||||||
self.a13_scale = nn.Parameter(torch.ones(
|
|
||||||
self.num_total_experts, dtype=torch.float32),
|
|
||||||
requires_grad=False)
|
|
||||||
self.a2_scale = nn.Parameter(torch.ones(self.num_total_experts,
|
|
||||||
dtype=torch.float32),
|
|
||||||
requires_grad=False)
|
|
||||||
|
|
||||||
set_weight_attrs(self.a13_scale, {
|
|
||||||
"weight_loader": self.weight_loader,
|
|
||||||
})
|
|
||||||
set_weight_attrs(self.a2_scale, {
|
|
||||||
"weight_loader": self.weight_loader,
|
|
||||||
})
|
|
||||||
|
|
||||||
def weight_loader(self, param: nn.Parameter, loaded_weight: torch.Tensor,
|
|
||||||
weight_name: str, expert_id: int):
|
|
||||||
tp_rank = get_tensor_model_parallel_rank()
|
|
||||||
param_data = param.data
|
|
||||||
shard_size = self.intermediate_size
|
|
||||||
shard = slice(tp_rank * shard_size, (tp_rank + 1) * shard_size)
|
|
||||||
if weight_name.endswith("w1.weight"):
|
|
||||||
param_data[expert_id, 0:shard_size, :] = loaded_weight[shard, :]
|
|
||||||
if weight_name.endswith("w3.weight"):
|
|
||||||
param_data[expert_id,
|
|
||||||
shard_size:2 * shard_size, :] = loaded_weight[shard, :]
|
|
||||||
if weight_name.endswith("w2.weight"):
|
|
||||||
param_data[expert_id, :, :] = loaded_weight[:, shard]
|
|
||||||
|
|
||||||
# Loading scales
|
|
||||||
if "input_scale" in weight_name or "w2.weight_scale" in weight_name:
|
|
||||||
if param_data[expert_id] != 1 and (param_data[expert_id] -
|
|
||||||
loaded_weight).abs() > 1e-5:
|
|
||||||
raise ValueError(
|
|
||||||
"input_scales of w1 and w3 of a layer "
|
|
||||||
f"must be equal. But got {param_data[expert_id]} "
|
|
||||||
f"vs. {loaded_weight}")
|
|
||||||
param_data[expert_id] = loaded_weight
|
|
||||||
elif "weight_scale" in weight_name:
|
|
||||||
# We have to keep the weight scales of w1 and w3 because
|
|
||||||
# we need to re-quantize w1/w3 weights after weight loading.
|
|
||||||
assert "w1" in weight_name or "w3" in weight_name
|
|
||||||
shard_id = 0 if "w1" in weight_name else 1
|
|
||||||
param_data[expert_id][shard_id] = loaded_weight
|
|
||||||
|
|
||||||
def process_weights_after_loading(self):
|
|
||||||
# Fp8 is the only case where we need to process after loading.
|
|
||||||
if not self.use_fp8:
|
|
||||||
return
|
|
||||||
|
|
||||||
# If checkpoint is fp16, quantize here.
|
|
||||||
if not self.quant_config.is_checkpoint_fp8_serialized:
|
|
||||||
w13_weight = torch.empty_like(self.w13_weight.data,
|
|
||||||
dtype=torch.float8_e4m3fn)
|
|
||||||
w2_weight = torch.empty_like(self.w2_weight.data,
|
|
||||||
dtype=torch.float8_e4m3fn)
|
|
||||||
|
|
||||||
# Re-initialize w13_scale because we directly quantize
|
|
||||||
# merged w13 weights and generate a single scaling factor.
|
|
||||||
self.w13_scale = nn.Parameter(torch.ones(self.num_total_experts,
|
|
||||||
dtype=torch.float32),
|
|
||||||
requires_grad=False)
|
|
||||||
for expert in range(self.num_total_experts):
|
|
||||||
w13_weight[expert, :, :], self.w13_scale[
|
|
||||||
expert] = ops.scaled_fp8_quant(
|
|
||||||
self.w13_weight.data[expert, :, :])
|
|
||||||
w2_weight[expert, :, :], self.w2_scale[
|
|
||||||
expert] = ops.scaled_fp8_quant(
|
|
||||||
self.w2_weight.data[expert, :, :])
|
|
||||||
self.w13_weight = nn.Parameter(w13_weight, requires_grad=False)
|
|
||||||
self.w2_weight = nn.Parameter(w2_weight, requires_grad=False)
|
|
||||||
|
|
||||||
else:
|
|
||||||
# If checkpoint is fp8 + static, cleanup input_scales.
|
|
||||||
# Since state_dict has an input_scale per expert but our kernels
|
|
||||||
# are passed one input_scale shared across all experts.
|
|
||||||
if self.quant_config.activation_scheme == "static":
|
|
||||||
if self.a13_scale is None or self.a2_scale is None:
|
|
||||||
raise ValueError(
|
|
||||||
"QuantConfig has static quantization, but found "
|
|
||||||
"activation scales are None.")
|
|
||||||
|
|
||||||
if (not all_close_1d(self.a13_scale)
|
|
||||||
or not all_close_1d(self.a2_scale)):
|
|
||||||
print_warning_once(
|
|
||||||
"Found input_scales that are not equal for "
|
|
||||||
"fp8 MoE layer. Using the maximum across experts "
|
|
||||||
"for each layer. ")
|
|
||||||
|
|
||||||
self.a13_scale = nn.Parameter(self.a13_scale.max(),
|
|
||||||
requires_grad=False)
|
|
||||||
self.a2_scale = nn.Parameter(self.a2_scale.max(),
|
|
||||||
requires_grad=False)
|
|
||||||
|
|
||||||
assert self.w13_scale is not None
|
|
||||||
shard_size = self.intermediate_size
|
|
||||||
max_w13_scales = self.w13_scale.max(dim=1).values
|
|
||||||
for expert_id in range(self.num_total_experts):
|
|
||||||
start = 0
|
|
||||||
for shard_id in range(2):
|
|
||||||
dq_weight = per_tensor_dequantize(
|
|
||||||
self.w13_weight[expert_id][start:start +
|
|
||||||
shard_size, :],
|
|
||||||
self.w13_scale[expert_id][shard_id])
|
|
||||||
self.w13_weight[expert_id][
|
|
||||||
start:start + shard_size, :] = per_tensor_quantize(
|
|
||||||
dq_weight, max_w13_scales[expert_id])
|
|
||||||
start += shard_size
|
|
||||||
|
|
||||||
self.w13_scale = nn.Parameter(max_w13_scales, requires_grad=False)
|
|
||||||
|
|
||||||
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||||
num_tokens, hidden_size = hidden_states.shape
|
num_tokens, hidden_size = hidden_states.shape
|
||||||
hidden_states = hidden_states.view(-1, self.hidden_size)
|
hidden_states = hidden_states.view(-1, self.hidden_size)
|
||||||
# router_logits: (num_tokens, n_experts)
|
# router_logits: (num_tokens, n_experts)
|
||||||
router_logits, _ = self.gate(hidden_states)
|
router_logits, _ = self.gate(hidden_states)
|
||||||
final_hidden_states = fused_moe(hidden_states,
|
final_hidden_states = self.experts(hidden_states, router_logits)
|
||||||
self.w13_weight,
|
|
||||||
self.w2_weight,
|
|
||||||
router_logits,
|
|
||||||
self.top_k,
|
|
||||||
renormalize=True,
|
|
||||||
inplace=True,
|
|
||||||
use_fp8=self.use_fp8,
|
|
||||||
w1_scale=self.w13_scale,
|
|
||||||
w2_scale=self.w2_scale,
|
|
||||||
a1_scale=self.a13_scale,
|
|
||||||
a2_scale=self.a2_scale)
|
|
||||||
|
|
||||||
if self.tp_size > 1:
|
|
||||||
final_hidden_states = tensor_model_parallel_all_reduce(
|
|
||||||
final_hidden_states)
|
|
||||||
|
|
||||||
return final_hidden_states.view(num_tokens, hidden_size)
|
return final_hidden_states.view(num_tokens, hidden_size)
|
||||||
|
|
||||||
|
|
||||||
@ -566,25 +372,28 @@ class MixtralForCausalLM(nn.Module, SupportsLoRA):
|
|||||||
|
|
||||||
expert_params_mapping = [
|
expert_params_mapping = [
|
||||||
# These are the weight scales for the experts
|
# These are the weight scales for the experts
|
||||||
# (param_name, weight_name, expert_id)
|
# (param_name, weight_name, expert_id, shard_id)
|
||||||
("w13_scale" if weight_name in ["w1", "w3"] else "w2_scale",
|
("experts.w13_scale"
|
||||||
f"experts.{expert_id}.{weight_name}.weight_scale", expert_id)
|
if weight_name in ["w1", "w3"] else "experts.w2_scale",
|
||||||
for expert_id in range(self.config.num_local_experts)
|
f"experts.{expert_id}.{weight_name}.weight_scale", expert_id,
|
||||||
for weight_name in ["w1", "w2", "w3"]
|
shard_id) for expert_id in range(self.config.num_local_experts)
|
||||||
|
for shard_id, weight_name in enumerate(["w1", "w2", "w3"])
|
||||||
] + [
|
] + [
|
||||||
# These are the weights for the experts
|
# These are the weights for the experts
|
||||||
# (param_name, weight_name, expert_id)
|
# (param_name, weight_name, expert_id)
|
||||||
("w13_weight" if weight_name in ["w1", "w3"] else "w2_weight",
|
("experts.w13_weight"
|
||||||
f"experts.{expert_id}.{weight_name}.weight", expert_id)
|
if weight_name in ["w1", "w3"] else "experts.w2_weight",
|
||||||
|
f"experts.{expert_id}.{weight_name}.weight", expert_id, shard_id)
|
||||||
for expert_id in range(self.config.num_local_experts)
|
for expert_id in range(self.config.num_local_experts)
|
||||||
for weight_name in ["w1", "w2", "w3"]
|
for shard_id, weight_name in enumerate(["w1", "w2", "w3"])
|
||||||
] + [
|
] + [
|
||||||
# These are the activation scales for the experts
|
# These are the activation scales for the experts
|
||||||
# (param_name, weight_name, expert_id)
|
# (param_name, weight_name, expert_id)
|
||||||
("a13_scale" if weight_name in ["w1", "w3"] else "a2_scale",
|
("experts.a13_scale"
|
||||||
f"experts.{expert_id}.{weight_name}.input_scale", expert_id)
|
if weight_name in ["w1", "w3"] else "experts.a2_scale",
|
||||||
for expert_id in range(self.config.num_local_experts)
|
f"experts.{expert_id}.{weight_name}.input_scale", expert_id,
|
||||||
for weight_name in ["w1", "w2", "w3"]
|
shard_id) for expert_id in range(self.config.num_local_experts)
|
||||||
|
for shard_id, weight_name in enumerate(["w1", "w2", "w3"])
|
||||||
]
|
]
|
||||||
|
|
||||||
params_dict = dict(self.named_parameters())
|
params_dict = dict(self.named_parameters())
|
||||||
@ -604,7 +413,8 @@ class MixtralForCausalLM(nn.Module, SupportsLoRA):
|
|||||||
weight_loader(param, loaded_weight, shard_id)
|
weight_loader(param, loaded_weight, shard_id)
|
||||||
break
|
break
|
||||||
else:
|
else:
|
||||||
for param_name, weight_name, expert_id in expert_params_mapping:
|
for mapping in expert_params_mapping:
|
||||||
|
param_name, weight_name, expert_id, shard_id = mapping
|
||||||
if weight_name not in name:
|
if weight_name not in name:
|
||||||
continue
|
continue
|
||||||
name = name.replace(weight_name, param_name)
|
name = name.replace(weight_name, param_name)
|
||||||
@ -613,6 +423,7 @@ class MixtralForCausalLM(nn.Module, SupportsLoRA):
|
|||||||
weight_loader(param,
|
weight_loader(param,
|
||||||
loaded_weight,
|
loaded_weight,
|
||||||
weight_name,
|
weight_name,
|
||||||
|
shard_id=shard_id,
|
||||||
expert_id=expert_id)
|
expert_id=expert_id)
|
||||||
break
|
break
|
||||||
else:
|
else:
|
||||||
@ -637,8 +448,3 @@ class MixtralForCausalLM(nn.Module, SupportsLoRA):
|
|||||||
weight_loader = getattr(param, "weight_loader",
|
weight_loader = getattr(param, "weight_loader",
|
||||||
default_weight_loader)
|
default_weight_loader)
|
||||||
weight_loader(param, loaded_weight)
|
weight_loader(param, loaded_weight)
|
||||||
|
|
||||||
|
|
||||||
def all_close_1d(x: torch.Tensor) -> bool:
|
|
||||||
assert len(x.shape) == 1
|
|
||||||
return all(torch.allclose(x[0], x[i]) for i in range(x.shape[0]))
|
|
||||||
|
@ -31,11 +31,10 @@ from transformers import PretrainedConfig
|
|||||||
|
|
||||||
from vllm.attention import Attention, AttentionMetadata
|
from vllm.attention import Attention, AttentionMetadata
|
||||||
from vllm.config import CacheConfig
|
from vllm.config import CacheConfig
|
||||||
from vllm.distributed import (get_tensor_model_parallel_rank,
|
from vllm.distributed import (get_tensor_model_parallel_world_size,
|
||||||
get_tensor_model_parallel_world_size,
|
|
||||||
tensor_model_parallel_all_reduce)
|
tensor_model_parallel_all_reduce)
|
||||||
from vllm.model_executor.layers.activation import SiluAndMul
|
from vllm.model_executor.layers.activation import SiluAndMul
|
||||||
from vllm.model_executor.layers.fused_moe import fused_moe
|
from vllm.model_executor.layers.fused_moe import FusedMoE
|
||||||
from vllm.model_executor.layers.layernorm import RMSNorm
|
from vllm.model_executor.layers.layernorm import RMSNorm
|
||||||
from vllm.model_executor.layers.linear import (MergedColumnParallelLinear,
|
from vllm.model_executor.layers.linear import (MergedColumnParallelLinear,
|
||||||
QKVParallelLinear,
|
QKVParallelLinear,
|
||||||
@ -93,28 +92,23 @@ class Qwen2MoeSparseMoeBlock(nn.Module):
|
|||||||
quant_config: Optional[QuantizationConfig] = None,
|
quant_config: Optional[QuantizationConfig] = None,
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.config = config
|
|
||||||
self.rank = get_tensor_model_parallel_rank()
|
|
||||||
self.tp_size = get_tensor_model_parallel_world_size()
|
self.tp_size = get_tensor_model_parallel_world_size()
|
||||||
self.n_routed_experts = config.num_experts
|
|
||||||
self.top_k = config.num_experts_per_tok
|
if self.tp_size > config.num_experts:
|
||||||
if self.tp_size > self.n_routed_experts:
|
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"Tensor parallel size {self.tp_size} is greater than "
|
f"Tensor parallel size {self.tp_size} is greater than "
|
||||||
f"the number of experts {self.n_routed_experts}.")
|
f"the number of experts {config.num_experts}.")
|
||||||
|
|
||||||
self.experts = nn.ModuleList([
|
self.experts = FusedMoE(num_experts=config.num_experts,
|
||||||
Qwen2MoeMLP(hidden_size=config.hidden_size,
|
top_k=config.num_experts_per_tok,
|
||||||
|
hidden_size=config.hidden_size,
|
||||||
intermediate_size=config.moe_intermediate_size,
|
intermediate_size=config.moe_intermediate_size,
|
||||||
hidden_act=config.hidden_act,
|
reduce_results=False,
|
||||||
quant_config=quant_config,
|
renormalize=config.norm_topk_prob,
|
||||||
reduce_results=False)
|
quant_config=quant_config)
|
||||||
for idx in range(self.n_routed_experts)
|
|
||||||
])
|
|
||||||
self.pack_params()
|
|
||||||
|
|
||||||
self.gate = ReplicatedLinear(config.hidden_size,
|
self.gate = ReplicatedLinear(config.hidden_size,
|
||||||
self.n_routed_experts,
|
config.num_experts,
|
||||||
bias=False,
|
bias=False,
|
||||||
quant_config=None)
|
quant_config=None)
|
||||||
if config.shared_expert_intermediate_size > 0:
|
if config.shared_expert_intermediate_size > 0:
|
||||||
@ -131,25 +125,6 @@ class Qwen2MoeSparseMoeBlock(nn.Module):
|
|||||||
1,
|
1,
|
||||||
bias=False)
|
bias=False)
|
||||||
|
|
||||||
def pack_params(self):
|
|
||||||
w1 = []
|
|
||||||
w2 = []
|
|
||||||
for expert in self.experts:
|
|
||||||
w1.append(expert.gate_up_proj.weight)
|
|
||||||
w2.append(expert.down_proj.weight)
|
|
||||||
self.w1 = torch._utils._flatten_dense_tensors(w1)
|
|
||||||
w1s = torch._utils._unflatten_dense_tensors(self.w1, w1)
|
|
||||||
for data, param in zip(w1s, w1):
|
|
||||||
param.data = data
|
|
||||||
self.w1 = self.w1.view(len(w1), *w1s[0].shape)
|
|
||||||
|
|
||||||
self.w2 = torch._utils._flatten_dense_tensors(w2)
|
|
||||||
w2s = torch._utils._unflatten_dense_tensors(self.w2, w2)
|
|
||||||
for data, param in zip(w2s, w2):
|
|
||||||
param.data = data
|
|
||||||
|
|
||||||
self.w2 = self.w2.view(len(w2), *w2s[0].shape)
|
|
||||||
|
|
||||||
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||||
num_tokens, hidden_dim = hidden_states.shape
|
num_tokens, hidden_dim = hidden_states.shape
|
||||||
hidden_states = hidden_states.view(-1, hidden_dim)
|
hidden_states = hidden_states.view(-1, hidden_dim)
|
||||||
@ -162,16 +137,11 @@ class Qwen2MoeSparseMoeBlock(nn.Module):
|
|||||||
|
|
||||||
# router_logits: (num_tokens, n_experts)
|
# router_logits: (num_tokens, n_experts)
|
||||||
router_logits, _ = self.gate(hidden_states)
|
router_logits, _ = self.gate(hidden_states)
|
||||||
final_hidden_states = fused_moe(hidden_states,
|
final_hidden_states = self.experts(hidden_states=hidden_states,
|
||||||
self.w1,
|
router_logits=router_logits)
|
||||||
self.w2,
|
|
||||||
router_logits,
|
|
||||||
self.top_k,
|
|
||||||
renormalize=self.config.norm_topk_prob,
|
|
||||||
inplace=True)
|
|
||||||
|
|
||||||
if shared_output is not None:
|
if shared_output is not None:
|
||||||
final_hidden_states = final_hidden_states + shared_output
|
final_hidden_states = final_hidden_states + shared_output
|
||||||
|
if self.tp_size > 1:
|
||||||
final_hidden_states = tensor_model_parallel_all_reduce(
|
final_hidden_states = tensor_model_parallel_all_reduce(
|
||||||
final_hidden_states)
|
final_hidden_states)
|
||||||
|
|
||||||
@ -284,7 +254,12 @@ class Qwen2MoeDecoderLayer(nn.Module):
|
|||||||
cache_config=cache_config,
|
cache_config=cache_config,
|
||||||
quant_config=quant_config,
|
quant_config=quant_config,
|
||||||
)
|
)
|
||||||
if (layer_idx not in config.mlp_only_layers) and (
|
|
||||||
|
# Note: Qwen/Qwen2-57B-A14B-Instruct does not have
|
||||||
|
# `mlp_only_layers` in the config.
|
||||||
|
mlp_only_layers = ([] if not hasattr(config, "mlp_only_layers") else
|
||||||
|
config.mlp_only_layers)
|
||||||
|
if (layer_idx not in mlp_only_layers) and (
|
||||||
config.num_experts > 0 and
|
config.num_experts > 0 and
|
||||||
(layer_idx + 1) % config.decoder_sparse_step == 0):
|
(layer_idx + 1) % config.decoder_sparse_step == 0):
|
||||||
self.mlp = Qwen2MoeSparseMoeBlock(config=config,
|
self.mlp = Qwen2MoeSparseMoeBlock(config=config,
|
||||||
@ -427,21 +402,36 @@ class Qwen2MoeForCausalLM(nn.Module):
|
|||||||
("gate_up_proj", "up_proj", 1),
|
("gate_up_proj", "up_proj", 1),
|
||||||
]
|
]
|
||||||
|
|
||||||
|
expert_params_mapping = [
|
||||||
|
# These are the weights for the experts
|
||||||
|
# (param_name, weight_name, expert_id, shard_id)
|
||||||
|
("experts.w13_weight" if weight_name in ["gate_proj", "up_proj"]
|
||||||
|
else "experts.w2_weight",
|
||||||
|
f"experts.{expert_id}.{weight_name}.weight", expert_id, shard_id)
|
||||||
|
for expert_id in range(self.config.num_experts) for shard_id,
|
||||||
|
weight_name in enumerate(["gate_proj", "down_proj", "up_proj"])
|
||||||
|
]
|
||||||
|
|
||||||
params_dict = dict(self.named_parameters())
|
params_dict = dict(self.named_parameters())
|
||||||
for name, loaded_weight in weights:
|
for name, loaded_weight in weights:
|
||||||
if "rotary_emb.inv_freq" in name:
|
if "rotary_emb.inv_freq" in name:
|
||||||
continue
|
continue
|
||||||
for (param_name, weight_name, shard_id) in stacked_params_mapping:
|
for (param_name, weight_name, shard_id) in stacked_params_mapping:
|
||||||
|
# Skip non-stacked layers and experts (experts handled below).
|
||||||
if weight_name not in name:
|
if weight_name not in name:
|
||||||
continue
|
continue
|
||||||
|
# We have mlp.experts[0].gate_proj in the checkpoint.
|
||||||
|
# Since we handle the experts below in expert_params_mapping,
|
||||||
|
# we need to skip here BEFORE we update the name, otherwise
|
||||||
|
# name will be updated to mlp.experts[0].gate_up_proj, which
|
||||||
|
# will then be updated below in expert_params_mapping
|
||||||
|
# for mlp.experts[0].gate_gate_up_proj, which breaks load.
|
||||||
|
if "mlp.experts" in name:
|
||||||
|
continue
|
||||||
name = name.replace(weight_name, param_name)
|
name = name.replace(weight_name, param_name)
|
||||||
# Skip loading extra bias for GPTQ models.
|
# Skip loading extra bias for GPTQ models.
|
||||||
if name.endswith(".bias") and name not in params_dict:
|
if name.endswith(".bias") and name not in params_dict:
|
||||||
continue
|
continue
|
||||||
# Skip experts that are not assigned to this worker.
|
|
||||||
if (("mlp.experts." in name or "mlp.shared_expert." in name)
|
|
||||||
and name not in params_dict):
|
|
||||||
continue
|
|
||||||
if name not in params_dict:
|
if name not in params_dict:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
@ -449,14 +439,24 @@ class Qwen2MoeForCausalLM(nn.Module):
|
|||||||
weight_loader = param.weight_loader
|
weight_loader = param.weight_loader
|
||||||
weight_loader(param, loaded_weight, shard_id)
|
weight_loader(param, loaded_weight, shard_id)
|
||||||
break
|
break
|
||||||
|
else:
|
||||||
|
for mapping in expert_params_mapping:
|
||||||
|
param_name, weight_name, expert_id, shard_id = mapping
|
||||||
|
if weight_name not in name:
|
||||||
|
continue
|
||||||
|
name = name.replace(weight_name, param_name)
|
||||||
|
param = params_dict[name]
|
||||||
|
weight_loader = param.weight_loader
|
||||||
|
weight_loader(param,
|
||||||
|
loaded_weight,
|
||||||
|
weight_name,
|
||||||
|
shard_id=shard_id,
|
||||||
|
expert_id=expert_id)
|
||||||
|
break
|
||||||
else:
|
else:
|
||||||
# Skip loading extra bias for GPTQ models.
|
# Skip loading extra bias for GPTQ models.
|
||||||
if name.endswith(".bias") and name not in params_dict:
|
if name.endswith(".bias") and name not in params_dict:
|
||||||
continue
|
continue
|
||||||
# Skip experts that are not assigned to this worker.
|
|
||||||
if (("mlp.experts." in name or "mlp.shared_expert." in name)
|
|
||||||
and name not in params_dict):
|
|
||||||
continue
|
|
||||||
if name not in params_dict:
|
if name not in params_dict:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user