[Model] Add support for DBRX (#3660)
This commit is contained in:
parent
d18f4e73f3
commit
e24336b5a7
@ -67,6 +67,7 @@ vLLM seamlessly supports many Hugging Face models, including the following archi
|
|||||||
- Baichuan & Baichuan2 (`baichuan-inc/Baichuan2-13B-Chat`, `baichuan-inc/Baichuan-7B`, etc.)
|
- Baichuan & Baichuan2 (`baichuan-inc/Baichuan2-13B-Chat`, `baichuan-inc/Baichuan-7B`, etc.)
|
||||||
- BLOOM (`bigscience/bloom`, `bigscience/bloomz`, etc.)
|
- BLOOM (`bigscience/bloom`, `bigscience/bloomz`, etc.)
|
||||||
- ChatGLM (`THUDM/chatglm2-6b`, `THUDM/chatglm3-6b`, etc.)
|
- ChatGLM (`THUDM/chatglm2-6b`, `THUDM/chatglm3-6b`, etc.)
|
||||||
|
- DBRX (`databricks/dbrx-base`, `databricks/dbrx-instruct` etc.)
|
||||||
- DeciLM (`Deci/DeciLM-7B`, `Deci/DeciLM-7B-instruct`, etc.)
|
- DeciLM (`Deci/DeciLM-7B`, `Deci/DeciLM-7B-instruct`, etc.)
|
||||||
- Falcon (`tiiuae/falcon-7b`, `tiiuae/falcon-40b`, `tiiuae/falcon-rw-7b`, etc.)
|
- Falcon (`tiiuae/falcon-7b`, `tiiuae/falcon-40b`, `tiiuae/falcon-rw-7b`, etc.)
|
||||||
- Gemma (`google/gemma-2b`, `google/gemma-7b`, etc.)
|
- Gemma (`google/gemma-2b`, `google/gemma-7b`, etc.)
|
||||||
|
@ -27,6 +27,10 @@ Alongside each architecture, we include some popular models that use it.
|
|||||||
- ChatGLM
|
- ChatGLM
|
||||||
- :code:`THUDM/chatglm2-6b`, :code:`THUDM/chatglm3-6b`, etc.
|
- :code:`THUDM/chatglm2-6b`, :code:`THUDM/chatglm3-6b`, etc.
|
||||||
- ✅︎
|
- ✅︎
|
||||||
|
* - :code:`DbrxForCausalLM`
|
||||||
|
- DBRX
|
||||||
|
- :code:`databricks/dbrx-base`, :code:`databricks/dbrx-instruct`, etc.
|
||||||
|
-
|
||||||
* - :code:`DeciLMForCausalLM`
|
* - :code:`DeciLMForCausalLM`
|
||||||
- DeciLM
|
- DeciLM
|
||||||
- :code:`Deci/DeciLM-7B`, :code:`Deci/DeciLM-7B-instruct`, etc.
|
- :code:`Deci/DeciLM-7B`, :code:`Deci/DeciLM-7B-instruct`, etc.
|
||||||
|
@ -14,3 +14,4 @@ prometheus_client >= 0.18.0
|
|||||||
pynvml == 11.5.0
|
pynvml == 11.5.0
|
||||||
triton >= 2.1.0
|
triton >= 2.1.0
|
||||||
outlines == 0.0.34
|
outlines == 0.0.34
|
||||||
|
tiktoken == 0.6.0 # Required for DBRX tokenizer
|
||||||
|
@ -277,6 +277,11 @@ class ModelConfig:
|
|||||||
# Currently, tensor parallelism is not supported in this case.
|
# Currently, tensor parallelism is not supported in this case.
|
||||||
return 1
|
return 1
|
||||||
|
|
||||||
|
# For DBRX and MPT
|
||||||
|
if self.hf_config.model_type in ["dbrx", "mpt"]:
|
||||||
|
return getattr(self.hf_config.attn_config, "kv_n_heads",
|
||||||
|
self.hf_config.num_attention_heads)
|
||||||
|
|
||||||
attributes = [
|
attributes = [
|
||||||
# For Falcon:
|
# For Falcon:
|
||||||
"n_head_kv",
|
"n_head_kv",
|
||||||
|
@ -17,6 +17,7 @@ _MODELS = {
|
|||||||
"BloomForCausalLM": ("bloom", "BloomForCausalLM"),
|
"BloomForCausalLM": ("bloom", "BloomForCausalLM"),
|
||||||
"ChatGLMModel": ("chatglm", "ChatGLMForCausalLM"),
|
"ChatGLMModel": ("chatglm", "ChatGLMForCausalLM"),
|
||||||
"ChatGLMForConditionalGeneration": ("chatglm", "ChatGLMForCausalLM"),
|
"ChatGLMForConditionalGeneration": ("chatglm", "ChatGLMForCausalLM"),
|
||||||
|
"DbrxForCausalLM": ("dbrx", "DbrxForCausalLM"),
|
||||||
"DeciLMForCausalLM": ("decilm", "DeciLMForCausalLM"),
|
"DeciLMForCausalLM": ("decilm", "DeciLMForCausalLM"),
|
||||||
"DeepseekForCausalLM": ("deepseek", "DeepseekForCausalLM"),
|
"DeepseekForCausalLM": ("deepseek", "DeepseekForCausalLM"),
|
||||||
"FalconForCausalLM": ("falcon", "FalconForCausalLM"),
|
"FalconForCausalLM": ("falcon", "FalconForCausalLM"),
|
||||||
|
421
vllm/model_executor/models/dbrx.py
Normal file
421
vllm/model_executor/models/dbrx.py
Normal file
@ -0,0 +1,421 @@
|
|||||||
|
# coding=utf-8
|
||||||
|
from typing import List, Optional
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
|
||||||
|
from vllm.attention import Attention, AttentionMetadata
|
||||||
|
from vllm.model_executor.layers.fused_moe import fused_moe
|
||||||
|
from vllm.model_executor.layers.linear import (LinearMethodBase,
|
||||||
|
QKVParallelLinear,
|
||||||
|
ReplicatedLinear,
|
||||||
|
RowParallelLinear)
|
||||||
|
from vllm.model_executor.layers.logits_processor import LogitsProcessor
|
||||||
|
from vllm.model_executor.layers.rotary_embedding import get_rope
|
||||||
|
from vllm.model_executor.layers.sampler import Sampler
|
||||||
|
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
||||||
|
DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding)
|
||||||
|
from vllm.model_executor.parallel_utils.communication_op import (
|
||||||
|
tensor_model_parallel_all_reduce)
|
||||||
|
from vllm.model_executor.parallel_utils.parallel_state import (
|
||||||
|
get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size)
|
||||||
|
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
||||||
|
from vllm.model_executor.utils import set_weight_attrs
|
||||||
|
from vllm.model_executor.weight_utils import (default_weight_loader,
|
||||||
|
hf_model_weights_iterator)
|
||||||
|
from vllm.sequence import SamplerOutput
|
||||||
|
from vllm.transformers_utils.configs.dbrx import DbrxConfig
|
||||||
|
|
||||||
|
|
||||||
|
class DbrxRouter(nn.Module):
|
||||||
|
"""A Router implementation for DBRX that returns logits for each expert
|
||||||
|
per token.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
config: DbrxConfig,
|
||||||
|
params_dtype: Optional[torch.dtype] = None,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
self.tp_size = get_tensor_model_parallel_world_size()
|
||||||
|
self.num_total_experts = config.ffn_config.moe_num_experts
|
||||||
|
self.d_model = config.d_model
|
||||||
|
self.layer = ReplicatedLinear(
|
||||||
|
self.d_model,
|
||||||
|
self.num_total_experts,
|
||||||
|
bias=False,
|
||||||
|
params_dtype=params_dtype,
|
||||||
|
linear_method=None,
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||||
|
router_logits, _ = self.layer(hidden_states)
|
||||||
|
return router_logits
|
||||||
|
|
||||||
|
|
||||||
|
class DbrxExperts(nn.Module):
|
||||||
|
"""A tensor-parallel MoE implementation for DBRX.
|
||||||
|
|
||||||
|
Each expert's weights are sharded across all ranks and a fused MoE
|
||||||
|
kernel is used for the forward pass, and finally we reduce the outputs
|
||||||
|
across ranks.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
config: DbrxConfig,
|
||||||
|
linear_method: Optional[LinearMethodBase] = None,
|
||||||
|
params_dtype: Optional[torch.dtype] = None,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
self.tp_size = get_tensor_model_parallel_world_size()
|
||||||
|
self.num_total_experts = config.ffn_config.moe_num_experts
|
||||||
|
self.top_k = config.ffn_config.moe_top_k
|
||||||
|
self.d_model = config.d_model
|
||||||
|
self.intermediate_size = (config.ffn_config.ffn_hidden_size //
|
||||||
|
self.tp_size)
|
||||||
|
|
||||||
|
if params_dtype is None:
|
||||||
|
params_dtype = torch.get_default_dtype()
|
||||||
|
self.params_dtype = params_dtype
|
||||||
|
|
||||||
|
self.router = DbrxRouter(config, self.params_dtype)
|
||||||
|
self.ws = nn.Parameter(
|
||||||
|
torch.empty(
|
||||||
|
self.num_total_experts,
|
||||||
|
2 * self.intermediate_size,
|
||||||
|
self.d_model,
|
||||||
|
device="cuda",
|
||||||
|
dtype=self.params_dtype,
|
||||||
|
))
|
||||||
|
self.w2s = nn.Parameter(
|
||||||
|
torch.empty(
|
||||||
|
self.num_total_experts,
|
||||||
|
self.d_model,
|
||||||
|
self.intermediate_size,
|
||||||
|
device="cuda",
|
||||||
|
dtype=self.params_dtype,
|
||||||
|
))
|
||||||
|
|
||||||
|
set_weight_attrs(
|
||||||
|
self.ws,
|
||||||
|
{
|
||||||
|
"weight_loader": self.weight_loader,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
set_weight_attrs(
|
||||||
|
self.w2s,
|
||||||
|
{
|
||||||
|
"weight_loader": self.weight_loader,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
def weight_loader(self, param: nn.Parameter, loaded_weight: torch.Tensor,
|
||||||
|
weight_name: str):
|
||||||
|
tp_rank = get_tensor_model_parallel_rank()
|
||||||
|
param_data = param.data
|
||||||
|
shard_size = self.intermediate_size
|
||||||
|
shard = slice(tp_rank * shard_size, (tp_rank + 1) * shard_size)
|
||||||
|
# DBRX uses GLU for each experts.
|
||||||
|
# GLU has 3 linear layers: w1, v1 and w2.
|
||||||
|
if weight_name.endswith("w1"):
|
||||||
|
loaded_weight = torch.reshape(
|
||||||
|
loaded_weight,
|
||||||
|
[-1, self.intermediate_size * self.tp_size, self.d_model],
|
||||||
|
)
|
||||||
|
param_data[:, 0:shard_size, :] = loaded_weight[:, shard, :]
|
||||||
|
if weight_name.endswith("v1"):
|
||||||
|
loaded_weight = torch.reshape(
|
||||||
|
loaded_weight,
|
||||||
|
[-1, self.intermediate_size * self.tp_size, self.d_model],
|
||||||
|
)
|
||||||
|
param_data[:,
|
||||||
|
shard_size:2 * shard_size, :] = loaded_weight[:,
|
||||||
|
shard, :]
|
||||||
|
if weight_name.endswith("w2"):
|
||||||
|
loaded_weight = torch.reshape(
|
||||||
|
loaded_weight,
|
||||||
|
[-1, self.intermediate_size * self.tp_size, self.d_model],
|
||||||
|
).transpose(1, 2)
|
||||||
|
param_data[:] = loaded_weight[:, :, shard]
|
||||||
|
|
||||||
|
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||||
|
num_tokens, hidden_size = hidden_states.shape
|
||||||
|
hidden_states = hidden_states.view(-1, self.d_model)
|
||||||
|
# router_logits: (num_tokens, n_experts)
|
||||||
|
router_logits = self.router(hidden_states)
|
||||||
|
final_hidden_states = fused_moe(
|
||||||
|
hidden_states,
|
||||||
|
self.ws,
|
||||||
|
self.w2s,
|
||||||
|
router_logits,
|
||||||
|
self.top_k,
|
||||||
|
renormalize=True,
|
||||||
|
inplace=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
if self.tp_size > 1:
|
||||||
|
final_hidden_states = tensor_model_parallel_all_reduce(
|
||||||
|
final_hidden_states)
|
||||||
|
|
||||||
|
return final_hidden_states.view(num_tokens, hidden_size)
|
||||||
|
|
||||||
|
|
||||||
|
class DbrxAttention(nn.Module):
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
config: DbrxConfig,
|
||||||
|
linear_method: Optional[LinearMethodBase] = None,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
self.d_model = config.d_model
|
||||||
|
self.total_num_heads = config.n_heads
|
||||||
|
self.head_dim = self.d_model // self.total_num_heads
|
||||||
|
self.total_num_kv_heads = config.attn_config.kv_n_heads
|
||||||
|
self.clip_qkv = config.attn_config.clip_qkv
|
||||||
|
self.rope_theta = config.attn_config.rope_theta
|
||||||
|
self.max_position = config.max_seq_len
|
||||||
|
|
||||||
|
# pylint: disable=invalid-name
|
||||||
|
self.Wqkv = QKVParallelLinear(
|
||||||
|
self.d_model,
|
||||||
|
self.head_dim,
|
||||||
|
self.total_num_heads,
|
||||||
|
self.total_num_kv_heads,
|
||||||
|
bias=False,
|
||||||
|
linear_method=linear_method,
|
||||||
|
)
|
||||||
|
self.out_proj = RowParallelLinear(
|
||||||
|
self.d_model,
|
||||||
|
self.d_model,
|
||||||
|
bias=False,
|
||||||
|
linear_method=linear_method,
|
||||||
|
)
|
||||||
|
self.rotary_emb = get_rope(
|
||||||
|
self.head_dim,
|
||||||
|
rotary_dim=self.head_dim,
|
||||||
|
max_position=self.max_position,
|
||||||
|
base=int(self.rope_theta),
|
||||||
|
is_neox_style=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
tp_world_size = get_tensor_model_parallel_world_size()
|
||||||
|
self.tp_size = tp_world_size
|
||||||
|
assert self.total_num_heads % tp_world_size == 0
|
||||||
|
self.num_heads = self.total_num_heads // tp_world_size
|
||||||
|
if self.total_num_kv_heads >= tp_world_size:
|
||||||
|
# Number of KV heads is greater than TP size, so we partition
|
||||||
|
# the KV heads across multiple tensor parallel GPUs.
|
||||||
|
assert self.total_num_kv_heads % tp_world_size == 0
|
||||||
|
else:
|
||||||
|
# Number of KV heads is less than TP size, so we replicate
|
||||||
|
# the KV heads across multiple tensor parallel GPUs.
|
||||||
|
assert tp_world_size % self.total_num_kv_heads == 0
|
||||||
|
self.num_kv_heads = max(1, self.total_num_kv_heads // tp_world_size)
|
||||||
|
self.q_size = self.num_heads * self.head_dim
|
||||||
|
self.kv_size = self.num_kv_heads * self.head_dim
|
||||||
|
self.scaling = self.head_dim**-0.5
|
||||||
|
self.attn = Attention(
|
||||||
|
self.num_heads,
|
||||||
|
self.head_dim,
|
||||||
|
self.scaling,
|
||||||
|
num_kv_heads=self.num_kv_heads,
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
position_ids: torch.Tensor,
|
||||||
|
hidden_states: torch.Tensor,
|
||||||
|
kv_cache: torch.Tensor,
|
||||||
|
attn_metadata: AttentionMetadata,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
qkv, _ = self.Wqkv(hidden_states)
|
||||||
|
if self.clip_qkv is not None:
|
||||||
|
qkv.clamp_(min=-self.clip_qkv, max=self.clip_qkv)
|
||||||
|
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
|
||||||
|
q, k = self.rotary_emb(position_ids, q, k)
|
||||||
|
attn_output = self.attn(q, k, v, kv_cache, attn_metadata)
|
||||||
|
hidden_states, _ = self.out_proj(attn_output)
|
||||||
|
return hidden_states
|
||||||
|
|
||||||
|
|
||||||
|
class DbrxFusedNormAttention(nn.Module):
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
config: DbrxConfig,
|
||||||
|
linear_method: Optional[LinearMethodBase] = None,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
self.d_model = config.d_model
|
||||||
|
self.attn = DbrxAttention(config, linear_method)
|
||||||
|
self.norm_1 = nn.LayerNorm(self.d_model)
|
||||||
|
self.norm_2 = nn.LayerNorm(self.d_model)
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
position_ids: torch.Tensor,
|
||||||
|
hidden_states: torch.Tensor,
|
||||||
|
kv_cache: torch.Tensor,
|
||||||
|
attn_metadata: AttentionMetadata,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
residual = hidden_states
|
||||||
|
hidden_states = self.norm_1(hidden_states)
|
||||||
|
x = self.attn(
|
||||||
|
position_ids=position_ids,
|
||||||
|
hidden_states=hidden_states,
|
||||||
|
kv_cache=kv_cache,
|
||||||
|
attn_metadata=attn_metadata,
|
||||||
|
)
|
||||||
|
hidden_states = residual + x
|
||||||
|
residual = hidden_states
|
||||||
|
hidden_states = self.norm_2(hidden_states)
|
||||||
|
return hidden_states, residual
|
||||||
|
|
||||||
|
|
||||||
|
class DbrxBlock(nn.Module):
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
config: DbrxConfig,
|
||||||
|
linear_method: Optional[LinearMethodBase] = None,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
self.norm_attn_norm = DbrxFusedNormAttention(config, linear_method)
|
||||||
|
self.ffn = DbrxExperts(config, linear_method)
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
position_ids: torch.Tensor,
|
||||||
|
hidden_states: torch.Tensor,
|
||||||
|
kv_cache: torch.Tensor,
|
||||||
|
attn_metadata: AttentionMetadata,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
hidden_states, residual = self.norm_attn_norm(
|
||||||
|
position_ids=position_ids,
|
||||||
|
hidden_states=hidden_states,
|
||||||
|
kv_cache=kv_cache,
|
||||||
|
attn_metadata=attn_metadata,
|
||||||
|
)
|
||||||
|
hidden_states = self.ffn(hidden_states)
|
||||||
|
hidden_states = hidden_states + residual
|
||||||
|
return hidden_states
|
||||||
|
|
||||||
|
|
||||||
|
class DbrxModel(nn.Module):
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
config: DbrxConfig,
|
||||||
|
linear_method: Optional[LinearMethodBase] = None,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
self.wte = VocabParallelEmbedding(
|
||||||
|
config.vocab_size,
|
||||||
|
config.d_model,
|
||||||
|
)
|
||||||
|
self.blocks = nn.ModuleList(
|
||||||
|
[DbrxBlock(config, linear_method) for _ in range(config.n_layers)])
|
||||||
|
self.norm_f = nn.LayerNorm(config.d_model, eps=1e-5)
|
||||||
|
for module in self.modules():
|
||||||
|
if hasattr(module, "bias") and isinstance(module.bias,
|
||||||
|
nn.Parameter):
|
||||||
|
# Remove the bias term in Linear and LayerNorm.
|
||||||
|
module.register_parameter("bias", None)
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
input_ids: torch.Tensor,
|
||||||
|
position_ids: torch.Tensor,
|
||||||
|
kv_caches: List[torch.Tensor],
|
||||||
|
attn_metadata: AttentionMetadata,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
hidden_states = self.wte(input_ids)
|
||||||
|
for i in range(len(self.blocks)):
|
||||||
|
block = self.blocks[i]
|
||||||
|
hidden_states = block(
|
||||||
|
position_ids,
|
||||||
|
hidden_states,
|
||||||
|
kv_caches[i],
|
||||||
|
attn_metadata,
|
||||||
|
)
|
||||||
|
hidden_states = self.norm_f(hidden_states)
|
||||||
|
return hidden_states
|
||||||
|
|
||||||
|
|
||||||
|
class DbrxForCausalLM(nn.Module):
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
config: DbrxConfig,
|
||||||
|
linear_method: Optional[LinearMethodBase] = None,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
self.config = config
|
||||||
|
self.linear_method = linear_method
|
||||||
|
self.unpadded_vocab_size = config.vocab_size
|
||||||
|
self.transformer = DbrxModel(config, linear_method)
|
||||||
|
self.lm_head = ParallelLMHead(
|
||||||
|
config.vocab_size,
|
||||||
|
config.d_model,
|
||||||
|
org_num_embeddings=config.vocab_size,
|
||||||
|
padding_size=DEFAULT_VOCAB_PADDING_SIZE,
|
||||||
|
)
|
||||||
|
self.logits_processor = LogitsProcessor(self.unpadded_vocab_size,
|
||||||
|
config.vocab_size)
|
||||||
|
self.sampler = Sampler()
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
input_ids: torch.Tensor,
|
||||||
|
positions: torch.Tensor,
|
||||||
|
kv_caches: List[torch.Tensor],
|
||||||
|
attn_metadata: AttentionMetadata,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
hidden_states = self.transformer(input_ids, positions, kv_caches,
|
||||||
|
attn_metadata)
|
||||||
|
return hidden_states
|
||||||
|
|
||||||
|
def compute_logits(self, hidden_states: torch.Tensor,
|
||||||
|
sampling_metadata: SamplingMetadata) -> torch.Tensor:
|
||||||
|
logits = self.logits_processor(self.lm_head.weight, hidden_states,
|
||||||
|
sampling_metadata)
|
||||||
|
return logits
|
||||||
|
|
||||||
|
def sample(
|
||||||
|
self,
|
||||||
|
logits: Optional[torch.Tensor],
|
||||||
|
sampling_metadata: SamplingMetadata,
|
||||||
|
) -> Optional[SamplerOutput]:
|
||||||
|
next_tokens = self.sampler(logits, sampling_metadata)
|
||||||
|
return next_tokens
|
||||||
|
|
||||||
|
def load_weights(
|
||||||
|
self,
|
||||||
|
model_name_or_path: str,
|
||||||
|
cache_dir: Optional[str] = None,
|
||||||
|
load_format: str = "auto",
|
||||||
|
revision: Optional[str] = None,
|
||||||
|
):
|
||||||
|
expert_params_mapping = [(
|
||||||
|
"ws" if weight_name in ["w1", "v1"] else "w2s",
|
||||||
|
f"experts.mlp.{weight_name}",
|
||||||
|
) for weight_name in ["w1", "v1", "w2"]]
|
||||||
|
params_dict = dict(self.named_parameters(remove_duplicate=False))
|
||||||
|
for name, loaded_weight in hf_model_weights_iterator(
|
||||||
|
model_name_or_path, cache_dir, load_format, revision):
|
||||||
|
for param_name, weight_name in expert_params_mapping:
|
||||||
|
if weight_name not in name:
|
||||||
|
continue
|
||||||
|
name = name.replace(weight_name, param_name)
|
||||||
|
param = params_dict[name]
|
||||||
|
weight_loader = param.weight_loader
|
||||||
|
weight_loader(param, loaded_weight, weight_name)
|
||||||
|
break
|
||||||
|
else:
|
||||||
|
param = params_dict[name]
|
||||||
|
weight_loader = getattr(param, "weight_loader",
|
||||||
|
default_weight_loader)
|
||||||
|
weight_loader(param, loaded_weight)
|
@ -6,6 +6,7 @@ from vllm.transformers_utils.configs import *
|
|||||||
|
|
||||||
_CONFIG_REGISTRY = {
|
_CONFIG_REGISTRY = {
|
||||||
"chatglm": ChatGLMConfig,
|
"chatglm": ChatGLMConfig,
|
||||||
|
"dbrx": DbrxConfig,
|
||||||
"mpt": MPTConfig,
|
"mpt": MPTConfig,
|
||||||
"RefinedWeb": RWConfig, # For tiiuae/falcon-40b(-instruct)
|
"RefinedWeb": RWConfig, # For tiiuae/falcon-40b(-instruct)
|
||||||
"RefinedWebModel": RWConfig, # For tiiuae/falcon-7b(-instruct)
|
"RefinedWebModel": RWConfig, # For tiiuae/falcon-7b(-instruct)
|
||||||
|
@ -1,4 +1,5 @@
|
|||||||
from vllm.transformers_utils.configs.chatglm import ChatGLMConfig
|
from vllm.transformers_utils.configs.chatglm import ChatGLMConfig
|
||||||
|
from vllm.transformers_utils.configs.dbrx import DbrxConfig
|
||||||
# RWConfig is for the original tiiuae/falcon-40b(-instruct) and
|
# RWConfig is for the original tiiuae/falcon-40b(-instruct) and
|
||||||
# tiiuae/falcon-7b(-instruct) models. Newer Falcon models will use the
|
# tiiuae/falcon-7b(-instruct) models. Newer Falcon models will use the
|
||||||
# `FalconConfig` class from the official HuggingFace transformers library.
|
# `FalconConfig` class from the official HuggingFace transformers library.
|
||||||
@ -8,6 +9,7 @@ from vllm.transformers_utils.configs.mpt import MPTConfig
|
|||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"ChatGLMConfig",
|
"ChatGLMConfig",
|
||||||
|
"DbrxConfig",
|
||||||
"MPTConfig",
|
"MPTConfig",
|
||||||
"RWConfig",
|
"RWConfig",
|
||||||
"JAISConfig",
|
"JAISConfig",
|
||||||
|
277
vllm/transformers_utils/configs/dbrx.py
Normal file
277
vllm/transformers_utils/configs/dbrx.py
Normal file
@ -0,0 +1,277 @@
|
|||||||
|
# yapf: disable
|
||||||
|
# ruff: noqa: E501
|
||||||
|
# coding=utf-8
|
||||||
|
# Copied from
|
||||||
|
# https://huggingface.co/databricks/dbrx-base/blob/main/configuration_dbrx.py
|
||||||
|
"""Dbrx configuration."""
|
||||||
|
|
||||||
|
from typing import Any, Optional
|
||||||
|
|
||||||
|
from transformers.configuration_utils import PretrainedConfig
|
||||||
|
from transformers.utils import logging
|
||||||
|
|
||||||
|
logger = logging.get_logger(__name__)
|
||||||
|
|
||||||
|
DBRX_PRETRAINED_CONFIG_ARCHIVE_MAP = {}
|
||||||
|
|
||||||
|
|
||||||
|
class DbrxAttentionConfig(PretrainedConfig):
|
||||||
|
"""Configuration class for Dbrx Attention.
|
||||||
|
|
||||||
|
[`DbrxAttention`] class. It is used to instantiate attention layers
|
||||||
|
according to the specified arguments, defining the layers architecture.
|
||||||
|
|
||||||
|
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
|
||||||
|
documentation from [`PretrainedConfig`] for more information.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
attn_pdrop (`float`, *optional*, defaults to 0.0):
|
||||||
|
The dropout probability for the attention layers.
|
||||||
|
clip_qkv (`float`, *optional*, defaults to None):
|
||||||
|
If not `None`, clip the queries, keys, and values in the attention layer to this value.
|
||||||
|
kv_n_heads (Optional[int]): For grouped_query_attention only, allow user to specify number of kv heads.
|
||||||
|
rope_theta (float): The base frequency for rope.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
attn_pdrop: float = 0,
|
||||||
|
clip_qkv: Optional[float] = None,
|
||||||
|
kv_n_heads: int = 1,
|
||||||
|
rope_theta: float = 10000.0,
|
||||||
|
**kwargs: Any,
|
||||||
|
):
|
||||||
|
super().__init__(**kwargs)
|
||||||
|
self.attn_pdrop = attn_pdrop
|
||||||
|
self.clip_qkv = clip_qkv
|
||||||
|
self.kv_n_heads = kv_n_heads
|
||||||
|
self.rope_theta = rope_theta
|
||||||
|
|
||||||
|
for k in ["model_type"]:
|
||||||
|
if k in kwargs:
|
||||||
|
kwargs.pop(k)
|
||||||
|
if len(kwargs) != 0:
|
||||||
|
raise ValueError(f"Found unknown {kwargs=}")
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_pretrained(
|
||||||
|
cls, pretrained_model_name_or_path: str, **kwargs: Any
|
||||||
|
) -> "PretrainedConfig":
|
||||||
|
cls._set_token_in_kwargs(kwargs)
|
||||||
|
|
||||||
|
config_dict, kwargs = cls.get_config_dict(
|
||||||
|
pretrained_model_name_or_path, **kwargs
|
||||||
|
)
|
||||||
|
|
||||||
|
if config_dict.get("model_type") == "dbrx":
|
||||||
|
config_dict = config_dict["attn_config"]
|
||||||
|
|
||||||
|
if (
|
||||||
|
"model_type" in config_dict
|
||||||
|
and hasattr(cls, "model_type")
|
||||||
|
and config_dict["model_type"] != cls.model_type
|
||||||
|
):
|
||||||
|
logger.warning(
|
||||||
|
f"You are using a model of type {config_dict['model_type']} to instantiate a model of type "
|
||||||
|
+ f"{cls.model_type}. This is not supported for all configurations of models and can yield errors."
|
||||||
|
)
|
||||||
|
|
||||||
|
return cls.from_dict(config_dict, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
class DbrxFFNConfig(PretrainedConfig):
|
||||||
|
"""Configuration class for Dbrx FFN.
|
||||||
|
|
||||||
|
[`DbrxFFN`] class. It is used to instantiate feedforward layers according to
|
||||||
|
the specified arguments, defining the layers architecture.
|
||||||
|
|
||||||
|
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
|
||||||
|
documentation from [`PretrainedConfig`] for more information.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
ffn_act_fn (dict, optional): A dict specifying activation function for the FFN.
|
||||||
|
The dict should have a key 'name' with the value being the name of
|
||||||
|
the activation function along with any additional keyword arguments.
|
||||||
|
ffn_hidden_size (int, optional): The hidden size of the feedforward network.
|
||||||
|
moe_num_experts (int, optional): The number of experts in the mixture of experts layer.
|
||||||
|
moe_top_k (int, optional): The number of experts to use in the mixture of experts layer.
|
||||||
|
moe_jitter_eps (float, optional): The jitter epsilon for the mixture of experts layer.
|
||||||
|
moe_loss_weight (float, optional): The loss weight for the mixture of experts layer.
|
||||||
|
moe_normalize_expert_weights (float, optional): The normalization factor for the expert weights.
|
||||||
|
uniform_expert_assignment (bool, optional): Whether to use uniform expert assignment.
|
||||||
|
This should only be used for benchmarking purposes.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
ffn_act_fn: Optional[dict] = None,
|
||||||
|
ffn_hidden_size: int = 3584,
|
||||||
|
moe_num_experts: int = 4,
|
||||||
|
moe_top_k: int = 1,
|
||||||
|
moe_jitter_eps: Optional[float] = None,
|
||||||
|
moe_loss_weight: float = 0.01,
|
||||||
|
moe_normalize_expert_weights: Optional[float] = 1,
|
||||||
|
uniform_expert_assignment: bool = False,
|
||||||
|
**kwargs: Any,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
if ffn_act_fn is None:
|
||||||
|
ffn_act_fn = {"name": "silu"}
|
||||||
|
self.ffn_act_fn = ffn_act_fn
|
||||||
|
self.ffn_hidden_size = ffn_hidden_size
|
||||||
|
self.moe_num_experts = moe_num_experts
|
||||||
|
self.moe_top_k = moe_top_k
|
||||||
|
self.moe_jitter_eps = moe_jitter_eps
|
||||||
|
self.moe_loss_weight = moe_loss_weight
|
||||||
|
self.moe_normalize_expert_weights = moe_normalize_expert_weights
|
||||||
|
self.uniform_expert_assignment = uniform_expert_assignment
|
||||||
|
|
||||||
|
for k in ["model_type"]:
|
||||||
|
if k in kwargs:
|
||||||
|
kwargs.pop(k)
|
||||||
|
if len(kwargs) != 0:
|
||||||
|
raise ValueError(f"Found unknown {kwargs=}")
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_pretrained(
|
||||||
|
cls, pretrained_model_name_or_path: str, **kwargs: Any
|
||||||
|
) -> "PretrainedConfig":
|
||||||
|
cls._set_token_in_kwargs(kwargs)
|
||||||
|
|
||||||
|
config_dict, kwargs = cls.get_config_dict(
|
||||||
|
pretrained_model_name_or_path, **kwargs
|
||||||
|
)
|
||||||
|
|
||||||
|
if config_dict.get("model_type") == "dbrx":
|
||||||
|
config_dict = config_dict["ffn_config"]
|
||||||
|
|
||||||
|
if (
|
||||||
|
"model_type" in config_dict
|
||||||
|
and hasattr(cls, "model_type")
|
||||||
|
and config_dict["model_type"] != cls.model_type
|
||||||
|
):
|
||||||
|
logger.warning(
|
||||||
|
f"You are using a model of type {config_dict['model_type']} to instantiate a model of type "
|
||||||
|
+ f"{cls.model_type}. This is not supported for all configurations of models and can yield errors."
|
||||||
|
)
|
||||||
|
|
||||||
|
return cls.from_dict(config_dict, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
class DbrxConfig(PretrainedConfig):
|
||||||
|
"""Configuration class for Dbrx.
|
||||||
|
|
||||||
|
[`DbrxModel`]. It is used to instantiate a Dbrx model according to the
|
||||||
|
specified arguments, defining the model architecture.
|
||||||
|
|
||||||
|
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
|
||||||
|
documentation from [`PretrainedConfig`] for more information.
|
||||||
|
|
||||||
|
|
||||||
|
Args:
|
||||||
|
d_model (`int`, *optional*, defaults to 6144):
|
||||||
|
Dimensionality of the embeddings and hidden states.
|
||||||
|
n_heads (`int`, *optional*, defaults to 48):
|
||||||
|
Number of attention heads for each attention layer in the Transformer encoder.
|
||||||
|
n_layers (`int`, *optional*, defaults to 40):
|
||||||
|
Number of hidden layers in the Transformer encoder.
|
||||||
|
max_seq_len (`int`, *optional*, defaults to 32768):
|
||||||
|
The maximum sequence length of the model.
|
||||||
|
vocab_size (`int`, *optional*, defaults to 100352):
|
||||||
|
Vocabulary size of the Dbrx model. Defines the maximum number of different tokens that can be represented by
|
||||||
|
the `inputs_ids` passed when calling [`DbrxModel`].
|
||||||
|
resid_pdrop (`float`, *optional*, defaults to 0.0):
|
||||||
|
The dropout probability applied to the attention output before combining with residual.
|
||||||
|
emb_pdrop (`float`, *optional*, defaults to 0.0):
|
||||||
|
The dropout probability for the embedding layer.
|
||||||
|
attn_config (`dict`, *optional*):
|
||||||
|
A dictionary used to configure the model's attention module.
|
||||||
|
ffn_config (`dict`, *optional*):
|
||||||
|
A dictionary used to configure the model's FFN module.
|
||||||
|
use_cache (`bool`, *optional*, defaults to `False`):
|
||||||
|
Whether or not the model should return the last key/values attentions (not used by all models).
|
||||||
|
initializer_range (`float`, *optional*, defaults to 0.02):
|
||||||
|
The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
|
||||||
|
output_router_logits (`bool`, *optional*, defaults to `False`):
|
||||||
|
Whether or not the router logits should be returned by the model. Enabling this will also
|
||||||
|
allow the model to output the auxiliary loss. See [here]() for more details
|
||||||
|
router_aux_loss_coef (`float`, *optional*, defaults to 0.001):
|
||||||
|
The aux loss factor for the total loss.
|
||||||
|
|
||||||
|
|
||||||
|
Example:
|
||||||
|
```python
|
||||||
|
>>> from transformers import DbrxConfig, DbrxModel
|
||||||
|
|
||||||
|
>>> # Initializing a Dbrx configuration
|
||||||
|
>>> configuration = DbrxConfig()
|
||||||
|
|
||||||
|
>>> # Initializing a model (with random weights) from the configuration
|
||||||
|
>>> model = DbrxModel(configuration)
|
||||||
|
|
||||||
|
>>> # Accessing the model configuration
|
||||||
|
>>> configuration = model.config
|
||||||
|
```
|
||||||
|
"""
|
||||||
|
|
||||||
|
model_type = "dbrx"
|
||||||
|
attribute_map = {
|
||||||
|
"num_attention_heads": "n_heads",
|
||||||
|
"hidden_size": "d_model",
|
||||||
|
"num_hidden_layers": "n_layers",
|
||||||
|
"max_position_embeddings": "max_seq_len",
|
||||||
|
}
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
d_model: int = 2048,
|
||||||
|
n_heads: int = 16,
|
||||||
|
n_layers: int = 24,
|
||||||
|
max_seq_len: int = 2048,
|
||||||
|
vocab_size: int = 32000,
|
||||||
|
resid_pdrop: float = 0.0,
|
||||||
|
emb_pdrop: float = 0.0,
|
||||||
|
attn_config: Optional[DbrxAttentionConfig] = None,
|
||||||
|
ffn_config: Optional[DbrxFFNConfig] = None,
|
||||||
|
use_cache: bool = True,
|
||||||
|
initializer_range: float = 0.02,
|
||||||
|
output_router_logits: bool = False,
|
||||||
|
router_aux_loss_coef: float = 0.05,
|
||||||
|
**kwargs: Any,
|
||||||
|
):
|
||||||
|
if attn_config is None:
|
||||||
|
self.attn_config = DbrxAttentionConfig()
|
||||||
|
elif isinstance(attn_config, dict):
|
||||||
|
self.attn_config = DbrxAttentionConfig(**attn_config)
|
||||||
|
else:
|
||||||
|
self.attn_config = attn_config
|
||||||
|
|
||||||
|
if ffn_config is None:
|
||||||
|
self.ffn_config = DbrxFFNConfig()
|
||||||
|
elif isinstance(ffn_config, dict):
|
||||||
|
self.ffn_config = DbrxFFNConfig(**ffn_config)
|
||||||
|
else:
|
||||||
|
self.ffn_config = ffn_config
|
||||||
|
|
||||||
|
self.d_model = d_model
|
||||||
|
self.n_heads = n_heads
|
||||||
|
self.n_layers = n_layers
|
||||||
|
self.max_seq_len = max_seq_len
|
||||||
|
self.vocab_size = vocab_size
|
||||||
|
self.resid_pdrop = resid_pdrop
|
||||||
|
self.emb_pdrop = emb_pdrop
|
||||||
|
self.use_cache = use_cache
|
||||||
|
self.initializer_range = initializer_range
|
||||||
|
self.output_router_logits = output_router_logits
|
||||||
|
self.router_aux_loss_coef = router_aux_loss_coef
|
||||||
|
|
||||||
|
tie_word_embeddings = kwargs.pop("tie_word_embeddings", False)
|
||||||
|
if tie_word_embeddings:
|
||||||
|
raise ValueError(
|
||||||
|
"tie_word_embeddings is not supported for Dbrx models."
|
||||||
|
)
|
||||||
|
|
||||||
|
super().__init__(
|
||||||
|
tie_word_embeddings=tie_word_embeddings,
|
||||||
|
**kwargs,
|
||||||
|
)
|
Loading…
x
Reference in New Issue
Block a user