[Model] Expose size to Idefics3 as mm_processor_kwargs (#10146)
Signed-off-by: Isotr0py <2037008807@qq.com>
This commit is contained in:
parent
f10797c0ce
commit
1ff4aed5bd
@ -382,10 +382,19 @@ def run_idefics3(question: str, modality: str):
|
|||||||
assert modality == "image"
|
assert modality == "image"
|
||||||
model_name = "HuggingFaceM4/Idefics3-8B-Llama3"
|
model_name = "HuggingFaceM4/Idefics3-8B-Llama3"
|
||||||
|
|
||||||
llm = LLM(model=model_name,
|
llm = LLM(
|
||||||
|
model=model_name,
|
||||||
max_model_len=8192,
|
max_model_len=8192,
|
||||||
max_num_seqs=2,
|
max_num_seqs=2,
|
||||||
enforce_eager=True)
|
enforce_eager=True,
|
||||||
|
# if you are running out of memory, you can reduce the "longest_edge".
|
||||||
|
# see: https://huggingface.co/HuggingFaceM4/Idefics3-8B-Llama3#model-optimizations
|
||||||
|
mm_processor_kwargs={
|
||||||
|
"size": {
|
||||||
|
"longest_edge": 3 * 364
|
||||||
|
},
|
||||||
|
},
|
||||||
|
)
|
||||||
prompt = (
|
prompt = (
|
||||||
f"<|begin_of_text|>User:<image>{question}<end_of_utterance>\nAssistant:"
|
f"<|begin_of_text|>User:<image>{question}<end_of_utterance>\nAssistant:"
|
||||||
)
|
)
|
||||||
|
@ -300,6 +300,13 @@ def load_idefics3(question, image_urls: List[str]) -> ModelRequestData:
|
|||||||
max_num_seqs=16,
|
max_num_seqs=16,
|
||||||
enforce_eager=True,
|
enforce_eager=True,
|
||||||
limit_mm_per_prompt={"image": len(image_urls)},
|
limit_mm_per_prompt={"image": len(image_urls)},
|
||||||
|
# if you are running out of memory, you can reduce the "longest_edge".
|
||||||
|
# see: https://huggingface.co/HuggingFaceM4/Idefics3-8B-Llama3#model-optimizations
|
||||||
|
mm_processor_kwargs={
|
||||||
|
"size": {
|
||||||
|
"longest_edge": 2 * 364
|
||||||
|
},
|
||||||
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
placeholders = "\n".join(f"Image-{i}: <image>\n"
|
placeholders = "\n".join(f"Image-{i}: <image>\n"
|
||||||
|
@ -0,0 +1,187 @@
|
|||||||
|
"""Tests for Idefics3's multimodal preprocessing kwargs."""
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
import torch
|
||||||
|
import transformers
|
||||||
|
from transformers import AutoImageProcessor, AutoTokenizer
|
||||||
|
|
||||||
|
from vllm.inputs import InputContext, token_inputs
|
||||||
|
from vllm.multimodal import MultiModalRegistry
|
||||||
|
|
||||||
|
from .....conftest import _ImageAssets
|
||||||
|
from ....utils import build_model_context
|
||||||
|
|
||||||
|
models = ["HuggingFaceM4/Idefics3-8B-Llama3"]
|
||||||
|
|
||||||
|
|
||||||
|
# Wrap lazy imports to avoid initializing CUDA during test collection
|
||||||
|
@pytest.fixture()
|
||||||
|
def input_processor_for_idefics3():
|
||||||
|
from vllm.model_executor.models.idefics3 import (
|
||||||
|
input_processor_for_idefics3)
|
||||||
|
return input_processor_for_idefics3
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture()
|
||||||
|
def dummy_data_for_idefics3():
|
||||||
|
from vllm.model_executor.models.idefics3 import dummy_data_for_idefics3
|
||||||
|
return dummy_data_for_idefics3
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture()
|
||||||
|
def get_max_idefics3_image_tokens():
|
||||||
|
from vllm.model_executor.models.idefics3 import (
|
||||||
|
get_max_idefics3_image_tokens)
|
||||||
|
return get_max_idefics3_image_tokens
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.skipif(transformers.__version__ < "4.46.0",
|
||||||
|
reason="Model introduced in HF >= 4.46.0")
|
||||||
|
@pytest.mark.parametrize("model", models)
|
||||||
|
@pytest.mark.parametrize("longest_edge", [None, 168, 336, 400, 2 * 336])
|
||||||
|
def test_input_mapper_override(model: str, image_assets: _ImageAssets,
|
||||||
|
longest_edge: Optional[int]):
|
||||||
|
"""Ensure that the [default] input mapper handles size properly."""
|
||||||
|
|
||||||
|
mm_processor_kwargs = {
|
||||||
|
"size": {
|
||||||
|
"longest_edge": longest_edge
|
||||||
|
}
|
||||||
|
} if longest_edge is not None else {}
|
||||||
|
ctx = build_model_context(
|
||||||
|
model_name=model,
|
||||||
|
tokenizer_name=model,
|
||||||
|
trust_remote_code=True,
|
||||||
|
mm_processor_kwargs=mm_processor_kwargs,
|
||||||
|
)
|
||||||
|
|
||||||
|
hf_processor = AutoImageProcessor.from_pretrained(model,
|
||||||
|
trust_remote_code=True,
|
||||||
|
**mm_processor_kwargs)
|
||||||
|
|
||||||
|
mm_registry = MultiModalRegistry()
|
||||||
|
mm_registry.init_mm_limits_per_prompt(ctx.model_config)
|
||||||
|
|
||||||
|
image = image_assets[0].pil_image
|
||||||
|
hf_result = hf_processor.preprocess(
|
||||||
|
image,
|
||||||
|
return_tensors="pt",
|
||||||
|
)
|
||||||
|
|
||||||
|
vllm_result = mm_registry.map_input(
|
||||||
|
ctx.model_config,
|
||||||
|
{"image": image},
|
||||||
|
)
|
||||||
|
|
||||||
|
assert torch.all(hf_result["pixel_values"] == vllm_result["pixel_values"])
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.skipif(transformers.__version__ < "4.46.0",
|
||||||
|
reason="Model introduced in HF >= 4.46.0")
|
||||||
|
@pytest.mark.parametrize("model", models)
|
||||||
|
@pytest.mark.parametrize("longest_edge, expected_max_tokens", [
|
||||||
|
(None, 2873),
|
||||||
|
(168, 169),
|
||||||
|
(336, 169),
|
||||||
|
(400, 338),
|
||||||
|
(672, 338),
|
||||||
|
])
|
||||||
|
def test_max_tokens_override(get_max_idefics3_image_tokens, model: str,
|
||||||
|
longest_edge: Optional[int],
|
||||||
|
expected_max_tokens: int):
|
||||||
|
"""Ensure get_max_idefics3_image_tokens handles mm_processor_kwargs."""
|
||||||
|
size = {"longest_edge": longest_edge} if longest_edge is not None else None
|
||||||
|
ctx = build_model_context(
|
||||||
|
model_name=model,
|
||||||
|
tokenizer_name=model,
|
||||||
|
trust_remote_code=True,
|
||||||
|
mm_processor_kwargs=None,
|
||||||
|
)
|
||||||
|
|
||||||
|
actual_max_tokens = get_max_idefics3_image_tokens(
|
||||||
|
ctx=InputContext(ctx.model_config),
|
||||||
|
size=size,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert expected_max_tokens == actual_max_tokens
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.skipif(transformers.__version__ < "4.46.0",
|
||||||
|
reason="Model introduced in HF >= 4.46.0")
|
||||||
|
@pytest.mark.parametrize("model", models)
|
||||||
|
@pytest.mark.parametrize("longest_edge, toks_per_img, num_imgs", [
|
||||||
|
(168, 169, 1),
|
||||||
|
(168, 169, 2),
|
||||||
|
(400, 338, 1),
|
||||||
|
(400, 338, 2),
|
||||||
|
])
|
||||||
|
def test_dummy_data_override(dummy_data_for_idefics3, model: str,
|
||||||
|
longest_edge: int, toks_per_img: int,
|
||||||
|
num_imgs: int):
|
||||||
|
"""Ensure dummy_data_for_idefics3 handles num_crops properly."""
|
||||||
|
# Same as the previous test - don't initialize mm_processor_kwargs
|
||||||
|
# in this test and assume that the kwargs will be correctly expanded by
|
||||||
|
# the partial when calling the dummy data func.
|
||||||
|
size = {"longest_edge": longest_edge} if longest_edge is not None else None
|
||||||
|
ctx = build_model_context(
|
||||||
|
model_name=model,
|
||||||
|
tokenizer_name=model,
|
||||||
|
trust_remote_code=True,
|
||||||
|
mm_processor_kwargs=None,
|
||||||
|
)
|
||||||
|
|
||||||
|
dummy_data = dummy_data_for_idefics3(
|
||||||
|
ctx=ctx,
|
||||||
|
seq_len=8192, # Should be bigger than num_imgs * toks_per_img
|
||||||
|
mm_counts={"image": num_imgs},
|
||||||
|
size=size)
|
||||||
|
sequence_data = dummy_data.seq_data
|
||||||
|
# Ensure we have the right number of placeholders per size
|
||||||
|
image_token_id = ctx.get_hf_config().image_token_id
|
||||||
|
img_tok_count = sequence_data.get_token_ids().count(image_token_id)
|
||||||
|
assert img_tok_count == toks_per_img * num_imgs
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.skipif(transformers.__version__ < "4.46.0",
|
||||||
|
reason="Model introduced in HF >= 4.46.0")
|
||||||
|
@pytest.mark.parametrize("model", models)
|
||||||
|
@pytest.mark.parametrize("longest_edge,expected_toks_per_img,num_imgs", [
|
||||||
|
(336, 169 * (1**2 + 1), 1),
|
||||||
|
(336, 169 * (1**2 + 1), 2),
|
||||||
|
(400, 169 * (2**2 + 1), 1),
|
||||||
|
(400, 169 * (2**2 + 1), 2),
|
||||||
|
])
|
||||||
|
def test_input_processor_override(input_processor_for_idefics3,
|
||||||
|
image_assets: _ImageAssets, model: str,
|
||||||
|
longest_edge: int,
|
||||||
|
expected_toks_per_img: int, num_imgs: int):
|
||||||
|
"""Ensure input_processor_for_idefics3 handles num_crops properly."""
|
||||||
|
# Same as the previous test - don't initialize mm_processor_kwargs
|
||||||
|
# in this test and assume that the kwargs will be correctly expanded by
|
||||||
|
# the partial when calling the custom input processor.
|
||||||
|
size = {"longest_edge": longest_edge} if longest_edge is not None else None
|
||||||
|
ctx = build_model_context(
|
||||||
|
model_name=model,
|
||||||
|
tokenizer_name=model,
|
||||||
|
trust_remote_code=True,
|
||||||
|
mm_processor_kwargs=None,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Build the image str / prompt based on the number of images we pass
|
||||||
|
tokenizer = AutoTokenizer.from_pretrained(model)
|
||||||
|
placeholders = "<image>" if num_imgs == 1 else "\n".join(
|
||||||
|
f"Image-{i}: <image>\n" for i in range(1, num_imgs + 1))
|
||||||
|
prompt = f"<|begin_of_text|>User:{placeholders}\n<end_of_utterance>\nAssistant:" # noqa: E501
|
||||||
|
images = [image_assets[0].pil_image.resize((336 * 4, 336 * 4))] * num_imgs
|
||||||
|
|
||||||
|
inputs = token_inputs(prompt_token_ids=tokenizer.encode(prompt),
|
||||||
|
prompt=prompt,
|
||||||
|
multi_modal_data={"image": images})
|
||||||
|
|
||||||
|
processed_inputs = input_processor_for_idefics3(ctx, inputs, size=size)
|
||||||
|
|
||||||
|
# Ensure we have the right number of placeholders per num_crops size
|
||||||
|
image_token_id = ctx.get_hf_config().image_token_id
|
||||||
|
img_tok_count = processed_inputs["prompt_token_ids"].count(image_token_id)
|
||||||
|
assert img_tok_count == expected_toks_per_img * num_imgs
|
@ -14,8 +14,8 @@
|
|||||||
"""Inference-only Idefics3 model compatible with HuggingFace weights."""
|
"""Inference-only Idefics3 model compatible with HuggingFace weights."""
|
||||||
|
|
||||||
import math
|
import math
|
||||||
from typing import (Iterable, List, Literal, Mapping, Optional, Tuple,
|
from typing import (Dict, Iterable, List, Literal, Mapping, NamedTuple,
|
||||||
TypedDict, Union)
|
Optional, Tuple, TypedDict, Union)
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.utils.checkpoint
|
import torch.utils.checkpoint
|
||||||
@ -23,6 +23,7 @@ from PIL import Image
|
|||||||
from torch import nn
|
from torch import nn
|
||||||
# Temporary solution for transformers below 4.46.0.
|
# Temporary solution for transformers below 4.46.0.
|
||||||
from transformers import PretrainedConfig as Idefics3Config
|
from transformers import PretrainedConfig as Idefics3Config
|
||||||
|
from transformers import ProcessorMixin as Idefics3ImageProcessor
|
||||||
|
|
||||||
from vllm.attention import AttentionMetadata
|
from vllm.attention import AttentionMetadata
|
||||||
from vllm.config import CacheConfig, MultiModalConfig
|
from vllm.config import CacheConfig, MultiModalConfig
|
||||||
@ -72,16 +73,41 @@ class Idefics3ImageEmbeddingInputs(TypedDict):
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
class Idefics3ProcessorSize(NamedTuple):
|
||||||
|
"""Hashable wrapper for unhashable `size` dict of Idefics3Processor."""
|
||||||
|
# NOTE: cached_get_processor/cached_get_image_processor uses lru_cache,
|
||||||
|
# we need to use NamedTuple instead of TypedDict to avoid hashing issues.
|
||||||
|
longest_edge: int
|
||||||
|
|
||||||
|
def __contains__(self, key: str) -> bool:
|
||||||
|
return key in self._asdict() and getattr(self, key) is not None
|
||||||
|
|
||||||
|
def __getitem__(self, key: str) -> int:
|
||||||
|
return getattr(self, key)
|
||||||
|
|
||||||
|
|
||||||
ImageInputs = Union[Idefics3ImagePixelInputs, Idefics3ImageEmbeddingInputs]
|
ImageInputs = Union[Idefics3ImagePixelInputs, Idefics3ImageEmbeddingInputs]
|
||||||
|
|
||||||
|
|
||||||
|
def get_mm_processor_kwargs(size: Optional[Dict[str, int]] = None) -> Dict:
|
||||||
|
mm_processor_kwargs = {}
|
||||||
|
if size:
|
||||||
|
mm_processor_kwargs["size"] = Idefics3ProcessorSize(**size)
|
||||||
|
return mm_processor_kwargs
|
||||||
|
|
||||||
|
|
||||||
def input_mapper_for_idefics3(
|
def input_mapper_for_idefics3(
|
||||||
ctx: InputContext,
|
ctx: InputContext,
|
||||||
data: object,
|
data: object,
|
||||||
|
*,
|
||||||
|
size: Optional[Dict[str, int]] = None,
|
||||||
):
|
):
|
||||||
model_config = ctx.model_config
|
model_config = ctx.model_config
|
||||||
|
mm_processor_kwargs = get_mm_processor_kwargs(size)
|
||||||
image_processor = cached_get_image_processor(
|
image_processor = cached_get_image_processor(
|
||||||
model_config.model, trust_remote_code=model_config.trust_remote_code)
|
model_config.model,
|
||||||
|
trust_remote_code=model_config.trust_remote_code,
|
||||||
|
**mm_processor_kwargs)
|
||||||
if image_processor is None:
|
if image_processor is None:
|
||||||
raise RuntimeError("No HuggingFace processor is available "
|
raise RuntimeError("No HuggingFace processor is available "
|
||||||
"to process the image object")
|
"to process the image object")
|
||||||
@ -201,13 +227,17 @@ def _get_image_prompt_string(image_rows: int, image_cols: int,
|
|||||||
global_img_token)
|
global_img_token)
|
||||||
|
|
||||||
|
|
||||||
def input_processor_for_idefics3(ctx: InputContext, inputs: DecoderOnlyInputs):
|
def input_processor_for_idefics3(ctx: InputContext,
|
||||||
|
inputs: DecoderOnlyInputs,
|
||||||
|
*,
|
||||||
|
size: Optional[Dict[str, int]] = None):
|
||||||
multi_modal_data = inputs.get("multi_modal_data")
|
multi_modal_data = inputs.get("multi_modal_data")
|
||||||
if multi_modal_data is None or "image" not in multi_modal_data:
|
if multi_modal_data is None or "image" not in multi_modal_data:
|
||||||
return inputs
|
return inputs
|
||||||
|
|
||||||
model_config = ctx.model_config
|
model_config = ctx.model_config
|
||||||
processor = cached_get_processor(model_config.model)
|
mm_processor_kwargs = get_mm_processor_kwargs(size)
|
||||||
|
processor = cached_get_processor(model_config.model, **mm_processor_kwargs)
|
||||||
image_processor = processor.image_processor
|
image_processor = processor.image_processor
|
||||||
tokenizer = processor.tokenizer
|
tokenizer = processor.tokenizer
|
||||||
size = image_processor.size['longest_edge']
|
size = image_processor.size['longest_edge']
|
||||||
@ -286,32 +316,46 @@ def input_processor_for_idefics3(ctx: InputContext, inputs: DecoderOnlyInputs):
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def get_max_idefics3_image_tokens(ctx: InputContext,
|
def _get_max_num_image_patch(image_processor: Idefics3ImageProcessor) -> int:
|
||||||
*,
|
|
||||||
num_crops: Optional[int] = None):
|
|
||||||
model_config = ctx.model_config
|
|
||||||
processor = cached_get_processor(model_config.model)
|
|
||||||
image_seq_len = processor.image_seq_len
|
|
||||||
image_processor = processor.image_processor
|
|
||||||
|
|
||||||
size = image_processor.size['longest_edge']
|
size = image_processor.size['longest_edge']
|
||||||
max_image_size = image_processor.max_image_size['longest_edge']
|
max_image_size = image_processor.max_image_size['longest_edge']
|
||||||
resized_height, resized_width = size, size
|
resized_height, resized_width = size, size
|
||||||
|
|
||||||
grid_h = resized_height // max_image_size
|
grid_h = resized_height // max_image_size
|
||||||
grid_w = resized_width // max_image_size
|
grid_w = resized_width // max_image_size
|
||||||
|
return (grid_h * grid_w + 1)
|
||||||
return (grid_h * grid_w + 1) * image_seq_len
|
|
||||||
|
|
||||||
|
|
||||||
def dummy_data_for_idefics3(ctx: InputContext, seq_len: int,
|
def get_max_idefics3_image_tokens(ctx: InputContext,
|
||||||
mm_counts: Mapping[str, int]) -> DummyData:
|
*,
|
||||||
|
size: Optional[Dict[str,
|
||||||
|
int]] = None) -> int:
|
||||||
|
model_config = ctx.model_config
|
||||||
|
mm_processor_kwargs = get_mm_processor_kwargs(size)
|
||||||
|
processor = cached_get_processor(model_config.model, **mm_processor_kwargs)
|
||||||
|
image_seq_len = processor.image_seq_len
|
||||||
|
image_processor = processor.image_processor
|
||||||
|
|
||||||
|
max_num_image_patches = _get_max_num_image_patch(image_processor)
|
||||||
|
|
||||||
|
return max_num_image_patches * image_seq_len
|
||||||
|
|
||||||
|
|
||||||
|
def dummy_data_for_idefics3(
|
||||||
|
ctx: InputContext,
|
||||||
|
seq_len: int,
|
||||||
|
mm_counts: Mapping[str, int],
|
||||||
|
*,
|
||||||
|
size: Optional[Dict[str, int]] = None) -> DummyData:
|
||||||
hf_config = ctx.get_hf_config()
|
hf_config = ctx.get_hf_config()
|
||||||
num_images = mm_counts["image"]
|
num_images = mm_counts["image"]
|
||||||
|
|
||||||
processor = cached_get_processor(ctx.model_config.model)
|
mm_processor_kwargs = get_mm_processor_kwargs(size)
|
||||||
|
processor = cached_get_processor(ctx.model_config.model,
|
||||||
|
**mm_processor_kwargs)
|
||||||
|
max_num_image_patches = _get_max_num_image_patch(processor.image_processor)
|
||||||
image_seq_len = processor.image_seq_len
|
image_seq_len = processor.image_seq_len
|
||||||
max_llm_image_tokens = 17 * image_seq_len * num_images
|
max_llm_image_tokens = max_num_image_patches * image_seq_len * num_images
|
||||||
|
|
||||||
seq_data = SequenceData.from_prompt_token_counts(
|
seq_data = SequenceData.from_prompt_token_counts(
|
||||||
(hf_config.image_token_id, max_llm_image_tokens), (0, seq_len))
|
(hf_config.image_token_id, max_llm_image_tokens), (0, seq_len))
|
||||||
|
Loading…
x
Reference in New Issue
Block a user