[VLM] merged multimodal processor and V1 support for idefics3 (#12660)
Signed-off-by: Isotr0py <2037008807@qq.com> Co-authored-by: Cyrus Leung <cyrus.tl.leung@gmail.com>
This commit is contained in:
parent
18a88fcccc
commit
815079de8e
@ -733,7 +733,7 @@ See [this page](#generative-models) for more information on how to use generativ
|
||||
* `HuggingFaceM4/Idefics3-8B-Llama3` etc.
|
||||
* ✅︎
|
||||
*
|
||||
*
|
||||
* ✅︎
|
||||
- * `InternVLChatModel`
|
||||
* InternVL 2.5, Mono-InternVL, InternVL 2.0
|
||||
* T + I<sup>E+</sup>
|
||||
|
@ -254,14 +254,14 @@ VLM_TEST_SETTINGS = {
|
||||
patch_hf_runner=model_utils.h2ovl_patch_hf_runner,
|
||||
),
|
||||
"idefics3": VLMTestInfo(
|
||||
models=["HuggingFaceM4/Idefics3-8B-Llama3"],
|
||||
models=["HuggingFaceTB/SmolVLM-256M-Instruct"],
|
||||
test_type=(VLMTestType.IMAGE, VLMTestType.MULTI_IMAGE),
|
||||
prompt_formatter=lambda img_prompt:f"<|begin_of_text|>User:{img_prompt}<end_of_utterance>\nAssistant:", # noqa: E501
|
||||
img_idx_to_prompt=lambda idx: "<image>",
|
||||
max_model_len=8192,
|
||||
max_num_seqs=2,
|
||||
auto_cls=AutoModelForVision2Seq,
|
||||
marks=[large_gpu_mark(min_gb=48)],
|
||||
hf_output_post_proc=model_utils.idefics3_trunc_hf_output,
|
||||
),
|
||||
"intern_vl": VLMTestInfo(
|
||||
models=[
|
||||
|
@ -192,6 +192,14 @@ def deepseekvl2_trunc_hf_output(hf_output: RunnerOutput,
|
||||
return output_ids, output_str, out_logprobs
|
||||
|
||||
|
||||
def idefics3_trunc_hf_output(hf_output: RunnerOutput,
|
||||
model: str) -> RunnerOutput:
|
||||
output_ids, output_str, out_logprobs = hf_output
|
||||
if output_str.endswith("<end_of_utterance>"):
|
||||
output_str = output_str.split("<end_of_utterance>")[0]
|
||||
return output_ids, output_str, out_logprobs
|
||||
|
||||
|
||||
def minicpmv_trunc_hf_output(hf_output: RunnerOutput,
|
||||
model: str) -> RunnerOutput:
|
||||
output_ids, output_str, out_logprobs = hf_output
|
||||
|
@ -149,6 +149,7 @@ def _test_processing_correctness(
|
||||
"adept/fuyu-8b",
|
||||
"h2oai/h2ovl-mississippi-800m",
|
||||
"OpenGVLab/InternVL2-1B",
|
||||
"HuggingFaceM4/Idefics3-8B-Llama3",
|
||||
"llava-hf/llava-1.5-7b-hf",
|
||||
"llava-hf/llava-v1.6-mistral-7b-hf",
|
||||
"llava-hf/LLaVA-NeXT-Video-7B-hf",
|
||||
|
@ -1,13 +1,10 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
"""Tests for Idefics3's multimodal preprocessing kwargs."""
|
||||
from typing import Optional
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
from transformers import AutoImageProcessor, AutoTokenizer
|
||||
from transformers import Idefics3Config
|
||||
|
||||
from vllm.inputs import InputContext, token_inputs
|
||||
from vllm.multimodal import MultiModalRegistry
|
||||
from vllm.multimodal import MULTIMODAL_REGISTRY
|
||||
from vllm.multimodal.utils import cached_get_tokenizer
|
||||
|
||||
from ....conftest import _ImageAssets
|
||||
from ...utils import build_model_context
|
||||
@ -15,163 +12,53 @@ from ...utils import build_model_context
|
||||
models = ["HuggingFaceM4/Idefics3-8B-Llama3"]
|
||||
|
||||
|
||||
# Wrap lazy imports to avoid initializing CUDA during test collection
|
||||
@pytest.fixture()
|
||||
def input_processor_for_idefics3():
|
||||
from vllm.model_executor.models.idefics3 import (
|
||||
input_processor_for_idefics3)
|
||||
return input_processor_for_idefics3
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
def dummy_data_for_idefics3():
|
||||
from vllm.model_executor.models.idefics3 import dummy_data_for_idefics3
|
||||
return dummy_data_for_idefics3
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
def get_max_idefics3_image_tokens():
|
||||
from vllm.model_executor.models.idefics3 import (
|
||||
get_max_idefics3_image_tokens)
|
||||
return get_max_idefics3_image_tokens
|
||||
|
||||
|
||||
@pytest.mark.parametrize("model", models)
|
||||
@pytest.mark.parametrize("longest_edge", [None, 168, 336, 400, 2 * 336])
|
||||
def test_input_mapper_override(model: str, image_assets: _ImageAssets,
|
||||
longest_edge: Optional[int]):
|
||||
"""Ensure that the [default] input mapper handles size properly."""
|
||||
|
||||
mm_processor_kwargs = {
|
||||
"size": {
|
||||
"longest_edge": longest_edge
|
||||
}
|
||||
} if longest_edge is not None else {}
|
||||
ctx = build_model_context(
|
||||
model_name=model,
|
||||
tokenizer_name=model,
|
||||
trust_remote_code=True,
|
||||
mm_processor_kwargs=mm_processor_kwargs,
|
||||
)
|
||||
|
||||
hf_processor = AutoImageProcessor.from_pretrained(model,
|
||||
trust_remote_code=True,
|
||||
**mm_processor_kwargs)
|
||||
|
||||
mm_registry = MultiModalRegistry()
|
||||
mm_registry.init_mm_limits_per_prompt(ctx.model_config)
|
||||
|
||||
image = image_assets[0].pil_image
|
||||
hf_result = hf_processor.preprocess(
|
||||
image,
|
||||
return_tensors="pt",
|
||||
)
|
||||
|
||||
vllm_result = mm_registry.map_input(
|
||||
ctx.model_config,
|
||||
{"image": image},
|
||||
)
|
||||
|
||||
assert torch.all(hf_result["pixel_values"] == vllm_result["pixel_values"])
|
||||
|
||||
|
||||
@pytest.mark.parametrize("model", models)
|
||||
@pytest.mark.parametrize("longest_edge, expected_max_tokens", [
|
||||
(None, 2873),
|
||||
(168, 169),
|
||||
(336, 169),
|
||||
(400, 338),
|
||||
(672, 338),
|
||||
])
|
||||
def test_max_tokens_override(get_max_idefics3_image_tokens, model: str,
|
||||
longest_edge: Optional[int],
|
||||
expected_max_tokens: int):
|
||||
"""Ensure get_max_idefics3_image_tokens handles mm_processor_kwargs."""
|
||||
size = {"longest_edge": longest_edge} if longest_edge is not None else None
|
||||
ctx = build_model_context(
|
||||
model_name=model,
|
||||
tokenizer_name=model,
|
||||
trust_remote_code=True,
|
||||
mm_processor_kwargs=None,
|
||||
)
|
||||
|
||||
actual_max_tokens = get_max_idefics3_image_tokens(
|
||||
ctx=InputContext(ctx.model_config),
|
||||
size=size,
|
||||
)
|
||||
|
||||
assert expected_max_tokens == actual_max_tokens
|
||||
|
||||
|
||||
@pytest.mark.parametrize("model", models)
|
||||
@pytest.mark.parametrize("longest_edge, toks_per_img, num_imgs", [
|
||||
(168, 169, 1),
|
||||
(168, 169, 2),
|
||||
(400, 338, 1),
|
||||
(400, 338, 2),
|
||||
])
|
||||
def test_dummy_data_override(dummy_data_for_idefics3, model: str,
|
||||
longest_edge: int, toks_per_img: int,
|
||||
num_imgs: int):
|
||||
"""Ensure dummy_data_for_idefics3 handles num_crops properly."""
|
||||
# Same as the previous test - don't initialize mm_processor_kwargs
|
||||
# in this test and assume that the kwargs will be correctly expanded by
|
||||
# the partial when calling the dummy data func.
|
||||
size = {"longest_edge": longest_edge} if longest_edge is not None else None
|
||||
ctx = build_model_context(
|
||||
model_name=model,
|
||||
tokenizer_name=model,
|
||||
trust_remote_code=True,
|
||||
mm_processor_kwargs=None,
|
||||
)
|
||||
|
||||
dummy_data = dummy_data_for_idefics3(
|
||||
ctx=ctx,
|
||||
seq_len=8192, # Should be bigger than num_imgs * toks_per_img
|
||||
mm_counts={"image": num_imgs},
|
||||
size=size)
|
||||
sequence_data = dummy_data.seq_data
|
||||
# Ensure we have the right number of placeholders per size
|
||||
image_token_id = ctx.get_hf_config().image_token_id
|
||||
img_tok_count = sequence_data.get_token_ids().count(image_token_id)
|
||||
assert img_tok_count == toks_per_img * num_imgs
|
||||
|
||||
|
||||
@pytest.mark.parametrize("model", models)
|
||||
@pytest.mark.parametrize("longest_edge,expected_toks_per_img,num_imgs", [
|
||||
(336, 169 * (1**2 + 1), 1),
|
||||
(336, 169 * (1**2 + 1), 2),
|
||||
(400, 169 * (2**2 + 1), 1),
|
||||
(400, 169 * (2**2 + 1), 2),
|
||||
])
|
||||
def test_input_processor_override(input_processor_for_idefics3,
|
||||
image_assets: _ImageAssets, model: str,
|
||||
longest_edge: int,
|
||||
expected_toks_per_img: int, num_imgs: int):
|
||||
# yapf: disable
|
||||
@pytest.mark.parametrize(
|
||||
("mm_processor_kwargs", "expected_toks_per_img"),
|
||||
[
|
||||
({"size": {"longest_edge": 364}}, 169),
|
||||
({"size": {"longest_edge": 728}}, 169 * (2**2 + 1)),
|
||||
])
|
||||
# yapf: enable
|
||||
@pytest.mark.parametrize("num_imgs", [1, 2])
|
||||
def test_processor_override(image_assets: _ImageAssets, model: str,
|
||||
mm_processor_kwargs: dict[str, object],
|
||||
expected_toks_per_img: int, num_imgs: int):
|
||||
"""Ensure input_processor_for_idefics3 handles num_crops properly."""
|
||||
# Same as the previous test - don't initialize mm_processor_kwargs
|
||||
# in this test and assume that the kwargs will be correctly expanded by
|
||||
# the partial when calling the custom input processor.
|
||||
size = {"longest_edge": longest_edge} if longest_edge is not None else None
|
||||
ctx = build_model_context(
|
||||
model_name=model,
|
||||
tokenizer_name=model,
|
||||
trust_remote_code=True,
|
||||
mm_processor_kwargs=None,
|
||||
limit_mm_per_prompt={"image": num_imgs},
|
||||
)
|
||||
tokenizer = cached_get_tokenizer(ctx.model_config.tokenizer)
|
||||
processor = MULTIMODAL_REGISTRY.create_processor(
|
||||
ctx.model_config,
|
||||
tokenizer=tokenizer,
|
||||
)
|
||||
hf_processor = processor.info.get_hf_processor(**mm_processor_kwargs)
|
||||
|
||||
# Build the image str / prompt based on the number of images we pass
|
||||
tokenizer = AutoTokenizer.from_pretrained(model)
|
||||
placeholders = "<image>" if num_imgs == 1 else "\n".join(
|
||||
f"Image-{i}: <image>\n" for i in range(1, num_imgs + 1))
|
||||
prompt = f"<|begin_of_text|>User:{placeholders}\n<end_of_utterance>\nAssistant:" # noqa: E501
|
||||
images = [image_assets[0].pil_image.resize((336 * 4, 336 * 4))] * num_imgs
|
||||
|
||||
inputs = token_inputs(prompt_token_ids=tokenizer.encode(prompt),
|
||||
prompt=prompt,
|
||||
multi_modal_data={"image": images})
|
||||
# Build mm_data
|
||||
image_size = ctx.get_hf_config(Idefics3Config).vision_config.image_size
|
||||
dummy_image_size = (image_size * 4, image_size * 4)
|
||||
dummy_image = image_assets[0].pil_image.resize(dummy_image_size)
|
||||
mm_data = {"image": [dummy_image] * num_imgs}
|
||||
|
||||
processed_inputs = input_processor_for_idefics3(ctx, inputs, size=size)
|
||||
processed_inputs = processor.apply(prompt, mm_data, mm_processor_kwargs)
|
||||
# Ensure the placeholders format are correct
|
||||
hf_processed_inputs = hf_processor(text=prompt, images=mm_data["image"])
|
||||
assert processed_inputs["prompt_token_ids"] == hf_processed_inputs[
|
||||
"input_ids"][0]
|
||||
|
||||
# Ensure we have the right number of placeholders per num_crops size
|
||||
image_token_id = ctx.get_hf_config().image_token_id
|
||||
|
@ -31,6 +31,17 @@ C = TypeVar("C", bound=PretrainedConfig, default=PretrainedConfig)
|
||||
P = TypeVar("P", bound=ProcessorMixin, default=ProcessorMixin)
|
||||
|
||||
|
||||
class HashableDict(dict):
|
||||
"""
|
||||
A dictionary that can be hashed by lru_cache.
|
||||
"""
|
||||
|
||||
# NOTE: pythonic dict is not hashable,
|
||||
# we override on it directly for simplicity
|
||||
def __hash__(self) -> int: # type: ignore[override]
|
||||
return hash(frozenset(self.items()))
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class InputContext:
|
||||
"""
|
||||
@ -104,6 +115,13 @@ class InputContext:
|
||||
if isinstance(typ, type):
|
||||
merged_kwargs["processor_cls"] = typ
|
||||
|
||||
# NOTE: Pythonic dict is not hashable and will raise unhashable type
|
||||
# error when calling `cached_get_processor`, therefore we need to
|
||||
# wrap it to a hashable dict.
|
||||
for key, value in merged_kwargs.items():
|
||||
if isinstance(value, dict):
|
||||
merged_kwargs[key] = HashableDict(value)
|
||||
|
||||
hf_processor = cached_get_processor(
|
||||
self.model_config.model,
|
||||
trust_remote_code=self.model_config.trust_remote_code,
|
||||
|
@ -16,35 +16,35 @@
|
||||
"""Inference-only Idefics3 model compatible with HuggingFace weights."""
|
||||
|
||||
import math
|
||||
from typing import (Dict, Iterable, List, Literal, Mapping, NamedTuple,
|
||||
Optional, Set, Tuple, TypedDict, Union)
|
||||
from typing import (Dict, Iterable, List, Literal, Mapping, Optional, Set,
|
||||
Tuple, TypedDict, Union)
|
||||
|
||||
import torch
|
||||
import torch.utils.checkpoint
|
||||
from PIL import Image
|
||||
from torch import nn
|
||||
# Temporary solution for transformers below 4.46.0.
|
||||
from transformers import PretrainedConfig as Idefics3Config
|
||||
from transformers import ProcessorMixin as Idefics3ImageProcessor
|
||||
from transformers import (BatchFeature, Idefics3Config, Idefics3ImageProcessor,
|
||||
Idefics3Processor)
|
||||
|
||||
from vllm.attention import AttentionMetadata
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.inputs import (INPUT_REGISTRY, DecoderOnlyInputs, DummyData,
|
||||
InputContext, token_inputs)
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.linear import ReplicatedLinear
|
||||
from vllm.model_executor.layers.logits_processor import LogitsProcessor
|
||||
from vllm.model_executor.layers.quantization import QuantizationConfig
|
||||
from vllm.model_executor.layers.sampler import Sampler, SamplerOutput
|
||||
from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler
|
||||
from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead
|
||||
from vllm.model_executor.models.module_mapping import MultiModelKeys
|
||||
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
||||
from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalKwargs
|
||||
from vllm.multimodal.image import cached_get_image_processor
|
||||
from vllm.multimodal.inputs import NestedTensors
|
||||
from vllm.sequence import IntermediateTensors, SequenceData
|
||||
from vllm.transformers_utils.processor import cached_get_processor
|
||||
from vllm.utils import is_list_of
|
||||
from vllm.multimodal.parse import ImageProcessorItems
|
||||
from vllm.multimodal.processing import (BaseMultiModalProcessor,
|
||||
BaseProcessingInfo,
|
||||
MultiModalDataItems,
|
||||
MultiModalFieldConfig,
|
||||
PromptReplacement)
|
||||
from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs
|
||||
from vllm.sequence import IntermediateTensors
|
||||
|
||||
# yapf: disable
|
||||
from .idefics2_vision_model import (
|
||||
@ -77,307 +77,253 @@ class Idefics3ImageEmbeddingInputs(TypedDict):
|
||||
"""
|
||||
|
||||
|
||||
class Idefics3ProcessorSize(NamedTuple):
|
||||
"""Hashable wrapper for unhashable `size` dict of Idefics3Processor."""
|
||||
# NOTE: cached_get_processor/cached_get_image_processor uses lru_cache,
|
||||
# we need to use NamedTuple instead of TypedDict to avoid hashing issues.
|
||||
longest_edge: int
|
||||
|
||||
def __contains__(self, key: str) -> bool:
|
||||
return key in self._asdict() and getattr(self, key) is not None
|
||||
|
||||
def __getitem__(self, key: str) -> int:
|
||||
return getattr(self, key)
|
||||
|
||||
|
||||
ImageInputs = Union[Idefics3ImagePixelInputs, Idefics3ImageEmbeddingInputs]
|
||||
|
||||
|
||||
def get_mm_processor_kwargs(size: Optional[Dict[str, int]] = None) -> Dict:
|
||||
mm_processor_kwargs = {}
|
||||
if size:
|
||||
mm_processor_kwargs["size"] = Idefics3ProcessorSize(**size)
|
||||
return mm_processor_kwargs
|
||||
class Idefics3ProcessingInfo(BaseProcessingInfo):
|
||||
|
||||
def get_hf_processor(
|
||||
self,
|
||||
*,
|
||||
size: Optional[Dict[str, int]] = None) -> Idefics3Processor:
|
||||
if size is not None:
|
||||
return self.ctx.get_hf_processor(Idefics3Processor, size=size)
|
||||
|
||||
def input_mapper_for_idefics3(
|
||||
ctx: InputContext,
|
||||
data: object,
|
||||
*,
|
||||
size: Optional[Dict[str, int]] = None,
|
||||
):
|
||||
model_config = ctx.model_config
|
||||
mm_processor_kwargs = get_mm_processor_kwargs(size)
|
||||
image_processor = cached_get_image_processor(
|
||||
model_config.model,
|
||||
trust_remote_code=model_config.trust_remote_code,
|
||||
**mm_processor_kwargs)
|
||||
if image_processor is None:
|
||||
raise RuntimeError("No HuggingFace processor is available "
|
||||
"to process the image object")
|
||||
return self.ctx.get_hf_processor(Idefics3Processor)
|
||||
|
||||
if isinstance(data, Image.Image):
|
||||
images = [[data]]
|
||||
elif is_list_of(data, Image.Image):
|
||||
images = [data]
|
||||
else:
|
||||
raise TypeError(f"Invalid image type: {type(data)}")
|
||||
def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]:
|
||||
return {"image": None}
|
||||
|
||||
try:
|
||||
batch_data = image_processor(images,
|
||||
return_tensors="pt",
|
||||
return_row_col_info=True).data
|
||||
except Exception:
|
||||
logger.error("Failed to process image (%s)", data)
|
||||
raise
|
||||
|
||||
return MultiModalKwargs(batch_data)
|
||||
|
||||
|
||||
def _resize_output_size(height: int,
|
||||
width: int,
|
||||
max_len: Optional[int] = None,
|
||||
min_len: Optional[int] = 1,
|
||||
max_size: Optional[int] = None) -> Tuple[int, int]:
|
||||
# Set default value for max_len if not provided
|
||||
max_len = max(height, width) if max_len is None else max_len
|
||||
aspect_ratio = width / height
|
||||
|
||||
# Handle the maximum size constraint
|
||||
if max_size is not None:
|
||||
max_len = min(max_len, max_size)
|
||||
|
||||
# Adjust dimensions according to the aspect ratio
|
||||
if width >= height:
|
||||
width = max_len
|
||||
height = int(width / aspect_ratio)
|
||||
else:
|
||||
height = max_len
|
||||
width = int(height * aspect_ratio)
|
||||
|
||||
# Ensure both width and height are even (if needed)
|
||||
height += 1 if height % 2 != 0 else 0
|
||||
width += 1 if width % 2 != 0 else 0
|
||||
|
||||
# Ensure dimensions are not smaller than the minimum length
|
||||
height = max(height, min_len)
|
||||
width = max(width, min_len)
|
||||
|
||||
return height, width
|
||||
|
||||
|
||||
def _get_resize_output_image_size(
|
||||
image_size: Tuple[int, int],
|
||||
resolution_max_side: int,
|
||||
max_image_size: int = 1820,
|
||||
) -> Tuple[int, int]:
|
||||
if resolution_max_side > max_image_size:
|
||||
raise ValueError(
|
||||
"`resolution_max_side` cannot be larger than `max_image_size`")
|
||||
|
||||
height, width = image_size
|
||||
|
||||
# Find the output size, when rescaling the longest edge to max_len and
|
||||
# preserving the aspect ratio
|
||||
height, width = _resize_output_size(height,
|
||||
width,
|
||||
max_len=resolution_max_side)
|
||||
|
||||
return height, width
|
||||
|
||||
|
||||
def _prompt_split_image(image_seq_len: int, image_rows: int, image_cols: int,
|
||||
fake_token_around_image: str, image_token: str,
|
||||
global_img_token: str) -> str:
|
||||
"""
|
||||
Prompt with expanded image tokens for when the image is split
|
||||
into patches.
|
||||
"""
|
||||
text_split_images = ""
|
||||
for n_h in range(image_rows):
|
||||
for n_w in range(image_cols):
|
||||
text_split_images += (fake_token_around_image +
|
||||
f"<row_{n_h + 1}_col_{n_w + 1}>" +
|
||||
image_token * image_seq_len)
|
||||
text_split_images += "\n"
|
||||
|
||||
text_split_images += "\n" + _prompt_single_image(
|
||||
image_seq_len=image_seq_len,
|
||||
fake_token_around_image=fake_token_around_image,
|
||||
image_token=image_token,
|
||||
global_img_token=global_img_token)
|
||||
return text_split_images
|
||||
|
||||
|
||||
def _prompt_single_image(image_seq_len: int, fake_token_around_image: str,
|
||||
image_token: str, global_img_token: str):
|
||||
"""Prompt with expanded image tokens for a single image."""
|
||||
return (fake_token_around_image + global_img_token +
|
||||
image_token * image_seq_len + fake_token_around_image)
|
||||
|
||||
|
||||
def _get_image_prompt_string(image_rows: int, image_cols: int,
|
||||
image_seq_len: int, fake_token_around_image: str,
|
||||
image_token: str, global_img_token: str):
|
||||
if image_rows == 0 and image_cols == 0:
|
||||
return _prompt_single_image(
|
||||
image_seq_len=image_seq_len,
|
||||
fake_token_around_image=fake_token_around_image,
|
||||
image_token=image_token,
|
||||
global_img_token=global_img_token,
|
||||
)
|
||||
return _prompt_split_image(image_seq_len, image_rows, image_cols,
|
||||
fake_token_around_image, image_token,
|
||||
global_img_token)
|
||||
|
||||
|
||||
def input_processor_for_idefics3(ctx: InputContext,
|
||||
inputs: DecoderOnlyInputs,
|
||||
*,
|
||||
size: Optional[Dict[str, int]] = None):
|
||||
multi_modal_data = inputs.get("multi_modal_data")
|
||||
if multi_modal_data is None or "image" not in multi_modal_data:
|
||||
return inputs
|
||||
|
||||
model_config = ctx.model_config
|
||||
mm_processor_kwargs = get_mm_processor_kwargs(size)
|
||||
processor = cached_get_processor(model_config.model, **mm_processor_kwargs)
|
||||
image_processor = processor.image_processor
|
||||
tokenizer = processor.tokenizer
|
||||
size = image_processor.size['longest_edge']
|
||||
max_image_size = image_processor.max_image_size['longest_edge']
|
||||
|
||||
image_data = multi_modal_data["image"]
|
||||
if isinstance(image_data, Image.Image):
|
||||
image_list = [image_data]
|
||||
elif is_list_of(image_data, Image.Image):
|
||||
image_list = image_data
|
||||
else:
|
||||
raise TypeError(f"Invalid image type: {type(image_data)}")
|
||||
|
||||
image_rows = []
|
||||
image_cols = []
|
||||
for image in image_list:
|
||||
height, width = _get_resize_output_image_size(image.size, size)
|
||||
|
||||
rows = math.ceil(height / max_image_size)
|
||||
cols = math.ceil(width / max_image_size)
|
||||
image_rows.append(rows)
|
||||
image_cols.append(cols)
|
||||
image_rows = [image_rows]
|
||||
image_cols = [image_cols]
|
||||
|
||||
n_images_in_text = []
|
||||
|
||||
text = inputs.get("prompt")
|
||||
if text is None:
|
||||
prompt_token_ids = inputs.get("prompt_token_ids", [])
|
||||
assert prompt_token_ids
|
||||
text = tokenizer.decode(prompt_token_ids)
|
||||
|
||||
if isinstance(text, str):
|
||||
text = [text]
|
||||
elif not isinstance(text, list) and not isinstance(text[0], str):
|
||||
raise ValueError("Invalid input text. Please provide a string, "
|
||||
"or a list of strings")
|
||||
|
||||
fake_image_token = processor.fake_image_token.content
|
||||
image_token = processor.image_token.content
|
||||
global_img_token = processor.global_image_tag
|
||||
|
||||
prompt_strings = []
|
||||
for sample, sample_rows, sample_cols in zip(text, image_rows, image_cols):
|
||||
n_images_in_text.append(sample.count(image_token))
|
||||
|
||||
# Replace the image token with fake tokens around the expanded
|
||||
# image token sequence of length `image_seq_len`
|
||||
image_prompt_strings = []
|
||||
for n_rows, n_cols in zip(sample_rows, sample_cols):
|
||||
image_prompt_string = _get_image_prompt_string(
|
||||
n_rows,
|
||||
n_cols,
|
||||
processor.image_seq_len,
|
||||
image_token=image_token,
|
||||
fake_token_around_image=fake_image_token,
|
||||
global_img_token=global_img_token,
|
||||
)
|
||||
image_prompt_strings.append(image_prompt_string)
|
||||
|
||||
split_sample = sample.split(image_token)
|
||||
if len(split_sample) == 0:
|
||||
raise ValueError("The image token should be present in the text.")
|
||||
|
||||
# Place in the image prompt strings where the image tokens are
|
||||
sample = split_sample[0]
|
||||
for i, image_prompt_string in enumerate(image_prompt_strings):
|
||||
sample += image_prompt_string + split_sample[i + 1]
|
||||
prompt_strings.append(sample)
|
||||
|
||||
prompt_token_ids = tokenizer(text=prompt_strings[0]).input_ids
|
||||
|
||||
return token_inputs(
|
||||
prompt_token_ids=prompt_token_ids,
|
||||
prompt=prompt_strings[0],
|
||||
multi_modal_data=multi_modal_data,
|
||||
)
|
||||
|
||||
|
||||
def _get_max_num_image_patch(image_processor: Idefics3ImageProcessor) -> int:
|
||||
size = image_processor.size['longest_edge']
|
||||
max_image_size = image_processor.max_image_size['longest_edge']
|
||||
resized_height, resized_width = size, size
|
||||
|
||||
grid_h = resized_height // max_image_size
|
||||
grid_w = resized_width // max_image_size
|
||||
return (grid_h * grid_w + 1)
|
||||
|
||||
|
||||
def get_max_idefics3_image_tokens(ctx: InputContext,
|
||||
*,
|
||||
size: Optional[Dict[str,
|
||||
int]] = None) -> int:
|
||||
model_config = ctx.model_config
|
||||
mm_processor_kwargs = get_mm_processor_kwargs(size)
|
||||
processor = cached_get_processor(model_config.model, **mm_processor_kwargs)
|
||||
image_seq_len = processor.image_seq_len
|
||||
image_processor = processor.image_processor
|
||||
|
||||
max_num_image_patches = _get_max_num_image_patch(image_processor)
|
||||
|
||||
return max_num_image_patches * image_seq_len
|
||||
|
||||
|
||||
def dummy_data_for_idefics3(
|
||||
ctx: InputContext,
|
||||
def get_mm_max_tokens_per_item(
|
||||
self,
|
||||
seq_len: int,
|
||||
mm_counts: Mapping[str, int],
|
||||
) -> Mapping[str, int]:
|
||||
hf_processor = self.get_hf_processor()
|
||||
image_processor: Idefics3ImageProcessor = hf_processor.image_processor
|
||||
grid_w, grid_h = self._get_image_feature_grid_size(
|
||||
image_width=image_processor.size['longest_edge'],
|
||||
image_height=image_processor.size['longest_edge'],
|
||||
)
|
||||
num_image_token = (grid_w * grid_h + 1) * hf_processor.image_seq_len
|
||||
# Calculate Non-image-token length
|
||||
# NOTE: <row_1_col_1> and <global-img> are special token for SmolVLM
|
||||
# but not for Idefic3, so we need to tokenize them to get actual length.
|
||||
tokenizer = self.get_tokenizer()
|
||||
tile_token_len = len(tokenizer.tokenize("<row_1_col_1>"))
|
||||
glob_token_len = len(tokenizer.tokenize(hf_processor.global_image_tag))
|
||||
# linebreak and <fake_token_around_image> always cost 1 token
|
||||
fake_token_len = lb_len = 1
|
||||
non_image_token = (grid_w * grid_h) * (
|
||||
tile_token_len + fake_token_len) + glob_token_len + (
|
||||
grid_h + 1) * lb_len + fake_token_len
|
||||
return {"image": num_image_token + non_image_token}
|
||||
|
||||
def _resize_output_size(self,
|
||||
*,
|
||||
height: int,
|
||||
width: int,
|
||||
max_len: Optional[int] = None,
|
||||
min_len: Optional[int] = 1,
|
||||
max_size: Optional[int] = None) -> tuple[int, int]:
|
||||
# Set default value for max_len if not provided
|
||||
max_len = max(height, width) if max_len is None else max_len
|
||||
aspect_ratio = width / height
|
||||
|
||||
# Handle the maximum size constraint
|
||||
if max_size is not None:
|
||||
max_len = min(max_len, max_size)
|
||||
|
||||
# Adjust dimensions according to the aspect ratio
|
||||
if width >= height:
|
||||
width = max_len
|
||||
height = int(width / aspect_ratio)
|
||||
else:
|
||||
height = max_len
|
||||
width = int(height * aspect_ratio)
|
||||
|
||||
# Ensure both width and height are even (if needed)
|
||||
height += height % 2
|
||||
width += width % 2
|
||||
|
||||
# Ensure dimensions are not smaller than the minimum length
|
||||
height = max(height, min_len)
|
||||
width = max(width, min_len)
|
||||
|
||||
return height, width
|
||||
|
||||
def _get_resize_output_image_size(
|
||||
self,
|
||||
*,
|
||||
size: Optional[Dict[str, int]] = None) -> DummyData:
|
||||
hf_config = ctx.get_hf_config()
|
||||
num_images = mm_counts["image"]
|
||||
image_width: int,
|
||||
image_height: int,
|
||||
resolution_max_side: int,
|
||||
) -> tuple[int, int]:
|
||||
hf_processor = self.get_hf_processor()
|
||||
image_processor: Idefics3ImageProcessor = hf_processor.image_processor
|
||||
max_image_size = image_processor.size['longest_edge']
|
||||
if resolution_max_side > max_image_size:
|
||||
raise ValueError(
|
||||
"`resolution_max_side` cannot be larger than `max_image_size`")
|
||||
|
||||
mm_processor_kwargs = get_mm_processor_kwargs(size)
|
||||
processor = cached_get_processor(ctx.model_config.model,
|
||||
**mm_processor_kwargs)
|
||||
max_num_image_patches = _get_max_num_image_patch(processor.image_processor)
|
||||
image_seq_len = processor.image_seq_len
|
||||
max_llm_image_tokens = max_num_image_patches * image_seq_len * num_images
|
||||
height, width = image_height, image_width
|
||||
|
||||
if seq_len - max_llm_image_tokens < 0:
|
||||
raise RuntimeError(
|
||||
f"Idefics3 cannot process {num_images} images in a prompt, "
|
||||
"please increase max_model_len or reduce image limit by "
|
||||
"--limit-mm-per-prompt.")
|
||||
# Find the output size, when rescaling the longest edge to max_len and
|
||||
# preserving the aspect ratio
|
||||
height, width = self._resize_output_size(height=height,
|
||||
width=width,
|
||||
max_len=resolution_max_side)
|
||||
return height, width
|
||||
|
||||
seq_data = SequenceData.from_prompt_token_counts(
|
||||
(hf_config.image_token_id, max_llm_image_tokens),
|
||||
(0, seq_len - max_llm_image_tokens))
|
||||
def _get_image_feature_grid_size(
|
||||
self,
|
||||
*,
|
||||
image_width: int,
|
||||
image_height: int,
|
||||
size: Optional[dict[str, object]] = None,
|
||||
) -> tuple[int, int]:
|
||||
hf_processor = self.get_hf_processor(size=size)
|
||||
image_processor: Idefics3ImageProcessor = hf_processor.image_processor
|
||||
max_image_size = image_processor.max_image_size['longest_edge']
|
||||
size = image_processor.size['longest_edge']
|
||||
assert size % max_image_size == 0, (
|
||||
"`longest_edge` in image_processor's `size` must be divisible by "
|
||||
"`longest_edge` in `max_image_size`, this may be caused by "
|
||||
"incorrect mm_kwargs override.")
|
||||
|
||||
width = height = hf_config.vision_config.image_size
|
||||
image = Image.new("RGB", (width, height), color=0)
|
||||
mm_data = {"image": [image] if num_images == 1 else [image] * num_images}
|
||||
resized_height, resized_width = self._get_resize_output_image_size(
|
||||
image_width=image_width,
|
||||
image_height=image_height,
|
||||
resolution_max_side=size,
|
||||
)
|
||||
if resized_height > max_image_size or resized_width > max_image_size:
|
||||
grid_h = math.ceil(resized_height / max_image_size)
|
||||
grid_w = math.ceil(resized_width / max_image_size)
|
||||
else:
|
||||
grid_h = grid_w = 0
|
||||
return grid_w, grid_h
|
||||
|
||||
return DummyData(seq_data, mm_data)
|
||||
|
||||
class Idefics3DummyInputsBuilder(BaseDummyInputsBuilder[Idefics3ProcessingInfo]
|
||||
):
|
||||
|
||||
def get_dummy_processor_inputs(
|
||||
self,
|
||||
seq_len: int,
|
||||
mm_counts: Mapping[str, int],
|
||||
) -> ProcessorInputs:
|
||||
num_images = mm_counts.get("image", 0)
|
||||
hf_processor = self.info.get_hf_processor()
|
||||
image_processor: Idefics3ImageProcessor = hf_processor.image_processor
|
||||
longest_edge = image_processor.max_image_size['longest_edge']
|
||||
image_token: str = hf_processor.image_token.content
|
||||
|
||||
mm_data = {
|
||||
"image":
|
||||
self._get_dummy_images(width=longest_edge,
|
||||
height=longest_edge,
|
||||
num_images=num_images)
|
||||
}
|
||||
|
||||
return ProcessorInputs(
|
||||
prompt_text=image_token * num_images,
|
||||
mm_data=mm_data,
|
||||
)
|
||||
|
||||
|
||||
class Idefics3MultimodalProcessor(
|
||||
BaseMultiModalProcessor[Idefics3ProcessingInfo]):
|
||||
|
||||
def _call_hf_processor(
|
||||
self,
|
||||
prompt: str,
|
||||
mm_data: Mapping[str, object],
|
||||
mm_kwargs: Mapping[str, object],
|
||||
) -> BatchFeature:
|
||||
if mm_data:
|
||||
processed_outputs = super()._call_hf_processor(
|
||||
prompt, mm_data, mm_kwargs)
|
||||
image_grids = [
|
||||
self.info._get_image_feature_grid_size(
|
||||
image_width=img.width,
|
||||
image_height=img.height,
|
||||
**mm_kwargs,
|
||||
) for img in mm_data["images"]
|
||||
]
|
||||
image_patches = list(map(lambda x: math.prod(x) + 1, image_grids))
|
||||
for key in ("pixel_values", "pixel_attention_mask"):
|
||||
data = processed_outputs.pop(key)
|
||||
data = data.flatten(0, 1).split(image_patches)
|
||||
processed_outputs[key] = data
|
||||
else:
|
||||
tokenizer = self.info.get_tokenizer()
|
||||
processed_outputs = tokenizer(prompt,
|
||||
add_special_tokens=True,
|
||||
return_tensors="pt")
|
||||
return processed_outputs
|
||||
|
||||
def _get_mm_fields_config(
|
||||
self,
|
||||
hf_inputs: BatchFeature,
|
||||
hf_processor_mm_kwargs: Mapping[str, object],
|
||||
) -> Mapping[str, MultiModalFieldConfig]:
|
||||
return dict(
|
||||
pixel_values=MultiModalFieldConfig.batched("image"),
|
||||
pixel_attention_mask=MultiModalFieldConfig.batched("image"),
|
||||
image_embeds=MultiModalFieldConfig.batched("image"),
|
||||
)
|
||||
|
||||
def _get_prompt_replacements(
|
||||
self,
|
||||
mm_items: MultiModalDataItems,
|
||||
hf_processor_mm_kwargs: Mapping[str, object],
|
||||
out_mm_kwargs: MultiModalKwargs,
|
||||
) -> list[PromptReplacement]:
|
||||
hf_processor = self.info.get_hf_processor(**hf_processor_mm_kwargs)
|
||||
|
||||
image_token = hf_processor.image_token.content
|
||||
fake_image_token = hf_processor.fake_image_token.content
|
||||
global_img_token = hf_processor.global_image_tag
|
||||
image_seq_len = hf_processor.image_seq_len
|
||||
grid_placeholder = "<row_{n_h}_col_{n_w}>"
|
||||
|
||||
p_img = image_token * image_seq_len
|
||||
global_img_placeholder = fake_image_token + global_img_token + p_img
|
||||
tile_img_placeholder = fake_image_token + grid_placeholder + p_img
|
||||
|
||||
def get_replacement_idefics3(item_idx: int) -> str:
|
||||
images = mm_items.get_items("image", ImageProcessorItems)
|
||||
|
||||
image_size = images.get_image_size(item_idx)
|
||||
grid_w, grid_h = self.info._get_image_feature_grid_size(
|
||||
image_width=image_size.width,
|
||||
image_height=image_size.height,
|
||||
**hf_processor_mm_kwargs,
|
||||
)
|
||||
if grid_w == 0 and grid_h == 0:
|
||||
image_placeholder = global_img_placeholder
|
||||
else:
|
||||
tiles_placeholder = list[str]()
|
||||
for i in range(grid_h):
|
||||
for j in range(grid_w):
|
||||
placeholder_per_tile = tile_img_placeholder.format(
|
||||
n_h=i + 1, n_w=j + 1)
|
||||
tiles_placeholder.append(placeholder_per_tile)
|
||||
# Add line break if it is the last tile in the row
|
||||
if j == grid_w - 1:
|
||||
tiles_placeholder.append("\n")
|
||||
|
||||
image_placeholder = "".join(
|
||||
[*tiles_placeholder, "\n", global_img_placeholder])
|
||||
return image_placeholder + fake_image_token
|
||||
|
||||
return [
|
||||
PromptReplacement(
|
||||
modality="image",
|
||||
target=image_token,
|
||||
replacement=get_replacement_idefics3,
|
||||
)
|
||||
]
|
||||
|
||||
|
||||
class Idefics3SimpleMLP(nn.Module):
|
||||
@ -453,7 +399,7 @@ class Idefics3Model(nn.Module):
|
||||
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
||||
super().__init__()
|
||||
|
||||
config = vllm_config.model_config.hf_config
|
||||
config: Idefics3Config = vllm_config.model_config.hf_config
|
||||
quant_config = vllm_config.quant_config
|
||||
|
||||
self.config = config
|
||||
@ -541,15 +487,13 @@ class Idefics3Model(nn.Module):
|
||||
self,
|
||||
pixel_values: torch.Tensor,
|
||||
pixel_attention_mask: Optional[torch.BoolTensor] = None,
|
||||
) -> torch.Tensor:
|
||||
) -> NestedTensors:
|
||||
# NOTE: we skip the step to select the vision feature layer since
|
||||
# this is already done inside the vision tower
|
||||
batch_size, num_images, num_channels, height, width = pixel_values.shape
|
||||
num_patches = [x.size(0) for x in pixel_values]
|
||||
pixel_values = pixel_values.to(
|
||||
dtype=self.vision_model.embeddings.patch_embedding.weight.dtype
|
||||
) # fp16 compatibility
|
||||
pixel_values = pixel_values.view(batch_size * num_images,
|
||||
*pixel_values.shape[2:])
|
||||
|
||||
# Remove padding images - padding images are full 0.
|
||||
nb_values_per_image = pixel_values.shape[1:].numel()
|
||||
@ -567,8 +511,6 @@ class Idefics3Model(nn.Module):
|
||||
)
|
||||
else:
|
||||
# Remove padding images from the mask
|
||||
pixel_attention_mask = pixel_attention_mask.view(
|
||||
batch_size * num_images, *pixel_attention_mask.shape[2:])
|
||||
pixel_attention_mask = pixel_attention_mask[
|
||||
real_images_inds].contiguous()
|
||||
|
||||
@ -587,10 +529,10 @@ class Idefics3Model(nn.Module):
|
||||
patch_attention_mask=patch_attention_mask,
|
||||
)
|
||||
|
||||
return image_hidden_states
|
||||
return image_hidden_states.split(num_patches)
|
||||
|
||||
def _process_image_pixels(
|
||||
self, inputs: Idefics3ImagePixelInputs) -> torch.Tensor:
|
||||
self, inputs: Idefics3ImagePixelInputs) -> NestedTensors:
|
||||
assert self.vision_model is not None
|
||||
|
||||
pixel_values = inputs["data"]
|
||||
@ -605,7 +547,9 @@ class Idefics3Model(nn.Module):
|
||||
|
||||
assert self.vision_model is not None
|
||||
image_features = self._process_image_pixels(image_input)
|
||||
return self.connector(image_features)
|
||||
num_patches = [x.size(0) for x in image_features]
|
||||
image_features = torch.cat(image_features)
|
||||
return self.connector(image_features).split(num_patches)
|
||||
|
||||
def get_input_embeddings(
|
||||
self,
|
||||
@ -634,10 +578,10 @@ class Idefics3Model(nn.Module):
|
||||
return hidden_states
|
||||
|
||||
|
||||
@MULTIMODAL_REGISTRY.register_image_input_mapper(input_mapper_for_idefics3)
|
||||
@MULTIMODAL_REGISTRY.register_max_image_tokens(get_max_idefics3_image_tokens)
|
||||
@INPUT_REGISTRY.register_dummy_data(dummy_data_for_idefics3)
|
||||
@INPUT_REGISTRY.register_input_processor(input_processor_for_idefics3)
|
||||
@MULTIMODAL_REGISTRY.register_processor(
|
||||
Idefics3MultimodalProcessor,
|
||||
info=Idefics3ProcessingInfo,
|
||||
dummy_inputs=Idefics3DummyInputsBuilder)
|
||||
class Idefics3ForConditionalGeneration(nn.Module, SupportsMultiModal,
|
||||
SupportsLoRA):
|
||||
packed_modules_mapping = {
|
||||
@ -689,7 +633,7 @@ class Idefics3ForConditionalGeneration(nn.Module, SupportsMultiModal,
|
||||
if self.config.text_config.tie_word_embeddings:
|
||||
self.lm_head.weight = self.model.text_model.wte.weight
|
||||
self.logits_processor = LogitsProcessor(config.text_config.vocab_size)
|
||||
self.sampler = Sampler()
|
||||
self.sampler = get_sampler()
|
||||
|
||||
def get_multimodal_embeddings(self, **kwargs) -> Optional[NestedTensors]:
|
||||
image_input = self.model._parse_and_validate_image_input(**kwargs)
|
||||
|
Loading…
x
Reference in New Issue
Block a user