[Core][Frontend] Support Passing Multimodal Processor Kwargs (#8657)
Signed-off-by: Alex-Brooks <Alex.Brooks@ibm.com>
This commit is contained in:
parent
d23679eb99
commit
9b8c8ba119
@ -40,3 +40,24 @@ def test_limit_mm_per_prompt_parser(arg, expected):
|
||||
def test_bad_nullable_kvs(arg):
|
||||
with pytest.raises(ArgumentTypeError):
|
||||
nullable_kvs(arg)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(("arg", "expected"), [
|
||||
(None, None),
|
||||
("{}", {}),
|
||||
('{"num_crops": 4}', {
|
||||
"num_crops": 4
|
||||
}),
|
||||
('{"foo": {"bar": "baz"}}', {
|
||||
"foo": {
|
||||
"bar": "baz"
|
||||
}
|
||||
}),
|
||||
])
|
||||
def test_mm_processor_kwargs_prompt_parser(arg, expected):
|
||||
parser = EngineArgs.add_cli_args(FlexibleArgumentParser())
|
||||
if arg is None:
|
||||
args = parser.parse_args([])
|
||||
else:
|
||||
args = parser.parse_args(["--mm-processor-kwargs", arg])
|
||||
assert args.mm_processor_kwargs == expected
|
||||
|
@ -5,14 +5,13 @@ import pytest
|
||||
import torch
|
||||
from PIL.Image import Image
|
||||
|
||||
from vllm.config import ModelConfig
|
||||
from vllm.inputs import InputContext, LLMInputs
|
||||
from vllm.multimodal.base import MultiModalInputs
|
||||
from vllm.multimodal.utils import cached_get_tokenizer, rescale_image_size
|
||||
|
||||
from ....conftest import (IMAGE_ASSETS, HfRunner, ImageAsset, PromptImageInput,
|
||||
VllmRunner, _ImageAssets)
|
||||
from ...utils import check_logprobs_close
|
||||
from ...utils import build_model_context, check_logprobs_close
|
||||
|
||||
text_only_models = [
|
||||
"Qwen/Qwen-7B-Chat" # Has no visual component
|
||||
@ -42,32 +41,6 @@ VIS_ENC_DIM = 4096
|
||||
IMG_SIZE = 448
|
||||
|
||||
|
||||
def build_model_context(model_name: str,
|
||||
tokenizer_name: Optional[str] = None,
|
||||
trust_remote_code: bool = False):
|
||||
"""Creates an InputContext for a given model.
|
||||
|
||||
Args:
|
||||
model_name: Name of the model being considered.
|
||||
tokenizer_name: Name of the tokenizer being considered.
|
||||
trust_remote_code: Whether or not to allow loading remote code.
|
||||
|
||||
Returns:
|
||||
InputContext for the model being considered.
|
||||
"""
|
||||
if tokenizer_name is None:
|
||||
tokenizer_name = model_name
|
||||
model_config = ModelConfig(
|
||||
model_name,
|
||||
tokenizer_name,
|
||||
tokenizer_mode="auto",
|
||||
trust_remote_code=trust_remote_code,
|
||||
dtype="float32",
|
||||
seed=0,
|
||||
)
|
||||
return InputContext(model_config)
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
def input_mapper_for_qwen():
|
||||
# Lazy import to avoid initializing CUDA during test collection
|
||||
|
@ -1,6 +1,8 @@
|
||||
import warnings
|
||||
from typing import Dict, List, Optional, Sequence, Tuple, Union
|
||||
|
||||
from vllm.config import ModelConfig
|
||||
from vllm.inputs import InputContext
|
||||
from vllm.sequence import Logprob, PromptLogprobs, SampleLogprobs
|
||||
|
||||
TokensText = Tuple[List[int], str]
|
||||
@ -240,3 +242,36 @@ def check_logprobs_close(
|
||||
warnings.simplefilter("always")
|
||||
|
||||
warnings.warn(fail_msg, stacklevel=2)
|
||||
|
||||
|
||||
def build_model_context(model_name: str,
|
||||
tokenizer_name: Optional[str] = None,
|
||||
trust_remote_code: bool = False,
|
||||
mm_processor_kwargs: Optional[Dict] = None,
|
||||
limit_mm_per_prompt: Optional[Dict] = None):
|
||||
"""Creates an InputContext for a given model.
|
||||
|
||||
Args:
|
||||
model_name: Name of the model being considered.
|
||||
tokenizer_name: Name of the tokenizer being considered.
|
||||
trust_remote_code: Whether or not to allow loading remote code.
|
||||
mm_processor_kwargs: optional processor kwargs for to be leveraged
|
||||
in the input processor, mapper, dummy data creation, etc.
|
||||
limit_mm_per_prompt: Multimodal limits.
|
||||
|
||||
Returns:
|
||||
InputContext for the model being considered.
|
||||
"""
|
||||
if tokenizer_name is None:
|
||||
tokenizer_name = model_name
|
||||
model_config = ModelConfig(
|
||||
model_name,
|
||||
tokenizer_name,
|
||||
tokenizer_mode="auto",
|
||||
trust_remote_code=trust_remote_code,
|
||||
dtype="float32",
|
||||
seed=0,
|
||||
mm_processor_kwargs=mm_processor_kwargs,
|
||||
limit_mm_per_prompt=limit_mm_per_prompt,
|
||||
)
|
||||
return InputContext(model_config)
|
||||
|
339
tests/multimodal/test_processor_kwargs.py
Normal file
339
tests/multimodal/test_processor_kwargs.py
Normal file
@ -0,0 +1,339 @@
|
||||
from array import array
|
||||
from typing import Mapping
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from vllm.inputs import InputContext, LLMInputs
|
||||
from vllm.inputs.registry import InputRegistry
|
||||
from vllm.multimodal import MultiModalRegistry
|
||||
from vllm.sequence import VLLM_TOKEN_ID_ARRAY_TYPE, SequenceData
|
||||
|
||||
from ..models.utils import build_model_context
|
||||
|
||||
# Used for fast tests where the model doesn't matter
|
||||
DUMMY_MODEL_ID = "facebook/opt-125m"
|
||||
# Used for tests that need a multimodal model
|
||||
MULTIMODAL_MODEL_ID = "microsoft/Phi-3.5-vision-instruct"
|
||||
|
||||
# For mm_processor_kwargs - we test overrides by defining mocks for each place
|
||||
# it is used, and ensuring that we can pass processor kwargs an override value
|
||||
# to receive the intended result for things like sequence length etc.
|
||||
DEFAULT_NUM_CROPS = 4
|
||||
NUM_CROPS_OVERRIDE = 16
|
||||
|
||||
|
||||
# Mocks for all of the places that we use the mm_processor_kwargs
|
||||
# to override values in different callables
|
||||
@pytest.fixture
|
||||
def use_processor_mock():
|
||||
"""Patches the internal model input processor with an override callable."""
|
||||
|
||||
def custom_processor(ctx: InputContext,
|
||||
llm_inputs: LLMInputs,
|
||||
*,
|
||||
num_crops=DEFAULT_NUM_CROPS):
|
||||
# For testing purposes, we don't worry about the llm inputs / return
|
||||
# type validation, and just return the value of the kwarg that we
|
||||
# clobber.
|
||||
return num_crops
|
||||
|
||||
with patch("vllm.inputs.registry.InputRegistry._get_model_input_processor",
|
||||
return_value=custom_processor):
|
||||
yield
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def use_dummy_data_mock():
|
||||
"""Patches the internal model input processor with an override callable."""
|
||||
|
||||
def custom_dummy_data_factory(self,
|
||||
ctx: InputContext,
|
||||
seq_len: int,
|
||||
mm_counts: Mapping[str, int],
|
||||
*,
|
||||
num_crops=DEFAULT_NUM_CROPS):
|
||||
seq_data = SequenceData(
|
||||
array(VLLM_TOKEN_ID_ARRAY_TYPE, [0] * num_crops))
|
||||
return seq_data, None
|
||||
|
||||
with patch(
|
||||
"vllm.inputs.registry.InputRegistry._default_dummy_data_factory",
|
||||
custom_dummy_data_factory):
|
||||
yield
|
||||
|
||||
|
||||
# Lazy import to avoid CUDA reinitialization error
|
||||
def mm_model_cls():
|
||||
from vllm.model_executor.models.phi3v import Phi3VForCausalLM
|
||||
|
||||
return Phi3VForCausalLM
|
||||
|
||||
|
||||
# lambda whose signature matches max token calcs extra & mapper + extra kwargs
|
||||
get_num_crops = lambda ctx, *, num_crops=DEFAULT_NUM_CROPS: num_crops
|
||||
custom_mapper = lambda ctx, data, *, num_crops=DEFAULT_NUM_CROPS: {
|
||||
"num_pixels": torch.zeros(size=(1, num_crops + 1, 3, 336, 336))
|
||||
}
|
||||
|
||||
|
||||
### Test for default processor logic & mm_processor_kwargs wrapping
|
||||
def test_default_processor_is_a_noop():
|
||||
"""Ensure that by default, there is no processor override."""
|
||||
dummy_registry = InputRegistry()
|
||||
ctx = build_model_context(DUMMY_MODEL_ID)
|
||||
processor = dummy_registry.create_input_processor(ctx.model_config)
|
||||
proc_inputs = LLMInputs(prompt_token_ids=[], prompt="")
|
||||
proc_outputs = processor(inputs=proc_inputs)
|
||||
assert proc_inputs is proc_outputs
|
||||
|
||||
|
||||
@pytest.mark.parametrize("num_crops", [None, NUM_CROPS_OVERRIDE])
|
||||
def test_processor_default_kwargs(use_processor_mock, num_crops):
|
||||
"""Ensure input processors can use processor kwargs."""
|
||||
dummy_registry = InputRegistry()
|
||||
# If we have a value for num_crops, pass the override value and make
|
||||
# sure we get that value as a return-value from out mock processor,
|
||||
# otherwise fall back to the default value
|
||||
mm_processor_kwargs = None if num_crops is None else {
|
||||
"num_crops": num_crops
|
||||
}
|
||||
expected_num_crops = DEFAULT_NUM_CROPS if num_crops is None else num_crops
|
||||
ctx = build_model_context(DUMMY_MODEL_ID,
|
||||
mm_processor_kwargs=mm_processor_kwargs)
|
||||
processor = dummy_registry.create_input_processor(ctx.model_config)
|
||||
|
||||
num_crops_val = processor(LLMInputs(prompt_token_ids=[], prompt=""))
|
||||
assert num_crops_val == expected_num_crops
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"mm_processor_kwargs",
|
||||
[
|
||||
# Not part of the signature
|
||||
{
|
||||
"does_not_exist": 100
|
||||
},
|
||||
# Part of the signature, not keyword only
|
||||
{
|
||||
"ctx": "something bad"
|
||||
}
|
||||
])
|
||||
def test_processor_with_sad_kwarg_overrides(use_processor_mock,
|
||||
mm_processor_kwargs):
|
||||
"""Ensure that input processors filter out invalid mm_processor_kwargs"""
|
||||
dummy_registry = InputRegistry()
|
||||
ctx = build_model_context(DUMMY_MODEL_ID,
|
||||
mm_processor_kwargs=mm_processor_kwargs)
|
||||
|
||||
processor = dummy_registry.create_input_processor(ctx.model_config)
|
||||
num_crops_val = processor(LLMInputs(prompt_token_ids=[], prompt=""))
|
||||
assert num_crops_val == DEFAULT_NUM_CROPS
|
||||
|
||||
|
||||
### Test overrides for the dummy data
|
||||
@pytest.mark.parametrize("num_crops", [None, NUM_CROPS_OVERRIDE])
|
||||
def test_dummy_data_kwarg_overrides(use_dummy_data_mock, num_crops):
|
||||
"""Ensure dummy data factories can use processor kwargs."""
|
||||
mm_processor_kwargs = None if num_crops is None else {
|
||||
"num_crops": num_crops
|
||||
}
|
||||
expected_seq_count = DEFAULT_NUM_CROPS if num_crops is None else num_crops
|
||||
dummy_registry = InputRegistry()
|
||||
ctx = build_model_context(DUMMY_MODEL_ID,
|
||||
mm_processor_kwargs=mm_processor_kwargs)
|
||||
mm_registry = MultiModalRegistry()
|
||||
mm_registry.init_mm_limits_per_prompt(ctx.model_config)
|
||||
|
||||
# NOTE: seq_len is thrown away here since this will leverage the
|
||||
# default dummy data factory that we have patched in, whose seq
|
||||
# len is solely dependent on the value of the mm_processor_kwargs.
|
||||
seq_data, _ = dummy_registry.dummy_data_for_profiling(
|
||||
ctx.model_config, seq_len=-1, mm_registry=mm_registry)
|
||||
assert len(seq_data.prompt_token_ids) == expected_seq_count
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"mm_processor_kwargs",
|
||||
[
|
||||
# Not part of the signature
|
||||
{
|
||||
"does_not_exist": 100
|
||||
},
|
||||
# Part of the signature, not keyword only
|
||||
{
|
||||
"ctx": "something bad"
|
||||
}
|
||||
])
|
||||
def test_dummy_data_with_sad_kwarg_overrides(use_dummy_data_mock,
|
||||
mm_processor_kwargs):
|
||||
"""Ensure the dummy data factory filters out invalid mm_processor_kwargs"""
|
||||
dummy_registry = InputRegistry()
|
||||
ctx = build_model_context(DUMMY_MODEL_ID,
|
||||
mm_processor_kwargs=mm_processor_kwargs)
|
||||
mm_registry = MultiModalRegistry()
|
||||
mm_registry.init_mm_limits_per_prompt(ctx.model_config)
|
||||
|
||||
# NOTE: seq_len is thrown away here since this will leverage the
|
||||
# default dummy data factory that we have patched in, whose seq
|
||||
# len is solely dependent on the value of the mm_processor_kwargs.
|
||||
seq_data, _ = dummy_registry.dummy_data_for_profiling(
|
||||
ctx.model_config, seq_len=-1, mm_registry=mm_registry)
|
||||
assert len(seq_data.prompt_token_ids) == DEFAULT_NUM_CROPS
|
||||
|
||||
|
||||
### Test overrides for the max token count per multimodal instance
|
||||
@pytest.mark.parametrize("num_crops", [None, NUM_CROPS_OVERRIDE])
|
||||
def test_max_tokens_kwarg_overrides(num_crops):
|
||||
"""Ensure max token calcs can use processor kwargs."""
|
||||
mm_processor_kwargs = None if num_crops is None else {
|
||||
"num_crops": num_crops
|
||||
}
|
||||
expected_seq_count = DEFAULT_NUM_CROPS if num_crops is None else num_crops
|
||||
|
||||
ctx = build_model_context(MULTIMODAL_MODEL_ID,
|
||||
trust_remote_code=True,
|
||||
mm_processor_kwargs=mm_processor_kwargs,
|
||||
limit_mm_per_prompt={"image": 1})
|
||||
|
||||
mm_registry = MultiModalRegistry()
|
||||
mm_registry.init_mm_limits_per_prompt(ctx.model_config)
|
||||
# Patch the image registry for phi3v with our lambda that is compatible
|
||||
# with overrides, then ensure that calling the method correctly echos
|
||||
# our num_crops value back from the mm_processor_kwargs.
|
||||
with patch.object(
|
||||
mm_registry._get_plugin("image"),
|
||||
"_max_mm_tokens",
|
||||
{mm_model_cls(): get_num_crops},
|
||||
):
|
||||
max_multimodal_tokens = mm_registry.get_max_multimodal_tokens(
|
||||
ctx.model_config)
|
||||
|
||||
assert expected_seq_count == max_multimodal_tokens
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"mm_processor_kwargs",
|
||||
[
|
||||
# Not part of the signature
|
||||
{
|
||||
"does_not_exist": 100
|
||||
},
|
||||
# Part of the signature, not keyword only
|
||||
{
|
||||
"ctx": "something bad"
|
||||
}
|
||||
])
|
||||
def test_max_tokens_with_sad_kwarg_overrides(mm_processor_kwargs):
|
||||
"""Ensure that max token calcs filters out invalid mm_processor_kwargs"""
|
||||
ctx = build_model_context(MULTIMODAL_MODEL_ID,
|
||||
trust_remote_code=True,
|
||||
mm_processor_kwargs=mm_processor_kwargs,
|
||||
limit_mm_per_prompt={"image": 1})
|
||||
|
||||
mm_registry = MultiModalRegistry()
|
||||
mm_registry.init_mm_limits_per_prompt(ctx.model_config)
|
||||
|
||||
# Similar before, but since these kwargs get filtered,
|
||||
# we always get our default value back.
|
||||
with patch.object(
|
||||
mm_registry._get_plugin("image"),
|
||||
"_max_mm_tokens",
|
||||
{mm_model_cls(): get_num_crops},
|
||||
):
|
||||
max_multimodal_tokens = mm_registry.get_max_multimodal_tokens(
|
||||
ctx.model_config)
|
||||
|
||||
assert max_multimodal_tokens == DEFAULT_NUM_CROPS
|
||||
|
||||
|
||||
### Test overrides for the mapper
|
||||
@pytest.mark.parametrize("num_crops", [DEFAULT_NUM_CROPS, NUM_CROPS_OVERRIDE])
|
||||
def test_default_mapper_with_processer_kwargs(image_assets, num_crops):
|
||||
"""Ensure that the mapper processor kwargs can fall back to HF models."""
|
||||
# NOTE - we don't validate bad inputs for the default mapper, because it's
|
||||
# through the automodel interface in transformers, so we can't easily
|
||||
# inspect what kwargs are or are not allowed.
|
||||
ctx = build_model_context(MULTIMODAL_MODEL_ID,
|
||||
trust_remote_code=True,
|
||||
mm_processor_kwargs={"num_crops": num_crops},
|
||||
limit_mm_per_prompt={"image": 1})
|
||||
|
||||
mm_registry = MultiModalRegistry()
|
||||
mm_registry.init_mm_limits_per_prompt(ctx.model_config)
|
||||
|
||||
image = image_assets[0].pil_image
|
||||
mm_inputs = {"image": image}
|
||||
|
||||
mapped_inputs = mm_registry.map_input(ctx.model_config, mm_inputs)
|
||||
# Phi3v pixel vals should have shape: [batch, num_crops+1, 3, 336, 336]
|
||||
assert mapped_inputs["pixel_values"].shape[1] == num_crops + 1
|
||||
|
||||
|
||||
@pytest.mark.parametrize("num_crops", [None, NUM_CROPS_OVERRIDE])
|
||||
def test_custom_mapper_kwarg_overrides(image_assets, num_crops):
|
||||
"""Ensure custom mappers can use processor kwargs."""
|
||||
mm_processor_kwargs = None if num_crops is None else {
|
||||
"num_crops": num_crops
|
||||
}
|
||||
expected_seq_count = DEFAULT_NUM_CROPS if num_crops is None else num_crops
|
||||
ctx = build_model_context(MULTIMODAL_MODEL_ID,
|
||||
trust_remote_code=True,
|
||||
mm_processor_kwargs=mm_processor_kwargs,
|
||||
limit_mm_per_prompt={"image": 1})
|
||||
|
||||
mm_registry = MultiModalRegistry()
|
||||
mm_registry.init_mm_limits_per_prompt(ctx.model_config)
|
||||
# Patch the image registry for phi3v with our lambda that is compatible
|
||||
# with overrides, then ensure that calling the method correctly echos
|
||||
# our num_crops value back from the mm_processor_kwargs.
|
||||
image = image_assets[0].pil_image
|
||||
mm_inputs = {"image": image}
|
||||
|
||||
with patch.object(
|
||||
mm_registry._get_plugin("image"),
|
||||
"_default_input_mapper",
|
||||
{mm_model_cls(): custom_mapper},
|
||||
):
|
||||
mapped_inputs = mm_registry.map_input(ctx.model_config, mm_inputs)
|
||||
|
||||
assert mapped_inputs["pixel_values"].shape[1] == expected_seq_count + 1
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"mm_processor_kwargs",
|
||||
[
|
||||
# Not part of the signature
|
||||
{
|
||||
"does_not_exist": 100
|
||||
},
|
||||
# Part of the signature, not keyword only
|
||||
{
|
||||
"ctx": "something bad"
|
||||
}
|
||||
])
|
||||
def test_custom_mapper_with_sad_kwarg_overrides(image_assets,
|
||||
mm_processor_kwargs):
|
||||
"""Ensure that custom mappers filters out invalid mm_processor_kwargs"""
|
||||
ctx = build_model_context(MULTIMODAL_MODEL_ID,
|
||||
trust_remote_code=True,
|
||||
mm_processor_kwargs=mm_processor_kwargs,
|
||||
limit_mm_per_prompt={"image": 1})
|
||||
|
||||
mm_registry = MultiModalRegistry()
|
||||
mm_registry.init_mm_limits_per_prompt(ctx.model_config)
|
||||
# Patch the image registry for phi3v with our lambda that is compatible
|
||||
# with overrides, then ensure that calling the method correctly echos
|
||||
# our num_crops value back from the mm_processor_kwargs.
|
||||
image = image_assets[0].pil_image
|
||||
mm_inputs = {"image": image}
|
||||
|
||||
with patch.object(
|
||||
mm_registry._get_plugin("image"),
|
||||
"_default_input_mapper",
|
||||
{mm_model_cls(): custom_mapper},
|
||||
):
|
||||
mapped_inputs = mm_registry.map_input(ctx.model_config, mm_inputs)
|
||||
|
||||
assert mapped_inputs["pixel_values"].shape[1] == DEFAULT_NUM_CROPS + 1
|
@ -122,6 +122,8 @@ class ModelConfig:
|
||||
can not be gathered from the vllm arguments.
|
||||
config_format: The config format which shall be loaded.
|
||||
Defaults to 'auto' which defaults to 'hf'.
|
||||
mm_processor_kwargs: Arguments to be forwarded to the model's processor
|
||||
for multi-modal data, e.g., image processor.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
@ -150,7 +152,8 @@ class ModelConfig:
|
||||
limit_mm_per_prompt: Optional[Mapping[str, int]] = None,
|
||||
use_async_output_proc: bool = True,
|
||||
override_neuron_config: Optional[Dict[str, Any]] = None,
|
||||
config_format: ConfigFormat = ConfigFormat.AUTO) -> None:
|
||||
config_format: ConfigFormat = ConfigFormat.AUTO,
|
||||
mm_processor_kwargs: Optional[Dict[str, Any]] = None) -> None:
|
||||
self.model = model
|
||||
self.tokenizer = tokenizer
|
||||
self.tokenizer_mode = tokenizer_mode
|
||||
@ -184,6 +187,7 @@ class ModelConfig:
|
||||
self.model, revision)
|
||||
self.dtype = _get_and_verify_dtype(self.hf_text_config, dtype)
|
||||
self.use_async_output_proc = use_async_output_proc
|
||||
self.mm_processor_kwargs = mm_processor_kwargs
|
||||
|
||||
# Set enforce_eager to False if the value is unset.
|
||||
if self.enforce_eager is None:
|
||||
|
@ -175,6 +175,7 @@ class EngineArgs:
|
||||
collect_detailed_traces: Optional[str] = None
|
||||
disable_async_output_proc: bool = False
|
||||
override_neuron_config: Optional[Dict[str, Any]] = None
|
||||
mm_processor_kwargs: Optional[Dict[str, Any]] = None
|
||||
|
||||
def __post_init__(self):
|
||||
if self.tokenizer is None:
|
||||
@ -513,6 +514,12 @@ class EngineArgs:
|
||||
'e.g.: `image=16,video=2` allows a maximum of 16 '
|
||||
'images and 2 videos per prompt. Defaults to 1 for '
|
||||
'each modality.'))
|
||||
parser.add_argument(
|
||||
'--mm-processor-kwargs',
|
||||
default=None,
|
||||
type=json.loads,
|
||||
help=('Overrides for the multimodal input mapping/processing,'
|
||||
'e.g., image processor. For example: {"num_crops": 4}.'))
|
||||
|
||||
# LoRA related configs
|
||||
parser.add_argument('--enable-lora',
|
||||
@ -822,6 +829,7 @@ class EngineArgs:
|
||||
use_async_output_proc=not self.disable_async_output_proc,
|
||||
override_neuron_config=self.override_neuron_config,
|
||||
config_format=self.config_format,
|
||||
mm_processor_kwargs=self.mm_processor_kwargs,
|
||||
)
|
||||
|
||||
def create_load_config(self) -> LoadConfig:
|
||||
|
@ -235,7 +235,7 @@ class LLMEngine:
|
||||
"decoding_config=%r, observability_config=%r, "
|
||||
"seed=%d, served_model_name=%s, use_v2_block_manager=%s, "
|
||||
"num_scheduler_steps=%d, enable_prefix_caching=%s, "
|
||||
"use_async_output_proc=%s)",
|
||||
"use_async_output_proc=%s, mm_processor_kwargs=%s)",
|
||||
VLLM_VERSION,
|
||||
model_config.model,
|
||||
speculative_config,
|
||||
@ -268,6 +268,7 @@ class LLMEngine:
|
||||
scheduler_config.num_scheduler_steps,
|
||||
cache_config.enable_prefix_caching,
|
||||
model_config.use_async_output_proc,
|
||||
model_config.mm_processor_kwargs,
|
||||
)
|
||||
# TODO(woosuk): Print more configs in debug mode.
|
||||
from vllm.plugins import load_general_plugins
|
||||
|
@ -134,6 +134,7 @@ class LLM:
|
||||
max_seq_len_to_capture: int = 8192,
|
||||
disable_custom_all_reduce: bool = False,
|
||||
disable_async_output_proc: bool = False,
|
||||
mm_processor_kwargs: Optional[Dict[str, Any]] = None,
|
||||
**kwargs,
|
||||
) -> None:
|
||||
'''
|
||||
@ -174,6 +175,7 @@ class LLM:
|
||||
max_seq_len_to_capture=max_seq_len_to_capture,
|
||||
disable_custom_all_reduce=disable_custom_all_reduce,
|
||||
disable_async_output_proc=disable_async_output_proc,
|
||||
mm_processor_kwargs=mm_processor_kwargs,
|
||||
**kwargs,
|
||||
)
|
||||
self.llm_engine = LLMEngine.from_engine_args(
|
||||
|
@ -9,6 +9,7 @@ from transformers import PretrainedConfig
|
||||
from typing_extensions import TypeVar
|
||||
|
||||
from vllm.logger import init_logger
|
||||
from vllm.utils import get_allowed_kwarg_only_overrides
|
||||
|
||||
from .data import LLMInputs
|
||||
|
||||
@ -68,12 +69,17 @@ class DummyDataFactory(Protocol):
|
||||
ctx: InputContext,
|
||||
seq_len: int,
|
||||
mm_counts: Mapping[str, int],
|
||||
**mm_processor_kwargs: Any,
|
||||
) -> Tuple["SequenceData", Optional["MultiModalDataDict"]]:
|
||||
"""
|
||||
Create dummy data to be inputted into the model.
|
||||
|
||||
Note:
|
||||
:data:`InputProcessor` is not applied to the dummy data.
|
||||
|
||||
The :code:`mm_processor_kwargs` are overrides provided at
|
||||
initialization time to values in the config whose values
|
||||
may affect the number of tokens per instance.
|
||||
"""
|
||||
...
|
||||
|
||||
@ -152,6 +158,10 @@ class InputRegistry:
|
||||
|
||||
return wrapper
|
||||
|
||||
def _get_dummy_data_factory(self, model_cls: Type[nn.Module]):
|
||||
return self._dummy_factories_by_model_type \
|
||||
.get(model_cls, self._default_dummy_data_factory)
|
||||
|
||||
def dummy_data_for_profiling(
|
||||
self,
|
||||
model_config: "ModelConfig",
|
||||
@ -174,15 +184,15 @@ class InputRegistry:
|
||||
from vllm.model_executor.model_loader import get_model_architecture
|
||||
|
||||
model_cls, _ = get_model_architecture(model_config)
|
||||
dummy_factory = self._dummy_factories_by_model_type \
|
||||
.get(model_cls, self._default_dummy_data_factory)
|
||||
mm_counts = mm_registry.get_mm_limits_per_prompt(model_config)
|
||||
dummy_factory = self._get_dummy_data_factory(model_cls)
|
||||
|
||||
seq_data, mm_data = dummy_factory(
|
||||
InputContext(model_config),
|
||||
seq_len,
|
||||
_MultiModalCounts(mm_counts),
|
||||
)
|
||||
mm_counts = mm_registry.get_mm_limits_per_prompt(model_config)
|
||||
mm_processor_kwargs = get_allowed_kwarg_only_overrides(
|
||||
dummy_factory, overrides=model_config.mm_processor_kwargs)
|
||||
|
||||
seq_data, mm_data = dummy_factory(InputContext(model_config), seq_len,
|
||||
_MultiModalCounts(mm_counts),
|
||||
**mm_processor_kwargs)
|
||||
|
||||
# Having more tokens is over-conservative but otherwise fine
|
||||
num_tokens = seq_data.prompt_token_ids
|
||||
@ -229,6 +239,10 @@ class InputRegistry:
|
||||
|
||||
return wrapper
|
||||
|
||||
def _get_model_input_processor(self, model_cls: Type[nn.Module]):
|
||||
return self._input_processors_by_model_type \
|
||||
.get(model_cls, self._default_input_processor)
|
||||
|
||||
def process_input(self, model_config: "ModelConfig",
|
||||
inputs: LLMInputs) -> LLMInputs:
|
||||
"""
|
||||
@ -243,15 +257,17 @@ class InputRegistry:
|
||||
from vllm.model_executor.model_loader import get_model_architecture
|
||||
|
||||
model_cls, _ = get_model_architecture(model_config)
|
||||
processor = self._get_model_input_processor(model_cls)
|
||||
|
||||
processor = self._input_processors_by_model_type \
|
||||
.get(model_cls, self._default_input_processor)
|
||||
mm_processor_kwargs = get_allowed_kwarg_only_overrides(
|
||||
processor, overrides=model_config.mm_processor_kwargs)
|
||||
|
||||
return processor(InputContext(model_config), inputs)
|
||||
return processor(InputContext(model_config), inputs,
|
||||
**mm_processor_kwargs)
|
||||
|
||||
def create_input_processor(self, model_config: "ModelConfig"):
|
||||
"""
|
||||
Create an input processor (see :meth:`process_input`) for a
|
||||
Create an input processor (see :meth:`_process_input`) for a
|
||||
specific model.
|
||||
"""
|
||||
return functools.partial(self.process_input, model_config)
|
||||
|
@ -14,7 +14,8 @@ from typing_extensions import TypeAlias
|
||||
from vllm.config import ModelConfig
|
||||
from vllm.inputs import InputContext
|
||||
from vllm.logger import init_logger
|
||||
from vllm.utils import JSONTree, is_list_of, json_map_leaves
|
||||
from vllm.utils import (JSONTree, get_allowed_kwarg_only_overrides, is_list_of,
|
||||
json_map_leaves)
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
@ -256,11 +257,20 @@ class MultiModalPlugin(ABC):
|
||||
model_cls, _ = get_model_architecture(model_config)
|
||||
|
||||
mapper = self._input_mappers.get(model_cls)
|
||||
# Only get processor kwargs at mapping time if we are not using the
|
||||
# input mapper; no overrides are used on the default here because they
|
||||
# should be passed to the huggingface resource at initialization time.
|
||||
if mapper is not None and mapper != self._default_input_mapper:
|
||||
mm_processor_kwargs = get_allowed_kwarg_only_overrides(
|
||||
mapper, overrides=model_config.mm_processor_kwargs)
|
||||
else:
|
||||
mm_processor_kwargs = {}
|
||||
|
||||
if mapper is None:
|
||||
raise KeyError(f"No input mapper in {self} is registered for "
|
||||
f"model class {model_cls.__name__}.")
|
||||
|
||||
return mapper(InputContext(model_config), data)
|
||||
return mapper(InputContext(model_config), data, **mm_processor_kwargs)
|
||||
|
||||
@abstractmethod
|
||||
def _default_max_multimodal_tokens(self, ctx: InputContext) -> int:
|
||||
@ -333,7 +343,10 @@ class MultiModalPlugin(ABC):
|
||||
f"for model class {model_cls.__name__} in {self}.")
|
||||
|
||||
if callable(max_mm_tokens):
|
||||
max_mm_tokens = max_mm_tokens(InputContext(model_config))
|
||||
mm_processor_kwargs = get_allowed_kwarg_only_overrides(
|
||||
max_mm_tokens, overrides=model_config.mm_processor_kwargs)
|
||||
max_mm_tokens = max_mm_tokens(InputContext(model_config),
|
||||
**mm_processor_kwargs)
|
||||
|
||||
self._validate_max_multimodal_tokens(max_mm_tokens)
|
||||
|
||||
|
@ -6,7 +6,7 @@ from PIL import Image
|
||||
from vllm.config import ModelConfig
|
||||
from vllm.inputs.registry import InputContext
|
||||
from vllm.logger import init_logger
|
||||
from vllm.transformers_utils.image_processor import get_image_processor
|
||||
from vllm.transformers_utils.processor import get_image_processor
|
||||
from vllm.utils import is_list_of
|
||||
|
||||
from .base import MultiModalData, MultiModalInputs, MultiModalPlugin
|
||||
@ -23,9 +23,14 @@ class ImagePlugin(MultiModalPlugin):
|
||||
return "image"
|
||||
|
||||
def _get_hf_image_processor(self, model_config: ModelConfig):
|
||||
mm_processor_kwargs = ({} if model_config.mm_processor_kwargs is None
|
||||
else model_config.mm_processor_kwargs)
|
||||
# We don't explicitly check kwarg overrides to the HF class
|
||||
# since the automodel just takes kwargs, so we can't inspect it
|
||||
return cached_get_image_processor(
|
||||
model_config.model,
|
||||
trust_remote_code=model_config.trust_remote_code)
|
||||
trust_remote_code=model_config.trust_remote_code,
|
||||
**mm_processor_kwargs)
|
||||
|
||||
def _default_input_mapper(
|
||||
self,
|
||||
@ -37,6 +42,7 @@ class ImagePlugin(MultiModalPlugin):
|
||||
# PIL image
|
||||
if isinstance(data, Image.Image) or is_list_of(data, Image.Image):
|
||||
image_processor = self._get_hf_image_processor(model_config)
|
||||
|
||||
if image_processor is None:
|
||||
raise RuntimeError("No HuggingFace processor is available "
|
||||
"to process the image object")
|
||||
|
@ -138,6 +138,15 @@ class MultiModalRegistry:
|
||||
"""
|
||||
Create an input mapper (see :meth:`map_input`) for a specific model.
|
||||
"""
|
||||
# NOTE - we currently make the assumption that if a model has multiple
|
||||
# supported modalities, they take the same kwargs. For the default,
|
||||
# this could be an issue in the future if it falls back to two HF
|
||||
# resources and we can't inspect the signature easily since it's
|
||||
# getting initialized through the autoclass.
|
||||
#
|
||||
# If this is a problem in the future, we should revisit it, but since
|
||||
# it potentially introduces a lot of complexity for a currently
|
||||
# uncommon case, we do not for simplicity of both use & implementation
|
||||
return functools.partial(self.map_input, model_config)
|
||||
|
||||
def register_max_multimodal_tokens(
|
||||
|
@ -6,7 +6,7 @@ import numpy as np
|
||||
from vllm.config import ModelConfig
|
||||
from vllm.inputs.registry import InputContext
|
||||
from vllm.logger import init_logger
|
||||
from vllm.transformers_utils.image_processor import get_video_processor
|
||||
from vllm.transformers_utils.processor import get_video_processor
|
||||
from vllm.transformers_utils.tokenizer import get_tokenizer
|
||||
from vllm.utils import is_list_of
|
||||
|
||||
@ -37,9 +37,14 @@ class VideoPlugin(ImagePlugin):
|
||||
return "video"
|
||||
|
||||
def _get_hf_video_processor(self, model_config: ModelConfig):
|
||||
mm_processor_kwargs = ({} if model_config.mm_processor_kwargs is None
|
||||
else model_config.mm_processor_kwargs)
|
||||
# We don't explicitly check kwarg overrides to the HF class
|
||||
# since the automodel just takes kwargs, so we can't inspect it
|
||||
return cached_get_video_processor(
|
||||
model_config.model,
|
||||
trust_remote_code=model_config.trust_remote_code)
|
||||
trust_remote_code=model_config.trust_remote_code,
|
||||
**mm_processor_kwargs)
|
||||
|
||||
def _default_input_mapper(
|
||||
self,
|
||||
|
@ -1,64 +0,0 @@
|
||||
from typing import cast
|
||||
|
||||
|
||||
def get_video_processor(
|
||||
processor_name: str,
|
||||
trust_remote_code: bool = False,
|
||||
):
|
||||
"""
|
||||
Gets a processor for the given model name via HuggingFace.
|
||||
"""
|
||||
from transformers import AutoProcessor
|
||||
|
||||
try:
|
||||
processor = AutoProcessor.from_pretrained(processor_name)
|
||||
video_processor = processor.video_processor
|
||||
|
||||
except ValueError as e:
|
||||
if not trust_remote_code:
|
||||
err_msg = (
|
||||
"Failed to load the processor. If the processor is "
|
||||
"a custom processor not yet available in the HuggingFace "
|
||||
"transformers library, consider setting "
|
||||
"`trust_remote_code=True` in LLM or using the "
|
||||
"`--trust-remote-code` flag in the CLI.")
|
||||
raise RuntimeError(err_msg) from e
|
||||
else:
|
||||
raise e
|
||||
return video_processor
|
||||
|
||||
|
||||
def get_image_processor(
|
||||
processor_name: str,
|
||||
*args,
|
||||
trust_remote_code: bool = False,
|
||||
**kwargs,
|
||||
):
|
||||
"""Gets an image processor for the given model name via HuggingFace."""
|
||||
# don't put this import at the top level
|
||||
# it will call torch.cuda.device_count()
|
||||
from transformers import AutoImageProcessor
|
||||
from transformers.image_processing_utils import BaseImageProcessor
|
||||
|
||||
try:
|
||||
processor = AutoImageProcessor.from_pretrained(
|
||||
processor_name,
|
||||
*args,
|
||||
trust_remote_code=trust_remote_code,
|
||||
**kwargs)
|
||||
except ValueError as e:
|
||||
# If the error pertains to the processor class not existing or not
|
||||
# currently being imported, suggest using the --trust-remote-code flag.
|
||||
# Unlike AutoTokenizer, AutoImageProcessor does not separate such errors
|
||||
if not trust_remote_code:
|
||||
err_msg = (
|
||||
"Failed to load the image processor. If the image processor is "
|
||||
"a custom processor not yet available in the HuggingFace "
|
||||
"transformers library, consider setting "
|
||||
"`trust_remote_code=True` in LLM or using the "
|
||||
"`--trust-remote-code` flag in the CLI.")
|
||||
raise RuntimeError(err_msg) from e
|
||||
else:
|
||||
raise e
|
||||
|
||||
return cast(BaseImageProcessor, processor)
|
@ -1,13 +1,13 @@
|
||||
from typing import cast
|
||||
from typing import Any, cast
|
||||
|
||||
|
||||
def get_processor(
|
||||
processor_name: str,
|
||||
*args,
|
||||
*args: Any,
|
||||
trust_remote_code: bool = False,
|
||||
**kwargs,
|
||||
**kwargs: Any,
|
||||
):
|
||||
"""Gets a processor for the given model name via HuggingFace."""
|
||||
"""Load a processor for the given model name via HuggingFace."""
|
||||
# don't put this import at the top level
|
||||
# it will call torch.cuda.device_count()
|
||||
from transformers import AutoProcessor
|
||||
@ -35,3 +35,60 @@ def get_processor(
|
||||
raise e
|
||||
|
||||
return cast(ProcessorMixin, processor)
|
||||
|
||||
|
||||
def get_image_processor(
|
||||
processor_name: str,
|
||||
*args: Any,
|
||||
trust_remote_code: bool = False,
|
||||
**kwargs: Any,
|
||||
):
|
||||
"""Load an image processor for the given model name via HuggingFace."""
|
||||
# don't put this import at the top level
|
||||
# it will call torch.cuda.device_count()
|
||||
from transformers import AutoImageProcessor
|
||||
from transformers.image_processing_utils import BaseImageProcessor
|
||||
|
||||
try:
|
||||
processor = AutoImageProcessor.from_pretrained(
|
||||
processor_name,
|
||||
*args,
|
||||
trust_remote_code=trust_remote_code,
|
||||
**kwargs)
|
||||
except ValueError as e:
|
||||
# If the error pertains to the processor class not existing or not
|
||||
# currently being imported, suggest using the --trust-remote-code flag.
|
||||
# Unlike AutoTokenizer, AutoImageProcessor does not separate such errors
|
||||
if not trust_remote_code:
|
||||
err_msg = (
|
||||
"Failed to load the image processor. If the image processor is "
|
||||
"a custom processor not yet available in the HuggingFace "
|
||||
"transformers library, consider setting "
|
||||
"`trust_remote_code=True` in LLM or using the "
|
||||
"`--trust-remote-code` flag in the CLI.")
|
||||
raise RuntimeError(err_msg) from e
|
||||
else:
|
||||
raise e
|
||||
|
||||
return cast(BaseImageProcessor, processor)
|
||||
|
||||
|
||||
def get_video_processor(
|
||||
processor_name: str,
|
||||
*args: Any,
|
||||
trust_remote_code: bool = False,
|
||||
**kwargs: Any,
|
||||
):
|
||||
"""Load a video processor for the given model name via HuggingFace."""
|
||||
# don't put this import at the top level
|
||||
# it will call torch.cuda.device_count()
|
||||
from transformers.image_processing_utils import BaseImageProcessor
|
||||
|
||||
processor = get_processor(
|
||||
processor_name,
|
||||
*args,
|
||||
trust_remote_code=trust_remote_code,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
return cast(BaseImageProcessor, processor.video_processor)
|
||||
|
@ -4,6 +4,7 @@ import contextlib
|
||||
import datetime
|
||||
import enum
|
||||
import gc
|
||||
import inspect
|
||||
import os
|
||||
import random
|
||||
import socket
|
||||
@ -1237,6 +1238,53 @@ async def _run_task_with_lock(task: Callable, lock: asyncio.Lock, *args,
|
||||
return await task(*args, **kwargs)
|
||||
|
||||
|
||||
def get_allowed_kwarg_only_overrides(
|
||||
callable: Callable[..., object],
|
||||
overrides: Optional[Dict[str, Any]],
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Given a callable which has one or more keyword only params and a dict
|
||||
mapping param names to values, drop values that can be not be kwarg
|
||||
expanded to overwrite one or more keyword-only args. This is used in a
|
||||
few places to handle custom processor overrides for multimodal models,
|
||||
e.g., for profiling when processor options provided by the user
|
||||
may affect the number of mm tokens per instance.
|
||||
|
||||
Args:
|
||||
callable: Callable which takes 0 or more keyword only arguments.
|
||||
overrides: Potential overrides to be used when invoking the callable.
|
||||
|
||||
Returns:
|
||||
Dictionary containing the kwargs to be leveraged which may be used
|
||||
to overwrite one or more keyword only arguments when invoking the
|
||||
callable.
|
||||
"""
|
||||
if not overrides:
|
||||
return {}
|
||||
|
||||
allowed_override_names = [
|
||||
name for name, param in inspect.signature(callable).parameters.items()
|
||||
if param.kind == inspect.Parameter.KEYWORD_ONLY
|
||||
]
|
||||
|
||||
# Drop any mm_processor_kwargs provided by the user that are
|
||||
# not kwarg names accepted by the provided input processor.
|
||||
filtered_overrides = {
|
||||
kwarg_name: val
|
||||
for kwarg_name, val in overrides.items()
|
||||
if kwarg_name in allowed_override_names
|
||||
}
|
||||
|
||||
# If anything is dropped, log a warning
|
||||
dropped_keys = overrides.keys() - filtered_overrides.keys()
|
||||
if dropped_keys:
|
||||
logger.warning(
|
||||
"The following intended overrides are not keyword-only args "
|
||||
"and and will be dropped: %s", dropped_keys)
|
||||
|
||||
return filtered_overrides
|
||||
|
||||
|
||||
# Using dynamo with vLLM doesn't really work well with PyTorch versions < 2.4.0.
|
||||
# In particular, the FakeScalarType is not supported for earlier versions of
|
||||
# PyTorch which breaks dynamo for any ops registered using ScalarType.
|
||||
|
Loading…
x
Reference in New Issue
Block a user