[Misc] Clean up MiniCPM-V/O code (#15337)

Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
Cyrus Leung 2025-03-25 18:22:52 +08:00 committed by GitHub
parent 3e2f37a69a
commit a9e879b316
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
7 changed files with 521 additions and 651 deletions

View File

@ -361,6 +361,7 @@ def run_llava_next_video(questions: list[str],
engine_args = EngineArgs( engine_args = EngineArgs(
model="llava-hf/LLaVA-NeXT-Video-7B-hf", model="llava-hf/LLaVA-NeXT-Video-7B-hf",
max_model_len=8192, max_model_len=8192,
max_num_seqs=2,
disable_mm_preprocessor_cache=args.disable_mm_preprocessor_cache, disable_mm_preprocessor_cache=args.disable_mm_preprocessor_cache,
) )

View File

@ -163,24 +163,24 @@ VLM_TEST_SETTINGS = {
marks=[pytest.mark.core_model, pytest.mark.cpu_model], marks=[pytest.mark.core_model, pytest.mark.cpu_model],
), ),
#### Extended model tests #### Extended model tests
# "aria": VLMTestInfo( "aria": VLMTestInfo(
# models=["rhymes-ai/Aria"], models=["rhymes-ai/Aria"],
# test_type=(VLMTestType.IMAGE, VLMTestType.MULTI_IMAGE), test_type=(VLMTestType.IMAGE, VLMTestType.MULTI_IMAGE),
# prompt_formatter=lambda img_prompt: f"<|im_start|>user\n{img_prompt}<|im_end|>\n<|im_start|>assistant\n ", # noqa: E501 prompt_formatter=lambda img_prompt: f"<|im_start|>user\n{img_prompt}<|im_end|>\n<|im_start|>assistant\n ", # noqa: E501
# img_idx_to_prompt=lambda idx: "<fim_prefix><|img|><fim_suffix>\n", img_idx_to_prompt=lambda idx: "<fim_prefix><|img|><fim_suffix>\n",
# max_model_len=4096, max_model_len=4096,
# max_num_seqs=2, max_num_seqs=2,
# auto_cls=AutoModelForImageTextToText, auto_cls=AutoModelForImageTextToText,
# single_image_prompts=IMAGE_ASSETS.prompts({ single_image_prompts=IMAGE_ASSETS.prompts({
# "stop_sign": "<vlm_image>Please describe the image shortly.", "stop_sign": "<vlm_image>Please describe the image shortly.",
# "cherry_blossom": "<vlm_image>Please infer the season with reason.", # noqa: E501 "cherry_blossom": "<vlm_image>Please infer the season with reason.", # noqa: E501
# }), }),
# multi_image_prompt="<vlm_image><vlm_image>Describe the two images shortly.", # noqa: E501 multi_image_prompt="<vlm_image><vlm_image>Describe the two images shortly.", # noqa: E501
# stop_str=["<|im_end|>"], stop_str=["<|im_end|>"],
# image_size_factors=[(0.10, 0.15)], image_size_factors=[(0.10, 0.15)],
# max_tokens=64, max_tokens=64,
# marks=[large_gpu_mark(min_gb=64)], marks=[large_gpu_mark(min_gb=64)],
# ), ),
"blip2": VLMTestInfo( "blip2": VLMTestInfo(
models=["Salesforce/blip2-opt-2.7b"], models=["Salesforce/blip2-opt-2.7b"],
test_type=VLMTestType.IMAGE, test_type=VLMTestType.IMAGE,
@ -352,6 +352,7 @@ VLM_TEST_SETTINGS = {
prompt_formatter=lambda vid_prompt: f"USER: {vid_prompt} ASSISTANT:", prompt_formatter=lambda vid_prompt: f"USER: {vid_prompt} ASSISTANT:",
num_video_frames=16, num_video_frames=16,
max_model_len=4096, max_model_len=4096,
max_num_seqs=2,
auto_cls=AutoModelForVision2Seq, auto_cls=AutoModelForVision2Seq,
vllm_output_post_proc=model_utils.llava_video_vllm_to_hf_output, vllm_output_post_proc=model_utils.llava_video_vllm_to_hf_output,
), ),
@ -384,7 +385,7 @@ VLM_TEST_SETTINGS = {
), ),
"minicpmo_26": VLMTestInfo( "minicpmo_26": VLMTestInfo(
models=["openbmb/MiniCPM-o-2_6"], models=["openbmb/MiniCPM-o-2_6"],
test_type=(VLMTestType.IMAGE, VLMTestType.MULTI_IMAGE), test_type=(VLMTestType.IMAGE),
prompt_formatter=lambda img_prompt: f"<|begin_of_text|><|start_header_id|>user<|end_header_id|>\n\n{img_prompt}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n", # noqa: E501 prompt_formatter=lambda img_prompt: f"<|begin_of_text|><|start_header_id|>user<|end_header_id|>\n\n{img_prompt}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n", # noqa: E501
img_idx_to_prompt=lambda idx: "(<image>./</image>)\n", img_idx_to_prompt=lambda idx: "(<image>./</image>)\n",
max_model_len=4096, max_model_len=4096,
@ -393,9 +394,21 @@ VLM_TEST_SETTINGS = {
hf_output_post_proc=model_utils.minicpmv_trunc_hf_output, hf_output_post_proc=model_utils.minicpmv_trunc_hf_output,
patch_hf_runner=model_utils.minicpmo_26_patch_hf_runner, patch_hf_runner=model_utils.minicpmo_26_patch_hf_runner,
), ),
"minicpmo_26_multi_image": VLMTestInfo(
models=["openbmb/MiniCPM-o-2_6"],
test_type=(VLMTestType.MULTI_IMAGE),
prompt_formatter=lambda img_prompt: f"<|begin_of_text|><|start_header_id|>user<|end_header_id|>\n\n{img_prompt}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n", # noqa: E501
img_idx_to_prompt=lambda idx: "(<image>./</image>)\n",
max_model_len=4096,
max_num_seqs=2,
get_stop_token_ids=lambda tok: tok.convert_tokens_to_ids(['<|im_end|>', '<|endoftext|>']), # noqa: E501
hf_output_post_proc=model_utils.minicpmv_trunc_hf_output,
patch_hf_runner=model_utils.minicpmo_26_patch_hf_runner,
marks=[large_gpu_mark(min_gb=32)],
),
"minicpmv_26": VLMTestInfo( "minicpmv_26": VLMTestInfo(
models=["openbmb/MiniCPM-V-2_6"], models=["openbmb/MiniCPM-V-2_6"],
test_type=(VLMTestType.IMAGE, VLMTestType.MULTI_IMAGE), test_type=(VLMTestType.IMAGE),
prompt_formatter=lambda img_prompt: f"<|begin_of_text|><|start_header_id|>user<|end_header_id|>\n\n{img_prompt}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n", # noqa: E501 prompt_formatter=lambda img_prompt: f"<|begin_of_text|><|start_header_id|>user<|end_header_id|>\n\n{img_prompt}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n", # noqa: E501
img_idx_to_prompt=lambda idx: "(<image>./</image>)\n", img_idx_to_prompt=lambda idx: "(<image>./</image>)\n",
max_model_len=4096, max_model_len=4096,
@ -404,6 +417,18 @@ VLM_TEST_SETTINGS = {
hf_output_post_proc=model_utils.minicpmv_trunc_hf_output, hf_output_post_proc=model_utils.minicpmv_trunc_hf_output,
patch_hf_runner=model_utils.minicpmv_26_patch_hf_runner, patch_hf_runner=model_utils.minicpmv_26_patch_hf_runner,
), ),
"minicpmv_26_multi_image": VLMTestInfo(
models=["openbmb/MiniCPM-V-2_6"],
test_type=(VLMTestType.MULTI_IMAGE),
prompt_formatter=lambda img_prompt: f"<|begin_of_text|><|start_header_id|>user<|end_header_id|>\n\n{img_prompt}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n", # noqa: E501
img_idx_to_prompt=lambda idx: "(<image>./</image>)\n",
max_model_len=4096,
max_num_seqs=2,
get_stop_token_ids=lambda tok: tok.convert_tokens_to_ids(['<|im_end|>', '<|endoftext|>']), # noqa: E501
hf_output_post_proc=model_utils.minicpmv_trunc_hf_output,
patch_hf_runner=model_utils.minicpmv_26_patch_hf_runner,
marks=[large_gpu_mark(min_gb=32)],
),
"molmo": VLMTestInfo( "molmo": VLMTestInfo(
models=["allenai/Molmo-7B-D-0924"], models=["allenai/Molmo-7B-D-0924"],
test_type=(VLMTestType.IMAGE), test_type=(VLMTestType.IMAGE),

View File

@ -1,6 +1,5 @@
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
import copy
from functools import partial from functools import partial
from typing import Optional, Union from typing import Optional, Union
@ -29,7 +28,7 @@ def _test_processing_correctness(
hit_rate: float, hit_rate: float,
num_batches: int, num_batches: int,
simplify_rate: float, simplify_rate: float,
ignore_mm_keys: Optional[list[str]] = None, ignore_mm_keys: Optional[set[str]] = None,
): ):
model_info = HF_EXAMPLE_MODELS.find_hf_info(model_id) model_info = HF_EXAMPLE_MODELS.find_hf_info(model_id)
model_info.check_available_online(on_fail="skip") model_info.check_available_online(on_fail="skip")
@ -145,7 +144,7 @@ def _test_processing_correctness_hf(
baseline_processor: BaseMultiModalProcessor, baseline_processor: BaseMultiModalProcessor,
cached_processor: BaseMultiModalProcessor, cached_processor: BaseMultiModalProcessor,
batch_idx: int, batch_idx: int,
ignore_mm_keys: Optional[list[str]] = None, ignore_mm_keys: Optional[set[str]] = None,
): ):
if model_config.hf_config.model_type in ("mllama", "whisper", "ultravox"): if model_config.hf_config.model_type in ("mllama", "whisper", "ultravox"):
# For some multimodal models, tokenizer will always add bos_token # For some multimodal models, tokenizer will always add bos_token
@ -167,11 +166,12 @@ def _test_processing_correctness_hf(
hf_processor_mm_kwargs={}, hf_processor_mm_kwargs={},
) )
assert _inputs_equal( _assert_inputs_equal(
baseline_result, baseline_result,
cached_result, cached_result,
ignore_mm_keys, ignore_mm_keys=ignore_mm_keys,
), f"Failed ({batch_idx=}, {prompt=}, {mm_data=})" msg=f"Failed ({batch_idx=}, {prompt=}, {mm_data=})",
)
baseline_tokenized_result = baseline_processor.apply( baseline_tokenized_result = baseline_processor.apply(
token_prompt, token_prompt,
@ -179,11 +179,12 @@ def _test_processing_correctness_hf(
hf_processor_mm_kwargs={}, hf_processor_mm_kwargs={},
) )
assert _inputs_equal( _assert_inputs_equal(
baseline_result, baseline_result,
baseline_tokenized_result, baseline_tokenized_result,
ignore_mm_keys, ignore_mm_keys=ignore_mm_keys,
), f"Failed ({batch_idx=}, {prompt=}, {mm_data=})" msg=f"Failed ({batch_idx=}, {prompt=}, {mm_data=})",
)
cached_tokenized_result = cached_processor.apply( cached_tokenized_result = cached_processor.apply(
token_prompt, token_prompt,
@ -191,11 +192,12 @@ def _test_processing_correctness_hf(
hf_processor_mm_kwargs={}, hf_processor_mm_kwargs={},
) )
assert _inputs_equal( _assert_inputs_equal(
cached_result, cached_result,
cached_tokenized_result, cached_tokenized_result,
ignore_mm_keys, ignore_mm_keys=ignore_mm_keys,
), f"Failed ({batch_idx=}, {prompt=}, {mm_data=})" msg=f"Failed ({batch_idx=}, {prompt=}, {mm_data=})",
)
def _test_processing_correctness_mistral( def _test_processing_correctness_mistral(
@ -206,7 +208,7 @@ def _test_processing_correctness_mistral(
baseline_processor: BaseMultiModalProcessor, baseline_processor: BaseMultiModalProcessor,
cached_processor: BaseMultiModalProcessor, cached_processor: BaseMultiModalProcessor,
batch_idx: int, batch_idx: int,
ignore_mm_keys: Optional[list[str]] = None, ignore_mm_keys: Optional[set[str]] = None,
): ):
images = mm_data.get("image", []) images = mm_data.get("image", [])
if not isinstance(images, list): if not isinstance(images, list):
@ -233,11 +235,12 @@ def _test_processing_correctness_mistral(
hf_processor_mm_kwargs={}, hf_processor_mm_kwargs={},
) )
assert _inputs_equal( _assert_inputs_equal(
baseline_tokenized_result, baseline_tokenized_result,
cached_tokenized_result, cached_tokenized_result,
ignore_mm_keys, ignore_mm_keys=ignore_mm_keys,
), f"Failed ({batch_idx=}, {prompt=}, {mm_data=})" msg=f"Failed ({batch_idx=}, {prompt=}, {mm_data=})",
)
# yapf: disable # yapf: disable
@ -261,6 +264,7 @@ def _test_processing_correctness_mistral(
"TIGER-Lab/Mantis-8B-siglip-llama3", "TIGER-Lab/Mantis-8B-siglip-llama3",
"mistralai/Pixtral-12B-2409", "mistralai/Pixtral-12B-2409",
"mistral-community/pixtral-12b", "mistral-community/pixtral-12b",
"openbmb/MiniCPM-Llama3-V-2_5",
"openbmb/MiniCPM-o-2_6", "openbmb/MiniCPM-o-2_6",
"openbmb/MiniCPM-V-2_6", "openbmb/MiniCPM-V-2_6",
"allenai/Molmo-7B-D-0924", "allenai/Molmo-7B-D-0924",
@ -290,7 +294,7 @@ def test_processing_correctness(
# In Ultravox, the audio_features can be different depending on padding # In Ultravox, the audio_features can be different depending on padding
# The slight difference should not be a problem though, since # The slight difference should not be a problem though, since
# attention_mask lets us ignore the difference. # attention_mask lets us ignore the difference.
ignore_mm_keys = ['audio_features'] ignore_mm_keys = {"audio_features"}
_test_processing_correctness( _test_processing_correctness(
model_id, model_id,
@ -328,38 +332,26 @@ def test_processing_correctness_phi3v(
) )
def _inputs_equal( def _assert_inputs_equal(
a: MultiModalInputs, a: MultiModalInputs,
b: MultiModalInputs, b: MultiModalInputs,
ignore_mm_keys: Optional[list[str]] = None, *,
ignore_mm_keys: Optional[set[str]] = None,
msg: str = "",
): ):
return _drop_mm_kwargs_keys(a, ignore_mm_keys) == _drop_mm_kwargs_keys( if ignore_mm_keys is None:
b, ignore_mm_keys) ignore_mm_keys = set()
if msg is None:
assert "mm_kwargs" in a and "mm_kwargs" in b
else:
assert "mm_kwargs" in a and "mm_kwargs" in b, msg
def _drop_mm_kwargs_keys(
result: MultiModalInputs,
ignore_mm_keys: Optional[list[str]] = None,
) -> MultiModalInputs:
"""Drop specified keys from result['mm_kwargs'].
This is mainly to avoid doing exact match of audio_features in ultravox.
Args:
result: Result to drop keys from
ignore_mm_keys: List of keys to ignore, e.g. ['audio_features']
"""
if not ignore_mm_keys:
return result
if 'mm_kwargs' in result:
result = copy.deepcopy(result)
mm_kwargs = result['mm_kwargs']
for key in ignore_mm_keys: for key in ignore_mm_keys:
mm_kwargs.pop(key, None) a["mm_kwargs"].pop(key, None)
for items in mm_kwargs._items_by_modality.values(): b["mm_kwargs"].pop(key, None)
for item in items:
for key in ignore_mm_keys:
item.pop(key, None)
return result if msg is None:
assert a == b
else:
assert a == b, msg

View File

@ -295,8 +295,6 @@ class Gemma3MultiModalProcessor(BaseMultiModalProcessor[Gemma3ProcessingInfo]):
# HF processor pops the `num_crops` kwarg, which is needed by vLLM # HF processor pops the `num_crops` kwarg, which is needed by vLLM
if (images := mm_data.get("images")) is not None: if (images := mm_data.get("images")) is not None:
assert isinstance(images, list)
parsed_images = (self._get_data_parser().parse_mm_data({ parsed_images = (self._get_data_parser().parse_mm_data({
"image": "image":
images images

View File

@ -23,7 +23,7 @@
# limitations under the License. # limitations under the License.
"""Inference-only MiniCPM-O model compatible with HuggingFace weights.""" """Inference-only MiniCPM-O model compatible with HuggingFace weights."""
from collections.abc import Iterable, Mapping, Sequence from collections.abc import Iterable, Mapping, Sequence
from typing import (Any, Callable, Dict, List, Literal, Optional, Set, Tuple, from typing import (Any, Callable, Dict, Literal, Optional, Set, Tuple,
TypedDict, Union) TypedDict, Union)
import torch import torch
@ -43,24 +43,26 @@ from vllm.multimodal.parse import (AudioItem, AudioProcessorItems,
from vllm.multimodal.processing import PromptReplacement, PromptUpdate from vllm.multimodal.processing import PromptReplacement, PromptUpdate
from vllm.multimodal.profiling import ProcessorInputs from vllm.multimodal.profiling import ProcessorInputs
from vllm.sequence import IntermediateTensors from vllm.sequence import IntermediateTensors
from vllm.utils import flatten_2d_lists
from .minicpmv import (MiniCPMV2_6, MiniCPMVDummyInputsBuilder, from .minicpmv import (MiniCPMV2_6, MiniCPMVDummyInputsBuilder,
MiniCPMVMultiModalDataParser, MiniCPMVMultiModalDataParser,
MiniCPMVMultiModalProcessor, MiniCPMVProcessingInfo, MiniCPMVMultiModalProcessor, MiniCPMVProcessingInfo,
_minicpmv_field_config) _minicpmv_field_config)
from .utils import AutoWeightsLoader, cast_overflow_tensors, maybe_prefix from .utils import (AutoWeightsLoader, cast_overflow_tensors, flatten_bn,
maybe_prefix)
CPU_DEVICE = torch.device("cpu") CPU_DEVICE = torch.device("cpu")
class MiniCPMOAudioFeatureInputs(TypedDict): class MiniCPMOAudioFeatureInputs(TypedDict):
type: Literal["audio_features"] type: Literal["audio_features"]
data: torch.Tensor audio_features: torch.Tensor
""" """
Shape: `(batch_size * num_audios * num_slices, num_channels, length)` Shape: `(batch_size * num_audios * num_slices, num_channels, length)`
Slice here means chunk. Audio that is too long will be split into slices, Slice here means chunk. Audio that is too long will be split into slices,
which is the same as image. which is the same as image.
Padding is used therefore `data` is `torch.Tensor`. Padding is used therefore `audio_features` is `torch.Tensor`.
""" """
audio_feature_lens: torch.Tensor audio_feature_lens: torch.Tensor
@ -68,7 +70,7 @@ class MiniCPMOAudioFeatureInputs(TypedDict):
Shape: `(batch_size * num_audios * num_slices)` Shape: `(batch_size * num_audios * num_slices)`
This should be feature length of each audio slice, This should be feature length of each audio slice,
which equals to `data.shape[-1]` which equals to `audio_features.shape[-1]`
""" """
audio_bounds: torch.Tensor audio_bounds: torch.Tensor
@ -81,7 +83,7 @@ class MiniCPMOAudioFeatureInputs(TypedDict):
class MiniCPMOAudioEmbeddingInputs(TypedDict): class MiniCPMOAudioEmbeddingInputs(TypedDict):
type: Literal["audio_embeds"] type: Literal["audio_embeds"]
data: List[torch.Tensor] audio_embeds: torch.Tensor
""" """
Shape: `(batch_size * num_images * num_slices, hidden_size)` Shape: `(batch_size * num_images * num_slices, hidden_size)`
@ -102,18 +104,11 @@ MiniCPMOAudioInputs = Union[MiniCPMOAudioFeatureInputs,
def _minicpmo_field_config(hf_inputs: Mapping[str, torch.Tensor]): def _minicpmo_field_config(hf_inputs: Mapping[str, torch.Tensor]):
audio_num_slices = hf_inputs.get("audio_num_slices", torch.empty(0))
return dict( return dict(
**_minicpmv_field_config(hf_inputs), **_minicpmv_field_config(hf_inputs),
audio_features=MultiModalFieldConfig.flat_from_sizes( audio_features=MultiModalFieldConfig.batched("audio"),
"audio", audio_num_slices), audio_feature_lens=MultiModalFieldConfig.batched("audio"),
audio_feature_lens=MultiModalFieldConfig.flat_from_sizes( audio_embeds=MultiModalFieldConfig.batched("audio"),
"audio", audio_num_slices),
audio_num_slices=MultiModalFieldConfig.batched("audio"),
audio_orders_in_mm_data=MultiModalFieldConfig.batched("audio"),
audio_embeds=MultiModalFieldConfig.flat_from_sizes(
"audio", audio_num_slices),
) )
@ -153,9 +148,6 @@ class MiniCPMOMultiModalDataParser(MiniCPMVMultiModalDataParser):
class MiniCPMOProcessingInfo(MiniCPMVProcessingInfo): class MiniCPMOProcessingInfo(MiniCPMVProcessingInfo):
audio_pattern = "(<audio>./</audio>)" audio_pattern = "(<audio>./</audio>)"
def get_supported_mm_modalities(self) -> List[str]:
return ["image", "video", "audio"]
def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]: def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]:
return {"image": None, "video": None, "audio": None} return {"image": None, "video": None, "audio": None}
@ -277,95 +269,47 @@ class MiniCPMOMultiModalProcessor(
mm_data: Mapping[str, object], mm_data: Mapping[str, object],
mm_kwargs: Mapping[str, object], mm_kwargs: Mapping[str, object],
) -> Mapping[str, NestedTensors]: ) -> Mapping[str, NestedTensors]:
mm_data = dict(mm_data) if (audios := mm_data.get("audios")) is None:
return {}
audios = mm_data.pop("audios", []) parsed_audios = (self._get_data_parser().parse_mm_data({
audio_embeds = mm_data.pop("audio_embeds", []) "audio": audios
if isinstance(audios, (list, torch.Tensor)) and len(audios) > 0: }).get_items("audio", AudioProcessorItems))
audio_outputs = {
"audio_lens": [], audio_inputs = self._base_call_hf_processor(
"audio_features": [], prompts=[self.info.audio_pattern] * len(parsed_audios),
"audio_feature_lens": [], mm_data={"audios": [[audio] for audio in parsed_audios]},
"audio_num_segments": [] mm_kwargs={
} **mm_kwargs, "chunk_input": True
for audio in audios:
single_audio_outputs = super().call_base_hf_processor(
prompt=self.info.audio_pattern,
mm_data={
"audios": audio,
"chunk_input": True
}, },
mm_kwargs=mm_kwargs) out_keys={"audio_features", "audio_feature_lens"},
audio_outputs["audio_lens"].append(len(audio)) )
audio_outputs["audio_features"].append(
single_audio_outputs["audio_features"]) # Avoid padding since we need the output for each audio to be
audio_outputs["audio_num_segments"].append( # independent of other audios for the cache to work correctly
len(single_audio_outputs["audio_feature_lens"][0])) unpadded_audio_features = [
audio_outputs["audio_feature_lens"] += \ feat[:, :feature_len] for feat, feature_len in zip(
single_audio_outputs["audio_feature_lens"] audio_inputs["audio_features"],
audio_outputs["audio_features"] = [ audio_inputs["audio_feature_lens"],
audio_feature for single_audio_features in \ )
audio_outputs["audio_features"]
for audio_feature in single_audio_features
] ]
audio_outputs["audio_feature_lens"] = torch.cat( audio_inputs["audio_features"] = unpadded_audio_features
audio_outputs["audio_feature_lens"])
elif len(audio_embeds): return audio_inputs
audio_outputs = {
"audio_lens": [
self.info.get_audio_len_by_num_chunks(
sum(chunk_embeds.shape[0]
for chunk_embeds in single_audio_embeds))
for single_audio_embeds in audio_embeds
],
"audio_embeds": [
chunk_embeds for single_audio_embeds in audio_embeds
for chunk_embeds in single_audio_embeds
],
"audio_num_segments": [
len(single_audio_embeds)
for single_audio_embeds in audio_embeds
]
}
else:
audio_outputs = {}
return audio_outputs
def get_placeholder_match_pattern(self) -> str: def get_placeholder_match_pattern(self) -> str:
return r"\(<(image|video|audio)>./</\1>\)" return r"\(<(image|video|audio)>./</\1>\)"
def get_placeholder_split_pattern(self) -> str:
return r"\(<(?:image|video|audio)>./</(?:image|video|audio)>\)"
def process_mm_inputs( def process_mm_inputs(
self, self,
mm_data: Mapping[str, object], mm_data: Mapping[str, object],
mm_kwargs: Mapping[str, object], mm_kwargs: Mapping[str, object],
) -> Mapping[str, Mapping[str, NestedTensors]]: ) -> Mapping[str, NestedTensors]:
return { return {
"image": self.process_images(mm_data, mm_kwargs), **super().process_mm_inputs(mm_data, mm_kwargs),
"video": self.process_videos(mm_data, mm_kwargs), **self.process_audios(mm_data, mm_kwargs),
"audio": self.process_audios(mm_data, mm_kwargs),
} }
def get_modality_num_counter(self, modality: str) -> str:
if modality == "audio":
return "audio_lens"
return super().get_modality_num_counter(modality)
def get_num_slices_by_modality(self, inputs: Dict[str, object],
modality: str, index: int) -> int:
if modality == "audio":
return inputs["audio"]["audio_num_segments"][index]
return super().get_num_slices_by_modality(inputs, modality, index)
def get_prompt_texts_by_modality(self, inputs: Dict[str, object],
modality: str, index: int) -> str:
if modality == "audio":
return self.get_audio_prompt_texts(
inputs["audio"]["audio_lens"][index])
return super().get_prompt_texts_by_modality(inputs, modality, index)
def _get_prompt_updates( def _get_prompt_updates(
self, self,
mm_items: MultiModalDataItems, mm_items: MultiModalDataItems,
@ -622,15 +566,16 @@ class MiniCPMO(MiniCPMV2_6):
# Copied from HF repo of MiniCPM-o-2_6, # Copied from HF repo of MiniCPM-o-2_6,
# designed for batched inputs and outputs # designed for batched inputs and outputs
def get_audio_hidden_states(self, data: MiniCPMOAudioInputs, def get_audio_hidden_states(self, data: MiniCPMOAudioInputs,
chunk_length: int) -> torch.Tensor: chunk_length: int) -> list[torch.Tensor]:
wavforms = data.get( wavforms = data.get(
"data", "audio_features",
[]) # (bs, 80, frames) or [], multi audios need filled in advance []) # (bs, 80, frames) or [], multi audios need filled in advance
audio_feature_lens_raw = [data.get("audio_feature_lens", audio_feature_lens_raw = [data.get("audio_feature_lens",
[])] # list, [[x1, x2], [y1], [z1]] [])] # list, [[x1, x2], [y1], [z1]]
# exist audio if len(wavforms) == 0:
if len(wavforms) > 0: return []
audio_feature_lens = torch.hstack(audio_feature_lens_raw) audio_feature_lens = torch.hstack(audio_feature_lens_raw)
batch_size, _, max_mel_seq_len = wavforms.shape batch_size, _, max_mel_seq_len = wavforms.shape
max_seq_len = (max_mel_seq_len - 1) // 2 + 1 max_seq_len = (max_mel_seq_len - 1) // 2 + 1
@ -648,8 +593,8 @@ class MiniCPMO(MiniCPMV2_6):
padding_mask = seq_range >= lengths_expand # 1 for padded values padding_mask = seq_range >= lengths_expand # 1 for padded values
audio_attention_mask_ = padding_mask.view( audio_attention_mask_ = padding_mask.view(
batch_size, 1, 1, max_seq_len).expand(batch_size, 1, batch_size, 1, 1, max_seq_len).expand(batch_size, 1, max_seq_len,
max_seq_len, max_seq_len) max_seq_len)
audio_attention_mask = audio_attention_mask_.to( audio_attention_mask = audio_attention_mask_.to(
dtype=self.apm.conv1.weight.dtype, dtype=self.apm.conv1.weight.dtype,
device=self.apm.conv1.weight.device) device=self.apm.conv1.weight.device)
@ -690,18 +635,15 @@ class MiniCPMO(MiniCPMV2_6):
idx += 1 idx += 1
final_audio_embeds.append(target_audio_embeds) final_audio_embeds.append(target_audio_embeds)
return final_audio_embeds return final_audio_embeds
else:
return []
def get_embedding_with_audios(self, vlm_embedding: torch.Tensor, def get_embedding_with_audios(self, vlm_embedding: torch.Tensor,
audio_inputs: Optional[MiniCPMOAudioInputs], audio_inputs: MiniCPMOAudioInputs,
chunk_length: int) -> torch.Tensor: chunk_length: int) -> torch.Tensor:
device, dtype = vlm_embedding.device, vlm_embedding.dtype device, dtype = vlm_embedding.device, vlm_embedding.dtype
if audio_inputs["type"] == "audio_embeds": if audio_inputs["type"] == "audio_embeds":
audio_embeddings = audio_inputs["data"]
audio_embeddings = [ audio_embeddings = [
audio_embeddings[i].to(device=device, dtype=dtype) item.to(device=device, dtype=dtype)
for i in range(len(audio_embeddings)) for item in audio_inputs["audio_embeds"]
] ]
else: else:
audio_embeddings = self.get_audio_hidden_states( audio_embeddings = self.get_audio_hidden_states(
@ -746,40 +688,56 @@ class MiniCPMO(MiniCPMV2_6):
def _parse_and_validate_audio_inputs( def _parse_and_validate_audio_inputs(
self, input_ids: torch.Tensor, self, input_ids: torch.Tensor,
**kwargs: object) -> Tuple[MiniCPMOAudioInputs]: **kwargs: object) -> Optional[MiniCPMOAudioInputs]:
audio_features = kwargs.pop("audio_features", []) audio_features = kwargs.pop("audio_features", None)
audio_feature_lens = kwargs.pop("audio_feature_lens", [])
audio_embeds = kwargs.pop("audio_embeds", None) audio_embeds = kwargs.pop("audio_embeds", None)
audio_start_id = kwargs.pop("audio_start_id", None)
audio_end_id = kwargs.pop("audio_end_id", None) if audio_features is None and audio_embeds is None:
return None
audio_start_id = kwargs.pop("audio_start_id")
if not isinstance(audio_start_id, torch.Tensor):
raise ValueError("Incorrect type of audio_start_id. "
f"Got type: {type(audio_start_id)}")
audio_end_id = kwargs.pop("audio_end_id")
if not isinstance(audio_end_id, torch.Tensor):
raise ValueError("Incorrect type of audio_end_id. "
f"Got type: {type(audio_end_id)}")
if audio_embeds is not None: if audio_embeds is not None:
audio_embeds = [ if not isinstance(audio_embeds, (torch.Tensor, list)):
audio_embeds[i][j] for i in range(len(audio_embeds)) raise ValueError("Incorrect type of audio_embeds. "
for j in range(len(audio_embeds[i])) f"Got type: {type(audio_embeds)}")
]
return MiniCPMOAudioEmbeddingInputs( return MiniCPMOAudioEmbeddingInputs(
type="audio_embeds",
audio_embeds=flatten_bn(flatten_2d_lists(audio_embeds),
concat=True),
audio_bounds=self._get_audio_bounds(input_ids, audio_start_id, audio_bounds=self._get_audio_bounds(input_ids, audio_start_id,
audio_end_id), audio_end_id),
data=audio_embeds, )
type="audio_embeds")
if len(audio_features) > 0: if audio_features is not None:
audio_features_all = [ if not isinstance(audio_features, (torch.Tensor, list)):
i.permute(1, 0) for audio_feature in audio_features raise ValueError("Incorrect type of audio_features. "
for i in audio_feature f"Got type: {type(audio_features)}")
]
audio_features = torch.nn.utils.rnn.pad_sequence( audio_feature_lens = kwargs.pop("audio_feature_lens")
audio_features_all, batch_first=True, if not isinstance(audio_feature_lens, (torch.Tensor, list)):
padding_value=0.0).permute(0, 2, 1) raise ValueError("Incorrect type of audio_feature_lens. "
audio_feature_lens = torch.cat( f"Got type: {type(audio_feature_lens)}")
[item for item in audio_feature_lens])
return MiniCPMOAudioFeatureInputs( return MiniCPMOAudioFeatureInputs(
type="audio_features",
audio_features=flatten_bn(audio_features, concat=True),
audio_feature_lens=flatten_bn(
flatten_2d_lists(audio_feature_lens), concat=True),
audio_bounds=self._get_audio_bounds(input_ids, audio_start_id, audio_bounds=self._get_audio_bounds(input_ids, audio_start_id,
audio_end_id), audio_end_id),
data=audio_features, )
audio_feature_lens=audio_feature_lens,
type="audio_features") raise AssertionError("This line should be unreachable.")
return None
def _parse_and_validate_inputs(self, input_ids: torch.Tensor, def _parse_and_validate_inputs(self, input_ids: torch.Tensor,
**kwargs: object): **kwargs: object):
@ -803,7 +761,7 @@ class MiniCPMO(MiniCPMV2_6):
else: else:
image_inputs, audio_inputs = \ image_inputs, audio_inputs = \
self._parse_and_validate_inputs(input_ids, **kwargs) self._parse_and_validate_inputs(input_ids, **kwargs)
vlm_embeddings, _ = self.get_embedding_with_vision( vlm_embeddings = self.get_embedding_with_vision(
input_ids, image_inputs) input_ids, image_inputs)
if audio_inputs is not None: if audio_inputs is not None:

View File

@ -24,6 +24,7 @@
"""Inference-only MiniCPM-V model compatible with HuggingFace weights.""" """Inference-only MiniCPM-V model compatible with HuggingFace weights."""
import math import math
import re import re
from collections import defaultdict
from collections.abc import Iterable, Mapping, Sequence from collections.abc import Iterable, Mapping, Sequence
from functools import cached_property, partial from functools import cached_property, partial
from typing import (Any, Callable, Dict, List, Literal, Optional, Set, Tuple, from typing import (Any, Callable, Dict, List, Literal, Optional, Set, Tuple,
@ -63,11 +64,12 @@ from vllm.multimodal.processing import (BaseMultiModalProcessor,
from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs
from vllm.platforms import current_platform from vllm.platforms import current_platform
from vllm.sequence import IntermediateTensors from vllm.sequence import IntermediateTensors
from vllm.utils import flatten_2d_lists
from .idefics2_vision_model import Idefics2VisionTransformer from .idefics2_vision_model import Idefics2VisionTransformer
from .interfaces import (SupportsLoRA, SupportsMultiModal, SupportsPP, from .interfaces import (SupportsLoRA, SupportsMultiModal, SupportsPP,
SupportsV0Only) SupportsV0Only)
from .utils import AutoWeightsLoader, maybe_prefix from .utils import AutoWeightsLoader, flatten_bn, maybe_prefix
CPU_DEVICE = torch.device("cpu") CPU_DEVICE = torch.device("cpu")
@ -76,7 +78,7 @@ RawImageType = Union[Image.Image, torch.Tensor]
class MiniCPMVImagePixelInputs(TypedDict): class MiniCPMVImagePixelInputs(TypedDict):
type: Literal["pixel_values"] type: Literal["pixel_values"]
data: List[torch.Tensor] pixel_values: list[torch.Tensor]
""" """
Shape: `(batch_size * num_images * num_slices, num_channels, height, width)` Shape: `(batch_size * num_images * num_slices, num_channels, height, width)`
@ -101,7 +103,7 @@ class MiniCPMVImagePixelInputs(TypedDict):
class MiniCPMVImageEmbeddingInputs(TypedDict): class MiniCPMVImageEmbeddingInputs(TypedDict):
type: Literal["image_embeds"] type: Literal["image_embeds"]
data: torch.Tensor image_embeds: torch.Tensor
""" """
Shape: `(batch_size * num_images * num_slices, Shape: `(batch_size * num_images * num_slices,
image_feature_size, hidden_size)` image_feature_size, hidden_size)`
@ -231,26 +233,15 @@ def get_version_by_config(config: PretrainedConfig) -> Tuple[int, ...]:
def _minicpmv_field_config(hf_inputs: Mapping[str, torch.Tensor]): def _minicpmv_field_config(hf_inputs: Mapping[str, torch.Tensor]):
image_num_slices = hf_inputs.get("image_num_slices", torch.empty(0))
video_num_slices = hf_inputs.get("video_num_slices", torch.empty(0))
return dict( return dict(
pixel_values=MultiModalFieldConfig.flat_from_sizes( pixel_values=MultiModalFieldConfig.batched("image"),
"image", image_num_slices),
image_sizes=MultiModalFieldConfig.batched("image"), image_sizes=MultiModalFieldConfig.batched("image"),
tgt_sizes=MultiModalFieldConfig.flat_from_sizes( tgt_sizes=MultiModalFieldConfig.batched("image"),
"image", image_num_slices), image_embeds=MultiModalFieldConfig.batched("image"),
image_num_slices=MultiModalFieldConfig.batched("image"), video_pixel_values=MultiModalFieldConfig.batched("video"),
image_embeds=MultiModalFieldConfig.flat_from_sizes(
"image", image_num_slices),
video_pixel_values=MultiModalFieldConfig.flat_from_sizes(
"video", video_num_slices),
video_image_sizes=MultiModalFieldConfig.batched("video"), video_image_sizes=MultiModalFieldConfig.batched("video"),
video_tgt_sizes=MultiModalFieldConfig.flat_from_sizes( video_tgt_sizes=MultiModalFieldConfig.batched("video"),
"video", video_num_slices), video_embeds=MultiModalFieldConfig.batched("video"),
video_embeds=MultiModalFieldConfig.flat_from_sizes(
"video", video_num_slices),
video_num_slices=MultiModalFieldConfig.batched("video"),
) )
@ -356,12 +347,6 @@ class MiniCPMVProcessingInfo(BaseProcessingInfo):
def get_model_version(self): def get_model_version(self):
return get_version_by_config(self.get_hf_config()) return get_version_by_config(self.get_hf_config())
def get_supported_mm_modalities(self) -> List[str]:
if self.get_model_version() == (2, 6):
return ["image", "video"]
else:
return ["image"]
def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]: def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]:
if self.get_model_version() == (2, 6): if self.get_model_version() == (2, 6):
return {"image": None, "video": None} return {"image": None, "video": None}
@ -526,187 +511,123 @@ class MiniCPMVMultiModalProcessor(BaseMultiModalProcessor[_I]):
def get_image_prompt_texts(self, def get_image_prompt_texts(self,
image_size: ImageSize, image_size: ImageSize,
image_idx: int = 0) -> str: image_idx: int = 0) -> str:
prompt_texts = self.get_slice_image_placeholder(image_size, return self.get_slice_image_placeholder(image_size,
image_idx=image_idx) image_idx=image_idx)
return prompt_texts
def get_video_prompt_texts(self, image_size: ImageSize, def get_video_prompt_texts(self, image_size: ImageSize,
num_frames: int) -> str: num_frames: int) -> str:
prompt_texts = "".join( return self.get_slice_image_placeholder(
self.get_slice_image_placeholder(
image_size=image_size, image_size=image_size,
image_idx=0, image_idx=0,
max_slice_nums=self.info.get_video_max_slice_num(), max_slice_nums=self.info.get_video_max_slice_num(),
use_image_id=False) for image_idx in range(num_frames)) use_image_id=False,
return prompt_texts ) * num_frames
def get_special_tokens(self) -> Dict[str, torch.Tensor]: def get_special_tokens(self) -> Dict[str, torch.Tensor]:
tokenizer = self.info.get_tokenizer() tokenizer = self.info.get_tokenizer()
special_tokens = { special_tokens = {
"im_start_id": torch.tensor(tokenizer.im_start_id), "im_start_id": tokenizer.im_start_id,
"im_end_id": torch.tensor(tokenizer.im_end_id) "im_end_id": tokenizer.im_end_id,
} }
if hasattr(tokenizer, "slice_start_id"): if hasattr(tokenizer, "slice_start_id"):
special_tokens["slice_start_id"] = torch.tensor( special_tokens["slice_start_id"] = tokenizer.slice_start_id
tokenizer.slice_start_id) special_tokens["slice_end_id"] = tokenizer.slice_end_id
special_tokens["slice_end_id"] = torch.tensor(
tokenizer.slice_end_id)
return special_tokens
@staticmethod return {k: torch.tensor(v) for k, v in special_tokens.items()}
def repack_processor_outputs(outputs: Any) -> BatchFeature:
valid_keys = ["pixel_values", "image_sizes", "tgt_sizes"]
outputs = {key: outputs[key][0] for key in valid_keys}
return outputs
def process_images( def process_images(
self, self,
mm_data: Mapping[str, object], mm_data: Mapping[str, object],
mm_kwargs: Mapping[str, object], mm_kwargs: Mapping[str, object],
) -> Mapping[str, NestedTensors]: ) -> Mapping[str, NestedTensors]:
mm_data = dict(mm_data) if (images := mm_data.get("images")) is None:
return {}
images = mm_data.pop("images", []) parsed_images = (self._get_data_parser().parse_mm_data({
image_embeds = mm_data.pop("image_embeds", []) "image": images
if isinstance(images, Image.Image): }).get_items("image", ImageProcessorItems))
images = [images]
if isinstance(images, (list, torch.Tensor)) and len(images) > 0: return self._base_call_hf_processor(
image_outputs = super()._call_hf_processor( prompts=[self.info.image_pattern] * len(parsed_images),
prompt=self.info.image_pattern * len(images), mm_data={"images": [[image] for image in parsed_images]},
mm_data={"images": images}, mm_kwargs=mm_kwargs,
mm_kwargs=mm_kwargs) out_keys={"pixel_values", "image_sizes", "tgt_sizes"},
image_outputs = self.repack_processor_outputs(image_outputs) )
elif len(image_embeds) > 0:
image_sizes = mm_data.pop("image_sizes", None)
image_outputs = {
"image_embeds": torch.cat(image_embeds),
"image_sizes": image_sizes
}
else:
image_outputs = {}
return image_outputs
def process_videos( def process_videos(
self, self,
mm_data: Mapping[str, object], mm_data: Mapping[str, object],
mm_kwargs: Mapping[str, object], mm_kwargs: Mapping[str, object],
) -> Mapping[str, NestedTensors]: ) -> Mapping[str, NestedTensors]:
mm_data = dict(mm_data) if (videos := mm_data.get("videos")) is None:
return {}
videos = mm_data.pop("videos", []) parsed_videos = (self._get_data_parser().parse_mm_data({
video_embeds = mm_data.pop("video_embeds", []) "video": videos
if len(videos) > 0 and isinstance(videos[0], Image.Image): }).get_items("video", VideoProcessorItems))
videos = [videos]
if isinstance(videos, list) and len(videos) > 0: max_slice_num = self.info.get_video_max_slice_num()
video_outputs = {
"video_pixel_values": [], video_inputs = self._base_call_hf_processor(
"video_image_sizes": [], prompts=[
"video_tgt_sizes": [], self.info.image_pattern * len(video) for video in parsed_videos
"num_frames": [] ],
} mm_data={"images": list(parsed_videos)},
for video in videos:
parsed_video = []
for frame in video:
if isinstance(frame, np.ndarray):
parsed_video.append(Image.fromarray(frame))
else:
parsed_video.append(frame)
video = parsed_video
single_video_outputs = super()._call_hf_processor(
prompt=self.info.image_pattern * len(video),
mm_data={"images": video},
mm_kwargs={ mm_kwargs={
**mm_kwargs, "max_slice_nums": **mm_kwargs, "max_slice_nums": max_slice_num
self.info.get_video_max_slice_num() },
}) out_keys={"pixel_values", "image_sizes", "tgt_sizes"},
video_outputs["num_frames"].append(len(video)) )
for key in single_video_outputs:
if "video_" + key in video_outputs: return {f"video_{k}": v for k, v in video_inputs.items()}
if key == "image_sizes":
video_outputs["video_" + key].append(
single_video_outputs[key][0][0])
else:
video_outputs["video_" +
key] += single_video_outputs[key][0]
elif len(video_embeds):
image_sizes = mm_data.pop("image_sizes", None)
num_frames = mm_data.pop("num_frames", None)
video_outputs = {
"video_embeds": torch.cat(video_embeds),
"video_image_sizes": image_sizes,
"num_frames": num_frames
}
else:
video_outputs = {}
return video_outputs
def get_placeholder_match_pattern(self) -> str: def get_placeholder_match_pattern(self) -> str:
return r"\(<(image|video)>./</\1>\)" return r"\(<(image|video)>./</\1>\)"
def get_placeholder_split_pattern(self) -> str:
return r"\(<(?:image|video)>./</(?:image|video)>\)"
def process_mm_inputs( def process_mm_inputs(
self, self,
mm_data: Mapping[str, object], mm_data: Mapping[str, object],
mm_kwargs: Mapping[str, object], mm_kwargs: Mapping[str, object],
) -> Mapping[str, Mapping[str, NestedTensors]]: ) -> Mapping[str, NestedTensors]:
return { return {
"image": self.process_images(mm_data, mm_kwargs), **self.process_images(mm_data, mm_kwargs),
"video": self.process_videos(mm_data, mm_kwargs), **self.process_videos(mm_data, mm_kwargs),
} }
def get_input_modalities(self, mm_data) -> List[str]: def _base_call_hf_processor(
supported_mm_modalities = self.info.get_supported_mm_modalities()
input_modalities = []
for modality in supported_mm_modalities:
if modality in mm_data and mm_data[modality] != {}:
input_modalities.append(modality)
return input_modalities
def get_modality_num_counter(self, modality: str) -> str:
if modality == "image":
return "image_sizes"
elif modality == "video":
return "video_image_sizes"
raise NotImplementedError(modality)
def get_num_slices_by_modality(self, inputs: dict[str, Any], modality: str,
index: int) -> int:
if modality == "image":
return self.info.get_image_slice_nums(
inputs[modality]["image_sizes"][index],
self.info.get_max_slice_num())
elif modality == "video":
return self.info.get_image_slice_nums(
inputs[modality]["video_image_sizes"][index],
self.info.get_video_max_slice_num()
) * inputs[modality]["num_frames"][index]
else:
raise ValueError(f"Unexpected modality: {modality}")
def get_prompt_texts_by_modality(self, inputs: dict[str, Any],
modality: str, index: int) -> str:
if modality == "image":
return self.get_image_prompt_texts(
inputs["image"]["image_sizes"][index], index)
elif modality == "video":
return self.get_video_prompt_texts(
inputs["video"]["video_image_sizes"][index],
inputs["video"]["num_frames"][index])
else:
raise ValueError(f"Unexpected modality: {modality}")
def call_base_hf_processor(
self, self,
prompt: str, prompts: list[str],
mm_data: Mapping[str, object], mm_data: Mapping[str, Sequence[object]],
mm_kwargs: Mapping[str, object], mm_kwargs: Mapping[str, object],
) -> BatchFeature: *,
return super()._call_hf_processor(prompt=prompt, out_keys: set[str],
) -> Mapping[str, NestedTensors]:
# This processor supports zipping prompt and mm_data together
if self.info.get_model_version() == (2, 6):
inputs = super()._call_hf_processor(
prompt=prompts, # type: ignore
mm_data=mm_data, mm_data=mm_data,
mm_kwargs=mm_kwargs) mm_kwargs=mm_kwargs,
)
else:
inputs = defaultdict[str, list[torch.Tensor]](list)
for i, prompt in enumerate(prompts):
inputs_one = super()._call_hf_processor(
prompt=prompt,
mm_data={
k: v[i]
for k, v in mm_data.items()
},
mm_kwargs=mm_kwargs,
)
for k, v in inputs_one.items():
assert len(v) == 1, (k, len(v))
inputs[k].append(v[0])
return {k: inputs[k] for k in out_keys}
def _call_hf_processor( def _call_hf_processor(
self, self,
@ -717,35 +638,12 @@ class MiniCPMVMultiModalProcessor(BaseMultiModalProcessor[_I]):
# Do not support combination inputs of images and videos for now # Do not support combination inputs of images and videos for now
# Try to handle interleaved multimodal data # Try to handle interleaved multimodal data
tokenizer = self.info.get_tokenizer() tokenizer = self.info.get_tokenizer()
inputs = self.process_mm_inputs(mm_data, mm_kwargs) mm_inputs = self.process_mm_inputs(mm_data, mm_kwargs)
mm_input_modalities = self.get_input_modalities(inputs)
num_mm_slices_lst = {
modality: list[int]()
for modality in mm_input_modalities
}
for modality in mm_input_modalities:
num_counter_key = self.get_modality_num_counter(modality)
for index in range(len(inputs[modality][num_counter_key])):
num_mm_slices_lst[modality].append(
self.get_num_slices_by_modality(inputs, modality, index))
num_mm_slices = {
modality: torch.tensor(v)
for modality, v in num_mm_slices_lst.items()
}
return BatchFeature({ return BatchFeature({
"input_ids": np.array([tokenizer.encode(prompt)]), "input_ids":
**{ torch.tensor([tokenizer.encode(prompt)]),
key: value **mm_inputs,
for modality in inputs
for key, value in inputs[modality].items()
},
**{
f"{modality}_num_slices": num_mm_slices[modality]
for modality in mm_input_modalities
}
}) })
def _hf_processor_applies_updates( def _hf_processor_applies_updates(
@ -810,7 +708,6 @@ class MiniCPMVMultiModalProcessor(BaseMultiModalProcessor[_I]):
hf_processor_mm_kwargs: Mapping[str, object], hf_processor_mm_kwargs: Mapping[str, object],
return_mm_hashes: bool = False, return_mm_hashes: bool = False,
) -> MultiModalInputs: ) -> MultiModalInputs:
supported_mm_modalities = self.info.get_supported_mm_modalities()
if isinstance(prompt, list): if isinstance(prompt, list):
prompt = self.info.get_tokenizer().decode(prompt) prompt = self.info.get_tokenizer().decode(prompt)
matches = re.findall(self.get_placeholder_match_pattern(), prompt) matches = re.findall(self.get_placeholder_match_pattern(), prompt)
@ -818,7 +715,7 @@ class MiniCPMVMultiModalProcessor(BaseMultiModalProcessor[_I]):
f"{modality}_orders": f"{modality}_orders":
torch.tensor( torch.tensor(
[index for index, m in enumerate(matches) if m == modality]) [index for index, m in enumerate(matches) if m == modality])
for modality in supported_mm_modalities for modality in self.info.get_supported_mm_limits()
} }
result = super().apply(prompt, mm_data, hf_processor_mm_kwargs, result = super().apply(prompt, mm_data, hf_processor_mm_kwargs,
return_mm_hashes) return_mm_hashes)
@ -884,18 +781,19 @@ class MiniCPMVBaseModel(nn.Module, SupportsMultiModal, SupportsPP,
self, self,
input_ids: torch.Tensor, input_ids: torch.Tensor,
image_inputs: Optional[MiniCPMVImageInputs], image_inputs: Optional[MiniCPMVImageInputs],
) -> Tuple[torch.Tensor, torch.Tensor]: ) -> torch.Tensor:
vlm_embedding: torch.Tensor = self.llm.get_input_embeddings(input_ids) vlm_embedding: torch.Tensor = self.llm.get_input_embeddings(input_ids)
if image_inputs is None: # No image if image_inputs is None:
vision_hidden_states = torch.tensor([], device=input_ids.device) return vlm_embedding
else:
if image_inputs["type"] == "image_embeds": if image_inputs["type"] == "image_embeds":
vision_hidden_states = (image_inputs["data"].type( vision_hidden_states = image_inputs["image_embeds"].to(
vlm_embedding.dtype).to(vlm_embedding.device)) device=vlm_embedding.device,
dtype=vlm_embedding.dtype,
)
else: else:
vision_hidden_states = self.get_vision_hidden_states( vision_hidden_states = self.get_vision_hidden_states(image_inputs)
image_inputs)
# See NOTE in _parse_and_validate_inputs # See NOTE in _parse_and_validate_inputs
image_bounds = image_inputs["image_bounds"] image_bounds = image_inputs["image_bounds"]
@ -904,15 +802,14 @@ class MiniCPMVBaseModel(nn.Module, SupportsMultiModal, SupportsPP,
torch.arange(start, end, dtype=torch.long) torch.arange(start, end, dtype=torch.long)
for start, end in image_bounds.tolist() for start, end in image_bounds.tolist()
]).to(vlm_embedding.device) ]).to(vlm_embedding.device)
vlm_embedding.scatter_( vlm_embedding.scatter_(
0, 0,
image_indices.view(-1, 1).repeat(1, image_indices.view(-1, 1).repeat(1, vlm_embedding.shape[-1]),
vlm_embedding.shape[-1]), vision_hidden_states.view(-1, vision_hidden_states.shape[-1]),
vision_hidden_states.view(-1,
vision_hidden_states.shape[-1]),
) )
return vlm_embedding, vision_hidden_states return vlm_embedding
def _get_image_bounds( def _get_image_bounds(
self, self,
@ -947,90 +844,115 @@ class MiniCPMVBaseModel(nn.Module, SupportsMultiModal, SupportsPP,
input_ids: torch.Tensor, input_ids: torch.Tensor,
**kwargs: object, **kwargs: object,
) -> Optional[MiniCPMVImageInputs]: ) -> Optional[MiniCPMVImageInputs]:
mm_data = { image_keys = {"pixel_values", "tgt_sizes"}
pixel_data = {
"image": { "image": {
key: kwargs.pop(key, []) key: kwargs.pop(key, None)
for key in ["pixel_values", "tgt_sizes", "image_num_slices"] for key in image_keys
}, },
"video": { "video": {
"pixel_values": kwargs.pop("video_pixel_values", []), key: kwargs.pop("video_" + key, None)
"tgt_sizes": kwargs.pop("video_tgt_sizes", []), for key in image_keys
"video_num_slices": kwargs.pop("video_num_slices", [])
} }
} }
im_start_id = kwargs.pop("im_start_id", None) embed_data = {
im_end_id = kwargs.pop("im_end_id", None) "image": kwargs.pop("image_embeds", None),
"video": kwargs.pop("video_embeds", None),
}
all_pixel_data = [
v for vs in pixel_data.values() for v in vs.values()
if v is not None
]
all_embed_data = [v for v in embed_data.values() if v is not None]
if len(all_pixel_data) == 0 and len(all_embed_data) == 0:
return None
im_start_id = kwargs.pop("im_start_id")
if not isinstance(im_start_id, torch.Tensor):
raise ValueError("Incorrect type of im_start_id. "
f"Got type: {type(im_start_id)}")
im_end_id = kwargs.pop("im_end_id")
if not isinstance(im_end_id, torch.Tensor):
raise ValueError("Incorrect type of im_end_id. "
f"Got type: {type(im_end_id)}")
slice_start_id = kwargs.pop("slice_start_id", None) slice_start_id = kwargs.pop("slice_start_id", None)
if slice_start_id is not None and not isinstance(
slice_start_id, torch.Tensor):
raise ValueError("Incorrect type of slice_start_id. "
f"Got type: {type(slice_start_id)}")
slice_end_id = kwargs.pop("slice_end_id", None) slice_end_id = kwargs.pop("slice_end_id", None)
mm_orders = { if slice_end_id is not None and not isinstance(slice_end_id,
f"{modality}": kwargs.pop(f"{modality}_orders", None) torch.Tensor):
for modality in ["image", "video", "audio"] raise ValueError("Incorrect type of slice_end_id. "
} f"Got type: {type(slice_end_id)}")
batch_size = max(len(mm_data["image"]["pixel_values"]),
len(mm_data["video"]["pixel_values"])) if len(all_embed_data) > 0:
image_embeds = kwargs.pop("image_embeds", None) if len(all_embed_data) > 1:
video_embeds = kwargs.pop("video_embeds", None) raise ValueError("Incorrect inputs for vision embeddings. "
if image_embeds is not None and video_embeds is not None: "Image embeds and video embeds can not "
raise ValueError( "exist simultaneously.")
"Incorrect inputs for vision embeddings. "
"Image embeds and video embeds can not exist simultaneously.") vision_embeds, = all_embed_data
if video_embeds is not None: if not isinstance(vision_embeds, (torch.Tensor, list)):
image_embeds = video_embeds raise ValueError(f"Incorrect type of vision_embeds. "
if image_embeds is not None: f"Got type: {type(vision_embeds)}")
if not isinstance(image_embeds, (torch.Tensor, list)):
raise ValueError(f"Incorrect type of image embeds. "
f"Got type: {type(image_embeds)}")
image_embeds = torch.concat(
[image_embeds[i] for i in range(len(image_embeds))])
return MiniCPMVImageEmbeddingInputs( return MiniCPMVImageEmbeddingInputs(
type="image_embeds",
image_embeds=flatten_bn(flatten_2d_lists(vision_embeds),
concat=True),
image_bounds=self._get_image_bounds(input_ids, im_start_id, image_bounds=self._get_image_bounds(input_ids, im_start_id,
im_end_id, slice_start_id, im_end_id, slice_start_id,
slice_end_id), slice_end_id),
data=image_embeds,
type="image_embeds",
) )
for modality, modality_mm_data in mm_data.items():
if not isinstance(modality_mm_data["pixel_values"],
(torch.Tensor, list)):
raise ValueError(
"Incorrect type of pixel values. "
f"Got type: {type(modality_mm_data['pixel_values'])}")
if not isinstance(modality_mm_data["tgt_sizes"], order_data = dict[str, Union[torch.Tensor, list[torch.Tensor]]]()
(torch.Tensor, list)): for modality in ("image", "video"):
raise ValueError( modality_orders = kwargs.pop(f"{modality}_orders", None)
"Incorrect type of target sizes. " if modality_orders is not None:
f"Got type: {type(modality_mm_data['tgt_sizes'])}") if not isinstance(modality_orders, (torch.Tensor, list)):
raise ValueError(f"Incorrect type of {modality}_orders. "
f"Got type: {type(modality_orders)}")
if len(modality_mm_data["pixel_values"]) != len( order_data[modality] = modality_orders
modality_mm_data["tgt_sizes"]):
raise ValueError(
"Inconsistent batch lengths, found: "
f"{len(modality_mm_data['pixel_values'])} vs. "
f"{len(modality_mm_data['tgt_sizes'])}")
pixel_values_flat: List[torch.Tensor] = [] batch_sizes = {
tgt_sizes_flat: List[torch.Tensor] = [] modality: len(modality_orders)
for modality, modality_orders in order_data.items()
}
unique_batch_sizes = set(batch_sizes.values())
assert len(unique_batch_sizes) == 1, (
f"Found inconsistent batch sizes: {batch_sizes}")
batch_size, = unique_batch_sizes
pixel_values_flat = list[torch.Tensor]()
tgt_sizes_flat = list[torch.Tensor]()
for b in range(batch_size): for b in range(batch_size):
mm_counts = {"image": 0, "video": 0} if self.version == (2, 6) \ mm_orders_b = [(idx_b.item(), modality)
else {"image": 0} for modality, modality_orders in order_data.items()
mm_slice_counts = {"image": 0, "video": 0} \ for idx_b in modality_orders[b]]
if self.version == (2, 6) else {"image": 0}
mm_orders_b = [(index, modality) for modality in mm_counts
for index in mm_orders[modality][b]]
for _, modality in sorted(mm_orders_b, key=lambda x: x[0]): for _, modality in sorted(mm_orders_b, key=lambda x: x[0]):
pos = mm_counts[modality] modality_pixel_data = pixel_data[modality]
num_slices = mm_data[modality][f"{modality}_num_slices"][b][
pos] modality_pixel_values = modality_pixel_data["pixel_values"]
slice_start_idx = mm_slice_counts[modality] if not isinstance(modality_pixel_values, (torch.Tensor, list)):
slice_end_idx = slice_start_idx + num_slices raise ValueError(
pixel_values_flat += mm_data[modality]["pixel_values"][b][ f"Incorrect type of pixel_values for {modality=}. "
slice_start_idx:slice_end_idx] f"Got type: {type(modality_pixel_values)}")
tgt_sizes_flat += mm_data[modality]["tgt_sizes"][b][
slice_start_idx:slice_end_idx] modality_tgt_sizes = modality_pixel_data["tgt_sizes"]
mm_counts[modality] += 1 if not isinstance(modality_tgt_sizes, (torch.Tensor, list)):
mm_slice_counts[modality] += num_slices raise ValueError(
f"Incorrect type of tgt_sizes for {modality=}. "
f"Got type: {type(modality_tgt_sizes)}")
pixel_values_flat += flatten_2d_lists(modality_pixel_values[b])
tgt_sizes_flat += flatten_2d_lists(modality_tgt_sizes[b])
# NOTE: Input IDs does not contain image tokens during memory profiling, # NOTE: Input IDs does not contain image tokens during memory profiling,
# so we allow it to be empty # so we allow it to be empty
@ -1042,16 +964,13 @@ class MiniCPMVBaseModel(nn.Module, SupportsMultiModal, SupportsPP,
if len(pixel_values_flat) == 0: if len(pixel_values_flat) == 0:
return None return None
if im_start_id is None:
return None
return MiniCPMVImagePixelInputs( return MiniCPMVImagePixelInputs(
type="pixel_values",
pixel_values=pixel_values_flat,
tgt_sizes=torch.stack(tgt_sizes_flat),
image_bounds=self._get_image_bounds(input_ids, im_start_id, image_bounds=self._get_image_bounds(input_ids, im_start_id,
im_end_id, slice_start_id, im_end_id, slice_start_id,
slice_end_id), slice_end_id),
data=pixel_values_flat,
tgt_sizes=torch.stack(tgt_sizes_flat),
type="pixel_values",
) )
def _parse_and_validate_inputs(self, input_ids: torch.Tensor, def _parse_and_validate_inputs(self, input_ids: torch.Tensor,
@ -1070,7 +989,7 @@ class MiniCPMVBaseModel(nn.Module, SupportsMultiModal, SupportsPP,
else: else:
image_inputs = \ image_inputs = \
self._parse_and_validate_inputs(input_ids, **kwargs) self._parse_and_validate_inputs(input_ids, **kwargs)
vlm_embeddings, _ = self.get_embedding_with_vision( vlm_embeddings = self.get_embedding_with_vision(
input_ids, image_inputs) input_ids, image_inputs)
# always pass the input via `inputs_embeds` # always pass the input via `inputs_embeds`
@ -1136,16 +1055,8 @@ class MiniCPMVBaseModel(nn.Module, SupportsMultiModal, SupportsPP,
prefix: str = "") -> nn.Module: prefix: str = "") -> nn.Module:
raise NotImplementedError raise NotImplementedError
def get_vision_embedding( def get_vision_hidden_states(
self, self, data: MiniCPMVImagePixelInputs) -> torch.Tensor:
pixel_values: List[torch.Tensor],
patch_attn_mask: Optional[torch.Tensor] = None,
tgt_sizes: Optional[torch.Tensor] = None,
) -> torch.Tensor:
raise NotImplementedError
def get_vision_hidden_states(self,
data: MiniCPMVImageInputs) -> torch.Tensor:
raise NotImplementedError raise NotImplementedError
@ -1216,35 +1127,27 @@ class MiniCPMV2_0(MiniCPMVBaseModel):
return resampler.to(device=current_platform.device_type, return resampler.to(device=current_platform.device_type,
dtype=torch.get_default_dtype()) dtype=torch.get_default_dtype())
def get_vision_embedding( def get_vision_hidden_states(
self, self, data: MiniCPMVImagePixelInputs) -> torch.Tensor:
pixel_values: List[torch.Tensor], pixel_values = data["pixel_values"]
patch_attn_mask: Optional[torch.Tensor] = None,
tgt_sizes: Optional[torch.Tensor] = None, P_h, P_w = self.vpm.patch_embed.patch_size
) -> torch.Tensor: dtype: torch.dtype = self.vpm.pos_embed.data.dtype
res = [] num_prefix_tokens = getattr(self.vpm, "num_prefix_tokens", 0)
dtype = self.vpm.pos_embed.data.dtype
res = list[torch.Tensor]()
for pixel_value in pixel_values: for pixel_value in pixel_values:
H, W = pixel_value[0].shape[-2:] H, W = pixel_value[0].shape[-2:]
tgt_size = ( tgt_size = (math.ceil(H / P_h), math.ceil(W / P_w))
math.ceil(H / self.vpm.patch_embed.patch_size[0]),
math.ceil(W / self.vpm.patch_embed.patch_size[0]),
)
vision_embedding = self.vpm.forward_features( vision_embedding = self.vpm.forward_features(
pixel_value.unsqueeze(0).type(dtype)) pixel_value.unsqueeze(0).type(dtype))
if (hasattr(self.vpm, "num_prefix_tokens")
and self.vpm.num_prefix_tokens > 0): if num_prefix_tokens > 0:
vision_embedding = vision_embedding[:, self.vpm. vision_embedding = vision_embedding[:, num_prefix_tokens:]
num_prefix_tokens:]
res.append(self.resampler(vision_embedding, tgt_size)) res.append(self.resampler(vision_embedding, tgt_size))
return torch.vstack(res) return torch.vstack(res)
def get_vision_hidden_states(self,
data: MiniCPMVImageInputs) -> torch.Tensor:
pixel_values = data["data"]
return self.get_vision_embedding(pixel_values)
class MiniCPMV2_5(MiniCPMVBaseModel, SupportsLoRA): class MiniCPMV2_5(MiniCPMVBaseModel, SupportsLoRA):
packed_modules_mapping = { packed_modules_mapping = {
@ -1299,45 +1202,41 @@ class MiniCPMV2_5(MiniCPMVBaseModel, SupportsLoRA):
return resampler.to(device=current_platform.device_type, return resampler.to(device=current_platform.device_type,
dtype=torch.get_default_dtype()) dtype=torch.get_default_dtype())
def get_vision_embedding( def get_vision_hidden_states(
self, self, data: MiniCPMVImagePixelInputs) -> torch.Tensor:
pixel_values: List[torch.Tensor], pixel_values = data["pixel_values"]
patch_attn_mask: Optional[torch.Tensor] = None,
tgt_sizes: Optional[torch.Tensor] = None,
) -> torch.Tensor:
vision_embedding = self.vpm(pixel_values,
patch_attention_mask=patch_attn_mask)
vision_embedding = self.resampler(vision_embedding, tgt_sizes)
return vision_embedding
def get_vision_hidden_states(self,
data: MiniCPMVImageInputs) -> torch.Tensor:
pixel_values = data["data"]
tgt_sizes = data["tgt_sizes"] tgt_sizes = data["tgt_sizes"]
device = self.vpm.embeddings.position_embedding.weight.device B = len(pixel_values)
dtype = self.vpm.embeddings.position_embedding.weight.dtype P = pixel_values[0].shape[-2]
all_pixel_values_lst = [ L = max(item.shape[-1] for item in pixel_values)
i.flatten(end_dim=1).permute(1, 0) for i in pixel_values device = pixel_values[0].device
] dtype = pixel_values[0].dtype
max_patches = (tgt_sizes[:, 0] * tgt_sizes[:, 1]).max().item() all_pixel_values = torch.zeros((B, 3, P, L),
dtype=dtype,
device=device)
for i, pixel_values_item in enumerate(pixel_values):
L_item = pixel_values_item.shape[-1]
all_pixel_values[i, ..., :L_item] = pixel_values_item
num_patches = tgt_sizes.prod(-1)
max_patches = num_patches.max().item()
assert isinstance(max_patches, int) assert isinstance(max_patches, int)
all_pixel_values = torch.nn.utils.rnn.pad_sequence( patch_attn_mask = torch.zeros((B, max_patches),
all_pixel_values_lst, batch_first=True, padding_value=0.0)
B, L, _ = all_pixel_values.shape
all_pixel_values = all_pixel_values.permute(0, 2,
1).reshape(B, 3, -1, L)
patch_attn_mask = torch.zeros((B, 1, max_patches),
dtype=torch.bool, dtype=torch.bool,
device=device) device=device)
for i in range(B): for i, num_patches_item in enumerate(num_patches):
patch_attn_mask[i, :tgt_sizes[i][0] * tgt_sizes[i][1]] = True patch_attn_mask[i, :num_patches_item] = True
return self.get_vision_embedding(all_pixel_values.type(dtype), vision_embedding = self.vpm(
patch_attn_mask, tgt_sizes) all_pixel_values,
patch_attention_mask=patch_attn_mask.unsqueeze(1),
tgt_sizes=None,
)
return self.resampler(vision_embedding, tgt_sizes)
class MiniCPMV2_6(MiniCPMVBaseModel, SupportsLoRA): class MiniCPMV2_6(MiniCPMVBaseModel, SupportsLoRA):
@ -1394,47 +1293,37 @@ class MiniCPMV2_6(MiniCPMVBaseModel, SupportsLoRA):
return resampler.to(device=current_platform.device_type, return resampler.to(device=current_platform.device_type,
dtype=torch.get_default_dtype()) dtype=torch.get_default_dtype())
def get_vision_embedding( def get_vision_hidden_states(
self, self, data: MiniCPMVImagePixelInputs) -> torch.Tensor:
pixel_values: List[torch.Tensor], pixel_values = data["pixel_values"]
patch_attn_mask: Optional[torch.Tensor] = None,
tgt_sizes: Optional[torch.Tensor] = None,
) -> torch.Tensor:
vision_embedding = self.vpm(
pixel_values,
patch_attention_mask=patch_attn_mask,
tgt_sizes=tgt_sizes,
)
return vision_embedding
def get_vision_hidden_states(self,
data: MiniCPMVImageInputs) -> torch.Tensor:
pixel_values = data["data"]
tgt_sizes = data["tgt_sizes"] tgt_sizes = data["tgt_sizes"]
device = self.vpm.embeddings.position_embedding.weight.device B = len(pixel_values)
dtype = self.vpm.embeddings.position_embedding.weight.dtype P = pixel_values[0].shape[-2]
all_pixel_values_lst = [ L = max(item.shape[-1] for item in pixel_values)
i.flatten(end_dim=1).permute(1, 0) for i in pixel_values device = pixel_values[0].device
] dtype = pixel_values[0].dtype
max_patches = (tgt_sizes[:, 0] * tgt_sizes[:, 1]).max().item() all_pixel_values = torch.zeros((B, 3, P, L),
dtype=dtype,
device=device)
for i, pixel_values_item in enumerate(pixel_values):
L_item = pixel_values_item.shape[-1]
all_pixel_values[i, ..., :L_item] = pixel_values_item
num_patches = tgt_sizes.prod(-1)
max_patches = num_patches.max().item()
assert isinstance(max_patches, int) assert isinstance(max_patches, int)
all_pixel_values = torch.nn.utils.rnn.pad_sequence( patch_attn_mask = torch.zeros((B, max_patches),
all_pixel_values_lst, batch_first=True, padding_value=0.0)
B, L, _ = all_pixel_values.shape
all_pixel_values = all_pixel_values.permute(0, 2,
1).reshape(B, 3, -1, L)
patch_attn_mask = torch.zeros((B, 1, max_patches),
dtype=torch.bool, dtype=torch.bool,
device=device) device=device)
for i in range(B): for i, num_patches_item in enumerate(num_patches):
patch_attn_mask[i, 0, :tgt_sizes[i][0] * tgt_sizes[i][1]] = True patch_attn_mask[i, :num_patches_item] = True
vision_embedding = self.vpm( vision_embedding = self.vpm(
all_pixel_values.type(dtype), all_pixel_values,
patch_attention_mask=patch_attn_mask, patch_attention_mask=patch_attn_mask.unsqueeze(1),
tgt_sizes=tgt_sizes, tgt_sizes=tgt_sizes,
) )

View File

@ -665,6 +665,13 @@ class MultiModalKwargs(UserDict[str, NestedTensors]):
return cast(BatchedTensorInputs, json_mapped) return cast(BatchedTensorInputs, json_mapped)
def __delitem__(self, key: str) -> None:
super().__delitem__(key)
for items in self._items_by_modality.values():
for item in items:
item.pop(key, None)
def __eq__(self, other: object) -> bool: def __eq__(self, other: object) -> bool:
if not isinstance(other, self.__class__): if not isinstance(other, self.__class__):
return False return False