[Model] Expose size to Idefics3 as mm_processor_kwargs (#10146)

Signed-off-by: Isotr0py <2037008807@qq.com>
This commit is contained in:
Isotr0py 2024-11-08 17:56:58 +08:00 committed by GitHub
parent f10797c0ce
commit 1ff4aed5bd
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 271 additions and 24 deletions

View File

@ -382,10 +382,19 @@ def run_idefics3(question: str, modality: str):
assert modality == "image"
model_name = "HuggingFaceM4/Idefics3-8B-Llama3"
llm = LLM(model=model_name,
max_model_len=8192,
max_num_seqs=2,
enforce_eager=True)
llm = LLM(
model=model_name,
max_model_len=8192,
max_num_seqs=2,
enforce_eager=True,
# if you are running out of memory, you can reduce the "longest_edge".
# see: https://huggingface.co/HuggingFaceM4/Idefics3-8B-Llama3#model-optimizations
mm_processor_kwargs={
"size": {
"longest_edge": 3 * 364
},
},
)
prompt = (
f"<|begin_of_text|>User:<image>{question}<end_of_utterance>\nAssistant:"
)
@ -518,4 +527,4 @@ if __name__ == "__main__":
default=16,
help='Number of frames to extract from the video.')
args = parser.parse_args()
main(args)
main(args)

View File

@ -300,6 +300,13 @@ def load_idefics3(question, image_urls: List[str]) -> ModelRequestData:
max_num_seqs=16,
enforce_eager=True,
limit_mm_per_prompt={"image": len(image_urls)},
# if you are running out of memory, you can reduce the "longest_edge".
# see: https://huggingface.co/HuggingFaceM4/Idefics3-8B-Llama3#model-optimizations
mm_processor_kwargs={
"size": {
"longest_edge": 2 * 364
},
},
)
placeholders = "\n".join(f"Image-{i}: <image>\n"

View File

@ -0,0 +1,187 @@
"""Tests for Idefics3's multimodal preprocessing kwargs."""
from typing import Optional
import pytest
import torch
import transformers
from transformers import AutoImageProcessor, AutoTokenizer
from vllm.inputs import InputContext, token_inputs
from vllm.multimodal import MultiModalRegistry
from .....conftest import _ImageAssets
from ....utils import build_model_context
models = ["HuggingFaceM4/Idefics3-8B-Llama3"]
# Wrap lazy imports to avoid initializing CUDA during test collection
@pytest.fixture()
def input_processor_for_idefics3():
from vllm.model_executor.models.idefics3 import (
input_processor_for_idefics3)
return input_processor_for_idefics3
@pytest.fixture()
def dummy_data_for_idefics3():
from vllm.model_executor.models.idefics3 import dummy_data_for_idefics3
return dummy_data_for_idefics3
@pytest.fixture()
def get_max_idefics3_image_tokens():
from vllm.model_executor.models.idefics3 import (
get_max_idefics3_image_tokens)
return get_max_idefics3_image_tokens
@pytest.mark.skipif(transformers.__version__ < "4.46.0",
reason="Model introduced in HF >= 4.46.0")
@pytest.mark.parametrize("model", models)
@pytest.mark.parametrize("longest_edge", [None, 168, 336, 400, 2 * 336])
def test_input_mapper_override(model: str, image_assets: _ImageAssets,
longest_edge: Optional[int]):
"""Ensure that the [default] input mapper handles size properly."""
mm_processor_kwargs = {
"size": {
"longest_edge": longest_edge
}
} if longest_edge is not None else {}
ctx = build_model_context(
model_name=model,
tokenizer_name=model,
trust_remote_code=True,
mm_processor_kwargs=mm_processor_kwargs,
)
hf_processor = AutoImageProcessor.from_pretrained(model,
trust_remote_code=True,
**mm_processor_kwargs)
mm_registry = MultiModalRegistry()
mm_registry.init_mm_limits_per_prompt(ctx.model_config)
image = image_assets[0].pil_image
hf_result = hf_processor.preprocess(
image,
return_tensors="pt",
)
vllm_result = mm_registry.map_input(
ctx.model_config,
{"image": image},
)
assert torch.all(hf_result["pixel_values"] == vllm_result["pixel_values"])
@pytest.mark.skipif(transformers.__version__ < "4.46.0",
reason="Model introduced in HF >= 4.46.0")
@pytest.mark.parametrize("model", models)
@pytest.mark.parametrize("longest_edge, expected_max_tokens", [
(None, 2873),
(168, 169),
(336, 169),
(400, 338),
(672, 338),
])
def test_max_tokens_override(get_max_idefics3_image_tokens, model: str,
longest_edge: Optional[int],
expected_max_tokens: int):
"""Ensure get_max_idefics3_image_tokens handles mm_processor_kwargs."""
size = {"longest_edge": longest_edge} if longest_edge is not None else None
ctx = build_model_context(
model_name=model,
tokenizer_name=model,
trust_remote_code=True,
mm_processor_kwargs=None,
)
actual_max_tokens = get_max_idefics3_image_tokens(
ctx=InputContext(ctx.model_config),
size=size,
)
assert expected_max_tokens == actual_max_tokens
@pytest.mark.skipif(transformers.__version__ < "4.46.0",
reason="Model introduced in HF >= 4.46.0")
@pytest.mark.parametrize("model", models)
@pytest.mark.parametrize("longest_edge, toks_per_img, num_imgs", [
(168, 169, 1),
(168, 169, 2),
(400, 338, 1),
(400, 338, 2),
])
def test_dummy_data_override(dummy_data_for_idefics3, model: str,
longest_edge: int, toks_per_img: int,
num_imgs: int):
"""Ensure dummy_data_for_idefics3 handles num_crops properly."""
# Same as the previous test - don't initialize mm_processor_kwargs
# in this test and assume that the kwargs will be correctly expanded by
# the partial when calling the dummy data func.
size = {"longest_edge": longest_edge} if longest_edge is not None else None
ctx = build_model_context(
model_name=model,
tokenizer_name=model,
trust_remote_code=True,
mm_processor_kwargs=None,
)
dummy_data = dummy_data_for_idefics3(
ctx=ctx,
seq_len=8192, # Should be bigger than num_imgs * toks_per_img
mm_counts={"image": num_imgs},
size=size)
sequence_data = dummy_data.seq_data
# Ensure we have the right number of placeholders per size
image_token_id = ctx.get_hf_config().image_token_id
img_tok_count = sequence_data.get_token_ids().count(image_token_id)
assert img_tok_count == toks_per_img * num_imgs
@pytest.mark.skipif(transformers.__version__ < "4.46.0",
reason="Model introduced in HF >= 4.46.0")
@pytest.mark.parametrize("model", models)
@pytest.mark.parametrize("longest_edge,expected_toks_per_img,num_imgs", [
(336, 169 * (1**2 + 1), 1),
(336, 169 * (1**2 + 1), 2),
(400, 169 * (2**2 + 1), 1),
(400, 169 * (2**2 + 1), 2),
])
def test_input_processor_override(input_processor_for_idefics3,
image_assets: _ImageAssets, model: str,
longest_edge: int,
expected_toks_per_img: int, num_imgs: int):
"""Ensure input_processor_for_idefics3 handles num_crops properly."""
# Same as the previous test - don't initialize mm_processor_kwargs
# in this test and assume that the kwargs will be correctly expanded by
# the partial when calling the custom input processor.
size = {"longest_edge": longest_edge} if longest_edge is not None else None
ctx = build_model_context(
model_name=model,
tokenizer_name=model,
trust_remote_code=True,
mm_processor_kwargs=None,
)
# Build the image str / prompt based on the number of images we pass
tokenizer = AutoTokenizer.from_pretrained(model)
placeholders = "<image>" if num_imgs == 1 else "\n".join(
f"Image-{i}: <image>\n" for i in range(1, num_imgs + 1))
prompt = f"<|begin_of_text|>User:{placeholders}\n<end_of_utterance>\nAssistant:" # noqa: E501
images = [image_assets[0].pil_image.resize((336 * 4, 336 * 4))] * num_imgs
inputs = token_inputs(prompt_token_ids=tokenizer.encode(prompt),
prompt=prompt,
multi_modal_data={"image": images})
processed_inputs = input_processor_for_idefics3(ctx, inputs, size=size)
# Ensure we have the right number of placeholders per num_crops size
image_token_id = ctx.get_hf_config().image_token_id
img_tok_count = processed_inputs["prompt_token_ids"].count(image_token_id)
assert img_tok_count == expected_toks_per_img * num_imgs

View File

@ -14,8 +14,8 @@
"""Inference-only Idefics3 model compatible with HuggingFace weights."""
import math
from typing import (Iterable, List, Literal, Mapping, Optional, Tuple,
TypedDict, Union)
from typing import (Dict, Iterable, List, Literal, Mapping, NamedTuple,
Optional, Tuple, TypedDict, Union)
import torch
import torch.utils.checkpoint
@ -23,6 +23,7 @@ from PIL import Image
from torch import nn
# Temporary solution for transformers below 4.46.0.
from transformers import PretrainedConfig as Idefics3Config
from transformers import ProcessorMixin as Idefics3ImageProcessor
from vllm.attention import AttentionMetadata
from vllm.config import CacheConfig, MultiModalConfig
@ -72,16 +73,41 @@ class Idefics3ImageEmbeddingInputs(TypedDict):
"""
class Idefics3ProcessorSize(NamedTuple):
"""Hashable wrapper for unhashable `size` dict of Idefics3Processor."""
# NOTE: cached_get_processor/cached_get_image_processor uses lru_cache,
# we need to use NamedTuple instead of TypedDict to avoid hashing issues.
longest_edge: int
def __contains__(self, key: str) -> bool:
return key in self._asdict() and getattr(self, key) is not None
def __getitem__(self, key: str) -> int:
return getattr(self, key)
ImageInputs = Union[Idefics3ImagePixelInputs, Idefics3ImageEmbeddingInputs]
def get_mm_processor_kwargs(size: Optional[Dict[str, int]] = None) -> Dict:
mm_processor_kwargs = {}
if size:
mm_processor_kwargs["size"] = Idefics3ProcessorSize(**size)
return mm_processor_kwargs
def input_mapper_for_idefics3(
ctx: InputContext,
data: object,
*,
size: Optional[Dict[str, int]] = None,
):
model_config = ctx.model_config
mm_processor_kwargs = get_mm_processor_kwargs(size)
image_processor = cached_get_image_processor(
model_config.model, trust_remote_code=model_config.trust_remote_code)
model_config.model,
trust_remote_code=model_config.trust_remote_code,
**mm_processor_kwargs)
if image_processor is None:
raise RuntimeError("No HuggingFace processor is available "
"to process the image object")
@ -201,13 +227,17 @@ def _get_image_prompt_string(image_rows: int, image_cols: int,
global_img_token)
def input_processor_for_idefics3(ctx: InputContext, inputs: DecoderOnlyInputs):
def input_processor_for_idefics3(ctx: InputContext,
inputs: DecoderOnlyInputs,
*,
size: Optional[Dict[str, int]] = None):
multi_modal_data = inputs.get("multi_modal_data")
if multi_modal_data is None or "image" not in multi_modal_data:
return inputs
model_config = ctx.model_config
processor = cached_get_processor(model_config.model)
mm_processor_kwargs = get_mm_processor_kwargs(size)
processor = cached_get_processor(model_config.model, **mm_processor_kwargs)
image_processor = processor.image_processor
tokenizer = processor.tokenizer
size = image_processor.size['longest_edge']
@ -286,32 +316,46 @@ def input_processor_for_idefics3(ctx: InputContext, inputs: DecoderOnlyInputs):
)
def get_max_idefics3_image_tokens(ctx: InputContext,
*,
num_crops: Optional[int] = None):
model_config = ctx.model_config
processor = cached_get_processor(model_config.model)
image_seq_len = processor.image_seq_len
image_processor = processor.image_processor
def _get_max_num_image_patch(image_processor: Idefics3ImageProcessor) -> int:
size = image_processor.size['longest_edge']
max_image_size = image_processor.max_image_size['longest_edge']
resized_height, resized_width = size, size
grid_h = resized_height // max_image_size
grid_w = resized_width // max_image_size
return (grid_h * grid_w + 1) * image_seq_len
return (grid_h * grid_w + 1)
def dummy_data_for_idefics3(ctx: InputContext, seq_len: int,
mm_counts: Mapping[str, int]) -> DummyData:
def get_max_idefics3_image_tokens(ctx: InputContext,
*,
size: Optional[Dict[str,
int]] = None) -> int:
model_config = ctx.model_config
mm_processor_kwargs = get_mm_processor_kwargs(size)
processor = cached_get_processor(model_config.model, **mm_processor_kwargs)
image_seq_len = processor.image_seq_len
image_processor = processor.image_processor
max_num_image_patches = _get_max_num_image_patch(image_processor)
return max_num_image_patches * image_seq_len
def dummy_data_for_idefics3(
ctx: InputContext,
seq_len: int,
mm_counts: Mapping[str, int],
*,
size: Optional[Dict[str, int]] = None) -> DummyData:
hf_config = ctx.get_hf_config()
num_images = mm_counts["image"]
processor = cached_get_processor(ctx.model_config.model)
mm_processor_kwargs = get_mm_processor_kwargs(size)
processor = cached_get_processor(ctx.model_config.model,
**mm_processor_kwargs)
max_num_image_patches = _get_max_num_image_patch(processor.image_processor)
image_seq_len = processor.image_seq_len
max_llm_image_tokens = 17 * image_seq_len * num_images
max_llm_image_tokens = max_num_image_patches * image_seq_len * num_images
seq_data = SequenceData.from_prompt_token_counts(
(hf_config.image_token_id, max_llm_image_tokens), (0, seq_len))