[Misc] Clean up MiniCPM-V/O code (#15337)
Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
parent
3e2f37a69a
commit
a9e879b316
@ -361,6 +361,7 @@ def run_llava_next_video(questions: list[str],
|
||||
engine_args = EngineArgs(
|
||||
model="llava-hf/LLaVA-NeXT-Video-7B-hf",
|
||||
max_model_len=8192,
|
||||
max_num_seqs=2,
|
||||
disable_mm_preprocessor_cache=args.disable_mm_preprocessor_cache,
|
||||
)
|
||||
|
||||
|
@ -163,24 +163,24 @@ VLM_TEST_SETTINGS = {
|
||||
marks=[pytest.mark.core_model, pytest.mark.cpu_model],
|
||||
),
|
||||
#### Extended model tests
|
||||
# "aria": VLMTestInfo(
|
||||
# models=["rhymes-ai/Aria"],
|
||||
# test_type=(VLMTestType.IMAGE, VLMTestType.MULTI_IMAGE),
|
||||
# prompt_formatter=lambda img_prompt: f"<|im_start|>user\n{img_prompt}<|im_end|>\n<|im_start|>assistant\n ", # noqa: E501
|
||||
# img_idx_to_prompt=lambda idx: "<fim_prefix><|img|><fim_suffix>\n",
|
||||
# max_model_len=4096,
|
||||
# max_num_seqs=2,
|
||||
# auto_cls=AutoModelForImageTextToText,
|
||||
# single_image_prompts=IMAGE_ASSETS.prompts({
|
||||
# "stop_sign": "<vlm_image>Please describe the image shortly.",
|
||||
# "cherry_blossom": "<vlm_image>Please infer the season with reason.", # noqa: E501
|
||||
# }),
|
||||
# multi_image_prompt="<vlm_image><vlm_image>Describe the two images shortly.", # noqa: E501
|
||||
# stop_str=["<|im_end|>"],
|
||||
# image_size_factors=[(0.10, 0.15)],
|
||||
# max_tokens=64,
|
||||
# marks=[large_gpu_mark(min_gb=64)],
|
||||
# ),
|
||||
"aria": VLMTestInfo(
|
||||
models=["rhymes-ai/Aria"],
|
||||
test_type=(VLMTestType.IMAGE, VLMTestType.MULTI_IMAGE),
|
||||
prompt_formatter=lambda img_prompt: f"<|im_start|>user\n{img_prompt}<|im_end|>\n<|im_start|>assistant\n ", # noqa: E501
|
||||
img_idx_to_prompt=lambda idx: "<fim_prefix><|img|><fim_suffix>\n",
|
||||
max_model_len=4096,
|
||||
max_num_seqs=2,
|
||||
auto_cls=AutoModelForImageTextToText,
|
||||
single_image_prompts=IMAGE_ASSETS.prompts({
|
||||
"stop_sign": "<vlm_image>Please describe the image shortly.",
|
||||
"cherry_blossom": "<vlm_image>Please infer the season with reason.", # noqa: E501
|
||||
}),
|
||||
multi_image_prompt="<vlm_image><vlm_image>Describe the two images shortly.", # noqa: E501
|
||||
stop_str=["<|im_end|>"],
|
||||
image_size_factors=[(0.10, 0.15)],
|
||||
max_tokens=64,
|
||||
marks=[large_gpu_mark(min_gb=64)],
|
||||
),
|
||||
"blip2": VLMTestInfo(
|
||||
models=["Salesforce/blip2-opt-2.7b"],
|
||||
test_type=VLMTestType.IMAGE,
|
||||
@ -352,6 +352,7 @@ VLM_TEST_SETTINGS = {
|
||||
prompt_formatter=lambda vid_prompt: f"USER: {vid_prompt} ASSISTANT:",
|
||||
num_video_frames=16,
|
||||
max_model_len=4096,
|
||||
max_num_seqs=2,
|
||||
auto_cls=AutoModelForVision2Seq,
|
||||
vllm_output_post_proc=model_utils.llava_video_vllm_to_hf_output,
|
||||
),
|
||||
@ -384,7 +385,7 @@ VLM_TEST_SETTINGS = {
|
||||
),
|
||||
"minicpmo_26": VLMTestInfo(
|
||||
models=["openbmb/MiniCPM-o-2_6"],
|
||||
test_type=(VLMTestType.IMAGE, VLMTestType.MULTI_IMAGE),
|
||||
test_type=(VLMTestType.IMAGE),
|
||||
prompt_formatter=lambda img_prompt: f"<|begin_of_text|><|start_header_id|>user<|end_header_id|>\n\n{img_prompt}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n", # noqa: E501
|
||||
img_idx_to_prompt=lambda idx: "(<image>./</image>)\n",
|
||||
max_model_len=4096,
|
||||
@ -393,9 +394,21 @@ VLM_TEST_SETTINGS = {
|
||||
hf_output_post_proc=model_utils.minicpmv_trunc_hf_output,
|
||||
patch_hf_runner=model_utils.minicpmo_26_patch_hf_runner,
|
||||
),
|
||||
"minicpmo_26_multi_image": VLMTestInfo(
|
||||
models=["openbmb/MiniCPM-o-2_6"],
|
||||
test_type=(VLMTestType.MULTI_IMAGE),
|
||||
prompt_formatter=lambda img_prompt: f"<|begin_of_text|><|start_header_id|>user<|end_header_id|>\n\n{img_prompt}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n", # noqa: E501
|
||||
img_idx_to_prompt=lambda idx: "(<image>./</image>)\n",
|
||||
max_model_len=4096,
|
||||
max_num_seqs=2,
|
||||
get_stop_token_ids=lambda tok: tok.convert_tokens_to_ids(['<|im_end|>', '<|endoftext|>']), # noqa: E501
|
||||
hf_output_post_proc=model_utils.minicpmv_trunc_hf_output,
|
||||
patch_hf_runner=model_utils.minicpmo_26_patch_hf_runner,
|
||||
marks=[large_gpu_mark(min_gb=32)],
|
||||
),
|
||||
"minicpmv_26": VLMTestInfo(
|
||||
models=["openbmb/MiniCPM-V-2_6"],
|
||||
test_type=(VLMTestType.IMAGE, VLMTestType.MULTI_IMAGE),
|
||||
test_type=(VLMTestType.IMAGE),
|
||||
prompt_formatter=lambda img_prompt: f"<|begin_of_text|><|start_header_id|>user<|end_header_id|>\n\n{img_prompt}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n", # noqa: E501
|
||||
img_idx_to_prompt=lambda idx: "(<image>./</image>)\n",
|
||||
max_model_len=4096,
|
||||
@ -404,6 +417,18 @@ VLM_TEST_SETTINGS = {
|
||||
hf_output_post_proc=model_utils.minicpmv_trunc_hf_output,
|
||||
patch_hf_runner=model_utils.minicpmv_26_patch_hf_runner,
|
||||
),
|
||||
"minicpmv_26_multi_image": VLMTestInfo(
|
||||
models=["openbmb/MiniCPM-V-2_6"],
|
||||
test_type=(VLMTestType.MULTI_IMAGE),
|
||||
prompt_formatter=lambda img_prompt: f"<|begin_of_text|><|start_header_id|>user<|end_header_id|>\n\n{img_prompt}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n", # noqa: E501
|
||||
img_idx_to_prompt=lambda idx: "(<image>./</image>)\n",
|
||||
max_model_len=4096,
|
||||
max_num_seqs=2,
|
||||
get_stop_token_ids=lambda tok: tok.convert_tokens_to_ids(['<|im_end|>', '<|endoftext|>']), # noqa: E501
|
||||
hf_output_post_proc=model_utils.minicpmv_trunc_hf_output,
|
||||
patch_hf_runner=model_utils.minicpmv_26_patch_hf_runner,
|
||||
marks=[large_gpu_mark(min_gb=32)],
|
||||
),
|
||||
"molmo": VLMTestInfo(
|
||||
models=["allenai/Molmo-7B-D-0924"],
|
||||
test_type=(VLMTestType.IMAGE),
|
||||
|
@ -1,6 +1,5 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
import copy
|
||||
from functools import partial
|
||||
from typing import Optional, Union
|
||||
|
||||
@ -29,7 +28,7 @@ def _test_processing_correctness(
|
||||
hit_rate: float,
|
||||
num_batches: int,
|
||||
simplify_rate: float,
|
||||
ignore_mm_keys: Optional[list[str]] = None,
|
||||
ignore_mm_keys: Optional[set[str]] = None,
|
||||
):
|
||||
model_info = HF_EXAMPLE_MODELS.find_hf_info(model_id)
|
||||
model_info.check_available_online(on_fail="skip")
|
||||
@ -145,7 +144,7 @@ def _test_processing_correctness_hf(
|
||||
baseline_processor: BaseMultiModalProcessor,
|
||||
cached_processor: BaseMultiModalProcessor,
|
||||
batch_idx: int,
|
||||
ignore_mm_keys: Optional[list[str]] = None,
|
||||
ignore_mm_keys: Optional[set[str]] = None,
|
||||
):
|
||||
if model_config.hf_config.model_type in ("mllama", "whisper", "ultravox"):
|
||||
# For some multimodal models, tokenizer will always add bos_token
|
||||
@ -167,11 +166,12 @@ def _test_processing_correctness_hf(
|
||||
hf_processor_mm_kwargs={},
|
||||
)
|
||||
|
||||
assert _inputs_equal(
|
||||
_assert_inputs_equal(
|
||||
baseline_result,
|
||||
cached_result,
|
||||
ignore_mm_keys,
|
||||
), f"Failed ({batch_idx=}, {prompt=}, {mm_data=})"
|
||||
ignore_mm_keys=ignore_mm_keys,
|
||||
msg=f"Failed ({batch_idx=}, {prompt=}, {mm_data=})",
|
||||
)
|
||||
|
||||
baseline_tokenized_result = baseline_processor.apply(
|
||||
token_prompt,
|
||||
@ -179,11 +179,12 @@ def _test_processing_correctness_hf(
|
||||
hf_processor_mm_kwargs={},
|
||||
)
|
||||
|
||||
assert _inputs_equal(
|
||||
_assert_inputs_equal(
|
||||
baseline_result,
|
||||
baseline_tokenized_result,
|
||||
ignore_mm_keys,
|
||||
), f"Failed ({batch_idx=}, {prompt=}, {mm_data=})"
|
||||
ignore_mm_keys=ignore_mm_keys,
|
||||
msg=f"Failed ({batch_idx=}, {prompt=}, {mm_data=})",
|
||||
)
|
||||
|
||||
cached_tokenized_result = cached_processor.apply(
|
||||
token_prompt,
|
||||
@ -191,11 +192,12 @@ def _test_processing_correctness_hf(
|
||||
hf_processor_mm_kwargs={},
|
||||
)
|
||||
|
||||
assert _inputs_equal(
|
||||
_assert_inputs_equal(
|
||||
cached_result,
|
||||
cached_tokenized_result,
|
||||
ignore_mm_keys,
|
||||
), f"Failed ({batch_idx=}, {prompt=}, {mm_data=})"
|
||||
ignore_mm_keys=ignore_mm_keys,
|
||||
msg=f"Failed ({batch_idx=}, {prompt=}, {mm_data=})",
|
||||
)
|
||||
|
||||
|
||||
def _test_processing_correctness_mistral(
|
||||
@ -206,7 +208,7 @@ def _test_processing_correctness_mistral(
|
||||
baseline_processor: BaseMultiModalProcessor,
|
||||
cached_processor: BaseMultiModalProcessor,
|
||||
batch_idx: int,
|
||||
ignore_mm_keys: Optional[list[str]] = None,
|
||||
ignore_mm_keys: Optional[set[str]] = None,
|
||||
):
|
||||
images = mm_data.get("image", [])
|
||||
if not isinstance(images, list):
|
||||
@ -233,11 +235,12 @@ def _test_processing_correctness_mistral(
|
||||
hf_processor_mm_kwargs={},
|
||||
)
|
||||
|
||||
assert _inputs_equal(
|
||||
_assert_inputs_equal(
|
||||
baseline_tokenized_result,
|
||||
cached_tokenized_result,
|
||||
ignore_mm_keys,
|
||||
), f"Failed ({batch_idx=}, {prompt=}, {mm_data=})"
|
||||
ignore_mm_keys=ignore_mm_keys,
|
||||
msg=f"Failed ({batch_idx=}, {prompt=}, {mm_data=})",
|
||||
)
|
||||
|
||||
|
||||
# yapf: disable
|
||||
@ -261,6 +264,7 @@ def _test_processing_correctness_mistral(
|
||||
"TIGER-Lab/Mantis-8B-siglip-llama3",
|
||||
"mistralai/Pixtral-12B-2409",
|
||||
"mistral-community/pixtral-12b",
|
||||
"openbmb/MiniCPM-Llama3-V-2_5",
|
||||
"openbmb/MiniCPM-o-2_6",
|
||||
"openbmb/MiniCPM-V-2_6",
|
||||
"allenai/Molmo-7B-D-0924",
|
||||
@ -290,7 +294,7 @@ def test_processing_correctness(
|
||||
# In Ultravox, the audio_features can be different depending on padding
|
||||
# The slight difference should not be a problem though, since
|
||||
# attention_mask lets us ignore the difference.
|
||||
ignore_mm_keys = ['audio_features']
|
||||
ignore_mm_keys = {"audio_features"}
|
||||
|
||||
_test_processing_correctness(
|
||||
model_id,
|
||||
@ -328,38 +332,26 @@ def test_processing_correctness_phi3v(
|
||||
)
|
||||
|
||||
|
||||
def _inputs_equal(
|
||||
def _assert_inputs_equal(
|
||||
a: MultiModalInputs,
|
||||
b: MultiModalInputs,
|
||||
ignore_mm_keys: Optional[list[str]] = None,
|
||||
*,
|
||||
ignore_mm_keys: Optional[set[str]] = None,
|
||||
msg: str = "",
|
||||
):
|
||||
return _drop_mm_kwargs_keys(a, ignore_mm_keys) == _drop_mm_kwargs_keys(
|
||||
b, ignore_mm_keys)
|
||||
if ignore_mm_keys is None:
|
||||
ignore_mm_keys = set()
|
||||
|
||||
if msg is None:
|
||||
assert "mm_kwargs" in a and "mm_kwargs" in b
|
||||
else:
|
||||
assert "mm_kwargs" in a and "mm_kwargs" in b, msg
|
||||
|
||||
def _drop_mm_kwargs_keys(
|
||||
result: MultiModalInputs,
|
||||
ignore_mm_keys: Optional[list[str]] = None,
|
||||
) -> MultiModalInputs:
|
||||
"""Drop specified keys from result['mm_kwargs'].
|
||||
for key in ignore_mm_keys:
|
||||
a["mm_kwargs"].pop(key, None)
|
||||
b["mm_kwargs"].pop(key, None)
|
||||
|
||||
This is mainly to avoid doing exact match of audio_features in ultravox.
|
||||
|
||||
Args:
|
||||
result: Result to drop keys from
|
||||
ignore_mm_keys: List of keys to ignore, e.g. ['audio_features']
|
||||
"""
|
||||
if not ignore_mm_keys:
|
||||
return result
|
||||
|
||||
if 'mm_kwargs' in result:
|
||||
result = copy.deepcopy(result)
|
||||
mm_kwargs = result['mm_kwargs']
|
||||
for key in ignore_mm_keys:
|
||||
mm_kwargs.pop(key, None)
|
||||
for items in mm_kwargs._items_by_modality.values():
|
||||
for item in items:
|
||||
for key in ignore_mm_keys:
|
||||
item.pop(key, None)
|
||||
|
||||
return result
|
||||
if msg is None:
|
||||
assert a == b
|
||||
else:
|
||||
assert a == b, msg
|
||||
|
@ -295,8 +295,6 @@ class Gemma3MultiModalProcessor(BaseMultiModalProcessor[Gemma3ProcessingInfo]):
|
||||
|
||||
# HF processor pops the `num_crops` kwarg, which is needed by vLLM
|
||||
if (images := mm_data.get("images")) is not None:
|
||||
assert isinstance(images, list)
|
||||
|
||||
parsed_images = (self._get_data_parser().parse_mm_data({
|
||||
"image":
|
||||
images
|
||||
|
@ -23,7 +23,7 @@
|
||||
# limitations under the License.
|
||||
"""Inference-only MiniCPM-O model compatible with HuggingFace weights."""
|
||||
from collections.abc import Iterable, Mapping, Sequence
|
||||
from typing import (Any, Callable, Dict, List, Literal, Optional, Set, Tuple,
|
||||
from typing import (Any, Callable, Dict, Literal, Optional, Set, Tuple,
|
||||
TypedDict, Union)
|
||||
|
||||
import torch
|
||||
@ -43,24 +43,26 @@ from vllm.multimodal.parse import (AudioItem, AudioProcessorItems,
|
||||
from vllm.multimodal.processing import PromptReplacement, PromptUpdate
|
||||
from vllm.multimodal.profiling import ProcessorInputs
|
||||
from vllm.sequence import IntermediateTensors
|
||||
from vllm.utils import flatten_2d_lists
|
||||
|
||||
from .minicpmv import (MiniCPMV2_6, MiniCPMVDummyInputsBuilder,
|
||||
MiniCPMVMultiModalDataParser,
|
||||
MiniCPMVMultiModalProcessor, MiniCPMVProcessingInfo,
|
||||
_minicpmv_field_config)
|
||||
from .utils import AutoWeightsLoader, cast_overflow_tensors, maybe_prefix
|
||||
from .utils import (AutoWeightsLoader, cast_overflow_tensors, flatten_bn,
|
||||
maybe_prefix)
|
||||
|
||||
CPU_DEVICE = torch.device("cpu")
|
||||
|
||||
|
||||
class MiniCPMOAudioFeatureInputs(TypedDict):
|
||||
type: Literal["audio_features"]
|
||||
data: torch.Tensor
|
||||
audio_features: torch.Tensor
|
||||
"""
|
||||
Shape: `(batch_size * num_audios * num_slices, num_channels, length)`
|
||||
Slice here means chunk. Audio that is too long will be split into slices,
|
||||
which is the same as image.
|
||||
Padding is used therefore `data` is `torch.Tensor`.
|
||||
Padding is used therefore `audio_features` is `torch.Tensor`.
|
||||
"""
|
||||
|
||||
audio_feature_lens: torch.Tensor
|
||||
@ -68,7 +70,7 @@ class MiniCPMOAudioFeatureInputs(TypedDict):
|
||||
Shape: `(batch_size * num_audios * num_slices)`
|
||||
|
||||
This should be feature length of each audio slice,
|
||||
which equals to `data.shape[-1]`
|
||||
which equals to `audio_features.shape[-1]`
|
||||
"""
|
||||
|
||||
audio_bounds: torch.Tensor
|
||||
@ -81,7 +83,7 @@ class MiniCPMOAudioFeatureInputs(TypedDict):
|
||||
|
||||
class MiniCPMOAudioEmbeddingInputs(TypedDict):
|
||||
type: Literal["audio_embeds"]
|
||||
data: List[torch.Tensor]
|
||||
audio_embeds: torch.Tensor
|
||||
"""
|
||||
Shape: `(batch_size * num_images * num_slices, hidden_size)`
|
||||
|
||||
@ -102,18 +104,11 @@ MiniCPMOAudioInputs = Union[MiniCPMOAudioFeatureInputs,
|
||||
|
||||
|
||||
def _minicpmo_field_config(hf_inputs: Mapping[str, torch.Tensor]):
|
||||
audio_num_slices = hf_inputs.get("audio_num_slices", torch.empty(0))
|
||||
|
||||
return dict(
|
||||
**_minicpmv_field_config(hf_inputs),
|
||||
audio_features=MultiModalFieldConfig.flat_from_sizes(
|
||||
"audio", audio_num_slices),
|
||||
audio_feature_lens=MultiModalFieldConfig.flat_from_sizes(
|
||||
"audio", audio_num_slices),
|
||||
audio_num_slices=MultiModalFieldConfig.batched("audio"),
|
||||
audio_orders_in_mm_data=MultiModalFieldConfig.batched("audio"),
|
||||
audio_embeds=MultiModalFieldConfig.flat_from_sizes(
|
||||
"audio", audio_num_slices),
|
||||
audio_features=MultiModalFieldConfig.batched("audio"),
|
||||
audio_feature_lens=MultiModalFieldConfig.batched("audio"),
|
||||
audio_embeds=MultiModalFieldConfig.batched("audio"),
|
||||
)
|
||||
|
||||
|
||||
@ -153,9 +148,6 @@ class MiniCPMOMultiModalDataParser(MiniCPMVMultiModalDataParser):
|
||||
class MiniCPMOProcessingInfo(MiniCPMVProcessingInfo):
|
||||
audio_pattern = "(<audio>./</audio>)"
|
||||
|
||||
def get_supported_mm_modalities(self) -> List[str]:
|
||||
return ["image", "video", "audio"]
|
||||
|
||||
def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]:
|
||||
return {"image": None, "video": None, "audio": None}
|
||||
|
||||
@ -277,95 +269,47 @@ class MiniCPMOMultiModalProcessor(
|
||||
mm_data: Mapping[str, object],
|
||||
mm_kwargs: Mapping[str, object],
|
||||
) -> Mapping[str, NestedTensors]:
|
||||
mm_data = dict(mm_data)
|
||||
if (audios := mm_data.get("audios")) is None:
|
||||
return {}
|
||||
|
||||
audios = mm_data.pop("audios", [])
|
||||
audio_embeds = mm_data.pop("audio_embeds", [])
|
||||
if isinstance(audios, (list, torch.Tensor)) and len(audios) > 0:
|
||||
audio_outputs = {
|
||||
"audio_lens": [],
|
||||
"audio_features": [],
|
||||
"audio_feature_lens": [],
|
||||
"audio_num_segments": []
|
||||
}
|
||||
for audio in audios:
|
||||
single_audio_outputs = super().call_base_hf_processor(
|
||||
prompt=self.info.audio_pattern,
|
||||
mm_data={
|
||||
"audios": audio,
|
||||
"chunk_input": True
|
||||
},
|
||||
mm_kwargs=mm_kwargs)
|
||||
audio_outputs["audio_lens"].append(len(audio))
|
||||
audio_outputs["audio_features"].append(
|
||||
single_audio_outputs["audio_features"])
|
||||
audio_outputs["audio_num_segments"].append(
|
||||
len(single_audio_outputs["audio_feature_lens"][0]))
|
||||
audio_outputs["audio_feature_lens"] += \
|
||||
single_audio_outputs["audio_feature_lens"]
|
||||
audio_outputs["audio_features"] = [
|
||||
audio_feature for single_audio_features in \
|
||||
audio_outputs["audio_features"]
|
||||
for audio_feature in single_audio_features
|
||||
]
|
||||
audio_outputs["audio_feature_lens"] = torch.cat(
|
||||
audio_outputs["audio_feature_lens"])
|
||||
elif len(audio_embeds):
|
||||
audio_outputs = {
|
||||
"audio_lens": [
|
||||
self.info.get_audio_len_by_num_chunks(
|
||||
sum(chunk_embeds.shape[0]
|
||||
for chunk_embeds in single_audio_embeds))
|
||||
for single_audio_embeds in audio_embeds
|
||||
],
|
||||
"audio_embeds": [
|
||||
chunk_embeds for single_audio_embeds in audio_embeds
|
||||
for chunk_embeds in single_audio_embeds
|
||||
],
|
||||
"audio_num_segments": [
|
||||
len(single_audio_embeds)
|
||||
for single_audio_embeds in audio_embeds
|
||||
]
|
||||
}
|
||||
else:
|
||||
audio_outputs = {}
|
||||
return audio_outputs
|
||||
parsed_audios = (self._get_data_parser().parse_mm_data({
|
||||
"audio": audios
|
||||
}).get_items("audio", AudioProcessorItems))
|
||||
|
||||
audio_inputs = self._base_call_hf_processor(
|
||||
prompts=[self.info.audio_pattern] * len(parsed_audios),
|
||||
mm_data={"audios": [[audio] for audio in parsed_audios]},
|
||||
mm_kwargs={
|
||||
**mm_kwargs, "chunk_input": True
|
||||
},
|
||||
out_keys={"audio_features", "audio_feature_lens"},
|
||||
)
|
||||
|
||||
# Avoid padding since we need the output for each audio to be
|
||||
# independent of other audios for the cache to work correctly
|
||||
unpadded_audio_features = [
|
||||
feat[:, :feature_len] for feat, feature_len in zip(
|
||||
audio_inputs["audio_features"],
|
||||
audio_inputs["audio_feature_lens"],
|
||||
)
|
||||
]
|
||||
audio_inputs["audio_features"] = unpadded_audio_features
|
||||
|
||||
return audio_inputs
|
||||
|
||||
def get_placeholder_match_pattern(self) -> str:
|
||||
return r"\(<(image|video|audio)>./</\1>\)"
|
||||
|
||||
def get_placeholder_split_pattern(self) -> str:
|
||||
return r"\(<(?:image|video|audio)>./</(?:image|video|audio)>\)"
|
||||
|
||||
def process_mm_inputs(
|
||||
self,
|
||||
mm_data: Mapping[str, object],
|
||||
mm_kwargs: Mapping[str, object],
|
||||
) -> Mapping[str, Mapping[str, NestedTensors]]:
|
||||
) -> Mapping[str, NestedTensors]:
|
||||
return {
|
||||
"image": self.process_images(mm_data, mm_kwargs),
|
||||
"video": self.process_videos(mm_data, mm_kwargs),
|
||||
"audio": self.process_audios(mm_data, mm_kwargs),
|
||||
**super().process_mm_inputs(mm_data, mm_kwargs),
|
||||
**self.process_audios(mm_data, mm_kwargs),
|
||||
}
|
||||
|
||||
def get_modality_num_counter(self, modality: str) -> str:
|
||||
if modality == "audio":
|
||||
return "audio_lens"
|
||||
return super().get_modality_num_counter(modality)
|
||||
|
||||
def get_num_slices_by_modality(self, inputs: Dict[str, object],
|
||||
modality: str, index: int) -> int:
|
||||
if modality == "audio":
|
||||
return inputs["audio"]["audio_num_segments"][index]
|
||||
return super().get_num_slices_by_modality(inputs, modality, index)
|
||||
|
||||
def get_prompt_texts_by_modality(self, inputs: Dict[str, object],
|
||||
modality: str, index: int) -> str:
|
||||
if modality == "audio":
|
||||
return self.get_audio_prompt_texts(
|
||||
inputs["audio"]["audio_lens"][index])
|
||||
return super().get_prompt_texts_by_modality(inputs, modality, index)
|
||||
|
||||
def _get_prompt_updates(
|
||||
self,
|
||||
mm_items: MultiModalDataItems,
|
||||
@ -622,86 +566,84 @@ class MiniCPMO(MiniCPMV2_6):
|
||||
# Copied from HF repo of MiniCPM-o-2_6,
|
||||
# designed for batched inputs and outputs
|
||||
def get_audio_hidden_states(self, data: MiniCPMOAudioInputs,
|
||||
chunk_length: int) -> torch.Tensor:
|
||||
chunk_length: int) -> list[torch.Tensor]:
|
||||
wavforms = data.get(
|
||||
"data",
|
||||
"audio_features",
|
||||
[]) # (bs, 80, frames) or [], multi audios need filled in advance
|
||||
audio_feature_lens_raw = [data.get("audio_feature_lens",
|
||||
[])] # list, [[x1, x2], [y1], [z1]]
|
||||
|
||||
# exist audio
|
||||
if len(wavforms) > 0:
|
||||
audio_feature_lens = torch.hstack(audio_feature_lens_raw)
|
||||
batch_size, _, max_mel_seq_len = wavforms.shape
|
||||
max_seq_len = (max_mel_seq_len - 1) // 2 + 1
|
||||
|
||||
# Create a sequence tensor of shape (batch_size, max_seq_len)
|
||||
seq_range = (torch.arange(
|
||||
0,
|
||||
max_seq_len,
|
||||
dtype=audio_feature_lens.dtype,
|
||||
device=audio_feature_lens.device).unsqueeze(0).expand(
|
||||
batch_size, max_seq_len))
|
||||
lengths_expand = audio_feature_lens.unsqueeze(1).expand(
|
||||
batch_size, max_seq_len)
|
||||
# Create mask
|
||||
padding_mask = seq_range >= lengths_expand # 1 for padded values
|
||||
|
||||
audio_attention_mask_ = padding_mask.view(
|
||||
batch_size, 1, 1, max_seq_len).expand(batch_size, 1,
|
||||
max_seq_len, max_seq_len)
|
||||
audio_attention_mask = audio_attention_mask_.to(
|
||||
dtype=self.apm.conv1.weight.dtype,
|
||||
device=self.apm.conv1.weight.device)
|
||||
|
||||
if chunk_length > 0:
|
||||
chunk_num_frame = int(chunk_length * 50)
|
||||
chunk_mask = self.subsequent_chunk_mask(
|
||||
size=max_seq_len,
|
||||
chunk_size=chunk_num_frame,
|
||||
num_left_chunks=-1,
|
||||
device=audio_attention_mask_.device,
|
||||
)
|
||||
audio_attention_mask_ = torch.logical_or(
|
||||
audio_attention_mask_, torch.logical_not(chunk_mask))
|
||||
|
||||
audio_attention_mask[audio_attention_mask_] = float("-inf")
|
||||
audio_states = self.apm(
|
||||
wavforms, attention_mask=audio_attention_mask).hidden_states[
|
||||
self.audio_encoder_layer]
|
||||
audio_embeds = self.audio_projection_layer(audio_states)
|
||||
|
||||
audio_embeds = audio_embeds.transpose(1, 2)
|
||||
audio_embeds = self.audio_avg_pooler(audio_embeds)
|
||||
audio_embeds = audio_embeds.transpose(1, 2)
|
||||
|
||||
_, feature_lens_after_pooling = \
|
||||
self._get_feat_extract_output_lengths(audio_feature_lens)
|
||||
|
||||
num_audio_tokens = feature_lens_after_pooling
|
||||
|
||||
final_audio_embeds = []
|
||||
idx = 0
|
||||
for i in range(len(audio_feature_lens_raw)):
|
||||
target_audio_embeds = []
|
||||
for _ in range(len(audio_feature_lens_raw[i])):
|
||||
target_audio_embeds.append(
|
||||
audio_embeds[idx, :num_audio_tokens[idx], :])
|
||||
idx += 1
|
||||
final_audio_embeds.append(target_audio_embeds)
|
||||
return final_audio_embeds
|
||||
else:
|
||||
if len(wavforms) == 0:
|
||||
return []
|
||||
|
||||
audio_feature_lens = torch.hstack(audio_feature_lens_raw)
|
||||
batch_size, _, max_mel_seq_len = wavforms.shape
|
||||
max_seq_len = (max_mel_seq_len - 1) // 2 + 1
|
||||
|
||||
# Create a sequence tensor of shape (batch_size, max_seq_len)
|
||||
seq_range = (torch.arange(
|
||||
0,
|
||||
max_seq_len,
|
||||
dtype=audio_feature_lens.dtype,
|
||||
device=audio_feature_lens.device).unsqueeze(0).expand(
|
||||
batch_size, max_seq_len))
|
||||
lengths_expand = audio_feature_lens.unsqueeze(1).expand(
|
||||
batch_size, max_seq_len)
|
||||
# Create mask
|
||||
padding_mask = seq_range >= lengths_expand # 1 for padded values
|
||||
|
||||
audio_attention_mask_ = padding_mask.view(
|
||||
batch_size, 1, 1, max_seq_len).expand(batch_size, 1, max_seq_len,
|
||||
max_seq_len)
|
||||
audio_attention_mask = audio_attention_mask_.to(
|
||||
dtype=self.apm.conv1.weight.dtype,
|
||||
device=self.apm.conv1.weight.device)
|
||||
|
||||
if chunk_length > 0:
|
||||
chunk_num_frame = int(chunk_length * 50)
|
||||
chunk_mask = self.subsequent_chunk_mask(
|
||||
size=max_seq_len,
|
||||
chunk_size=chunk_num_frame,
|
||||
num_left_chunks=-1,
|
||||
device=audio_attention_mask_.device,
|
||||
)
|
||||
audio_attention_mask_ = torch.logical_or(
|
||||
audio_attention_mask_, torch.logical_not(chunk_mask))
|
||||
|
||||
audio_attention_mask[audio_attention_mask_] = float("-inf")
|
||||
audio_states = self.apm(
|
||||
wavforms, attention_mask=audio_attention_mask).hidden_states[
|
||||
self.audio_encoder_layer]
|
||||
audio_embeds = self.audio_projection_layer(audio_states)
|
||||
|
||||
audio_embeds = audio_embeds.transpose(1, 2)
|
||||
audio_embeds = self.audio_avg_pooler(audio_embeds)
|
||||
audio_embeds = audio_embeds.transpose(1, 2)
|
||||
|
||||
_, feature_lens_after_pooling = \
|
||||
self._get_feat_extract_output_lengths(audio_feature_lens)
|
||||
|
||||
num_audio_tokens = feature_lens_after_pooling
|
||||
|
||||
final_audio_embeds = []
|
||||
idx = 0
|
||||
for i in range(len(audio_feature_lens_raw)):
|
||||
target_audio_embeds = []
|
||||
for _ in range(len(audio_feature_lens_raw[i])):
|
||||
target_audio_embeds.append(
|
||||
audio_embeds[idx, :num_audio_tokens[idx], :])
|
||||
idx += 1
|
||||
final_audio_embeds.append(target_audio_embeds)
|
||||
return final_audio_embeds
|
||||
|
||||
def get_embedding_with_audios(self, vlm_embedding: torch.Tensor,
|
||||
audio_inputs: Optional[MiniCPMOAudioInputs],
|
||||
audio_inputs: MiniCPMOAudioInputs,
|
||||
chunk_length: int) -> torch.Tensor:
|
||||
device, dtype = vlm_embedding.device, vlm_embedding.dtype
|
||||
if audio_inputs["type"] == "audio_embeds":
|
||||
audio_embeddings = audio_inputs["data"]
|
||||
audio_embeddings = [
|
||||
audio_embeddings[i].to(device=device, dtype=dtype)
|
||||
for i in range(len(audio_embeddings))
|
||||
item.to(device=device, dtype=dtype)
|
||||
for item in audio_inputs["audio_embeds"]
|
||||
]
|
||||
else:
|
||||
audio_embeddings = self.get_audio_hidden_states(
|
||||
@ -746,40 +688,56 @@ class MiniCPMO(MiniCPMV2_6):
|
||||
|
||||
def _parse_and_validate_audio_inputs(
|
||||
self, input_ids: torch.Tensor,
|
||||
**kwargs: object) -> Tuple[MiniCPMOAudioInputs]:
|
||||
audio_features = kwargs.pop("audio_features", [])
|
||||
audio_feature_lens = kwargs.pop("audio_feature_lens", [])
|
||||
**kwargs: object) -> Optional[MiniCPMOAudioInputs]:
|
||||
audio_features = kwargs.pop("audio_features", None)
|
||||
audio_embeds = kwargs.pop("audio_embeds", None)
|
||||
audio_start_id = kwargs.pop("audio_start_id", None)
|
||||
audio_end_id = kwargs.pop("audio_end_id", None)
|
||||
|
||||
if audio_features is None and audio_embeds is None:
|
||||
return None
|
||||
|
||||
audio_start_id = kwargs.pop("audio_start_id")
|
||||
if not isinstance(audio_start_id, torch.Tensor):
|
||||
raise ValueError("Incorrect type of audio_start_id. "
|
||||
f"Got type: {type(audio_start_id)}")
|
||||
|
||||
audio_end_id = kwargs.pop("audio_end_id")
|
||||
if not isinstance(audio_end_id, torch.Tensor):
|
||||
raise ValueError("Incorrect type of audio_end_id. "
|
||||
f"Got type: {type(audio_end_id)}")
|
||||
|
||||
if audio_embeds is not None:
|
||||
audio_embeds = [
|
||||
audio_embeds[i][j] for i in range(len(audio_embeds))
|
||||
for j in range(len(audio_embeds[i]))
|
||||
]
|
||||
if not isinstance(audio_embeds, (torch.Tensor, list)):
|
||||
raise ValueError("Incorrect type of audio_embeds. "
|
||||
f"Got type: {type(audio_embeds)}")
|
||||
|
||||
return MiniCPMOAudioEmbeddingInputs(
|
||||
type="audio_embeds",
|
||||
audio_embeds=flatten_bn(flatten_2d_lists(audio_embeds),
|
||||
concat=True),
|
||||
audio_bounds=self._get_audio_bounds(input_ids, audio_start_id,
|
||||
audio_end_id),
|
||||
data=audio_embeds,
|
||||
type="audio_embeds")
|
||||
if len(audio_features) > 0:
|
||||
audio_features_all = [
|
||||
i.permute(1, 0) for audio_feature in audio_features
|
||||
for i in audio_feature
|
||||
]
|
||||
audio_features = torch.nn.utils.rnn.pad_sequence(
|
||||
audio_features_all, batch_first=True,
|
||||
padding_value=0.0).permute(0, 2, 1)
|
||||
audio_feature_lens = torch.cat(
|
||||
[item for item in audio_feature_lens])
|
||||
)
|
||||
|
||||
if audio_features is not None:
|
||||
if not isinstance(audio_features, (torch.Tensor, list)):
|
||||
raise ValueError("Incorrect type of audio_features. "
|
||||
f"Got type: {type(audio_features)}")
|
||||
|
||||
audio_feature_lens = kwargs.pop("audio_feature_lens")
|
||||
if not isinstance(audio_feature_lens, (torch.Tensor, list)):
|
||||
raise ValueError("Incorrect type of audio_feature_lens. "
|
||||
f"Got type: {type(audio_feature_lens)}")
|
||||
|
||||
return MiniCPMOAudioFeatureInputs(
|
||||
type="audio_features",
|
||||
audio_features=flatten_bn(audio_features, concat=True),
|
||||
audio_feature_lens=flatten_bn(
|
||||
flatten_2d_lists(audio_feature_lens), concat=True),
|
||||
audio_bounds=self._get_audio_bounds(input_ids, audio_start_id,
|
||||
audio_end_id),
|
||||
data=audio_features,
|
||||
audio_feature_lens=audio_feature_lens,
|
||||
type="audio_features")
|
||||
return None
|
||||
)
|
||||
|
||||
raise AssertionError("This line should be unreachable.")
|
||||
|
||||
def _parse_and_validate_inputs(self, input_ids: torch.Tensor,
|
||||
**kwargs: object):
|
||||
@ -803,7 +761,7 @@ class MiniCPMO(MiniCPMV2_6):
|
||||
else:
|
||||
image_inputs, audio_inputs = \
|
||||
self._parse_and_validate_inputs(input_ids, **kwargs)
|
||||
vlm_embeddings, _ = self.get_embedding_with_vision(
|
||||
vlm_embeddings = self.get_embedding_with_vision(
|
||||
input_ids, image_inputs)
|
||||
|
||||
if audio_inputs is not None:
|
||||
|
@ -24,6 +24,7 @@
|
||||
"""Inference-only MiniCPM-V model compatible with HuggingFace weights."""
|
||||
import math
|
||||
import re
|
||||
from collections import defaultdict
|
||||
from collections.abc import Iterable, Mapping, Sequence
|
||||
from functools import cached_property, partial
|
||||
from typing import (Any, Callable, Dict, List, Literal, Optional, Set, Tuple,
|
||||
@ -63,11 +64,12 @@ from vllm.multimodal.processing import (BaseMultiModalProcessor,
|
||||
from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.sequence import IntermediateTensors
|
||||
from vllm.utils import flatten_2d_lists
|
||||
|
||||
from .idefics2_vision_model import Idefics2VisionTransformer
|
||||
from .interfaces import (SupportsLoRA, SupportsMultiModal, SupportsPP,
|
||||
SupportsV0Only)
|
||||
from .utils import AutoWeightsLoader, maybe_prefix
|
||||
from .utils import AutoWeightsLoader, flatten_bn, maybe_prefix
|
||||
|
||||
CPU_DEVICE = torch.device("cpu")
|
||||
|
||||
@ -76,7 +78,7 @@ RawImageType = Union[Image.Image, torch.Tensor]
|
||||
|
||||
class MiniCPMVImagePixelInputs(TypedDict):
|
||||
type: Literal["pixel_values"]
|
||||
data: List[torch.Tensor]
|
||||
pixel_values: list[torch.Tensor]
|
||||
"""
|
||||
Shape: `(batch_size * num_images * num_slices, num_channels, height, width)`
|
||||
|
||||
@ -101,7 +103,7 @@ class MiniCPMVImagePixelInputs(TypedDict):
|
||||
|
||||
class MiniCPMVImageEmbeddingInputs(TypedDict):
|
||||
type: Literal["image_embeds"]
|
||||
data: torch.Tensor
|
||||
image_embeds: torch.Tensor
|
||||
"""
|
||||
Shape: `(batch_size * num_images * num_slices,
|
||||
image_feature_size, hidden_size)`
|
||||
@ -231,26 +233,15 @@ def get_version_by_config(config: PretrainedConfig) -> Tuple[int, ...]:
|
||||
|
||||
|
||||
def _minicpmv_field_config(hf_inputs: Mapping[str, torch.Tensor]):
|
||||
image_num_slices = hf_inputs.get("image_num_slices", torch.empty(0))
|
||||
video_num_slices = hf_inputs.get("video_num_slices", torch.empty(0))
|
||||
|
||||
return dict(
|
||||
pixel_values=MultiModalFieldConfig.flat_from_sizes(
|
||||
"image", image_num_slices),
|
||||
pixel_values=MultiModalFieldConfig.batched("image"),
|
||||
image_sizes=MultiModalFieldConfig.batched("image"),
|
||||
tgt_sizes=MultiModalFieldConfig.flat_from_sizes(
|
||||
"image", image_num_slices),
|
||||
image_num_slices=MultiModalFieldConfig.batched("image"),
|
||||
image_embeds=MultiModalFieldConfig.flat_from_sizes(
|
||||
"image", image_num_slices),
|
||||
video_pixel_values=MultiModalFieldConfig.flat_from_sizes(
|
||||
"video", video_num_slices),
|
||||
tgt_sizes=MultiModalFieldConfig.batched("image"),
|
||||
image_embeds=MultiModalFieldConfig.batched("image"),
|
||||
video_pixel_values=MultiModalFieldConfig.batched("video"),
|
||||
video_image_sizes=MultiModalFieldConfig.batched("video"),
|
||||
video_tgt_sizes=MultiModalFieldConfig.flat_from_sizes(
|
||||
"video", video_num_slices),
|
||||
video_embeds=MultiModalFieldConfig.flat_from_sizes(
|
||||
"video", video_num_slices),
|
||||
video_num_slices=MultiModalFieldConfig.batched("video"),
|
||||
video_tgt_sizes=MultiModalFieldConfig.batched("video"),
|
||||
video_embeds=MultiModalFieldConfig.batched("video"),
|
||||
)
|
||||
|
||||
|
||||
@ -356,12 +347,6 @@ class MiniCPMVProcessingInfo(BaseProcessingInfo):
|
||||
def get_model_version(self):
|
||||
return get_version_by_config(self.get_hf_config())
|
||||
|
||||
def get_supported_mm_modalities(self) -> List[str]:
|
||||
if self.get_model_version() == (2, 6):
|
||||
return ["image", "video"]
|
||||
else:
|
||||
return ["image"]
|
||||
|
||||
def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]:
|
||||
if self.get_model_version() == (2, 6):
|
||||
return {"image": None, "video": None}
|
||||
@ -526,187 +511,123 @@ class MiniCPMVMultiModalProcessor(BaseMultiModalProcessor[_I]):
|
||||
def get_image_prompt_texts(self,
|
||||
image_size: ImageSize,
|
||||
image_idx: int = 0) -> str:
|
||||
prompt_texts = self.get_slice_image_placeholder(image_size,
|
||||
image_idx=image_idx)
|
||||
return prompt_texts
|
||||
return self.get_slice_image_placeholder(image_size,
|
||||
image_idx=image_idx)
|
||||
|
||||
def get_video_prompt_texts(self, image_size: ImageSize,
|
||||
num_frames: int) -> str:
|
||||
prompt_texts = "".join(
|
||||
self.get_slice_image_placeholder(
|
||||
image_size=image_size,
|
||||
image_idx=0,
|
||||
max_slice_nums=self.info.get_video_max_slice_num(),
|
||||
use_image_id=False) for image_idx in range(num_frames))
|
||||
return prompt_texts
|
||||
return self.get_slice_image_placeholder(
|
||||
image_size=image_size,
|
||||
image_idx=0,
|
||||
max_slice_nums=self.info.get_video_max_slice_num(),
|
||||
use_image_id=False,
|
||||
) * num_frames
|
||||
|
||||
def get_special_tokens(self) -> Dict[str, torch.Tensor]:
|
||||
tokenizer = self.info.get_tokenizer()
|
||||
|
||||
special_tokens = {
|
||||
"im_start_id": torch.tensor(tokenizer.im_start_id),
|
||||
"im_end_id": torch.tensor(tokenizer.im_end_id)
|
||||
"im_start_id": tokenizer.im_start_id,
|
||||
"im_end_id": tokenizer.im_end_id,
|
||||
}
|
||||
if hasattr(tokenizer, "slice_start_id"):
|
||||
special_tokens["slice_start_id"] = torch.tensor(
|
||||
tokenizer.slice_start_id)
|
||||
special_tokens["slice_end_id"] = torch.tensor(
|
||||
tokenizer.slice_end_id)
|
||||
return special_tokens
|
||||
special_tokens["slice_start_id"] = tokenizer.slice_start_id
|
||||
special_tokens["slice_end_id"] = tokenizer.slice_end_id
|
||||
|
||||
@staticmethod
|
||||
def repack_processor_outputs(outputs: Any) -> BatchFeature:
|
||||
valid_keys = ["pixel_values", "image_sizes", "tgt_sizes"]
|
||||
outputs = {key: outputs[key][0] for key in valid_keys}
|
||||
return outputs
|
||||
return {k: torch.tensor(v) for k, v in special_tokens.items()}
|
||||
|
||||
def process_images(
|
||||
self,
|
||||
mm_data: Mapping[str, object],
|
||||
mm_kwargs: Mapping[str, object],
|
||||
) -> Mapping[str, NestedTensors]:
|
||||
mm_data = dict(mm_data)
|
||||
if (images := mm_data.get("images")) is None:
|
||||
return {}
|
||||
|
||||
images = mm_data.pop("images", [])
|
||||
image_embeds = mm_data.pop("image_embeds", [])
|
||||
if isinstance(images, Image.Image):
|
||||
images = [images]
|
||||
if isinstance(images, (list, torch.Tensor)) and len(images) > 0:
|
||||
image_outputs = super()._call_hf_processor(
|
||||
prompt=self.info.image_pattern * len(images),
|
||||
mm_data={"images": images},
|
||||
mm_kwargs=mm_kwargs)
|
||||
image_outputs = self.repack_processor_outputs(image_outputs)
|
||||
elif len(image_embeds) > 0:
|
||||
image_sizes = mm_data.pop("image_sizes", None)
|
||||
image_outputs = {
|
||||
"image_embeds": torch.cat(image_embeds),
|
||||
"image_sizes": image_sizes
|
||||
}
|
||||
else:
|
||||
image_outputs = {}
|
||||
return image_outputs
|
||||
parsed_images = (self._get_data_parser().parse_mm_data({
|
||||
"image": images
|
||||
}).get_items("image", ImageProcessorItems))
|
||||
|
||||
return self._base_call_hf_processor(
|
||||
prompts=[self.info.image_pattern] * len(parsed_images),
|
||||
mm_data={"images": [[image] for image in parsed_images]},
|
||||
mm_kwargs=mm_kwargs,
|
||||
out_keys={"pixel_values", "image_sizes", "tgt_sizes"},
|
||||
)
|
||||
|
||||
def process_videos(
|
||||
self,
|
||||
mm_data: Mapping[str, object],
|
||||
mm_kwargs: Mapping[str, object],
|
||||
) -> Mapping[str, NestedTensors]:
|
||||
mm_data = dict(mm_data)
|
||||
if (videos := mm_data.get("videos")) is None:
|
||||
return {}
|
||||
|
||||
videos = mm_data.pop("videos", [])
|
||||
video_embeds = mm_data.pop("video_embeds", [])
|
||||
if len(videos) > 0 and isinstance(videos[0], Image.Image):
|
||||
videos = [videos]
|
||||
if isinstance(videos, list) and len(videos) > 0:
|
||||
video_outputs = {
|
||||
"video_pixel_values": [],
|
||||
"video_image_sizes": [],
|
||||
"video_tgt_sizes": [],
|
||||
"num_frames": []
|
||||
}
|
||||
for video in videos:
|
||||
parsed_video = []
|
||||
for frame in video:
|
||||
if isinstance(frame, np.ndarray):
|
||||
parsed_video.append(Image.fromarray(frame))
|
||||
else:
|
||||
parsed_video.append(frame)
|
||||
video = parsed_video
|
||||
single_video_outputs = super()._call_hf_processor(
|
||||
prompt=self.info.image_pattern * len(video),
|
||||
mm_data={"images": video},
|
||||
mm_kwargs={
|
||||
**mm_kwargs, "max_slice_nums":
|
||||
self.info.get_video_max_slice_num()
|
||||
})
|
||||
video_outputs["num_frames"].append(len(video))
|
||||
for key in single_video_outputs:
|
||||
if "video_" + key in video_outputs:
|
||||
if key == "image_sizes":
|
||||
video_outputs["video_" + key].append(
|
||||
single_video_outputs[key][0][0])
|
||||
else:
|
||||
video_outputs["video_" +
|
||||
key] += single_video_outputs[key][0]
|
||||
elif len(video_embeds):
|
||||
image_sizes = mm_data.pop("image_sizes", None)
|
||||
num_frames = mm_data.pop("num_frames", None)
|
||||
video_outputs = {
|
||||
"video_embeds": torch.cat(video_embeds),
|
||||
"video_image_sizes": image_sizes,
|
||||
"num_frames": num_frames
|
||||
}
|
||||
else:
|
||||
video_outputs = {}
|
||||
return video_outputs
|
||||
parsed_videos = (self._get_data_parser().parse_mm_data({
|
||||
"video": videos
|
||||
}).get_items("video", VideoProcessorItems))
|
||||
|
||||
max_slice_num = self.info.get_video_max_slice_num()
|
||||
|
||||
video_inputs = self._base_call_hf_processor(
|
||||
prompts=[
|
||||
self.info.image_pattern * len(video) for video in parsed_videos
|
||||
],
|
||||
mm_data={"images": list(parsed_videos)},
|
||||
mm_kwargs={
|
||||
**mm_kwargs, "max_slice_nums": max_slice_num
|
||||
},
|
||||
out_keys={"pixel_values", "image_sizes", "tgt_sizes"},
|
||||
)
|
||||
|
||||
return {f"video_{k}": v for k, v in video_inputs.items()}
|
||||
|
||||
def get_placeholder_match_pattern(self) -> str:
|
||||
return r"\(<(image|video)>./</\1>\)"
|
||||
|
||||
def get_placeholder_split_pattern(self) -> str:
|
||||
return r"\(<(?:image|video)>./</(?:image|video)>\)"
|
||||
|
||||
def process_mm_inputs(
|
||||
self,
|
||||
mm_data: Mapping[str, object],
|
||||
mm_kwargs: Mapping[str, object],
|
||||
) -> Mapping[str, Mapping[str, NestedTensors]]:
|
||||
) -> Mapping[str, NestedTensors]:
|
||||
return {
|
||||
"image": self.process_images(mm_data, mm_kwargs),
|
||||
"video": self.process_videos(mm_data, mm_kwargs),
|
||||
**self.process_images(mm_data, mm_kwargs),
|
||||
**self.process_videos(mm_data, mm_kwargs),
|
||||
}
|
||||
|
||||
def get_input_modalities(self, mm_data) -> List[str]:
|
||||
supported_mm_modalities = self.info.get_supported_mm_modalities()
|
||||
input_modalities = []
|
||||
for modality in supported_mm_modalities:
|
||||
if modality in mm_data and mm_data[modality] != {}:
|
||||
input_modalities.append(modality)
|
||||
return input_modalities
|
||||
|
||||
def get_modality_num_counter(self, modality: str) -> str:
|
||||
if modality == "image":
|
||||
return "image_sizes"
|
||||
elif modality == "video":
|
||||
return "video_image_sizes"
|
||||
|
||||
raise NotImplementedError(modality)
|
||||
|
||||
def get_num_slices_by_modality(self, inputs: dict[str, Any], modality: str,
|
||||
index: int) -> int:
|
||||
if modality == "image":
|
||||
return self.info.get_image_slice_nums(
|
||||
inputs[modality]["image_sizes"][index],
|
||||
self.info.get_max_slice_num())
|
||||
elif modality == "video":
|
||||
return self.info.get_image_slice_nums(
|
||||
inputs[modality]["video_image_sizes"][index],
|
||||
self.info.get_video_max_slice_num()
|
||||
) * inputs[modality]["num_frames"][index]
|
||||
else:
|
||||
raise ValueError(f"Unexpected modality: {modality}")
|
||||
|
||||
def get_prompt_texts_by_modality(self, inputs: dict[str, Any],
|
||||
modality: str, index: int) -> str:
|
||||
if modality == "image":
|
||||
return self.get_image_prompt_texts(
|
||||
inputs["image"]["image_sizes"][index], index)
|
||||
elif modality == "video":
|
||||
return self.get_video_prompt_texts(
|
||||
inputs["video"]["video_image_sizes"][index],
|
||||
inputs["video"]["num_frames"][index])
|
||||
else:
|
||||
raise ValueError(f"Unexpected modality: {modality}")
|
||||
|
||||
def call_base_hf_processor(
|
||||
def _base_call_hf_processor(
|
||||
self,
|
||||
prompt: str,
|
||||
mm_data: Mapping[str, object],
|
||||
prompts: list[str],
|
||||
mm_data: Mapping[str, Sequence[object]],
|
||||
mm_kwargs: Mapping[str, object],
|
||||
) -> BatchFeature:
|
||||
return super()._call_hf_processor(prompt=prompt,
|
||||
mm_data=mm_data,
|
||||
mm_kwargs=mm_kwargs)
|
||||
*,
|
||||
out_keys: set[str],
|
||||
) -> Mapping[str, NestedTensors]:
|
||||
# This processor supports zipping prompt and mm_data together
|
||||
if self.info.get_model_version() == (2, 6):
|
||||
inputs = super()._call_hf_processor(
|
||||
prompt=prompts, # type: ignore
|
||||
mm_data=mm_data,
|
||||
mm_kwargs=mm_kwargs,
|
||||
)
|
||||
else:
|
||||
inputs = defaultdict[str, list[torch.Tensor]](list)
|
||||
|
||||
for i, prompt in enumerate(prompts):
|
||||
inputs_one = super()._call_hf_processor(
|
||||
prompt=prompt,
|
||||
mm_data={
|
||||
k: v[i]
|
||||
for k, v in mm_data.items()
|
||||
},
|
||||
mm_kwargs=mm_kwargs,
|
||||
)
|
||||
|
||||
for k, v in inputs_one.items():
|
||||
assert len(v) == 1, (k, len(v))
|
||||
inputs[k].append(v[0])
|
||||
|
||||
return {k: inputs[k] for k in out_keys}
|
||||
|
||||
def _call_hf_processor(
|
||||
self,
|
||||
@ -717,35 +638,12 @@ class MiniCPMVMultiModalProcessor(BaseMultiModalProcessor[_I]):
|
||||
# Do not support combination inputs of images and videos for now
|
||||
# Try to handle interleaved multimodal data
|
||||
tokenizer = self.info.get_tokenizer()
|
||||
inputs = self.process_mm_inputs(mm_data, mm_kwargs)
|
||||
mm_input_modalities = self.get_input_modalities(inputs)
|
||||
|
||||
num_mm_slices_lst = {
|
||||
modality: list[int]()
|
||||
for modality in mm_input_modalities
|
||||
}
|
||||
for modality in mm_input_modalities:
|
||||
num_counter_key = self.get_modality_num_counter(modality)
|
||||
for index in range(len(inputs[modality][num_counter_key])):
|
||||
num_mm_slices_lst[modality].append(
|
||||
self.get_num_slices_by_modality(inputs, modality, index))
|
||||
|
||||
num_mm_slices = {
|
||||
modality: torch.tensor(v)
|
||||
for modality, v in num_mm_slices_lst.items()
|
||||
}
|
||||
mm_inputs = self.process_mm_inputs(mm_data, mm_kwargs)
|
||||
|
||||
return BatchFeature({
|
||||
"input_ids": np.array([tokenizer.encode(prompt)]),
|
||||
**{
|
||||
key: value
|
||||
for modality in inputs
|
||||
for key, value in inputs[modality].items()
|
||||
},
|
||||
**{
|
||||
f"{modality}_num_slices": num_mm_slices[modality]
|
||||
for modality in mm_input_modalities
|
||||
}
|
||||
"input_ids":
|
||||
torch.tensor([tokenizer.encode(prompt)]),
|
||||
**mm_inputs,
|
||||
})
|
||||
|
||||
def _hf_processor_applies_updates(
|
||||
@ -810,7 +708,6 @@ class MiniCPMVMultiModalProcessor(BaseMultiModalProcessor[_I]):
|
||||
hf_processor_mm_kwargs: Mapping[str, object],
|
||||
return_mm_hashes: bool = False,
|
||||
) -> MultiModalInputs:
|
||||
supported_mm_modalities = self.info.get_supported_mm_modalities()
|
||||
if isinstance(prompt, list):
|
||||
prompt = self.info.get_tokenizer().decode(prompt)
|
||||
matches = re.findall(self.get_placeholder_match_pattern(), prompt)
|
||||
@ -818,7 +715,7 @@ class MiniCPMVMultiModalProcessor(BaseMultiModalProcessor[_I]):
|
||||
f"{modality}_orders":
|
||||
torch.tensor(
|
||||
[index for index, m in enumerate(matches) if m == modality])
|
||||
for modality in supported_mm_modalities
|
||||
for modality in self.info.get_supported_mm_limits()
|
||||
}
|
||||
result = super().apply(prompt, mm_data, hf_processor_mm_kwargs,
|
||||
return_mm_hashes)
|
||||
@ -884,35 +781,35 @@ class MiniCPMVBaseModel(nn.Module, SupportsMultiModal, SupportsPP,
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
image_inputs: Optional[MiniCPMVImageInputs],
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
) -> torch.Tensor:
|
||||
vlm_embedding: torch.Tensor = self.llm.get_input_embeddings(input_ids)
|
||||
|
||||
if image_inputs is None: # No image
|
||||
vision_hidden_states = torch.tensor([], device=input_ids.device)
|
||||
if image_inputs is None:
|
||||
return vlm_embedding
|
||||
|
||||
if image_inputs["type"] == "image_embeds":
|
||||
vision_hidden_states = image_inputs["image_embeds"].to(
|
||||
device=vlm_embedding.device,
|
||||
dtype=vlm_embedding.dtype,
|
||||
)
|
||||
else:
|
||||
if image_inputs["type"] == "image_embeds":
|
||||
vision_hidden_states = (image_inputs["data"].type(
|
||||
vlm_embedding.dtype).to(vlm_embedding.device))
|
||||
else:
|
||||
vision_hidden_states = self.get_vision_hidden_states(
|
||||
image_inputs)
|
||||
vision_hidden_states = self.get_vision_hidden_states(image_inputs)
|
||||
|
||||
# See NOTE in _parse_and_validate_inputs
|
||||
image_bounds = image_inputs["image_bounds"]
|
||||
if len(image_bounds) > 0:
|
||||
image_indices = torch.stack([
|
||||
torch.arange(start, end, dtype=torch.long)
|
||||
for start, end in image_bounds.tolist()
|
||||
]).to(vlm_embedding.device)
|
||||
vlm_embedding.scatter_(
|
||||
0,
|
||||
image_indices.view(-1, 1).repeat(1,
|
||||
vlm_embedding.shape[-1]),
|
||||
vision_hidden_states.view(-1,
|
||||
vision_hidden_states.shape[-1]),
|
||||
)
|
||||
# See NOTE in _parse_and_validate_inputs
|
||||
image_bounds = image_inputs["image_bounds"]
|
||||
if len(image_bounds) > 0:
|
||||
image_indices = torch.stack([
|
||||
torch.arange(start, end, dtype=torch.long)
|
||||
for start, end in image_bounds.tolist()
|
||||
]).to(vlm_embedding.device)
|
||||
|
||||
return vlm_embedding, vision_hidden_states
|
||||
vlm_embedding.scatter_(
|
||||
0,
|
||||
image_indices.view(-1, 1).repeat(1, vlm_embedding.shape[-1]),
|
||||
vision_hidden_states.view(-1, vision_hidden_states.shape[-1]),
|
||||
)
|
||||
|
||||
return vlm_embedding
|
||||
|
||||
def _get_image_bounds(
|
||||
self,
|
||||
@ -947,90 +844,115 @@ class MiniCPMVBaseModel(nn.Module, SupportsMultiModal, SupportsPP,
|
||||
input_ids: torch.Tensor,
|
||||
**kwargs: object,
|
||||
) -> Optional[MiniCPMVImageInputs]:
|
||||
mm_data = {
|
||||
image_keys = {"pixel_values", "tgt_sizes"}
|
||||
pixel_data = {
|
||||
"image": {
|
||||
key: kwargs.pop(key, [])
|
||||
for key in ["pixel_values", "tgt_sizes", "image_num_slices"]
|
||||
key: kwargs.pop(key, None)
|
||||
for key in image_keys
|
||||
},
|
||||
"video": {
|
||||
"pixel_values": kwargs.pop("video_pixel_values", []),
|
||||
"tgt_sizes": kwargs.pop("video_tgt_sizes", []),
|
||||
"video_num_slices": kwargs.pop("video_num_slices", [])
|
||||
key: kwargs.pop("video_" + key, None)
|
||||
for key in image_keys
|
||||
}
|
||||
}
|
||||
im_start_id = kwargs.pop("im_start_id", None)
|
||||
im_end_id = kwargs.pop("im_end_id", None)
|
||||
slice_start_id = kwargs.pop("slice_start_id", None)
|
||||
slice_end_id = kwargs.pop("slice_end_id", None)
|
||||
mm_orders = {
|
||||
f"{modality}": kwargs.pop(f"{modality}_orders", None)
|
||||
for modality in ["image", "video", "audio"]
|
||||
embed_data = {
|
||||
"image": kwargs.pop("image_embeds", None),
|
||||
"video": kwargs.pop("video_embeds", None),
|
||||
}
|
||||
batch_size = max(len(mm_data["image"]["pixel_values"]),
|
||||
len(mm_data["video"]["pixel_values"]))
|
||||
image_embeds = kwargs.pop("image_embeds", None)
|
||||
video_embeds = kwargs.pop("video_embeds", None)
|
||||
if image_embeds is not None and video_embeds is not None:
|
||||
raise ValueError(
|
||||
"Incorrect inputs for vision embeddings. "
|
||||
"Image embeds and video embeds can not exist simultaneously.")
|
||||
if video_embeds is not None:
|
||||
image_embeds = video_embeds
|
||||
if image_embeds is not None:
|
||||
if not isinstance(image_embeds, (torch.Tensor, list)):
|
||||
raise ValueError(f"Incorrect type of image embeds. "
|
||||
f"Got type: {type(image_embeds)}")
|
||||
image_embeds = torch.concat(
|
||||
[image_embeds[i] for i in range(len(image_embeds))])
|
||||
|
||||
all_pixel_data = [
|
||||
v for vs in pixel_data.values() for v in vs.values()
|
||||
if v is not None
|
||||
]
|
||||
all_embed_data = [v for v in embed_data.values() if v is not None]
|
||||
if len(all_pixel_data) == 0 and len(all_embed_data) == 0:
|
||||
return None
|
||||
|
||||
im_start_id = kwargs.pop("im_start_id")
|
||||
if not isinstance(im_start_id, torch.Tensor):
|
||||
raise ValueError("Incorrect type of im_start_id. "
|
||||
f"Got type: {type(im_start_id)}")
|
||||
|
||||
im_end_id = kwargs.pop("im_end_id")
|
||||
if not isinstance(im_end_id, torch.Tensor):
|
||||
raise ValueError("Incorrect type of im_end_id. "
|
||||
f"Got type: {type(im_end_id)}")
|
||||
|
||||
slice_start_id = kwargs.pop("slice_start_id", None)
|
||||
if slice_start_id is not None and not isinstance(
|
||||
slice_start_id, torch.Tensor):
|
||||
raise ValueError("Incorrect type of slice_start_id. "
|
||||
f"Got type: {type(slice_start_id)}")
|
||||
|
||||
slice_end_id = kwargs.pop("slice_end_id", None)
|
||||
if slice_end_id is not None and not isinstance(slice_end_id,
|
||||
torch.Tensor):
|
||||
raise ValueError("Incorrect type of slice_end_id. "
|
||||
f"Got type: {type(slice_end_id)}")
|
||||
|
||||
if len(all_embed_data) > 0:
|
||||
if len(all_embed_data) > 1:
|
||||
raise ValueError("Incorrect inputs for vision embeddings. "
|
||||
"Image embeds and video embeds can not "
|
||||
"exist simultaneously.")
|
||||
|
||||
vision_embeds, = all_embed_data
|
||||
if not isinstance(vision_embeds, (torch.Tensor, list)):
|
||||
raise ValueError(f"Incorrect type of vision_embeds. "
|
||||
f"Got type: {type(vision_embeds)}")
|
||||
|
||||
return MiniCPMVImageEmbeddingInputs(
|
||||
type="image_embeds",
|
||||
image_embeds=flatten_bn(flatten_2d_lists(vision_embeds),
|
||||
concat=True),
|
||||
image_bounds=self._get_image_bounds(input_ids, im_start_id,
|
||||
im_end_id, slice_start_id,
|
||||
slice_end_id),
|
||||
data=image_embeds,
|
||||
type="image_embeds",
|
||||
)
|
||||
for modality, modality_mm_data in mm_data.items():
|
||||
if not isinstance(modality_mm_data["pixel_values"],
|
||||
(torch.Tensor, list)):
|
||||
raise ValueError(
|
||||
"Incorrect type of pixel values. "
|
||||
f"Got type: {type(modality_mm_data['pixel_values'])}")
|
||||
|
||||
if not isinstance(modality_mm_data["tgt_sizes"],
|
||||
(torch.Tensor, list)):
|
||||
raise ValueError(
|
||||
"Incorrect type of target sizes. "
|
||||
f"Got type: {type(modality_mm_data['tgt_sizes'])}")
|
||||
order_data = dict[str, Union[torch.Tensor, list[torch.Tensor]]]()
|
||||
for modality in ("image", "video"):
|
||||
modality_orders = kwargs.pop(f"{modality}_orders", None)
|
||||
if modality_orders is not None:
|
||||
if not isinstance(modality_orders, (torch.Tensor, list)):
|
||||
raise ValueError(f"Incorrect type of {modality}_orders. "
|
||||
f"Got type: {type(modality_orders)}")
|
||||
|
||||
if len(modality_mm_data["pixel_values"]) != len(
|
||||
modality_mm_data["tgt_sizes"]):
|
||||
raise ValueError(
|
||||
"Inconsistent batch lengths, found: "
|
||||
f"{len(modality_mm_data['pixel_values'])} vs. "
|
||||
f"{len(modality_mm_data['tgt_sizes'])}")
|
||||
order_data[modality] = modality_orders
|
||||
|
||||
pixel_values_flat: List[torch.Tensor] = []
|
||||
tgt_sizes_flat: List[torch.Tensor] = []
|
||||
batch_sizes = {
|
||||
modality: len(modality_orders)
|
||||
for modality, modality_orders in order_data.items()
|
||||
}
|
||||
unique_batch_sizes = set(batch_sizes.values())
|
||||
assert len(unique_batch_sizes) == 1, (
|
||||
f"Found inconsistent batch sizes: {batch_sizes}")
|
||||
batch_size, = unique_batch_sizes
|
||||
|
||||
pixel_values_flat = list[torch.Tensor]()
|
||||
tgt_sizes_flat = list[torch.Tensor]()
|
||||
for b in range(batch_size):
|
||||
mm_counts = {"image": 0, "video": 0} if self.version == (2, 6) \
|
||||
else {"image": 0}
|
||||
mm_slice_counts = {"image": 0, "video": 0} \
|
||||
if self.version == (2, 6) else {"image": 0}
|
||||
mm_orders_b = [(index, modality) for modality in mm_counts
|
||||
for index in mm_orders[modality][b]]
|
||||
mm_orders_b = [(idx_b.item(), modality)
|
||||
for modality, modality_orders in order_data.items()
|
||||
for idx_b in modality_orders[b]]
|
||||
|
||||
for _, modality in sorted(mm_orders_b, key=lambda x: x[0]):
|
||||
pos = mm_counts[modality]
|
||||
num_slices = mm_data[modality][f"{modality}_num_slices"][b][
|
||||
pos]
|
||||
slice_start_idx = mm_slice_counts[modality]
|
||||
slice_end_idx = slice_start_idx + num_slices
|
||||
pixel_values_flat += mm_data[modality]["pixel_values"][b][
|
||||
slice_start_idx:slice_end_idx]
|
||||
tgt_sizes_flat += mm_data[modality]["tgt_sizes"][b][
|
||||
slice_start_idx:slice_end_idx]
|
||||
mm_counts[modality] += 1
|
||||
mm_slice_counts[modality] += num_slices
|
||||
modality_pixel_data = pixel_data[modality]
|
||||
|
||||
modality_pixel_values = modality_pixel_data["pixel_values"]
|
||||
if not isinstance(modality_pixel_values, (torch.Tensor, list)):
|
||||
raise ValueError(
|
||||
f"Incorrect type of pixel_values for {modality=}. "
|
||||
f"Got type: {type(modality_pixel_values)}")
|
||||
|
||||
modality_tgt_sizes = modality_pixel_data["tgt_sizes"]
|
||||
if not isinstance(modality_tgt_sizes, (torch.Tensor, list)):
|
||||
raise ValueError(
|
||||
f"Incorrect type of tgt_sizes for {modality=}. "
|
||||
f"Got type: {type(modality_tgt_sizes)}")
|
||||
|
||||
pixel_values_flat += flatten_2d_lists(modality_pixel_values[b])
|
||||
tgt_sizes_flat += flatten_2d_lists(modality_tgt_sizes[b])
|
||||
|
||||
# NOTE: Input IDs does not contain image tokens during memory profiling,
|
||||
# so we allow it to be empty
|
||||
@ -1042,16 +964,13 @@ class MiniCPMVBaseModel(nn.Module, SupportsMultiModal, SupportsPP,
|
||||
if len(pixel_values_flat) == 0:
|
||||
return None
|
||||
|
||||
if im_start_id is None:
|
||||
return None
|
||||
|
||||
return MiniCPMVImagePixelInputs(
|
||||
type="pixel_values",
|
||||
pixel_values=pixel_values_flat,
|
||||
tgt_sizes=torch.stack(tgt_sizes_flat),
|
||||
image_bounds=self._get_image_bounds(input_ids, im_start_id,
|
||||
im_end_id, slice_start_id,
|
||||
slice_end_id),
|
||||
data=pixel_values_flat,
|
||||
tgt_sizes=torch.stack(tgt_sizes_flat),
|
||||
type="pixel_values",
|
||||
)
|
||||
|
||||
def _parse_and_validate_inputs(self, input_ids: torch.Tensor,
|
||||
@ -1070,7 +989,7 @@ class MiniCPMVBaseModel(nn.Module, SupportsMultiModal, SupportsPP,
|
||||
else:
|
||||
image_inputs = \
|
||||
self._parse_and_validate_inputs(input_ids, **kwargs)
|
||||
vlm_embeddings, _ = self.get_embedding_with_vision(
|
||||
vlm_embeddings = self.get_embedding_with_vision(
|
||||
input_ids, image_inputs)
|
||||
|
||||
# always pass the input via `inputs_embeds`
|
||||
@ -1136,16 +1055,8 @@ class MiniCPMVBaseModel(nn.Module, SupportsMultiModal, SupportsPP,
|
||||
prefix: str = "") -> nn.Module:
|
||||
raise NotImplementedError
|
||||
|
||||
def get_vision_embedding(
|
||||
self,
|
||||
pixel_values: List[torch.Tensor],
|
||||
patch_attn_mask: Optional[torch.Tensor] = None,
|
||||
tgt_sizes: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
raise NotImplementedError
|
||||
|
||||
def get_vision_hidden_states(self,
|
||||
data: MiniCPMVImageInputs) -> torch.Tensor:
|
||||
def get_vision_hidden_states(
|
||||
self, data: MiniCPMVImagePixelInputs) -> torch.Tensor:
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
@ -1216,35 +1127,27 @@ class MiniCPMV2_0(MiniCPMVBaseModel):
|
||||
return resampler.to(device=current_platform.device_type,
|
||||
dtype=torch.get_default_dtype())
|
||||
|
||||
def get_vision_embedding(
|
||||
self,
|
||||
pixel_values: List[torch.Tensor],
|
||||
patch_attn_mask: Optional[torch.Tensor] = None,
|
||||
tgt_sizes: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
res = []
|
||||
dtype = self.vpm.pos_embed.data.dtype
|
||||
def get_vision_hidden_states(
|
||||
self, data: MiniCPMVImagePixelInputs) -> torch.Tensor:
|
||||
pixel_values = data["pixel_values"]
|
||||
|
||||
P_h, P_w = self.vpm.patch_embed.patch_size
|
||||
dtype: torch.dtype = self.vpm.pos_embed.data.dtype
|
||||
num_prefix_tokens = getattr(self.vpm, "num_prefix_tokens", 0)
|
||||
|
||||
res = list[torch.Tensor]()
|
||||
for pixel_value in pixel_values:
|
||||
H, W = pixel_value[0].shape[-2:]
|
||||
tgt_size = (
|
||||
math.ceil(H / self.vpm.patch_embed.patch_size[0]),
|
||||
math.ceil(W / self.vpm.patch_embed.patch_size[0]),
|
||||
)
|
||||
tgt_size = (math.ceil(H / P_h), math.ceil(W / P_w))
|
||||
vision_embedding = self.vpm.forward_features(
|
||||
pixel_value.unsqueeze(0).type(dtype))
|
||||
if (hasattr(self.vpm, "num_prefix_tokens")
|
||||
and self.vpm.num_prefix_tokens > 0):
|
||||
vision_embedding = vision_embedding[:, self.vpm.
|
||||
num_prefix_tokens:]
|
||||
|
||||
if num_prefix_tokens > 0:
|
||||
vision_embedding = vision_embedding[:, num_prefix_tokens:]
|
||||
res.append(self.resampler(vision_embedding, tgt_size))
|
||||
|
||||
return torch.vstack(res)
|
||||
|
||||
def get_vision_hidden_states(self,
|
||||
data: MiniCPMVImageInputs) -> torch.Tensor:
|
||||
pixel_values = data["data"]
|
||||
|
||||
return self.get_vision_embedding(pixel_values)
|
||||
|
||||
|
||||
class MiniCPMV2_5(MiniCPMVBaseModel, SupportsLoRA):
|
||||
packed_modules_mapping = {
|
||||
@ -1299,45 +1202,41 @@ class MiniCPMV2_5(MiniCPMVBaseModel, SupportsLoRA):
|
||||
return resampler.to(device=current_platform.device_type,
|
||||
dtype=torch.get_default_dtype())
|
||||
|
||||
def get_vision_embedding(
|
||||
self,
|
||||
pixel_values: List[torch.Tensor],
|
||||
patch_attn_mask: Optional[torch.Tensor] = None,
|
||||
tgt_sizes: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
vision_embedding = self.vpm(pixel_values,
|
||||
patch_attention_mask=patch_attn_mask)
|
||||
vision_embedding = self.resampler(vision_embedding, tgt_sizes)
|
||||
return vision_embedding
|
||||
|
||||
def get_vision_hidden_states(self,
|
||||
data: MiniCPMVImageInputs) -> torch.Tensor:
|
||||
pixel_values = data["data"]
|
||||
def get_vision_hidden_states(
|
||||
self, data: MiniCPMVImagePixelInputs) -> torch.Tensor:
|
||||
pixel_values = data["pixel_values"]
|
||||
tgt_sizes = data["tgt_sizes"]
|
||||
|
||||
device = self.vpm.embeddings.position_embedding.weight.device
|
||||
dtype = self.vpm.embeddings.position_embedding.weight.dtype
|
||||
all_pixel_values_lst = [
|
||||
i.flatten(end_dim=1).permute(1, 0) for i in pixel_values
|
||||
]
|
||||
B = len(pixel_values)
|
||||
P = pixel_values[0].shape[-2]
|
||||
L = max(item.shape[-1] for item in pixel_values)
|
||||
device = pixel_values[0].device
|
||||
dtype = pixel_values[0].dtype
|
||||
|
||||
max_patches = (tgt_sizes[:, 0] * tgt_sizes[:, 1]).max().item()
|
||||
all_pixel_values = torch.zeros((B, 3, P, L),
|
||||
dtype=dtype,
|
||||
device=device)
|
||||
for i, pixel_values_item in enumerate(pixel_values):
|
||||
L_item = pixel_values_item.shape[-1]
|
||||
all_pixel_values[i, ..., :L_item] = pixel_values_item
|
||||
|
||||
num_patches = tgt_sizes.prod(-1)
|
||||
max_patches = num_patches.max().item()
|
||||
assert isinstance(max_patches, int)
|
||||
|
||||
all_pixel_values = torch.nn.utils.rnn.pad_sequence(
|
||||
all_pixel_values_lst, batch_first=True, padding_value=0.0)
|
||||
B, L, _ = all_pixel_values.shape
|
||||
all_pixel_values = all_pixel_values.permute(0, 2,
|
||||
1).reshape(B, 3, -1, L)
|
||||
|
||||
patch_attn_mask = torch.zeros((B, 1, max_patches),
|
||||
patch_attn_mask = torch.zeros((B, max_patches),
|
||||
dtype=torch.bool,
|
||||
device=device)
|
||||
for i in range(B):
|
||||
patch_attn_mask[i, :tgt_sizes[i][0] * tgt_sizes[i][1]] = True
|
||||
for i, num_patches_item in enumerate(num_patches):
|
||||
patch_attn_mask[i, :num_patches_item] = True
|
||||
|
||||
return self.get_vision_embedding(all_pixel_values.type(dtype),
|
||||
patch_attn_mask, tgt_sizes)
|
||||
vision_embedding = self.vpm(
|
||||
all_pixel_values,
|
||||
patch_attention_mask=patch_attn_mask.unsqueeze(1),
|
||||
tgt_sizes=None,
|
||||
)
|
||||
|
||||
return self.resampler(vision_embedding, tgt_sizes)
|
||||
|
||||
|
||||
class MiniCPMV2_6(MiniCPMVBaseModel, SupportsLoRA):
|
||||
@ -1394,47 +1293,37 @@ class MiniCPMV2_6(MiniCPMVBaseModel, SupportsLoRA):
|
||||
return resampler.to(device=current_platform.device_type,
|
||||
dtype=torch.get_default_dtype())
|
||||
|
||||
def get_vision_embedding(
|
||||
self,
|
||||
pixel_values: List[torch.Tensor],
|
||||
patch_attn_mask: Optional[torch.Tensor] = None,
|
||||
tgt_sizes: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
vision_embedding = self.vpm(
|
||||
pixel_values,
|
||||
patch_attention_mask=patch_attn_mask,
|
||||
tgt_sizes=tgt_sizes,
|
||||
)
|
||||
return vision_embedding
|
||||
|
||||
def get_vision_hidden_states(self,
|
||||
data: MiniCPMVImageInputs) -> torch.Tensor:
|
||||
pixel_values = data["data"]
|
||||
def get_vision_hidden_states(
|
||||
self, data: MiniCPMVImagePixelInputs) -> torch.Tensor:
|
||||
pixel_values = data["pixel_values"]
|
||||
tgt_sizes = data["tgt_sizes"]
|
||||
|
||||
device = self.vpm.embeddings.position_embedding.weight.device
|
||||
dtype = self.vpm.embeddings.position_embedding.weight.dtype
|
||||
all_pixel_values_lst = [
|
||||
i.flatten(end_dim=1).permute(1, 0) for i in pixel_values
|
||||
]
|
||||
B = len(pixel_values)
|
||||
P = pixel_values[0].shape[-2]
|
||||
L = max(item.shape[-1] for item in pixel_values)
|
||||
device = pixel_values[0].device
|
||||
dtype = pixel_values[0].dtype
|
||||
|
||||
max_patches = (tgt_sizes[:, 0] * tgt_sizes[:, 1]).max().item()
|
||||
all_pixel_values = torch.zeros((B, 3, P, L),
|
||||
dtype=dtype,
|
||||
device=device)
|
||||
for i, pixel_values_item in enumerate(pixel_values):
|
||||
L_item = pixel_values_item.shape[-1]
|
||||
all_pixel_values[i, ..., :L_item] = pixel_values_item
|
||||
|
||||
num_patches = tgt_sizes.prod(-1)
|
||||
max_patches = num_patches.max().item()
|
||||
assert isinstance(max_patches, int)
|
||||
|
||||
all_pixel_values = torch.nn.utils.rnn.pad_sequence(
|
||||
all_pixel_values_lst, batch_first=True, padding_value=0.0)
|
||||
B, L, _ = all_pixel_values.shape
|
||||
all_pixel_values = all_pixel_values.permute(0, 2,
|
||||
1).reshape(B, 3, -1, L)
|
||||
|
||||
patch_attn_mask = torch.zeros((B, 1, max_patches),
|
||||
patch_attn_mask = torch.zeros((B, max_patches),
|
||||
dtype=torch.bool,
|
||||
device=device)
|
||||
for i in range(B):
|
||||
patch_attn_mask[i, 0, :tgt_sizes[i][0] * tgt_sizes[i][1]] = True
|
||||
for i, num_patches_item in enumerate(num_patches):
|
||||
patch_attn_mask[i, :num_patches_item] = True
|
||||
|
||||
vision_embedding = self.vpm(
|
||||
all_pixel_values.type(dtype),
|
||||
patch_attention_mask=patch_attn_mask,
|
||||
all_pixel_values,
|
||||
patch_attention_mask=patch_attn_mask.unsqueeze(1),
|
||||
tgt_sizes=tgt_sizes,
|
||||
)
|
||||
|
||||
|
@ -665,6 +665,13 @@ class MultiModalKwargs(UserDict[str, NestedTensors]):
|
||||
|
||||
return cast(BatchedTensorInputs, json_mapped)
|
||||
|
||||
def __delitem__(self, key: str) -> None:
|
||||
super().__delitem__(key)
|
||||
|
||||
for items in self._items_by_modality.values():
|
||||
for item in items:
|
||||
item.pop(key, None)
|
||||
|
||||
def __eq__(self, other: object) -> bool:
|
||||
if not isinstance(other, self.__class__):
|
||||
return False
|
||||
|
Loading…
x
Reference in New Issue
Block a user