[Misc] Clean up MiniCPM-V/O code (#15337)

Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
Cyrus Leung 2025-03-25 18:22:52 +08:00 committed by GitHub
parent 3e2f37a69a
commit a9e879b316
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
7 changed files with 521 additions and 651 deletions

View File

@ -361,6 +361,7 @@ def run_llava_next_video(questions: list[str],
engine_args = EngineArgs(
model="llava-hf/LLaVA-NeXT-Video-7B-hf",
max_model_len=8192,
max_num_seqs=2,
disable_mm_preprocessor_cache=args.disable_mm_preprocessor_cache,
)

View File

@ -163,24 +163,24 @@ VLM_TEST_SETTINGS = {
marks=[pytest.mark.core_model, pytest.mark.cpu_model],
),
#### Extended model tests
# "aria": VLMTestInfo(
# models=["rhymes-ai/Aria"],
# test_type=(VLMTestType.IMAGE, VLMTestType.MULTI_IMAGE),
# prompt_formatter=lambda img_prompt: f"<|im_start|>user\n{img_prompt}<|im_end|>\n<|im_start|>assistant\n ", # noqa: E501
# img_idx_to_prompt=lambda idx: "<fim_prefix><|img|><fim_suffix>\n",
# max_model_len=4096,
# max_num_seqs=2,
# auto_cls=AutoModelForImageTextToText,
# single_image_prompts=IMAGE_ASSETS.prompts({
# "stop_sign": "<vlm_image>Please describe the image shortly.",
# "cherry_blossom": "<vlm_image>Please infer the season with reason.", # noqa: E501
# }),
# multi_image_prompt="<vlm_image><vlm_image>Describe the two images shortly.", # noqa: E501
# stop_str=["<|im_end|>"],
# image_size_factors=[(0.10, 0.15)],
# max_tokens=64,
# marks=[large_gpu_mark(min_gb=64)],
# ),
"aria": VLMTestInfo(
models=["rhymes-ai/Aria"],
test_type=(VLMTestType.IMAGE, VLMTestType.MULTI_IMAGE),
prompt_formatter=lambda img_prompt: f"<|im_start|>user\n{img_prompt}<|im_end|>\n<|im_start|>assistant\n ", # noqa: E501
img_idx_to_prompt=lambda idx: "<fim_prefix><|img|><fim_suffix>\n",
max_model_len=4096,
max_num_seqs=2,
auto_cls=AutoModelForImageTextToText,
single_image_prompts=IMAGE_ASSETS.prompts({
"stop_sign": "<vlm_image>Please describe the image shortly.",
"cherry_blossom": "<vlm_image>Please infer the season with reason.", # noqa: E501
}),
multi_image_prompt="<vlm_image><vlm_image>Describe the two images shortly.", # noqa: E501
stop_str=["<|im_end|>"],
image_size_factors=[(0.10, 0.15)],
max_tokens=64,
marks=[large_gpu_mark(min_gb=64)],
),
"blip2": VLMTestInfo(
models=["Salesforce/blip2-opt-2.7b"],
test_type=VLMTestType.IMAGE,
@ -352,6 +352,7 @@ VLM_TEST_SETTINGS = {
prompt_formatter=lambda vid_prompt: f"USER: {vid_prompt} ASSISTANT:",
num_video_frames=16,
max_model_len=4096,
max_num_seqs=2,
auto_cls=AutoModelForVision2Seq,
vllm_output_post_proc=model_utils.llava_video_vllm_to_hf_output,
),
@ -384,7 +385,7 @@ VLM_TEST_SETTINGS = {
),
"minicpmo_26": VLMTestInfo(
models=["openbmb/MiniCPM-o-2_6"],
test_type=(VLMTestType.IMAGE, VLMTestType.MULTI_IMAGE),
test_type=(VLMTestType.IMAGE),
prompt_formatter=lambda img_prompt: f"<|begin_of_text|><|start_header_id|>user<|end_header_id|>\n\n{img_prompt}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n", # noqa: E501
img_idx_to_prompt=lambda idx: "(<image>./</image>)\n",
max_model_len=4096,
@ -393,9 +394,21 @@ VLM_TEST_SETTINGS = {
hf_output_post_proc=model_utils.minicpmv_trunc_hf_output,
patch_hf_runner=model_utils.minicpmo_26_patch_hf_runner,
),
"minicpmo_26_multi_image": VLMTestInfo(
models=["openbmb/MiniCPM-o-2_6"],
test_type=(VLMTestType.MULTI_IMAGE),
prompt_formatter=lambda img_prompt: f"<|begin_of_text|><|start_header_id|>user<|end_header_id|>\n\n{img_prompt}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n", # noqa: E501
img_idx_to_prompt=lambda idx: "(<image>./</image>)\n",
max_model_len=4096,
max_num_seqs=2,
get_stop_token_ids=lambda tok: tok.convert_tokens_to_ids(['<|im_end|>', '<|endoftext|>']), # noqa: E501
hf_output_post_proc=model_utils.minicpmv_trunc_hf_output,
patch_hf_runner=model_utils.minicpmo_26_patch_hf_runner,
marks=[large_gpu_mark(min_gb=32)],
),
"minicpmv_26": VLMTestInfo(
models=["openbmb/MiniCPM-V-2_6"],
test_type=(VLMTestType.IMAGE, VLMTestType.MULTI_IMAGE),
test_type=(VLMTestType.IMAGE),
prompt_formatter=lambda img_prompt: f"<|begin_of_text|><|start_header_id|>user<|end_header_id|>\n\n{img_prompt}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n", # noqa: E501
img_idx_to_prompt=lambda idx: "(<image>./</image>)\n",
max_model_len=4096,
@ -404,6 +417,18 @@ VLM_TEST_SETTINGS = {
hf_output_post_proc=model_utils.minicpmv_trunc_hf_output,
patch_hf_runner=model_utils.minicpmv_26_patch_hf_runner,
),
"minicpmv_26_multi_image": VLMTestInfo(
models=["openbmb/MiniCPM-V-2_6"],
test_type=(VLMTestType.MULTI_IMAGE),
prompt_formatter=lambda img_prompt: f"<|begin_of_text|><|start_header_id|>user<|end_header_id|>\n\n{img_prompt}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n", # noqa: E501
img_idx_to_prompt=lambda idx: "(<image>./</image>)\n",
max_model_len=4096,
max_num_seqs=2,
get_stop_token_ids=lambda tok: tok.convert_tokens_to_ids(['<|im_end|>', '<|endoftext|>']), # noqa: E501
hf_output_post_proc=model_utils.minicpmv_trunc_hf_output,
patch_hf_runner=model_utils.minicpmv_26_patch_hf_runner,
marks=[large_gpu_mark(min_gb=32)],
),
"molmo": VLMTestInfo(
models=["allenai/Molmo-7B-D-0924"],
test_type=(VLMTestType.IMAGE),

View File

@ -1,6 +1,5 @@
# SPDX-License-Identifier: Apache-2.0
import copy
from functools import partial
from typing import Optional, Union
@ -29,7 +28,7 @@ def _test_processing_correctness(
hit_rate: float,
num_batches: int,
simplify_rate: float,
ignore_mm_keys: Optional[list[str]] = None,
ignore_mm_keys: Optional[set[str]] = None,
):
model_info = HF_EXAMPLE_MODELS.find_hf_info(model_id)
model_info.check_available_online(on_fail="skip")
@ -145,7 +144,7 @@ def _test_processing_correctness_hf(
baseline_processor: BaseMultiModalProcessor,
cached_processor: BaseMultiModalProcessor,
batch_idx: int,
ignore_mm_keys: Optional[list[str]] = None,
ignore_mm_keys: Optional[set[str]] = None,
):
if model_config.hf_config.model_type in ("mllama", "whisper", "ultravox"):
# For some multimodal models, tokenizer will always add bos_token
@ -167,11 +166,12 @@ def _test_processing_correctness_hf(
hf_processor_mm_kwargs={},
)
assert _inputs_equal(
_assert_inputs_equal(
baseline_result,
cached_result,
ignore_mm_keys,
), f"Failed ({batch_idx=}, {prompt=}, {mm_data=})"
ignore_mm_keys=ignore_mm_keys,
msg=f"Failed ({batch_idx=}, {prompt=}, {mm_data=})",
)
baseline_tokenized_result = baseline_processor.apply(
token_prompt,
@ -179,11 +179,12 @@ def _test_processing_correctness_hf(
hf_processor_mm_kwargs={},
)
assert _inputs_equal(
_assert_inputs_equal(
baseline_result,
baseline_tokenized_result,
ignore_mm_keys,
), f"Failed ({batch_idx=}, {prompt=}, {mm_data=})"
ignore_mm_keys=ignore_mm_keys,
msg=f"Failed ({batch_idx=}, {prompt=}, {mm_data=})",
)
cached_tokenized_result = cached_processor.apply(
token_prompt,
@ -191,11 +192,12 @@ def _test_processing_correctness_hf(
hf_processor_mm_kwargs={},
)
assert _inputs_equal(
_assert_inputs_equal(
cached_result,
cached_tokenized_result,
ignore_mm_keys,
), f"Failed ({batch_idx=}, {prompt=}, {mm_data=})"
ignore_mm_keys=ignore_mm_keys,
msg=f"Failed ({batch_idx=}, {prompt=}, {mm_data=})",
)
def _test_processing_correctness_mistral(
@ -206,7 +208,7 @@ def _test_processing_correctness_mistral(
baseline_processor: BaseMultiModalProcessor,
cached_processor: BaseMultiModalProcessor,
batch_idx: int,
ignore_mm_keys: Optional[list[str]] = None,
ignore_mm_keys: Optional[set[str]] = None,
):
images = mm_data.get("image", [])
if not isinstance(images, list):
@ -233,11 +235,12 @@ def _test_processing_correctness_mistral(
hf_processor_mm_kwargs={},
)
assert _inputs_equal(
_assert_inputs_equal(
baseline_tokenized_result,
cached_tokenized_result,
ignore_mm_keys,
), f"Failed ({batch_idx=}, {prompt=}, {mm_data=})"
ignore_mm_keys=ignore_mm_keys,
msg=f"Failed ({batch_idx=}, {prompt=}, {mm_data=})",
)
# yapf: disable
@ -261,6 +264,7 @@ def _test_processing_correctness_mistral(
"TIGER-Lab/Mantis-8B-siglip-llama3",
"mistralai/Pixtral-12B-2409",
"mistral-community/pixtral-12b",
"openbmb/MiniCPM-Llama3-V-2_5",
"openbmb/MiniCPM-o-2_6",
"openbmb/MiniCPM-V-2_6",
"allenai/Molmo-7B-D-0924",
@ -290,7 +294,7 @@ def test_processing_correctness(
# In Ultravox, the audio_features can be different depending on padding
# The slight difference should not be a problem though, since
# attention_mask lets us ignore the difference.
ignore_mm_keys = ['audio_features']
ignore_mm_keys = {"audio_features"}
_test_processing_correctness(
model_id,
@ -328,38 +332,26 @@ def test_processing_correctness_phi3v(
)
def _inputs_equal(
def _assert_inputs_equal(
a: MultiModalInputs,
b: MultiModalInputs,
ignore_mm_keys: Optional[list[str]] = None,
*,
ignore_mm_keys: Optional[set[str]] = None,
msg: str = "",
):
return _drop_mm_kwargs_keys(a, ignore_mm_keys) == _drop_mm_kwargs_keys(
b, ignore_mm_keys)
if ignore_mm_keys is None:
ignore_mm_keys = set()
if msg is None:
assert "mm_kwargs" in a and "mm_kwargs" in b
else:
assert "mm_kwargs" in a and "mm_kwargs" in b, msg
def _drop_mm_kwargs_keys(
result: MultiModalInputs,
ignore_mm_keys: Optional[list[str]] = None,
) -> MultiModalInputs:
"""Drop specified keys from result['mm_kwargs'].
for key in ignore_mm_keys:
a["mm_kwargs"].pop(key, None)
b["mm_kwargs"].pop(key, None)
This is mainly to avoid doing exact match of audio_features in ultravox.
Args:
result: Result to drop keys from
ignore_mm_keys: List of keys to ignore, e.g. ['audio_features']
"""
if not ignore_mm_keys:
return result
if 'mm_kwargs' in result:
result = copy.deepcopy(result)
mm_kwargs = result['mm_kwargs']
for key in ignore_mm_keys:
mm_kwargs.pop(key, None)
for items in mm_kwargs._items_by_modality.values():
for item in items:
for key in ignore_mm_keys:
item.pop(key, None)
return result
if msg is None:
assert a == b
else:
assert a == b, msg

View File

@ -295,8 +295,6 @@ class Gemma3MultiModalProcessor(BaseMultiModalProcessor[Gemma3ProcessingInfo]):
# HF processor pops the `num_crops` kwarg, which is needed by vLLM
if (images := mm_data.get("images")) is not None:
assert isinstance(images, list)
parsed_images = (self._get_data_parser().parse_mm_data({
"image":
images

View File

@ -23,7 +23,7 @@
# limitations under the License.
"""Inference-only MiniCPM-O model compatible with HuggingFace weights."""
from collections.abc import Iterable, Mapping, Sequence
from typing import (Any, Callable, Dict, List, Literal, Optional, Set, Tuple,
from typing import (Any, Callable, Dict, Literal, Optional, Set, Tuple,
TypedDict, Union)
import torch
@ -43,24 +43,26 @@ from vllm.multimodal.parse import (AudioItem, AudioProcessorItems,
from vllm.multimodal.processing import PromptReplacement, PromptUpdate
from vllm.multimodal.profiling import ProcessorInputs
from vllm.sequence import IntermediateTensors
from vllm.utils import flatten_2d_lists
from .minicpmv import (MiniCPMV2_6, MiniCPMVDummyInputsBuilder,
MiniCPMVMultiModalDataParser,
MiniCPMVMultiModalProcessor, MiniCPMVProcessingInfo,
_minicpmv_field_config)
from .utils import AutoWeightsLoader, cast_overflow_tensors, maybe_prefix
from .utils import (AutoWeightsLoader, cast_overflow_tensors, flatten_bn,
maybe_prefix)
CPU_DEVICE = torch.device("cpu")
class MiniCPMOAudioFeatureInputs(TypedDict):
type: Literal["audio_features"]
data: torch.Tensor
audio_features: torch.Tensor
"""
Shape: `(batch_size * num_audios * num_slices, num_channels, length)`
Slice here means chunk. Audio that is too long will be split into slices,
which is the same as image.
Padding is used therefore `data` is `torch.Tensor`.
Padding is used therefore `audio_features` is `torch.Tensor`.
"""
audio_feature_lens: torch.Tensor
@ -68,7 +70,7 @@ class MiniCPMOAudioFeatureInputs(TypedDict):
Shape: `(batch_size * num_audios * num_slices)`
This should be feature length of each audio slice,
which equals to `data.shape[-1]`
which equals to `audio_features.shape[-1]`
"""
audio_bounds: torch.Tensor
@ -81,7 +83,7 @@ class MiniCPMOAudioFeatureInputs(TypedDict):
class MiniCPMOAudioEmbeddingInputs(TypedDict):
type: Literal["audio_embeds"]
data: List[torch.Tensor]
audio_embeds: torch.Tensor
"""
Shape: `(batch_size * num_images * num_slices, hidden_size)`
@ -102,18 +104,11 @@ MiniCPMOAudioInputs = Union[MiniCPMOAudioFeatureInputs,
def _minicpmo_field_config(hf_inputs: Mapping[str, torch.Tensor]):
audio_num_slices = hf_inputs.get("audio_num_slices", torch.empty(0))
return dict(
**_minicpmv_field_config(hf_inputs),
audio_features=MultiModalFieldConfig.flat_from_sizes(
"audio", audio_num_slices),
audio_feature_lens=MultiModalFieldConfig.flat_from_sizes(
"audio", audio_num_slices),
audio_num_slices=MultiModalFieldConfig.batched("audio"),
audio_orders_in_mm_data=MultiModalFieldConfig.batched("audio"),
audio_embeds=MultiModalFieldConfig.flat_from_sizes(
"audio", audio_num_slices),
audio_features=MultiModalFieldConfig.batched("audio"),
audio_feature_lens=MultiModalFieldConfig.batched("audio"),
audio_embeds=MultiModalFieldConfig.batched("audio"),
)
@ -153,9 +148,6 @@ class MiniCPMOMultiModalDataParser(MiniCPMVMultiModalDataParser):
class MiniCPMOProcessingInfo(MiniCPMVProcessingInfo):
audio_pattern = "(<audio>./</audio>)"
def get_supported_mm_modalities(self) -> List[str]:
return ["image", "video", "audio"]
def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]:
return {"image": None, "video": None, "audio": None}
@ -277,95 +269,47 @@ class MiniCPMOMultiModalProcessor(
mm_data: Mapping[str, object],
mm_kwargs: Mapping[str, object],
) -> Mapping[str, NestedTensors]:
mm_data = dict(mm_data)
if (audios := mm_data.get("audios")) is None:
return {}
audios = mm_data.pop("audios", [])
audio_embeds = mm_data.pop("audio_embeds", [])
if isinstance(audios, (list, torch.Tensor)) and len(audios) > 0:
audio_outputs = {
"audio_lens": [],
"audio_features": [],
"audio_feature_lens": [],
"audio_num_segments": []
}
for audio in audios:
single_audio_outputs = super().call_base_hf_processor(
prompt=self.info.audio_pattern,
mm_data={
"audios": audio,
"chunk_input": True
},
mm_kwargs=mm_kwargs)
audio_outputs["audio_lens"].append(len(audio))
audio_outputs["audio_features"].append(
single_audio_outputs["audio_features"])
audio_outputs["audio_num_segments"].append(
len(single_audio_outputs["audio_feature_lens"][0]))
audio_outputs["audio_feature_lens"] += \
single_audio_outputs["audio_feature_lens"]
audio_outputs["audio_features"] = [
audio_feature for single_audio_features in \
audio_outputs["audio_features"]
for audio_feature in single_audio_features
]
audio_outputs["audio_feature_lens"] = torch.cat(
audio_outputs["audio_feature_lens"])
elif len(audio_embeds):
audio_outputs = {
"audio_lens": [
self.info.get_audio_len_by_num_chunks(
sum(chunk_embeds.shape[0]
for chunk_embeds in single_audio_embeds))
for single_audio_embeds in audio_embeds
],
"audio_embeds": [
chunk_embeds for single_audio_embeds in audio_embeds
for chunk_embeds in single_audio_embeds
],
"audio_num_segments": [
len(single_audio_embeds)
for single_audio_embeds in audio_embeds
]
}
else:
audio_outputs = {}
return audio_outputs
parsed_audios = (self._get_data_parser().parse_mm_data({
"audio": audios
}).get_items("audio", AudioProcessorItems))
audio_inputs = self._base_call_hf_processor(
prompts=[self.info.audio_pattern] * len(parsed_audios),
mm_data={"audios": [[audio] for audio in parsed_audios]},
mm_kwargs={
**mm_kwargs, "chunk_input": True
},
out_keys={"audio_features", "audio_feature_lens"},
)
# Avoid padding since we need the output for each audio to be
# independent of other audios for the cache to work correctly
unpadded_audio_features = [
feat[:, :feature_len] for feat, feature_len in zip(
audio_inputs["audio_features"],
audio_inputs["audio_feature_lens"],
)
]
audio_inputs["audio_features"] = unpadded_audio_features
return audio_inputs
def get_placeholder_match_pattern(self) -> str:
return r"\(<(image|video|audio)>./</\1>\)"
def get_placeholder_split_pattern(self) -> str:
return r"\(<(?:image|video|audio)>./</(?:image|video|audio)>\)"
def process_mm_inputs(
self,
mm_data: Mapping[str, object],
mm_kwargs: Mapping[str, object],
) -> Mapping[str, Mapping[str, NestedTensors]]:
) -> Mapping[str, NestedTensors]:
return {
"image": self.process_images(mm_data, mm_kwargs),
"video": self.process_videos(mm_data, mm_kwargs),
"audio": self.process_audios(mm_data, mm_kwargs),
**super().process_mm_inputs(mm_data, mm_kwargs),
**self.process_audios(mm_data, mm_kwargs),
}
def get_modality_num_counter(self, modality: str) -> str:
if modality == "audio":
return "audio_lens"
return super().get_modality_num_counter(modality)
def get_num_slices_by_modality(self, inputs: Dict[str, object],
modality: str, index: int) -> int:
if modality == "audio":
return inputs["audio"]["audio_num_segments"][index]
return super().get_num_slices_by_modality(inputs, modality, index)
def get_prompt_texts_by_modality(self, inputs: Dict[str, object],
modality: str, index: int) -> str:
if modality == "audio":
return self.get_audio_prompt_texts(
inputs["audio"]["audio_lens"][index])
return super().get_prompt_texts_by_modality(inputs, modality, index)
def _get_prompt_updates(
self,
mm_items: MultiModalDataItems,
@ -622,86 +566,84 @@ class MiniCPMO(MiniCPMV2_6):
# Copied from HF repo of MiniCPM-o-2_6,
# designed for batched inputs and outputs
def get_audio_hidden_states(self, data: MiniCPMOAudioInputs,
chunk_length: int) -> torch.Tensor:
chunk_length: int) -> list[torch.Tensor]:
wavforms = data.get(
"data",
"audio_features",
[]) # (bs, 80, frames) or [], multi audios need filled in advance
audio_feature_lens_raw = [data.get("audio_feature_lens",
[])] # list, [[x1, x2], [y1], [z1]]
# exist audio
if len(wavforms) > 0:
audio_feature_lens = torch.hstack(audio_feature_lens_raw)
batch_size, _, max_mel_seq_len = wavforms.shape
max_seq_len = (max_mel_seq_len - 1) // 2 + 1
# Create a sequence tensor of shape (batch_size, max_seq_len)
seq_range = (torch.arange(
0,
max_seq_len,
dtype=audio_feature_lens.dtype,
device=audio_feature_lens.device).unsqueeze(0).expand(
batch_size, max_seq_len))
lengths_expand = audio_feature_lens.unsqueeze(1).expand(
batch_size, max_seq_len)
# Create mask
padding_mask = seq_range >= lengths_expand # 1 for padded values
audio_attention_mask_ = padding_mask.view(
batch_size, 1, 1, max_seq_len).expand(batch_size, 1,
max_seq_len, max_seq_len)
audio_attention_mask = audio_attention_mask_.to(
dtype=self.apm.conv1.weight.dtype,
device=self.apm.conv1.weight.device)
if chunk_length > 0:
chunk_num_frame = int(chunk_length * 50)
chunk_mask = self.subsequent_chunk_mask(
size=max_seq_len,
chunk_size=chunk_num_frame,
num_left_chunks=-1,
device=audio_attention_mask_.device,
)
audio_attention_mask_ = torch.logical_or(
audio_attention_mask_, torch.logical_not(chunk_mask))
audio_attention_mask[audio_attention_mask_] = float("-inf")
audio_states = self.apm(
wavforms, attention_mask=audio_attention_mask).hidden_states[
self.audio_encoder_layer]
audio_embeds = self.audio_projection_layer(audio_states)
audio_embeds = audio_embeds.transpose(1, 2)
audio_embeds = self.audio_avg_pooler(audio_embeds)
audio_embeds = audio_embeds.transpose(1, 2)
_, feature_lens_after_pooling = \
self._get_feat_extract_output_lengths(audio_feature_lens)
num_audio_tokens = feature_lens_after_pooling
final_audio_embeds = []
idx = 0
for i in range(len(audio_feature_lens_raw)):
target_audio_embeds = []
for _ in range(len(audio_feature_lens_raw[i])):
target_audio_embeds.append(
audio_embeds[idx, :num_audio_tokens[idx], :])
idx += 1
final_audio_embeds.append(target_audio_embeds)
return final_audio_embeds
else:
if len(wavforms) == 0:
return []
audio_feature_lens = torch.hstack(audio_feature_lens_raw)
batch_size, _, max_mel_seq_len = wavforms.shape
max_seq_len = (max_mel_seq_len - 1) // 2 + 1
# Create a sequence tensor of shape (batch_size, max_seq_len)
seq_range = (torch.arange(
0,
max_seq_len,
dtype=audio_feature_lens.dtype,
device=audio_feature_lens.device).unsqueeze(0).expand(
batch_size, max_seq_len))
lengths_expand = audio_feature_lens.unsqueeze(1).expand(
batch_size, max_seq_len)
# Create mask
padding_mask = seq_range >= lengths_expand # 1 for padded values
audio_attention_mask_ = padding_mask.view(
batch_size, 1, 1, max_seq_len).expand(batch_size, 1, max_seq_len,
max_seq_len)
audio_attention_mask = audio_attention_mask_.to(
dtype=self.apm.conv1.weight.dtype,
device=self.apm.conv1.weight.device)
if chunk_length > 0:
chunk_num_frame = int(chunk_length * 50)
chunk_mask = self.subsequent_chunk_mask(
size=max_seq_len,
chunk_size=chunk_num_frame,
num_left_chunks=-1,
device=audio_attention_mask_.device,
)
audio_attention_mask_ = torch.logical_or(
audio_attention_mask_, torch.logical_not(chunk_mask))
audio_attention_mask[audio_attention_mask_] = float("-inf")
audio_states = self.apm(
wavforms, attention_mask=audio_attention_mask).hidden_states[
self.audio_encoder_layer]
audio_embeds = self.audio_projection_layer(audio_states)
audio_embeds = audio_embeds.transpose(1, 2)
audio_embeds = self.audio_avg_pooler(audio_embeds)
audio_embeds = audio_embeds.transpose(1, 2)
_, feature_lens_after_pooling = \
self._get_feat_extract_output_lengths(audio_feature_lens)
num_audio_tokens = feature_lens_after_pooling
final_audio_embeds = []
idx = 0
for i in range(len(audio_feature_lens_raw)):
target_audio_embeds = []
for _ in range(len(audio_feature_lens_raw[i])):
target_audio_embeds.append(
audio_embeds[idx, :num_audio_tokens[idx], :])
idx += 1
final_audio_embeds.append(target_audio_embeds)
return final_audio_embeds
def get_embedding_with_audios(self, vlm_embedding: torch.Tensor,
audio_inputs: Optional[MiniCPMOAudioInputs],
audio_inputs: MiniCPMOAudioInputs,
chunk_length: int) -> torch.Tensor:
device, dtype = vlm_embedding.device, vlm_embedding.dtype
if audio_inputs["type"] == "audio_embeds":
audio_embeddings = audio_inputs["data"]
audio_embeddings = [
audio_embeddings[i].to(device=device, dtype=dtype)
for i in range(len(audio_embeddings))
item.to(device=device, dtype=dtype)
for item in audio_inputs["audio_embeds"]
]
else:
audio_embeddings = self.get_audio_hidden_states(
@ -746,40 +688,56 @@ class MiniCPMO(MiniCPMV2_6):
def _parse_and_validate_audio_inputs(
self, input_ids: torch.Tensor,
**kwargs: object) -> Tuple[MiniCPMOAudioInputs]:
audio_features = kwargs.pop("audio_features", [])
audio_feature_lens = kwargs.pop("audio_feature_lens", [])
**kwargs: object) -> Optional[MiniCPMOAudioInputs]:
audio_features = kwargs.pop("audio_features", None)
audio_embeds = kwargs.pop("audio_embeds", None)
audio_start_id = kwargs.pop("audio_start_id", None)
audio_end_id = kwargs.pop("audio_end_id", None)
if audio_features is None and audio_embeds is None:
return None
audio_start_id = kwargs.pop("audio_start_id")
if not isinstance(audio_start_id, torch.Tensor):
raise ValueError("Incorrect type of audio_start_id. "
f"Got type: {type(audio_start_id)}")
audio_end_id = kwargs.pop("audio_end_id")
if not isinstance(audio_end_id, torch.Tensor):
raise ValueError("Incorrect type of audio_end_id. "
f"Got type: {type(audio_end_id)}")
if audio_embeds is not None:
audio_embeds = [
audio_embeds[i][j] for i in range(len(audio_embeds))
for j in range(len(audio_embeds[i]))
]
if not isinstance(audio_embeds, (torch.Tensor, list)):
raise ValueError("Incorrect type of audio_embeds. "
f"Got type: {type(audio_embeds)}")
return MiniCPMOAudioEmbeddingInputs(
type="audio_embeds",
audio_embeds=flatten_bn(flatten_2d_lists(audio_embeds),
concat=True),
audio_bounds=self._get_audio_bounds(input_ids, audio_start_id,
audio_end_id),
data=audio_embeds,
type="audio_embeds")
if len(audio_features) > 0:
audio_features_all = [
i.permute(1, 0) for audio_feature in audio_features
for i in audio_feature
]
audio_features = torch.nn.utils.rnn.pad_sequence(
audio_features_all, batch_first=True,
padding_value=0.0).permute(0, 2, 1)
audio_feature_lens = torch.cat(
[item for item in audio_feature_lens])
)
if audio_features is not None:
if not isinstance(audio_features, (torch.Tensor, list)):
raise ValueError("Incorrect type of audio_features. "
f"Got type: {type(audio_features)}")
audio_feature_lens = kwargs.pop("audio_feature_lens")
if not isinstance(audio_feature_lens, (torch.Tensor, list)):
raise ValueError("Incorrect type of audio_feature_lens. "
f"Got type: {type(audio_feature_lens)}")
return MiniCPMOAudioFeatureInputs(
type="audio_features",
audio_features=flatten_bn(audio_features, concat=True),
audio_feature_lens=flatten_bn(
flatten_2d_lists(audio_feature_lens), concat=True),
audio_bounds=self._get_audio_bounds(input_ids, audio_start_id,
audio_end_id),
data=audio_features,
audio_feature_lens=audio_feature_lens,
type="audio_features")
return None
)
raise AssertionError("This line should be unreachable.")
def _parse_and_validate_inputs(self, input_ids: torch.Tensor,
**kwargs: object):
@ -803,7 +761,7 @@ class MiniCPMO(MiniCPMV2_6):
else:
image_inputs, audio_inputs = \
self._parse_and_validate_inputs(input_ids, **kwargs)
vlm_embeddings, _ = self.get_embedding_with_vision(
vlm_embeddings = self.get_embedding_with_vision(
input_ids, image_inputs)
if audio_inputs is not None:

View File

@ -24,6 +24,7 @@
"""Inference-only MiniCPM-V model compatible with HuggingFace weights."""
import math
import re
from collections import defaultdict
from collections.abc import Iterable, Mapping, Sequence
from functools import cached_property, partial
from typing import (Any, Callable, Dict, List, Literal, Optional, Set, Tuple,
@ -63,11 +64,12 @@ from vllm.multimodal.processing import (BaseMultiModalProcessor,
from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs
from vllm.platforms import current_platform
from vllm.sequence import IntermediateTensors
from vllm.utils import flatten_2d_lists
from .idefics2_vision_model import Idefics2VisionTransformer
from .interfaces import (SupportsLoRA, SupportsMultiModal, SupportsPP,
SupportsV0Only)
from .utils import AutoWeightsLoader, maybe_prefix
from .utils import AutoWeightsLoader, flatten_bn, maybe_prefix
CPU_DEVICE = torch.device("cpu")
@ -76,7 +78,7 @@ RawImageType = Union[Image.Image, torch.Tensor]
class MiniCPMVImagePixelInputs(TypedDict):
type: Literal["pixel_values"]
data: List[torch.Tensor]
pixel_values: list[torch.Tensor]
"""
Shape: `(batch_size * num_images * num_slices, num_channels, height, width)`
@ -101,7 +103,7 @@ class MiniCPMVImagePixelInputs(TypedDict):
class MiniCPMVImageEmbeddingInputs(TypedDict):
type: Literal["image_embeds"]
data: torch.Tensor
image_embeds: torch.Tensor
"""
Shape: `(batch_size * num_images * num_slices,
image_feature_size, hidden_size)`
@ -231,26 +233,15 @@ def get_version_by_config(config: PretrainedConfig) -> Tuple[int, ...]:
def _minicpmv_field_config(hf_inputs: Mapping[str, torch.Tensor]):
image_num_slices = hf_inputs.get("image_num_slices", torch.empty(0))
video_num_slices = hf_inputs.get("video_num_slices", torch.empty(0))
return dict(
pixel_values=MultiModalFieldConfig.flat_from_sizes(
"image", image_num_slices),
pixel_values=MultiModalFieldConfig.batched("image"),
image_sizes=MultiModalFieldConfig.batched("image"),
tgt_sizes=MultiModalFieldConfig.flat_from_sizes(
"image", image_num_slices),
image_num_slices=MultiModalFieldConfig.batched("image"),
image_embeds=MultiModalFieldConfig.flat_from_sizes(
"image", image_num_slices),
video_pixel_values=MultiModalFieldConfig.flat_from_sizes(
"video", video_num_slices),
tgt_sizes=MultiModalFieldConfig.batched("image"),
image_embeds=MultiModalFieldConfig.batched("image"),
video_pixel_values=MultiModalFieldConfig.batched("video"),
video_image_sizes=MultiModalFieldConfig.batched("video"),
video_tgt_sizes=MultiModalFieldConfig.flat_from_sizes(
"video", video_num_slices),
video_embeds=MultiModalFieldConfig.flat_from_sizes(
"video", video_num_slices),
video_num_slices=MultiModalFieldConfig.batched("video"),
video_tgt_sizes=MultiModalFieldConfig.batched("video"),
video_embeds=MultiModalFieldConfig.batched("video"),
)
@ -356,12 +347,6 @@ class MiniCPMVProcessingInfo(BaseProcessingInfo):
def get_model_version(self):
return get_version_by_config(self.get_hf_config())
def get_supported_mm_modalities(self) -> List[str]:
if self.get_model_version() == (2, 6):
return ["image", "video"]
else:
return ["image"]
def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]:
if self.get_model_version() == (2, 6):
return {"image": None, "video": None}
@ -526,187 +511,123 @@ class MiniCPMVMultiModalProcessor(BaseMultiModalProcessor[_I]):
def get_image_prompt_texts(self,
image_size: ImageSize,
image_idx: int = 0) -> str:
prompt_texts = self.get_slice_image_placeholder(image_size,
image_idx=image_idx)
return prompt_texts
return self.get_slice_image_placeholder(image_size,
image_idx=image_idx)
def get_video_prompt_texts(self, image_size: ImageSize,
num_frames: int) -> str:
prompt_texts = "".join(
self.get_slice_image_placeholder(
image_size=image_size,
image_idx=0,
max_slice_nums=self.info.get_video_max_slice_num(),
use_image_id=False) for image_idx in range(num_frames))
return prompt_texts
return self.get_slice_image_placeholder(
image_size=image_size,
image_idx=0,
max_slice_nums=self.info.get_video_max_slice_num(),
use_image_id=False,
) * num_frames
def get_special_tokens(self) -> Dict[str, torch.Tensor]:
tokenizer = self.info.get_tokenizer()
special_tokens = {
"im_start_id": torch.tensor(tokenizer.im_start_id),
"im_end_id": torch.tensor(tokenizer.im_end_id)
"im_start_id": tokenizer.im_start_id,
"im_end_id": tokenizer.im_end_id,
}
if hasattr(tokenizer, "slice_start_id"):
special_tokens["slice_start_id"] = torch.tensor(
tokenizer.slice_start_id)
special_tokens["slice_end_id"] = torch.tensor(
tokenizer.slice_end_id)
return special_tokens
special_tokens["slice_start_id"] = tokenizer.slice_start_id
special_tokens["slice_end_id"] = tokenizer.slice_end_id
@staticmethod
def repack_processor_outputs(outputs: Any) -> BatchFeature:
valid_keys = ["pixel_values", "image_sizes", "tgt_sizes"]
outputs = {key: outputs[key][0] for key in valid_keys}
return outputs
return {k: torch.tensor(v) for k, v in special_tokens.items()}
def process_images(
self,
mm_data: Mapping[str, object],
mm_kwargs: Mapping[str, object],
) -> Mapping[str, NestedTensors]:
mm_data = dict(mm_data)
if (images := mm_data.get("images")) is None:
return {}
images = mm_data.pop("images", [])
image_embeds = mm_data.pop("image_embeds", [])
if isinstance(images, Image.Image):
images = [images]
if isinstance(images, (list, torch.Tensor)) and len(images) > 0:
image_outputs = super()._call_hf_processor(
prompt=self.info.image_pattern * len(images),
mm_data={"images": images},
mm_kwargs=mm_kwargs)
image_outputs = self.repack_processor_outputs(image_outputs)
elif len(image_embeds) > 0:
image_sizes = mm_data.pop("image_sizes", None)
image_outputs = {
"image_embeds": torch.cat(image_embeds),
"image_sizes": image_sizes
}
else:
image_outputs = {}
return image_outputs
parsed_images = (self._get_data_parser().parse_mm_data({
"image": images
}).get_items("image", ImageProcessorItems))
return self._base_call_hf_processor(
prompts=[self.info.image_pattern] * len(parsed_images),
mm_data={"images": [[image] for image in parsed_images]},
mm_kwargs=mm_kwargs,
out_keys={"pixel_values", "image_sizes", "tgt_sizes"},
)
def process_videos(
self,
mm_data: Mapping[str, object],
mm_kwargs: Mapping[str, object],
) -> Mapping[str, NestedTensors]:
mm_data = dict(mm_data)
if (videos := mm_data.get("videos")) is None:
return {}
videos = mm_data.pop("videos", [])
video_embeds = mm_data.pop("video_embeds", [])
if len(videos) > 0 and isinstance(videos[0], Image.Image):
videos = [videos]
if isinstance(videos, list) and len(videos) > 0:
video_outputs = {
"video_pixel_values": [],
"video_image_sizes": [],
"video_tgt_sizes": [],
"num_frames": []
}
for video in videos:
parsed_video = []
for frame in video:
if isinstance(frame, np.ndarray):
parsed_video.append(Image.fromarray(frame))
else:
parsed_video.append(frame)
video = parsed_video
single_video_outputs = super()._call_hf_processor(
prompt=self.info.image_pattern * len(video),
mm_data={"images": video},
mm_kwargs={
**mm_kwargs, "max_slice_nums":
self.info.get_video_max_slice_num()
})
video_outputs["num_frames"].append(len(video))
for key in single_video_outputs:
if "video_" + key in video_outputs:
if key == "image_sizes":
video_outputs["video_" + key].append(
single_video_outputs[key][0][0])
else:
video_outputs["video_" +
key] += single_video_outputs[key][0]
elif len(video_embeds):
image_sizes = mm_data.pop("image_sizes", None)
num_frames = mm_data.pop("num_frames", None)
video_outputs = {
"video_embeds": torch.cat(video_embeds),
"video_image_sizes": image_sizes,
"num_frames": num_frames
}
else:
video_outputs = {}
return video_outputs
parsed_videos = (self._get_data_parser().parse_mm_data({
"video": videos
}).get_items("video", VideoProcessorItems))
max_slice_num = self.info.get_video_max_slice_num()
video_inputs = self._base_call_hf_processor(
prompts=[
self.info.image_pattern * len(video) for video in parsed_videos
],
mm_data={"images": list(parsed_videos)},
mm_kwargs={
**mm_kwargs, "max_slice_nums": max_slice_num
},
out_keys={"pixel_values", "image_sizes", "tgt_sizes"},
)
return {f"video_{k}": v for k, v in video_inputs.items()}
def get_placeholder_match_pattern(self) -> str:
return r"\(<(image|video)>./</\1>\)"
def get_placeholder_split_pattern(self) -> str:
return r"\(<(?:image|video)>./</(?:image|video)>\)"
def process_mm_inputs(
self,
mm_data: Mapping[str, object],
mm_kwargs: Mapping[str, object],
) -> Mapping[str, Mapping[str, NestedTensors]]:
) -> Mapping[str, NestedTensors]:
return {
"image": self.process_images(mm_data, mm_kwargs),
"video": self.process_videos(mm_data, mm_kwargs),
**self.process_images(mm_data, mm_kwargs),
**self.process_videos(mm_data, mm_kwargs),
}
def get_input_modalities(self, mm_data) -> List[str]:
supported_mm_modalities = self.info.get_supported_mm_modalities()
input_modalities = []
for modality in supported_mm_modalities:
if modality in mm_data and mm_data[modality] != {}:
input_modalities.append(modality)
return input_modalities
def get_modality_num_counter(self, modality: str) -> str:
if modality == "image":
return "image_sizes"
elif modality == "video":
return "video_image_sizes"
raise NotImplementedError(modality)
def get_num_slices_by_modality(self, inputs: dict[str, Any], modality: str,
index: int) -> int:
if modality == "image":
return self.info.get_image_slice_nums(
inputs[modality]["image_sizes"][index],
self.info.get_max_slice_num())
elif modality == "video":
return self.info.get_image_slice_nums(
inputs[modality]["video_image_sizes"][index],
self.info.get_video_max_slice_num()
) * inputs[modality]["num_frames"][index]
else:
raise ValueError(f"Unexpected modality: {modality}")
def get_prompt_texts_by_modality(self, inputs: dict[str, Any],
modality: str, index: int) -> str:
if modality == "image":
return self.get_image_prompt_texts(
inputs["image"]["image_sizes"][index], index)
elif modality == "video":
return self.get_video_prompt_texts(
inputs["video"]["video_image_sizes"][index],
inputs["video"]["num_frames"][index])
else:
raise ValueError(f"Unexpected modality: {modality}")
def call_base_hf_processor(
def _base_call_hf_processor(
self,
prompt: str,
mm_data: Mapping[str, object],
prompts: list[str],
mm_data: Mapping[str, Sequence[object]],
mm_kwargs: Mapping[str, object],
) -> BatchFeature:
return super()._call_hf_processor(prompt=prompt,
mm_data=mm_data,
mm_kwargs=mm_kwargs)
*,
out_keys: set[str],
) -> Mapping[str, NestedTensors]:
# This processor supports zipping prompt and mm_data together
if self.info.get_model_version() == (2, 6):
inputs = super()._call_hf_processor(
prompt=prompts, # type: ignore
mm_data=mm_data,
mm_kwargs=mm_kwargs,
)
else:
inputs = defaultdict[str, list[torch.Tensor]](list)
for i, prompt in enumerate(prompts):
inputs_one = super()._call_hf_processor(
prompt=prompt,
mm_data={
k: v[i]
for k, v in mm_data.items()
},
mm_kwargs=mm_kwargs,
)
for k, v in inputs_one.items():
assert len(v) == 1, (k, len(v))
inputs[k].append(v[0])
return {k: inputs[k] for k in out_keys}
def _call_hf_processor(
self,
@ -717,35 +638,12 @@ class MiniCPMVMultiModalProcessor(BaseMultiModalProcessor[_I]):
# Do not support combination inputs of images and videos for now
# Try to handle interleaved multimodal data
tokenizer = self.info.get_tokenizer()
inputs = self.process_mm_inputs(mm_data, mm_kwargs)
mm_input_modalities = self.get_input_modalities(inputs)
num_mm_slices_lst = {
modality: list[int]()
for modality in mm_input_modalities
}
for modality in mm_input_modalities:
num_counter_key = self.get_modality_num_counter(modality)
for index in range(len(inputs[modality][num_counter_key])):
num_mm_slices_lst[modality].append(
self.get_num_slices_by_modality(inputs, modality, index))
num_mm_slices = {
modality: torch.tensor(v)
for modality, v in num_mm_slices_lst.items()
}
mm_inputs = self.process_mm_inputs(mm_data, mm_kwargs)
return BatchFeature({
"input_ids": np.array([tokenizer.encode(prompt)]),
**{
key: value
for modality in inputs
for key, value in inputs[modality].items()
},
**{
f"{modality}_num_slices": num_mm_slices[modality]
for modality in mm_input_modalities
}
"input_ids":
torch.tensor([tokenizer.encode(prompt)]),
**mm_inputs,
})
def _hf_processor_applies_updates(
@ -810,7 +708,6 @@ class MiniCPMVMultiModalProcessor(BaseMultiModalProcessor[_I]):
hf_processor_mm_kwargs: Mapping[str, object],
return_mm_hashes: bool = False,
) -> MultiModalInputs:
supported_mm_modalities = self.info.get_supported_mm_modalities()
if isinstance(prompt, list):
prompt = self.info.get_tokenizer().decode(prompt)
matches = re.findall(self.get_placeholder_match_pattern(), prompt)
@ -818,7 +715,7 @@ class MiniCPMVMultiModalProcessor(BaseMultiModalProcessor[_I]):
f"{modality}_orders":
torch.tensor(
[index for index, m in enumerate(matches) if m == modality])
for modality in supported_mm_modalities
for modality in self.info.get_supported_mm_limits()
}
result = super().apply(prompt, mm_data, hf_processor_mm_kwargs,
return_mm_hashes)
@ -884,35 +781,35 @@ class MiniCPMVBaseModel(nn.Module, SupportsMultiModal, SupportsPP,
self,
input_ids: torch.Tensor,
image_inputs: Optional[MiniCPMVImageInputs],
) -> Tuple[torch.Tensor, torch.Tensor]:
) -> torch.Tensor:
vlm_embedding: torch.Tensor = self.llm.get_input_embeddings(input_ids)
if image_inputs is None: # No image
vision_hidden_states = torch.tensor([], device=input_ids.device)
if image_inputs is None:
return vlm_embedding
if image_inputs["type"] == "image_embeds":
vision_hidden_states = image_inputs["image_embeds"].to(
device=vlm_embedding.device,
dtype=vlm_embedding.dtype,
)
else:
if image_inputs["type"] == "image_embeds":
vision_hidden_states = (image_inputs["data"].type(
vlm_embedding.dtype).to(vlm_embedding.device))
else:
vision_hidden_states = self.get_vision_hidden_states(
image_inputs)
vision_hidden_states = self.get_vision_hidden_states(image_inputs)
# See NOTE in _parse_and_validate_inputs
image_bounds = image_inputs["image_bounds"]
if len(image_bounds) > 0:
image_indices = torch.stack([
torch.arange(start, end, dtype=torch.long)
for start, end in image_bounds.tolist()
]).to(vlm_embedding.device)
vlm_embedding.scatter_(
0,
image_indices.view(-1, 1).repeat(1,
vlm_embedding.shape[-1]),
vision_hidden_states.view(-1,
vision_hidden_states.shape[-1]),
)
# See NOTE in _parse_and_validate_inputs
image_bounds = image_inputs["image_bounds"]
if len(image_bounds) > 0:
image_indices = torch.stack([
torch.arange(start, end, dtype=torch.long)
for start, end in image_bounds.tolist()
]).to(vlm_embedding.device)
return vlm_embedding, vision_hidden_states
vlm_embedding.scatter_(
0,
image_indices.view(-1, 1).repeat(1, vlm_embedding.shape[-1]),
vision_hidden_states.view(-1, vision_hidden_states.shape[-1]),
)
return vlm_embedding
def _get_image_bounds(
self,
@ -947,90 +844,115 @@ class MiniCPMVBaseModel(nn.Module, SupportsMultiModal, SupportsPP,
input_ids: torch.Tensor,
**kwargs: object,
) -> Optional[MiniCPMVImageInputs]:
mm_data = {
image_keys = {"pixel_values", "tgt_sizes"}
pixel_data = {
"image": {
key: kwargs.pop(key, [])
for key in ["pixel_values", "tgt_sizes", "image_num_slices"]
key: kwargs.pop(key, None)
for key in image_keys
},
"video": {
"pixel_values": kwargs.pop("video_pixel_values", []),
"tgt_sizes": kwargs.pop("video_tgt_sizes", []),
"video_num_slices": kwargs.pop("video_num_slices", [])
key: kwargs.pop("video_" + key, None)
for key in image_keys
}
}
im_start_id = kwargs.pop("im_start_id", None)
im_end_id = kwargs.pop("im_end_id", None)
slice_start_id = kwargs.pop("slice_start_id", None)
slice_end_id = kwargs.pop("slice_end_id", None)
mm_orders = {
f"{modality}": kwargs.pop(f"{modality}_orders", None)
for modality in ["image", "video", "audio"]
embed_data = {
"image": kwargs.pop("image_embeds", None),
"video": kwargs.pop("video_embeds", None),
}
batch_size = max(len(mm_data["image"]["pixel_values"]),
len(mm_data["video"]["pixel_values"]))
image_embeds = kwargs.pop("image_embeds", None)
video_embeds = kwargs.pop("video_embeds", None)
if image_embeds is not None and video_embeds is not None:
raise ValueError(
"Incorrect inputs for vision embeddings. "
"Image embeds and video embeds can not exist simultaneously.")
if video_embeds is not None:
image_embeds = video_embeds
if image_embeds is not None:
if not isinstance(image_embeds, (torch.Tensor, list)):
raise ValueError(f"Incorrect type of image embeds. "
f"Got type: {type(image_embeds)}")
image_embeds = torch.concat(
[image_embeds[i] for i in range(len(image_embeds))])
all_pixel_data = [
v for vs in pixel_data.values() for v in vs.values()
if v is not None
]
all_embed_data = [v for v in embed_data.values() if v is not None]
if len(all_pixel_data) == 0 and len(all_embed_data) == 0:
return None
im_start_id = kwargs.pop("im_start_id")
if not isinstance(im_start_id, torch.Tensor):
raise ValueError("Incorrect type of im_start_id. "
f"Got type: {type(im_start_id)}")
im_end_id = kwargs.pop("im_end_id")
if not isinstance(im_end_id, torch.Tensor):
raise ValueError("Incorrect type of im_end_id. "
f"Got type: {type(im_end_id)}")
slice_start_id = kwargs.pop("slice_start_id", None)
if slice_start_id is not None and not isinstance(
slice_start_id, torch.Tensor):
raise ValueError("Incorrect type of slice_start_id. "
f"Got type: {type(slice_start_id)}")
slice_end_id = kwargs.pop("slice_end_id", None)
if slice_end_id is not None and not isinstance(slice_end_id,
torch.Tensor):
raise ValueError("Incorrect type of slice_end_id. "
f"Got type: {type(slice_end_id)}")
if len(all_embed_data) > 0:
if len(all_embed_data) > 1:
raise ValueError("Incorrect inputs for vision embeddings. "
"Image embeds and video embeds can not "
"exist simultaneously.")
vision_embeds, = all_embed_data
if not isinstance(vision_embeds, (torch.Tensor, list)):
raise ValueError(f"Incorrect type of vision_embeds. "
f"Got type: {type(vision_embeds)}")
return MiniCPMVImageEmbeddingInputs(
type="image_embeds",
image_embeds=flatten_bn(flatten_2d_lists(vision_embeds),
concat=True),
image_bounds=self._get_image_bounds(input_ids, im_start_id,
im_end_id, slice_start_id,
slice_end_id),
data=image_embeds,
type="image_embeds",
)
for modality, modality_mm_data in mm_data.items():
if not isinstance(modality_mm_data["pixel_values"],
(torch.Tensor, list)):
raise ValueError(
"Incorrect type of pixel values. "
f"Got type: {type(modality_mm_data['pixel_values'])}")
if not isinstance(modality_mm_data["tgt_sizes"],
(torch.Tensor, list)):
raise ValueError(
"Incorrect type of target sizes. "
f"Got type: {type(modality_mm_data['tgt_sizes'])}")
order_data = dict[str, Union[torch.Tensor, list[torch.Tensor]]]()
for modality in ("image", "video"):
modality_orders = kwargs.pop(f"{modality}_orders", None)
if modality_orders is not None:
if not isinstance(modality_orders, (torch.Tensor, list)):
raise ValueError(f"Incorrect type of {modality}_orders. "
f"Got type: {type(modality_orders)}")
if len(modality_mm_data["pixel_values"]) != len(
modality_mm_data["tgt_sizes"]):
raise ValueError(
"Inconsistent batch lengths, found: "
f"{len(modality_mm_data['pixel_values'])} vs. "
f"{len(modality_mm_data['tgt_sizes'])}")
order_data[modality] = modality_orders
pixel_values_flat: List[torch.Tensor] = []
tgt_sizes_flat: List[torch.Tensor] = []
batch_sizes = {
modality: len(modality_orders)
for modality, modality_orders in order_data.items()
}
unique_batch_sizes = set(batch_sizes.values())
assert len(unique_batch_sizes) == 1, (
f"Found inconsistent batch sizes: {batch_sizes}")
batch_size, = unique_batch_sizes
pixel_values_flat = list[torch.Tensor]()
tgt_sizes_flat = list[torch.Tensor]()
for b in range(batch_size):
mm_counts = {"image": 0, "video": 0} if self.version == (2, 6) \
else {"image": 0}
mm_slice_counts = {"image": 0, "video": 0} \
if self.version == (2, 6) else {"image": 0}
mm_orders_b = [(index, modality) for modality in mm_counts
for index in mm_orders[modality][b]]
mm_orders_b = [(idx_b.item(), modality)
for modality, modality_orders in order_data.items()
for idx_b in modality_orders[b]]
for _, modality in sorted(mm_orders_b, key=lambda x: x[0]):
pos = mm_counts[modality]
num_slices = mm_data[modality][f"{modality}_num_slices"][b][
pos]
slice_start_idx = mm_slice_counts[modality]
slice_end_idx = slice_start_idx + num_slices
pixel_values_flat += mm_data[modality]["pixel_values"][b][
slice_start_idx:slice_end_idx]
tgt_sizes_flat += mm_data[modality]["tgt_sizes"][b][
slice_start_idx:slice_end_idx]
mm_counts[modality] += 1
mm_slice_counts[modality] += num_slices
modality_pixel_data = pixel_data[modality]
modality_pixel_values = modality_pixel_data["pixel_values"]
if not isinstance(modality_pixel_values, (torch.Tensor, list)):
raise ValueError(
f"Incorrect type of pixel_values for {modality=}. "
f"Got type: {type(modality_pixel_values)}")
modality_tgt_sizes = modality_pixel_data["tgt_sizes"]
if not isinstance(modality_tgt_sizes, (torch.Tensor, list)):
raise ValueError(
f"Incorrect type of tgt_sizes for {modality=}. "
f"Got type: {type(modality_tgt_sizes)}")
pixel_values_flat += flatten_2d_lists(modality_pixel_values[b])
tgt_sizes_flat += flatten_2d_lists(modality_tgt_sizes[b])
# NOTE: Input IDs does not contain image tokens during memory profiling,
# so we allow it to be empty
@ -1042,16 +964,13 @@ class MiniCPMVBaseModel(nn.Module, SupportsMultiModal, SupportsPP,
if len(pixel_values_flat) == 0:
return None
if im_start_id is None:
return None
return MiniCPMVImagePixelInputs(
type="pixel_values",
pixel_values=pixel_values_flat,
tgt_sizes=torch.stack(tgt_sizes_flat),
image_bounds=self._get_image_bounds(input_ids, im_start_id,
im_end_id, slice_start_id,
slice_end_id),
data=pixel_values_flat,
tgt_sizes=torch.stack(tgt_sizes_flat),
type="pixel_values",
)
def _parse_and_validate_inputs(self, input_ids: torch.Tensor,
@ -1070,7 +989,7 @@ class MiniCPMVBaseModel(nn.Module, SupportsMultiModal, SupportsPP,
else:
image_inputs = \
self._parse_and_validate_inputs(input_ids, **kwargs)
vlm_embeddings, _ = self.get_embedding_with_vision(
vlm_embeddings = self.get_embedding_with_vision(
input_ids, image_inputs)
# always pass the input via `inputs_embeds`
@ -1136,16 +1055,8 @@ class MiniCPMVBaseModel(nn.Module, SupportsMultiModal, SupportsPP,
prefix: str = "") -> nn.Module:
raise NotImplementedError
def get_vision_embedding(
self,
pixel_values: List[torch.Tensor],
patch_attn_mask: Optional[torch.Tensor] = None,
tgt_sizes: Optional[torch.Tensor] = None,
) -> torch.Tensor:
raise NotImplementedError
def get_vision_hidden_states(self,
data: MiniCPMVImageInputs) -> torch.Tensor:
def get_vision_hidden_states(
self, data: MiniCPMVImagePixelInputs) -> torch.Tensor:
raise NotImplementedError
@ -1216,35 +1127,27 @@ class MiniCPMV2_0(MiniCPMVBaseModel):
return resampler.to(device=current_platform.device_type,
dtype=torch.get_default_dtype())
def get_vision_embedding(
self,
pixel_values: List[torch.Tensor],
patch_attn_mask: Optional[torch.Tensor] = None,
tgt_sizes: Optional[torch.Tensor] = None,
) -> torch.Tensor:
res = []
dtype = self.vpm.pos_embed.data.dtype
def get_vision_hidden_states(
self, data: MiniCPMVImagePixelInputs) -> torch.Tensor:
pixel_values = data["pixel_values"]
P_h, P_w = self.vpm.patch_embed.patch_size
dtype: torch.dtype = self.vpm.pos_embed.data.dtype
num_prefix_tokens = getattr(self.vpm, "num_prefix_tokens", 0)
res = list[torch.Tensor]()
for pixel_value in pixel_values:
H, W = pixel_value[0].shape[-2:]
tgt_size = (
math.ceil(H / self.vpm.patch_embed.patch_size[0]),
math.ceil(W / self.vpm.patch_embed.patch_size[0]),
)
tgt_size = (math.ceil(H / P_h), math.ceil(W / P_w))
vision_embedding = self.vpm.forward_features(
pixel_value.unsqueeze(0).type(dtype))
if (hasattr(self.vpm, "num_prefix_tokens")
and self.vpm.num_prefix_tokens > 0):
vision_embedding = vision_embedding[:, self.vpm.
num_prefix_tokens:]
if num_prefix_tokens > 0:
vision_embedding = vision_embedding[:, num_prefix_tokens:]
res.append(self.resampler(vision_embedding, tgt_size))
return torch.vstack(res)
def get_vision_hidden_states(self,
data: MiniCPMVImageInputs) -> torch.Tensor:
pixel_values = data["data"]
return self.get_vision_embedding(pixel_values)
class MiniCPMV2_5(MiniCPMVBaseModel, SupportsLoRA):
packed_modules_mapping = {
@ -1299,45 +1202,41 @@ class MiniCPMV2_5(MiniCPMVBaseModel, SupportsLoRA):
return resampler.to(device=current_platform.device_type,
dtype=torch.get_default_dtype())
def get_vision_embedding(
self,
pixel_values: List[torch.Tensor],
patch_attn_mask: Optional[torch.Tensor] = None,
tgt_sizes: Optional[torch.Tensor] = None,
) -> torch.Tensor:
vision_embedding = self.vpm(pixel_values,
patch_attention_mask=patch_attn_mask)
vision_embedding = self.resampler(vision_embedding, tgt_sizes)
return vision_embedding
def get_vision_hidden_states(self,
data: MiniCPMVImageInputs) -> torch.Tensor:
pixel_values = data["data"]
def get_vision_hidden_states(
self, data: MiniCPMVImagePixelInputs) -> torch.Tensor:
pixel_values = data["pixel_values"]
tgt_sizes = data["tgt_sizes"]
device = self.vpm.embeddings.position_embedding.weight.device
dtype = self.vpm.embeddings.position_embedding.weight.dtype
all_pixel_values_lst = [
i.flatten(end_dim=1).permute(1, 0) for i in pixel_values
]
B = len(pixel_values)
P = pixel_values[0].shape[-2]
L = max(item.shape[-1] for item in pixel_values)
device = pixel_values[0].device
dtype = pixel_values[0].dtype
max_patches = (tgt_sizes[:, 0] * tgt_sizes[:, 1]).max().item()
all_pixel_values = torch.zeros((B, 3, P, L),
dtype=dtype,
device=device)
for i, pixel_values_item in enumerate(pixel_values):
L_item = pixel_values_item.shape[-1]
all_pixel_values[i, ..., :L_item] = pixel_values_item
num_patches = tgt_sizes.prod(-1)
max_patches = num_patches.max().item()
assert isinstance(max_patches, int)
all_pixel_values = torch.nn.utils.rnn.pad_sequence(
all_pixel_values_lst, batch_first=True, padding_value=0.0)
B, L, _ = all_pixel_values.shape
all_pixel_values = all_pixel_values.permute(0, 2,
1).reshape(B, 3, -1, L)
patch_attn_mask = torch.zeros((B, 1, max_patches),
patch_attn_mask = torch.zeros((B, max_patches),
dtype=torch.bool,
device=device)
for i in range(B):
patch_attn_mask[i, :tgt_sizes[i][0] * tgt_sizes[i][1]] = True
for i, num_patches_item in enumerate(num_patches):
patch_attn_mask[i, :num_patches_item] = True
return self.get_vision_embedding(all_pixel_values.type(dtype),
patch_attn_mask, tgt_sizes)
vision_embedding = self.vpm(
all_pixel_values,
patch_attention_mask=patch_attn_mask.unsqueeze(1),
tgt_sizes=None,
)
return self.resampler(vision_embedding, tgt_sizes)
class MiniCPMV2_6(MiniCPMVBaseModel, SupportsLoRA):
@ -1394,47 +1293,37 @@ class MiniCPMV2_6(MiniCPMVBaseModel, SupportsLoRA):
return resampler.to(device=current_platform.device_type,
dtype=torch.get_default_dtype())
def get_vision_embedding(
self,
pixel_values: List[torch.Tensor],
patch_attn_mask: Optional[torch.Tensor] = None,
tgt_sizes: Optional[torch.Tensor] = None,
) -> torch.Tensor:
vision_embedding = self.vpm(
pixel_values,
patch_attention_mask=patch_attn_mask,
tgt_sizes=tgt_sizes,
)
return vision_embedding
def get_vision_hidden_states(self,
data: MiniCPMVImageInputs) -> torch.Tensor:
pixel_values = data["data"]
def get_vision_hidden_states(
self, data: MiniCPMVImagePixelInputs) -> torch.Tensor:
pixel_values = data["pixel_values"]
tgt_sizes = data["tgt_sizes"]
device = self.vpm.embeddings.position_embedding.weight.device
dtype = self.vpm.embeddings.position_embedding.weight.dtype
all_pixel_values_lst = [
i.flatten(end_dim=1).permute(1, 0) for i in pixel_values
]
B = len(pixel_values)
P = pixel_values[0].shape[-2]
L = max(item.shape[-1] for item in pixel_values)
device = pixel_values[0].device
dtype = pixel_values[0].dtype
max_patches = (tgt_sizes[:, 0] * tgt_sizes[:, 1]).max().item()
all_pixel_values = torch.zeros((B, 3, P, L),
dtype=dtype,
device=device)
for i, pixel_values_item in enumerate(pixel_values):
L_item = pixel_values_item.shape[-1]
all_pixel_values[i, ..., :L_item] = pixel_values_item
num_patches = tgt_sizes.prod(-1)
max_patches = num_patches.max().item()
assert isinstance(max_patches, int)
all_pixel_values = torch.nn.utils.rnn.pad_sequence(
all_pixel_values_lst, batch_first=True, padding_value=0.0)
B, L, _ = all_pixel_values.shape
all_pixel_values = all_pixel_values.permute(0, 2,
1).reshape(B, 3, -1, L)
patch_attn_mask = torch.zeros((B, 1, max_patches),
patch_attn_mask = torch.zeros((B, max_patches),
dtype=torch.bool,
device=device)
for i in range(B):
patch_attn_mask[i, 0, :tgt_sizes[i][0] * tgt_sizes[i][1]] = True
for i, num_patches_item in enumerate(num_patches):
patch_attn_mask[i, :num_patches_item] = True
vision_embedding = self.vpm(
all_pixel_values.type(dtype),
patch_attention_mask=patch_attn_mask,
all_pixel_values,
patch_attention_mask=patch_attn_mask.unsqueeze(1),
tgt_sizes=tgt_sizes,
)

View File

@ -665,6 +665,13 @@ class MultiModalKwargs(UserDict[str, NestedTensors]):
return cast(BatchedTensorInputs, json_mapped)
def __delitem__(self, key: str) -> None:
super().__delitem__(key)
for items in self._items_by_modality.values():
for item in items:
item.pop(key, None)
def __eq__(self, other: object) -> bool:
if not isinstance(other, self.__class__):
return False