Joe Runde d58268c56a
[V1] Make v1 more testable (#9888)
Signed-off-by: Joe Runde <Joseph.Runde@ibm.com>
2024-11-06 11:57:35 -08:00

1165 lines
43 KiB
Python

# Adapted from
# https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/llama/modeling_llama.py
# Copyright 2023 The vLLM team.
# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
#
# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
# and OPT implementations in this library. It has been modified from its
# original forms to accommodate minor architectural differences compared
# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Inference-only MiniCPM-V model compatible with HuggingFace weights."""
import math
import re
from functools import partial
from typing import (Any, Callable, Iterable, List, Literal, Mapping, Optional,
Tuple, TypedDict, Union)
import torch
import torch.types
from PIL import Image
from torch import nn
from transformers import PretrainedConfig
from typing_extensions import NotRequired
from vllm.attention import AttentionMetadata
from vllm.config import CacheConfig, LoRAConfig, MultiModalConfig
from vllm.inputs import (INPUT_REGISTRY, DecoderOnlyInputs, DummyData,
InputContext, token_inputs)
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.resampler import (BaseResampler, Resampler2,
get_2d_sincos_pos_embed)
from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler
from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead
from vllm.model_executor.model_loader.utils import set_default_torch_dtype
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.models.llama import LlamaModel
from vllm.model_executor.models.minicpm import MiniCPMModel
from vllm.model_executor.models.module_mapping import MultiModelKeys
from vllm.model_executor.models.qwen2 import Qwen2Model
from vllm.model_executor.models.utils import LLMWrapper
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.base import MultiModalInputs
from vllm.multimodal.image import cached_get_image_processor
from vllm.multimodal.utils import cached_get_tokenizer
from vllm.sequence import IntermediateTensors, SequenceData
from .idefics2_vision_model import Idefics2VisionTransformer
from .interfaces import SupportsLoRA, SupportsMultiModal, SupportsPP
from .utils import is_pp_missing_parameter
_KEYS_TO_MODIFY_MAPPING = {
"llm.lm_head": "lm_head",
}
RawImageType = Union[Image.Image, torch.Tensor]
class MiniCPMVRawImageInput(TypedDict):
"""Input mapper input with auxiliary data for computing image bounds."""
image: RawImageType
# Image bounds token ids in 0-dim scaler tensor.
im_start_id: torch.Tensor
im_end_id: torch.Tensor
slice_start_id: NotRequired[torch.Tensor]
slice_end_id: NotRequired[torch.Tensor]
class MiniCPMVImagePixelInputs(TypedDict):
type: Literal["pixel_values"]
data: List[torch.Tensor]
"""
Shape: `(batch_size * num_images, num_channels, height, width)`
Note that the image size may vary, so we pass it as a list
instead of a batched tensor.
"""
image_bounds: torch.Tensor
"""
Shape: `(batch_size * num_images, 2)`
This should be in `(start, stop)` format.
"""
tgt_sizes: torch.Tensor
"""
Shape: `(batch_size * num_images, 2)`
This should be in `(height, width)` format.
"""
class MiniCPMVImageEmbeddingInputs(TypedDict):
type: Literal["image_embeds"]
data: torch.Tensor
"""
Shape: `(batch_size * num_images, image_feature_size, hidden_size)`
`hidden_size` must match the hidden size of language model backbone.
instead of a batched tensor.
"""
image_bounds: torch.Tensor
"""
Shape: `(batch_size * num_images, 2)`
This should be in `(start, stop)` format.
"""
MiniCPMVImageInputs = Union[MiniCPMVImagePixelInputs,
MiniCPMVImageEmbeddingInputs]
DEFAULT_LN = partial(nn.LayerNorm, eps=1e-6)
class Resampler2_5(BaseResampler):
def __init__(self,
num_queries: int,
embed_dim: int,
num_heads: int,
kv_dim: Optional[int] = None,
norm_layer: Callable[[int], nn.LayerNorm] = DEFAULT_LN,
max_size: Tuple[int, int] = (70, 70),
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "") -> None:
super().__init__(num_queries,
embed_dim,
num_heads,
kv_dim,
norm_layer,
quant_config=quant_config,
prefix=prefix)
self.max_size = max_size
self._set_2d_pos_cache(self.max_size)
self.apply(self._init_weights)
def _set_2d_pos_cache(self,
max_size: Tuple[int, int],
device: torch.types.Device = "cpu") -> None:
pos_embed_arr = get_2d_sincos_pos_embed(self.embed_dim,
max_size,
version=(2, 5))
pos_embed = torch.from_numpy(pos_embed_arr).float().to(device)
self.register_buffer("pos_embed", pos_embed, persistent=False)
def _adjust_pos_cache(self, tgt_sizes: torch.Tensor,
device: torch.types.Device) -> None:
max_h = tgt_sizes[:, 0].max().item()
max_w = tgt_sizes[:, 1].max().item()
assert isinstance(max_h, int) and isinstance(max_w, int)
if max_h > self.max_size[0] or max_w > self.max_size[1]:
self.max_size = (
max(max_h, self.max_size[0]),
max(max_w, self.max_size[1]),
)
self._set_2d_pos_cache(self.max_size, device)
def forward(self, x: torch.Tensor,
tgt_sizes: torch.Tensor) -> torch.Tensor:
assert x.shape[0] == tgt_sizes.shape[0]
bs = x.shape[0]
device = x.device
dtype = x.dtype
patch_len = tgt_sizes[:, 0] * tgt_sizes[:, 1]
self._adjust_pos_cache(tgt_sizes, device=device)
max_patch_len = patch_len.max().item()
assert isinstance(max_patch_len, int)
key_padding_mask = torch.zeros((bs, max_patch_len),
dtype=torch.bool,
device=device)
pos_embed = []
for i in range(bs):
tgt_h, tgt_w = tgt_sizes[i].tolist()
pos_embed.append(self.pos_embed[:tgt_h, :tgt_w, :].reshape(
(tgt_h * tgt_w, -1)).to(dtype)) # patches * D
key_padding_mask[i, patch_len[i]:] = True
pos_embed = torch.nn.utils.rnn.pad_sequence(pos_embed,
batch_first=True,
padding_value=0.0).permute(
1, 0,
2) # BLD => L * B * D
x, _ = self.kv_proj(x) # B * L * D
x = self.ln_kv(x).permute(1, 0, 2) # L * B * D
q = self.ln_q(self.query) # Q * D
out = self.attn(
self._repeat(q, bs), # Q * B * D
x + pos_embed, # L * B * D + L * B * D
x,
key_padding_mask=key_padding_mask,
)[0]
# out: Q * B * D
x = out.permute(1, 0, 2) # B * Q * D
x = self.ln_post(x)
x = x @ self.proj
return x
def _build_image_input(ctx: InputContext,
image: RawImageType) -> MiniCPMVRawImageInput:
tokenizer = cached_get_tokenizer(
ctx.model_config.tokenizer,
trust_remote_code=ctx.model_config.trust_remote_code)
if hasattr(tokenizer, "slice_start_id"):
return MiniCPMVRawImageInput(
image=image,
im_start_id=torch.tensor(tokenizer.im_start_id),
im_end_id=torch.tensor(tokenizer.im_end_id),
slice_start_id=torch.tensor(tokenizer.slice_start_id),
slice_end_id=torch.tensor(tokenizer.slice_end_id))
else:
return MiniCPMVRawImageInput(
image=image,
im_start_id=torch.tensor(tokenizer.im_start_id),
im_end_id=torch.tensor(tokenizer.im_end_id))
def get_version_by_config(config: PretrainedConfig) -> Tuple[int, ...]:
version_float = getattr(config, "version", None)
# The old configs do not include version number
# TODO: Remove this after the HF repos are updated
if version_float is None:
if config.hidden_size == 2304 and config.query_num == 64:
return (2, 0)
return (2, 5)
version_str = str(version_float)
return tuple(int(x) for x in version_str.split("."))
def get_max_minicpmv_image_tokens(ctx: InputContext):
hf_config = ctx.get_hf_config()
return getattr(hf_config, "query_num", 64)
def dummy_seq_data_for_minicpmv(seq_len: int, num_images: int):
return SequenceData.from_prompt_token_counts((0, seq_len))
def dummy_image_for_minicpmv(ctx: InputContext, hf_config: PretrainedConfig,
num_images: int):
width = height = hf_config.image_size
image = _build_image_input(ctx,
image=Image.new("RGB", (width, height),
color=0))
return {"image": [image] if num_images == 1 else [image] * num_images}
def dummy_data_for_minicpmv(ctx: InputContext, seq_len: int,
mm_counts: Mapping[str, int]):
hf_config = ctx.get_hf_config()
num_images = mm_counts["image"]
seq_data = dummy_seq_data_for_minicpmv(seq_len, num_images)
mm_data = dummy_image_for_minicpmv(ctx, hf_config, num_images)
return DummyData(seq_data, mm_data)
def input_processor_for_minicpmv(ctx: InputContext, inputs: DecoderOnlyInputs):
multi_modal_data = inputs.get("multi_modal_data")
if multi_modal_data is None or "image" not in multi_modal_data:
return inputs
model_config = ctx.model_config
version = get_version_by_config(model_config.hf_config)
tokenizer = cached_get_tokenizer(
model_config.tokenizer,
trust_remote_code=model_config.trust_remote_code)
image_processor = cached_get_image_processor(model_config.tokenizer)
def get_placeholder(image_size: Tuple[int, int], num_image: int):
if version == (2, 0) or version == (2, 5):
return image_processor. \
get_slice_image_placeholder(image_size)
return image_processor. \
get_slice_image_placeholder(image_size, num_image)
prompt = inputs.get("prompt")
token_ids = inputs.get("prompt_token_ids")
if prompt is None:
prompt = tokenizer.decode(token_ids)
pattern = "(<image>./</image>)"
images = multi_modal_data["image"]
image_tags = re.findall(pattern, prompt)
if len(image_tags) == 0:
new_token_ids = token_ids
new_prompt = prompt
else:
if isinstance(images, dict):
image_size_list = images.get("image_size_list")
images = [images.get("image_embeds")]
else:
if isinstance(images, Image.Image):
images = [images]
image_size_list = [image.size for image in images]
text_chunks = prompt.split(pattern)
new_prompt_chunks: List[str] = []
for i in range(len(image_size_list)):
new_prompt_chunks += [
text_chunks[i],
get_placeholder(image_size_list[i], i)
]
new_prompt_chunks.append(text_chunks[-1])
new_prompt = "".join(new_prompt_chunks)
new_token_ids = tokenizer.encode(new_prompt)
multi_modal_data["image"] = [
_build_image_input(ctx, image) for image in images
]
return token_inputs(
prompt_token_ids=new_token_ids,
prompt=new_prompt,
multi_modal_data=multi_modal_data,
)
def input_mapper_for_minicpmv(ctx: InputContext, data: object):
model_config = ctx.model_config
image_processor = cached_get_image_processor(
model_config.model, trust_remote_code=model_config.trust_remote_code)
if image_processor is None:
raise RuntimeError("No HuggingFace processor is available "
"to process the image object")
if not isinstance(data, list):
raise ValueError(
"Image input must be list of MiniCPMVImageInput, got (%s)", data)
if len(data) > 0 and isinstance(data[0]['image'], torch.Tensor):
batch_data = {
"image_embeds": data[0]['image'],
}
else:
batch_data = image_processor \
.preprocess([img["image"] for img in data], return_tensors="pt") \
.data
if len(data) > 0:
batch_data["im_start_id"] = data[0]["im_start_id"]
batch_data["im_end_id"] = data[0]["im_end_id"]
if "slice_start_id" in data[0]:
batch_data["slice_start_id"] = data[0]["slice_start_id"]
batch_data["slice_end_id"] = data[0]["slice_end_id"]
return MultiModalInputs(batch_data)
class MiniCPMVBaseModel(nn.Module, SupportsMultiModal, SupportsPP):
"""
The abstract class of MiniCPMV can only be inherited, but cannot be
instantiated.
"""
def __init__(
self,
config: PretrainedConfig,
multimodal_config: MultiModalConfig,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
):
super().__init__()
# All MiniCPM-V models disable `tie_word_embeddings` but
# `PretrainedConfig.tie_word_embeddings` defaults to True; we cannot
# check `tie_word_embeddings` until vLLM integrate MiniCPM-V model
# and config class
self.config = config
self.multimodal_config = multimodal_config
self.version = get_version_by_config(self.config)
self.llm = self.init_llm(config,
cache_config,
quant_config,
prefix="llm")
self.vpm = self.init_vision_module(config, quant_config, prefix="vpm")
param_dtype = torch.get_default_dtype()
self.vpm.to(dtype=param_dtype)
self.vision_dim = (self.vpm.embed_dim if self.version == (2, 0) else
self.vpm.embeddings.embed_dim)
self.embed_dim = self.config.hidden_size
self.resampler = self.init_resampler(self.embed_dim,
self.vision_dim,
quant_config=quant_config,
prefix="resampler")
self.resampler.to(device="cuda", dtype=param_dtype)
# TODO: why is there _KEYS_TO_MODIFY_MAPPING? lm_head should be in llm
self.lm_head = ParallelLMHead(config.vocab_size,
config.hidden_size,
quant_config=quant_config,
prefix="llm.lm_head")
self.logits_processor = LogitsProcessor(config.vocab_size)
self.sampler = get_sampler()
self.make_empty_intermediate_tensors = (
self.llm.make_empty_intermediate_tensors)
def get_embedding(
self,
input_ids: torch.Tensor,
image_inputs: Optional[MiniCPMVImageInputs],
) -> Tuple[torch.Tensor, torch.Tensor]:
vlm_embedding: torch.Tensor = self.llm.embed_tokens(input_ids)
if hasattr(self.config, "scale_emb"):
vlm_embedding *= self.config.scale_emb
if image_inputs is None: # No image
vision_hidden_states = torch.tensor([], device=input_ids.device)
else:
if image_inputs["type"] == "image_embeds":
vision_hidden_states = (image_inputs["data"].type(
vlm_embedding.dtype).to(vlm_embedding.device))
else:
vision_hidden_states = self.get_vision_hidden_states(
image_inputs)
# See NOTE in _parse_and_validate_inputs
image_bounds = image_inputs["image_bounds"]
if len(image_bounds) > 0:
image_indices = torch.stack([
torch.arange(start, end, dtype=torch.long)
for start, end in image_bounds.tolist()
]).to(vlm_embedding.device)
vlm_embedding.scatter_(
0,
image_indices.view(-1, 1).repeat(1,
vlm_embedding.shape[-1]),
vision_hidden_states.view(-1,
vision_hidden_states.shape[-1]),
)
return vlm_embedding, vision_hidden_states
def _get_image_bounds(
self,
input_ids: torch.Tensor,
im_start_id: torch.Tensor,
im_end_id: torch.Tensor,
slice_start_id: Optional[torch.Tensor] = None,
slice_end_id: Optional[torch.Tensor] = None) -> torch.Tensor:
# All the images in the batch should share the same special image
# bound token ids.
start_cond = input_ids == im_start_id[0]
end_cond = input_ids == im_end_id[0]
if slice_start_id is not None:
start_cond |= (input_ids == slice_start_id[0])
end_cond |= (input_ids == slice_end_id[0])
image_start_tokens, = torch.where(start_cond)
image_start_tokens += 1
image_end_tokens, = torch.where(end_cond)
valid_image_nums = max(len(image_start_tokens), len(image_end_tokens))
if valid_image_nums == 0:
return torch.zeros((0, 2), device=input_ids.device)
return torch.hstack([
image_start_tokens[:valid_image_nums].unsqueeze(-1),
image_end_tokens[:valid_image_nums].unsqueeze(-1),
])
def _parse_and_validate_inputs(
self,
input_ids: torch.Tensor,
**kwargs: object,
) -> Optional[MiniCPMVImageInputs]:
pixel_values = kwargs.pop("pixel_values", [])
tgt_sizes = kwargs.pop("tgt_sizes", [])
im_start_id = kwargs.pop("im_start_id", None)
im_end_id = kwargs.pop("im_end_id", None)
slice_start_id = kwargs.pop("slice_start_id", None)
slice_end_id = kwargs.pop("slice_end_id", None)
image_embeds = kwargs.pop("image_embeds", None)
if image_embeds is not None:
return MiniCPMVImageEmbeddingInputs(
image_bounds=self._get_image_bounds(input_ids, im_start_id,
im_end_id, slice_start_id,
slice_end_id),
data=image_embeds,
type="image_embeds",
)
if not isinstance(pixel_values, (torch.Tensor, list)):
raise ValueError("Incorrect type of pixel values. "
f"Got type: {type(pixel_values)}")
if not isinstance(tgt_sizes, (torch.Tensor, list)):
raise ValueError("Incorrect type of target sizes. "
f"Got type: {type(tgt_sizes)}")
if len(pixel_values) != len(tgt_sizes):
raise ValueError("Inconsistent batch lengths, found: "
f"{len(pixel_values)} vs. {len(tgt_sizes)}")
pixel_values_flat: List[torch.Tensor] = []
tgt_sizes_flat: List[torch.Tensor] = []
for pixel_b, tgt_b in zip(pixel_values, tgt_sizes):
if len(pixel_b) != len(tgt_b):
raise ValueError("Inconsistent N lengths, found: "
f"{len(pixel_b)} vs {len(tgt_b)}")
for pixel_n, tgt_n in zip(pixel_b, tgt_b):
pixel_values_flat += pixel_n
tgt_sizes_flat += tgt_n
# NOTE: Input IDs does not contain image tokens during memory profiling,
# so we allow it to be empty
if len(pixel_values_flat) != len(tgt_sizes_flat):
raise ValueError("Inconsistent flattened lengths, found: "
f"{len(pixel_values_flat)} vs. "
f"{len(tgt_sizes_flat)}")
if len(pixel_values_flat) == 0:
return None
if im_start_id is None:
return None
return MiniCPMVImagePixelInputs(
image_bounds=self._get_image_bounds(input_ids, im_start_id,
im_end_id, slice_start_id,
slice_end_id),
data=pixel_values_flat,
tgt_sizes=torch.stack(tgt_sizes_flat),
type="pixel_values",
)
def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata,
intermediate_tensors: Optional[IntermediateTensors] = None,
**kwargs: Any,
) -> torch.Tensor:
if intermediate_tensors is not None:
vlm_embeddings = None
else:
image_inputs = self._parse_and_validate_inputs(input_ids, **kwargs)
vlm_embeddings, _ = self.get_embedding(input_ids, image_inputs)
# always pass the input via `inputs_embeds`
# to make sure the computation graph is consistent
# for `torch.compile` integration
input_ids = None
output = self.llm(
input_ids=input_ids,
positions=positions,
kv_caches=kv_caches,
attn_metadata=attn_metadata,
intermediate_tensors=intermediate_tensors,
inputs_embeds=vlm_embeddings,
)
return output
def compute_logits(
self,
hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> Optional[torch.Tensor]:
logits = self.logits_processor(self.lm_head, hidden_states,
sampling_metadata)
return logits
def sample(
self,
logits: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> Optional[SamplerOutput]:
next_tokens = self.sampler(logits, sampling_metadata)
return next_tokens
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
stacked_params_mapping = [
# (param_name, shard_name, shard_id)
("qkv_proj", "q_proj", "q"),
("qkv_proj", "k_proj", "k"),
("qkv_proj", "v_proj", "v"),
("gate_up_proj", "gate_proj", 0),
("gate_up_proj", "up_proj", 1),
]
params_dict = dict(self.named_parameters())
for name, loaded_weight in weights:
for key_to_modify, new_key in _KEYS_TO_MODIFY_MAPPING.items():
if key_to_modify in name:
name = name.replace(key_to_modify, new_key)
if "rotary_emb.inv_freq" in name:
continue
if ("rotary_emb.cos_cached" in name
or "rotary_emb.sin_cached" in name):
# Models trained using ColossalAI may include these tensors in
# the checkpoint. Skip them.
continue
use_default_weight_loading = False
if self.is_default_weight_loading(name):
use_default_weight_loading = True
else:
for param_name, weight_name, shard_id in stacked_params_mapping:
if weight_name not in name:
continue
if is_pp_missing_parameter(
name.replace(weight_name, param_name), self):
continue
param = params_dict[name.replace(weight_name, param_name)]
weight_loader = param.weight_loader
weight_loader(param, loaded_weight, shard_id)
break
else:
use_default_weight_loading = True
if use_default_weight_loading:
if is_pp_missing_parameter(name, self):
continue
param = params_dict[name]
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
weight_loader(param, loaded_weight)
def get_mm_mapping(self) -> MultiModelKeys:
"""
Get the module prefix in multimodal models
"""
return MultiModelKeys.from_string_field(language_model="llm",
connector="resampler",
tower_model="vpm")
def init_llm(
self,
config: PretrainedConfig,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
) -> nn.Module:
raise NotImplementedError
def init_vision_module(
self,
config: PretrainedConfig,
quant_config: Optional[QuantizationConfig],
prefix: str = "",
) -> nn.Module:
raise NotImplementedError
def init_resampler(self,
embed_dim: int,
vision_dim: int,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "") -> nn.Module:
raise NotImplementedError
def get_vision_embedding(
self,
pixel_values: List[torch.Tensor],
patch_attn_mask: Optional[torch.Tensor] = None,
tgt_sizes: Optional[torch.Tensor] = None,
) -> torch.Tensor:
raise NotImplementedError
def get_vision_hidden_states(self,
data: MiniCPMVImageInputs) -> torch.Tensor:
raise NotImplementedError
def is_default_weight_loading(self, name: str) -> bool:
raise NotImplementedError
class MiniCPMV2_0(MiniCPMVBaseModel):
def __init__(
self,
config: PretrainedConfig,
multimodal_config: MultiModalConfig,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
):
super().__init__(config, multimodal_config, cache_config, quant_config)
assert self.version == (2, 0)
def init_llm(
self,
config: PretrainedConfig,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
) -> nn.Module:
return LLMWrapper(MiniCPMModel(config,
cache_config=cache_config,
quant_config=quant_config,
prefix=prefix),
name="model")
def init_vision_module(
self,
config: PretrainedConfig,
quant_config: Optional[QuantizationConfig],
prefix: str = "",
) -> nn.Module:
# TODO :refactor this vision model
try:
import timm
except ImportError:
raise ImportError("Please install timm==0.9.10") from ImportError
with set_default_torch_dtype(torch.float16):
model = timm.create_model(
"vit_so400m_patch14_siglip_384.webli",
pretrained=False,
num_classes=0,
dynamic_img_size=True,
dynamic_img_pad=True,
)
if (isinstance(model, timm.models.VisionTransformer)
and model.attn_pool is not None):
model.attn_pool = torch.nn.Identity()
if self.config.drop_vision_last_layer:
model.blocks = model.blocks[:-1]
return model
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.model.embed_tokens(input_ids)
def init_resampler(self,
embed_dim: int,
vision_dim: int,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "") -> nn.Module:
with set_default_torch_dtype(torch.float16):
resampler = Resampler2(embed_dim=embed_dim,
num_heads=embed_dim // 128,
grid_size=int(
math.sqrt(self.config.query_num)),
kv_dim=vision_dim,
adaptive=False,
do_post_projection=True,
quant_config=quant_config,
prefix=prefix)
return resampler
def get_vision_embedding(
self,
pixel_values: List[torch.Tensor],
patch_attn_mask: Optional[torch.Tensor] = None,
tgt_sizes: Optional[torch.Tensor] = None,
) -> torch.Tensor:
res = []
dtype = self.vpm.pos_embed.data.dtype
for pixel_value in pixel_values:
H, W = pixel_value[0].shape[-2:]
tgt_size = (
math.ceil(H / self.vpm.patch_embed.patch_size[0]),
math.ceil(W / self.vpm.patch_embed.patch_size[0]),
)
vision_embedding = self.vpm.forward_features(
pixel_value.unsqueeze(0).type(dtype))
if (hasattr(self.vpm, "num_prefix_tokens")
and self.vpm.num_prefix_tokens > 0):
vision_embedding = vision_embedding[:, self.vpm.
num_prefix_tokens:]
res.append(self.resampler(vision_embedding, tgt_size))
return torch.vstack(res)
def get_vision_hidden_states(self,
data: MiniCPMVImageInputs) -> torch.Tensor:
pixel_values = data["data"]
return self.get_vision_embedding(pixel_values)
def is_default_weight_loading(self, name: str) -> bool:
return "resampler" in name or "vpm" in name
class MiniCPMV2_5(MiniCPMVBaseModel, SupportsLoRA):
packed_modules_mapping = {
"qkv_proj": [
"q_proj",
"k_proj",
"v_proj",
],
"gate_up_proj": [
"gate_proj",
"up_proj",
],
}
# LoRA specific attributes
supported_lora_modules = [
# vision encoder
"fc1",
"fc2",
"out_proj",
# language model
"qkv_proj", # same name with vision encoder
"o_proj",
"gate_up_proj",
"down_proj",
# resampler
"kv_proj",
]
# BitandBytes specific attributes
default_bitsandbytes_target_modules = [
".gate_proj.",
".down_proj.",
".up_proj.",
".q_proj.",
".k_proj.",
".v_proj.",
".o_proj.",
# vision encoder
".fc1.",
".fc2.",
# Currently, vllm does not support BNB quantization for the `out_proj`
# of the resampler, so it's necessary to distinguish between the
# vision encoder and the resampler's out_proj. The same applies to
# MiniCPMV2_6.
".self_attn.out_proj.", # vision encoder out_proj
# resampler
".kv_proj.",
]
bitsandbytes_stacked_params_mapping = {
# shard_name, weight_name, index
"q_proj": ("qkv_proj", 0),
"k_proj": ("qkv_proj", 1),
"v_proj": ("qkv_proj", 2),
"gate_proj": ("gate_up_proj", 0),
"up_proj": ("gate_up_proj", 1),
}
embedding_modules = {}
embedding_padding_modules = []
def __init__(
self,
config: PretrainedConfig,
multimodal_config: MultiModalConfig,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
lora_config: Optional[LoRAConfig] = None,
):
super().__init__(config, multimodal_config, cache_config, quant_config)
assert self.version == (2, 5)
def init_llm(
self,
config: PretrainedConfig,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
) -> nn.Module:
return LLMWrapper(LlamaModel(config,
cache_config=cache_config,
quant_config=quant_config,
prefix=prefix),
name="model")
def init_vision_module(
self,
config: PretrainedConfig,
quant_config: Optional[QuantizationConfig],
prefix: str = "",
) -> nn.Module:
model = Idefics2VisionTransformer(config.vision_config,
quant_config=quant_config,
prefix=prefix)
if self.config.drop_vision_last_layer:
model.encoder.layers = model.encoder.layers[:-1]
return model
def init_resampler(self,
embed_dim: int,
vision_dim: int,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "") -> nn.Module:
with set_default_torch_dtype(torch.float16):
resampler = Resampler2_5(num_queries=self.config.query_num,
embed_dim=embed_dim,
num_heads=embed_dim // 128,
kv_dim=vision_dim,
quant_config=quant_config,
prefix=prefix)
return resampler
def get_vision_embedding(
self,
pixel_values: List[torch.Tensor],
patch_attn_mask: Optional[torch.Tensor] = None,
tgt_sizes: Optional[torch.Tensor] = None,
) -> torch.Tensor:
vision_embedding = self.vpm(pixel_values,
patch_attention_mask=patch_attn_mask)
vision_embedding = self.resampler(vision_embedding, tgt_sizes)
return vision_embedding
def get_vision_hidden_states(self,
data: MiniCPMVImageInputs) -> torch.Tensor:
pixel_values = data["data"]
tgt_sizes = data["tgt_sizes"]
device = self.vpm.embeddings.position_embedding.weight.device
dtype = self.vpm.embeddings.position_embedding.weight.dtype
all_pixel_values_lst = [
i.flatten(end_dim=1).permute(1, 0) for i in pixel_values
]
max_patches = (tgt_sizes[:, 0] * tgt_sizes[:, 1]).max().item()
assert isinstance(max_patches, int)
all_pixel_values = torch.nn.utils.rnn.pad_sequence(
all_pixel_values_lst, batch_first=True, padding_value=0.0)
B, L, _ = all_pixel_values.shape
all_pixel_values = all_pixel_values.permute(0, 2,
1).reshape(B, 3, -1, L)
patch_attn_mask = torch.zeros((B, 1, max_patches),
dtype=torch.bool,
device=device)
for i in range(B):
patch_attn_mask[i, :tgt_sizes[i][0] * tgt_sizes[i][1]] = True
return self.get_vision_embedding(all_pixel_values.type(dtype),
patch_attn_mask, tgt_sizes)
def is_default_weight_loading(self, name: str) -> bool:
return "resampler" in name
class MiniCPMV2_6(MiniCPMVBaseModel, SupportsLoRA):
packed_modules_mapping = {
"qkv_proj": [
"q_proj",
"k_proj",
"v_proj",
],
"gate_up_proj": [
"gate_proj",
"up_proj",
],
}
# LoRA specific attributes
supported_lora_modules = [
# vision encoder
"fc1",
"fc2",
"out_proj",
# language model
"qkv_proj", # same name with vision encoder
"o_proj",
"gate_up_proj",
"down_proj",
# resampler
"kv_proj",
]
# BitandBytes specific attributes
default_bitsandbytes_target_modules = [
".gate_proj.",
".down_proj.",
".up_proj.",
".q_proj.",
".k_proj.",
".v_proj.",
".o_proj.",
# vision encoder
".fc1.",
".fc2.",
".self_attn.out_proj.",
# resampler
".kv_proj.",
]
bitsandbytes_stacked_params_mapping = {
# shard_name, weight_name, index
"q_proj": ("qkv_proj", 0),
"k_proj": ("qkv_proj", 1),
"v_proj": ("qkv_proj", 2),
"gate_proj": ("gate_up_proj", 0),
"up_proj": ("gate_up_proj", 1),
}
embedding_modules = {}
embedding_padding_modules = []
def __init__(
self,
config: PretrainedConfig,
multimodal_config: MultiModalConfig,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
):
super().__init__(config, multimodal_config, cache_config, quant_config)
assert self.version == (2, 6)
def init_llm(
self,
config: PretrainedConfig,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
) -> nn.Module:
return LLMWrapper(Qwen2Model(config,
cache_config=cache_config,
quant_config=quant_config,
prefix=prefix),
name="model")
def init_vision_module(
self,
config: PretrainedConfig,
quant_config: Optional[QuantizationConfig],
prefix: str = "",
) -> nn.Module:
model = Idefics2VisionTransformer(config.vision_config,
quant_config=quant_config,
prefix=prefix)
if self.config.drop_vision_last_layer:
model.encoder.layers = model.encoder.layers[:-1]
return model
def init_resampler(self,
embed_dim: int,
vision_dim: int,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "") -> nn.Module:
with set_default_torch_dtype(torch.float16):
# The resampler in 2.6 remains consistent with the one in 2.5.
resampler = Resampler2_5(num_queries=self.config.query_num,
embed_dim=embed_dim,
num_heads=embed_dim // 128,
kv_dim=vision_dim,
quant_config=quant_config,
prefix=prefix)
return resampler
def get_vision_embedding(
self,
pixel_values: List[torch.Tensor],
patch_attn_mask: Optional[torch.Tensor] = None,
tgt_sizes: Optional[torch.Tensor] = None,
) -> torch.Tensor:
vision_embedding = self.vpm(
pixel_values,
patch_attention_mask=patch_attn_mask,
tgt_sizes=tgt_sizes,
)
return vision_embedding
def get_vision_hidden_states(self,
data: MiniCPMVImageInputs) -> torch.Tensor:
pixel_values = data["data"]
tgt_sizes = data["tgt_sizes"]
device = self.vpm.embeddings.position_embedding.weight.device
dtype = self.vpm.embeddings.position_embedding.weight.dtype
all_pixel_values_lst = [
i.flatten(end_dim=1).permute(1, 0) for i in pixel_values
]
max_patches = (tgt_sizes[:, 0] * tgt_sizes[:, 1]).max().item()
assert isinstance(max_patches, int)
all_pixel_values = torch.nn.utils.rnn.pad_sequence(
all_pixel_values_lst, batch_first=True, padding_value=0.0)
B, L, _ = all_pixel_values.shape
all_pixel_values = all_pixel_values.permute(0, 2,
1).reshape(B, 3, -1, L)
patch_attn_mask = torch.zeros((B, 1, max_patches),
dtype=torch.bool,
device=device)
for i in range(B):
patch_attn_mask[i, 0, :tgt_sizes[i][0] * tgt_sizes[i][1]] = True
vision_embedding = self.vpm(
all_pixel_values.type(dtype),
patch_attention_mask=patch_attn_mask,
tgt_sizes=tgt_sizes,
)
return self.resampler(vision_embedding, tgt_sizes)
def is_default_weight_loading(self, name: str) -> bool:
return "resampler" in name
_SUPPORT_VERSION = {
(2, 0): MiniCPMV2_0,
(2, 5): MiniCPMV2_5,
(2, 6): MiniCPMV2_6
}
@MULTIMODAL_REGISTRY.register_image_input_mapper(input_mapper_for_minicpmv)
@MULTIMODAL_REGISTRY.register_max_image_tokens(get_max_minicpmv_image_tokens)
@INPUT_REGISTRY.register_dummy_data(dummy_data_for_minicpmv)
@INPUT_REGISTRY.register_input_processor(input_processor_for_minicpmv)
class MiniCPMV(MiniCPMVBaseModel, SupportsLoRA):
"""
Different versions of MiniCPMV use different visual encoders and LLMs,
which is not conducive to the current integration logic of LoRA and
bitsandbytes in vLLM. Therefore, it is necessary to separate them.
"""
# Ensure that the LoRA support check passes when the class is not
# initialized, but set all these attributes to empty.
packed_modules_mapping = {}
supported_lora_modules = []
embedding_modules = {}
embedding_padding_modules = []
def __new__(cls,
config: PretrainedConfig,
multimodal_config: MultiModalConfig,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
lora_config: Optional[LoRAConfig] = None):
if not hasattr(config, "version"):
if config.hidden_size == 2304 and config.query_num == 64:
version = (2, 0)
else:
version = (2, 5)
else:
version = str(config.version).split(".")
version = tuple([int(x) for x in version])
# Dispatch class based on version
instance_class = _SUPPORT_VERSION.get(version)
if instance_class is None:
raise ValueError(
"Currently, MiniCPMV only supports versions 2.0, 2.5, and 2.6")
return instance_class(config, multimodal_config, cache_config,
quant_config)