[Bugfix] Fix CI failures for InternVL and Mantis models (#12728)
Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
parent
649550f27e
commit
18016a5e62
@ -9,6 +9,7 @@ from pathlib import PosixPath
|
|||||||
from typing import Type
|
from typing import Type
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
from packaging.version import Version
|
||||||
from transformers import AutoModelForVision2Seq
|
from transformers import AutoModelForVision2Seq
|
||||||
from transformers import __version__ as TRANSFORMERS_VERSION
|
from transformers import __version__ as TRANSFORMERS_VERSION
|
||||||
|
|
||||||
@ -154,13 +155,7 @@ VLM_TEST_SETTINGS = {
|
|||||||
stop_str=["<|im_end|>"],
|
stop_str=["<|im_end|>"],
|
||||||
image_size_factors=[(0.10, 0.15)],
|
image_size_factors=[(0.10, 0.15)],
|
||||||
max_tokens=64,
|
max_tokens=64,
|
||||||
marks=[
|
marks=[large_gpu_mark(min_gb=64)],
|
||||||
pytest.mark.skipif(
|
|
||||||
TRANSFORMERS_VERSION < "4.48.0",
|
|
||||||
reason="HF model requires transformers>=4.48.0",
|
|
||||||
),
|
|
||||||
large_gpu_mark(min_gb=64),
|
|
||||||
],
|
|
||||||
),
|
),
|
||||||
"blip2": VLMTestInfo(
|
"blip2": VLMTestInfo(
|
||||||
models=["Salesforce/blip2-opt-2.7b"],
|
models=["Salesforce/blip2-opt-2.7b"],
|
||||||
@ -206,7 +201,7 @@ VLM_TEST_SETTINGS = {
|
|||||||
image_size_factors=[(), (1.0, ), (1.0, 1.0, 1.0), (0.1, 0.5, 1.0)],
|
image_size_factors=[(), (1.0, ), (1.0, 1.0, 1.0), (0.1, 0.5, 1.0)],
|
||||||
marks=[
|
marks=[
|
||||||
pytest.mark.skipif(
|
pytest.mark.skipif(
|
||||||
TRANSFORMERS_VERSION >= "4.48.0",
|
Version(TRANSFORMERS_VERSION) >= Version("4.48"),
|
||||||
reason="HF model is not compatible with transformers>=4.48.0",
|
reason="HF model is not compatible with transformers>=4.48.0",
|
||||||
)
|
)
|
||||||
],
|
],
|
||||||
@ -339,6 +334,12 @@ VLM_TEST_SETTINGS = {
|
|||||||
auto_cls=AutoModelForVision2Seq,
|
auto_cls=AutoModelForVision2Seq,
|
||||||
vllm_output_post_proc=model_utils.mantis_vllm_to_hf_output,
|
vllm_output_post_proc=model_utils.mantis_vllm_to_hf_output,
|
||||||
patch_hf_runner=model_utils.mantis_patch_hf_runner,
|
patch_hf_runner=model_utils.mantis_patch_hf_runner,
|
||||||
|
marks=[
|
||||||
|
pytest.mark.skipif(
|
||||||
|
Version(TRANSFORMERS_VERSION) >= Version("4.48"),
|
||||||
|
reason="HF model is not compatible with transformers>=4.48.0",
|
||||||
|
)
|
||||||
|
],
|
||||||
),
|
),
|
||||||
"minicpmv_25": VLMTestInfo(
|
"minicpmv_25": VLMTestInfo(
|
||||||
models=["openbmb/MiniCPM-Llama3-V-2_5"],
|
models=["openbmb/MiniCPM-Llama3-V-2_5"],
|
||||||
|
@ -224,8 +224,7 @@ _CROSS_ENCODER_EXAMPLE_MODELS = {
|
|||||||
|
|
||||||
_MULTIMODAL_EXAMPLE_MODELS = {
|
_MULTIMODAL_EXAMPLE_MODELS = {
|
||||||
# [Decoder-only]
|
# [Decoder-only]
|
||||||
"AriaForConditionalGeneration": _HfExamplesInfo("rhymes-ai/Aria",
|
"AriaForConditionalGeneration": _HfExamplesInfo("rhymes-ai/Aria"),
|
||||||
min_transformers_version="4.48"),
|
|
||||||
"Blip2ForConditionalGeneration": _HfExamplesInfo("Salesforce/blip2-opt-2.7b"), # noqa: E501
|
"Blip2ForConditionalGeneration": _HfExamplesInfo("Salesforce/blip2-opt-2.7b"), # noqa: E501
|
||||||
"ChameleonForConditionalGeneration": _HfExamplesInfo("facebook/chameleon-7b"), # noqa: E501
|
"ChameleonForConditionalGeneration": _HfExamplesInfo("facebook/chameleon-7b"), # noqa: E501
|
||||||
"ChatGLMModel": _HfExamplesInfo("THUDM/glm-4v-9b",
|
"ChatGLMModel": _HfExamplesInfo("THUDM/glm-4v-9b",
|
||||||
|
@ -1,11 +1,13 @@
|
|||||||
# SPDX-License-Identifier: Apache-2.0
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
|
||||||
from contextlib import nullcontext
|
from contextlib import nullcontext
|
||||||
|
from types import MethodType
|
||||||
from typing import cast
|
from typing import cast
|
||||||
from unittest.mock import MagicMock
|
from unittest.mock import MagicMock
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import pytest
|
import pytest
|
||||||
|
from transformers import ProcessorMixin
|
||||||
|
|
||||||
from vllm.config import ModelConfig
|
from vllm.config import ModelConfig
|
||||||
from vllm.multimodal import MULTIMODAL_REGISTRY
|
from vllm.multimodal import MULTIMODAL_REGISTRY
|
||||||
@ -636,3 +638,70 @@ def test_limit_mm_per_prompt_apply(model_id, num_images, limit, is_valid):
|
|||||||
mm_data=mm_data,
|
mm_data=mm_data,
|
||||||
hf_processor_mm_kwargs={},
|
hf_processor_mm_kwargs={},
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class _ProcessorProxy:
|
||||||
|
|
||||||
|
def __init__(self, processor: ProcessorMixin) -> None:
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
self.__processor = processor
|
||||||
|
|
||||||
|
def __getattr__(self, key: str):
|
||||||
|
return getattr(self.__processor, key)
|
||||||
|
|
||||||
|
def __call__(
|
||||||
|
self,
|
||||||
|
text=None,
|
||||||
|
images=None,
|
||||||
|
videos=None,
|
||||||
|
exists=None,
|
||||||
|
return_tensors=None,
|
||||||
|
):
|
||||||
|
return dict(exists=exists)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("model_id", ["Qwen/Qwen2-VL-7B-Instruct"]) # Dummy
|
||||||
|
# yapf: disable
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
("call_kwargs", "expected_kwargs"),
|
||||||
|
[
|
||||||
|
# Should ignore invalid kwargs
|
||||||
|
({"does_not_exist": 100}, {"exists": None}),
|
||||||
|
({"exists": 1}, {"exists": 1}),
|
||||||
|
({"does_not_exist": 100, "exists": 1}, {"exists": 1}),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
# yapf: enable
|
||||||
|
def test_hf_processor_kwargs(model_id, call_kwargs, expected_kwargs):
|
||||||
|
model_config = ModelConfig(
|
||||||
|
model=model_id,
|
||||||
|
task="auto",
|
||||||
|
tokenizer=model_id,
|
||||||
|
tokenizer_mode="auto",
|
||||||
|
trust_remote_code=False,
|
||||||
|
seed=0,
|
||||||
|
dtype="half",
|
||||||
|
revision=None,
|
||||||
|
)
|
||||||
|
|
||||||
|
processor = MULTIMODAL_REGISTRY.create_processor(
|
||||||
|
model_config,
|
||||||
|
tokenizer=cached_get_tokenizer(model_config.tokenizer),
|
||||||
|
)
|
||||||
|
orig_get_hf_processor = processor.info.get_hf_processor
|
||||||
|
|
||||||
|
def get_hf_processor(self, **kwargs):
|
||||||
|
assert kwargs == call_kwargs
|
||||||
|
return _ProcessorProxy(orig_get_hf_processor())
|
||||||
|
|
||||||
|
processor.info.get_hf_processor = MethodType(get_hf_processor,
|
||||||
|
processor.info)
|
||||||
|
|
||||||
|
out_kwargs = processor._call_hf_processor(
|
||||||
|
prompt="",
|
||||||
|
mm_data={},
|
||||||
|
mm_kwargs=call_kwargs,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert out_kwargs == expected_kwargs
|
||||||
|
@ -1,402 +0,0 @@
|
|||||||
# SPDX-License-Identifier: Apache-2.0
|
|
||||||
|
|
||||||
from array import array
|
|
||||||
from typing import Callable, Dict, Mapping, Optional
|
|
||||||
from unittest.mock import patch
|
|
||||||
|
|
||||||
import pytest
|
|
||||||
import torch
|
|
||||||
|
|
||||||
from vllm.inputs import (DecoderOnlyInputs, DummyData, InputContext,
|
|
||||||
InputRegistry, ProcessorInputs, token_inputs)
|
|
||||||
from vllm.multimodal import MultiModalRegistry
|
|
||||||
from vllm.sequence import VLLM_TOKEN_ID_ARRAY_TYPE, SequenceData
|
|
||||||
|
|
||||||
from ..models.utils import build_model_context
|
|
||||||
|
|
||||||
# Used for fast tests where the model doesn't matter
|
|
||||||
DUMMY_MODEL_ID = "facebook/opt-125m"
|
|
||||||
# Used for tests that need a multimodal model
|
|
||||||
MULTIMODAL_MODEL_ID = "OpenGVLab/InternVL2-2B"
|
|
||||||
|
|
||||||
# For mm_processor_kwargs - we test overrides by defining mocks for each place
|
|
||||||
# it is used, and ensuring that we can pass processor kwargs an override value
|
|
||||||
# to receive the intended result for things like sequence length etc.
|
|
||||||
DEFAULT_MAX_DYNAMIC_PATCH = 6
|
|
||||||
MAX_DYNAMIC_PATCH_OVERRIDE = 4
|
|
||||||
|
|
||||||
|
|
||||||
# Mocks for all of the places that we use the mm_processor_kwargs
|
|
||||||
# to override values in different callables
|
|
||||||
@pytest.fixture
|
|
||||||
def use_processor_mock():
|
|
||||||
"""Patches the internal model input processor with an override callable."""
|
|
||||||
|
|
||||||
def custom_processor(ctx: InputContext,
|
|
||||||
inputs: DecoderOnlyInputs,
|
|
||||||
*,
|
|
||||||
max_dynamic_patch=DEFAULT_MAX_DYNAMIC_PATCH):
|
|
||||||
# For testing purposes, we don't worry about the prompt
|
|
||||||
return token_inputs(
|
|
||||||
prompt_token_ids=[],
|
|
||||||
mm_processor_kwargs={"max_dynamic_patch": max_dynamic_patch})
|
|
||||||
|
|
||||||
with patch("vllm.inputs.registry.InputRegistry._get_model_input_processor",
|
|
||||||
return_value=custom_processor):
|
|
||||||
yield
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
|
||||||
def use_dummy_data_mock():
|
|
||||||
"""Patches the internal model input processor with an override callable."""
|
|
||||||
|
|
||||||
def custom_dummy_data_factory(self,
|
|
||||||
ctx: InputContext,
|
|
||||||
seq_len: int,
|
|
||||||
mm_counts: Mapping[str, int],
|
|
||||||
*,
|
|
||||||
max_dynamic_patch=DEFAULT_MAX_DYNAMIC_PATCH):
|
|
||||||
seq_data = SequenceData(
|
|
||||||
array(VLLM_TOKEN_ID_ARRAY_TYPE, [0] * max_dynamic_patch))
|
|
||||||
return DummyData(seq_data, None)
|
|
||||||
|
|
||||||
with patch(
|
|
||||||
"vllm.inputs.registry.InputRegistry._default_dummy_data_factory",
|
|
||||||
custom_dummy_data_factory):
|
|
||||||
yield
|
|
||||||
|
|
||||||
|
|
||||||
# Lazy import to avoid CUDA reinitialization error
|
|
||||||
def mm_model_cls():
|
|
||||||
from vllm.model_executor.models.internvl import InternVLChatModel
|
|
||||||
|
|
||||||
return InternVLChatModel
|
|
||||||
|
|
||||||
|
|
||||||
# lambda whose signature matches max token calcs extra & mapper + extra kwargs
|
|
||||||
get_max_dynamic_patch = lambda ctx, *, max_dynamic_patch=DEFAULT_MAX_DYNAMIC_PATCH: max_dynamic_patch # noqa: E501
|
|
||||||
custom_mapper = lambda ctx, data, *, max_dynamic_patch=DEFAULT_MAX_DYNAMIC_PATCH: { # noqa: E501
|
|
||||||
"pixel_values": torch.zeros(size=(1, max_dynamic_patch + 1, 3, 448, 448))
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
### Tests for default processor logic & mm_processor_kwargs wrapping
|
|
||||||
def test_default_processor_is_a_noop():
|
|
||||||
"""Ensure that by default, there is no processor override."""
|
|
||||||
dummy_registry = InputRegistry()
|
|
||||||
ctx = build_model_context(DUMMY_MODEL_ID)
|
|
||||||
processor = dummy_registry.create_input_processor(ctx.model_config)
|
|
||||||
proc_inputs = token_inputs(prompt_token_ids=[], prompt="")
|
|
||||||
proc_outputs = processor(inputs=proc_inputs)
|
|
||||||
assert proc_inputs is proc_outputs
|
|
||||||
|
|
||||||
|
|
||||||
def _get_max_dynamic_patch_info(init_max_dynamic_patch: int,
|
|
||||||
inference_max_dynamic_patch: int):
|
|
||||||
"""Get the init / inference kwargs and expected max_dynamic_patch."""
|
|
||||||
# If we have a value for max_dynamic_patch, pass the override value and make
|
|
||||||
# sure we get that value as a return-value from out mock processor,
|
|
||||||
# otherwise fall back to the default value
|
|
||||||
init_kwargs = None if init_max_dynamic_patch is None else {
|
|
||||||
"max_dynamic_patch": init_max_dynamic_patch
|
|
||||||
}
|
|
||||||
inference_kwargs = None if inference_max_dynamic_patch is None else {
|
|
||||||
"max_dynamic_patch": inference_max_dynamic_patch
|
|
||||||
}
|
|
||||||
if inference_max_dynamic_patch is not None:
|
|
||||||
expected_seq_count = inference_max_dynamic_patch
|
|
||||||
elif init_max_dynamic_patch is not None:
|
|
||||||
expected_seq_count = init_max_dynamic_patch
|
|
||||||
else:
|
|
||||||
expected_seq_count = DEFAULT_MAX_DYNAMIC_PATCH
|
|
||||||
return init_kwargs, inference_kwargs, expected_seq_count
|
|
||||||
|
|
||||||
|
|
||||||
def _get_processed_max_dynamic_patch(
|
|
||||||
processor: Callable[[ProcessorInputs], ProcessorInputs],
|
|
||||||
inference_kwargs: Optional[Dict[str, int]],
|
|
||||||
) -> int:
|
|
||||||
processed_inputs = processor(
|
|
||||||
token_inputs(prompt_token_ids=[],
|
|
||||||
prompt="",
|
|
||||||
mm_processor_kwargs=inference_kwargs))
|
|
||||||
|
|
||||||
assert "type" in processed_inputs
|
|
||||||
assert processed_inputs["type"] == "token"
|
|
||||||
assert "mm_processor_kwargs" in processed_inputs
|
|
||||||
return processed_inputs["mm_processor_kwargs"]["max_dynamic_patch"]
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize(
|
|
||||||
"init_max_dynamic_patch,inference_max_dynamic_patch", [
|
|
||||||
(None, None),
|
|
||||||
(MAX_DYNAMIC_PATCH_OVERRIDE, None),
|
|
||||||
(DEFAULT_MAX_DYNAMIC_PATCH, MAX_DYNAMIC_PATCH_OVERRIDE),
|
|
||||||
])
|
|
||||||
def test_input_processor_kwargs(use_processor_mock, init_max_dynamic_patch,
|
|
||||||
inference_max_dynamic_patch):
|
|
||||||
"""Ensure input processors can use processor kwargs."""
|
|
||||||
dummy_registry = InputRegistry()
|
|
||||||
|
|
||||||
(init_kwargs, inference_kwargs,
|
|
||||||
expected_seq_count) = _get_max_dynamic_patch_info(
|
|
||||||
init_max_dynamic_patch, inference_max_dynamic_patch)
|
|
||||||
|
|
||||||
ctx = build_model_context(DUMMY_MODEL_ID, mm_processor_kwargs=init_kwargs)
|
|
||||||
processor = dummy_registry.create_input_processor(ctx.model_config)
|
|
||||||
max_dynamic_patch_val = _get_processed_max_dynamic_patch(
|
|
||||||
processor, inference_kwargs)
|
|
||||||
|
|
||||||
assert max_dynamic_patch_val == expected_seq_count
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize(
|
|
||||||
"mm_processor_kwargs",
|
|
||||||
[
|
|
||||||
# Not part of the signature
|
|
||||||
{
|
|
||||||
"does_not_exist": 100
|
|
||||||
},
|
|
||||||
# Part of the signature, not keyword only
|
|
||||||
{
|
|
||||||
"ctx": "something bad"
|
|
||||||
}
|
|
||||||
])
|
|
||||||
def test_processor_with_sad_kwarg_overrides(use_processor_mock,
|
|
||||||
mm_processor_kwargs):
|
|
||||||
"""Ensure that input processors filter out invalid mm_processor_kwargs"""
|
|
||||||
dummy_registry = InputRegistry()
|
|
||||||
# Should filter out the init time kwargs
|
|
||||||
ctx = build_model_context(DUMMY_MODEL_ID,
|
|
||||||
mm_processor_kwargs=mm_processor_kwargs)
|
|
||||||
|
|
||||||
processor = dummy_registry.create_input_processor(ctx.model_config)
|
|
||||||
# Should filter out the inference time kwargs
|
|
||||||
max_dynamic_patch_val = _get_processed_max_dynamic_patch(
|
|
||||||
processor, mm_processor_kwargs)
|
|
||||||
assert max_dynamic_patch_val == DEFAULT_MAX_DYNAMIC_PATCH
|
|
||||||
|
|
||||||
|
|
||||||
### Test overrides for the dummy data
|
|
||||||
@pytest.mark.parametrize("max_dynamic_patch",
|
|
||||||
[None, MAX_DYNAMIC_PATCH_OVERRIDE])
|
|
||||||
def test_dummy_data_kwarg_overrides(use_dummy_data_mock, max_dynamic_patch):
|
|
||||||
"""Ensure dummy data factories can use processor kwargs."""
|
|
||||||
mm_processor_kwargs = None if max_dynamic_patch is None else {
|
|
||||||
"max_dynamic_patch": max_dynamic_patch
|
|
||||||
}
|
|
||||||
expected_seq_count = (DEFAULT_MAX_DYNAMIC_PATCH
|
|
||||||
if max_dynamic_patch is None else max_dynamic_patch)
|
|
||||||
dummy_registry = InputRegistry()
|
|
||||||
ctx = build_model_context(DUMMY_MODEL_ID,
|
|
||||||
mm_processor_kwargs=mm_processor_kwargs)
|
|
||||||
mm_registry = MultiModalRegistry()
|
|
||||||
mm_registry.init_mm_limits_per_prompt(ctx.model_config)
|
|
||||||
|
|
||||||
# NOTE: seq_len is thrown away here since this will leverage the
|
|
||||||
# default dummy data factory that we have patched in, whose seq
|
|
||||||
# len is solely dependent on the value of the mm_processor_kwargs.
|
|
||||||
dummy_data = dummy_registry.dummy_data_for_profiling(
|
|
||||||
ctx.model_config, seq_len=-1, mm_registry=mm_registry)
|
|
||||||
assert len(dummy_data.seq_data.prompt_token_ids) == expected_seq_count
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize(
|
|
||||||
"mm_processor_kwargs",
|
|
||||||
[
|
|
||||||
# Not part of the signature
|
|
||||||
{
|
|
||||||
"does_not_exist": 100
|
|
||||||
},
|
|
||||||
# Part of the signature, not keyword only
|
|
||||||
{
|
|
||||||
"ctx": "something bad"
|
|
||||||
}
|
|
||||||
])
|
|
||||||
def test_dummy_data_with_sad_kwarg_overrides(use_dummy_data_mock,
|
|
||||||
mm_processor_kwargs):
|
|
||||||
"""Ensure the dummy data factory filters out invalid mm_processor_kwargs"""
|
|
||||||
dummy_registry = InputRegistry()
|
|
||||||
ctx = build_model_context(DUMMY_MODEL_ID,
|
|
||||||
mm_processor_kwargs=mm_processor_kwargs)
|
|
||||||
mm_registry = MultiModalRegistry()
|
|
||||||
mm_registry.init_mm_limits_per_prompt(ctx.model_config)
|
|
||||||
|
|
||||||
# NOTE: seq_len is thrown away here since this will leverage the
|
|
||||||
# default dummy data factory that we have patched in, whose seq
|
|
||||||
# len is solely dependent on the value of the mm_processor_kwargs.
|
|
||||||
dummy_data = dummy_registry.dummy_data_for_profiling(
|
|
||||||
ctx.model_config, seq_len=-1, mm_registry=mm_registry)
|
|
||||||
assert len(
|
|
||||||
dummy_data.seq_data.prompt_token_ids) == DEFAULT_MAX_DYNAMIC_PATCH
|
|
||||||
|
|
||||||
|
|
||||||
### Test overrides for the max token count per multimodal instance
|
|
||||||
@pytest.mark.parametrize("max_dynamic_patch",
|
|
||||||
[None, MAX_DYNAMIC_PATCH_OVERRIDE])
|
|
||||||
def test_max_tokens_kwarg_overrides(max_dynamic_patch):
|
|
||||||
"""Ensure max token calcs can use processor kwargs."""
|
|
||||||
mm_processor_kwargs = None if max_dynamic_patch is None else {
|
|
||||||
"max_dynamic_patch": max_dynamic_patch
|
|
||||||
}
|
|
||||||
expected_seq_count = (DEFAULT_MAX_DYNAMIC_PATCH
|
|
||||||
if max_dynamic_patch is None else max_dynamic_patch)
|
|
||||||
|
|
||||||
ctx = build_model_context(MULTIMODAL_MODEL_ID,
|
|
||||||
task="generate",
|
|
||||||
trust_remote_code=True,
|
|
||||||
mm_processor_kwargs=mm_processor_kwargs,
|
|
||||||
limit_mm_per_prompt={"image": 1})
|
|
||||||
|
|
||||||
mm_registry = MultiModalRegistry()
|
|
||||||
mm_registry.init_mm_limits_per_prompt(ctx.model_config)
|
|
||||||
# Patch the image registry for phi3v with our lambda that is compatible
|
|
||||||
# with overrides, then ensure that calling the method correctly echos
|
|
||||||
# our max_dynamic_patch value back from the mm_processor_kwargs.
|
|
||||||
with patch.object(
|
|
||||||
mm_registry._get_plugin("image"),
|
|
||||||
"_max_mm_tokens",
|
|
||||||
{mm_model_cls(): get_max_dynamic_patch},
|
|
||||||
):
|
|
||||||
max_multimodal_tokens = mm_registry.get_max_multimodal_tokens(
|
|
||||||
ctx.model_config)
|
|
||||||
|
|
||||||
assert expected_seq_count == max_multimodal_tokens
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize(
|
|
||||||
"mm_processor_kwargs",
|
|
||||||
[
|
|
||||||
# Not part of the signature
|
|
||||||
{
|
|
||||||
"does_not_exist": 100
|
|
||||||
},
|
|
||||||
# Part of the signature, not keyword only
|
|
||||||
{
|
|
||||||
"ctx": "something bad"
|
|
||||||
}
|
|
||||||
])
|
|
||||||
def test_max_tokens_with_sad_kwarg_overrides(mm_processor_kwargs):
|
|
||||||
"""Ensure that max token calcs filters out invalid mm_processor_kwargs"""
|
|
||||||
ctx = build_model_context(MULTIMODAL_MODEL_ID,
|
|
||||||
task="generate",
|
|
||||||
trust_remote_code=True,
|
|
||||||
mm_processor_kwargs=mm_processor_kwargs,
|
|
||||||
limit_mm_per_prompt={"image": 1})
|
|
||||||
|
|
||||||
mm_registry = MultiModalRegistry()
|
|
||||||
mm_registry.init_mm_limits_per_prompt(ctx.model_config)
|
|
||||||
|
|
||||||
# Similar before, but since these kwargs get filtered,
|
|
||||||
# we always get our default value back.
|
|
||||||
with patch.object(
|
|
||||||
mm_registry._get_plugin("image"),
|
|
||||||
"_max_mm_tokens",
|
|
||||||
{mm_model_cls(): get_max_dynamic_patch},
|
|
||||||
):
|
|
||||||
max_multimodal_tokens = mm_registry.get_max_multimodal_tokens(
|
|
||||||
ctx.model_config)
|
|
||||||
|
|
||||||
assert max_multimodal_tokens == DEFAULT_MAX_DYNAMIC_PATCH
|
|
||||||
|
|
||||||
|
|
||||||
### Test overrides for the mapper
|
|
||||||
@pytest.mark.parametrize(
|
|
||||||
"max_dynamic_patch",
|
|
||||||
[DEFAULT_MAX_DYNAMIC_PATCH, MAX_DYNAMIC_PATCH_OVERRIDE])
|
|
||||||
def test_default_mapper_with_processor_kwargs(image_assets, max_dynamic_patch):
|
|
||||||
"""Ensure that the mapper processor kwargs can fall back to HF models."""
|
|
||||||
# NOTE - we don't validate bad inputs for the default mapper, because it's
|
|
||||||
# through the automodel interface in transformers, so we can't easily
|
|
||||||
# inspect what kwargs are or are not allowed.
|
|
||||||
ctx = build_model_context(
|
|
||||||
MULTIMODAL_MODEL_ID,
|
|
||||||
task="generate",
|
|
||||||
trust_remote_code=True,
|
|
||||||
mm_processor_kwargs={"max_dynamic_patch": max_dynamic_patch},
|
|
||||||
limit_mm_per_prompt={"image": 1})
|
|
||||||
|
|
||||||
mm_registry = MultiModalRegistry()
|
|
||||||
mm_registry.init_mm_limits_per_prompt(ctx.model_config)
|
|
||||||
|
|
||||||
image = image_assets[0].pil_image
|
|
||||||
mm_inputs = {"image": image}
|
|
||||||
|
|
||||||
mapped_inputs = mm_registry.map_input(ctx.model_config, mm_inputs)
|
|
||||||
# pixel vals should have shape: [batch, max_dynamic_patch+1, ...]
|
|
||||||
assert mapped_inputs["pixel_values"].shape[1] == max_dynamic_patch + 1
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize(
|
|
||||||
"init_max_dynamic_patch,inference_max_dynamic_patch", [
|
|
||||||
(None, None),
|
|
||||||
(MAX_DYNAMIC_PATCH_OVERRIDE, None),
|
|
||||||
(DEFAULT_MAX_DYNAMIC_PATCH, MAX_DYNAMIC_PATCH_OVERRIDE),
|
|
||||||
])
|
|
||||||
def test_custom_mapper_kwarg_overrides(image_assets, init_max_dynamic_patch,
|
|
||||||
inference_max_dynamic_patch):
|
|
||||||
"""Ensure custom mappers can use processor kwargs."""
|
|
||||||
(init_kwargs, inference_kwargs,
|
|
||||||
expected_seq_count) = _get_max_dynamic_patch_info(
|
|
||||||
init_max_dynamic_patch, inference_max_dynamic_patch)
|
|
||||||
|
|
||||||
ctx = build_model_context(MULTIMODAL_MODEL_ID,
|
|
||||||
task="generate",
|
|
||||||
trust_remote_code=True,
|
|
||||||
mm_processor_kwargs=init_kwargs,
|
|
||||||
limit_mm_per_prompt={"image": 1})
|
|
||||||
|
|
||||||
mm_registry = MultiModalRegistry()
|
|
||||||
mm_registry.init_mm_limits_per_prompt(ctx.model_config)
|
|
||||||
image = image_assets[0].pil_image
|
|
||||||
mm_inputs = {"image": image}
|
|
||||||
|
|
||||||
# Patch the image registry for phi3v with our lambda that is compatible
|
|
||||||
# with overrides, then ensure that calling the method correctly echos
|
|
||||||
# our max_dynamic_patch value back from the mm_processor_kwargs.
|
|
||||||
mm_registry._get_plugin("image").register_input_mapper(custom_mapper)(
|
|
||||||
mm_model_cls())
|
|
||||||
mapped_inputs = mm_registry.map_input(ctx.model_config, mm_inputs,
|
|
||||||
inference_kwargs)
|
|
||||||
|
|
||||||
assert mapped_inputs["pixel_values"].shape[1] == expected_seq_count + 1
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize(
|
|
||||||
"mm_processor_kwargs",
|
|
||||||
[
|
|
||||||
# Not part of the signature
|
|
||||||
{
|
|
||||||
"does_not_exist": 100
|
|
||||||
},
|
|
||||||
# Part of the signature, not keyword only
|
|
||||||
{
|
|
||||||
"ctx": "something bad"
|
|
||||||
}
|
|
||||||
])
|
|
||||||
def test_custom_mapper_with_sad_kwarg_overrides(image_assets,
|
|
||||||
mm_processor_kwargs):
|
|
||||||
"""Ensure that custom mappers filters out invalid mm_processor_kwargs"""
|
|
||||||
# Should filter out the init time kwargs
|
|
||||||
ctx = build_model_context(MULTIMODAL_MODEL_ID,
|
|
||||||
task="generate",
|
|
||||||
trust_remote_code=True,
|
|
||||||
mm_processor_kwargs=mm_processor_kwargs,
|
|
||||||
limit_mm_per_prompt={"image": 1})
|
|
||||||
|
|
||||||
mm_registry = MultiModalRegistry()
|
|
||||||
mm_registry.init_mm_limits_per_prompt(ctx.model_config)
|
|
||||||
image = image_assets[0].pil_image
|
|
||||||
mm_inputs = {"image": image}
|
|
||||||
|
|
||||||
# Patch the image registry for phi3v with our lambda that is compatible
|
|
||||||
# with overrides, then ensure that calling the method correctly echos
|
|
||||||
# our max_dynamic_patch value back from the mm_processor_kwargs.
|
|
||||||
mm_registry._get_plugin("image").register_input_mapper(custom_mapper)(
|
|
||||||
mm_model_cls())
|
|
||||||
# Should filter out the inference time kwargs
|
|
||||||
mapped_inputs = mm_registry.map_input(
|
|
||||||
ctx.model_config, mm_inputs, mm_processor_kwargs=mm_processor_kwargs)
|
|
||||||
|
|
||||||
assert mapped_inputs["pixel_values"].shape[1] == (
|
|
||||||
DEFAULT_MAX_DYNAMIC_PATCH + 1)
|
|
Loading…
x
Reference in New Issue
Block a user