[Model] support minicpm3 (#8297)

Co-authored-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
ywfang 2024-09-14 22:50:26 +08:00 committed by GitHub
parent 1ef0d2efd0
commit 8a0cf1ddc3
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
7 changed files with 282 additions and 38 deletions

View File

@ -22,7 +22,7 @@ docker exec cpu-test-avx2 bash -c "python3 examples/offline_inference.py"
# Run basic model test # Run basic model test
docker exec cpu-test bash -c " docker exec cpu-test bash -c "
pip install pytest matplotlib einops transformers_stream_generator pip install pytest matplotlib einops transformers_stream_generator datamodel_code_generator
pytest -v -s tests/models/decoder_only/language \ pytest -v -s tests/models/decoder_only/language \
--ignore=tests/models/test_fp8.py \ --ignore=tests/models/test_fp8.py \
--ignore=tests/models/decoder_only/language/test_jamba.py \ --ignore=tests/models/decoder_only/language/test_jamba.py \

View File

@ -107,6 +107,10 @@ Decoder-only Language Models
- MiniCPM - MiniCPM
- :code:`openbmb/MiniCPM-2B-sft-bf16`, :code:`openbmb/MiniCPM-2B-dpo-bf16`, etc. - :code:`openbmb/MiniCPM-2B-sft-bf16`, :code:`openbmb/MiniCPM-2B-dpo-bf16`, etc.
- -
* - :code:`MiniCPM3ForCausalLM`
- MiniCPM3
- :code:`openbmb/MiniCPM3-4B`, etc.
-
* - :code:`MistralForCausalLM` * - :code:`MistralForCausalLM`
- Mistral, Mistral-Instruct - Mistral, Mistral-Instruct
- :code:`mistralai/Mistral-7B-v0.1`, :code:`mistralai/Mistral-7B-Instruct-v0.1`, etc. - :code:`mistralai/Mistral-7B-v0.1`, :code:`mistralai/Mistral-7B-Instruct-v0.1`, etc.

View File

@ -21,6 +21,7 @@ compressed-tensors==0.4.0 # required for compressed-tensors
timm # required for internvl test timm # required for internvl test
transformers_stream_generator # required for qwen-vl test transformers_stream_generator # required for qwen-vl test
matplotlib # required for qwen-vl test matplotlib # required for qwen-vl test
datamodel_code_generator # required for minicpm3 test
# TODO: Add this after fully implementing llava(mantis) # TODO: Add this after fully implementing llava(mantis)
# git+https://github.com/TIGER-AI-Lab/Mantis.git # required for llava(mantis) test # git+https://github.com/TIGER-AI-Lab/Mantis.git # required for llava(mantis) test

View File

@ -5,7 +5,8 @@ This tests bigger models and use half precision.
Run `pytest tests/models/test_big_models.py`. Run `pytest tests/models/test_big_models.py`.
""" """
import pytest import pytest
import torch
from vllm.platforms import current_platform
from ...utils import check_outputs_equal from ...utils import check_outputs_equal
@ -19,10 +20,12 @@ MODELS = [
# "Qwen/Qwen1.5-0.5B" # Broken, # "Qwen/Qwen1.5-0.5B" # Broken,
] ]
if not current_platform.is_cpu():
# MiniCPM requires fused_moe which is not supported by CPU
MODELS.append("openbmb/MiniCPM3-4B")
#TODO: remove this after CPU float16 support ready #TODO: remove this after CPU float16 support ready
target_dtype = "float" target_dtype = "float" if current_platform.is_cpu() else "half"
if torch.cuda.is_available():
target_dtype = "half"
@pytest.mark.parametrize("model", MODELS) @pytest.mark.parametrize("model", MODELS)
@ -39,7 +42,7 @@ def test_models(
with hf_runner(model, dtype=dtype) as hf_model: with hf_runner(model, dtype=dtype) as hf_model:
hf_outputs = hf_model.generate_greedy(example_prompts, max_tokens) hf_outputs = hf_model.generate_greedy(example_prompts, max_tokens)
with vllm_runner(model, dtype=dtype) as vllm_model: with vllm_runner(model, dtype=dtype, enforce_eager=True) as vllm_model:
vllm_outputs = vllm_model.generate_greedy(example_prompts, max_tokens) vllm_outputs = vllm_model.generate_greedy(example_prompts, max_tokens)
check_outputs_equal( check_outputs_equal(
@ -57,7 +60,7 @@ def test_model_print(
model: str, model: str,
dtype: str, dtype: str,
) -> None: ) -> None:
with vllm_runner(model, dtype=dtype) as vllm_model: with vllm_runner(model, dtype=dtype, enforce_eager=True) as vllm_model:
# This test is for verifying whether the model's extra_repr # This test is for verifying whether the model's extra_repr
# can be printed correctly. # can be printed correctly.
print(vllm_model.model.llm_engine.model_executor.driver_worker. print(vllm_model.model.llm_engine.model_executor.driver_worker.

View File

@ -43,6 +43,7 @@ _GENERATION_MODELS = {
"MptForCausalLM": ("mpt", "MPTForCausalLM"), "MptForCausalLM": ("mpt", "MPTForCausalLM"),
"MPTForCausalLM": ("mpt", "MPTForCausalLM"), "MPTForCausalLM": ("mpt", "MPTForCausalLM"),
"MiniCPMForCausalLM": ("minicpm", "MiniCPMForCausalLM"), "MiniCPMForCausalLM": ("minicpm", "MiniCPMForCausalLM"),
"MiniCPM3ForCausalLM": ("minicpm3", "MiniCPM3ForCausalLM"),
"NemotronForCausalLM": ("nemotron", "NemotronForCausalLM"), "NemotronForCausalLM": ("nemotron", "NemotronForCausalLM"),
"OlmoForCausalLM": ("olmo", "OlmoForCausalLM"), "OlmoForCausalLM": ("olmo", "OlmoForCausalLM"),
"OPTForCausalLM": ("opt", "OPTForCausalLM"), "OPTForCausalLM": ("opt", "OPTForCausalLM"),

View File

@ -270,38 +270,47 @@ class MiniCPMDecoderLayer(nn.Module):
) -> None: ) -> None:
super().__init__() super().__init__()
self.config = config self.config = config
self.cache_config = cache_config
self.quant_config = quant_config
self.hidden_size = config.hidden_size self.hidden_size = config.hidden_size
rope_theta = getattr(config, "rope_theta", 10000) self.rope_theta = getattr(config, "rope_theta", 10000)
rope_scaling = getattr(config, "rope_scaling", None) self.rope_scaling = getattr(config, "rope_scaling", None)
max_position_embeddings = getattr(config, "max_position_embeddings", self.max_position_embeddings = getattr(config,
8192) "max_position_embeddings", 8192)
self._init_attn_block()
self._init_ffn_block()
def _init_attn_block(self):
self.input_layernorm = RMSNorm(self.config.hidden_size,
eps=self.config.rms_norm_eps)
self.self_attn = MiniCPMAttention( self.self_attn = MiniCPMAttention(
hidden_size=self.hidden_size, hidden_size=self.hidden_size,
num_heads=config.num_attention_heads, num_heads=self.config.num_attention_heads,
num_kv_heads=config.num_key_value_heads, num_kv_heads=self.config.num_key_value_heads,
rope_theta=rope_theta, rope_theta=self.rope_theta,
rope_scaling=rope_scaling, rope_scaling=self.rope_scaling,
max_position_embeddings=max_position_embeddings, max_position_embeddings=self.max_position_embeddings,
cache_config=cache_config, cache_config=self.cache_config,
quant_config=quant_config, quant_config=self.quant_config,
) )
def _init_ffn_block(self):
self.post_attention_layernorm = RMSNorm(self.config.hidden_size,
eps=self.config.rms_norm_eps)
self.num_experts = getattr(self.config, "num_experts", 0) self.num_experts = getattr(self.config, "num_experts", 0)
if self.num_experts == 0: if self.num_experts == 0:
self.mlp = MiniCPMMLP( self.mlp = MiniCPMMLP(
hidden_size=self.hidden_size, hidden_size=self.hidden_size,
intermediate_size=config.intermediate_size, intermediate_size=self.config.intermediate_size,
hidden_act=config.hidden_act, hidden_act=self.config.hidden_act,
quant_config=quant_config, quant_config=self.quant_config,
) )
else: else:
self.mlp = MiniCPMMoE(num_experts=config.num_experts, self.mlp = MiniCPMMoE(
top_k=config.num_experts_per_tok, num_experts=self.config.num_experts,
hidden_size=config.hidden_size, top_k=self.config.num_experts_per_tok,
intermediate_size=config.intermediate_size) hidden_size=self.config.hidden_size,
self.input_layernorm = RMSNorm(config.hidden_size, intermediate_size=self.config.intermediate_size)
eps=config.rms_norm_eps)
self.post_attention_layernorm = RMSNorm(config.hidden_size,
eps=config.rms_norm_eps)
def forward( def forward(
self, self,
@ -344,6 +353,8 @@ class MiniCPMModel(nn.Module):
) -> None: ) -> None:
super().__init__() super().__init__()
self.config = config self.config = config
self.cache_config = cache_config
self.quant_config = quant_config
self.padding_idx = config.pad_token_id self.padding_idx = config.pad_token_id
lora_vocab = (lora_config.lora_extra_vocab_size * lora_vocab = (lora_config.lora_extra_vocab_size *
(lora_config.max_loras or 1)) if lora_config else 0 (lora_config.max_loras or 1)) if lora_config else 0
@ -354,12 +365,16 @@ class MiniCPMModel(nn.Module):
config.hidden_size, config.hidden_size,
org_num_embeddings=config.vocab_size, org_num_embeddings=config.vocab_size,
) )
self.layers = nn.ModuleList([ self._init_layers()
MiniCPMDecoderLayer(config, cache_config, quant_config)
for _ in range(config.num_hidden_layers)
])
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
def _init_layers(self):
self.layers = nn.ModuleList([
MiniCPMDecoderLayer(self.config, self.cache_config,
self.quant_config)
for _ in range(self.config.num_hidden_layers)
])
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
embedding = self.embed_tokens(input_ids) embedding = self.embed_tokens(input_ids)
return embedding * self.config.scale_emb return embedding * self.config.scale_emb
@ -431,13 +446,11 @@ class MiniCPMForCausalLM(nn.Module, SupportsLoRA):
self.config = config self.config = config
self.lora_config = lora_config self.lora_config = lora_config
self.cache_config = cache_config
self.quant_config = quant_config
self.num_experts = getattr(self.config, "num_experts", 0) self.num_experts = getattr(self.config, "num_experts", 0)
self.quant_config = quant_config self._init_model()
self.model = MiniCPMModel(config,
cache_config,
quant_config,
lora_config=lora_config)
unpadded_vocab_size = config.vocab_size unpadded_vocab_size = config.vocab_size
if lora_config: if lora_config:
unpadded_vocab_size += lora_config.lora_extra_vocab_size unpadded_vocab_size += lora_config.lora_extra_vocab_size
@ -458,6 +471,12 @@ class MiniCPMForCausalLM(nn.Module, SupportsLoRA):
config.vocab_size) config.vocab_size)
self.sampler = Sampler() self.sampler = Sampler()
def _init_model(self):
self.model = MiniCPMModel(config=self.config,
cache_config=self.cache_config,
quant_config=self.quant_config,
lora_config=self.lora_config)
def forward( def forward(
self, self,
input_ids: torch.Tensor, input_ids: torch.Tensor,

View File

@ -0,0 +1,216 @@
# coding=utf-8
# Adapted from
# https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/llama/modeling_llama.py
# Copyright 2024 The ModelBest team.
# Copyright 2023 The vLLM team.
# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
#
# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
# and OPT implementations in this library. It has been modified from its
# original forms to accommodate minor architectural differences compared
# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Inference-only MiniCPM3 model compatible with HuggingFace weights."""
from typing import Any, Dict, Optional
import torch
from torch import nn
from vllm.attention import Attention, AttentionMetadata
from vllm.config import CacheConfig
from vllm.distributed import get_tensor_model_parallel_world_size
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
ReplicatedLinear,
RowParallelLinear)
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig)
from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.models.minicpm import (MiniCPMDecoderLayer,
MiniCPMForCausalLM,
MiniCPMModel)
class MiniCPM3Attention(nn.Module):
def __init__(
self,
config,
hidden_size: int,
num_heads: int,
qk_nope_head_dim: int,
qk_rope_head_dim: int,
v_head_dim: int,
q_lora_rank: int,
kv_lora_rank: int,
rope_theta: float = 10000,
rope_scaling: Optional[Dict[str, Any]] = None,
max_position_embeddings: int = 8192,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
) -> None:
super().__init__()
self.hidden_size = hidden_size
self.qk_nope_head_dim = qk_nope_head_dim
self.qk_rope_head_dim = qk_rope_head_dim
self.qk_head_dim = qk_nope_head_dim + qk_rope_head_dim
self.v_head_dim = v_head_dim
self.q_lora_rank = q_lora_rank
self.kv_lora_rank = kv_lora_rank
self.num_heads = num_heads
tp_size = get_tensor_model_parallel_world_size()
assert self.num_heads % tp_size == 0
self.num_local_heads = num_heads // tp_size
self.scaling = self.qk_head_dim**-0.5
self.rope_theta = rope_theta
self.max_position_embeddings = max_position_embeddings
self.q_a_proj = ReplicatedLinear(self.hidden_size,
self.q_lora_rank,
bias=False,
quant_config=quant_config)
self.q_a_layernorm = RMSNorm(self.q_lora_rank, eps=config.rms_norm_eps)
self.q_b_proj = ColumnParallelLinear(q_lora_rank,
self.num_heads * self.qk_head_dim,
bias=False,
quant_config=quant_config)
self.kv_a_proj_with_mqa = ReplicatedLinear(self.hidden_size,
self.kv_lora_rank +
self.qk_rope_head_dim,
bias=False,
quant_config=quant_config)
self.kv_a_layernorm = RMSNorm(self.kv_lora_rank,
eps=config.rms_norm_eps)
self.kv_b_proj = ColumnParallelLinear(
self.kv_lora_rank,
self.num_heads * (self.qk_nope_head_dim + self.v_head_dim),
bias=False,
quant_config=quant_config)
# O projection.
self.o_proj = RowParallelLinear(self.num_heads * self.v_head_dim,
self.hidden_size,
bias=False,
quant_config=quant_config)
self.rotary_emb = get_rope(
self.qk_rope_head_dim,
rotary_dim=self.qk_rope_head_dim,
max_position=max_position_embeddings,
base=rope_theta,
rope_scaling=rope_scaling,
)
self.attn = Attention(self.num_local_heads,
self.qk_head_dim,
self.scaling,
num_kv_heads=self.num_local_heads,
cache_config=cache_config,
quant_config=quant_config)
def forward(
self,
positions: torch.Tensor,
hidden_states: torch.Tensor,
kv_cache: torch.Tensor,
attn_metadata: AttentionMetadata,
) -> torch.Tensor:
q, _ = self.q_a_proj(hidden_states)
q = self.q_a_layernorm(q)
q, _ = self.q_b_proj(q)
q = q.view(-1, self.num_local_heads, self.qk_head_dim)
_, q_pe = q.split([self.qk_nope_head_dim, self.qk_rope_head_dim],
dim=-1)
latent_cache, _ = self.kv_a_proj_with_mqa(hidden_states)
kv_a, _ = latent_cache.split(
[self.kv_lora_rank, self.qk_rope_head_dim], dim=-1)
latent_cache = latent_cache.unsqueeze(1)
kv_a = self.kv_a_layernorm(kv_a.contiguous())
kv, _ = self.kv_b_proj(kv_a)
kv = kv.view(-1, self.num_local_heads,
self.qk_nope_head_dim + self.v_head_dim)
k_nope, v = kv.split([self.qk_nope_head_dim, self.v_head_dim], dim=-1)
k_pe = latent_cache[:, :, self.kv_lora_rank:]
q_pe, k_pe = self.rotary_emb(
positions,
q_pe.reshape(-1, self.num_local_heads * self.qk_rope_head_dim),
k_pe.reshape(-1, self.qk_rope_head_dim))
q_pe = q_pe.view(-1, self.num_local_heads, self.qk_rope_head_dim)
k_pe = k_pe.view(-1, 1, self.qk_rope_head_dim)
q[..., self.qk_nope_head_dim:] = q_pe
k = torch.empty_like(q)
k[..., :self.qk_nope_head_dim] = k_nope
k[..., self.qk_nope_head_dim:] = k_pe
q = q.reshape(-1, self.num_local_heads * self.qk_head_dim)
k = k.view(-1, self.num_local_heads * self.qk_head_dim)
v = torch.nn.functional.pad(
v, [0, self.qk_head_dim - self.v_head_dim],
value=0).view(-1, self.num_local_heads * self.qk_head_dim)
attn_output = self.attn(q, k, v, kv_cache, attn_metadata)
attn_output = attn_output.view(
-1, self.num_local_heads,
self.qk_head_dim)[..., :self.v_head_dim].reshape(
-1, self.num_local_heads * self.v_head_dim)
output, _ = self.o_proj(attn_output)
return output
class MiniCPM3DecoderLayer(MiniCPMDecoderLayer):
def _init_attn_block(self):
self.input_layernorm = RMSNorm(self.config.hidden_size,
eps=self.config.rms_norm_eps)
self.self_attn = MiniCPM3Attention(
config=self.config,
hidden_size=self.hidden_size,
num_heads=self.config.num_attention_heads,
qk_nope_head_dim=self.config.qk_nope_head_dim,
qk_rope_head_dim=self.config.qk_rope_head_dim,
v_head_dim=self.config.v_head_dim,
q_lora_rank=self.config.q_lora_rank,
kv_lora_rank=self.config.kv_lora_rank,
rope_theta=self.rope_theta,
rope_scaling=self.rope_scaling,
max_position_embeddings=self.max_position_embeddings,
cache_config=self.cache_config,
quant_config=self.quant_config,
)
class MiniCPM3Model(MiniCPMModel):
def _init_layers(self):
self.layers = nn.ModuleList([
MiniCPM3DecoderLayer(self.config, self.cache_config,
self.quant_config)
for _ in range(self.config.num_hidden_layers)
])
class MiniCPM3ForCausalLM(MiniCPMForCausalLM):
def _init_model(self):
self.model = MiniCPM3Model(config=self.config,
cache_config=self.cache_config,
quant_config=self.quant_config,
lora_config=self.lora_config)