[ Misc ] Refactor MoE to isolate Fp8 From Mixtral (#5970)

Co-authored-by: Robert Shaw <rshaw@neuralmagic>
Co-authored-by: Michael Goin <michael@neuralmagic.com>
This commit is contained in:
Robert Shaw 2024-07-02 17:54:35 -04:00 committed by GitHub
parent 4d26d806e1
commit 7c008c51a9
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
10 changed files with 535 additions and 304 deletions

View File

@ -0,0 +1,11 @@
# bash ./run-lm-eval-gsm-vllm-baseline.sh -m neuralmagic/Mixtral-8x22B-Instruct-v0.1-FP8-dynamic -b "auto" -l 250 -f 5 -t 8
model_name: "neuralmagic/Mixtral-8x22B-Instruct-v0.1-FP8-dynamic"
tasks:
- name: "gsm8k"
metrics:
- name: "exact_match,strict-match"
value: 0.86
- name: "exact_match,flexible-extract"
value: 0.86
limit: 250
num_fewshot: 5

View File

@ -0,0 +1,11 @@
# bash ./run-lm-eval-gsm-vllm-baseline.sh -m neuralmagic/Mixtral-8x7B-Instruct-v0.1-FP8 -b "auto" -l 250 -f 5 -t 4
model_name: "neuralmagic/Mixtral-8x7B-Instruct-v0.1-FP8"
tasks:
- name: "gsm8k"
metrics:
- name: "exact_match,strict-match"
value: 0.624
- name: "exact_match,flexible-extract"
value: 0.624
limit: 250
num_fewshot: 5

View File

@ -0,0 +1,11 @@
# bash ./run-lm-eval-gsm-vllm-baseline.sh -m Qwen/Qwen2-57B-A14B-Instruct -b "auto" -l 250 -f 5 -t 4
model_name: "Qwen/Qwen2-57B-A14B-Instruct"
tasks:
- name: "gsm8k"
metrics:
- name: "exact_match,strict-match"
value: 0.792
- name: "exact_match,flexible-extract"
value: 0.824
limit: 250
num_fewshot: 5

View File

@ -1,2 +1,3 @@
Meta-Llama-3-70B-Instruct.yaml
Mixtral-8x7B-Instruct-v0.1.yaml
Qwen2-57B-A14-Instruct.yaml

View File

@ -77,8 +77,8 @@ def test_mixtral_moe(dtype: torch.dtype):
for i in range(config.num_local_experts):
weights = (hf_moe.experts[i].w1.weight.data,
hf_moe.experts[i].w3.weight.data)
vllm_moe.w13_weight[i][:] = torch.cat(weights, dim=0)
vllm_moe.w2_weight[i][:] = hf_moe.experts[i].w2.weight.data
vllm_moe.experts.w13_weight[i][:] = torch.cat(weights, dim=0)
vllm_moe.experts.w2_weight[i][:] = hf_moe.experts[i].w2.weight.data
# Generate input batch of dimensions [batch_size, seq_len, hidden_dim]
hf_inputs = torch.randn((1, 64, config.hidden_size)).to(dtype).to("cuda")

View File

@ -1,5 +1,7 @@
from vllm.model_executor.layers.fused_moe.fused_moe import (
fused_experts, fused_moe, fused_topk, get_config_file_name, grouped_topk)
from vllm.model_executor.layers.fused_moe.layer import (FusedMoE,
FusedMoEMethodBase)
__all__ = [
"fused_moe",
@ -7,4 +9,6 @@ __all__ = [
"fused_experts",
"get_config_file_name",
"grouped_topk",
"FusedMoE",
"FusedMoEMethodBase",
]

View File

@ -0,0 +1,197 @@
from abc import abstractmethod
from typing import Optional
import torch
from vllm.distributed import (get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size,
tensor_model_parallel_all_reduce)
from vllm.logger import init_logger
from vllm.model_executor.layers.fused_moe.fused_moe import fused_moe
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig, QuantizeMethodBase)
from vllm.model_executor.utils import set_weight_attrs
logger = init_logger(__name__)
class FusedMoEMethodBase(QuantizeMethodBase):
@abstractmethod
def create_weights(self, layer: torch.nn.Module, num_experts: int,
hidden_size: int, intermediate_size: int,
params_dtype: torch.dtype, **extra_weight_attrs):
raise NotImplementedError
@abstractmethod
def apply(self,
layer: torch.nn.Module,
x: torch.Tensor,
router_logits: torch.Tensor,
top_k: int,
renormalize: bool = True) -> torch.Tensor:
raise NotImplementedError
class UnquantizedFusedMoEMethod(FusedMoEMethodBase):
"""MoE method without quantization."""
def create_weights(self, layer: torch.nn.Module, num_experts: int,
hidden_size: int, intermediate_size: int,
params_dtype: torch.dtype, **extra_weight_attrs):
# Fused gate_up_proj (column parallel)
w13_weight = torch.nn.Parameter(torch.empty(num_experts,
2 * intermediate_size,
hidden_size,
dtype=params_dtype),
requires_grad=False)
layer.register_parameter("w13_weight", w13_weight)
set_weight_attrs(w13_weight, extra_weight_attrs)
# down_proj (row parallel)
w2_weight = torch.nn.Parameter(torch.empty(num_experts,
hidden_size,
intermediate_size,
dtype=params_dtype),
requires_grad=False)
layer.register_parameter("w2_weight", w2_weight)
set_weight_attrs(w2_weight, extra_weight_attrs)
def apply(self,
layer: torch.nn.Module,
x: torch.Tensor,
router_logits: torch.Tensor,
top_k: int,
renormalize: bool = True) -> torch.Tensor:
return fused_moe(x,
layer.w13_weight,
layer.w2_weight,
router_logits,
top_k,
renormalize=renormalize,
inplace=True)
class FusedMoE(torch.nn.Module):
"""FusedMoE layer for MoE models.
This layer contains both MergedColumnParallel weights (gate_up_proj /
w13) and RowParallelLinear weights (down_proj/ w2).
Note: Mixtral uses w1, w2, and w3 for gate, up, and down_proj. We
copy that naming convention here and handle any remapping in the
load_weights function in each model implementation.
Args:
num_experts: Number of experts in the model
top_k: Number of experts selected for each token
hidden_size: Input hidden state size of the transformer
intermediate_size: Intermediate size of the experts
params_dtype: Data type for the parameters.
reduce_results: Whether to all all_reduce on the output of the layer
renomalize: Whether to renormalize the logits in the fused_moe kernel
quant_config: Quantization configure.
"""
def __init__(
self,
num_experts: int,
top_k: int,
hidden_size: int,
intermediate_size: int,
params_dtype: Optional[torch.dtype] = None,
reduce_results: bool = False,
renormalize: bool = True,
quant_config: Optional[QuantizationConfig] = None,
tp_size: Optional[int] = None,
):
super().__init__()
if params_dtype is None:
params_dtype = torch.get_default_dtype()
self.tp_size = (tp_size if tp_size is not None else
get_tensor_model_parallel_world_size())
self.top_k = top_k
self.num_experts = num_experts
self.intermediate_size_per_partition = intermediate_size // self.tp_size
self.reduce_results = reduce_results
self.renormalize = renormalize
if quant_config is None:
self.quant_method: Optional[QuantizeMethodBase] = (
UnquantizedFusedMoEMethod())
else:
self.quant_method = quant_config.get_quant_method(self)
assert self.quant_method is not None
self.quant_method.create_weights(
layer=self,
num_experts=num_experts,
hidden_size=hidden_size,
intermediate_size=self.intermediate_size_per_partition,
params_dtype=params_dtype,
weight_loader=self.weight_loader)
def weight_loader(self, param: torch.nn.Parameter,
loaded_weight: torch.Tensor, weight_name: str,
shard_id: int, expert_id: int):
param_data = param.data
# FIXME(robertgshaw2-neuralmagic): Overfit to Mixtral.
# Follow up PR to enable fp8 for other MoE models.
if "input_scale" in weight_name or "w2.weight_scale" in weight_name:
if param_data[expert_id] != 1 and (param_data[expert_id] -
loaded_weight).abs() > 1e-5:
raise ValueError(
"input_scales of w1 and w3 of a layer "
f"must be equal. But got {param_data[expert_id]} "
f"vs. {loaded_weight}")
param_data[expert_id] = loaded_weight
# FIXME(robertgshaw2-neuralmagic): Overfit to Mixtral.
# Follow up PR to enable fp8 for other MoE models.
elif "weight_scale" in weight_name:
# We have to keep the weight scales of w1 and w3 because
# we need to re-quantize w1/w3 weights after weight loading.
assert "w1" in weight_name or "w3" in weight_name
shard_id = 0 if "w1" in weight_name else 1
param_data[expert_id][shard_id] = loaded_weight
else:
tp_rank = get_tensor_model_parallel_rank()
shard_size = self.intermediate_size_per_partition
shard = slice(tp_rank * shard_size, (tp_rank + 1) * shard_size)
# w1, gate_proj case: Load into first shard of w13.
if shard_id == 0:
param_data[expert_id,
0:shard_size, :] = loaded_weight[shard, :]
# w3, up_proj case: Load into second shard of w13.
elif shard_id == 2:
param_data[expert_id, shard_size:2 *
shard_size, :] = loaded_weight[shard, :]
# w2, down_proj case: Load into only shard of w2.
elif shard_id == 1:
param_data[expert_id, :, :] = loaded_weight[:, shard]
else:
raise ValueError(
f"Shard id must be in [0,1,2] but got {shard_id}")
def forward(self, hidden_states: torch.Tensor,
router_logits: torch.Tensor):
assert self.quant_method is not None
# Matrix multiply.
final_hidden_states = self.quant_method.apply(
self,
x=hidden_states,
router_logits=router_logits,
top_k=self.top_k,
renormalize=self.renormalize)
if self.reduce_results and self.tp_size > 1:
final_hidden_states = tensor_model_parallel_all_reduce(
final_hidden_states)
return final_hidden_states

View File

@ -6,6 +6,8 @@ from torch.nn.parameter import Parameter
from vllm import _custom_ops as ops
from vllm.logger import init_logger
from vllm.model_executor.layers.fused_moe import (FusedMoE, FusedMoEMethodBase,
fused_moe)
from vllm.model_executor.layers.linear import LinearBase, LinearMethodBase
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig, QuantizeMethodBase)
@ -71,7 +73,9 @@ class Fp8Config(QuantizationConfig):
if isinstance(layer, LinearBase):
return Fp8LinearMethod(self)
if isinstance(layer, Attention):
elif isinstance(layer, FusedMoE):
return Fp8MoEMethod(self)
elif isinstance(layer, Attention):
return Fp8KVCacheMethod(self)
return None
@ -270,6 +274,187 @@ class Fp8LinearMethod(LinearMethodBase):
return torch.narrow(output, 0, 0, x.shape[0])
class Fp8MoEMethod(FusedMoEMethodBase):
"""MoE method for FP8.
Supports loading FP8 checkpoints with static weight scale and
dynamic/static activation scale.
Also supports loading quantized FP16/BF16 model checkpoints with dynamic
activation scaling. The weight scaling factor will be initialized after
the model weights are loaded.
Args:
quant_config: The quantization config.
"""
def __init__(self, quant_config: Fp8Config):
self.quant_config = quant_config
def create_weights(self, layer: Module, num_experts: int, hidden_size: int,
intermediate_size: int, params_dtype: torch.dtype,
**extra_weight_attrs):
layer.process_after_load = True
if self.quant_config.is_checkpoint_fp8_serialized:
params_dtype = torch.float8_e4m3fn
# WEIGHTS
w13_weight = torch.nn.Parameter(torch.empty(num_experts,
2 * intermediate_size,
hidden_size,
dtype=params_dtype),
requires_grad=False)
layer.register_parameter("w13_weight", w13_weight)
set_weight_attrs(w13_weight, extra_weight_attrs)
w2_weight = torch.nn.Parameter(torch.empty(num_experts,
hidden_size,
intermediate_size,
dtype=params_dtype),
requires_grad=False)
layer.register_parameter("w2_weight", w2_weight)
set_weight_attrs(w2_weight, extra_weight_attrs)
# WEIGHT_SCALES
# Allocate 2 scales for w1 and w3 respectively.
# They will be combined to a single scale after weight loading.
w13_scale = torch.nn.Parameter(torch.ones(num_experts,
2,
dtype=torch.float32),
requires_grad=False)
layer.register_parameter("w13_scale", w13_scale)
w2_scale = torch.nn.Parameter(torch.ones(num_experts,
dtype=torch.float32),
requires_grad=False)
layer.register_parameter("w2_scale", w2_scale)
# If loading fp8 checkpoint, pass the weight loaders.
# If loading an fp16 checkpoint, do not (we will quantize in
# process_weights_after_loading()
if self.quant_config.is_checkpoint_fp8_serialized:
set_weight_attrs(w13_scale, extra_weight_attrs)
set_weight_attrs(w2_scale, extra_weight_attrs)
# INPUT_SCALES
if self.quant_config.activation_scheme == "static":
if not self.quant_config.is_checkpoint_fp8_serialized:
raise ValueError(
"Found static activation scheme for checkpoint that "
"was not serialized fp8.")
a13_scale = torch.nn.Parameter(torch.ones(num_experts,
dtype=torch.float32),
requires_grad=False)
layer.register_parameter("a13_scale", a13_scale)
set_weight_attrs(a13_scale, extra_weight_attrs)
a2_scale = torch.nn.Parameter(torch.ones(num_experts,
dtype=torch.float32),
requires_grad=False)
layer.register_parameter("a2_scale", a2_scale)
set_weight_attrs(a2_scale, extra_weight_attrs)
else:
layer.a13_scale = None
layer.a2_scale = None
def process_weights_after_loading(self, layer: Module) -> None:
if (not hasattr(layer, "process_after_load")
or not layer.process_after_load):
return
# If checkpoint is fp16, quantize in place.
if not self.quant_config.is_checkpoint_fp8_serialized:
w13_weight = torch.empty_like(layer.w13_weight.data,
dtype=torch.float8_e4m3fn)
w2_weight = torch.empty_like(layer.w2_weight.data,
dtype=torch.float8_e4m3fn)
# Re-initialize w13_scale because we directly quantize
# merged w13 weights and generate a single scaling factor.
layer.w13_scale = torch.nn.Parameter(torch.ones(
layer.num_experts,
dtype=torch.float32,
device=w13_weight.device),
requires_grad=False)
for expert in range(layer.num_experts):
w13_weight[expert, :, :], layer.w13_scale[
expert] = ops.scaled_fp8_quant(
layer.w13_weight.data[expert, :, :])
w2_weight[expert, :, :], layer.w2_scale[
expert] = ops.scaled_fp8_quant(
layer.w2_weight.data[expert, :, :])
layer.w13_weight = torch.nn.Parameter(w13_weight,
requires_grad=False)
layer.w2_weight = torch.nn.Parameter(w2_weight,
requires_grad=False)
return
# If checkpoint is fp8, we need to handle that the
# MoE kernels require single activation scale and single weight
# scale for w13 per expert.
else:
# Fp8 moe kernels require a single activation scale.
# We take the max of all the scales in case they differ.
if self.quant_config.activation_scheme == "static":
if layer.a13_scale is None or layer.a2_scale is None:
raise ValueError(
"QuantConfig has static quantization, but found "
"activation scales are None.")
if (not all_close_1d(layer.a13_scale)
or not all_close_1d(layer.a2_scale)):
print_warning_once(
"Found input_scales that are not equal for "
"fp8 MoE layer. Using the maximum across experts "
"for each layer. ")
layer.a13_scale = torch.nn.Parameter(layer.a13_scale.max(),
requires_grad=False)
layer.a2_scale = torch.nn.Parameter(layer.a2_scale.max(),
requires_grad=False)
# Fp8 moe kernel needs single weight scale for w13 per expert.
# We take the max then dequant and requant each expert.
assert layer.w13_scale is not None
shard_size = layer.intermediate_size_per_partition
max_w13_scales = layer.w13_scale.max(dim=1).values
for expert_id in range(layer.num_experts):
start = 0
for shard_id in range(2):
dq_weight = per_tensor_dequantize(
layer.w13_weight[expert_id][start:start +
shard_size, :],
layer.w13_scale[expert_id][shard_id])
layer.w13_weight[expert_id][
start:start + shard_size, :] = per_tensor_quantize(
dq_weight, max_w13_scales[expert_id])
start += shard_size
layer.w13_scale = torch.nn.Parameter(max_w13_scales,
requires_grad=False)
return
def apply(self,
layer: torch.nn.Module,
x: torch.Tensor,
router_logits: torch.Tensor,
top_k: int,
renormalize: bool = True) -> torch.Tensor:
return fused_moe(x,
layer.w13_weight,
layer.w2_weight,
router_logits,
top_k,
renormalize=renormalize,
inplace=True,
use_fp8=True,
w1_scale=layer.w13_scale,
w2_scale=layer.w2_scale,
a1_scale=layer.a13_scale,
a2_scale=layer.a2_scale)
class Fp8KVCacheMethod(QuantizeMethodBase):
"""Supports loading kv-cache scaling factors from FP8 checkpoints.
"""
@ -321,3 +506,8 @@ def per_tensor_dequantize(
fake_qweight = tensor.to(torch.float16)
dq_weight = fake_qweight * inv_scale
return dq_weight
def all_close_1d(x: torch.Tensor) -> bool:
assert len(x.shape) == 1
return all(torch.allclose(x[0], x[i]) for i in range(x.shape[0]))

View File

@ -27,13 +27,10 @@ import torch
from torch import nn
from transformers import MixtralConfig
from vllm import _custom_ops as ops
from vllm.attention import Attention, AttentionMetadata
from vllm.config import CacheConfig, LoRAConfig
from vllm.distributed import (get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size,
tensor_model_parallel_all_reduce)
from vllm.model_executor.layers.fused_moe import fused_moe
from vllm.distributed import get_tensor_model_parallel_world_size
from vllm.model_executor.layers.fused_moe import FusedMoE
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.linear import (QKVParallelLinear,
ReplicatedLinear,
@ -41,16 +38,12 @@ from vllm.model_executor.layers.linear import (QKVParallelLinear,
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig)
from vllm.model_executor.layers.quantization.fp8 import (Fp8Config,
per_tensor_dequantize,
per_tensor_quantize)
from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.layers.sampler import Sampler
from vllm.model_executor.layers.vocab_parallel_embedding import (
DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding)
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.model_executor.utils import set_weight_attrs
from vllm.sequence import IntermediateTensors, SamplerOutput
from vllm.utils import print_warning_once
@ -66,227 +59,40 @@ class MixtralMoE(nn.Module):
across ranks.
"""
def __init__(
self,
num_experts: int,
top_k: int,
hidden_size: int,
intermediate_size: int,
params_dtype: Optional[torch.dtype] = None,
tp_size: Optional[int] = None,
quant_config: Optional[QuantizationConfig] = None,
):
def __init__(self,
num_experts: int,
top_k: int,
hidden_size: int,
intermediate_size: int,
params_dtype: Optional[torch.dtype] = None,
quant_config: Optional[QuantizationConfig] = None,
tp_size: Optional[int] = None):
super().__init__()
self.tp_size = tp_size or get_tensor_model_parallel_world_size()
self.num_total_experts = num_experts
self.top_k = top_k
self.hidden_size = hidden_size
self.intermediate_size = intermediate_size // self.tp_size
self.quant_config = quant_config
# FIXME(pcmoritz): Make this more general to support different
# quantization schemes
self.use_fp8 = isinstance(quant_config, Fp8Config)
if params_dtype is None:
params_dtype = torch.get_default_dtype()
self.params_dtype = params_dtype
# Gate always runs at half / full precision for now.
self.gate = ReplicatedLinear(self.hidden_size,
self.num_total_experts,
self.gate = ReplicatedLinear(hidden_size,
num_experts,
bias=False,
params_dtype=self.params_dtype,
params_dtype=params_dtype,
quant_config=None)
if self.use_fp8 and self.quant_config.is_checkpoint_fp8_serialized:
params_dtype = torch.float8_e4m3fn
self.w13_weight = nn.Parameter(torch.empty(self.num_total_experts,
2 * self.intermediate_size,
self.hidden_size,
dtype=params_dtype),
requires_grad=False)
self.w2_weight = nn.Parameter(torch.empty(self.num_total_experts,
self.hidden_size,
self.intermediate_size,
dtype=params_dtype),
requires_grad=False)
set_weight_attrs(self.w13_weight, {
"weight_loader": self.weight_loader,
})
set_weight_attrs(self.w2_weight, {
"weight_loader": self.weight_loader,
})
# Used for fp8.
self.w13_scale = None
self.w2_scale = None
self.a13_scale = None
self.a2_scale = None
if self.use_fp8:
# WEIGHT_SCALE (for fp8)
# Allocate 2 scales for w1 and w3 respectively.
# They will be combined to a single scale after weight loading.
self.w13_scale = nn.Parameter(torch.ones(self.num_total_experts,
2,
dtype=torch.float32),
requires_grad=False)
self.w2_scale = nn.Parameter(torch.ones(self.num_total_experts,
dtype=torch.float32),
requires_grad=False)
# If loading fp8 checkpoint, pass the weight loaders.
# If loading an fp16 checkpoint, do not (we will quantize in
# process_weights_after_loading()
if quant_config.is_checkpoint_fp8_serialized:
set_weight_attrs(self.w13_scale, {
"weight_loader": self.weight_loader,
})
set_weight_attrs(self.w2_scale, {
"weight_loader": self.weight_loader,
})
# INPUT_SCALE (for fp8)
if quant_config.activation_scheme == "static":
if not quant_config.is_checkpoint_fp8_serialized:
raise ValueError(
"Found static activation scheme for checkpoint that "
"was not serialized fp8.")
self.a13_scale = nn.Parameter(torch.ones(
self.num_total_experts, dtype=torch.float32),
requires_grad=False)
self.a2_scale = nn.Parameter(torch.ones(self.num_total_experts,
dtype=torch.float32),
requires_grad=False)
set_weight_attrs(self.a13_scale, {
"weight_loader": self.weight_loader,
})
set_weight_attrs(self.a2_scale, {
"weight_loader": self.weight_loader,
})
def weight_loader(self, param: nn.Parameter, loaded_weight: torch.Tensor,
weight_name: str, expert_id: int):
tp_rank = get_tensor_model_parallel_rank()
param_data = param.data
shard_size = self.intermediate_size
shard = slice(tp_rank * shard_size, (tp_rank + 1) * shard_size)
if weight_name.endswith("w1.weight"):
param_data[expert_id, 0:shard_size, :] = loaded_weight[shard, :]
if weight_name.endswith("w3.weight"):
param_data[expert_id,
shard_size:2 * shard_size, :] = loaded_weight[shard, :]
if weight_name.endswith("w2.weight"):
param_data[expert_id, :, :] = loaded_weight[:, shard]
# Loading scales
if "input_scale" in weight_name or "w2.weight_scale" in weight_name:
if param_data[expert_id] != 1 and (param_data[expert_id] -
loaded_weight).abs() > 1e-5:
raise ValueError(
"input_scales of w1 and w3 of a layer "
f"must be equal. But got {param_data[expert_id]} "
f"vs. {loaded_weight}")
param_data[expert_id] = loaded_weight
elif "weight_scale" in weight_name:
# We have to keep the weight scales of w1 and w3 because
# we need to re-quantize w1/w3 weights after weight loading.
assert "w1" in weight_name or "w3" in weight_name
shard_id = 0 if "w1" in weight_name else 1
param_data[expert_id][shard_id] = loaded_weight
def process_weights_after_loading(self):
# Fp8 is the only case where we need to process after loading.
if not self.use_fp8:
return
# If checkpoint is fp16, quantize here.
if not self.quant_config.is_checkpoint_fp8_serialized:
w13_weight = torch.empty_like(self.w13_weight.data,
dtype=torch.float8_e4m3fn)
w2_weight = torch.empty_like(self.w2_weight.data,
dtype=torch.float8_e4m3fn)
# Re-initialize w13_scale because we directly quantize
# merged w13 weights and generate a single scaling factor.
self.w13_scale = nn.Parameter(torch.ones(self.num_total_experts,
dtype=torch.float32),
requires_grad=False)
for expert in range(self.num_total_experts):
w13_weight[expert, :, :], self.w13_scale[
expert] = ops.scaled_fp8_quant(
self.w13_weight.data[expert, :, :])
w2_weight[expert, :, :], self.w2_scale[
expert] = ops.scaled_fp8_quant(
self.w2_weight.data[expert, :, :])
self.w13_weight = nn.Parameter(w13_weight, requires_grad=False)
self.w2_weight = nn.Parameter(w2_weight, requires_grad=False)
else:
# If checkpoint is fp8 + static, cleanup input_scales.
# Since state_dict has an input_scale per expert but our kernels
# are passed one input_scale shared across all experts.
if self.quant_config.activation_scheme == "static":
if self.a13_scale is None or self.a2_scale is None:
raise ValueError(
"QuantConfig has static quantization, but found "
"activation scales are None.")
if (not all_close_1d(self.a13_scale)
or not all_close_1d(self.a2_scale)):
print_warning_once(
"Found input_scales that are not equal for "
"fp8 MoE layer. Using the maximum across experts "
"for each layer. ")
self.a13_scale = nn.Parameter(self.a13_scale.max(),
requires_grad=False)
self.a2_scale = nn.Parameter(self.a2_scale.max(),
requires_grad=False)
assert self.w13_scale is not None
shard_size = self.intermediate_size
max_w13_scales = self.w13_scale.max(dim=1).values
for expert_id in range(self.num_total_experts):
start = 0
for shard_id in range(2):
dq_weight = per_tensor_dequantize(
self.w13_weight[expert_id][start:start +
shard_size, :],
self.w13_scale[expert_id][shard_id])
self.w13_weight[expert_id][
start:start + shard_size, :] = per_tensor_quantize(
dq_weight, max_w13_scales[expert_id])
start += shard_size
self.w13_scale = nn.Parameter(max_w13_scales, requires_grad=False)
self.experts = FusedMoE(num_experts=num_experts,
top_k=top_k,
hidden_size=hidden_size,
intermediate_size=intermediate_size,
params_dtype=params_dtype,
reduce_results=True,
renormalize=True,
quant_config=quant_config,
tp_size=tp_size)
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
num_tokens, hidden_size = hidden_states.shape
hidden_states = hidden_states.view(-1, self.hidden_size)
# router_logits: (num_tokens, n_experts)
router_logits, _ = self.gate(hidden_states)
final_hidden_states = fused_moe(hidden_states,
self.w13_weight,
self.w2_weight,
router_logits,
self.top_k,
renormalize=True,
inplace=True,
use_fp8=self.use_fp8,
w1_scale=self.w13_scale,
w2_scale=self.w2_scale,
a1_scale=self.a13_scale,
a2_scale=self.a2_scale)
if self.tp_size > 1:
final_hidden_states = tensor_model_parallel_all_reduce(
final_hidden_states)
final_hidden_states = self.experts(hidden_states, router_logits)
return final_hidden_states.view(num_tokens, hidden_size)
@ -566,25 +372,28 @@ class MixtralForCausalLM(nn.Module, SupportsLoRA):
expert_params_mapping = [
# These are the weight scales for the experts
# (param_name, weight_name, expert_id)
("w13_scale" if weight_name in ["w1", "w3"] else "w2_scale",
f"experts.{expert_id}.{weight_name}.weight_scale", expert_id)
for expert_id in range(self.config.num_local_experts)
for weight_name in ["w1", "w2", "w3"]
# (param_name, weight_name, expert_id, shard_id)
("experts.w13_scale"
if weight_name in ["w1", "w3"] else "experts.w2_scale",
f"experts.{expert_id}.{weight_name}.weight_scale", expert_id,
shard_id) for expert_id in range(self.config.num_local_experts)
for shard_id, weight_name in enumerate(["w1", "w2", "w3"])
] + [
# These are the weights for the experts
# (param_name, weight_name, expert_id)
("w13_weight" if weight_name in ["w1", "w3"] else "w2_weight",
f"experts.{expert_id}.{weight_name}.weight", expert_id)
("experts.w13_weight"
if weight_name in ["w1", "w3"] else "experts.w2_weight",
f"experts.{expert_id}.{weight_name}.weight", expert_id, shard_id)
for expert_id in range(self.config.num_local_experts)
for weight_name in ["w1", "w2", "w3"]
for shard_id, weight_name in enumerate(["w1", "w2", "w3"])
] + [
# These are the activation scales for the experts
# (param_name, weight_name, expert_id)
("a13_scale" if weight_name in ["w1", "w3"] else "a2_scale",
f"experts.{expert_id}.{weight_name}.input_scale", expert_id)
for expert_id in range(self.config.num_local_experts)
for weight_name in ["w1", "w2", "w3"]
("experts.a13_scale"
if weight_name in ["w1", "w3"] else "experts.a2_scale",
f"experts.{expert_id}.{weight_name}.input_scale", expert_id,
shard_id) for expert_id in range(self.config.num_local_experts)
for shard_id, weight_name in enumerate(["w1", "w2", "w3"])
]
params_dict = dict(self.named_parameters())
@ -604,7 +413,8 @@ class MixtralForCausalLM(nn.Module, SupportsLoRA):
weight_loader(param, loaded_weight, shard_id)
break
else:
for param_name, weight_name, expert_id in expert_params_mapping:
for mapping in expert_params_mapping:
param_name, weight_name, expert_id, shard_id = mapping
if weight_name not in name:
continue
name = name.replace(weight_name, param_name)
@ -613,6 +423,7 @@ class MixtralForCausalLM(nn.Module, SupportsLoRA):
weight_loader(param,
loaded_weight,
weight_name,
shard_id=shard_id,
expert_id=expert_id)
break
else:
@ -637,8 +448,3 @@ class MixtralForCausalLM(nn.Module, SupportsLoRA):
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
weight_loader(param, loaded_weight)
def all_close_1d(x: torch.Tensor) -> bool:
assert len(x.shape) == 1
return all(torch.allclose(x[0], x[i]) for i in range(x.shape[0]))

View File

@ -31,11 +31,10 @@ from transformers import PretrainedConfig
from vllm.attention import Attention, AttentionMetadata
from vllm.config import CacheConfig
from vllm.distributed import (get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size,
from vllm.distributed import (get_tensor_model_parallel_world_size,
tensor_model_parallel_all_reduce)
from vllm.model_executor.layers.activation import SiluAndMul
from vllm.model_executor.layers.fused_moe import fused_moe
from vllm.model_executor.layers.fused_moe import FusedMoE
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.linear import (MergedColumnParallelLinear,
QKVParallelLinear,
@ -93,28 +92,23 @@ class Qwen2MoeSparseMoeBlock(nn.Module):
quant_config: Optional[QuantizationConfig] = None,
):
super().__init__()
self.config = config
self.rank = get_tensor_model_parallel_rank()
self.tp_size = get_tensor_model_parallel_world_size()
self.n_routed_experts = config.num_experts
self.top_k = config.num_experts_per_tok
if self.tp_size > self.n_routed_experts:
if self.tp_size > config.num_experts:
raise ValueError(
f"Tensor parallel size {self.tp_size} is greater than "
f"the number of experts {self.n_routed_experts}.")
f"the number of experts {config.num_experts}.")
self.experts = nn.ModuleList([
Qwen2MoeMLP(hidden_size=config.hidden_size,
intermediate_size=config.moe_intermediate_size,
hidden_act=config.hidden_act,
quant_config=quant_config,
reduce_results=False)
for idx in range(self.n_routed_experts)
])
self.pack_params()
self.experts = FusedMoE(num_experts=config.num_experts,
top_k=config.num_experts_per_tok,
hidden_size=config.hidden_size,
intermediate_size=config.moe_intermediate_size,
reduce_results=False,
renormalize=config.norm_topk_prob,
quant_config=quant_config)
self.gate = ReplicatedLinear(config.hidden_size,
self.n_routed_experts,
config.num_experts,
bias=False,
quant_config=None)
if config.shared_expert_intermediate_size > 0:
@ -131,25 +125,6 @@ class Qwen2MoeSparseMoeBlock(nn.Module):
1,
bias=False)
def pack_params(self):
w1 = []
w2 = []
for expert in self.experts:
w1.append(expert.gate_up_proj.weight)
w2.append(expert.down_proj.weight)
self.w1 = torch._utils._flatten_dense_tensors(w1)
w1s = torch._utils._unflatten_dense_tensors(self.w1, w1)
for data, param in zip(w1s, w1):
param.data = data
self.w1 = self.w1.view(len(w1), *w1s[0].shape)
self.w2 = torch._utils._flatten_dense_tensors(w2)
w2s = torch._utils._unflatten_dense_tensors(self.w2, w2)
for data, param in zip(w2s, w2):
param.data = data
self.w2 = self.w2.view(len(w2), *w2s[0].shape)
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
num_tokens, hidden_dim = hidden_states.shape
hidden_states = hidden_states.view(-1, hidden_dim)
@ -162,18 +137,13 @@ class Qwen2MoeSparseMoeBlock(nn.Module):
# router_logits: (num_tokens, n_experts)
router_logits, _ = self.gate(hidden_states)
final_hidden_states = fused_moe(hidden_states,
self.w1,
self.w2,
router_logits,
self.top_k,
renormalize=self.config.norm_topk_prob,
inplace=True)
final_hidden_states = self.experts(hidden_states=hidden_states,
router_logits=router_logits)
if shared_output is not None:
final_hidden_states = final_hidden_states + shared_output
final_hidden_states = tensor_model_parallel_all_reduce(
final_hidden_states)
if self.tp_size > 1:
final_hidden_states = tensor_model_parallel_all_reduce(
final_hidden_states)
return final_hidden_states.view(num_tokens, hidden_dim)
@ -284,7 +254,12 @@ class Qwen2MoeDecoderLayer(nn.Module):
cache_config=cache_config,
quant_config=quant_config,
)
if (layer_idx not in config.mlp_only_layers) and (
# Note: Qwen/Qwen2-57B-A14B-Instruct does not have
# `mlp_only_layers` in the config.
mlp_only_layers = ([] if not hasattr(config, "mlp_only_layers") else
config.mlp_only_layers)
if (layer_idx not in mlp_only_layers) and (
config.num_experts > 0 and
(layer_idx + 1) % config.decoder_sparse_step == 0):
self.mlp = Qwen2MoeSparseMoeBlock(config=config,
@ -427,21 +402,36 @@ class Qwen2MoeForCausalLM(nn.Module):
("gate_up_proj", "up_proj", 1),
]
expert_params_mapping = [
# These are the weights for the experts
# (param_name, weight_name, expert_id, shard_id)
("experts.w13_weight" if weight_name in ["gate_proj", "up_proj"]
else "experts.w2_weight",
f"experts.{expert_id}.{weight_name}.weight", expert_id, shard_id)
for expert_id in range(self.config.num_experts) for shard_id,
weight_name in enumerate(["gate_proj", "down_proj", "up_proj"])
]
params_dict = dict(self.named_parameters())
for name, loaded_weight in weights:
if "rotary_emb.inv_freq" in name:
continue
for (param_name, weight_name, shard_id) in stacked_params_mapping:
# Skip non-stacked layers and experts (experts handled below).
if weight_name not in name:
continue
# We have mlp.experts[0].gate_proj in the checkpoint.
# Since we handle the experts below in expert_params_mapping,
# we need to skip here BEFORE we update the name, otherwise
# name will be updated to mlp.experts[0].gate_up_proj, which
# will then be updated below in expert_params_mapping
# for mlp.experts[0].gate_gate_up_proj, which breaks load.
if "mlp.experts" in name:
continue
name = name.replace(weight_name, param_name)
# Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict:
continue
# Skip experts that are not assigned to this worker.
if (("mlp.experts." in name or "mlp.shared_expert." in name)
and name not in params_dict):
continue
if name not in params_dict:
continue
@ -450,17 +440,27 @@ class Qwen2MoeForCausalLM(nn.Module):
weight_loader(param, loaded_weight, shard_id)
break
else:
# Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict:
continue
# Skip experts that are not assigned to this worker.
if (("mlp.experts." in name or "mlp.shared_expert." in name)
and name not in params_dict):
continue
if name not in params_dict:
continue
for mapping in expert_params_mapping:
param_name, weight_name, expert_id, shard_id = mapping
if weight_name not in name:
continue
name = name.replace(weight_name, param_name)
param = params_dict[name]
weight_loader = param.weight_loader
weight_loader(param,
loaded_weight,
weight_name,
shard_id=shard_id,
expert_id=expert_id)
break
else:
# Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict:
continue
if name not in params_dict:
continue
param = params_dict[name]
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
weight_loader(param, loaded_weight)
param = params_dict[name]
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
weight_loader(param, loaded_weight)