diff --git a/.buildkite/lm-eval-harness/configs/Mixtral-8x22B-Instruct-v0.1-FP8-Dynamic.yaml b/.buildkite/lm-eval-harness/configs/Mixtral-8x22B-Instruct-v0.1-FP8-Dynamic.yaml new file mode 100644 index 00000000..75a24e40 --- /dev/null +++ b/.buildkite/lm-eval-harness/configs/Mixtral-8x22B-Instruct-v0.1-FP8-Dynamic.yaml @@ -0,0 +1,11 @@ +# bash ./run-lm-eval-gsm-vllm-baseline.sh -m neuralmagic/Mixtral-8x22B-Instruct-v0.1-FP8-dynamic -b "auto" -l 250 -f 5 -t 8 +model_name: "neuralmagic/Mixtral-8x22B-Instruct-v0.1-FP8-dynamic" +tasks: +- name: "gsm8k" + metrics: + - name: "exact_match,strict-match" + value: 0.86 + - name: "exact_match,flexible-extract" + value: 0.86 +limit: 250 +num_fewshot: 5 diff --git a/.buildkite/lm-eval-harness/configs/Mixtral-8x7B-Instruct-v0.1-FP8.yaml b/.buildkite/lm-eval-harness/configs/Mixtral-8x7B-Instruct-v0.1-FP8.yaml new file mode 100644 index 00000000..436ec219 --- /dev/null +++ b/.buildkite/lm-eval-harness/configs/Mixtral-8x7B-Instruct-v0.1-FP8.yaml @@ -0,0 +1,11 @@ +# bash ./run-lm-eval-gsm-vllm-baseline.sh -m neuralmagic/Mixtral-8x7B-Instruct-v0.1-FP8 -b "auto" -l 250 -f 5 -t 4 +model_name: "neuralmagic/Mixtral-8x7B-Instruct-v0.1-FP8" +tasks: +- name: "gsm8k" + metrics: + - name: "exact_match,strict-match" + value: 0.624 + - name: "exact_match,flexible-extract" + value: 0.624 +limit: 250 +num_fewshot: 5 diff --git a/.buildkite/lm-eval-harness/configs/Qwen2-57B-A14-Instruct.yaml b/.buildkite/lm-eval-harness/configs/Qwen2-57B-A14-Instruct.yaml new file mode 100644 index 00000000..45d5efc8 --- /dev/null +++ b/.buildkite/lm-eval-harness/configs/Qwen2-57B-A14-Instruct.yaml @@ -0,0 +1,11 @@ +# bash ./run-lm-eval-gsm-vllm-baseline.sh -m Qwen/Qwen2-57B-A14B-Instruct -b "auto" -l 250 -f 5 -t 4 +model_name: "Qwen/Qwen2-57B-A14B-Instruct" +tasks: +- name: "gsm8k" + metrics: + - name: "exact_match,strict-match" + value: 0.792 + - name: "exact_match,flexible-extract" + value: 0.824 +limit: 250 +num_fewshot: 5 diff --git a/.buildkite/lm-eval-harness/configs/models-large.txt b/.buildkite/lm-eval-harness/configs/models-large.txt index 127ec5d9..2007dd2e 100644 --- a/.buildkite/lm-eval-harness/configs/models-large.txt +++ b/.buildkite/lm-eval-harness/configs/models-large.txt @@ -1,2 +1,3 @@ Meta-Llama-3-70B-Instruct.yaml Mixtral-8x7B-Instruct-v0.1.yaml +Qwen2-57B-A14-Instruct.yaml diff --git a/tests/kernels/test_moe.py b/tests/kernels/test_moe.py index 22b6769a..2f9eee42 100644 --- a/tests/kernels/test_moe.py +++ b/tests/kernels/test_moe.py @@ -77,8 +77,8 @@ def test_mixtral_moe(dtype: torch.dtype): for i in range(config.num_local_experts): weights = (hf_moe.experts[i].w1.weight.data, hf_moe.experts[i].w3.weight.data) - vllm_moe.w13_weight[i][:] = torch.cat(weights, dim=0) - vllm_moe.w2_weight[i][:] = hf_moe.experts[i].w2.weight.data + vllm_moe.experts.w13_weight[i][:] = torch.cat(weights, dim=0) + vllm_moe.experts.w2_weight[i][:] = hf_moe.experts[i].w2.weight.data # Generate input batch of dimensions [batch_size, seq_len, hidden_dim] hf_inputs = torch.randn((1, 64, config.hidden_size)).to(dtype).to("cuda") diff --git a/vllm/model_executor/layers/fused_moe/__init__.py b/vllm/model_executor/layers/fused_moe/__init__.py index 1dafae50..db837231 100644 --- a/vllm/model_executor/layers/fused_moe/__init__.py +++ b/vllm/model_executor/layers/fused_moe/__init__.py @@ -1,5 +1,7 @@ from vllm.model_executor.layers.fused_moe.fused_moe import ( fused_experts, fused_moe, fused_topk, get_config_file_name, grouped_topk) +from vllm.model_executor.layers.fused_moe.layer import (FusedMoE, + FusedMoEMethodBase) __all__ = [ "fused_moe", @@ -7,4 +9,6 @@ __all__ = [ "fused_experts", "get_config_file_name", "grouped_topk", + "FusedMoE", + "FusedMoEMethodBase", ] diff --git a/vllm/model_executor/layers/fused_moe/layer.py b/vllm/model_executor/layers/fused_moe/layer.py new file mode 100644 index 00000000..73cfcd7f --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/layer.py @@ -0,0 +1,197 @@ +from abc import abstractmethod +from typing import Optional + +import torch + +from vllm.distributed import (get_tensor_model_parallel_rank, + get_tensor_model_parallel_world_size, + tensor_model_parallel_all_reduce) +from vllm.logger import init_logger +from vllm.model_executor.layers.fused_moe.fused_moe import fused_moe +from vllm.model_executor.layers.quantization.base_config import ( + QuantizationConfig, QuantizeMethodBase) +from vllm.model_executor.utils import set_weight_attrs + +logger = init_logger(__name__) + + +class FusedMoEMethodBase(QuantizeMethodBase): + + @abstractmethod + def create_weights(self, layer: torch.nn.Module, num_experts: int, + hidden_size: int, intermediate_size: int, + params_dtype: torch.dtype, **extra_weight_attrs): + raise NotImplementedError + + @abstractmethod + def apply(self, + layer: torch.nn.Module, + x: torch.Tensor, + router_logits: torch.Tensor, + top_k: int, + renormalize: bool = True) -> torch.Tensor: + raise NotImplementedError + + +class UnquantizedFusedMoEMethod(FusedMoEMethodBase): + """MoE method without quantization.""" + + def create_weights(self, layer: torch.nn.Module, num_experts: int, + hidden_size: int, intermediate_size: int, + params_dtype: torch.dtype, **extra_weight_attrs): + + # Fused gate_up_proj (column parallel) + w13_weight = torch.nn.Parameter(torch.empty(num_experts, + 2 * intermediate_size, + hidden_size, + dtype=params_dtype), + requires_grad=False) + layer.register_parameter("w13_weight", w13_weight) + set_weight_attrs(w13_weight, extra_weight_attrs) + + # down_proj (row parallel) + w2_weight = torch.nn.Parameter(torch.empty(num_experts, + hidden_size, + intermediate_size, + dtype=params_dtype), + requires_grad=False) + layer.register_parameter("w2_weight", w2_weight) + set_weight_attrs(w2_weight, extra_weight_attrs) + + def apply(self, + layer: torch.nn.Module, + x: torch.Tensor, + router_logits: torch.Tensor, + top_k: int, + renormalize: bool = True) -> torch.Tensor: + + return fused_moe(x, + layer.w13_weight, + layer.w2_weight, + router_logits, + top_k, + renormalize=renormalize, + inplace=True) + + +class FusedMoE(torch.nn.Module): + """FusedMoE layer for MoE models. + + This layer contains both MergedColumnParallel weights (gate_up_proj / + w13) and RowParallelLinear weights (down_proj/ w2). + + Note: Mixtral uses w1, w2, and w3 for gate, up, and down_proj. We + copy that naming convention here and handle any remapping in the + load_weights function in each model implementation. + + Args: + num_experts: Number of experts in the model + top_k: Number of experts selected for each token + hidden_size: Input hidden state size of the transformer + intermediate_size: Intermediate size of the experts + params_dtype: Data type for the parameters. + reduce_results: Whether to all all_reduce on the output of the layer + renomalize: Whether to renormalize the logits in the fused_moe kernel + quant_config: Quantization configure. + """ + + def __init__( + self, + num_experts: int, + top_k: int, + hidden_size: int, + intermediate_size: int, + params_dtype: Optional[torch.dtype] = None, + reduce_results: bool = False, + renormalize: bool = True, + quant_config: Optional[QuantizationConfig] = None, + tp_size: Optional[int] = None, + ): + super().__init__() + + if params_dtype is None: + params_dtype = torch.get_default_dtype() + + self.tp_size = (tp_size if tp_size is not None else + get_tensor_model_parallel_world_size()) + self.top_k = top_k + self.num_experts = num_experts + self.intermediate_size_per_partition = intermediate_size // self.tp_size + self.reduce_results = reduce_results + self.renormalize = renormalize + + if quant_config is None: + self.quant_method: Optional[QuantizeMethodBase] = ( + UnquantizedFusedMoEMethod()) + else: + self.quant_method = quant_config.get_quant_method(self) + assert self.quant_method is not None + + self.quant_method.create_weights( + layer=self, + num_experts=num_experts, + hidden_size=hidden_size, + intermediate_size=self.intermediate_size_per_partition, + params_dtype=params_dtype, + weight_loader=self.weight_loader) + + def weight_loader(self, param: torch.nn.Parameter, + loaded_weight: torch.Tensor, weight_name: str, + shard_id: int, expert_id: int): + param_data = param.data + + # FIXME(robertgshaw2-neuralmagic): Overfit to Mixtral. + # Follow up PR to enable fp8 for other MoE models. + if "input_scale" in weight_name or "w2.weight_scale" in weight_name: + if param_data[expert_id] != 1 and (param_data[expert_id] - + loaded_weight).abs() > 1e-5: + raise ValueError( + "input_scales of w1 and w3 of a layer " + f"must be equal. But got {param_data[expert_id]} " + f"vs. {loaded_weight}") + param_data[expert_id] = loaded_weight + # FIXME(robertgshaw2-neuralmagic): Overfit to Mixtral. + # Follow up PR to enable fp8 for other MoE models. + elif "weight_scale" in weight_name: + # We have to keep the weight scales of w1 and w3 because + # we need to re-quantize w1/w3 weights after weight loading. + assert "w1" in weight_name or "w3" in weight_name + shard_id = 0 if "w1" in weight_name else 1 + param_data[expert_id][shard_id] = loaded_weight + else: + tp_rank = get_tensor_model_parallel_rank() + shard_size = self.intermediate_size_per_partition + shard = slice(tp_rank * shard_size, (tp_rank + 1) * shard_size) + + # w1, gate_proj case: Load into first shard of w13. + if shard_id == 0: + param_data[expert_id, + 0:shard_size, :] = loaded_weight[shard, :] + # w3, up_proj case: Load into second shard of w13. + elif shard_id == 2: + param_data[expert_id, shard_size:2 * + shard_size, :] = loaded_weight[shard, :] + # w2, down_proj case: Load into only shard of w2. + elif shard_id == 1: + param_data[expert_id, :, :] = loaded_weight[:, shard] + else: + raise ValueError( + f"Shard id must be in [0,1,2] but got {shard_id}") + + def forward(self, hidden_states: torch.Tensor, + router_logits: torch.Tensor): + assert self.quant_method is not None + + # Matrix multiply. + final_hidden_states = self.quant_method.apply( + self, + x=hidden_states, + router_logits=router_logits, + top_k=self.top_k, + renormalize=self.renormalize) + + if self.reduce_results and self.tp_size > 1: + final_hidden_states = tensor_model_parallel_all_reduce( + final_hidden_states) + + return final_hidden_states diff --git a/vllm/model_executor/layers/quantization/fp8.py b/vllm/model_executor/layers/quantization/fp8.py index 5d503a22..dc2ca35c 100644 --- a/vllm/model_executor/layers/quantization/fp8.py +++ b/vllm/model_executor/layers/quantization/fp8.py @@ -6,6 +6,8 @@ from torch.nn.parameter import Parameter from vllm import _custom_ops as ops from vllm.logger import init_logger +from vllm.model_executor.layers.fused_moe import (FusedMoE, FusedMoEMethodBase, + fused_moe) from vllm.model_executor.layers.linear import LinearBase, LinearMethodBase from vllm.model_executor.layers.quantization.base_config import ( QuantizationConfig, QuantizeMethodBase) @@ -71,7 +73,9 @@ class Fp8Config(QuantizationConfig): if isinstance(layer, LinearBase): return Fp8LinearMethod(self) - if isinstance(layer, Attention): + elif isinstance(layer, FusedMoE): + return Fp8MoEMethod(self) + elif isinstance(layer, Attention): return Fp8KVCacheMethod(self) return None @@ -270,6 +274,187 @@ class Fp8LinearMethod(LinearMethodBase): return torch.narrow(output, 0, 0, x.shape[0]) +class Fp8MoEMethod(FusedMoEMethodBase): + """MoE method for FP8. + Supports loading FP8 checkpoints with static weight scale and + dynamic/static activation scale. + + Also supports loading quantized FP16/BF16 model checkpoints with dynamic + activation scaling. The weight scaling factor will be initialized after + the model weights are loaded. + + Args: + quant_config: The quantization config. + """ + + def __init__(self, quant_config: Fp8Config): + self.quant_config = quant_config + + def create_weights(self, layer: Module, num_experts: int, hidden_size: int, + intermediate_size: int, params_dtype: torch.dtype, + **extra_weight_attrs): + + layer.process_after_load = True + + if self.quant_config.is_checkpoint_fp8_serialized: + params_dtype = torch.float8_e4m3fn + + # WEIGHTS + w13_weight = torch.nn.Parameter(torch.empty(num_experts, + 2 * intermediate_size, + hidden_size, + dtype=params_dtype), + requires_grad=False) + layer.register_parameter("w13_weight", w13_weight) + set_weight_attrs(w13_weight, extra_weight_attrs) + + w2_weight = torch.nn.Parameter(torch.empty(num_experts, + hidden_size, + intermediate_size, + dtype=params_dtype), + requires_grad=False) + layer.register_parameter("w2_weight", w2_weight) + set_weight_attrs(w2_weight, extra_weight_attrs) + + # WEIGHT_SCALES + # Allocate 2 scales for w1 and w3 respectively. + # They will be combined to a single scale after weight loading. + w13_scale = torch.nn.Parameter(torch.ones(num_experts, + 2, + dtype=torch.float32), + requires_grad=False) + layer.register_parameter("w13_scale", w13_scale) + + w2_scale = torch.nn.Parameter(torch.ones(num_experts, + dtype=torch.float32), + requires_grad=False) + layer.register_parameter("w2_scale", w2_scale) + + # If loading fp8 checkpoint, pass the weight loaders. + # If loading an fp16 checkpoint, do not (we will quantize in + # process_weights_after_loading() + if self.quant_config.is_checkpoint_fp8_serialized: + set_weight_attrs(w13_scale, extra_weight_attrs) + set_weight_attrs(w2_scale, extra_weight_attrs) + + # INPUT_SCALES + if self.quant_config.activation_scheme == "static": + if not self.quant_config.is_checkpoint_fp8_serialized: + raise ValueError( + "Found static activation scheme for checkpoint that " + "was not serialized fp8.") + + a13_scale = torch.nn.Parameter(torch.ones(num_experts, + dtype=torch.float32), + requires_grad=False) + layer.register_parameter("a13_scale", a13_scale) + set_weight_attrs(a13_scale, extra_weight_attrs) + + a2_scale = torch.nn.Parameter(torch.ones(num_experts, + dtype=torch.float32), + requires_grad=False) + layer.register_parameter("a2_scale", a2_scale) + set_weight_attrs(a2_scale, extra_weight_attrs) + else: + layer.a13_scale = None + layer.a2_scale = None + + def process_weights_after_loading(self, layer: Module) -> None: + if (not hasattr(layer, "process_after_load") + or not layer.process_after_load): + return + + # If checkpoint is fp16, quantize in place. + if not self.quant_config.is_checkpoint_fp8_serialized: + w13_weight = torch.empty_like(layer.w13_weight.data, + dtype=torch.float8_e4m3fn) + w2_weight = torch.empty_like(layer.w2_weight.data, + dtype=torch.float8_e4m3fn) + + # Re-initialize w13_scale because we directly quantize + # merged w13 weights and generate a single scaling factor. + layer.w13_scale = torch.nn.Parameter(torch.ones( + layer.num_experts, + dtype=torch.float32, + device=w13_weight.device), + requires_grad=False) + for expert in range(layer.num_experts): + w13_weight[expert, :, :], layer.w13_scale[ + expert] = ops.scaled_fp8_quant( + layer.w13_weight.data[expert, :, :]) + w2_weight[expert, :, :], layer.w2_scale[ + expert] = ops.scaled_fp8_quant( + layer.w2_weight.data[expert, :, :]) + layer.w13_weight = torch.nn.Parameter(w13_weight, + requires_grad=False) + layer.w2_weight = torch.nn.Parameter(w2_weight, + requires_grad=False) + return + + # If checkpoint is fp8, we need to handle that the + # MoE kernels require single activation scale and single weight + # scale for w13 per expert. + else: + # Fp8 moe kernels require a single activation scale. + # We take the max of all the scales in case they differ. + if self.quant_config.activation_scheme == "static": + if layer.a13_scale is None or layer.a2_scale is None: + raise ValueError( + "QuantConfig has static quantization, but found " + "activation scales are None.") + if (not all_close_1d(layer.a13_scale) + or not all_close_1d(layer.a2_scale)): + print_warning_once( + "Found input_scales that are not equal for " + "fp8 MoE layer. Using the maximum across experts " + "for each layer. ") + layer.a13_scale = torch.nn.Parameter(layer.a13_scale.max(), + requires_grad=False) + layer.a2_scale = torch.nn.Parameter(layer.a2_scale.max(), + requires_grad=False) + + # Fp8 moe kernel needs single weight scale for w13 per expert. + # We take the max then dequant and requant each expert. + assert layer.w13_scale is not None + shard_size = layer.intermediate_size_per_partition + max_w13_scales = layer.w13_scale.max(dim=1).values + for expert_id in range(layer.num_experts): + start = 0 + for shard_id in range(2): + dq_weight = per_tensor_dequantize( + layer.w13_weight[expert_id][start:start + + shard_size, :], + layer.w13_scale[expert_id][shard_id]) + layer.w13_weight[expert_id][ + start:start + shard_size, :] = per_tensor_quantize( + dq_weight, max_w13_scales[expert_id]) + start += shard_size + + layer.w13_scale = torch.nn.Parameter(max_w13_scales, + requires_grad=False) + return + + def apply(self, + layer: torch.nn.Module, + x: torch.Tensor, + router_logits: torch.Tensor, + top_k: int, + renormalize: bool = True) -> torch.Tensor: + + return fused_moe(x, + layer.w13_weight, + layer.w2_weight, + router_logits, + top_k, + renormalize=renormalize, + inplace=True, + use_fp8=True, + w1_scale=layer.w13_scale, + w2_scale=layer.w2_scale, + a1_scale=layer.a13_scale, + a2_scale=layer.a2_scale) + + class Fp8KVCacheMethod(QuantizeMethodBase): """Supports loading kv-cache scaling factors from FP8 checkpoints. """ @@ -321,3 +506,8 @@ def per_tensor_dequantize( fake_qweight = tensor.to(torch.float16) dq_weight = fake_qweight * inv_scale return dq_weight + + +def all_close_1d(x: torch.Tensor) -> bool: + assert len(x.shape) == 1 + return all(torch.allclose(x[0], x[i]) for i in range(x.shape[0])) diff --git a/vllm/model_executor/models/mixtral.py b/vllm/model_executor/models/mixtral.py index 05c36b9c..5144e7ea 100644 --- a/vllm/model_executor/models/mixtral.py +++ b/vllm/model_executor/models/mixtral.py @@ -27,13 +27,10 @@ import torch from torch import nn from transformers import MixtralConfig -from vllm import _custom_ops as ops from vllm.attention import Attention, AttentionMetadata from vllm.config import CacheConfig, LoRAConfig -from vllm.distributed import (get_tensor_model_parallel_rank, - get_tensor_model_parallel_world_size, - tensor_model_parallel_all_reduce) -from vllm.model_executor.layers.fused_moe import fused_moe +from vllm.distributed import get_tensor_model_parallel_world_size +from vllm.model_executor.layers.fused_moe import FusedMoE from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.linear import (QKVParallelLinear, ReplicatedLinear, @@ -41,16 +38,12 @@ from vllm.model_executor.layers.linear import (QKVParallelLinear, from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.quantization.base_config import ( QuantizationConfig) -from vllm.model_executor.layers.quantization.fp8 import (Fp8Config, - per_tensor_dequantize, - per_tensor_quantize) from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.sampler import Sampler from vllm.model_executor.layers.vocab_parallel_embedding import ( DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding) from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.sampling_metadata import SamplingMetadata -from vllm.model_executor.utils import set_weight_attrs from vllm.sequence import IntermediateTensors, SamplerOutput from vllm.utils import print_warning_once @@ -66,227 +59,40 @@ class MixtralMoE(nn.Module): across ranks. """ - def __init__( - self, - num_experts: int, - top_k: int, - hidden_size: int, - intermediate_size: int, - params_dtype: Optional[torch.dtype] = None, - tp_size: Optional[int] = None, - quant_config: Optional[QuantizationConfig] = None, - ): + def __init__(self, + num_experts: int, + top_k: int, + hidden_size: int, + intermediate_size: int, + params_dtype: Optional[torch.dtype] = None, + quant_config: Optional[QuantizationConfig] = None, + tp_size: Optional[int] = None): super().__init__() - self.tp_size = tp_size or get_tensor_model_parallel_world_size() - self.num_total_experts = num_experts - self.top_k = top_k self.hidden_size = hidden_size - self.intermediate_size = intermediate_size // self.tp_size - self.quant_config = quant_config - - # FIXME(pcmoritz): Make this more general to support different - # quantization schemes - self.use_fp8 = isinstance(quant_config, Fp8Config) - - if params_dtype is None: - params_dtype = torch.get_default_dtype() - self.params_dtype = params_dtype # Gate always runs at half / full precision for now. - self.gate = ReplicatedLinear(self.hidden_size, - self.num_total_experts, + self.gate = ReplicatedLinear(hidden_size, + num_experts, bias=False, - params_dtype=self.params_dtype, + params_dtype=params_dtype, quant_config=None) - if self.use_fp8 and self.quant_config.is_checkpoint_fp8_serialized: - params_dtype = torch.float8_e4m3fn - - self.w13_weight = nn.Parameter(torch.empty(self.num_total_experts, - 2 * self.intermediate_size, - self.hidden_size, - dtype=params_dtype), - requires_grad=False) - self.w2_weight = nn.Parameter(torch.empty(self.num_total_experts, - self.hidden_size, - self.intermediate_size, - dtype=params_dtype), - requires_grad=False) - - set_weight_attrs(self.w13_weight, { - "weight_loader": self.weight_loader, - }) - set_weight_attrs(self.w2_weight, { - "weight_loader": self.weight_loader, - }) - - # Used for fp8. - self.w13_scale = None - self.w2_scale = None - self.a13_scale = None - self.a2_scale = None - - if self.use_fp8: - # WEIGHT_SCALE (for fp8) - # Allocate 2 scales for w1 and w3 respectively. - # They will be combined to a single scale after weight loading. - self.w13_scale = nn.Parameter(torch.ones(self.num_total_experts, - 2, - dtype=torch.float32), - requires_grad=False) - self.w2_scale = nn.Parameter(torch.ones(self.num_total_experts, - dtype=torch.float32), - requires_grad=False) - - # If loading fp8 checkpoint, pass the weight loaders. - # If loading an fp16 checkpoint, do not (we will quantize in - # process_weights_after_loading() - if quant_config.is_checkpoint_fp8_serialized: - set_weight_attrs(self.w13_scale, { - "weight_loader": self.weight_loader, - }) - set_weight_attrs(self.w2_scale, { - "weight_loader": self.weight_loader, - }) - - # INPUT_SCALE (for fp8) - if quant_config.activation_scheme == "static": - if not quant_config.is_checkpoint_fp8_serialized: - raise ValueError( - "Found static activation scheme for checkpoint that " - "was not serialized fp8.") - self.a13_scale = nn.Parameter(torch.ones( - self.num_total_experts, dtype=torch.float32), - requires_grad=False) - self.a2_scale = nn.Parameter(torch.ones(self.num_total_experts, - dtype=torch.float32), - requires_grad=False) - - set_weight_attrs(self.a13_scale, { - "weight_loader": self.weight_loader, - }) - set_weight_attrs(self.a2_scale, { - "weight_loader": self.weight_loader, - }) - - def weight_loader(self, param: nn.Parameter, loaded_weight: torch.Tensor, - weight_name: str, expert_id: int): - tp_rank = get_tensor_model_parallel_rank() - param_data = param.data - shard_size = self.intermediate_size - shard = slice(tp_rank * shard_size, (tp_rank + 1) * shard_size) - if weight_name.endswith("w1.weight"): - param_data[expert_id, 0:shard_size, :] = loaded_weight[shard, :] - if weight_name.endswith("w3.weight"): - param_data[expert_id, - shard_size:2 * shard_size, :] = loaded_weight[shard, :] - if weight_name.endswith("w2.weight"): - param_data[expert_id, :, :] = loaded_weight[:, shard] - - # Loading scales - if "input_scale" in weight_name or "w2.weight_scale" in weight_name: - if param_data[expert_id] != 1 and (param_data[expert_id] - - loaded_weight).abs() > 1e-5: - raise ValueError( - "input_scales of w1 and w3 of a layer " - f"must be equal. But got {param_data[expert_id]} " - f"vs. {loaded_weight}") - param_data[expert_id] = loaded_weight - elif "weight_scale" in weight_name: - # We have to keep the weight scales of w1 and w3 because - # we need to re-quantize w1/w3 weights after weight loading. - assert "w1" in weight_name or "w3" in weight_name - shard_id = 0 if "w1" in weight_name else 1 - param_data[expert_id][shard_id] = loaded_weight - - def process_weights_after_loading(self): - # Fp8 is the only case where we need to process after loading. - if not self.use_fp8: - return - - # If checkpoint is fp16, quantize here. - if not self.quant_config.is_checkpoint_fp8_serialized: - w13_weight = torch.empty_like(self.w13_weight.data, - dtype=torch.float8_e4m3fn) - w2_weight = torch.empty_like(self.w2_weight.data, - dtype=torch.float8_e4m3fn) - - # Re-initialize w13_scale because we directly quantize - # merged w13 weights and generate a single scaling factor. - self.w13_scale = nn.Parameter(torch.ones(self.num_total_experts, - dtype=torch.float32), - requires_grad=False) - for expert in range(self.num_total_experts): - w13_weight[expert, :, :], self.w13_scale[ - expert] = ops.scaled_fp8_quant( - self.w13_weight.data[expert, :, :]) - w2_weight[expert, :, :], self.w2_scale[ - expert] = ops.scaled_fp8_quant( - self.w2_weight.data[expert, :, :]) - self.w13_weight = nn.Parameter(w13_weight, requires_grad=False) - self.w2_weight = nn.Parameter(w2_weight, requires_grad=False) - - else: - # If checkpoint is fp8 + static, cleanup input_scales. - # Since state_dict has an input_scale per expert but our kernels - # are passed one input_scale shared across all experts. - if self.quant_config.activation_scheme == "static": - if self.a13_scale is None or self.a2_scale is None: - raise ValueError( - "QuantConfig has static quantization, but found " - "activation scales are None.") - - if (not all_close_1d(self.a13_scale) - or not all_close_1d(self.a2_scale)): - print_warning_once( - "Found input_scales that are not equal for " - "fp8 MoE layer. Using the maximum across experts " - "for each layer. ") - - self.a13_scale = nn.Parameter(self.a13_scale.max(), - requires_grad=False) - self.a2_scale = nn.Parameter(self.a2_scale.max(), - requires_grad=False) - - assert self.w13_scale is not None - shard_size = self.intermediate_size - max_w13_scales = self.w13_scale.max(dim=1).values - for expert_id in range(self.num_total_experts): - start = 0 - for shard_id in range(2): - dq_weight = per_tensor_dequantize( - self.w13_weight[expert_id][start:start + - shard_size, :], - self.w13_scale[expert_id][shard_id]) - self.w13_weight[expert_id][ - start:start + shard_size, :] = per_tensor_quantize( - dq_weight, max_w13_scales[expert_id]) - start += shard_size - - self.w13_scale = nn.Parameter(max_w13_scales, requires_grad=False) + self.experts = FusedMoE(num_experts=num_experts, + top_k=top_k, + hidden_size=hidden_size, + intermediate_size=intermediate_size, + params_dtype=params_dtype, + reduce_results=True, + renormalize=True, + quant_config=quant_config, + tp_size=tp_size) def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: num_tokens, hidden_size = hidden_states.shape hidden_states = hidden_states.view(-1, self.hidden_size) # router_logits: (num_tokens, n_experts) router_logits, _ = self.gate(hidden_states) - final_hidden_states = fused_moe(hidden_states, - self.w13_weight, - self.w2_weight, - router_logits, - self.top_k, - renormalize=True, - inplace=True, - use_fp8=self.use_fp8, - w1_scale=self.w13_scale, - w2_scale=self.w2_scale, - a1_scale=self.a13_scale, - a2_scale=self.a2_scale) - - if self.tp_size > 1: - final_hidden_states = tensor_model_parallel_all_reduce( - final_hidden_states) - + final_hidden_states = self.experts(hidden_states, router_logits) return final_hidden_states.view(num_tokens, hidden_size) @@ -566,25 +372,28 @@ class MixtralForCausalLM(nn.Module, SupportsLoRA): expert_params_mapping = [ # These are the weight scales for the experts - # (param_name, weight_name, expert_id) - ("w13_scale" if weight_name in ["w1", "w3"] else "w2_scale", - f"experts.{expert_id}.{weight_name}.weight_scale", expert_id) - for expert_id in range(self.config.num_local_experts) - for weight_name in ["w1", "w2", "w3"] + # (param_name, weight_name, expert_id, shard_id) + ("experts.w13_scale" + if weight_name in ["w1", "w3"] else "experts.w2_scale", + f"experts.{expert_id}.{weight_name}.weight_scale", expert_id, + shard_id) for expert_id in range(self.config.num_local_experts) + for shard_id, weight_name in enumerate(["w1", "w2", "w3"]) ] + [ # These are the weights for the experts # (param_name, weight_name, expert_id) - ("w13_weight" if weight_name in ["w1", "w3"] else "w2_weight", - f"experts.{expert_id}.{weight_name}.weight", expert_id) + ("experts.w13_weight" + if weight_name in ["w1", "w3"] else "experts.w2_weight", + f"experts.{expert_id}.{weight_name}.weight", expert_id, shard_id) for expert_id in range(self.config.num_local_experts) - for weight_name in ["w1", "w2", "w3"] + for shard_id, weight_name in enumerate(["w1", "w2", "w3"]) ] + [ # These are the activation scales for the experts # (param_name, weight_name, expert_id) - ("a13_scale" if weight_name in ["w1", "w3"] else "a2_scale", - f"experts.{expert_id}.{weight_name}.input_scale", expert_id) - for expert_id in range(self.config.num_local_experts) - for weight_name in ["w1", "w2", "w3"] + ("experts.a13_scale" + if weight_name in ["w1", "w3"] else "experts.a2_scale", + f"experts.{expert_id}.{weight_name}.input_scale", expert_id, + shard_id) for expert_id in range(self.config.num_local_experts) + for shard_id, weight_name in enumerate(["w1", "w2", "w3"]) ] params_dict = dict(self.named_parameters()) @@ -604,7 +413,8 @@ class MixtralForCausalLM(nn.Module, SupportsLoRA): weight_loader(param, loaded_weight, shard_id) break else: - for param_name, weight_name, expert_id in expert_params_mapping: + for mapping in expert_params_mapping: + param_name, weight_name, expert_id, shard_id = mapping if weight_name not in name: continue name = name.replace(weight_name, param_name) @@ -613,6 +423,7 @@ class MixtralForCausalLM(nn.Module, SupportsLoRA): weight_loader(param, loaded_weight, weight_name, + shard_id=shard_id, expert_id=expert_id) break else: @@ -637,8 +448,3 @@ class MixtralForCausalLM(nn.Module, SupportsLoRA): weight_loader = getattr(param, "weight_loader", default_weight_loader) weight_loader(param, loaded_weight) - - -def all_close_1d(x: torch.Tensor) -> bool: - assert len(x.shape) == 1 - return all(torch.allclose(x[0], x[i]) for i in range(x.shape[0])) diff --git a/vllm/model_executor/models/qwen2_moe.py b/vllm/model_executor/models/qwen2_moe.py index b3e7dfef..8decb446 100644 --- a/vllm/model_executor/models/qwen2_moe.py +++ b/vllm/model_executor/models/qwen2_moe.py @@ -31,11 +31,10 @@ from transformers import PretrainedConfig from vllm.attention import Attention, AttentionMetadata from vllm.config import CacheConfig -from vllm.distributed import (get_tensor_model_parallel_rank, - get_tensor_model_parallel_world_size, +from vllm.distributed import (get_tensor_model_parallel_world_size, tensor_model_parallel_all_reduce) from vllm.model_executor.layers.activation import SiluAndMul -from vllm.model_executor.layers.fused_moe import fused_moe +from vllm.model_executor.layers.fused_moe import FusedMoE from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.linear import (MergedColumnParallelLinear, QKVParallelLinear, @@ -93,28 +92,23 @@ class Qwen2MoeSparseMoeBlock(nn.Module): quant_config: Optional[QuantizationConfig] = None, ): super().__init__() - self.config = config - self.rank = get_tensor_model_parallel_rank() self.tp_size = get_tensor_model_parallel_world_size() - self.n_routed_experts = config.num_experts - self.top_k = config.num_experts_per_tok - if self.tp_size > self.n_routed_experts: + + if self.tp_size > config.num_experts: raise ValueError( f"Tensor parallel size {self.tp_size} is greater than " - f"the number of experts {self.n_routed_experts}.") + f"the number of experts {config.num_experts}.") - self.experts = nn.ModuleList([ - Qwen2MoeMLP(hidden_size=config.hidden_size, - intermediate_size=config.moe_intermediate_size, - hidden_act=config.hidden_act, - quant_config=quant_config, - reduce_results=False) - for idx in range(self.n_routed_experts) - ]) - self.pack_params() + self.experts = FusedMoE(num_experts=config.num_experts, + top_k=config.num_experts_per_tok, + hidden_size=config.hidden_size, + intermediate_size=config.moe_intermediate_size, + reduce_results=False, + renormalize=config.norm_topk_prob, + quant_config=quant_config) self.gate = ReplicatedLinear(config.hidden_size, - self.n_routed_experts, + config.num_experts, bias=False, quant_config=None) if config.shared_expert_intermediate_size > 0: @@ -131,25 +125,6 @@ class Qwen2MoeSparseMoeBlock(nn.Module): 1, bias=False) - def pack_params(self): - w1 = [] - w2 = [] - for expert in self.experts: - w1.append(expert.gate_up_proj.weight) - w2.append(expert.down_proj.weight) - self.w1 = torch._utils._flatten_dense_tensors(w1) - w1s = torch._utils._unflatten_dense_tensors(self.w1, w1) - for data, param in zip(w1s, w1): - param.data = data - self.w1 = self.w1.view(len(w1), *w1s[0].shape) - - self.w2 = torch._utils._flatten_dense_tensors(w2) - w2s = torch._utils._unflatten_dense_tensors(self.w2, w2) - for data, param in zip(w2s, w2): - param.data = data - - self.w2 = self.w2.view(len(w2), *w2s[0].shape) - def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: num_tokens, hidden_dim = hidden_states.shape hidden_states = hidden_states.view(-1, hidden_dim) @@ -162,18 +137,13 @@ class Qwen2MoeSparseMoeBlock(nn.Module): # router_logits: (num_tokens, n_experts) router_logits, _ = self.gate(hidden_states) - final_hidden_states = fused_moe(hidden_states, - self.w1, - self.w2, - router_logits, - self.top_k, - renormalize=self.config.norm_topk_prob, - inplace=True) - + final_hidden_states = self.experts(hidden_states=hidden_states, + router_logits=router_logits) if shared_output is not None: final_hidden_states = final_hidden_states + shared_output - final_hidden_states = tensor_model_parallel_all_reduce( - final_hidden_states) + if self.tp_size > 1: + final_hidden_states = tensor_model_parallel_all_reduce( + final_hidden_states) return final_hidden_states.view(num_tokens, hidden_dim) @@ -284,7 +254,12 @@ class Qwen2MoeDecoderLayer(nn.Module): cache_config=cache_config, quant_config=quant_config, ) - if (layer_idx not in config.mlp_only_layers) and ( + + # Note: Qwen/Qwen2-57B-A14B-Instruct does not have + # `mlp_only_layers` in the config. + mlp_only_layers = ([] if not hasattr(config, "mlp_only_layers") else + config.mlp_only_layers) + if (layer_idx not in mlp_only_layers) and ( config.num_experts > 0 and (layer_idx + 1) % config.decoder_sparse_step == 0): self.mlp = Qwen2MoeSparseMoeBlock(config=config, @@ -427,21 +402,36 @@ class Qwen2MoeForCausalLM(nn.Module): ("gate_up_proj", "up_proj", 1), ] + expert_params_mapping = [ + # These are the weights for the experts + # (param_name, weight_name, expert_id, shard_id) + ("experts.w13_weight" if weight_name in ["gate_proj", "up_proj"] + else "experts.w2_weight", + f"experts.{expert_id}.{weight_name}.weight", expert_id, shard_id) + for expert_id in range(self.config.num_experts) for shard_id, + weight_name in enumerate(["gate_proj", "down_proj", "up_proj"]) + ] + params_dict = dict(self.named_parameters()) for name, loaded_weight in weights: if "rotary_emb.inv_freq" in name: continue for (param_name, weight_name, shard_id) in stacked_params_mapping: + # Skip non-stacked layers and experts (experts handled below). if weight_name not in name: continue + # We have mlp.experts[0].gate_proj in the checkpoint. + # Since we handle the experts below in expert_params_mapping, + # we need to skip here BEFORE we update the name, otherwise + # name will be updated to mlp.experts[0].gate_up_proj, which + # will then be updated below in expert_params_mapping + # for mlp.experts[0].gate_gate_up_proj, which breaks load. + if "mlp.experts" in name: + continue name = name.replace(weight_name, param_name) # Skip loading extra bias for GPTQ models. if name.endswith(".bias") and name not in params_dict: continue - # Skip experts that are not assigned to this worker. - if (("mlp.experts." in name or "mlp.shared_expert." in name) - and name not in params_dict): - continue if name not in params_dict: continue @@ -450,17 +440,27 @@ class Qwen2MoeForCausalLM(nn.Module): weight_loader(param, loaded_weight, shard_id) break else: - # Skip loading extra bias for GPTQ models. - if name.endswith(".bias") and name not in params_dict: - continue - # Skip experts that are not assigned to this worker. - if (("mlp.experts." in name or "mlp.shared_expert." in name) - and name not in params_dict): - continue - if name not in params_dict: - continue + for mapping in expert_params_mapping: + param_name, weight_name, expert_id, shard_id = mapping + if weight_name not in name: + continue + name = name.replace(weight_name, param_name) + param = params_dict[name] + weight_loader = param.weight_loader + weight_loader(param, + loaded_weight, + weight_name, + shard_id=shard_id, + expert_id=expert_id) + break + else: + # Skip loading extra bias for GPTQ models. + if name.endswith(".bias") and name not in params_dict: + continue + if name not in params_dict: + continue - param = params_dict[name] - weight_loader = getattr(param, "weight_loader", - default_weight_loader) - weight_loader(param, loaded_weight) + param = params_dict[name] + weight_loader = getattr(param, "weight_loader", + default_weight_loader) + weight_loader(param, loaded_weight)