[Core][Frontend] Support Passing Multimodal Processor Kwargs (#8657)

Signed-off-by: Alex-Brooks <Alex.Brooks@ibm.com>
This commit is contained in:
Alex Brooks 2024-09-23 01:44:48 -06:00 committed by GitHub
parent d23679eb99
commit 9b8c8ba119
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
16 changed files with 590 additions and 117 deletions

View File

@ -40,3 +40,24 @@ def test_limit_mm_per_prompt_parser(arg, expected):
def test_bad_nullable_kvs(arg):
with pytest.raises(ArgumentTypeError):
nullable_kvs(arg)
@pytest.mark.parametrize(("arg", "expected"), [
(None, None),
("{}", {}),
('{"num_crops": 4}', {
"num_crops": 4
}),
('{"foo": {"bar": "baz"}}', {
"foo": {
"bar": "baz"
}
}),
])
def test_mm_processor_kwargs_prompt_parser(arg, expected):
parser = EngineArgs.add_cli_args(FlexibleArgumentParser())
if arg is None:
args = parser.parse_args([])
else:
args = parser.parse_args(["--mm-processor-kwargs", arg])
assert args.mm_processor_kwargs == expected

View File

@ -5,14 +5,13 @@ import pytest
import torch
from PIL.Image import Image
from vllm.config import ModelConfig
from vllm.inputs import InputContext, LLMInputs
from vllm.multimodal.base import MultiModalInputs
from vllm.multimodal.utils import cached_get_tokenizer, rescale_image_size
from ....conftest import (IMAGE_ASSETS, HfRunner, ImageAsset, PromptImageInput,
VllmRunner, _ImageAssets)
from ...utils import check_logprobs_close
from ...utils import build_model_context, check_logprobs_close
text_only_models = [
"Qwen/Qwen-7B-Chat" # Has no visual component
@ -42,32 +41,6 @@ VIS_ENC_DIM = 4096
IMG_SIZE = 448
def build_model_context(model_name: str,
tokenizer_name: Optional[str] = None,
trust_remote_code: bool = False):
"""Creates an InputContext for a given model.
Args:
model_name: Name of the model being considered.
tokenizer_name: Name of the tokenizer being considered.
trust_remote_code: Whether or not to allow loading remote code.
Returns:
InputContext for the model being considered.
"""
if tokenizer_name is None:
tokenizer_name = model_name
model_config = ModelConfig(
model_name,
tokenizer_name,
tokenizer_mode="auto",
trust_remote_code=trust_remote_code,
dtype="float32",
seed=0,
)
return InputContext(model_config)
@pytest.fixture()
def input_mapper_for_qwen():
# Lazy import to avoid initializing CUDA during test collection

View File

@ -1,6 +1,8 @@
import warnings
from typing import Dict, List, Optional, Sequence, Tuple, Union
from vllm.config import ModelConfig
from vllm.inputs import InputContext
from vllm.sequence import Logprob, PromptLogprobs, SampleLogprobs
TokensText = Tuple[List[int], str]
@ -240,3 +242,36 @@ def check_logprobs_close(
warnings.simplefilter("always")
warnings.warn(fail_msg, stacklevel=2)
def build_model_context(model_name: str,
tokenizer_name: Optional[str] = None,
trust_remote_code: bool = False,
mm_processor_kwargs: Optional[Dict] = None,
limit_mm_per_prompt: Optional[Dict] = None):
"""Creates an InputContext for a given model.
Args:
model_name: Name of the model being considered.
tokenizer_name: Name of the tokenizer being considered.
trust_remote_code: Whether or not to allow loading remote code.
mm_processor_kwargs: optional processor kwargs for to be leveraged
in the input processor, mapper, dummy data creation, etc.
limit_mm_per_prompt: Multimodal limits.
Returns:
InputContext for the model being considered.
"""
if tokenizer_name is None:
tokenizer_name = model_name
model_config = ModelConfig(
model_name,
tokenizer_name,
tokenizer_mode="auto",
trust_remote_code=trust_remote_code,
dtype="float32",
seed=0,
mm_processor_kwargs=mm_processor_kwargs,
limit_mm_per_prompt=limit_mm_per_prompt,
)
return InputContext(model_config)

View File

@ -0,0 +1,339 @@
from array import array
from typing import Mapping
from unittest.mock import patch
import pytest
import torch
from vllm.inputs import InputContext, LLMInputs
from vllm.inputs.registry import InputRegistry
from vllm.multimodal import MultiModalRegistry
from vllm.sequence import VLLM_TOKEN_ID_ARRAY_TYPE, SequenceData
from ..models.utils import build_model_context
# Used for fast tests where the model doesn't matter
DUMMY_MODEL_ID = "facebook/opt-125m"
# Used for tests that need a multimodal model
MULTIMODAL_MODEL_ID = "microsoft/Phi-3.5-vision-instruct"
# For mm_processor_kwargs - we test overrides by defining mocks for each place
# it is used, and ensuring that we can pass processor kwargs an override value
# to receive the intended result for things like sequence length etc.
DEFAULT_NUM_CROPS = 4
NUM_CROPS_OVERRIDE = 16
# Mocks for all of the places that we use the mm_processor_kwargs
# to override values in different callables
@pytest.fixture
def use_processor_mock():
"""Patches the internal model input processor with an override callable."""
def custom_processor(ctx: InputContext,
llm_inputs: LLMInputs,
*,
num_crops=DEFAULT_NUM_CROPS):
# For testing purposes, we don't worry about the llm inputs / return
# type validation, and just return the value of the kwarg that we
# clobber.
return num_crops
with patch("vllm.inputs.registry.InputRegistry._get_model_input_processor",
return_value=custom_processor):
yield
@pytest.fixture
def use_dummy_data_mock():
"""Patches the internal model input processor with an override callable."""
def custom_dummy_data_factory(self,
ctx: InputContext,
seq_len: int,
mm_counts: Mapping[str, int],
*,
num_crops=DEFAULT_NUM_CROPS):
seq_data = SequenceData(
array(VLLM_TOKEN_ID_ARRAY_TYPE, [0] * num_crops))
return seq_data, None
with patch(
"vllm.inputs.registry.InputRegistry._default_dummy_data_factory",
custom_dummy_data_factory):
yield
# Lazy import to avoid CUDA reinitialization error
def mm_model_cls():
from vllm.model_executor.models.phi3v import Phi3VForCausalLM
return Phi3VForCausalLM
# lambda whose signature matches max token calcs extra & mapper + extra kwargs
get_num_crops = lambda ctx, *, num_crops=DEFAULT_NUM_CROPS: num_crops
custom_mapper = lambda ctx, data, *, num_crops=DEFAULT_NUM_CROPS: {
"num_pixels": torch.zeros(size=(1, num_crops + 1, 3, 336, 336))
}
### Test for default processor logic & mm_processor_kwargs wrapping
def test_default_processor_is_a_noop():
"""Ensure that by default, there is no processor override."""
dummy_registry = InputRegistry()
ctx = build_model_context(DUMMY_MODEL_ID)
processor = dummy_registry.create_input_processor(ctx.model_config)
proc_inputs = LLMInputs(prompt_token_ids=[], prompt="")
proc_outputs = processor(inputs=proc_inputs)
assert proc_inputs is proc_outputs
@pytest.mark.parametrize("num_crops", [None, NUM_CROPS_OVERRIDE])
def test_processor_default_kwargs(use_processor_mock, num_crops):
"""Ensure input processors can use processor kwargs."""
dummy_registry = InputRegistry()
# If we have a value for num_crops, pass the override value and make
# sure we get that value as a return-value from out mock processor,
# otherwise fall back to the default value
mm_processor_kwargs = None if num_crops is None else {
"num_crops": num_crops
}
expected_num_crops = DEFAULT_NUM_CROPS if num_crops is None else num_crops
ctx = build_model_context(DUMMY_MODEL_ID,
mm_processor_kwargs=mm_processor_kwargs)
processor = dummy_registry.create_input_processor(ctx.model_config)
num_crops_val = processor(LLMInputs(prompt_token_ids=[], prompt=""))
assert num_crops_val == expected_num_crops
@pytest.mark.parametrize(
"mm_processor_kwargs",
[
# Not part of the signature
{
"does_not_exist": 100
},
# Part of the signature, not keyword only
{
"ctx": "something bad"
}
])
def test_processor_with_sad_kwarg_overrides(use_processor_mock,
mm_processor_kwargs):
"""Ensure that input processors filter out invalid mm_processor_kwargs"""
dummy_registry = InputRegistry()
ctx = build_model_context(DUMMY_MODEL_ID,
mm_processor_kwargs=mm_processor_kwargs)
processor = dummy_registry.create_input_processor(ctx.model_config)
num_crops_val = processor(LLMInputs(prompt_token_ids=[], prompt=""))
assert num_crops_val == DEFAULT_NUM_CROPS
### Test overrides for the dummy data
@pytest.mark.parametrize("num_crops", [None, NUM_CROPS_OVERRIDE])
def test_dummy_data_kwarg_overrides(use_dummy_data_mock, num_crops):
"""Ensure dummy data factories can use processor kwargs."""
mm_processor_kwargs = None if num_crops is None else {
"num_crops": num_crops
}
expected_seq_count = DEFAULT_NUM_CROPS if num_crops is None else num_crops
dummy_registry = InputRegistry()
ctx = build_model_context(DUMMY_MODEL_ID,
mm_processor_kwargs=mm_processor_kwargs)
mm_registry = MultiModalRegistry()
mm_registry.init_mm_limits_per_prompt(ctx.model_config)
# NOTE: seq_len is thrown away here since this will leverage the
# default dummy data factory that we have patched in, whose seq
# len is solely dependent on the value of the mm_processor_kwargs.
seq_data, _ = dummy_registry.dummy_data_for_profiling(
ctx.model_config, seq_len=-1, mm_registry=mm_registry)
assert len(seq_data.prompt_token_ids) == expected_seq_count
@pytest.mark.parametrize(
"mm_processor_kwargs",
[
# Not part of the signature
{
"does_not_exist": 100
},
# Part of the signature, not keyword only
{
"ctx": "something bad"
}
])
def test_dummy_data_with_sad_kwarg_overrides(use_dummy_data_mock,
mm_processor_kwargs):
"""Ensure the dummy data factory filters out invalid mm_processor_kwargs"""
dummy_registry = InputRegistry()
ctx = build_model_context(DUMMY_MODEL_ID,
mm_processor_kwargs=mm_processor_kwargs)
mm_registry = MultiModalRegistry()
mm_registry.init_mm_limits_per_prompt(ctx.model_config)
# NOTE: seq_len is thrown away here since this will leverage the
# default dummy data factory that we have patched in, whose seq
# len is solely dependent on the value of the mm_processor_kwargs.
seq_data, _ = dummy_registry.dummy_data_for_profiling(
ctx.model_config, seq_len=-1, mm_registry=mm_registry)
assert len(seq_data.prompt_token_ids) == DEFAULT_NUM_CROPS
### Test overrides for the max token count per multimodal instance
@pytest.mark.parametrize("num_crops", [None, NUM_CROPS_OVERRIDE])
def test_max_tokens_kwarg_overrides(num_crops):
"""Ensure max token calcs can use processor kwargs."""
mm_processor_kwargs = None if num_crops is None else {
"num_crops": num_crops
}
expected_seq_count = DEFAULT_NUM_CROPS if num_crops is None else num_crops
ctx = build_model_context(MULTIMODAL_MODEL_ID,
trust_remote_code=True,
mm_processor_kwargs=mm_processor_kwargs,
limit_mm_per_prompt={"image": 1})
mm_registry = MultiModalRegistry()
mm_registry.init_mm_limits_per_prompt(ctx.model_config)
# Patch the image registry for phi3v with our lambda that is compatible
# with overrides, then ensure that calling the method correctly echos
# our num_crops value back from the mm_processor_kwargs.
with patch.object(
mm_registry._get_plugin("image"),
"_max_mm_tokens",
{mm_model_cls(): get_num_crops},
):
max_multimodal_tokens = mm_registry.get_max_multimodal_tokens(
ctx.model_config)
assert expected_seq_count == max_multimodal_tokens
@pytest.mark.parametrize(
"mm_processor_kwargs",
[
# Not part of the signature
{
"does_not_exist": 100
},
# Part of the signature, not keyword only
{
"ctx": "something bad"
}
])
def test_max_tokens_with_sad_kwarg_overrides(mm_processor_kwargs):
"""Ensure that max token calcs filters out invalid mm_processor_kwargs"""
ctx = build_model_context(MULTIMODAL_MODEL_ID,
trust_remote_code=True,
mm_processor_kwargs=mm_processor_kwargs,
limit_mm_per_prompt={"image": 1})
mm_registry = MultiModalRegistry()
mm_registry.init_mm_limits_per_prompt(ctx.model_config)
# Similar before, but since these kwargs get filtered,
# we always get our default value back.
with patch.object(
mm_registry._get_plugin("image"),
"_max_mm_tokens",
{mm_model_cls(): get_num_crops},
):
max_multimodal_tokens = mm_registry.get_max_multimodal_tokens(
ctx.model_config)
assert max_multimodal_tokens == DEFAULT_NUM_CROPS
### Test overrides for the mapper
@pytest.mark.parametrize("num_crops", [DEFAULT_NUM_CROPS, NUM_CROPS_OVERRIDE])
def test_default_mapper_with_processer_kwargs(image_assets, num_crops):
"""Ensure that the mapper processor kwargs can fall back to HF models."""
# NOTE - we don't validate bad inputs for the default mapper, because it's
# through the automodel interface in transformers, so we can't easily
# inspect what kwargs are or are not allowed.
ctx = build_model_context(MULTIMODAL_MODEL_ID,
trust_remote_code=True,
mm_processor_kwargs={"num_crops": num_crops},
limit_mm_per_prompt={"image": 1})
mm_registry = MultiModalRegistry()
mm_registry.init_mm_limits_per_prompt(ctx.model_config)
image = image_assets[0].pil_image
mm_inputs = {"image": image}
mapped_inputs = mm_registry.map_input(ctx.model_config, mm_inputs)
# Phi3v pixel vals should have shape: [batch, num_crops+1, 3, 336, 336]
assert mapped_inputs["pixel_values"].shape[1] == num_crops + 1
@pytest.mark.parametrize("num_crops", [None, NUM_CROPS_OVERRIDE])
def test_custom_mapper_kwarg_overrides(image_assets, num_crops):
"""Ensure custom mappers can use processor kwargs."""
mm_processor_kwargs = None if num_crops is None else {
"num_crops": num_crops
}
expected_seq_count = DEFAULT_NUM_CROPS if num_crops is None else num_crops
ctx = build_model_context(MULTIMODAL_MODEL_ID,
trust_remote_code=True,
mm_processor_kwargs=mm_processor_kwargs,
limit_mm_per_prompt={"image": 1})
mm_registry = MultiModalRegistry()
mm_registry.init_mm_limits_per_prompt(ctx.model_config)
# Patch the image registry for phi3v with our lambda that is compatible
# with overrides, then ensure that calling the method correctly echos
# our num_crops value back from the mm_processor_kwargs.
image = image_assets[0].pil_image
mm_inputs = {"image": image}
with patch.object(
mm_registry._get_plugin("image"),
"_default_input_mapper",
{mm_model_cls(): custom_mapper},
):
mapped_inputs = mm_registry.map_input(ctx.model_config, mm_inputs)
assert mapped_inputs["pixel_values"].shape[1] == expected_seq_count + 1
@pytest.mark.parametrize(
"mm_processor_kwargs",
[
# Not part of the signature
{
"does_not_exist": 100
},
# Part of the signature, not keyword only
{
"ctx": "something bad"
}
])
def test_custom_mapper_with_sad_kwarg_overrides(image_assets,
mm_processor_kwargs):
"""Ensure that custom mappers filters out invalid mm_processor_kwargs"""
ctx = build_model_context(MULTIMODAL_MODEL_ID,
trust_remote_code=True,
mm_processor_kwargs=mm_processor_kwargs,
limit_mm_per_prompt={"image": 1})
mm_registry = MultiModalRegistry()
mm_registry.init_mm_limits_per_prompt(ctx.model_config)
# Patch the image registry for phi3v with our lambda that is compatible
# with overrides, then ensure that calling the method correctly echos
# our num_crops value back from the mm_processor_kwargs.
image = image_assets[0].pil_image
mm_inputs = {"image": image}
with patch.object(
mm_registry._get_plugin("image"),
"_default_input_mapper",
{mm_model_cls(): custom_mapper},
):
mapped_inputs = mm_registry.map_input(ctx.model_config, mm_inputs)
assert mapped_inputs["pixel_values"].shape[1] == DEFAULT_NUM_CROPS + 1

View File

@ -122,6 +122,8 @@ class ModelConfig:
can not be gathered from the vllm arguments.
config_format: The config format which shall be loaded.
Defaults to 'auto' which defaults to 'hf'.
mm_processor_kwargs: Arguments to be forwarded to the model's processor
for multi-modal data, e.g., image processor.
"""
def __init__(self,
@ -150,7 +152,8 @@ class ModelConfig:
limit_mm_per_prompt: Optional[Mapping[str, int]] = None,
use_async_output_proc: bool = True,
override_neuron_config: Optional[Dict[str, Any]] = None,
config_format: ConfigFormat = ConfigFormat.AUTO) -> None:
config_format: ConfigFormat = ConfigFormat.AUTO,
mm_processor_kwargs: Optional[Dict[str, Any]] = None) -> None:
self.model = model
self.tokenizer = tokenizer
self.tokenizer_mode = tokenizer_mode
@ -184,6 +187,7 @@ class ModelConfig:
self.model, revision)
self.dtype = _get_and_verify_dtype(self.hf_text_config, dtype)
self.use_async_output_proc = use_async_output_proc
self.mm_processor_kwargs = mm_processor_kwargs
# Set enforce_eager to False if the value is unset.
if self.enforce_eager is None:

View File

@ -175,6 +175,7 @@ class EngineArgs:
collect_detailed_traces: Optional[str] = None
disable_async_output_proc: bool = False
override_neuron_config: Optional[Dict[str, Any]] = None
mm_processor_kwargs: Optional[Dict[str, Any]] = None
def __post_init__(self):
if self.tokenizer is None:
@ -513,6 +514,12 @@ class EngineArgs:
'e.g.: `image=16,video=2` allows a maximum of 16 '
'images and 2 videos per prompt. Defaults to 1 for '
'each modality.'))
parser.add_argument(
'--mm-processor-kwargs',
default=None,
type=json.loads,
help=('Overrides for the multimodal input mapping/processing,'
'e.g., image processor. For example: {"num_crops": 4}.'))
# LoRA related configs
parser.add_argument('--enable-lora',
@ -822,6 +829,7 @@ class EngineArgs:
use_async_output_proc=not self.disable_async_output_proc,
override_neuron_config=self.override_neuron_config,
config_format=self.config_format,
mm_processor_kwargs=self.mm_processor_kwargs,
)
def create_load_config(self) -> LoadConfig:

View File

@ -235,7 +235,7 @@ class LLMEngine:
"decoding_config=%r, observability_config=%r, "
"seed=%d, served_model_name=%s, use_v2_block_manager=%s, "
"num_scheduler_steps=%d, enable_prefix_caching=%s, "
"use_async_output_proc=%s)",
"use_async_output_proc=%s, mm_processor_kwargs=%s)",
VLLM_VERSION,
model_config.model,
speculative_config,
@ -268,6 +268,7 @@ class LLMEngine:
scheduler_config.num_scheduler_steps,
cache_config.enable_prefix_caching,
model_config.use_async_output_proc,
model_config.mm_processor_kwargs,
)
# TODO(woosuk): Print more configs in debug mode.
from vllm.plugins import load_general_plugins

View File

@ -134,6 +134,7 @@ class LLM:
max_seq_len_to_capture: int = 8192,
disable_custom_all_reduce: bool = False,
disable_async_output_proc: bool = False,
mm_processor_kwargs: Optional[Dict[str, Any]] = None,
**kwargs,
) -> None:
'''
@ -174,6 +175,7 @@ class LLM:
max_seq_len_to_capture=max_seq_len_to_capture,
disable_custom_all_reduce=disable_custom_all_reduce,
disable_async_output_proc=disable_async_output_proc,
mm_processor_kwargs=mm_processor_kwargs,
**kwargs,
)
self.llm_engine = LLMEngine.from_engine_args(

View File

@ -9,6 +9,7 @@ from transformers import PretrainedConfig
from typing_extensions import TypeVar
from vllm.logger import init_logger
from vllm.utils import get_allowed_kwarg_only_overrides
from .data import LLMInputs
@ -68,12 +69,17 @@ class DummyDataFactory(Protocol):
ctx: InputContext,
seq_len: int,
mm_counts: Mapping[str, int],
**mm_processor_kwargs: Any,
) -> Tuple["SequenceData", Optional["MultiModalDataDict"]]:
"""
Create dummy data to be inputted into the model.
Note:
:data:`InputProcessor` is not applied to the dummy data.
The :code:`mm_processor_kwargs` are overrides provided at
initialization time to values in the config whose values
may affect the number of tokens per instance.
"""
...
@ -152,6 +158,10 @@ class InputRegistry:
return wrapper
def _get_dummy_data_factory(self, model_cls: Type[nn.Module]):
return self._dummy_factories_by_model_type \
.get(model_cls, self._default_dummy_data_factory)
def dummy_data_for_profiling(
self,
model_config: "ModelConfig",
@ -174,15 +184,15 @@ class InputRegistry:
from vllm.model_executor.model_loader import get_model_architecture
model_cls, _ = get_model_architecture(model_config)
dummy_factory = self._dummy_factories_by_model_type \
.get(model_cls, self._default_dummy_data_factory)
mm_counts = mm_registry.get_mm_limits_per_prompt(model_config)
dummy_factory = self._get_dummy_data_factory(model_cls)
seq_data, mm_data = dummy_factory(
InputContext(model_config),
seq_len,
_MultiModalCounts(mm_counts),
)
mm_counts = mm_registry.get_mm_limits_per_prompt(model_config)
mm_processor_kwargs = get_allowed_kwarg_only_overrides(
dummy_factory, overrides=model_config.mm_processor_kwargs)
seq_data, mm_data = dummy_factory(InputContext(model_config), seq_len,
_MultiModalCounts(mm_counts),
**mm_processor_kwargs)
# Having more tokens is over-conservative but otherwise fine
num_tokens = seq_data.prompt_token_ids
@ -229,6 +239,10 @@ class InputRegistry:
return wrapper
def _get_model_input_processor(self, model_cls: Type[nn.Module]):
return self._input_processors_by_model_type \
.get(model_cls, self._default_input_processor)
def process_input(self, model_config: "ModelConfig",
inputs: LLMInputs) -> LLMInputs:
"""
@ -243,15 +257,17 @@ class InputRegistry:
from vllm.model_executor.model_loader import get_model_architecture
model_cls, _ = get_model_architecture(model_config)
processor = self._get_model_input_processor(model_cls)
processor = self._input_processors_by_model_type \
.get(model_cls, self._default_input_processor)
mm_processor_kwargs = get_allowed_kwarg_only_overrides(
processor, overrides=model_config.mm_processor_kwargs)
return processor(InputContext(model_config), inputs)
return processor(InputContext(model_config), inputs,
**mm_processor_kwargs)
def create_input_processor(self, model_config: "ModelConfig"):
"""
Create an input processor (see :meth:`process_input`) for a
Create an input processor (see :meth:`_process_input`) for a
specific model.
"""
return functools.partial(self.process_input, model_config)

View File

@ -14,7 +14,8 @@ from typing_extensions import TypeAlias
from vllm.config import ModelConfig
from vllm.inputs import InputContext
from vllm.logger import init_logger
from vllm.utils import JSONTree, is_list_of, json_map_leaves
from vllm.utils import (JSONTree, get_allowed_kwarg_only_overrides, is_list_of,
json_map_leaves)
logger = init_logger(__name__)
@ -256,11 +257,20 @@ class MultiModalPlugin(ABC):
model_cls, _ = get_model_architecture(model_config)
mapper = self._input_mappers.get(model_cls)
# Only get processor kwargs at mapping time if we are not using the
# input mapper; no overrides are used on the default here because they
# should be passed to the huggingface resource at initialization time.
if mapper is not None and mapper != self._default_input_mapper:
mm_processor_kwargs = get_allowed_kwarg_only_overrides(
mapper, overrides=model_config.mm_processor_kwargs)
else:
mm_processor_kwargs = {}
if mapper is None:
raise KeyError(f"No input mapper in {self} is registered for "
f"model class {model_cls.__name__}.")
return mapper(InputContext(model_config), data)
return mapper(InputContext(model_config), data, **mm_processor_kwargs)
@abstractmethod
def _default_max_multimodal_tokens(self, ctx: InputContext) -> int:
@ -333,7 +343,10 @@ class MultiModalPlugin(ABC):
f"for model class {model_cls.__name__} in {self}.")
if callable(max_mm_tokens):
max_mm_tokens = max_mm_tokens(InputContext(model_config))
mm_processor_kwargs = get_allowed_kwarg_only_overrides(
max_mm_tokens, overrides=model_config.mm_processor_kwargs)
max_mm_tokens = max_mm_tokens(InputContext(model_config),
**mm_processor_kwargs)
self._validate_max_multimodal_tokens(max_mm_tokens)

View File

@ -6,7 +6,7 @@ from PIL import Image
from vllm.config import ModelConfig
from vllm.inputs.registry import InputContext
from vllm.logger import init_logger
from vllm.transformers_utils.image_processor import get_image_processor
from vllm.transformers_utils.processor import get_image_processor
from vllm.utils import is_list_of
from .base import MultiModalData, MultiModalInputs, MultiModalPlugin
@ -23,9 +23,14 @@ class ImagePlugin(MultiModalPlugin):
return "image"
def _get_hf_image_processor(self, model_config: ModelConfig):
mm_processor_kwargs = ({} if model_config.mm_processor_kwargs is None
else model_config.mm_processor_kwargs)
# We don't explicitly check kwarg overrides to the HF class
# since the automodel just takes kwargs, so we can't inspect it
return cached_get_image_processor(
model_config.model,
trust_remote_code=model_config.trust_remote_code)
trust_remote_code=model_config.trust_remote_code,
**mm_processor_kwargs)
def _default_input_mapper(
self,
@ -37,6 +42,7 @@ class ImagePlugin(MultiModalPlugin):
# PIL image
if isinstance(data, Image.Image) or is_list_of(data, Image.Image):
image_processor = self._get_hf_image_processor(model_config)
if image_processor is None:
raise RuntimeError("No HuggingFace processor is available "
"to process the image object")

View File

@ -138,6 +138,15 @@ class MultiModalRegistry:
"""
Create an input mapper (see :meth:`map_input`) for a specific model.
"""
# NOTE - we currently make the assumption that if a model has multiple
# supported modalities, they take the same kwargs. For the default,
# this could be an issue in the future if it falls back to two HF
# resources and we can't inspect the signature easily since it's
# getting initialized through the autoclass.
#
# If this is a problem in the future, we should revisit it, but since
# it potentially introduces a lot of complexity for a currently
# uncommon case, we do not for simplicity of both use & implementation
return functools.partial(self.map_input, model_config)
def register_max_multimodal_tokens(

View File

@ -6,7 +6,7 @@ import numpy as np
from vllm.config import ModelConfig
from vllm.inputs.registry import InputContext
from vllm.logger import init_logger
from vllm.transformers_utils.image_processor import get_video_processor
from vllm.transformers_utils.processor import get_video_processor
from vllm.transformers_utils.tokenizer import get_tokenizer
from vllm.utils import is_list_of
@ -37,9 +37,14 @@ class VideoPlugin(ImagePlugin):
return "video"
def _get_hf_video_processor(self, model_config: ModelConfig):
mm_processor_kwargs = ({} if model_config.mm_processor_kwargs is None
else model_config.mm_processor_kwargs)
# We don't explicitly check kwarg overrides to the HF class
# since the automodel just takes kwargs, so we can't inspect it
return cached_get_video_processor(
model_config.model,
trust_remote_code=model_config.trust_remote_code)
trust_remote_code=model_config.trust_remote_code,
**mm_processor_kwargs)
def _default_input_mapper(
self,

View File

@ -1,64 +0,0 @@
from typing import cast
def get_video_processor(
processor_name: str,
trust_remote_code: bool = False,
):
"""
Gets a processor for the given model name via HuggingFace.
"""
from transformers import AutoProcessor
try:
processor = AutoProcessor.from_pretrained(processor_name)
video_processor = processor.video_processor
except ValueError as e:
if not trust_remote_code:
err_msg = (
"Failed to load the processor. If the processor is "
"a custom processor not yet available in the HuggingFace "
"transformers library, consider setting "
"`trust_remote_code=True` in LLM or using the "
"`--trust-remote-code` flag in the CLI.")
raise RuntimeError(err_msg) from e
else:
raise e
return video_processor
def get_image_processor(
processor_name: str,
*args,
trust_remote_code: bool = False,
**kwargs,
):
"""Gets an image processor for the given model name via HuggingFace."""
# don't put this import at the top level
# it will call torch.cuda.device_count()
from transformers import AutoImageProcessor
from transformers.image_processing_utils import BaseImageProcessor
try:
processor = AutoImageProcessor.from_pretrained(
processor_name,
*args,
trust_remote_code=trust_remote_code,
**kwargs)
except ValueError as e:
# If the error pertains to the processor class not existing or not
# currently being imported, suggest using the --trust-remote-code flag.
# Unlike AutoTokenizer, AutoImageProcessor does not separate such errors
if not trust_remote_code:
err_msg = (
"Failed to load the image processor. If the image processor is "
"a custom processor not yet available in the HuggingFace "
"transformers library, consider setting "
"`trust_remote_code=True` in LLM or using the "
"`--trust-remote-code` flag in the CLI.")
raise RuntimeError(err_msg) from e
else:
raise e
return cast(BaseImageProcessor, processor)

View File

@ -1,13 +1,13 @@
from typing import cast
from typing import Any, cast
def get_processor(
processor_name: str,
*args,
*args: Any,
trust_remote_code: bool = False,
**kwargs,
**kwargs: Any,
):
"""Gets a processor for the given model name via HuggingFace."""
"""Load a processor for the given model name via HuggingFace."""
# don't put this import at the top level
# it will call torch.cuda.device_count()
from transformers import AutoProcessor
@ -35,3 +35,60 @@ def get_processor(
raise e
return cast(ProcessorMixin, processor)
def get_image_processor(
processor_name: str,
*args: Any,
trust_remote_code: bool = False,
**kwargs: Any,
):
"""Load an image processor for the given model name via HuggingFace."""
# don't put this import at the top level
# it will call torch.cuda.device_count()
from transformers import AutoImageProcessor
from transformers.image_processing_utils import BaseImageProcessor
try:
processor = AutoImageProcessor.from_pretrained(
processor_name,
*args,
trust_remote_code=trust_remote_code,
**kwargs)
except ValueError as e:
# If the error pertains to the processor class not existing or not
# currently being imported, suggest using the --trust-remote-code flag.
# Unlike AutoTokenizer, AutoImageProcessor does not separate such errors
if not trust_remote_code:
err_msg = (
"Failed to load the image processor. If the image processor is "
"a custom processor not yet available in the HuggingFace "
"transformers library, consider setting "
"`trust_remote_code=True` in LLM or using the "
"`--trust-remote-code` flag in the CLI.")
raise RuntimeError(err_msg) from e
else:
raise e
return cast(BaseImageProcessor, processor)
def get_video_processor(
processor_name: str,
*args: Any,
trust_remote_code: bool = False,
**kwargs: Any,
):
"""Load a video processor for the given model name via HuggingFace."""
# don't put this import at the top level
# it will call torch.cuda.device_count()
from transformers.image_processing_utils import BaseImageProcessor
processor = get_processor(
processor_name,
*args,
trust_remote_code=trust_remote_code,
**kwargs,
)
return cast(BaseImageProcessor, processor.video_processor)

View File

@ -4,6 +4,7 @@ import contextlib
import datetime
import enum
import gc
import inspect
import os
import random
import socket
@ -1237,6 +1238,53 @@ async def _run_task_with_lock(task: Callable, lock: asyncio.Lock, *args,
return await task(*args, **kwargs)
def get_allowed_kwarg_only_overrides(
callable: Callable[..., object],
overrides: Optional[Dict[str, Any]],
) -> Dict[str, Any]:
"""
Given a callable which has one or more keyword only params and a dict
mapping param names to values, drop values that can be not be kwarg
expanded to overwrite one or more keyword-only args. This is used in a
few places to handle custom processor overrides for multimodal models,
e.g., for profiling when processor options provided by the user
may affect the number of mm tokens per instance.
Args:
callable: Callable which takes 0 or more keyword only arguments.
overrides: Potential overrides to be used when invoking the callable.
Returns:
Dictionary containing the kwargs to be leveraged which may be used
to overwrite one or more keyword only arguments when invoking the
callable.
"""
if not overrides:
return {}
allowed_override_names = [
name for name, param in inspect.signature(callable).parameters.items()
if param.kind == inspect.Parameter.KEYWORD_ONLY
]
# Drop any mm_processor_kwargs provided by the user that are
# not kwarg names accepted by the provided input processor.
filtered_overrides = {
kwarg_name: val
for kwarg_name, val in overrides.items()
if kwarg_name in allowed_override_names
}
# If anything is dropped, log a warning
dropped_keys = overrides.keys() - filtered_overrides.keys()
if dropped_keys:
logger.warning(
"The following intended overrides are not keyword-only args "
"and and will be dropped: %s", dropped_keys)
return filtered_overrides
# Using dynamo with vLLM doesn't really work well with PyTorch versions < 2.4.0.
# In particular, the FakeScalarType is not supported for earlier versions of
# PyTorch which breaks dynamo for any ops registered using ScalarType.