[VLM] merged multimodal processor and V1 support for idefics3 (#12660)

Signed-off-by: Isotr0py <2037008807@qq.com>
Co-authored-by: Cyrus Leung <cyrus.tl.leung@gmail.com>
This commit is contained in:
Isotr0py 2025-02-04 20:00:51 +08:00 committed by GitHub
parent 18a88fcccc
commit 815079de8e
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
7 changed files with 320 additions and 462 deletions

View File

@ -733,7 +733,7 @@ See [this page](#generative-models) for more information on how to use generativ
* `HuggingFaceM4/Idefics3-8B-Llama3` etc.
* ✅︎
*
*
* ✅︎
- * `InternVLChatModel`
* InternVL 2.5, Mono-InternVL, InternVL 2.0
* T + I<sup>E+</sup>

View File

@ -254,14 +254,14 @@ VLM_TEST_SETTINGS = {
patch_hf_runner=model_utils.h2ovl_patch_hf_runner,
),
"idefics3": VLMTestInfo(
models=["HuggingFaceM4/Idefics3-8B-Llama3"],
models=["HuggingFaceTB/SmolVLM-256M-Instruct"],
test_type=(VLMTestType.IMAGE, VLMTestType.MULTI_IMAGE),
prompt_formatter=lambda img_prompt:f"<|begin_of_text|>User:{img_prompt}<end_of_utterance>\nAssistant:", # noqa: E501
img_idx_to_prompt=lambda idx: "<image>",
max_model_len=8192,
max_num_seqs=2,
auto_cls=AutoModelForVision2Seq,
marks=[large_gpu_mark(min_gb=48)],
hf_output_post_proc=model_utils.idefics3_trunc_hf_output,
),
"intern_vl": VLMTestInfo(
models=[

View File

@ -192,6 +192,14 @@ def deepseekvl2_trunc_hf_output(hf_output: RunnerOutput,
return output_ids, output_str, out_logprobs
def idefics3_trunc_hf_output(hf_output: RunnerOutput,
model: str) -> RunnerOutput:
output_ids, output_str, out_logprobs = hf_output
if output_str.endswith("<end_of_utterance>"):
output_str = output_str.split("<end_of_utterance>")[0]
return output_ids, output_str, out_logprobs
def minicpmv_trunc_hf_output(hf_output: RunnerOutput,
model: str) -> RunnerOutput:
output_ids, output_str, out_logprobs = hf_output

View File

@ -149,6 +149,7 @@ def _test_processing_correctness(
"adept/fuyu-8b",
"h2oai/h2ovl-mississippi-800m",
"OpenGVLab/InternVL2-1B",
"HuggingFaceM4/Idefics3-8B-Llama3",
"llava-hf/llava-1.5-7b-hf",
"llava-hf/llava-v1.6-mistral-7b-hf",
"llava-hf/LLaVA-NeXT-Video-7B-hf",

View File

@ -1,13 +1,10 @@
# SPDX-License-Identifier: Apache-2.0
"""Tests for Idefics3's multimodal preprocessing kwargs."""
from typing import Optional
import pytest
import torch
from transformers import AutoImageProcessor, AutoTokenizer
from transformers import Idefics3Config
from vllm.inputs import InputContext, token_inputs
from vllm.multimodal import MultiModalRegistry
from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.utils import cached_get_tokenizer
from ....conftest import _ImageAssets
from ...utils import build_model_context
@ -15,163 +12,53 @@ from ...utils import build_model_context
models = ["HuggingFaceM4/Idefics3-8B-Llama3"]
# Wrap lazy imports to avoid initializing CUDA during test collection
@pytest.fixture()
def input_processor_for_idefics3():
from vllm.model_executor.models.idefics3 import (
input_processor_for_idefics3)
return input_processor_for_idefics3
@pytest.fixture()
def dummy_data_for_idefics3():
from vllm.model_executor.models.idefics3 import dummy_data_for_idefics3
return dummy_data_for_idefics3
@pytest.fixture()
def get_max_idefics3_image_tokens():
from vllm.model_executor.models.idefics3 import (
get_max_idefics3_image_tokens)
return get_max_idefics3_image_tokens
@pytest.mark.parametrize("model", models)
@pytest.mark.parametrize("longest_edge", [None, 168, 336, 400, 2 * 336])
def test_input_mapper_override(model: str, image_assets: _ImageAssets,
longest_edge: Optional[int]):
"""Ensure that the [default] input mapper handles size properly."""
mm_processor_kwargs = {
"size": {
"longest_edge": longest_edge
}
} if longest_edge is not None else {}
ctx = build_model_context(
model_name=model,
tokenizer_name=model,
trust_remote_code=True,
mm_processor_kwargs=mm_processor_kwargs,
)
hf_processor = AutoImageProcessor.from_pretrained(model,
trust_remote_code=True,
**mm_processor_kwargs)
mm_registry = MultiModalRegistry()
mm_registry.init_mm_limits_per_prompt(ctx.model_config)
image = image_assets[0].pil_image
hf_result = hf_processor.preprocess(
image,
return_tensors="pt",
)
vllm_result = mm_registry.map_input(
ctx.model_config,
{"image": image},
)
assert torch.all(hf_result["pixel_values"] == vllm_result["pixel_values"])
@pytest.mark.parametrize("model", models)
@pytest.mark.parametrize("longest_edge, expected_max_tokens", [
(None, 2873),
(168, 169),
(336, 169),
(400, 338),
(672, 338),
# yapf: disable
@pytest.mark.parametrize(
("mm_processor_kwargs", "expected_toks_per_img"),
[
({"size": {"longest_edge": 364}}, 169),
({"size": {"longest_edge": 728}}, 169 * (2**2 + 1)),
])
def test_max_tokens_override(get_max_idefics3_image_tokens, model: str,
longest_edge: Optional[int],
expected_max_tokens: int):
"""Ensure get_max_idefics3_image_tokens handles mm_processor_kwargs."""
size = {"longest_edge": longest_edge} if longest_edge is not None else None
ctx = build_model_context(
model_name=model,
tokenizer_name=model,
trust_remote_code=True,
mm_processor_kwargs=None,
)
actual_max_tokens = get_max_idefics3_image_tokens(
ctx=InputContext(ctx.model_config),
size=size,
)
assert expected_max_tokens == actual_max_tokens
@pytest.mark.parametrize("model", models)
@pytest.mark.parametrize("longest_edge, toks_per_img, num_imgs", [
(168, 169, 1),
(168, 169, 2),
(400, 338, 1),
(400, 338, 2),
])
def test_dummy_data_override(dummy_data_for_idefics3, model: str,
longest_edge: int, toks_per_img: int,
num_imgs: int):
"""Ensure dummy_data_for_idefics3 handles num_crops properly."""
# Same as the previous test - don't initialize mm_processor_kwargs
# in this test and assume that the kwargs will be correctly expanded by
# the partial when calling the dummy data func.
size = {"longest_edge": longest_edge} if longest_edge is not None else None
ctx = build_model_context(
model_name=model,
tokenizer_name=model,
trust_remote_code=True,
mm_processor_kwargs=None,
)
dummy_data = dummy_data_for_idefics3(
ctx=ctx,
seq_len=8192, # Should be bigger than num_imgs * toks_per_img
mm_counts={"image": num_imgs},
size=size)
sequence_data = dummy_data.seq_data
# Ensure we have the right number of placeholders per size
image_token_id = ctx.get_hf_config().image_token_id
img_tok_count = sequence_data.get_token_ids().count(image_token_id)
assert img_tok_count == toks_per_img * num_imgs
@pytest.mark.parametrize("model", models)
@pytest.mark.parametrize("longest_edge,expected_toks_per_img,num_imgs", [
(336, 169 * (1**2 + 1), 1),
(336, 169 * (1**2 + 1), 2),
(400, 169 * (2**2 + 1), 1),
(400, 169 * (2**2 + 1), 2),
])
def test_input_processor_override(input_processor_for_idefics3,
image_assets: _ImageAssets, model: str,
longest_edge: int,
# yapf: enable
@pytest.mark.parametrize("num_imgs", [1, 2])
def test_processor_override(image_assets: _ImageAssets, model: str,
mm_processor_kwargs: dict[str, object],
expected_toks_per_img: int, num_imgs: int):
"""Ensure input_processor_for_idefics3 handles num_crops properly."""
# Same as the previous test - don't initialize mm_processor_kwargs
# in this test and assume that the kwargs will be correctly expanded by
# the partial when calling the custom input processor.
size = {"longest_edge": longest_edge} if longest_edge is not None else None
ctx = build_model_context(
model_name=model,
tokenizer_name=model,
trust_remote_code=True,
mm_processor_kwargs=None,
limit_mm_per_prompt={"image": num_imgs},
)
tokenizer = cached_get_tokenizer(ctx.model_config.tokenizer)
processor = MULTIMODAL_REGISTRY.create_processor(
ctx.model_config,
tokenizer=tokenizer,
)
hf_processor = processor.info.get_hf_processor(**mm_processor_kwargs)
# Build the image str / prompt based on the number of images we pass
tokenizer = AutoTokenizer.from_pretrained(model)
placeholders = "<image>" if num_imgs == 1 else "\n".join(
f"Image-{i}: <image>\n" for i in range(1, num_imgs + 1))
prompt = f"<|begin_of_text|>User:{placeholders}\n<end_of_utterance>\nAssistant:" # noqa: E501
images = [image_assets[0].pil_image.resize((336 * 4, 336 * 4))] * num_imgs
inputs = token_inputs(prompt_token_ids=tokenizer.encode(prompt),
prompt=prompt,
multi_modal_data={"image": images})
# Build mm_data
image_size = ctx.get_hf_config(Idefics3Config).vision_config.image_size
dummy_image_size = (image_size * 4, image_size * 4)
dummy_image = image_assets[0].pil_image.resize(dummy_image_size)
mm_data = {"image": [dummy_image] * num_imgs}
processed_inputs = input_processor_for_idefics3(ctx, inputs, size=size)
processed_inputs = processor.apply(prompt, mm_data, mm_processor_kwargs)
# Ensure the placeholders format are correct
hf_processed_inputs = hf_processor(text=prompt, images=mm_data["image"])
assert processed_inputs["prompt_token_ids"] == hf_processed_inputs[
"input_ids"][0]
# Ensure we have the right number of placeholders per num_crops size
image_token_id = ctx.get_hf_config().image_token_id

View File

@ -31,6 +31,17 @@ C = TypeVar("C", bound=PretrainedConfig, default=PretrainedConfig)
P = TypeVar("P", bound=ProcessorMixin, default=ProcessorMixin)
class HashableDict(dict):
"""
A dictionary that can be hashed by lru_cache.
"""
# NOTE: pythonic dict is not hashable,
# we override on it directly for simplicity
def __hash__(self) -> int: # type: ignore[override]
return hash(frozenset(self.items()))
@dataclass(frozen=True)
class InputContext:
"""
@ -104,6 +115,13 @@ class InputContext:
if isinstance(typ, type):
merged_kwargs["processor_cls"] = typ
# NOTE: Pythonic dict is not hashable and will raise unhashable type
# error when calling `cached_get_processor`, therefore we need to
# wrap it to a hashable dict.
for key, value in merged_kwargs.items():
if isinstance(value, dict):
merged_kwargs[key] = HashableDict(value)
hf_processor = cached_get_processor(
self.model_config.model,
trust_remote_code=self.model_config.trust_remote_code,

View File

@ -16,35 +16,35 @@
"""Inference-only Idefics3 model compatible with HuggingFace weights."""
import math
from typing import (Dict, Iterable, List, Literal, Mapping, NamedTuple,
Optional, Set, Tuple, TypedDict, Union)
from typing import (Dict, Iterable, List, Literal, Mapping, Optional, Set,
Tuple, TypedDict, Union)
import torch
import torch.utils.checkpoint
from PIL import Image
from torch import nn
# Temporary solution for transformers below 4.46.0.
from transformers import PretrainedConfig as Idefics3Config
from transformers import ProcessorMixin as Idefics3ImageProcessor
from transformers import (BatchFeature, Idefics3Config, Idefics3ImageProcessor,
Idefics3Processor)
from vllm.attention import AttentionMetadata
from vllm.config import VllmConfig
from vllm.inputs import (INPUT_REGISTRY, DecoderOnlyInputs, DummyData,
InputContext, token_inputs)
from vllm.logger import init_logger
from vllm.model_executor.layers.linear import ReplicatedLinear
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.sampler import Sampler, SamplerOutput
from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler
from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead
from vllm.model_executor.models.module_mapping import MultiModelKeys
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalKwargs
from vllm.multimodal.image import cached_get_image_processor
from vllm.multimodal.inputs import NestedTensors
from vllm.sequence import IntermediateTensors, SequenceData
from vllm.transformers_utils.processor import cached_get_processor
from vllm.utils import is_list_of
from vllm.multimodal.parse import ImageProcessorItems
from vllm.multimodal.processing import (BaseMultiModalProcessor,
BaseProcessingInfo,
MultiModalDataItems,
MultiModalFieldConfig,
PromptReplacement)
from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs
from vllm.sequence import IntermediateTensors
# yapf: disable
from .idefics2_vision_model import (
@ -77,68 +77,55 @@ class Idefics3ImageEmbeddingInputs(TypedDict):
"""
class Idefics3ProcessorSize(NamedTuple):
"""Hashable wrapper for unhashable `size` dict of Idefics3Processor."""
# NOTE: cached_get_processor/cached_get_image_processor uses lru_cache,
# we need to use NamedTuple instead of TypedDict to avoid hashing issues.
longest_edge: int
def __contains__(self, key: str) -> bool:
return key in self._asdict() and getattr(self, key) is not None
def __getitem__(self, key: str) -> int:
return getattr(self, key)
ImageInputs = Union[Idefics3ImagePixelInputs, Idefics3ImageEmbeddingInputs]
def get_mm_processor_kwargs(size: Optional[Dict[str, int]] = None) -> Dict:
mm_processor_kwargs = {}
if size:
mm_processor_kwargs["size"] = Idefics3ProcessorSize(**size)
return mm_processor_kwargs
class Idefics3ProcessingInfo(BaseProcessingInfo):
def input_mapper_for_idefics3(
ctx: InputContext,
data: object,
def get_hf_processor(
self,
*,
size: Optional[Dict[str, int]] = None,
):
model_config = ctx.model_config
mm_processor_kwargs = get_mm_processor_kwargs(size)
image_processor = cached_get_image_processor(
model_config.model,
trust_remote_code=model_config.trust_remote_code,
**mm_processor_kwargs)
if image_processor is None:
raise RuntimeError("No HuggingFace processor is available "
"to process the image object")
size: Optional[Dict[str, int]] = None) -> Idefics3Processor:
if size is not None:
return self.ctx.get_hf_processor(Idefics3Processor, size=size)
if isinstance(data, Image.Image):
images = [[data]]
elif is_list_of(data, Image.Image):
images = [data]
else:
raise TypeError(f"Invalid image type: {type(data)}")
return self.ctx.get_hf_processor(Idefics3Processor)
try:
batch_data = image_processor(images,
return_tensors="pt",
return_row_col_info=True).data
except Exception:
logger.error("Failed to process image (%s)", data)
raise
def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]:
return {"image": None}
return MultiModalKwargs(batch_data)
def get_mm_max_tokens_per_item(
self,
seq_len: int,
mm_counts: Mapping[str, int],
) -> Mapping[str, int]:
hf_processor = self.get_hf_processor()
image_processor: Idefics3ImageProcessor = hf_processor.image_processor
grid_w, grid_h = self._get_image_feature_grid_size(
image_width=image_processor.size['longest_edge'],
image_height=image_processor.size['longest_edge'],
)
num_image_token = (grid_w * grid_h + 1) * hf_processor.image_seq_len
# Calculate Non-image-token length
# NOTE: <row_1_col_1> and <global-img> are special token for SmolVLM
# but not for Idefic3, so we need to tokenize them to get actual length.
tokenizer = self.get_tokenizer()
tile_token_len = len(tokenizer.tokenize("<row_1_col_1>"))
glob_token_len = len(tokenizer.tokenize(hf_processor.global_image_tag))
# linebreak and <fake_token_around_image> always cost 1 token
fake_token_len = lb_len = 1
non_image_token = (grid_w * grid_h) * (
tile_token_len + fake_token_len) + glob_token_len + (
grid_h + 1) * lb_len + fake_token_len
return {"image": num_image_token + non_image_token}
def _resize_output_size(height: int,
def _resize_output_size(self,
*,
height: int,
width: int,
max_len: Optional[int] = None,
min_len: Optional[int] = 1,
max_size: Optional[int] = None) -> Tuple[int, int]:
max_size: Optional[int] = None) -> tuple[int, int]:
# Set default value for max_len if not provided
max_len = max(height, width) if max_len is None else max_len
aspect_ratio = width / height
@ -156,8 +143,8 @@ def _resize_output_size(height: int,
width = int(height * aspect_ratio)
# Ensure both width and height are even (if needed)
height += 1 if height % 2 != 0 else 0
width += 1 if width % 2 != 0 else 0
height += height % 2
width += width % 2
# Ensure dimensions are not smaller than the minimum length
height = max(height, min_len)
@ -165,219 +152,178 @@ def _resize_output_size(height: int,
return height, width
def _get_resize_output_image_size(
image_size: Tuple[int, int],
self,
*,
image_width: int,
image_height: int,
resolution_max_side: int,
max_image_size: int = 1820,
) -> Tuple[int, int]:
) -> tuple[int, int]:
hf_processor = self.get_hf_processor()
image_processor: Idefics3ImageProcessor = hf_processor.image_processor
max_image_size = image_processor.size['longest_edge']
if resolution_max_side > max_image_size:
raise ValueError(
"`resolution_max_side` cannot be larger than `max_image_size`")
height, width = image_size
height, width = image_height, image_width
# Find the output size, when rescaling the longest edge to max_len and
# preserving the aspect ratio
height, width = _resize_output_size(height,
width,
height, width = self._resize_output_size(height=height,
width=width,
max_len=resolution_max_side)
return height, width
def _prompt_split_image(image_seq_len: int, image_rows: int, image_cols: int,
fake_token_around_image: str, image_token: str,
global_img_token: str) -> str:
"""
Prompt with expanded image tokens for when the image is split
into patches.
"""
text_split_images = ""
for n_h in range(image_rows):
for n_w in range(image_cols):
text_split_images += (fake_token_around_image +
f"<row_{n_h + 1}_col_{n_w + 1}>" +
image_token * image_seq_len)
text_split_images += "\n"
text_split_images += "\n" + _prompt_single_image(
image_seq_len=image_seq_len,
fake_token_around_image=fake_token_around_image,
image_token=image_token,
global_img_token=global_img_token)
return text_split_images
def _prompt_single_image(image_seq_len: int, fake_token_around_image: str,
image_token: str, global_img_token: str):
"""Prompt with expanded image tokens for a single image."""
return (fake_token_around_image + global_img_token +
image_token * image_seq_len + fake_token_around_image)
def _get_image_prompt_string(image_rows: int, image_cols: int,
image_seq_len: int, fake_token_around_image: str,
image_token: str, global_img_token: str):
if image_rows == 0 and image_cols == 0:
return _prompt_single_image(
image_seq_len=image_seq_len,
fake_token_around_image=fake_token_around_image,
image_token=image_token,
global_img_token=global_img_token,
)
return _prompt_split_image(image_seq_len, image_rows, image_cols,
fake_token_around_image, image_token,
global_img_token)
def input_processor_for_idefics3(ctx: InputContext,
inputs: DecoderOnlyInputs,
def _get_image_feature_grid_size(
self,
*,
size: Optional[Dict[str, int]] = None):
multi_modal_data = inputs.get("multi_modal_data")
if multi_modal_data is None or "image" not in multi_modal_data:
return inputs
model_config = ctx.model_config
mm_processor_kwargs = get_mm_processor_kwargs(size)
processor = cached_get_processor(model_config.model, **mm_processor_kwargs)
image_processor = processor.image_processor
tokenizer = processor.tokenizer
size = image_processor.size['longest_edge']
image_width: int,
image_height: int,
size: Optional[dict[str, object]] = None,
) -> tuple[int, int]:
hf_processor = self.get_hf_processor(size=size)
image_processor: Idefics3ImageProcessor = hf_processor.image_processor
max_image_size = image_processor.max_image_size['longest_edge']
size = image_processor.size['longest_edge']
assert size % max_image_size == 0, (
"`longest_edge` in image_processor's `size` must be divisible by "
"`longest_edge` in `max_image_size`, this may be caused by "
"incorrect mm_kwargs override.")
image_data = multi_modal_data["image"]
if isinstance(image_data, Image.Image):
image_list = [image_data]
elif is_list_of(image_data, Image.Image):
image_list = image_data
resized_height, resized_width = self._get_resize_output_image_size(
image_width=image_width,
image_height=image_height,
resolution_max_side=size,
)
if resized_height > max_image_size or resized_width > max_image_size:
grid_h = math.ceil(resized_height / max_image_size)
grid_w = math.ceil(resized_width / max_image_size)
else:
raise TypeError(f"Invalid image type: {type(image_data)}")
image_rows = []
image_cols = []
for image in image_list:
height, width = _get_resize_output_image_size(image.size, size)
rows = math.ceil(height / max_image_size)
cols = math.ceil(width / max_image_size)
image_rows.append(rows)
image_cols.append(cols)
image_rows = [image_rows]
image_cols = [image_cols]
n_images_in_text = []
text = inputs.get("prompt")
if text is None:
prompt_token_ids = inputs.get("prompt_token_ids", [])
assert prompt_token_ids
text = tokenizer.decode(prompt_token_ids)
if isinstance(text, str):
text = [text]
elif not isinstance(text, list) and not isinstance(text[0], str):
raise ValueError("Invalid input text. Please provide a string, "
"or a list of strings")
fake_image_token = processor.fake_image_token.content
image_token = processor.image_token.content
global_img_token = processor.global_image_tag
prompt_strings = []
for sample, sample_rows, sample_cols in zip(text, image_rows, image_cols):
n_images_in_text.append(sample.count(image_token))
# Replace the image token with fake tokens around the expanded
# image token sequence of length `image_seq_len`
image_prompt_strings = []
for n_rows, n_cols in zip(sample_rows, sample_cols):
image_prompt_string = _get_image_prompt_string(
n_rows,
n_cols,
processor.image_seq_len,
image_token=image_token,
fake_token_around_image=fake_image_token,
global_img_token=global_img_token,
)
image_prompt_strings.append(image_prompt_string)
split_sample = sample.split(image_token)
if len(split_sample) == 0:
raise ValueError("The image token should be present in the text.")
# Place in the image prompt strings where the image tokens are
sample = split_sample[0]
for i, image_prompt_string in enumerate(image_prompt_strings):
sample += image_prompt_string + split_sample[i + 1]
prompt_strings.append(sample)
prompt_token_ids = tokenizer(text=prompt_strings[0]).input_ids
return token_inputs(
prompt_token_ids=prompt_token_ids,
prompt=prompt_strings[0],
multi_modal_data=multi_modal_data,
)
grid_h = grid_w = 0
return grid_w, grid_h
def _get_max_num_image_patch(image_processor: Idefics3ImageProcessor) -> int:
size = image_processor.size['longest_edge']
max_image_size = image_processor.max_image_size['longest_edge']
resized_height, resized_width = size, size
class Idefics3DummyInputsBuilder(BaseDummyInputsBuilder[Idefics3ProcessingInfo]
):
grid_h = resized_height // max_image_size
grid_w = resized_width // max_image_size
return (grid_h * grid_w + 1)
def get_max_idefics3_image_tokens(ctx: InputContext,
*,
size: Optional[Dict[str,
int]] = None) -> int:
model_config = ctx.model_config
mm_processor_kwargs = get_mm_processor_kwargs(size)
processor = cached_get_processor(model_config.model, **mm_processor_kwargs)
image_seq_len = processor.image_seq_len
image_processor = processor.image_processor
max_num_image_patches = _get_max_num_image_patch(image_processor)
return max_num_image_patches * image_seq_len
def dummy_data_for_idefics3(
ctx: InputContext,
def get_dummy_processor_inputs(
self,
seq_len: int,
mm_counts: Mapping[str, int],
*,
size: Optional[Dict[str, int]] = None) -> DummyData:
hf_config = ctx.get_hf_config()
num_images = mm_counts["image"]
) -> ProcessorInputs:
num_images = mm_counts.get("image", 0)
hf_processor = self.info.get_hf_processor()
image_processor: Idefics3ImageProcessor = hf_processor.image_processor
longest_edge = image_processor.max_image_size['longest_edge']
image_token: str = hf_processor.image_token.content
mm_processor_kwargs = get_mm_processor_kwargs(size)
processor = cached_get_processor(ctx.model_config.model,
**mm_processor_kwargs)
max_num_image_patches = _get_max_num_image_patch(processor.image_processor)
image_seq_len = processor.image_seq_len
max_llm_image_tokens = max_num_image_patches * image_seq_len * num_images
mm_data = {
"image":
self._get_dummy_images(width=longest_edge,
height=longest_edge,
num_images=num_images)
}
if seq_len - max_llm_image_tokens < 0:
raise RuntimeError(
f"Idefics3 cannot process {num_images} images in a prompt, "
"please increase max_model_len or reduce image limit by "
"--limit-mm-per-prompt.")
return ProcessorInputs(
prompt_text=image_token * num_images,
mm_data=mm_data,
)
seq_data = SequenceData.from_prompt_token_counts(
(hf_config.image_token_id, max_llm_image_tokens),
(0, seq_len - max_llm_image_tokens))
width = height = hf_config.vision_config.image_size
image = Image.new("RGB", (width, height), color=0)
mm_data = {"image": [image] if num_images == 1 else [image] * num_images}
class Idefics3MultimodalProcessor(
BaseMultiModalProcessor[Idefics3ProcessingInfo]):
return DummyData(seq_data, mm_data)
def _call_hf_processor(
self,
prompt: str,
mm_data: Mapping[str, object],
mm_kwargs: Mapping[str, object],
) -> BatchFeature:
if mm_data:
processed_outputs = super()._call_hf_processor(
prompt, mm_data, mm_kwargs)
image_grids = [
self.info._get_image_feature_grid_size(
image_width=img.width,
image_height=img.height,
**mm_kwargs,
) for img in mm_data["images"]
]
image_patches = list(map(lambda x: math.prod(x) + 1, image_grids))
for key in ("pixel_values", "pixel_attention_mask"):
data = processed_outputs.pop(key)
data = data.flatten(0, 1).split(image_patches)
processed_outputs[key] = data
else:
tokenizer = self.info.get_tokenizer()
processed_outputs = tokenizer(prompt,
add_special_tokens=True,
return_tensors="pt")
return processed_outputs
def _get_mm_fields_config(
self,
hf_inputs: BatchFeature,
hf_processor_mm_kwargs: Mapping[str, object],
) -> Mapping[str, MultiModalFieldConfig]:
return dict(
pixel_values=MultiModalFieldConfig.batched("image"),
pixel_attention_mask=MultiModalFieldConfig.batched("image"),
image_embeds=MultiModalFieldConfig.batched("image"),
)
def _get_prompt_replacements(
self,
mm_items: MultiModalDataItems,
hf_processor_mm_kwargs: Mapping[str, object],
out_mm_kwargs: MultiModalKwargs,
) -> list[PromptReplacement]:
hf_processor = self.info.get_hf_processor(**hf_processor_mm_kwargs)
image_token = hf_processor.image_token.content
fake_image_token = hf_processor.fake_image_token.content
global_img_token = hf_processor.global_image_tag
image_seq_len = hf_processor.image_seq_len
grid_placeholder = "<row_{n_h}_col_{n_w}>"
p_img = image_token * image_seq_len
global_img_placeholder = fake_image_token + global_img_token + p_img
tile_img_placeholder = fake_image_token + grid_placeholder + p_img
def get_replacement_idefics3(item_idx: int) -> str:
images = mm_items.get_items("image", ImageProcessorItems)
image_size = images.get_image_size(item_idx)
grid_w, grid_h = self.info._get_image_feature_grid_size(
image_width=image_size.width,
image_height=image_size.height,
**hf_processor_mm_kwargs,
)
if grid_w == 0 and grid_h == 0:
image_placeholder = global_img_placeholder
else:
tiles_placeholder = list[str]()
for i in range(grid_h):
for j in range(grid_w):
placeholder_per_tile = tile_img_placeholder.format(
n_h=i + 1, n_w=j + 1)
tiles_placeholder.append(placeholder_per_tile)
# Add line break if it is the last tile in the row
if j == grid_w - 1:
tiles_placeholder.append("\n")
image_placeholder = "".join(
[*tiles_placeholder, "\n", global_img_placeholder])
return image_placeholder + fake_image_token
return [
PromptReplacement(
modality="image",
target=image_token,
replacement=get_replacement_idefics3,
)
]
class Idefics3SimpleMLP(nn.Module):
@ -453,7 +399,7 @@ class Idefics3Model(nn.Module):
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__()
config = vllm_config.model_config.hf_config
config: Idefics3Config = vllm_config.model_config.hf_config
quant_config = vllm_config.quant_config
self.config = config
@ -541,15 +487,13 @@ class Idefics3Model(nn.Module):
self,
pixel_values: torch.Tensor,
pixel_attention_mask: Optional[torch.BoolTensor] = None,
) -> torch.Tensor:
) -> NestedTensors:
# NOTE: we skip the step to select the vision feature layer since
# this is already done inside the vision tower
batch_size, num_images, num_channels, height, width = pixel_values.shape
num_patches = [x.size(0) for x in pixel_values]
pixel_values = pixel_values.to(
dtype=self.vision_model.embeddings.patch_embedding.weight.dtype
) # fp16 compatibility
pixel_values = pixel_values.view(batch_size * num_images,
*pixel_values.shape[2:])
# Remove padding images - padding images are full 0.
nb_values_per_image = pixel_values.shape[1:].numel()
@ -567,8 +511,6 @@ class Idefics3Model(nn.Module):
)
else:
# Remove padding images from the mask
pixel_attention_mask = pixel_attention_mask.view(
batch_size * num_images, *pixel_attention_mask.shape[2:])
pixel_attention_mask = pixel_attention_mask[
real_images_inds].contiguous()
@ -587,10 +529,10 @@ class Idefics3Model(nn.Module):
patch_attention_mask=patch_attention_mask,
)
return image_hidden_states
return image_hidden_states.split(num_patches)
def _process_image_pixels(
self, inputs: Idefics3ImagePixelInputs) -> torch.Tensor:
self, inputs: Idefics3ImagePixelInputs) -> NestedTensors:
assert self.vision_model is not None
pixel_values = inputs["data"]
@ -605,7 +547,9 @@ class Idefics3Model(nn.Module):
assert self.vision_model is not None
image_features = self._process_image_pixels(image_input)
return self.connector(image_features)
num_patches = [x.size(0) for x in image_features]
image_features = torch.cat(image_features)
return self.connector(image_features).split(num_patches)
def get_input_embeddings(
self,
@ -634,10 +578,10 @@ class Idefics3Model(nn.Module):
return hidden_states
@MULTIMODAL_REGISTRY.register_image_input_mapper(input_mapper_for_idefics3)
@MULTIMODAL_REGISTRY.register_max_image_tokens(get_max_idefics3_image_tokens)
@INPUT_REGISTRY.register_dummy_data(dummy_data_for_idefics3)
@INPUT_REGISTRY.register_input_processor(input_processor_for_idefics3)
@MULTIMODAL_REGISTRY.register_processor(
Idefics3MultimodalProcessor,
info=Idefics3ProcessingInfo,
dummy_inputs=Idefics3DummyInputsBuilder)
class Idefics3ForConditionalGeneration(nn.Module, SupportsMultiModal,
SupportsLoRA):
packed_modules_mapping = {
@ -689,7 +633,7 @@ class Idefics3ForConditionalGeneration(nn.Module, SupportsMultiModal,
if self.config.text_config.tie_word_embeddings:
self.lm_head.weight = self.model.text_model.wte.weight
self.logits_processor = LogitsProcessor(config.text_config.vocab_size)
self.sampler = Sampler()
self.sampler = get_sampler()
def get_multimodal_embeddings(self, **kwargs) -> Optional[NestedTensors]:
image_input = self.model._parse_and_validate_image_input(**kwargs)