[Model] Add PLaMo2 (#14323)

Signed-off-by: Shinichi Hemmi <50256998+Alnusjaponica@users.noreply.github.com>
Signed-off-by: shemmi <shemmi@preferred.jp>
Co-authored-by: Kento Nozawa <nzw0301@preferred.jp>
Co-authored-by: Hiroaki Mikami <mhiroaki@preferred.jp>
Co-authored-by: Calvin Metzger <metzger@preferred.jp>
This commit is contained in:
Shinichi Hemmi 2025-04-16 11:31:30 +09:00 committed by GitHub
parent fdcb850f14
commit 3badb0213b
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
9 changed files with 800 additions and 24 deletions

View File

@ -400,8 +400,9 @@ steps:
- pytest -v -s models/test_transformers.py
- pytest -v -s models/test_registry.py
# V1 Test: https://github.com/vllm-project/vllm/issues/14531
- VLLM_USE_V1=0 pytest -v -s models/test_initialization.py -k 'not llama4'
- VLLM_USE_V1=0 pytest -v -s models/test_initialization.py -k 'not llama4 and not plamo2'
- VLLM_USE_V1=0 pytest -v -s models/test_initialization.py -k 'llama4'
- VLLM_USE_V1=0 pytest -v -s models/test_initialization.py -k 'plamo2'
- label: Language Models Test (Standard) # 32min
#mirror_hardwares: [amd]
@ -411,6 +412,8 @@ steps:
- tests/models/embedding/language
- tests/models/encoder_decoder/language
commands:
# Install causal-conv1d for plamo2 models here, as it is not compatible with pip-compile.
- pip install causal-conv1d
- pytest -v -s models/decoder_only/language -m 'core_model or quant_model'
- pytest -v -s models/embedding/language -m core_model
@ -422,6 +425,8 @@ steps:
- tests/models/embedding/language
- tests/models/encoder_decoder/language
commands:
# Install causal-conv1d for plamo2 models here, as it is not compatible with pip-compile.
- pip install causal-conv1d
- pytest -v -s models/decoder_only/language -m 'not core_model and not quant_model'
- pytest -v -s models/embedding/language -m 'not core_model'

View File

@ -497,6 +497,11 @@ See [this page](#generative-models) for more information on how to use generativ
* `adept/persimmon-8b-base`, `adept/persimmon-8b-chat`, etc.
*
* ✅︎
- * `Plamo2ForCausalLM`
* PLaMo2
* `pfnet/plamo-2-1b`, `pfnet/plamo-2-8b`, etc.
*
*
- * `QWenLMHeadModel`
* Qwen
* `Qwen/Qwen-7B`, `Qwen/Qwen-7B-Chat`, etc.

View File

@ -27,6 +27,7 @@ torch==2.6.0
torchaudio==2.6.0
torchvision==0.21.0
transformers_stream_generator # required for qwen-vl test
mamba_ssm # required for plamo2 test
matplotlib # required for qwen-vl test
mistral_common[opencv] >= 1.5.4 # required for pixtral test
num2words # required for smolvlm test

View File

@ -111,6 +111,7 @@ einops==0.8.0
# via
# -r requirements/test.in
# encodec
# mamba-ssm
# vector-quantize-pytorch
# vocos
einx==0.3.0
@ -233,6 +234,8 @@ lxml==5.3.0
# via
# blobfile
# sacrebleu
mamba-ssm==2.2.4
# via -r requirements/test.in
markdown-it-py==3.0.0
# via rich
markupsafe==3.0.2
@ -268,6 +271,8 @@ mypy-extensions==1.0.0
# via black
networkx==3.2.1
# via torch
ninja==1.11.1.3
# via mamba-ssm
nltk==3.9.1
# via rouge-score
num2words==0.5.14
@ -360,6 +365,7 @@ packaging==24.1
# fastparquet
# huggingface-hub
# lazy-loader
# mamba-ssm
# matplotlib
# peft
# plotly
@ -571,6 +577,7 @@ sentencepiece==0.2.0
# via mistral-common
setuptools==75.8.0
# via
# mamba-ssm
# pytablewriter
# torch
shellingham==1.5.4
@ -627,6 +634,7 @@ torch==2.6.0
# encodec
# fastsafetensors
# lm-eval
# mamba-ssm
# peft
# runai-model-streamer
# sentence-transformers
@ -664,6 +672,7 @@ transformers==4.51.1
# -r requirements/test.in
# genai-perf
# lm-eval
# mamba-ssm
# peft
# sentence-transformers
# transformers-stream-generator

View File

@ -9,9 +9,15 @@ from vllm.sampling_params import SamplingParams
from ...utils import check_outputs_equal
# This test is for the hybrid models
MODELS = ["ai21labs/Jamba-tiny-dev", "Zyphra/Zamba2-1.2B-instruct"]
MODELS = [
"ai21labs/Jamba-tiny-dev", "Zyphra/Zamba2-1.2B-instruct",
"pfnet/plamo-2-1b"
]
# Bamba at Fp32 is too big for the CI (L4 GPU).
# MODELS = ["ai21labs/Jamba-tiny-dev", "ibm-ai-platform/Bamba-9B"]
# Note: Running Plamo2 in transformers implementation requires to install
# causal-conv1d package, which is not listed as a test dependency as it's
# not compatible with pip-compile.
@pytest.mark.parametrize("model", MODELS)
@ -25,21 +31,11 @@ def test_models(
dtype: str,
max_tokens: int,
) -> None:
# numeric error produces different generation
if "Bamba" in model:
example_prompts.pop(3)
model_kwargs = {
"use_mamba_kernels": False, # mamba kernels are not installed so HF
# don't use them
}
if "Zamba2" in model:
# Zamba2 HF implementation automatically checks if mamba kernels are
# installed
model_kwargs = {}
with hf_runner(model, dtype=dtype, model_kwargs=model_kwargs) as hf_model:
with hf_runner(model, dtype=dtype) as hf_model:
hf_outputs = hf_model.generate_greedy(example_prompts, max_tokens)
with vllm_runner(model, dtype=dtype) as vllm_model:
@ -94,6 +90,10 @@ def test_mamba_prefill_chunking_with_parallel_sampling(
# correctly for n > 1 decoding steps inside a
# chunked prefill forward pass (where we have both prefills
# and decoding together )
if 'plamo-2' in model:
dtype = "float" # use a different dtype for plamo
sampling_params = SamplingParams(n=3,
temperature=1,
seed=0,
@ -125,20 +125,14 @@ def test_mamba_prefill_chunking(hf_runner, vllm_runner, example_prompts,
example_prompts.pop(3)
example_prompts.pop(2)
dtype = "half" # use a different dtype for Bamba
elif "Zamba2" in model:
example_prompts.pop(7)
dtype = "half"
elif "plamo-2-1b" in model:
example_prompts.pop(7)
model_kwargs = {
"use_mamba_kernels": False, # mamba kernels are not installed so HF
# don't use them
}
if "Zamba2" in model:
# Zamba2 HF implementation automatically checks if mamba kernels are
# installed
model_kwargs = {}
with hf_runner(model, dtype=dtype, model_kwargs=model_kwargs) as hf_model:
with hf_runner(model, dtype=dtype) as hf_model:
non_chunked = hf_model.generate_greedy(example_prompts, max_tokens)
with vllm_runner(model,
@ -208,7 +202,8 @@ def test_mamba_cache_cg_padding(
# This test is for verifying that mamba cache is padded to CG captured
# batch size. If it's not, a torch RuntimeError will be raised because
# tensor dimensions aren't compatible
vllm_config = EngineArgs(model=model).create_engine_config()
vllm_config = EngineArgs(model=model,
trust_remote_code=True).create_engine_config()
while len(example_prompts) == vllm_config.pad_for_cudagraph(
len(example_prompts)):
example_prompts.append(example_prompts[0])

View File

@ -204,6 +204,8 @@ _TEXT_GENERATION_EXAMPLE_MODELS = {
trust_remote_code=True),
"PhiMoEForCausalLM": _HfExamplesInfo("microsoft/Phi-3.5-MoE-instruct",
trust_remote_code=True),
"Plamo2ForCausalLM": _HfExamplesInfo("pfnet/plamo-2-1b",
trust_remote_code=True),
"QWenLMHeadModel": _HfExamplesInfo("Qwen/Qwen-7B-Chat",
trust_remote_code=True),
"Qwen2ForCausalLM": _HfExamplesInfo("Qwen/Qwen2-7B-Instruct",

View File

@ -2838,6 +2838,13 @@ def _get_and_verify_dtype(
else:
torch_dtype = config_dtype
if config.model_type == "plamo2":
logger.info(
"For PLaMo2, we cast models to bfloat16 instead of using "
"float16 by default. This is because float16 does not work."
)
torch_dtype = torch.bfloat16
from vllm.platforms import current_platform
if (current_platform.is_cpu()
and current_platform.get_cpu_architecture()
@ -2867,6 +2874,11 @@ def _get_and_verify_dtype(
"using float16 by default. Please specify `dtype` if you "
"want to use float16.")
torch_dtype = torch.bfloat16
elif dtype == "float16" and config.model_type == "plamo2":
logger.warning(
"For PLaMo2, using float16 is unstable and might cause "
"unexpected behavior. Please use bfloat16 or float32 instead.")
torch_dtype = torch.float16
else:
if dtype not in _STR_DTYPE_TO_TORCH_DTYPE:
raise ValueError(f"Unknown dtype: {dtype}")

View File

@ -0,0 +1,746 @@
# SPDX-License-Identifier: Apache-2.0
"""Inference-only PLaMo2 model."""
import math
from typing import Iterable, Optional, Tuple
import torch
from torch import nn
from transformers import PretrainedConfig, PreTrainedModel
from vllm.attention.backends.abstract import AttentionMetadata
from vllm.attention.layer import Attention
from vllm.config import CacheConfig, VllmConfig
from vllm.distributed import get_tensor_model_parallel_world_size
from vllm.forward_context import get_forward_context
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
MergedColumnParallelLinear,
QKVParallelLinear,
RowParallelLinear)
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.mamba.ops.causal_conv1d import (
causal_conv1d_fn, causal_conv1d_update)
from vllm.model_executor.layers.mamba.ops.mamba_ssm import (
selective_scan_fn, selective_state_update)
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler
from vllm.model_executor.layers.vocab_parallel_embedding import (
DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding)
from vllm.model_executor.model_loader.weight_utils import (
composed_weight_loader, default_weight_loader, sharded_weight_loader)
from vllm.model_executor.models.interfaces import (HasInnerState, IsHybrid,
SupportsV0Only)
from vllm.model_executor.models.mamba_cache import (MambaCacheManager,
MambaCacheParams)
from vllm.model_executor.models.utils import maybe_prefix
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.model_executor.utils import set_weight_attrs
from vllm.sequence import IntermediateTensors
from vllm.utils import LayerBlockType
# Only used for type hinting.
class Plamo2Config(PretrainedConfig): # type: ignore
model_type: str = "plamo2"
hidden_size: int
num_hidden_layers: int
rms_norm_eps: float
# Attention
num_attention_heads: int
hidden_size_per_head: int
num_key_value_heads: int
# Mamba
mamba_d_state: int
mamba_d_conv: int
mamba_num_heads: int
mamba_step: int
# MLP
intermediate_size: int
# Tokenizer
vocab_size: int
class Plamo2PreTrainedModel(PreTrainedModel): # type: ignore
def _init_weights(self, module: torch.nn.Module) -> None:
std = 0.02
if isinstance(module, nn.Linear):
module.weight.data.normal_(mean=0.0, std=std)
if module.bias is not None:
module.bias.data.zero_()
elif isinstance(module, nn.Embedding):
module.weight.data.normal_(mean=0.0, std=std)
if module.padding_idx is not None:
module.weight.data[module.padding_idx].zero_()
def get_initial_dt_bias(num_heads: int) -> torch.Tensor:
dt_min = 0.001
dt_max = 0.1
dt = torch.exp(
torch.rand(num_heads) * (math.log(dt_max) - math.log(dt_min)) +
math.log(dt_min))
dt = torch.clamp(dt, 1e-4)
inv_dt = dt + torch.log(-torch.expm1(-dt))
return inv_dt
def is_mamba(config: Plamo2Config, i: int) -> bool:
assert config.mamba_step > 1
if config.num_hidden_layers <= (config.mamba_step // 2):
# use attention in last layer
return i != config.num_hidden_layers - 1
return (i % config.mamba_step) != (config.mamba_step // 2)
# TODO(Shinichi): Replace this with RMSNorm.
def _rms_norm(hidden_states: torch.Tensor, weight: torch.Tensor,
eps: float) -> torch.Tensor:
input_shape = hidden_states.shape
hidden_states = hidden_states.reshape(input_shape[:-1] + weight.shape)
input_dtype = hidden_states.dtype
hidden_states = hidden_states.to(torch.float32)
variance = hidden_states.pow(2).mean(-1, keepdim=True)
hidden_states = hidden_states * torch.rsqrt(variance + eps)
hidden_states = hidden_states.to(input_dtype)
hidden_states = weight * hidden_states
return hidden_states.reshape(input_shape)
def _swiglu(h: torch.Tensor) -> torch.Tensor:
h0, h1 = h.chunk(2, dim=-1)
return torch.nn.functional.silu(h0) * h1
# Adapted from transformers.models.mamba.modeling_mamba.MambaMixer
class Plamo2MambaMixer(nn.Module):
# TODO(Shinichi): Rebase on Mamba2 implementation.
def __init__(self,
config: Plamo2Config,
cache_config: CacheConfig,
quant_config: QuantizationConfig,
max_model_len: int,
prefix: str = "",
**kwargs) -> None:
super().__init__()
self.config = config
self.hidden_size = config.hidden_size
self.ssm_state_size = config.mamba_d_state
self.conv_kernel_size = config.mamba_d_conv
self.intermediate_size = (config.mamba_num_heads *
config.hidden_size_per_head)
self.hidden_size_per_head = config.hidden_size_per_head
self.num_heads = config.mamba_num_heads
self.time_step_rank = max(64, self.hidden_size // 16)
self.use_conv_bias = False
self.use_bias = False
self.conv1d = ColumnParallelLinear(
input_size=self.conv_kernel_size,
output_size=self.intermediate_size,
bias=self.use_conv_bias,
)
# unsqueeze to fit conv1d weights shape into the linear weights shape.
# Can't do this in `weight_loader` since it already exists in
# `ColumnParallelLinear` and `set_weight_attrs`
# doesn't allow to override it
self.conv1d.weight.data = self.conv1d.weight.data.unsqueeze(1)
self.in_proj = MergedColumnParallelLinear(
self.hidden_size,
[self.intermediate_size] * 2,
bias=self.use_bias,
prefix=f"{prefix}.in_proj",
)
# selective projection used to make dt, B and C input dependent
self.bcdt_proj = RowParallelLinear(
self.intermediate_size,
self.time_step_rank + self.ssm_state_size * 2,
bias=False,
prefix=f"{prefix}.bcdt_proj",
)
# time step projection (discretization) -
# In the forward we need to apply dt_proj without the bias,
# as the bias is added in the selective scan kernel.
self.dt_proj = ColumnParallelLinear(
self.time_step_rank,
self.num_heads,
bias=False,
prefix=f"{prefix}.dt_proj",
)
self.dt_bias = torch.nn.Parameter(get_initial_dt_bias(self.num_heads))
tp_size = get_tensor_model_parallel_world_size()
self.A = nn.Parameter(
torch.empty(
self.intermediate_size // tp_size,
self.ssm_state_size,
dtype=torch.float32,
))
self.D = nn.Parameter(torch.ones(self.intermediate_size // tp_size))
set_weight_attrs(self.D, {"weight_loader": sharded_weight_loader(0)})
a_weight_loader = composed_weight_loader(
sharded_weight_loader(0), lambda x: -torch.exp(x.float()))
set_weight_attrs(self.A, {"weight_loader": a_weight_loader})
self.out_proj = RowParallelLinear(
self.intermediate_size,
self.hidden_size,
bias=self.use_bias,
input_is_parallel=True,
prefix=f"{prefix}.out_proj",
)
# The activation function is fixed to SiLU.
self.activation = "silu"
self.dt_norm = RMSNorm(self.time_step_rank, eps=config.rms_norm_eps)
self.B_norm = RMSNorm(self.ssm_state_size, eps=config.rms_norm_eps)
self.C_norm = RMSNorm(self.ssm_state_size, eps=config.rms_norm_eps)
def forward(
self,
hidden_states: torch.Tensor,
mamba_cache_params: MambaCacheParams,
**kwargs,
) -> torch.Tensor:
attn_metadata: AttentionMetadata = get_forward_context().attn_metadata
# 1. Gated MLP's linear projection
projected_states = self.in_proj(hidden_states)[0]
# Reshaping the projected states as in modeling_plamo.py.
length = len(hidden_states)
projected_states = projected_states.reshape(length, self.num_heads, -1)
gate, hidden_states = torch.split(
projected_states,
[self.hidden_size_per_head, self.hidden_size_per_head],
dim=-1)
hidden_states = hidden_states.reshape(length, -1).transpose(0, 1)
gate = gate.reshape(length, -1).transpose(0, 1)
# 2. Convolution sequence transformation
conv_weights = self.conv1d.weight.view(self.conv1d.weight.size(0),
self.conv1d.weight.size(2))
if attn_metadata.query_start_loc is not None \
and attn_metadata.context_lens_tensor is not None:
# |---------- N-1 iteration --------|
# |---------------- N iteration ---------------------|
# |- tokenA -|......................|-- newTokens ---|
# |---------- context_len ----------|
# |-------------------- seq_len ---------------------|
# |-- query_len ---|
hidden_states = causal_conv1d_fn(
hidden_states,
conv_weights,
self.conv1d.bias,
activation=self.activation,
conv_states=mamba_cache_params.conv_state,
has_initial_state=attn_metadata.context_lens_tensor > 0,
cache_indices=mamba_cache_params.state_indices_tensor,
query_start_loc=attn_metadata.query_start_loc)
else:
hidden_states = causal_conv1d_update(
hidden_states.transpose(0, 1),
mamba_cache_params.conv_state,
conv_weights,
self.conv1d.bias,
self.activation,
conv_state_indices=mamba_cache_params.state_indices_tensor)
hidden_states = hidden_states.transpose(0, 1)
# 3. State Space Model sequence transformation
# 3.a. input varying initialization of time_step, B and C
ssm_parameters = self.bcdt_proj(hidden_states.transpose(-2, -1))[0]
# Splitting the ssm_parameters as in modeling_plamo.py.
B, C, time_step = torch.split(
ssm_parameters,
[self.ssm_state_size, self.ssm_state_size, self.time_step_rank],
dim=-1,
)
time_step = self.dt_norm(time_step.contiguous())
B = self.B_norm(B.contiguous())
C = self.C_norm(C.contiguous())
discrete_time_step = self.dt_proj(time_step)[0].transpose(-2, -1)
# 3.c perform the recurrence y ← SSM(A, B, C)(x)
time_proj_bias = (self.dt_bias.float() if hasattr(
self.dt_proj, "bias") else None)
# Broadcasting as in modeling_plamo.py.
discrete_time_step = discrete_time_step.transpose(
0, 1)[..., None].expand(-1, -1, self.hidden_size_per_head)
discrete_time_step = discrete_time_step.reshape(
-1, self.intermediate_size).transpose(0, 1)
time_proj_bias = time_proj_bias[...,
None].expand(-1,
self.hidden_size_per_head)
time_proj_bias = time_proj_bias.reshape(self.intermediate_size)
if attn_metadata.query_start_loc is not None \
and attn_metadata.context_lens_tensor is not None:
scan_outputs = selective_scan_fn(
hidden_states,
mamba_cache_params.ssm_state,
discrete_time_step,
self.A,
B.transpose(-2, -1),
C.transpose(-2, -1),
self.D.float(),
gate,
time_proj_bias,
delta_softplus=True,
cache_indices=mamba_cache_params.state_indices_tensor,
has_initial_state=attn_metadata.context_lens_tensor > 0,
query_start_loc=attn_metadata.query_start_loc)
else:
scan_outputs = selective_state_update(
mamba_cache_params.ssm_state,
hidden_states.transpose(0, 1),
discrete_time_step.transpose(0, 1),
self.A,
B,
C,
self.D,
gate.transpose(0, 1),
time_proj_bias,
dt_softplus=True,
state_batch_indices=mamba_cache_params.state_indices_tensor)
scan_outputs = scan_outputs.transpose(0, 1)
# 4. Final linear projection
contextualized_states = self.out_proj(scan_outputs.transpose(-2,
-1))[0]
return contextualized_states
class DenseMLP(nn.Module):
def __init__(
self,
config: Plamo2Config,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
) -> None:
super().__init__()
self.hidden_size = config.hidden_size
self.intermediate_size = config.intermediate_size
self.gate_up_proj = MergedColumnParallelLinear(
self.hidden_size, [self.intermediate_size] * 2,
bias=False,
prefix=f"{prefix}.gate_up_proj",
quant_config=quant_config)
self.down_proj = RowParallelLinear(self.intermediate_size,
self.hidden_size,
bias=False,
prefix=f"{prefix}.down_proj",
quant_config=quant_config)
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
h = self.gate_up_proj(hidden_states)[0]
h = _swiglu(h)
output, _ = self.down_proj(h)
return output # type: ignore
class Plamo2AttentionMixer(nn.Module):
def __init__(self,
config: Plamo2Config,
cache_config: CacheConfig,
quant_config: QuantizationConfig,
max_model_len: int | None = None,
prefix: str = "",
**kwargs) -> None:
super().__init__()
self.hidden_size = config.hidden_size
tp_size = get_tensor_model_parallel_world_size()
self.total_num_heads = config.num_attention_heads
assert self.total_num_heads % tp_size == 0
self.num_heads = self.total_num_heads // tp_size
self.total_num_kv_heads = config.num_key_value_heads
if self.total_num_kv_heads >= tp_size:
# Number of KV heads is greater than TP size, so we partition
# the KV heads across multiple tensor parallel GPUs.
assert self.total_num_kv_heads % tp_size == 0
else:
# Number of KV heads is less than TP size, so we replicate
# the KV heads across multiple tensor parallel GPUs.
assert tp_size % self.total_num_kv_heads == 0
self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size)
self.head_dim = config.hidden_size_per_head
self.q_size = self.num_heads * self.head_dim
self.kv_size = self.num_kv_heads * self.head_dim
self.scaling = self.head_dim**-0.5
self.qkv_proj = QKVParallelLinear(
config.hidden_size,
self.head_dim,
self.total_num_heads,
self.total_num_kv_heads,
bias=False,
quant_config=quant_config,
)
self.o_proj = RowParallelLinear(self.total_num_heads * self.head_dim,
config.hidden_size,
bias=False,
quant_config=quant_config)
self.rope_theta = config.rope_theta if hasattr(config,
"rope_theta") else 10000
self.rope_scaling = config.rope_scaling if hasattr(
config, "rope_scaling") else None
assert max_model_len is not None, "max_model_len must be provided"
self.rotary_emb = get_rope(
self.head_dim,
rotary_dim=self.head_dim,
max_position=max_model_len,
base=self.rope_theta,
rope_scaling=self.rope_scaling,
)
self.q_weight = torch.nn.Parameter(
torch.ones((self.num_heads, config.hidden_size_per_head)))
self.k_weight = torch.nn.Parameter(
torch.ones((self.num_kv_heads, config.hidden_size_per_head)))
self.attn = Attention(
self.num_heads,
self.head_dim,
self.scaling,
num_kv_heads=self.num_kv_heads,
cache_config=cache_config,
prefix=f"{prefix}.attn",
)
def forward(
self,
positions: torch.Tensor,
hidden_states: torch.Tensor,
residual: Optional[torch.Tensor],
**kwargs,
) -> torch.Tensor:
qkv, _ = self.qkv_proj(hidden_states)
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
q = _rms_norm(q, self.q_weight, 1e-6)
k = _rms_norm(k, self.k_weight, 1e-6)
q, k = self.rotary_emb(positions, q, k)
attn_output = self.attn(q, k, v)
output, _ = self.o_proj(attn_output)
return output
class Plamo2DecoderLayer(nn.Module):
def __init__(self,
vllm_config: VllmConfig,
layer_idx: int,
max_model_len: int | None = None,
prefix: str = "",
**kwargs) -> None:
super().__init__()
config = vllm_config.model_config.hf_config
cache_config = vllm_config.cache_config
quant_config = vllm_config.quant_config
max_model_len = vllm_config.scheduler_config.max_model_len
self.is_mamba = is_mamba(config, layer_idx)
if self.is_mamba:
self.mixer = Plamo2MambaMixer(config=config,
cache_config=cache_config,
quant_config=quant_config,
max_model_len=max_model_len,
prefix=f"{prefix}.mixer")
else:
self.mixer = Plamo2AttentionMixer(config=config,
cache_config=cache_config,
quant_config=quant_config,
max_model_len=max_model_len,
prefix=f"{prefix}.mixer")
self.mlp = DenseMLP(config=config,
quant_config=quant_config,
prefix=f"{prefix}.mlp")
self.pre_mixer_norm = RMSNorm(config.hidden_size,
eps=config.rms_norm_eps)
self.post_mixer_norm = RMSNorm(config.hidden_size,
eps=config.rms_norm_eps)
self.pre_mlp_norm = RMSNorm(config.hidden_size,
eps=config.rms_norm_eps)
self.post_mlp_norm = RMSNorm(config.hidden_size,
eps=config.rms_norm_eps)
def forward(
self,
positions: torch.Tensor,
hidden_states: torch.Tensor,
residual: Optional[torch.Tensor],
mamba_cache_params: MambaCacheParams,
**kwargs,
):
if residual is None:
residual = hidden_states
hidden_states = self.pre_mixer_norm(hidden_states)
else:
hidden_states, residual = self.pre_mixer_norm(
hidden_states, residual)
hidden_states = self.mixer(positions=positions,
hidden_states=hidden_states,
residual=residual,
mamba_cache_params=mamba_cache_params)
hidden_states = self.post_mixer_norm(hidden_states)
# Fully Connected
hidden_states, residual = self.pre_mlp_norm(hidden_states, residual)
hidden_states = self.mlp(hidden_states)
hidden_states = self.post_mlp_norm(hidden_states)
return hidden_states, residual
class Plamo2Decoder(torch.nn.Module):
def __init__(self, vllm_config: VllmConfig, prefix: str = "") -> None:
super().__init__()
num_hidden_layers = vllm_config.model_config.hf_config.num_hidden_layers
self.layers = nn.ModuleList([
Plamo2DecoderLayer(vllm_config=vllm_config,
layer_idx=i,
prefix=f"{prefix}.layers.{i}")
for i in range(num_hidden_layers)
])
def forward(
self,
positions: torch.Tensor,
hidden_states: torch.Tensor,
residual: Optional[torch.Tensor],
mamba_cache_params: MambaCacheParams,
) -> torch.Tensor:
mamba_cache_index = 0
for layer in self.layers:
layer_mamba_cache_params = None
if layer.is_mamba:
layer_mamba_cache_params = mamba_cache_params.at_layer_idx(
mamba_cache_index)
mamba_cache_index += 1
hidden_states, residual = layer(
positions=positions,
hidden_states=hidden_states,
residual=residual,
mamba_cache_params=layer_mamba_cache_params)
return hidden_states, residual
class Plamo2Model(Plamo2PreTrainedModel):
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__(vllm_config.model_config.hf_config)
config = vllm_config.model_config.hf_config
self.config = config
self.padding_idx = config.pad_token_id
self.vocab_size = config.vocab_size
self.org_vocab_size = config.vocab_size
self.embed_tokens = VocabParallelEmbedding(
self.vocab_size,
config.hidden_size,
org_num_embeddings=config.vocab_size,
prefix=f"{prefix}.embed_tokens",
)
self.layers = Plamo2Decoder(vllm_config, prefix=f"{prefix}.layers")
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.post_init()
def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
mamba_cache_params: MambaCacheParams,
intermediate_tensors: Optional[IntermediateTensors] = None,
inputs_embeds: Optional[torch.Tensor] = None,
) -> torch.Tensor:
# TODO(Shinichi): Implement pipeline parallelism.
hidden_states = self.embed_tokens(input_ids)
residual = None
hidden_states, residual = self.layers(
positions=positions,
hidden_states=hidden_states,
residual=residual,
mamba_cache_params=mamba_cache_params)
hidden_states, _ = self.norm(hidden_states, residual)
return hidden_states
class Plamo2ForCausalLM(Plamo2PreTrainedModel, HasInnerState, IsHybrid,
SupportsV0Only):
packed_modules_mapping = {
"qkv_proj": [
"q_proj",
"k_proj",
"v_proj",
],
}
def __init__(self, *, vllm_config: VllmConfig, prefix: str = "") -> None:
config = vllm_config.model_config.hf_config
scheduler_config = vllm_config.scheduler_config
assert not vllm_config.cache_config.enable_prefix_caching, \
"PLaMo2 currently does not support prefix caching"
super().__init__(config)
self.config = config
self.vllm_config = vllm_config
self.model_config = vllm_config.model_config
self.scheduler_config = scheduler_config
# ModelConfig.get_head_size assumes head_dim is set or calculated as
# hidden_size // num_attention_heads. However, this is not always
# the case for PLaMo2, as indicated by the FIXME comment.
self.config.head_dim = self.config.hidden_size_per_head
self.model = Plamo2Model(vllm_config=vllm_config,
prefix=maybe_prefix(prefix, "model"))
self.vocab_size = self.config.vocab_size
self.unpadded_vocab_size = self.config.vocab_size
num_embeddings = ((self.vocab_size + 15) // 16) * 16
self.lm_head = ParallelLMHead(
num_embeddings,
self.config.hidden_size,
org_num_embeddings=self.config.vocab_size,
padding_size=DEFAULT_VOCAB_PADDING_SIZE,
prefix=f"{prefix}.lm_head",
)
if self.config.tie_word_embeddings:
self.lm_head = self.lm_head.tie_weights(self.model.embed_tokens)
# Used to track and store by the Mamba cache between steps.
self.mamba_cache: Optional[MambaCacheManager] = None
self.logits_processor = LogitsProcessor(self.unpadded_vocab_size,
self.config.vocab_size)
self.sampler = get_sampler()
# Initialize weights and apply final processing
self.post_init()
def forward(self,
input_ids: torch.Tensor,
positions: torch.Tensor,
intermediate_tensors: Optional[IntermediateTensors] = None,
inputs_embeds: Optional[torch.Tensor] = None,
**kwargs):
if self.mamba_cache is None:
num_mamba_layers = self.model_config.get_num_layers_by_block_type(
self.vllm_config.parallel_config, LayerBlockType.mamba)
self.mamba_cache = MambaCacheManager(
self.vllm_config, self.lm_head.weight.dtype, num_mamba_layers,
*self._get_mamba_cache_shape())
mamba_cache_params = self.mamba_cache.current_run_tensors(**kwargs)
hidden_states = self.model(input_ids, positions, mamba_cache_params,
intermediate_tensors, inputs_embeds)
return hidden_states
def copy_inputs_before_cuda_graphs(self, input_buffers, **kwargs):
return self.mamba_cache.copy_inputs_before_cuda_graphs(
input_buffers, **kwargs)
def get_seqlen_agnostic_capture_inputs(self, batch_size: int):
return self.mamba_cache.get_seqlen_agnostic_capture_inputs(batch_size)
def _get_mamba_cache_shape(
self) -> Tuple[Tuple[int, int], Tuple[int, int]]:
world_size = get_tensor_model_parallel_world_size()
hidden_size = (self.config.mamba_num_heads *
self.config.hidden_size_per_head)
conv_state_shape = (
hidden_size // world_size,
self.config.mamba_d_conv - 1,
)
temporal_state_shape = (
hidden_size // world_size,
self.config.mamba_d_state,
)
return conv_state_shape, temporal_state_shape
def compute_logits(
self,
hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> Optional[torch.Tensor]:
logits = self.logits_processor(self.lm_head, hidden_states,
sampling_metadata)
return logits
def sample(
self,
logits: Optional[torch.Tensor],
sampling_metadata: SamplingMetadata,
) -> Optional[SamplerOutput]:
next_tokens = self.sampler(logits, sampling_metadata)
return next_tokens
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
params_dict = dict(self.named_parameters())
for name, loaded_weight in weights:
# Both tie_word_embeddings=True and lm_head.weight in the safetensor
# at the same time causes dict key access error.
if name == "lm_head.weight" and self.config.tie_word_embeddings:
assert "lm_head.weight" not in params_dict
continue
# Update the weight names to be compatible with the vllm version
# of the model.
# Do not change the order of the replacements.
replacements = {
# Rename incompatible weight names.
".A_log": ".A",
".B_norm_weight": ".B_norm.weight",
".C_norm_weight": ".C_norm.weight",
".dt_norm_weight": ".dt_norm.weight",
}
# Apply replacements based on the defined mappings
for old, new in replacements.items():
if old in name:
name = name.replace(old, new)
# Broadcast the loaded weight to match the model's parameter shape.
if ".A" in name:
loaded_weight = loaded_weight[:, None, None].expand(
-1, self.config.hidden_size_per_head,
self.config.mamba_d_state)
loaded_weight = loaded_weight.reshape(
-1, self.config.mamba_d_state)
elif ".D" in name:
loaded_weight = loaded_weight[:, None].expand(
-1, self.config.hidden_size_per_head)
loaded_weight = loaded_weight.reshape(-1)
# Offset parameter with vllm's RMSNorm haven't been supported yet.
if ".pre_mixer_norm" in name:
loaded_weight += 1.0
elif ".post_mixer_norm" in name:
loaded_weight += 1.0 / 5
elif ".pre_mlp_norm" in name:
loaded_weight += 1.0
elif ".post_mlp_norm" in name:
loaded_weight += 1.0 / (5**1.5)
elif "model.norm.weight" in name:
loaded_weight += 1.0
param = params_dict[name]
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
weight_loader(param, loaded_weight)

View File

@ -99,6 +99,7 @@ _TEXT_GENERATION_MODELS = {
"Phi3ForCausalLM": ("phi3", "Phi3ForCausalLM"),
"Phi3SmallForCausalLM": ("phi3_small", "Phi3SmallForCausalLM"),
"PhiMoEForCausalLM": ("phimoe", "PhiMoEForCausalLM"),
"Plamo2ForCausalLM": ("plamo2", "Plamo2ForCausalLM"),
"QWenLMHeadModel": ("qwen", "QWenLMHeadModel"),
"Qwen2ForCausalLM": ("qwen2", "Qwen2ForCausalLM"),
"Qwen2MoeForCausalLM": ("qwen2_moe", "Qwen2MoeForCausalLM"),