[Core][Frontend] Support Passing Multimodal Processor Kwargs (#8657)
Signed-off-by: Alex-Brooks <Alex.Brooks@ibm.com>
This commit is contained in:
parent
d23679eb99
commit
9b8c8ba119
@ -40,3 +40,24 @@ def test_limit_mm_per_prompt_parser(arg, expected):
|
|||||||
def test_bad_nullable_kvs(arg):
|
def test_bad_nullable_kvs(arg):
|
||||||
with pytest.raises(ArgumentTypeError):
|
with pytest.raises(ArgumentTypeError):
|
||||||
nullable_kvs(arg)
|
nullable_kvs(arg)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(("arg", "expected"), [
|
||||||
|
(None, None),
|
||||||
|
("{}", {}),
|
||||||
|
('{"num_crops": 4}', {
|
||||||
|
"num_crops": 4
|
||||||
|
}),
|
||||||
|
('{"foo": {"bar": "baz"}}', {
|
||||||
|
"foo": {
|
||||||
|
"bar": "baz"
|
||||||
|
}
|
||||||
|
}),
|
||||||
|
])
|
||||||
|
def test_mm_processor_kwargs_prompt_parser(arg, expected):
|
||||||
|
parser = EngineArgs.add_cli_args(FlexibleArgumentParser())
|
||||||
|
if arg is None:
|
||||||
|
args = parser.parse_args([])
|
||||||
|
else:
|
||||||
|
args = parser.parse_args(["--mm-processor-kwargs", arg])
|
||||||
|
assert args.mm_processor_kwargs == expected
|
||||||
|
@ -5,14 +5,13 @@ import pytest
|
|||||||
import torch
|
import torch
|
||||||
from PIL.Image import Image
|
from PIL.Image import Image
|
||||||
|
|
||||||
from vllm.config import ModelConfig
|
|
||||||
from vllm.inputs import InputContext, LLMInputs
|
from vllm.inputs import InputContext, LLMInputs
|
||||||
from vllm.multimodal.base import MultiModalInputs
|
from vllm.multimodal.base import MultiModalInputs
|
||||||
from vllm.multimodal.utils import cached_get_tokenizer, rescale_image_size
|
from vllm.multimodal.utils import cached_get_tokenizer, rescale_image_size
|
||||||
|
|
||||||
from ....conftest import (IMAGE_ASSETS, HfRunner, ImageAsset, PromptImageInput,
|
from ....conftest import (IMAGE_ASSETS, HfRunner, ImageAsset, PromptImageInput,
|
||||||
VllmRunner, _ImageAssets)
|
VllmRunner, _ImageAssets)
|
||||||
from ...utils import check_logprobs_close
|
from ...utils import build_model_context, check_logprobs_close
|
||||||
|
|
||||||
text_only_models = [
|
text_only_models = [
|
||||||
"Qwen/Qwen-7B-Chat" # Has no visual component
|
"Qwen/Qwen-7B-Chat" # Has no visual component
|
||||||
@ -42,32 +41,6 @@ VIS_ENC_DIM = 4096
|
|||||||
IMG_SIZE = 448
|
IMG_SIZE = 448
|
||||||
|
|
||||||
|
|
||||||
def build_model_context(model_name: str,
|
|
||||||
tokenizer_name: Optional[str] = None,
|
|
||||||
trust_remote_code: bool = False):
|
|
||||||
"""Creates an InputContext for a given model.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
model_name: Name of the model being considered.
|
|
||||||
tokenizer_name: Name of the tokenizer being considered.
|
|
||||||
trust_remote_code: Whether or not to allow loading remote code.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
InputContext for the model being considered.
|
|
||||||
"""
|
|
||||||
if tokenizer_name is None:
|
|
||||||
tokenizer_name = model_name
|
|
||||||
model_config = ModelConfig(
|
|
||||||
model_name,
|
|
||||||
tokenizer_name,
|
|
||||||
tokenizer_mode="auto",
|
|
||||||
trust_remote_code=trust_remote_code,
|
|
||||||
dtype="float32",
|
|
||||||
seed=0,
|
|
||||||
)
|
|
||||||
return InputContext(model_config)
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture()
|
@pytest.fixture()
|
||||||
def input_mapper_for_qwen():
|
def input_mapper_for_qwen():
|
||||||
# Lazy import to avoid initializing CUDA during test collection
|
# Lazy import to avoid initializing CUDA during test collection
|
||||||
|
@ -1,6 +1,8 @@
|
|||||||
import warnings
|
import warnings
|
||||||
from typing import Dict, List, Optional, Sequence, Tuple, Union
|
from typing import Dict, List, Optional, Sequence, Tuple, Union
|
||||||
|
|
||||||
|
from vllm.config import ModelConfig
|
||||||
|
from vllm.inputs import InputContext
|
||||||
from vllm.sequence import Logprob, PromptLogprobs, SampleLogprobs
|
from vllm.sequence import Logprob, PromptLogprobs, SampleLogprobs
|
||||||
|
|
||||||
TokensText = Tuple[List[int], str]
|
TokensText = Tuple[List[int], str]
|
||||||
@ -240,3 +242,36 @@ def check_logprobs_close(
|
|||||||
warnings.simplefilter("always")
|
warnings.simplefilter("always")
|
||||||
|
|
||||||
warnings.warn(fail_msg, stacklevel=2)
|
warnings.warn(fail_msg, stacklevel=2)
|
||||||
|
|
||||||
|
|
||||||
|
def build_model_context(model_name: str,
|
||||||
|
tokenizer_name: Optional[str] = None,
|
||||||
|
trust_remote_code: bool = False,
|
||||||
|
mm_processor_kwargs: Optional[Dict] = None,
|
||||||
|
limit_mm_per_prompt: Optional[Dict] = None):
|
||||||
|
"""Creates an InputContext for a given model.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model_name: Name of the model being considered.
|
||||||
|
tokenizer_name: Name of the tokenizer being considered.
|
||||||
|
trust_remote_code: Whether or not to allow loading remote code.
|
||||||
|
mm_processor_kwargs: optional processor kwargs for to be leveraged
|
||||||
|
in the input processor, mapper, dummy data creation, etc.
|
||||||
|
limit_mm_per_prompt: Multimodal limits.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
InputContext for the model being considered.
|
||||||
|
"""
|
||||||
|
if tokenizer_name is None:
|
||||||
|
tokenizer_name = model_name
|
||||||
|
model_config = ModelConfig(
|
||||||
|
model_name,
|
||||||
|
tokenizer_name,
|
||||||
|
tokenizer_mode="auto",
|
||||||
|
trust_remote_code=trust_remote_code,
|
||||||
|
dtype="float32",
|
||||||
|
seed=0,
|
||||||
|
mm_processor_kwargs=mm_processor_kwargs,
|
||||||
|
limit_mm_per_prompt=limit_mm_per_prompt,
|
||||||
|
)
|
||||||
|
return InputContext(model_config)
|
||||||
|
339
tests/multimodal/test_processor_kwargs.py
Normal file
339
tests/multimodal/test_processor_kwargs.py
Normal file
@ -0,0 +1,339 @@
|
|||||||
|
from array import array
|
||||||
|
from typing import Mapping
|
||||||
|
from unittest.mock import patch
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from vllm.inputs import InputContext, LLMInputs
|
||||||
|
from vllm.inputs.registry import InputRegistry
|
||||||
|
from vllm.multimodal import MultiModalRegistry
|
||||||
|
from vllm.sequence import VLLM_TOKEN_ID_ARRAY_TYPE, SequenceData
|
||||||
|
|
||||||
|
from ..models.utils import build_model_context
|
||||||
|
|
||||||
|
# Used for fast tests where the model doesn't matter
|
||||||
|
DUMMY_MODEL_ID = "facebook/opt-125m"
|
||||||
|
# Used for tests that need a multimodal model
|
||||||
|
MULTIMODAL_MODEL_ID = "microsoft/Phi-3.5-vision-instruct"
|
||||||
|
|
||||||
|
# For mm_processor_kwargs - we test overrides by defining mocks for each place
|
||||||
|
# it is used, and ensuring that we can pass processor kwargs an override value
|
||||||
|
# to receive the intended result for things like sequence length etc.
|
||||||
|
DEFAULT_NUM_CROPS = 4
|
||||||
|
NUM_CROPS_OVERRIDE = 16
|
||||||
|
|
||||||
|
|
||||||
|
# Mocks for all of the places that we use the mm_processor_kwargs
|
||||||
|
# to override values in different callables
|
||||||
|
@pytest.fixture
|
||||||
|
def use_processor_mock():
|
||||||
|
"""Patches the internal model input processor with an override callable."""
|
||||||
|
|
||||||
|
def custom_processor(ctx: InputContext,
|
||||||
|
llm_inputs: LLMInputs,
|
||||||
|
*,
|
||||||
|
num_crops=DEFAULT_NUM_CROPS):
|
||||||
|
# For testing purposes, we don't worry about the llm inputs / return
|
||||||
|
# type validation, and just return the value of the kwarg that we
|
||||||
|
# clobber.
|
||||||
|
return num_crops
|
||||||
|
|
||||||
|
with patch("vllm.inputs.registry.InputRegistry._get_model_input_processor",
|
||||||
|
return_value=custom_processor):
|
||||||
|
yield
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def use_dummy_data_mock():
|
||||||
|
"""Patches the internal model input processor with an override callable."""
|
||||||
|
|
||||||
|
def custom_dummy_data_factory(self,
|
||||||
|
ctx: InputContext,
|
||||||
|
seq_len: int,
|
||||||
|
mm_counts: Mapping[str, int],
|
||||||
|
*,
|
||||||
|
num_crops=DEFAULT_NUM_CROPS):
|
||||||
|
seq_data = SequenceData(
|
||||||
|
array(VLLM_TOKEN_ID_ARRAY_TYPE, [0] * num_crops))
|
||||||
|
return seq_data, None
|
||||||
|
|
||||||
|
with patch(
|
||||||
|
"vllm.inputs.registry.InputRegistry._default_dummy_data_factory",
|
||||||
|
custom_dummy_data_factory):
|
||||||
|
yield
|
||||||
|
|
||||||
|
|
||||||
|
# Lazy import to avoid CUDA reinitialization error
|
||||||
|
def mm_model_cls():
|
||||||
|
from vllm.model_executor.models.phi3v import Phi3VForCausalLM
|
||||||
|
|
||||||
|
return Phi3VForCausalLM
|
||||||
|
|
||||||
|
|
||||||
|
# lambda whose signature matches max token calcs extra & mapper + extra kwargs
|
||||||
|
get_num_crops = lambda ctx, *, num_crops=DEFAULT_NUM_CROPS: num_crops
|
||||||
|
custom_mapper = lambda ctx, data, *, num_crops=DEFAULT_NUM_CROPS: {
|
||||||
|
"num_pixels": torch.zeros(size=(1, num_crops + 1, 3, 336, 336))
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
### Test for default processor logic & mm_processor_kwargs wrapping
|
||||||
|
def test_default_processor_is_a_noop():
|
||||||
|
"""Ensure that by default, there is no processor override."""
|
||||||
|
dummy_registry = InputRegistry()
|
||||||
|
ctx = build_model_context(DUMMY_MODEL_ID)
|
||||||
|
processor = dummy_registry.create_input_processor(ctx.model_config)
|
||||||
|
proc_inputs = LLMInputs(prompt_token_ids=[], prompt="")
|
||||||
|
proc_outputs = processor(inputs=proc_inputs)
|
||||||
|
assert proc_inputs is proc_outputs
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("num_crops", [None, NUM_CROPS_OVERRIDE])
|
||||||
|
def test_processor_default_kwargs(use_processor_mock, num_crops):
|
||||||
|
"""Ensure input processors can use processor kwargs."""
|
||||||
|
dummy_registry = InputRegistry()
|
||||||
|
# If we have a value for num_crops, pass the override value and make
|
||||||
|
# sure we get that value as a return-value from out mock processor,
|
||||||
|
# otherwise fall back to the default value
|
||||||
|
mm_processor_kwargs = None if num_crops is None else {
|
||||||
|
"num_crops": num_crops
|
||||||
|
}
|
||||||
|
expected_num_crops = DEFAULT_NUM_CROPS if num_crops is None else num_crops
|
||||||
|
ctx = build_model_context(DUMMY_MODEL_ID,
|
||||||
|
mm_processor_kwargs=mm_processor_kwargs)
|
||||||
|
processor = dummy_registry.create_input_processor(ctx.model_config)
|
||||||
|
|
||||||
|
num_crops_val = processor(LLMInputs(prompt_token_ids=[], prompt=""))
|
||||||
|
assert num_crops_val == expected_num_crops
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"mm_processor_kwargs",
|
||||||
|
[
|
||||||
|
# Not part of the signature
|
||||||
|
{
|
||||||
|
"does_not_exist": 100
|
||||||
|
},
|
||||||
|
# Part of the signature, not keyword only
|
||||||
|
{
|
||||||
|
"ctx": "something bad"
|
||||||
|
}
|
||||||
|
])
|
||||||
|
def test_processor_with_sad_kwarg_overrides(use_processor_mock,
|
||||||
|
mm_processor_kwargs):
|
||||||
|
"""Ensure that input processors filter out invalid mm_processor_kwargs"""
|
||||||
|
dummy_registry = InputRegistry()
|
||||||
|
ctx = build_model_context(DUMMY_MODEL_ID,
|
||||||
|
mm_processor_kwargs=mm_processor_kwargs)
|
||||||
|
|
||||||
|
processor = dummy_registry.create_input_processor(ctx.model_config)
|
||||||
|
num_crops_val = processor(LLMInputs(prompt_token_ids=[], prompt=""))
|
||||||
|
assert num_crops_val == DEFAULT_NUM_CROPS
|
||||||
|
|
||||||
|
|
||||||
|
### Test overrides for the dummy data
|
||||||
|
@pytest.mark.parametrize("num_crops", [None, NUM_CROPS_OVERRIDE])
|
||||||
|
def test_dummy_data_kwarg_overrides(use_dummy_data_mock, num_crops):
|
||||||
|
"""Ensure dummy data factories can use processor kwargs."""
|
||||||
|
mm_processor_kwargs = None if num_crops is None else {
|
||||||
|
"num_crops": num_crops
|
||||||
|
}
|
||||||
|
expected_seq_count = DEFAULT_NUM_CROPS if num_crops is None else num_crops
|
||||||
|
dummy_registry = InputRegistry()
|
||||||
|
ctx = build_model_context(DUMMY_MODEL_ID,
|
||||||
|
mm_processor_kwargs=mm_processor_kwargs)
|
||||||
|
mm_registry = MultiModalRegistry()
|
||||||
|
mm_registry.init_mm_limits_per_prompt(ctx.model_config)
|
||||||
|
|
||||||
|
# NOTE: seq_len is thrown away here since this will leverage the
|
||||||
|
# default dummy data factory that we have patched in, whose seq
|
||||||
|
# len is solely dependent on the value of the mm_processor_kwargs.
|
||||||
|
seq_data, _ = dummy_registry.dummy_data_for_profiling(
|
||||||
|
ctx.model_config, seq_len=-1, mm_registry=mm_registry)
|
||||||
|
assert len(seq_data.prompt_token_ids) == expected_seq_count
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"mm_processor_kwargs",
|
||||||
|
[
|
||||||
|
# Not part of the signature
|
||||||
|
{
|
||||||
|
"does_not_exist": 100
|
||||||
|
},
|
||||||
|
# Part of the signature, not keyword only
|
||||||
|
{
|
||||||
|
"ctx": "something bad"
|
||||||
|
}
|
||||||
|
])
|
||||||
|
def test_dummy_data_with_sad_kwarg_overrides(use_dummy_data_mock,
|
||||||
|
mm_processor_kwargs):
|
||||||
|
"""Ensure the dummy data factory filters out invalid mm_processor_kwargs"""
|
||||||
|
dummy_registry = InputRegistry()
|
||||||
|
ctx = build_model_context(DUMMY_MODEL_ID,
|
||||||
|
mm_processor_kwargs=mm_processor_kwargs)
|
||||||
|
mm_registry = MultiModalRegistry()
|
||||||
|
mm_registry.init_mm_limits_per_prompt(ctx.model_config)
|
||||||
|
|
||||||
|
# NOTE: seq_len is thrown away here since this will leverage the
|
||||||
|
# default dummy data factory that we have patched in, whose seq
|
||||||
|
# len is solely dependent on the value of the mm_processor_kwargs.
|
||||||
|
seq_data, _ = dummy_registry.dummy_data_for_profiling(
|
||||||
|
ctx.model_config, seq_len=-1, mm_registry=mm_registry)
|
||||||
|
assert len(seq_data.prompt_token_ids) == DEFAULT_NUM_CROPS
|
||||||
|
|
||||||
|
|
||||||
|
### Test overrides for the max token count per multimodal instance
|
||||||
|
@pytest.mark.parametrize("num_crops", [None, NUM_CROPS_OVERRIDE])
|
||||||
|
def test_max_tokens_kwarg_overrides(num_crops):
|
||||||
|
"""Ensure max token calcs can use processor kwargs."""
|
||||||
|
mm_processor_kwargs = None if num_crops is None else {
|
||||||
|
"num_crops": num_crops
|
||||||
|
}
|
||||||
|
expected_seq_count = DEFAULT_NUM_CROPS if num_crops is None else num_crops
|
||||||
|
|
||||||
|
ctx = build_model_context(MULTIMODAL_MODEL_ID,
|
||||||
|
trust_remote_code=True,
|
||||||
|
mm_processor_kwargs=mm_processor_kwargs,
|
||||||
|
limit_mm_per_prompt={"image": 1})
|
||||||
|
|
||||||
|
mm_registry = MultiModalRegistry()
|
||||||
|
mm_registry.init_mm_limits_per_prompt(ctx.model_config)
|
||||||
|
# Patch the image registry for phi3v with our lambda that is compatible
|
||||||
|
# with overrides, then ensure that calling the method correctly echos
|
||||||
|
# our num_crops value back from the mm_processor_kwargs.
|
||||||
|
with patch.object(
|
||||||
|
mm_registry._get_plugin("image"),
|
||||||
|
"_max_mm_tokens",
|
||||||
|
{mm_model_cls(): get_num_crops},
|
||||||
|
):
|
||||||
|
max_multimodal_tokens = mm_registry.get_max_multimodal_tokens(
|
||||||
|
ctx.model_config)
|
||||||
|
|
||||||
|
assert expected_seq_count == max_multimodal_tokens
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"mm_processor_kwargs",
|
||||||
|
[
|
||||||
|
# Not part of the signature
|
||||||
|
{
|
||||||
|
"does_not_exist": 100
|
||||||
|
},
|
||||||
|
# Part of the signature, not keyword only
|
||||||
|
{
|
||||||
|
"ctx": "something bad"
|
||||||
|
}
|
||||||
|
])
|
||||||
|
def test_max_tokens_with_sad_kwarg_overrides(mm_processor_kwargs):
|
||||||
|
"""Ensure that max token calcs filters out invalid mm_processor_kwargs"""
|
||||||
|
ctx = build_model_context(MULTIMODAL_MODEL_ID,
|
||||||
|
trust_remote_code=True,
|
||||||
|
mm_processor_kwargs=mm_processor_kwargs,
|
||||||
|
limit_mm_per_prompt={"image": 1})
|
||||||
|
|
||||||
|
mm_registry = MultiModalRegistry()
|
||||||
|
mm_registry.init_mm_limits_per_prompt(ctx.model_config)
|
||||||
|
|
||||||
|
# Similar before, but since these kwargs get filtered,
|
||||||
|
# we always get our default value back.
|
||||||
|
with patch.object(
|
||||||
|
mm_registry._get_plugin("image"),
|
||||||
|
"_max_mm_tokens",
|
||||||
|
{mm_model_cls(): get_num_crops},
|
||||||
|
):
|
||||||
|
max_multimodal_tokens = mm_registry.get_max_multimodal_tokens(
|
||||||
|
ctx.model_config)
|
||||||
|
|
||||||
|
assert max_multimodal_tokens == DEFAULT_NUM_CROPS
|
||||||
|
|
||||||
|
|
||||||
|
### Test overrides for the mapper
|
||||||
|
@pytest.mark.parametrize("num_crops", [DEFAULT_NUM_CROPS, NUM_CROPS_OVERRIDE])
|
||||||
|
def test_default_mapper_with_processer_kwargs(image_assets, num_crops):
|
||||||
|
"""Ensure that the mapper processor kwargs can fall back to HF models."""
|
||||||
|
# NOTE - we don't validate bad inputs for the default mapper, because it's
|
||||||
|
# through the automodel interface in transformers, so we can't easily
|
||||||
|
# inspect what kwargs are or are not allowed.
|
||||||
|
ctx = build_model_context(MULTIMODAL_MODEL_ID,
|
||||||
|
trust_remote_code=True,
|
||||||
|
mm_processor_kwargs={"num_crops": num_crops},
|
||||||
|
limit_mm_per_prompt={"image": 1})
|
||||||
|
|
||||||
|
mm_registry = MultiModalRegistry()
|
||||||
|
mm_registry.init_mm_limits_per_prompt(ctx.model_config)
|
||||||
|
|
||||||
|
image = image_assets[0].pil_image
|
||||||
|
mm_inputs = {"image": image}
|
||||||
|
|
||||||
|
mapped_inputs = mm_registry.map_input(ctx.model_config, mm_inputs)
|
||||||
|
# Phi3v pixel vals should have shape: [batch, num_crops+1, 3, 336, 336]
|
||||||
|
assert mapped_inputs["pixel_values"].shape[1] == num_crops + 1
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("num_crops", [None, NUM_CROPS_OVERRIDE])
|
||||||
|
def test_custom_mapper_kwarg_overrides(image_assets, num_crops):
|
||||||
|
"""Ensure custom mappers can use processor kwargs."""
|
||||||
|
mm_processor_kwargs = None if num_crops is None else {
|
||||||
|
"num_crops": num_crops
|
||||||
|
}
|
||||||
|
expected_seq_count = DEFAULT_NUM_CROPS if num_crops is None else num_crops
|
||||||
|
ctx = build_model_context(MULTIMODAL_MODEL_ID,
|
||||||
|
trust_remote_code=True,
|
||||||
|
mm_processor_kwargs=mm_processor_kwargs,
|
||||||
|
limit_mm_per_prompt={"image": 1})
|
||||||
|
|
||||||
|
mm_registry = MultiModalRegistry()
|
||||||
|
mm_registry.init_mm_limits_per_prompt(ctx.model_config)
|
||||||
|
# Patch the image registry for phi3v with our lambda that is compatible
|
||||||
|
# with overrides, then ensure that calling the method correctly echos
|
||||||
|
# our num_crops value back from the mm_processor_kwargs.
|
||||||
|
image = image_assets[0].pil_image
|
||||||
|
mm_inputs = {"image": image}
|
||||||
|
|
||||||
|
with patch.object(
|
||||||
|
mm_registry._get_plugin("image"),
|
||||||
|
"_default_input_mapper",
|
||||||
|
{mm_model_cls(): custom_mapper},
|
||||||
|
):
|
||||||
|
mapped_inputs = mm_registry.map_input(ctx.model_config, mm_inputs)
|
||||||
|
|
||||||
|
assert mapped_inputs["pixel_values"].shape[1] == expected_seq_count + 1
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"mm_processor_kwargs",
|
||||||
|
[
|
||||||
|
# Not part of the signature
|
||||||
|
{
|
||||||
|
"does_not_exist": 100
|
||||||
|
},
|
||||||
|
# Part of the signature, not keyword only
|
||||||
|
{
|
||||||
|
"ctx": "something bad"
|
||||||
|
}
|
||||||
|
])
|
||||||
|
def test_custom_mapper_with_sad_kwarg_overrides(image_assets,
|
||||||
|
mm_processor_kwargs):
|
||||||
|
"""Ensure that custom mappers filters out invalid mm_processor_kwargs"""
|
||||||
|
ctx = build_model_context(MULTIMODAL_MODEL_ID,
|
||||||
|
trust_remote_code=True,
|
||||||
|
mm_processor_kwargs=mm_processor_kwargs,
|
||||||
|
limit_mm_per_prompt={"image": 1})
|
||||||
|
|
||||||
|
mm_registry = MultiModalRegistry()
|
||||||
|
mm_registry.init_mm_limits_per_prompt(ctx.model_config)
|
||||||
|
# Patch the image registry for phi3v with our lambda that is compatible
|
||||||
|
# with overrides, then ensure that calling the method correctly echos
|
||||||
|
# our num_crops value back from the mm_processor_kwargs.
|
||||||
|
image = image_assets[0].pil_image
|
||||||
|
mm_inputs = {"image": image}
|
||||||
|
|
||||||
|
with patch.object(
|
||||||
|
mm_registry._get_plugin("image"),
|
||||||
|
"_default_input_mapper",
|
||||||
|
{mm_model_cls(): custom_mapper},
|
||||||
|
):
|
||||||
|
mapped_inputs = mm_registry.map_input(ctx.model_config, mm_inputs)
|
||||||
|
|
||||||
|
assert mapped_inputs["pixel_values"].shape[1] == DEFAULT_NUM_CROPS + 1
|
@ -122,6 +122,8 @@ class ModelConfig:
|
|||||||
can not be gathered from the vllm arguments.
|
can not be gathered from the vllm arguments.
|
||||||
config_format: The config format which shall be loaded.
|
config_format: The config format which shall be loaded.
|
||||||
Defaults to 'auto' which defaults to 'hf'.
|
Defaults to 'auto' which defaults to 'hf'.
|
||||||
|
mm_processor_kwargs: Arguments to be forwarded to the model's processor
|
||||||
|
for multi-modal data, e.g., image processor.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self,
|
def __init__(self,
|
||||||
@ -150,7 +152,8 @@ class ModelConfig:
|
|||||||
limit_mm_per_prompt: Optional[Mapping[str, int]] = None,
|
limit_mm_per_prompt: Optional[Mapping[str, int]] = None,
|
||||||
use_async_output_proc: bool = True,
|
use_async_output_proc: bool = True,
|
||||||
override_neuron_config: Optional[Dict[str, Any]] = None,
|
override_neuron_config: Optional[Dict[str, Any]] = None,
|
||||||
config_format: ConfigFormat = ConfigFormat.AUTO) -> None:
|
config_format: ConfigFormat = ConfigFormat.AUTO,
|
||||||
|
mm_processor_kwargs: Optional[Dict[str, Any]] = None) -> None:
|
||||||
self.model = model
|
self.model = model
|
||||||
self.tokenizer = tokenizer
|
self.tokenizer = tokenizer
|
||||||
self.tokenizer_mode = tokenizer_mode
|
self.tokenizer_mode = tokenizer_mode
|
||||||
@ -184,6 +187,7 @@ class ModelConfig:
|
|||||||
self.model, revision)
|
self.model, revision)
|
||||||
self.dtype = _get_and_verify_dtype(self.hf_text_config, dtype)
|
self.dtype = _get_and_verify_dtype(self.hf_text_config, dtype)
|
||||||
self.use_async_output_proc = use_async_output_proc
|
self.use_async_output_proc = use_async_output_proc
|
||||||
|
self.mm_processor_kwargs = mm_processor_kwargs
|
||||||
|
|
||||||
# Set enforce_eager to False if the value is unset.
|
# Set enforce_eager to False if the value is unset.
|
||||||
if self.enforce_eager is None:
|
if self.enforce_eager is None:
|
||||||
|
@ -175,6 +175,7 @@ class EngineArgs:
|
|||||||
collect_detailed_traces: Optional[str] = None
|
collect_detailed_traces: Optional[str] = None
|
||||||
disable_async_output_proc: bool = False
|
disable_async_output_proc: bool = False
|
||||||
override_neuron_config: Optional[Dict[str, Any]] = None
|
override_neuron_config: Optional[Dict[str, Any]] = None
|
||||||
|
mm_processor_kwargs: Optional[Dict[str, Any]] = None
|
||||||
|
|
||||||
def __post_init__(self):
|
def __post_init__(self):
|
||||||
if self.tokenizer is None:
|
if self.tokenizer is None:
|
||||||
@ -513,6 +514,12 @@ class EngineArgs:
|
|||||||
'e.g.: `image=16,video=2` allows a maximum of 16 '
|
'e.g.: `image=16,video=2` allows a maximum of 16 '
|
||||||
'images and 2 videos per prompt. Defaults to 1 for '
|
'images and 2 videos per prompt. Defaults to 1 for '
|
||||||
'each modality.'))
|
'each modality.'))
|
||||||
|
parser.add_argument(
|
||||||
|
'--mm-processor-kwargs',
|
||||||
|
default=None,
|
||||||
|
type=json.loads,
|
||||||
|
help=('Overrides for the multimodal input mapping/processing,'
|
||||||
|
'e.g., image processor. For example: {"num_crops": 4}.'))
|
||||||
|
|
||||||
# LoRA related configs
|
# LoRA related configs
|
||||||
parser.add_argument('--enable-lora',
|
parser.add_argument('--enable-lora',
|
||||||
@ -822,6 +829,7 @@ class EngineArgs:
|
|||||||
use_async_output_proc=not self.disable_async_output_proc,
|
use_async_output_proc=not self.disable_async_output_proc,
|
||||||
override_neuron_config=self.override_neuron_config,
|
override_neuron_config=self.override_neuron_config,
|
||||||
config_format=self.config_format,
|
config_format=self.config_format,
|
||||||
|
mm_processor_kwargs=self.mm_processor_kwargs,
|
||||||
)
|
)
|
||||||
|
|
||||||
def create_load_config(self) -> LoadConfig:
|
def create_load_config(self) -> LoadConfig:
|
||||||
|
@ -235,7 +235,7 @@ class LLMEngine:
|
|||||||
"decoding_config=%r, observability_config=%r, "
|
"decoding_config=%r, observability_config=%r, "
|
||||||
"seed=%d, served_model_name=%s, use_v2_block_manager=%s, "
|
"seed=%d, served_model_name=%s, use_v2_block_manager=%s, "
|
||||||
"num_scheduler_steps=%d, enable_prefix_caching=%s, "
|
"num_scheduler_steps=%d, enable_prefix_caching=%s, "
|
||||||
"use_async_output_proc=%s)",
|
"use_async_output_proc=%s, mm_processor_kwargs=%s)",
|
||||||
VLLM_VERSION,
|
VLLM_VERSION,
|
||||||
model_config.model,
|
model_config.model,
|
||||||
speculative_config,
|
speculative_config,
|
||||||
@ -268,6 +268,7 @@ class LLMEngine:
|
|||||||
scheduler_config.num_scheduler_steps,
|
scheduler_config.num_scheduler_steps,
|
||||||
cache_config.enable_prefix_caching,
|
cache_config.enable_prefix_caching,
|
||||||
model_config.use_async_output_proc,
|
model_config.use_async_output_proc,
|
||||||
|
model_config.mm_processor_kwargs,
|
||||||
)
|
)
|
||||||
# TODO(woosuk): Print more configs in debug mode.
|
# TODO(woosuk): Print more configs in debug mode.
|
||||||
from vllm.plugins import load_general_plugins
|
from vllm.plugins import load_general_plugins
|
||||||
|
@ -134,6 +134,7 @@ class LLM:
|
|||||||
max_seq_len_to_capture: int = 8192,
|
max_seq_len_to_capture: int = 8192,
|
||||||
disable_custom_all_reduce: bool = False,
|
disable_custom_all_reduce: bool = False,
|
||||||
disable_async_output_proc: bool = False,
|
disable_async_output_proc: bool = False,
|
||||||
|
mm_processor_kwargs: Optional[Dict[str, Any]] = None,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
) -> None:
|
) -> None:
|
||||||
'''
|
'''
|
||||||
@ -174,6 +175,7 @@ class LLM:
|
|||||||
max_seq_len_to_capture=max_seq_len_to_capture,
|
max_seq_len_to_capture=max_seq_len_to_capture,
|
||||||
disable_custom_all_reduce=disable_custom_all_reduce,
|
disable_custom_all_reduce=disable_custom_all_reduce,
|
||||||
disable_async_output_proc=disable_async_output_proc,
|
disable_async_output_proc=disable_async_output_proc,
|
||||||
|
mm_processor_kwargs=mm_processor_kwargs,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
)
|
)
|
||||||
self.llm_engine = LLMEngine.from_engine_args(
|
self.llm_engine = LLMEngine.from_engine_args(
|
||||||
|
@ -9,6 +9,7 @@ from transformers import PretrainedConfig
|
|||||||
from typing_extensions import TypeVar
|
from typing_extensions import TypeVar
|
||||||
|
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
|
from vllm.utils import get_allowed_kwarg_only_overrides
|
||||||
|
|
||||||
from .data import LLMInputs
|
from .data import LLMInputs
|
||||||
|
|
||||||
@ -68,12 +69,17 @@ class DummyDataFactory(Protocol):
|
|||||||
ctx: InputContext,
|
ctx: InputContext,
|
||||||
seq_len: int,
|
seq_len: int,
|
||||||
mm_counts: Mapping[str, int],
|
mm_counts: Mapping[str, int],
|
||||||
|
**mm_processor_kwargs: Any,
|
||||||
) -> Tuple["SequenceData", Optional["MultiModalDataDict"]]:
|
) -> Tuple["SequenceData", Optional["MultiModalDataDict"]]:
|
||||||
"""
|
"""
|
||||||
Create dummy data to be inputted into the model.
|
Create dummy data to be inputted into the model.
|
||||||
|
|
||||||
Note:
|
Note:
|
||||||
:data:`InputProcessor` is not applied to the dummy data.
|
:data:`InputProcessor` is not applied to the dummy data.
|
||||||
|
|
||||||
|
The :code:`mm_processor_kwargs` are overrides provided at
|
||||||
|
initialization time to values in the config whose values
|
||||||
|
may affect the number of tokens per instance.
|
||||||
"""
|
"""
|
||||||
...
|
...
|
||||||
|
|
||||||
@ -152,6 +158,10 @@ class InputRegistry:
|
|||||||
|
|
||||||
return wrapper
|
return wrapper
|
||||||
|
|
||||||
|
def _get_dummy_data_factory(self, model_cls: Type[nn.Module]):
|
||||||
|
return self._dummy_factories_by_model_type \
|
||||||
|
.get(model_cls, self._default_dummy_data_factory)
|
||||||
|
|
||||||
def dummy_data_for_profiling(
|
def dummy_data_for_profiling(
|
||||||
self,
|
self,
|
||||||
model_config: "ModelConfig",
|
model_config: "ModelConfig",
|
||||||
@ -174,15 +184,15 @@ class InputRegistry:
|
|||||||
from vllm.model_executor.model_loader import get_model_architecture
|
from vllm.model_executor.model_loader import get_model_architecture
|
||||||
|
|
||||||
model_cls, _ = get_model_architecture(model_config)
|
model_cls, _ = get_model_architecture(model_config)
|
||||||
dummy_factory = self._dummy_factories_by_model_type \
|
dummy_factory = self._get_dummy_data_factory(model_cls)
|
||||||
.get(model_cls, self._default_dummy_data_factory)
|
|
||||||
mm_counts = mm_registry.get_mm_limits_per_prompt(model_config)
|
|
||||||
|
|
||||||
seq_data, mm_data = dummy_factory(
|
mm_counts = mm_registry.get_mm_limits_per_prompt(model_config)
|
||||||
InputContext(model_config),
|
mm_processor_kwargs = get_allowed_kwarg_only_overrides(
|
||||||
seq_len,
|
dummy_factory, overrides=model_config.mm_processor_kwargs)
|
||||||
_MultiModalCounts(mm_counts),
|
|
||||||
)
|
seq_data, mm_data = dummy_factory(InputContext(model_config), seq_len,
|
||||||
|
_MultiModalCounts(mm_counts),
|
||||||
|
**mm_processor_kwargs)
|
||||||
|
|
||||||
# Having more tokens is over-conservative but otherwise fine
|
# Having more tokens is over-conservative but otherwise fine
|
||||||
num_tokens = seq_data.prompt_token_ids
|
num_tokens = seq_data.prompt_token_ids
|
||||||
@ -229,6 +239,10 @@ class InputRegistry:
|
|||||||
|
|
||||||
return wrapper
|
return wrapper
|
||||||
|
|
||||||
|
def _get_model_input_processor(self, model_cls: Type[nn.Module]):
|
||||||
|
return self._input_processors_by_model_type \
|
||||||
|
.get(model_cls, self._default_input_processor)
|
||||||
|
|
||||||
def process_input(self, model_config: "ModelConfig",
|
def process_input(self, model_config: "ModelConfig",
|
||||||
inputs: LLMInputs) -> LLMInputs:
|
inputs: LLMInputs) -> LLMInputs:
|
||||||
"""
|
"""
|
||||||
@ -243,15 +257,17 @@ class InputRegistry:
|
|||||||
from vllm.model_executor.model_loader import get_model_architecture
|
from vllm.model_executor.model_loader import get_model_architecture
|
||||||
|
|
||||||
model_cls, _ = get_model_architecture(model_config)
|
model_cls, _ = get_model_architecture(model_config)
|
||||||
|
processor = self._get_model_input_processor(model_cls)
|
||||||
|
|
||||||
processor = self._input_processors_by_model_type \
|
mm_processor_kwargs = get_allowed_kwarg_only_overrides(
|
||||||
.get(model_cls, self._default_input_processor)
|
processor, overrides=model_config.mm_processor_kwargs)
|
||||||
|
|
||||||
return processor(InputContext(model_config), inputs)
|
return processor(InputContext(model_config), inputs,
|
||||||
|
**mm_processor_kwargs)
|
||||||
|
|
||||||
def create_input_processor(self, model_config: "ModelConfig"):
|
def create_input_processor(self, model_config: "ModelConfig"):
|
||||||
"""
|
"""
|
||||||
Create an input processor (see :meth:`process_input`) for a
|
Create an input processor (see :meth:`_process_input`) for a
|
||||||
specific model.
|
specific model.
|
||||||
"""
|
"""
|
||||||
return functools.partial(self.process_input, model_config)
|
return functools.partial(self.process_input, model_config)
|
||||||
|
@ -14,7 +14,8 @@ from typing_extensions import TypeAlias
|
|||||||
from vllm.config import ModelConfig
|
from vllm.config import ModelConfig
|
||||||
from vllm.inputs import InputContext
|
from vllm.inputs import InputContext
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
from vllm.utils import JSONTree, is_list_of, json_map_leaves
|
from vllm.utils import (JSONTree, get_allowed_kwarg_only_overrides, is_list_of,
|
||||||
|
json_map_leaves)
|
||||||
|
|
||||||
logger = init_logger(__name__)
|
logger = init_logger(__name__)
|
||||||
|
|
||||||
@ -256,11 +257,20 @@ class MultiModalPlugin(ABC):
|
|||||||
model_cls, _ = get_model_architecture(model_config)
|
model_cls, _ = get_model_architecture(model_config)
|
||||||
|
|
||||||
mapper = self._input_mappers.get(model_cls)
|
mapper = self._input_mappers.get(model_cls)
|
||||||
|
# Only get processor kwargs at mapping time if we are not using the
|
||||||
|
# input mapper; no overrides are used on the default here because they
|
||||||
|
# should be passed to the huggingface resource at initialization time.
|
||||||
|
if mapper is not None and mapper != self._default_input_mapper:
|
||||||
|
mm_processor_kwargs = get_allowed_kwarg_only_overrides(
|
||||||
|
mapper, overrides=model_config.mm_processor_kwargs)
|
||||||
|
else:
|
||||||
|
mm_processor_kwargs = {}
|
||||||
|
|
||||||
if mapper is None:
|
if mapper is None:
|
||||||
raise KeyError(f"No input mapper in {self} is registered for "
|
raise KeyError(f"No input mapper in {self} is registered for "
|
||||||
f"model class {model_cls.__name__}.")
|
f"model class {model_cls.__name__}.")
|
||||||
|
|
||||||
return mapper(InputContext(model_config), data)
|
return mapper(InputContext(model_config), data, **mm_processor_kwargs)
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def _default_max_multimodal_tokens(self, ctx: InputContext) -> int:
|
def _default_max_multimodal_tokens(self, ctx: InputContext) -> int:
|
||||||
@ -333,7 +343,10 @@ class MultiModalPlugin(ABC):
|
|||||||
f"for model class {model_cls.__name__} in {self}.")
|
f"for model class {model_cls.__name__} in {self}.")
|
||||||
|
|
||||||
if callable(max_mm_tokens):
|
if callable(max_mm_tokens):
|
||||||
max_mm_tokens = max_mm_tokens(InputContext(model_config))
|
mm_processor_kwargs = get_allowed_kwarg_only_overrides(
|
||||||
|
max_mm_tokens, overrides=model_config.mm_processor_kwargs)
|
||||||
|
max_mm_tokens = max_mm_tokens(InputContext(model_config),
|
||||||
|
**mm_processor_kwargs)
|
||||||
|
|
||||||
self._validate_max_multimodal_tokens(max_mm_tokens)
|
self._validate_max_multimodal_tokens(max_mm_tokens)
|
||||||
|
|
||||||
|
@ -6,7 +6,7 @@ from PIL import Image
|
|||||||
from vllm.config import ModelConfig
|
from vllm.config import ModelConfig
|
||||||
from vllm.inputs.registry import InputContext
|
from vllm.inputs.registry import InputContext
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
from vllm.transformers_utils.image_processor import get_image_processor
|
from vllm.transformers_utils.processor import get_image_processor
|
||||||
from vllm.utils import is_list_of
|
from vllm.utils import is_list_of
|
||||||
|
|
||||||
from .base import MultiModalData, MultiModalInputs, MultiModalPlugin
|
from .base import MultiModalData, MultiModalInputs, MultiModalPlugin
|
||||||
@ -23,9 +23,14 @@ class ImagePlugin(MultiModalPlugin):
|
|||||||
return "image"
|
return "image"
|
||||||
|
|
||||||
def _get_hf_image_processor(self, model_config: ModelConfig):
|
def _get_hf_image_processor(self, model_config: ModelConfig):
|
||||||
|
mm_processor_kwargs = ({} if model_config.mm_processor_kwargs is None
|
||||||
|
else model_config.mm_processor_kwargs)
|
||||||
|
# We don't explicitly check kwarg overrides to the HF class
|
||||||
|
# since the automodel just takes kwargs, so we can't inspect it
|
||||||
return cached_get_image_processor(
|
return cached_get_image_processor(
|
||||||
model_config.model,
|
model_config.model,
|
||||||
trust_remote_code=model_config.trust_remote_code)
|
trust_remote_code=model_config.trust_remote_code,
|
||||||
|
**mm_processor_kwargs)
|
||||||
|
|
||||||
def _default_input_mapper(
|
def _default_input_mapper(
|
||||||
self,
|
self,
|
||||||
@ -37,6 +42,7 @@ class ImagePlugin(MultiModalPlugin):
|
|||||||
# PIL image
|
# PIL image
|
||||||
if isinstance(data, Image.Image) or is_list_of(data, Image.Image):
|
if isinstance(data, Image.Image) or is_list_of(data, Image.Image):
|
||||||
image_processor = self._get_hf_image_processor(model_config)
|
image_processor = self._get_hf_image_processor(model_config)
|
||||||
|
|
||||||
if image_processor is None:
|
if image_processor is None:
|
||||||
raise RuntimeError("No HuggingFace processor is available "
|
raise RuntimeError("No HuggingFace processor is available "
|
||||||
"to process the image object")
|
"to process the image object")
|
||||||
|
@ -138,6 +138,15 @@ class MultiModalRegistry:
|
|||||||
"""
|
"""
|
||||||
Create an input mapper (see :meth:`map_input`) for a specific model.
|
Create an input mapper (see :meth:`map_input`) for a specific model.
|
||||||
"""
|
"""
|
||||||
|
# NOTE - we currently make the assumption that if a model has multiple
|
||||||
|
# supported modalities, they take the same kwargs. For the default,
|
||||||
|
# this could be an issue in the future if it falls back to two HF
|
||||||
|
# resources and we can't inspect the signature easily since it's
|
||||||
|
# getting initialized through the autoclass.
|
||||||
|
#
|
||||||
|
# If this is a problem in the future, we should revisit it, but since
|
||||||
|
# it potentially introduces a lot of complexity for a currently
|
||||||
|
# uncommon case, we do not for simplicity of both use & implementation
|
||||||
return functools.partial(self.map_input, model_config)
|
return functools.partial(self.map_input, model_config)
|
||||||
|
|
||||||
def register_max_multimodal_tokens(
|
def register_max_multimodal_tokens(
|
||||||
|
@ -6,7 +6,7 @@ import numpy as np
|
|||||||
from vllm.config import ModelConfig
|
from vllm.config import ModelConfig
|
||||||
from vllm.inputs.registry import InputContext
|
from vllm.inputs.registry import InputContext
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
from vllm.transformers_utils.image_processor import get_video_processor
|
from vllm.transformers_utils.processor import get_video_processor
|
||||||
from vllm.transformers_utils.tokenizer import get_tokenizer
|
from vllm.transformers_utils.tokenizer import get_tokenizer
|
||||||
from vllm.utils import is_list_of
|
from vllm.utils import is_list_of
|
||||||
|
|
||||||
@ -37,9 +37,14 @@ class VideoPlugin(ImagePlugin):
|
|||||||
return "video"
|
return "video"
|
||||||
|
|
||||||
def _get_hf_video_processor(self, model_config: ModelConfig):
|
def _get_hf_video_processor(self, model_config: ModelConfig):
|
||||||
|
mm_processor_kwargs = ({} if model_config.mm_processor_kwargs is None
|
||||||
|
else model_config.mm_processor_kwargs)
|
||||||
|
# We don't explicitly check kwarg overrides to the HF class
|
||||||
|
# since the automodel just takes kwargs, so we can't inspect it
|
||||||
return cached_get_video_processor(
|
return cached_get_video_processor(
|
||||||
model_config.model,
|
model_config.model,
|
||||||
trust_remote_code=model_config.trust_remote_code)
|
trust_remote_code=model_config.trust_remote_code,
|
||||||
|
**mm_processor_kwargs)
|
||||||
|
|
||||||
def _default_input_mapper(
|
def _default_input_mapper(
|
||||||
self,
|
self,
|
||||||
|
@ -1,64 +0,0 @@
|
|||||||
from typing import cast
|
|
||||||
|
|
||||||
|
|
||||||
def get_video_processor(
|
|
||||||
processor_name: str,
|
|
||||||
trust_remote_code: bool = False,
|
|
||||||
):
|
|
||||||
"""
|
|
||||||
Gets a processor for the given model name via HuggingFace.
|
|
||||||
"""
|
|
||||||
from transformers import AutoProcessor
|
|
||||||
|
|
||||||
try:
|
|
||||||
processor = AutoProcessor.from_pretrained(processor_name)
|
|
||||||
video_processor = processor.video_processor
|
|
||||||
|
|
||||||
except ValueError as e:
|
|
||||||
if not trust_remote_code:
|
|
||||||
err_msg = (
|
|
||||||
"Failed to load the processor. If the processor is "
|
|
||||||
"a custom processor not yet available in the HuggingFace "
|
|
||||||
"transformers library, consider setting "
|
|
||||||
"`trust_remote_code=True` in LLM or using the "
|
|
||||||
"`--trust-remote-code` flag in the CLI.")
|
|
||||||
raise RuntimeError(err_msg) from e
|
|
||||||
else:
|
|
||||||
raise e
|
|
||||||
return video_processor
|
|
||||||
|
|
||||||
|
|
||||||
def get_image_processor(
|
|
||||||
processor_name: str,
|
|
||||||
*args,
|
|
||||||
trust_remote_code: bool = False,
|
|
||||||
**kwargs,
|
|
||||||
):
|
|
||||||
"""Gets an image processor for the given model name via HuggingFace."""
|
|
||||||
# don't put this import at the top level
|
|
||||||
# it will call torch.cuda.device_count()
|
|
||||||
from transformers import AutoImageProcessor
|
|
||||||
from transformers.image_processing_utils import BaseImageProcessor
|
|
||||||
|
|
||||||
try:
|
|
||||||
processor = AutoImageProcessor.from_pretrained(
|
|
||||||
processor_name,
|
|
||||||
*args,
|
|
||||||
trust_remote_code=trust_remote_code,
|
|
||||||
**kwargs)
|
|
||||||
except ValueError as e:
|
|
||||||
# If the error pertains to the processor class not existing or not
|
|
||||||
# currently being imported, suggest using the --trust-remote-code flag.
|
|
||||||
# Unlike AutoTokenizer, AutoImageProcessor does not separate such errors
|
|
||||||
if not trust_remote_code:
|
|
||||||
err_msg = (
|
|
||||||
"Failed to load the image processor. If the image processor is "
|
|
||||||
"a custom processor not yet available in the HuggingFace "
|
|
||||||
"transformers library, consider setting "
|
|
||||||
"`trust_remote_code=True` in LLM or using the "
|
|
||||||
"`--trust-remote-code` flag in the CLI.")
|
|
||||||
raise RuntimeError(err_msg) from e
|
|
||||||
else:
|
|
||||||
raise e
|
|
||||||
|
|
||||||
return cast(BaseImageProcessor, processor)
|
|
@ -1,13 +1,13 @@
|
|||||||
from typing import cast
|
from typing import Any, cast
|
||||||
|
|
||||||
|
|
||||||
def get_processor(
|
def get_processor(
|
||||||
processor_name: str,
|
processor_name: str,
|
||||||
*args,
|
*args: Any,
|
||||||
trust_remote_code: bool = False,
|
trust_remote_code: bool = False,
|
||||||
**kwargs,
|
**kwargs: Any,
|
||||||
):
|
):
|
||||||
"""Gets a processor for the given model name via HuggingFace."""
|
"""Load a processor for the given model name via HuggingFace."""
|
||||||
# don't put this import at the top level
|
# don't put this import at the top level
|
||||||
# it will call torch.cuda.device_count()
|
# it will call torch.cuda.device_count()
|
||||||
from transformers import AutoProcessor
|
from transformers import AutoProcessor
|
||||||
@ -35,3 +35,60 @@ def get_processor(
|
|||||||
raise e
|
raise e
|
||||||
|
|
||||||
return cast(ProcessorMixin, processor)
|
return cast(ProcessorMixin, processor)
|
||||||
|
|
||||||
|
|
||||||
|
def get_image_processor(
|
||||||
|
processor_name: str,
|
||||||
|
*args: Any,
|
||||||
|
trust_remote_code: bool = False,
|
||||||
|
**kwargs: Any,
|
||||||
|
):
|
||||||
|
"""Load an image processor for the given model name via HuggingFace."""
|
||||||
|
# don't put this import at the top level
|
||||||
|
# it will call torch.cuda.device_count()
|
||||||
|
from transformers import AutoImageProcessor
|
||||||
|
from transformers.image_processing_utils import BaseImageProcessor
|
||||||
|
|
||||||
|
try:
|
||||||
|
processor = AutoImageProcessor.from_pretrained(
|
||||||
|
processor_name,
|
||||||
|
*args,
|
||||||
|
trust_remote_code=trust_remote_code,
|
||||||
|
**kwargs)
|
||||||
|
except ValueError as e:
|
||||||
|
# If the error pertains to the processor class not existing or not
|
||||||
|
# currently being imported, suggest using the --trust-remote-code flag.
|
||||||
|
# Unlike AutoTokenizer, AutoImageProcessor does not separate such errors
|
||||||
|
if not trust_remote_code:
|
||||||
|
err_msg = (
|
||||||
|
"Failed to load the image processor. If the image processor is "
|
||||||
|
"a custom processor not yet available in the HuggingFace "
|
||||||
|
"transformers library, consider setting "
|
||||||
|
"`trust_remote_code=True` in LLM or using the "
|
||||||
|
"`--trust-remote-code` flag in the CLI.")
|
||||||
|
raise RuntimeError(err_msg) from e
|
||||||
|
else:
|
||||||
|
raise e
|
||||||
|
|
||||||
|
return cast(BaseImageProcessor, processor)
|
||||||
|
|
||||||
|
|
||||||
|
def get_video_processor(
|
||||||
|
processor_name: str,
|
||||||
|
*args: Any,
|
||||||
|
trust_remote_code: bool = False,
|
||||||
|
**kwargs: Any,
|
||||||
|
):
|
||||||
|
"""Load a video processor for the given model name via HuggingFace."""
|
||||||
|
# don't put this import at the top level
|
||||||
|
# it will call torch.cuda.device_count()
|
||||||
|
from transformers.image_processing_utils import BaseImageProcessor
|
||||||
|
|
||||||
|
processor = get_processor(
|
||||||
|
processor_name,
|
||||||
|
*args,
|
||||||
|
trust_remote_code=trust_remote_code,
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
|
|
||||||
|
return cast(BaseImageProcessor, processor.video_processor)
|
||||||
|
@ -4,6 +4,7 @@ import contextlib
|
|||||||
import datetime
|
import datetime
|
||||||
import enum
|
import enum
|
||||||
import gc
|
import gc
|
||||||
|
import inspect
|
||||||
import os
|
import os
|
||||||
import random
|
import random
|
||||||
import socket
|
import socket
|
||||||
@ -1237,6 +1238,53 @@ async def _run_task_with_lock(task: Callable, lock: asyncio.Lock, *args,
|
|||||||
return await task(*args, **kwargs)
|
return await task(*args, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
def get_allowed_kwarg_only_overrides(
|
||||||
|
callable: Callable[..., object],
|
||||||
|
overrides: Optional[Dict[str, Any]],
|
||||||
|
) -> Dict[str, Any]:
|
||||||
|
"""
|
||||||
|
Given a callable which has one or more keyword only params and a dict
|
||||||
|
mapping param names to values, drop values that can be not be kwarg
|
||||||
|
expanded to overwrite one or more keyword-only args. This is used in a
|
||||||
|
few places to handle custom processor overrides for multimodal models,
|
||||||
|
e.g., for profiling when processor options provided by the user
|
||||||
|
may affect the number of mm tokens per instance.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
callable: Callable which takes 0 or more keyword only arguments.
|
||||||
|
overrides: Potential overrides to be used when invoking the callable.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dictionary containing the kwargs to be leveraged which may be used
|
||||||
|
to overwrite one or more keyword only arguments when invoking the
|
||||||
|
callable.
|
||||||
|
"""
|
||||||
|
if not overrides:
|
||||||
|
return {}
|
||||||
|
|
||||||
|
allowed_override_names = [
|
||||||
|
name for name, param in inspect.signature(callable).parameters.items()
|
||||||
|
if param.kind == inspect.Parameter.KEYWORD_ONLY
|
||||||
|
]
|
||||||
|
|
||||||
|
# Drop any mm_processor_kwargs provided by the user that are
|
||||||
|
# not kwarg names accepted by the provided input processor.
|
||||||
|
filtered_overrides = {
|
||||||
|
kwarg_name: val
|
||||||
|
for kwarg_name, val in overrides.items()
|
||||||
|
if kwarg_name in allowed_override_names
|
||||||
|
}
|
||||||
|
|
||||||
|
# If anything is dropped, log a warning
|
||||||
|
dropped_keys = overrides.keys() - filtered_overrides.keys()
|
||||||
|
if dropped_keys:
|
||||||
|
logger.warning(
|
||||||
|
"The following intended overrides are not keyword-only args "
|
||||||
|
"and and will be dropped: %s", dropped_keys)
|
||||||
|
|
||||||
|
return filtered_overrides
|
||||||
|
|
||||||
|
|
||||||
# Using dynamo with vLLM doesn't really work well with PyTorch versions < 2.4.0.
|
# Using dynamo with vLLM doesn't really work well with PyTorch versions < 2.4.0.
|
||||||
# In particular, the FakeScalarType is not supported for earlier versions of
|
# In particular, the FakeScalarType is not supported for earlier versions of
|
||||||
# PyTorch which breaks dynamo for any ops registered using ScalarType.
|
# PyTorch which breaks dynamo for any ops registered using ScalarType.
|
||||||
|
Loading…
x
Reference in New Issue
Block a user