vllm/tests/multimodal/test_processor_kwargs.py

340 lines
13 KiB
Python
Raw Normal View History

from array import array
from typing import Mapping
from unittest.mock import patch
import pytest
import torch
from vllm.inputs import InputContext, LLMInputs
from vllm.inputs.registry import InputRegistry
from vllm.multimodal import MultiModalRegistry
from vllm.sequence import VLLM_TOKEN_ID_ARRAY_TYPE, SequenceData
from ..models.utils import build_model_context
# Used for fast tests where the model doesn't matter
DUMMY_MODEL_ID = "facebook/opt-125m"
# Used for tests that need a multimodal model
MULTIMODAL_MODEL_ID = "microsoft/Phi-3.5-vision-instruct"
# For mm_processor_kwargs - we test overrides by defining mocks for each place
# it is used, and ensuring that we can pass processor kwargs an override value
# to receive the intended result for things like sequence length etc.
DEFAULT_NUM_CROPS = 4
NUM_CROPS_OVERRIDE = 16
# Mocks for all of the places that we use the mm_processor_kwargs
# to override values in different callables
@pytest.fixture
def use_processor_mock():
"""Patches the internal model input processor with an override callable."""
def custom_processor(ctx: InputContext,
llm_inputs: LLMInputs,
*,
num_crops=DEFAULT_NUM_CROPS):
# For testing purposes, we don't worry about the llm inputs / return
# type validation, and just return the value of the kwarg that we
# clobber.
return num_crops
with patch("vllm.inputs.registry.InputRegistry._get_model_input_processor",
return_value=custom_processor):
yield
@pytest.fixture
def use_dummy_data_mock():
"""Patches the internal model input processor with an override callable."""
def custom_dummy_data_factory(self,
ctx: InputContext,
seq_len: int,
mm_counts: Mapping[str, int],
*,
num_crops=DEFAULT_NUM_CROPS):
seq_data = SequenceData(
array(VLLM_TOKEN_ID_ARRAY_TYPE, [0] * num_crops))
return seq_data, None
with patch(
"vllm.inputs.registry.InputRegistry._default_dummy_data_factory",
custom_dummy_data_factory):
yield
# Lazy import to avoid CUDA reinitialization error
def mm_model_cls():
from vllm.model_executor.models.phi3v import Phi3VForCausalLM
return Phi3VForCausalLM
# lambda whose signature matches max token calcs extra & mapper + extra kwargs
get_num_crops = lambda ctx, *, num_crops=DEFAULT_NUM_CROPS: num_crops
custom_mapper = lambda ctx, data, *, num_crops=DEFAULT_NUM_CROPS: {
"num_pixels": torch.zeros(size=(1, num_crops + 1, 3, 336, 336))
}
### Test for default processor logic & mm_processor_kwargs wrapping
def test_default_processor_is_a_noop():
"""Ensure that by default, there is no processor override."""
dummy_registry = InputRegistry()
ctx = build_model_context(DUMMY_MODEL_ID)
processor = dummy_registry.create_input_processor(ctx.model_config)
proc_inputs = LLMInputs(prompt_token_ids=[], prompt="")
proc_outputs = processor(inputs=proc_inputs)
assert proc_inputs is proc_outputs
@pytest.mark.parametrize("num_crops", [None, NUM_CROPS_OVERRIDE])
def test_processor_default_kwargs(use_processor_mock, num_crops):
"""Ensure input processors can use processor kwargs."""
dummy_registry = InputRegistry()
# If we have a value for num_crops, pass the override value and make
# sure we get that value as a return-value from out mock processor,
# otherwise fall back to the default value
mm_processor_kwargs = None if num_crops is None else {
"num_crops": num_crops
}
expected_num_crops = DEFAULT_NUM_CROPS if num_crops is None else num_crops
ctx = build_model_context(DUMMY_MODEL_ID,
mm_processor_kwargs=mm_processor_kwargs)
processor = dummy_registry.create_input_processor(ctx.model_config)
num_crops_val = processor(LLMInputs(prompt_token_ids=[], prompt=""))
assert num_crops_val == expected_num_crops
@pytest.mark.parametrize(
"mm_processor_kwargs",
[
# Not part of the signature
{
"does_not_exist": 100
},
# Part of the signature, not keyword only
{
"ctx": "something bad"
}
])
def test_processor_with_sad_kwarg_overrides(use_processor_mock,
mm_processor_kwargs):
"""Ensure that input processors filter out invalid mm_processor_kwargs"""
dummy_registry = InputRegistry()
ctx = build_model_context(DUMMY_MODEL_ID,
mm_processor_kwargs=mm_processor_kwargs)
processor = dummy_registry.create_input_processor(ctx.model_config)
num_crops_val = processor(LLMInputs(prompt_token_ids=[], prompt=""))
assert num_crops_val == DEFAULT_NUM_CROPS
### Test overrides for the dummy data
@pytest.mark.parametrize("num_crops", [None, NUM_CROPS_OVERRIDE])
def test_dummy_data_kwarg_overrides(use_dummy_data_mock, num_crops):
"""Ensure dummy data factories can use processor kwargs."""
mm_processor_kwargs = None if num_crops is None else {
"num_crops": num_crops
}
expected_seq_count = DEFAULT_NUM_CROPS if num_crops is None else num_crops
dummy_registry = InputRegistry()
ctx = build_model_context(DUMMY_MODEL_ID,
mm_processor_kwargs=mm_processor_kwargs)
mm_registry = MultiModalRegistry()
mm_registry.init_mm_limits_per_prompt(ctx.model_config)
# NOTE: seq_len is thrown away here since this will leverage the
# default dummy data factory that we have patched in, whose seq
# len is solely dependent on the value of the mm_processor_kwargs.
seq_data, _ = dummy_registry.dummy_data_for_profiling(
ctx.model_config, seq_len=-1, mm_registry=mm_registry)
assert len(seq_data.prompt_token_ids) == expected_seq_count
@pytest.mark.parametrize(
"mm_processor_kwargs",
[
# Not part of the signature
{
"does_not_exist": 100
},
# Part of the signature, not keyword only
{
"ctx": "something bad"
}
])
def test_dummy_data_with_sad_kwarg_overrides(use_dummy_data_mock,
mm_processor_kwargs):
"""Ensure the dummy data factory filters out invalid mm_processor_kwargs"""
dummy_registry = InputRegistry()
ctx = build_model_context(DUMMY_MODEL_ID,
mm_processor_kwargs=mm_processor_kwargs)
mm_registry = MultiModalRegistry()
mm_registry.init_mm_limits_per_prompt(ctx.model_config)
# NOTE: seq_len is thrown away here since this will leverage the
# default dummy data factory that we have patched in, whose seq
# len is solely dependent on the value of the mm_processor_kwargs.
seq_data, _ = dummy_registry.dummy_data_for_profiling(
ctx.model_config, seq_len=-1, mm_registry=mm_registry)
assert len(seq_data.prompt_token_ids) == DEFAULT_NUM_CROPS
### Test overrides for the max token count per multimodal instance
@pytest.mark.parametrize("num_crops", [None, NUM_CROPS_OVERRIDE])
def test_max_tokens_kwarg_overrides(num_crops):
"""Ensure max token calcs can use processor kwargs."""
mm_processor_kwargs = None if num_crops is None else {
"num_crops": num_crops
}
expected_seq_count = DEFAULT_NUM_CROPS if num_crops is None else num_crops
ctx = build_model_context(MULTIMODAL_MODEL_ID,
trust_remote_code=True,
mm_processor_kwargs=mm_processor_kwargs,
limit_mm_per_prompt={"image": 1})
mm_registry = MultiModalRegistry()
mm_registry.init_mm_limits_per_prompt(ctx.model_config)
# Patch the image registry for phi3v with our lambda that is compatible
# with overrides, then ensure that calling the method correctly echos
# our num_crops value back from the mm_processor_kwargs.
with patch.object(
mm_registry._get_plugin("image"),
"_max_mm_tokens",
{mm_model_cls(): get_num_crops},
):
max_multimodal_tokens = mm_registry.get_max_multimodal_tokens(
ctx.model_config)
assert expected_seq_count == max_multimodal_tokens
@pytest.mark.parametrize(
"mm_processor_kwargs",
[
# Not part of the signature
{
"does_not_exist": 100
},
# Part of the signature, not keyword only
{
"ctx": "something bad"
}
])
def test_max_tokens_with_sad_kwarg_overrides(mm_processor_kwargs):
"""Ensure that max token calcs filters out invalid mm_processor_kwargs"""
ctx = build_model_context(MULTIMODAL_MODEL_ID,
trust_remote_code=True,
mm_processor_kwargs=mm_processor_kwargs,
limit_mm_per_prompt={"image": 1})
mm_registry = MultiModalRegistry()
mm_registry.init_mm_limits_per_prompt(ctx.model_config)
# Similar before, but since these kwargs get filtered,
# we always get our default value back.
with patch.object(
mm_registry._get_plugin("image"),
"_max_mm_tokens",
{mm_model_cls(): get_num_crops},
):
max_multimodal_tokens = mm_registry.get_max_multimodal_tokens(
ctx.model_config)
assert max_multimodal_tokens == DEFAULT_NUM_CROPS
### Test overrides for the mapper
@pytest.mark.parametrize("num_crops", [DEFAULT_NUM_CROPS, NUM_CROPS_OVERRIDE])
def test_default_mapper_with_processer_kwargs(image_assets, num_crops):
"""Ensure that the mapper processor kwargs can fall back to HF models."""
# NOTE - we don't validate bad inputs for the default mapper, because it's
# through the automodel interface in transformers, so we can't easily
# inspect what kwargs are or are not allowed.
ctx = build_model_context(MULTIMODAL_MODEL_ID,
trust_remote_code=True,
mm_processor_kwargs={"num_crops": num_crops},
limit_mm_per_prompt={"image": 1})
mm_registry = MultiModalRegistry()
mm_registry.init_mm_limits_per_prompt(ctx.model_config)
image = image_assets[0].pil_image
mm_inputs = {"image": image}
mapped_inputs = mm_registry.map_input(ctx.model_config, mm_inputs)
# Phi3v pixel vals should have shape: [batch, num_crops+1, 3, 336, 336]
assert mapped_inputs["pixel_values"].shape[1] == num_crops + 1
@pytest.mark.parametrize("num_crops", [None, NUM_CROPS_OVERRIDE])
def test_custom_mapper_kwarg_overrides(image_assets, num_crops):
"""Ensure custom mappers can use processor kwargs."""
mm_processor_kwargs = None if num_crops is None else {
"num_crops": num_crops
}
expected_seq_count = DEFAULT_NUM_CROPS if num_crops is None else num_crops
ctx = build_model_context(MULTIMODAL_MODEL_ID,
trust_remote_code=True,
mm_processor_kwargs=mm_processor_kwargs,
limit_mm_per_prompt={"image": 1})
mm_registry = MultiModalRegistry()
mm_registry.init_mm_limits_per_prompt(ctx.model_config)
# Patch the image registry for phi3v with our lambda that is compatible
# with overrides, then ensure that calling the method correctly echos
# our num_crops value back from the mm_processor_kwargs.
image = image_assets[0].pil_image
mm_inputs = {"image": image}
with patch.object(
mm_registry._get_plugin("image"),
"_default_input_mapper",
{mm_model_cls(): custom_mapper},
):
mapped_inputs = mm_registry.map_input(ctx.model_config, mm_inputs)
assert mapped_inputs["pixel_values"].shape[1] == expected_seq_count + 1
@pytest.mark.parametrize(
"mm_processor_kwargs",
[
# Not part of the signature
{
"does_not_exist": 100
},
# Part of the signature, not keyword only
{
"ctx": "something bad"
}
])
def test_custom_mapper_with_sad_kwarg_overrides(image_assets,
mm_processor_kwargs):
"""Ensure that custom mappers filters out invalid mm_processor_kwargs"""
ctx = build_model_context(MULTIMODAL_MODEL_ID,
trust_remote_code=True,
mm_processor_kwargs=mm_processor_kwargs,
limit_mm_per_prompt={"image": 1})
mm_registry = MultiModalRegistry()
mm_registry.init_mm_limits_per_prompt(ctx.model_config)
# Patch the image registry for phi3v with our lambda that is compatible
# with overrides, then ensure that calling the method correctly echos
# our num_crops value back from the mm_processor_kwargs.
image = image_assets[0].pil_image
mm_inputs = {"image": image}
with patch.object(
mm_registry._get_plugin("image"),
"_default_input_mapper",
{mm_model_cls(): custom_mapper},
):
mapped_inputs = mm_registry.map_input(ctx.model_config, mm_inputs)
assert mapped_inputs["pixel_values"].shape[1] == DEFAULT_NUM_CROPS + 1