diff --git a/examples/offline_inference_arctic.py b/examples/offline_inference_arctic.py new file mode 100644 index 00000000..1fec3c99 --- /dev/null +++ b/examples/offline_inference_arctic.py @@ -0,0 +1,26 @@ +from vllm import LLM, SamplingParams + +# Sample prompts. +prompts = [ + "Hello, my name is", + "The president of the United States is", + "The capital of France is", + "The future of AI is", +] +# Create a sampling params object. +sampling_params = SamplingParams(temperature=0.8, top_p=0.95) + +# Create an LLM. +llm = LLM(model="snowflake/snowflake-arctic-instruct", + quantization="deepspeedfp", + tensor_parallel_size=8, + trust_remote_code=True) +# Generate texts from the prompts. The output is a list of RequestOutput objects +# that contain the prompt, generated text, and other information. + +outputs = llm.generate(prompts, sampling_params) +# Print the outputs. +for output in outputs: + prompt = output.prompt + generated_text = output.outputs[0].text + print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}") diff --git a/vllm/model_executor/layers/fused_moe/__init__.py b/vllm/model_executor/layers/fused_moe/__init__.py index 496d69c8..2926c7d1 100644 --- a/vllm/model_executor/layers/fused_moe/__init__.py +++ b/vllm/model_executor/layers/fused_moe/__init__.py @@ -1,7 +1,9 @@ from vllm.model_executor.layers.fused_moe.fused_moe import ( - fused_moe, get_config_file_name) + fused_experts, fused_moe, fused_topk, get_config_file_name) __all__ = [ "fused_moe", + "fused_topk", + "fused_experts", "get_config_file_name", ] diff --git a/vllm/model_executor/layers/fused_moe/fused_moe.py b/vllm/model_executor/layers/fused_moe/fused_moe.py index 3cb04194..bb7938b3 100644 --- a/vllm/model_executor/layers/fused_moe/fused_moe.py +++ b/vllm/model_executor/layers/fused_moe/fused_moe.py @@ -308,60 +308,16 @@ def get_moe_configs(E: int, N: int, return None -def fused_moe( +def fused_topk( hidden_states: torch.Tensor, - w1: torch.Tensor, - w2: torch.Tensor, gating_output: torch.Tensor, topk: int, renormalize: bool, - inplace: bool = False, - override_config: Optional[Dict[str, Any]] = None, - use_fp8: bool = False, - w1_scale: Optional[torch.Tensor] = None, - w2_scale: Optional[torch.Tensor] = None, - a1_scale: Optional[torch.Tensor] = None, - a2_scale: Optional[torch.Tensor] = None, -) -> torch.Tensor: - """ - This function computes a Mixture of Experts (MoE) layer using two sets of - weights, w1 and w2, and top-k gating mechanism. - - Parameters: - - hidden_states (torch.Tensor): The input tensor to the MoE layer. - - w1 (torch.Tensor): The first set of expert weights. - - w2 (torch.Tensor): The second set of expert weights. - - gating_output (torch.Tensor): The output of the gating operation - (before softmax). - - topk (int): The number of top-k experts to select. - - renormalize (bool): If True, renormalize the top-k weights to sum to 1. - - inplace (bool): If True, perform the operation in-place. - Defaults to False. - - override_config (Optional[Dict[str, Any]]): Optional override - for the kernel configuration. - - use_fp8 (bool): If True, use fp8 arithmetic to compute the inner - products for w1 and w2. Defaults to False. - - w1_scale (Optional[torch.Tensor]): Optional scale to be used for - w1. - - w2_scale (Optional[torch.Tensor]): Optional scale to be used for - w2. - - Returns: - - torch.Tensor: The output tensor after applying the MoE layer. - """ - # Check constraints. +): assert hidden_states.shape[0] == gating_output.shape[0], ( "Number of tokens mismatch") - assert hidden_states.shape[1] == w1.shape[2], "Hidden size mismatch" - assert gating_output.shape[1] == w1.shape[0], "Number of experts mismatch" - assert hidden_states.is_contiguous(), "Hidden_states must be contiguous" - assert w1.is_contiguous(), "Expert weights1 must be contiguous" - assert w2.is_contiguous(), "Expert weights2 must be contiguous" - assert hidden_states.dtype in [ - torch.float32, torch.float16, torch.bfloat16 - ] + M, _ = hidden_states.shape - E, N, _ = w1.shape if is_hip(): # The MoE kernels are not yet supported on ROCm. @@ -393,6 +349,33 @@ def fused_moe( del token_expert_indicies # Not used. Will be used in the future. if renormalize: topk_weights = topk_weights / topk_weights.sum(dim=-1, keepdim=True) + return topk_weights, topk_ids + + +def fused_experts(hidden_states: torch.Tensor, + w1: torch.Tensor, + w2: torch.Tensor, + topk_weights: torch.Tensor, + topk_ids: torch.Tensor, + inplace: bool = False, + override_config: Optional[Dict[str, Any]] = None, + use_fp8: bool = False, + w1_scale: Optional[torch.Tensor] = None, + w2_scale: Optional[torch.Tensor] = None, + a1_scale: Optional[torch.Tensor] = None, + a2_scale: Optional[torch.Tensor] = None): + # Check constraints. + assert hidden_states.shape[1] == w1.shape[2], "Hidden size mismatch" + assert topk_weights.shape == topk_ids.shape, "topk shape mismatch" + assert hidden_states.is_contiguous(), "Hidden_states must be contiguous" + assert w1.is_contiguous(), "Expert weights1 must be contiguous" + assert w2.is_contiguous(), "Expert weights2 must be contiguous" + assert hidden_states.dtype in [ + torch.float32, torch.float16, torch.bfloat16 + ] + + M, _ = hidden_states.shape + E, N, _ = w1.shape if override_config: config = override_config @@ -477,3 +460,63 @@ def fused_moe( out=hidden_states) return torch.sum(intermediate_cache3.view(*intermediate_cache3.shape), dim=1) + + +def fused_moe( + hidden_states: torch.Tensor, + w1: torch.Tensor, + w2: torch.Tensor, + gating_output: torch.Tensor, + topk: int, + renormalize: bool, + inplace: bool = False, + override_config: Optional[Dict[str, Any]] = None, + use_fp8: bool = False, + w1_scale: Optional[torch.Tensor] = None, + w2_scale: Optional[torch.Tensor] = None, + a1_scale: Optional[torch.Tensor] = None, + a2_scale: Optional[torch.Tensor] = None, +) -> torch.Tensor: + """ + This function computes a Mixture of Experts (MoE) layer using two sets of + weights, w1 and w2, and top-k gating mechanism. + + Parameters: + - hidden_states (torch.Tensor): The input tensor to the MoE layer. + - w1 (torch.Tensor): The first set of expert weights. + - w2 (torch.Tensor): The second set of expert weights. + - gating_output (torch.Tensor): The output of the gating operation + (before softmax). + - topk (int): The number of top-k experts to select. + - renormalize (bool): If True, renormalize the top-k weights to sum to 1. + - inplace (bool): If True, perform the operation in-place. + Defaults to False. + - override_config (Optional[Dict[str, Any]]): Optional override + for the kernel configuration. + - use_fp8 (bool): If True, use fp8 arithmetic to compute the inner + products for w1 and w2. Defaults to False. + - w1_scale (Optional[torch.Tensor]): Optional scale to be used for + w1. + - w2_scale (Optional[torch.Tensor]): Optional scale to be used for + w2. + + Returns: + - torch.Tensor: The output tensor after applying the MoE layer. + """ + # Check constraints. + assert gating_output.shape[1] == w1.shape[0], "Number of experts mismatch" + + topk_weights, topk_ids = fused_topk(hidden_states, gating_output, topk, + renormalize) + return fused_experts(hidden_states, + w1, + w2, + topk_weights, + topk_ids, + inplace=inplace, + override_config=override_config, + use_fp8=use_fp8, + w1_scale=w1_scale, + w2_scale=w2_scale, + a1_scale=a1_scale, + a2_scale=a2_scale) diff --git a/vllm/model_executor/layers/quantization/__init__.py b/vllm/model_executor/layers/quantization/__init__.py index 1c652e34..5798bc35 100644 --- a/vllm/model_executor/layers/quantization/__init__.py +++ b/vllm/model_executor/layers/quantization/__init__.py @@ -4,6 +4,8 @@ from vllm.model_executor.layers.quantization.aqlm import AQLMConfig from vllm.model_executor.layers.quantization.awq import AWQConfig from vllm.model_executor.layers.quantization.base_config import ( QuantizationConfig) +from vllm.model_executor.layers.quantization.deepspeedfp import ( + DeepSpeedFPConfig) from vllm.model_executor.layers.quantization.fp8 import Fp8Config from vllm.model_executor.layers.quantization.gptq import GPTQConfig from vllm.model_executor.layers.quantization.gptq_marlin import ( @@ -19,6 +21,7 @@ QUANTIZATION_METHODS: Dict[str, Type[QuantizationConfig]] = { "squeezellm": SqueezeLLMConfig, "gptq_marlin": GPTQMarlinConfig, "marlin": MarlinConfig, + "deepspeedfp": DeepSpeedFPConfig } diff --git a/vllm/model_executor/layers/quantization/deepspeedfp.py b/vllm/model_executor/layers/quantization/deepspeedfp.py new file mode 100644 index 00000000..31cdffbc --- /dev/null +++ b/vllm/model_executor/layers/quantization/deepspeedfp.py @@ -0,0 +1,194 @@ +from typing import Any, Dict, List, Optional + +import torch +import torch.nn as nn +import torch.nn.functional as F + +from vllm.model_executor.layers.linear import LinearBase, LinearMethodBase +from vllm.model_executor.layers.quantization.base_config import ( + QuantizationConfig) +from vllm.model_executor.utils import set_weight_attrs + + +class DeepSpeedFPConfig(QuantizationConfig): + """Config for DeepSpeed FP quantizer. It supports fp6 and fp8. + + Args: + weight_bits: the target quantization bits, 6 or 8. + group_size: group size for quantizaiton, default to 128. + """ + + def __init__( + self, + weight_bits: int = 8, + group_size: int = 512, + ) -> None: + self.weight_bits = weight_bits + self.group_size = group_size + self.valid_types = [torch.bfloat16, torch.float16] + + if self.weight_bits not in (6, 8): + raise ValueError( + "Currently, only 6-bit or 8-bit weight quantization are " + f"supported for DeepSpeed FP quantizaiton, but got " + f"{self.weight_bits} bits.") + + def __repr__(self) -> str: + return (f"DeepSpeedFPConfig(weight_bits={self.weight_bits}), " + f"group_size={self.group_size}") + + @classmethod + def get_name(cls) -> str: + return "DeepSpeedFP" + + @classmethod + def from_config(cls, config: Dict[str, Any]) -> "DeepSpeedFPConfig": + weight_bits = cls.get_from_keys(config, ["bits"]) + group_size = cls.get_from_keys(config, ["group_size"]) + return cls(weight_bits=weight_bits, group_size=group_size) + + def get_linear_method(self) -> "DeepSpeedFPLinearMethod": + return DeepSpeedFPLinearMethod(self) + + def get_scaled_act_names(self) -> List[str]: + return [] + + @classmethod + def get_supported_act_dtypes(cls) -> List[torch.dtype]: + return [torch.half, torch.bfloat16] + + @classmethod + # Need to figure it out + def get_min_capability(cls) -> int: + return 60 + + @staticmethod + def get_config_filenames() -> List[str]: + return [ + "quant_config.json", + "quantize_config.json", + ] + + def get_quant_method( + self, + layer: torch.nn.Module) -> Optional["DeepSpeedFPLinearMethod"]: + if isinstance(layer, LinearBase): + return DeepSpeedFPLinearMethod(self) + return None + + +class DeepSpeedFPLinearMethod(LinearMethodBase): + """Linear method for DeepSpeedFP quantizer. + + Args: + quant_config: the DeepSpeedFP quantization config. + """ + + def __init__(self, quant_config: DeepSpeedFPConfig): + self.quant_config = quant_config + self.weight = None + + def create_weights(self, + layer: torch.nn.Module, + input_size_per_partition: int, + output_partition_sizes: List[int], + input_size: int, + output_size: int, + params_dtype: torch.dtype, + weight_loader=None, + **extra_weight_attrs): + del output_size + del input_size + output_size_per_partition = sum(output_partition_sizes) + weight = DeepSpeedFPParameter( + torch.Size((output_size_per_partition, input_size_per_partition)), + params_dtype=params_dtype, + quant_config=self.quant_config, + ) + set_weight_attrs(weight, { + "input_dim": 1, + "output_dim": 0, + }) + layer.register_parameter("weight", weight) + + def quant_weight_loader(param, loaded_weight, *args, **kwargs): + # Calls the original weight loader (if any), quantizes the result, + # and then loads the quantized parameter. + if weight_loader is not None: + orig_param_data = param.data + param.data = param.ds_dequantize() + weight_loader(param, loaded_weight, *args, **kwargs) + param.data, loaded_weight = orig_param_data, param.data + param.ds_quantize_(loaded_weight.cuda()) + + extra_weight_attrs["weight_loader"] = quant_weight_loader + set_weight_attrs(weight, extra_weight_attrs) + + def apply(self, + layer: torch.nn.Module, + x: torch.Tensor, + bias: Optional[torch.Tensor] = None) -> torch.Tensor: + weight = layer.weight + y = weight.ds_dequantize() + return F.linear(x, y, bias) + + +class DeepSpeedFPParameter(nn.Parameter): + """ + DeepSpeedFP quantized parameter class that implements fp8/fp6 + quantization deepspeed. Weights are stored in quantized form on + GPUs, and can be dequantized on-the-fly when needed by the model. + """ + + def __new__(cls, orig_shape: torch.Size, params_dtype: torch.dtype, + quant_config: DeepSpeedFPConfig): + try: + import deepspeed + if deepspeed.__version__ < "0.14.2": + raise ImportError("deepspeed version is wrong. Please " + "install deepspeed>=0.14.2.") + from deepspeed.ops.fp_quantizer import FP_Quantize + except ImportError as err: + raise ImportError("Please install deepspeed>=0.14.2 via " + "`pip install deepspeed>=0.14.2` to use " + "deepspeedfp quantizer.") from err + data = torch.empty(( + orig_shape.numel() // quant_config.group_size, + quant_config.group_size * quant_config.weight_bits // 8 + 4, + ), + dtype=torch.int8) + self = torch.Tensor._make_subclass(cls, data, data.requires_grad) + self.orig_shape = orig_shape + self.quant_config = quant_config + self.fp_quantizer = FP_Quantize(group_size=quant_config.group_size) + self.fp_quantizer.orig_shape = orig_shape + self.fp_quantizer.orig_dtype = params_dtype + return self + + def ds_quantize_(self, tensor: torch.Tensor): + assert tensor.device.type == "cuda" and tensor.dtype != torch.int8 + return self.data.copy_( + self.fp_quantizer.quantize( + tensor.data, + q_bits=self.quant_config.weight_bits, + )) + + def ds_dequantize(self, fp_out=None) -> torch.Tensor: + """ + Return a tensor containing the dequantized weights of this parameter. + """ + assert self.data.device.type == "cuda" and self.data.dtype == torch.int8 + return self.fp_quantizer.dequantize( + self.data, fp_out=fp_out, q_bits=self.quant_config.weight_bits) + + def ds_selective_dequantize(self, indices, fp_out=None) -> torch.Tensor: + """ + Return a tensor where only the weights at `indices` are dequantized + (to save HBM -> SRAM bandwidth). + """ + assert self.data.device.type == "cuda" and self.data.dtype == torch.int8 + return self.fp_quantizer.selective_dequantize( + self.data, + indices, + fp_out=fp_out, + q_bits=self.quant_config.weight_bits) diff --git a/vllm/model_executor/models/__init__.py b/vllm/model_executor/models/__init__.py index c5cdc059..d5263b50 100755 --- a/vllm/model_executor/models/__init__.py +++ b/vllm/model_executor/models/__init__.py @@ -54,6 +54,7 @@ _MODELS = { "StableLMEpochForCausalLM": ("stablelm", "StablelmForCausalLM"), "StableLmForCausalLM": ("stablelm", "StablelmForCausalLM"), "Starcoder2ForCausalLM": ("starcoder2", "Starcoder2ForCausalLM"), + "ArcticForCausalLM": ("arctic", "ArcticForCausalLM"), "XverseForCausalLM": ("xverse", "XverseForCausalLM"), } diff --git a/vllm/model_executor/models/arctic.py b/vllm/model_executor/models/arctic.py new file mode 100644 index 00000000..796cef7c --- /dev/null +++ b/vllm/model_executor/models/arctic.py @@ -0,0 +1,521 @@ +"""Inference-only Snowflake Arctic model.""" +from typing import Iterable, List, Optional, Tuple + +import torch +from torch import nn + +from vllm.attention import Attention, AttentionMetadata +from vllm.distributed import (get_tensor_model_parallel_rank, + get_tensor_model_parallel_world_size, + tensor_model_parallel_all_reduce) +from vllm.logger import init_logger +from vllm.model_executor.layers.activation import SiluAndMul +from vllm.model_executor.layers.fused_moe import fused_experts, fused_topk +from vllm.model_executor.layers.layernorm import RMSNorm +from vllm.model_executor.layers.linear import (MergedColumnParallelLinear, + QKVParallelLinear, + ReplicatedLinear, + RowParallelLinear) +from vllm.model_executor.layers.logits_processor import LogitsProcessor +from vllm.model_executor.layers.quantization.base_config import ( + QuantizationConfig) +from vllm.model_executor.layers.quantization.deepspeedfp import ( + DeepSpeedFPConfig, DeepSpeedFPParameter) +from vllm.model_executor.layers.rotary_embedding import get_rope +from vllm.model_executor.layers.sampler import Sampler +from vllm.model_executor.layers.vocab_parallel_embedding import ( + ParallelLMHead, VocabParallelEmbedding) +from vllm.model_executor.model_loader.weight_utils import default_weight_loader +from vllm.model_executor.sampling_metadata import SamplingMetadata +from vllm.model_executor.utils import set_weight_attrs +from vllm.sequence import SamplerOutput +from vllm.transformers_utils.configs.arctic import ArcticConfig + +logger = init_logger(__name__) + + +class ArcticMLP(nn.Module): + + def __init__(self, + config: ArcticConfig, + layer_id: int, + expert_id: int = -1, + is_residual_mlp: bool = False, + quant_config: Optional[QuantizationConfig] = None, + reduce_results: bool = True): + super(ArcticMLP, self).__init__() + self.hidden_size = config.hidden_size + self.expert_id = expert_id + self.layer_id = layer_id + + self.ffn_dim = config.intermediate_size if not is_residual_mlp \ + else self.hidden_size + + self.w13 = MergedColumnParallelLinear(self.hidden_size, + [self.ffn_dim] * 2, + bias=False, + quant_config=quant_config) + self.w2 = RowParallelLinear(self.ffn_dim, + self.hidden_size, + bias=False, + reduce_results=reduce_results, + quant_config=quant_config) + if config.hidden_act != "silu": + raise ValueError(f"Unsupported activation: {config.hidden_act}. " + "Only silu is supported for now.") + self.act_fn = SiluAndMul() + + def forward(self, hidden_states): + gate_up, _ = self.w13(hidden_states) + hidden_states = self.act_fn(gate_up) + hidden_states, _ = self.w2(hidden_states) + return hidden_states + + +class ArcticMoE(nn.Module): + """ + Model-parallel implementation of Arctic MoE Layer. + """ + + def __init__(self, + config: ArcticConfig, + layer_id: int, + tp_size: Optional[int] = None, + params_dtype: Optional[torch.dtype] = None, + quant_config: Optional[QuantizationConfig] = None, + reduce_results: bool = True): + super(ArcticMoE, self).__init__() + + self.tp_size = tp_size or get_tensor_model_parallel_world_size() + self.hidden_size = config.hidden_size + self.num_experts = config.num_local_experts + self.layer_id = layer_id + self.top_k = config.num_experts_per_tok + self.intermediate_size = config.intermediate_size // self.tp_size + + self.is_moe_layer = (layer_id + 1) % config.moe_layer_frequency == 0 + self.is_quant = isinstance(quant_config, DeepSpeedFPConfig) + self.reduce_results = reduce_results + # Some other parameters + if params_dtype is None: + params_dtype = torch.get_default_dtype() + self.params_dtype = params_dtype + + if not self.is_moe_layer: + self.mlp = ArcticMLP(config, + layer_id=layer_id, + quant_config=quant_config, + reduce_results=reduce_results) + else: + self.gate = ReplicatedLinear(self.hidden_size, + self.num_experts, + bias=False, + params_dtype=self.params_dtype, + quant_config=quant_config) + if self.is_quant: + self.ws = DeepSpeedFPParameter( + torch.Size((self.num_experts, 2 * self.intermediate_size, + self.hidden_size)), + params_dtype=params_dtype, + quant_config=quant_config, + ) + self.w2s = DeepSpeedFPParameter( + torch.Size((self.num_experts, self.hidden_size, + self.intermediate_size)), + params_dtype=params_dtype, + quant_config=quant_config, + ) + else: + self.ws = nn.Parameter( + torch.empty(self.num_experts, + 2 * self.intermediate_size, + self.hidden_size, + device="cuda", + dtype=self.params_dtype)) + self.w2s = nn.Parameter( + torch.empty(self.num_experts, + self.hidden_size, + self.intermediate_size, + device="cuda", + dtype=self.params_dtype)) + set_weight_attrs(self.ws, { + "weight_loader": self.weight_loader, + }) + set_weight_attrs(self.w2s, { + "weight_loader": self.weight_loader, + }) + + def weight_loader(self, param: nn.Parameter, loaded_weight: torch.Tensor, + weight_name: str, expert_id: int): + tp_rank = get_tensor_model_parallel_rank() + param_data = param.ds_dequantize() if self.is_quant else param.data + shard_size = self.intermediate_size + shard = slice(tp_rank * shard_size, (tp_rank + 1) * shard_size) + if weight_name.endswith("w1.weight"): + param_data[expert_id, 0:shard_size, :] = loaded_weight[shard, :] + if weight_name.endswith("w3.weight"): + param_data[expert_id, + shard_size:2 * shard_size, :] = loaded_weight[shard, :] + if weight_name.endswith("w2.weight"): + param_data[expert_id, :, :] = loaded_weight[:, shard] + if self.is_quant: + param.ds_quantize_(param_data) + + def local_moe_fused(self, hidden_states: torch.Tensor) -> torch.Tensor: + num_tokens, hidden_size = hidden_states.shape + hidden_states = hidden_states.view(-1, self.hidden_size) + # router_logits: (num_tokens, n_experts) + router_logits, _ = self.gate(hidden_states) + do_normalize = self.top_k > 1 + topk_weights, topk_ids = fused_topk(hidden_states, + router_logits, + self.top_k, + renormalize=do_normalize) + # topk_ids: (num_tokens, k) + if self.is_quant: + if 2 * num_tokens <= self.num_experts: + # If much fewer tokens than experts, use selective dequantize. + ws_dequantized = self.ws.ds_selective_dequantize( + topk_ids.flatten()) + w2s_dequantized = self.w2s.ds_selective_dequantize( + topk_ids.flatten()) + # We gathered the experts to the tokens so update the mapping. + topk_ids = torch.arange( + 0, + topk_ids.numel(), + device=topk_ids.device, + ).reshape(topk_ids.shape) + else: + ws_dequantized = self.ws.ds_dequantize() + w2s_dequantized = self.w2s.ds_dequantize() + + final_hidden_states = fused_experts( + hidden_states, + ws_dequantized if self.is_quant else self.ws, + w2s_dequantized if self.is_quant else self.w2s, + topk_weights, + topk_ids, + inplace=True) + if self.reduce_results and self.tp_size > 1: + final_hidden_states = tensor_model_parallel_all_reduce( + final_hidden_states) + return final_hidden_states.view(num_tokens, hidden_size) + + def forward(self, hidden_states: torch.Tensor): + if self.is_moe_layer: + final_hidden_states = self.local_moe_fused(hidden_states) + else: + final_hidden_states = self.mlp(hidden_states) + return final_hidden_states + + +class ArcticAttention(nn.Module): + + def __init__( + self, + config: ArcticConfig, + layer_idx: Optional[int] = None, + quant_config: Optional[QuantizationConfig] = None, + ): + super().__init__() + self.config = config + self.layer_idx = layer_idx + self.hidden_size = config.hidden_size + + tp_size = get_tensor_model_parallel_world_size() + self.total_num_heads = config.num_attention_heads + assert self.total_num_heads % tp_size == 0 + self.num_heads = self.total_num_heads // tp_size + self.total_num_kv_heads = config.num_key_value_heads + if self.total_num_kv_heads >= tp_size: + assert self.total_num_kv_heads % tp_size == 0 + else: + assert tp_size % self.total_num_kv_heads == 0 + self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size) + self.head_dim = self.hidden_size // self.total_num_heads + self.q_size = self.num_heads * self.head_dim + self.kv_size = self.num_kv_heads * self.head_dim + + self.max_position_embeddings = config.max_position_embeddings + self.rope_theta = config.rope_theta + self.scaling = self.head_dim**-0.5 + + self.qkv_proj = QKVParallelLinear(self.hidden_size, + self.head_dim, + self.total_num_heads, + self.total_num_kv_heads, + bias=False, + quant_config=quant_config) + self.o_proj = RowParallelLinear( + self.total_num_heads * self.head_dim, + self.hidden_size, + bias=False, + reduce_results=True, + quant_config=quant_config, + ) + + self.rotary_emb = get_rope( + self.head_dim, + rotary_dim=self.head_dim, + max_position=self.max_position_embeddings, + base=int(self.rope_theta), + is_neox_style=True, + ) + + self.attn = Attention(self.num_heads, + self.head_dim, + self.scaling, + num_kv_heads=self.num_kv_heads) + + def forward( + self, + positions: torch.Tensor, + hidden_states: torch.Tensor, + kv_cache: torch.Tensor, + attn_metadata: AttentionMetadata, + ) -> torch.Tensor: + qkv, _ = self.qkv_proj(hidden_states) + q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) + q, k = self.rotary_emb(positions, q, k) + attn_output = self.attn(q, k, v, kv_cache, attn_metadata) + output, _ = self.o_proj(attn_output) + return output + + +class ArcticDecoderLayer(nn.Module): + + def __init__( + self, + config: ArcticConfig, + layer_idx: int, + quant_config: Optional[QuantizationConfig] = None, + ) -> None: + super().__init__() + self.layer_idx = layer_idx + self.hidden_size = config.hidden_size + is_moe_layer = (layer_idx + 1) % config.moe_layer_frequency == 0 + self.use_residual = config.use_residual and is_moe_layer + self.self_attn = ArcticAttention(config, + layer_idx, + quant_config=quant_config) + self.block_sparse_moe = ArcticMoE( + config, + layer_id=layer_idx, + quant_config=quant_config, + reduce_results=(not self.use_residual)) + + self.input_layernorm = RMSNorm(config.hidden_size, + eps=config.rms_norm_eps) + self.post_attention_layernorm = RMSNorm(config.hidden_size, + eps=config.rms_norm_eps) + + if self.use_residual: + self.residual_layernorm = RMSNorm(config.hidden_size, + eps=config.rms_norm_eps) + self.residual_mlp = ArcticMLP(config, + layer_id=layer_idx, + is_residual_mlp=True, + reduce_results=False) + + def forward( + self, + positions: torch.Tensor, + hidden_states: torch.Tensor, + kv_cache: torch.Tensor, + attn_metadata: AttentionMetadata, + ) -> torch.Tensor: + residual_input = hidden_states + hidden_states = self.input_layernorm(hidden_states) + hidden_states = self.self_attn( + positions=positions, + hidden_states=hidden_states, + kv_cache=kv_cache, + attn_metadata=attn_metadata, + ) + hidden_states = residual_input + hidden_states + + residual_attn = hidden_states + if self.use_residual: + hidden_states = self.residual_layernorm(hidden_states) + hidden_states = self.residual_mlp(hidden_states) + residual_mlp = hidden_states + hidden_states = self.post_attention_layernorm(residual_input) + hidden_states = self.block_sparse_moe(hidden_states) + hidden_states = residual_mlp + hidden_states + hidden_states = tensor_model_parallel_all_reduce(hidden_states) + hidden_states = residual_attn + hidden_states + else: + hidden_states = self.post_attention_layernorm(hidden_states) + hidden_states = self.block_sparse_moe(hidden_states) + hidden_states = residual_attn + hidden_states + return hidden_states + + +class ArcticModel(nn.Module): + + def __init__( + self, + config: ArcticConfig, + quant_config: Optional[QuantizationConfig] = None, + ) -> None: + super().__init__() + self.padding_idx = config.pad_token_id + self.vocab_size = config.vocab_size + self.embed_tokens = VocabParallelEmbedding( + self.vocab_size, + config.hidden_size, + org_num_embeddings=self.vocab_size) + self.layers = nn.ModuleList([ + ArcticDecoderLayer(config, layer_idx, quant_config=quant_config) + for layer_idx in range(config.num_hidden_layers) + ]) + self._attn_implementation = config._attn_implementation + self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + kv_caches: List[torch.Tensor], + attn_metadata: AttentionMetadata, + ) -> torch.Tensor: + hidden_states = self.embed_tokens(input_ids) + for i in range(len(self.layers)): + layer = self.layers[i] + hidden_states = layer(positions, hidden_states, kv_caches[i], + attn_metadata) + hidden_states = self.norm(hidden_states) + return hidden_states + + +class ArcticForCausalLM(nn.Module): + + def __init__(self, + config: ArcticConfig, + quant_config: Optional[QuantizationConfig] = None, + **kwargs) -> None: + super().__init__() + self.config = config + self.model = ArcticModel(config, quant_config) + self.vocab_size = config.vocab_size + self.lm_head = ParallelLMHead( + self.vocab_size, + config.hidden_size, + ) + self.num_experts = config.num_local_experts + self.num_experts_per_tok = config.num_experts_per_tok + self.unpadded_vocab_size = config.vocab_size + self.logits_processor = LogitsProcessor(self.unpadded_vocab_size, + config.vocab_size) + self.sampler = Sampler() + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + kv_caches: List[torch.Tensor], + attn_metadata: AttentionMetadata, + ) -> torch.Tensor: + hidden_states = self.model(input_ids, positions, kv_caches, + attn_metadata) + return hidden_states + + def compute_logits(self, hidden_states: torch.Tensor, + sampling_metadata: SamplingMetadata) -> torch.Tensor: + logits = self.logits_processor(self.lm_head.weight, hidden_states, + sampling_metadata) + return logits + + def sample( + self, + logits: Optional[torch.Tensor], + sampling_metadata: SamplingMetadata, + ) -> Optional[SamplerOutput]: + next_tokens = self.sampler(logits, sampling_metadata) + return next_tokens + + def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): + stacked_params_mapping = [ + # (param_name, shard_name, shard_id) + ("qkv_proj", "q_proj", "q"), + ("qkv_proj", "k_proj", "k"), + ("qkv_proj", "v_proj", "v"), + ] + + mlp_params_mapping = [] + expert_params_mapping = [] + num_layers = self.config.num_hidden_layers + + for layer in range(num_layers): + mlp_params_mapping.append( + (f"layers.{layer}.residual_mlp.w13.weight", + f"layers.{layer}.residual_mlp.w1.weight", 0)) + mlp_params_mapping.append( + (f"layers.{layer}.residual_mlp.w13.weight", + f"layers.{layer}.residual_mlp.w3.weight", 1)) + if layer % 2 == 0: + # MLP layers + mlp_params_mapping.append( + (f"layers.{layer}.block_sparse_moe.mlp.w13.weight", + f"layers.{layer}.block_sparse_moe.mlp.w1.weight", 0)) + mlp_params_mapping.append( + (f"layers.{layer}.block_sparse_moe.mlp.w13.weight", + f"layers.{layer}.block_sparse_moe.mlp.w3.weight", 1)) + else: + # MoE layers + for expert_id in range(self.config.num_local_experts): + expert_params_mapping.append( + ("ws", f"experts.{expert_id}.w1.weight", expert_id)) + expert_params_mapping.append( + ("w2s", f"experts.{expert_id}.w2.weight", expert_id)) + expert_params_mapping.append( + ("ws", f"experts.{expert_id}.w3.weight", expert_id)) + + params_dict = dict(self.named_parameters()) + + logger.info( + "It will take ~10 minutes loading from the 16-bit weights. " + "Alternatively, use the prequantized 8-bit weights of arctic " + "and set load-format to `sharded_state` will accelerate loading.") + for name, loaded_weight in weights: + for (param_name, weight_name, shard_id) in stacked_params_mapping: + if weight_name not in name: + continue + name = name.replace(weight_name, param_name) + # Skip loading extra bias for GPTQ models. + if name.endswith(".bias") and name not in params_dict: + continue + param = params_dict[name] + weight_loader = param.weight_loader + weight_loader(param, loaded_weight, shard_id) + break + else: + for param_name, weight_name, shard_id in mlp_params_mapping: + if weight_name not in name: + continue + name = name.replace(weight_name, param_name) + param = params_dict[name] + weight_loader = param.weight_loader + weight_loader(param, loaded_weight, shard_id) + break + else: + for param_name, weight_name, shard_id \ + in expert_params_mapping: + if weight_name not in name: + continue + name = name.replace(weight_name, param_name) + param = params_dict[name] + weight_loader = param.weight_loader + weight_loader(param, + loaded_weight, + weight_name, + expert_id=shard_id) + break + else: + if name.endswith(".bias") and name not in params_dict: + continue + param = params_dict[name] + + weight_loader = getattr(param, "weight_loader", + default_weight_loader) + weight_loader(param, loaded_weight) diff --git a/vllm/transformers_utils/configs/arctic.py b/vllm/transformers_utils/configs/arctic.py new file mode 100644 index 00000000..7780bf5e --- /dev/null +++ b/vllm/transformers_utils/configs/arctic.py @@ -0,0 +1,204 @@ +# yapf: disable +# ruff: noqa: E501 +# coding=utf-8 +# Copied from +# https://huggingface.co/Snowflake/snowflake-arctic-instruct/blob/main/configuration_arctic.py +""" Arctic model configuration""" + +from dataclasses import asdict, dataclass +from typing import Any, Dict + +from transformers.configuration_utils import PretrainedConfig +from transformers.utils import logging + +logger = logging.get_logger(__name__) + +ARCTIC_PRETRAINED_CONFIG_ARCHIVE_MAP = { + "arctic": "https://huggingface.co/Snowflake/snowflake-arctic-instruct/tree/main/config.json", +} + + +@dataclass +class ArcticLoraConfig: + lora_r: int = 64 + lora_alpha: float = 16 + shard_base_weights: bool = False + + +@dataclass +class ArcticQuantizationConfig: + q_bits: int = 8 + rounding: str = "nearest" + mantissa_bits: int = 3 + group_size: int = 128 + + +class ArcticConfig(PretrainedConfig): + r""" + This is the configuration class to store the configuration of a [`ArcticModel`]. It is used to instantiate an + Arctic model according to the specified arguments, defining the model architecture. Instantiating a configuration + with the defaults will yield a similar configuration to that of the #TODO(rsamdani): add what model has the default config.. + + + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + + + Args: + vocab_size (`int`, *optional*, defaults to 32000): + Vocabulary size of the Arctic model. Defines the number of different tokens that can be represented by the + `inputs_ids` passed when calling [`ArcticModel`] + hidden_size (`int`, *optional*, defaults to 4096): + Dimension of the hidden representations. + intermediate_size (`int`, *optional*, defaults to 14336): + Dimension of the MLP representations. + num_hidden_layers (`int`, *optional*, defaults to 32): + Number of hidden layers in the Transformer encoder. + num_attention_heads (`int`, *optional*, defaults to 32): + Number of attention heads for each attention layer in the Transformer encoder. + num_key_value_heads (`int`, *optional*, defaults to 8): + This is the number of key_value heads that should be used to implement Grouped Query Attention. If + `num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if + `num_key_value_heads=1 the model will use Multi Query Attention (MQA) otherwise GQA is used. When + converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed + by meanpooling all the original heads within that group. For more details checkout [this + paper](https://arxiv.org/pdf/2305.13245.pdf). If it is not specified, will default to `8`. + hidden_act (`str` or `function`, *optional*, defaults to `"silu"`): + The non-linear activation function (function or string) in the decoder. + max_position_embeddings (`int`, *optional*, defaults to `4096*32`): + The maximum sequence length that this model might ever be used with. Arctic's sliding window attention + allows sequence of up to 4096*32 tokens. + initializer_range (`float`, *optional*, defaults to 0.02): + The standard deviation of the truncated_normal_initializer for initializing all weight matrices. + rms_norm_eps (`float`, *optional*, defaults to 1e-05): + The epsilon used by the rms normalization layers. + use_cache (`bool`, *optional*, defaults to `True`): + Whether or not the model should return the last key/values attentions (not used by all models). Only + relevant if `config.is_decoder=True`. + pad_token_id (`int`, *optional*): + The id of the padding token. + bos_token_id (`int`, *optional*, defaults to 1): + The id of the "beginning-of-sequence" token. + eos_token_id (`int`, *optional*, defaults to 2): + The id of the "end-of-sequence" token. + tie_word_embeddings (`bool`, *optional*, defaults to `False`): + Whether the model's input and output word embeddings should be tied. + rope_theta (`float`, *optional*, defaults to 1000000.0): + The base period of the RoPE embeddings. + sliding_window (`int`, *optional*): + Sliding window attention window size. If not specified, will default to `4096`. + attention_dropout (`float`, *optional*, defaults to 0.0): + The dropout ratio for the attention probabilities. + num_experts_per_tok (`int`, *optional*, defaults to 2): + The number of experts to root per-token, can be also interpreted as the `top-p` routing + parameter + num_local_experts (`int`, *optional*, defaults to 8): + Number of experts per Sparse MLP layer. + router_aux_loss_coef (`float`, *optional*, defaults to 0.001): + The aux loss factor for the total loss. + + ```python + >>> from transformers import ArcticModel, ArcticConfig + + >>> # Initializing a Arctic 7B style configuration TODO(rsamdani): verify which model does the default configuration correspond to. + >>> configuration = ArcticConfig() + + >>> # Initializing a model from the Arctic 7B style configuration + >>> model = ArcticModel(configuration) + + >>> # Accessing the model configuration + >>> configuration = model.config + ```""" + + model_type = "arctic" + keys_to_ignore_at_inference = ["past_key_values"] + + def __init__( + self, + vocab_size=32000, + hidden_size=4096, + intermediate_size=14336, + num_hidden_layers=32, + num_attention_heads=32, + num_key_value_heads=None, + hidden_act="silu", + max_position_embeddings=4096, + initializer_range=0.02, + rms_norm_eps=1e-5, + use_cache=True, + pad_token_id=None, + bos_token_id=1, + eos_token_id=2, + tie_word_embeddings=False, + rope_theta=1e6, + sliding_window=None, + attention_dropout=0.0, + num_experts_per_tok=1, + num_local_experts=8, + router_aux_loss_coef=0.001, + moe_layer_frequency=2, + parallel_attn_mlp_res=False, + moe_train_capacity_factor=1, + moe_eval_capacity_factor=1, + enable_expert_tensor_parallelism=False, + moe_min_capacity=0, + moe_token_dropping=True, + quantization=None, + **kwargs, + ): + self.vocab_size = vocab_size + self.max_position_embeddings = max_position_embeddings + self.hidden_size = hidden_size + self.intermediate_size = intermediate_size + self.num_hidden_layers = num_hidden_layers + self.num_attention_heads = num_attention_heads + self.sliding_window = sliding_window + + # for backward compatibility + if num_key_value_heads is None: + num_key_value_heads = num_attention_heads + + self.num_key_value_heads = num_key_value_heads + self.hidden_act = hidden_act + self.initializer_range = initializer_range + self.rms_norm_eps = rms_norm_eps + self.use_cache = use_cache + self.rope_theta = rope_theta + self.attention_dropout = attention_dropout + + self.num_experts_per_tok = num_experts_per_tok + self.num_local_experts = num_local_experts + self.router_aux_loss_coef = router_aux_loss_coef + self.moe_layer_frequency = moe_layer_frequency + self.moe_train_capacity_factor = moe_train_capacity_factor + self.moe_eval_capacity_factor = moe_eval_capacity_factor + self.enable_expert_tensor_parallelism = enable_expert_tensor_parallelism + self.moe_min_capacity = moe_min_capacity + self.moe_token_dropping = moe_token_dropping + self.parallel_attn_mlp_res = parallel_attn_mlp_res + if isinstance(quantization, dict): + self.quantization = ArcticQuantizationConfig(**quantization) + else: + self.quantization = quantization + + super().__init__( + pad_token_id=pad_token_id, + bos_token_id=bos_token_id, + eos_token_id=eos_token_id, + tie_word_embeddings=tie_word_embeddings, + **kwargs, + ) + + @classmethod + def from_dict(cls, config_dict: Dict[str, Any], **kwargs) -> "ArcticConfig": + result = super().from_dict(config_dict, **kwargs) + config = result[0] if isinstance(result, tuple) else result + if isinstance(config.quantization, dict): + config.quantization = ArcticQuantizationConfig(**config.quantization) + return result + + def to_dict(self) -> Dict[str, Any]: + ret = super().to_dict() + if isinstance(ret["quantization"], ArcticQuantizationConfig): + ret["quantization"] = asdict(ret["quantization"]) + return ret