[Core][Frontend] Add Support for Inference Time mm_processor_kwargs (#9131)

Signed-off-by: Alex-Brooks <Alex.Brooks@ibm.com>
This commit is contained in:
Alex Brooks 2024-10-08 08:12:56 -06:00 committed by GitHub
parent 8c746226c9
commit a3691b6b5e
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
21 changed files with 440 additions and 118 deletions

View File

@ -105,6 +105,7 @@ def run_phi3v(question: str, modality: str):
trust_remote_code=True, trust_remote_code=True,
max_model_len=4096, max_model_len=4096,
max_num_seqs=2, max_num_seqs=2,
# Note - mm_processor_kwargs can also be passed to generate/chat calls
mm_processor_kwargs={"num_crops": 16}, mm_processor_kwargs={"num_crops": 16},
) )
stop_token_ids = None stop_token_ids = None

View File

@ -74,11 +74,11 @@ def mm_model_cls():
# lambda whose signature matches max token calcs extra & mapper + extra kwargs # lambda whose signature matches max token calcs extra & mapper + extra kwargs
get_num_crops = lambda ctx, *, num_crops=DEFAULT_NUM_CROPS: num_crops get_num_crops = lambda ctx, *, num_crops=DEFAULT_NUM_CROPS: num_crops
custom_mapper = lambda ctx, data, *, num_crops=DEFAULT_NUM_CROPS: { custom_mapper = lambda ctx, data, *, num_crops=DEFAULT_NUM_CROPS: {
"num_pixels": torch.zeros(size=(1, num_crops + 1, 3, 336, 336)) "pixel_values": torch.zeros(size=(1, num_crops + 1, 3, 336, 336))
} }
### Test for default processor logic & mm_processor_kwargs wrapping ### Tests for default processor logic & mm_processor_kwargs wrapping
def test_default_processor_is_a_noop(): def test_default_processor_is_a_noop():
"""Ensure that by default, there is no processor override.""" """Ensure that by default, there is no processor override."""
dummy_registry = InputRegistry() dummy_registry = InputRegistry()
@ -89,23 +89,46 @@ def test_default_processor_is_a_noop():
assert proc_inputs is proc_outputs assert proc_inputs is proc_outputs
@pytest.mark.parametrize("num_crops", [None, NUM_CROPS_OVERRIDE]) def _get_num_crops_info(init_num_crops: int, inference_num_crops: int):
def test_processor_default_kwargs(use_processor_mock, num_crops): """Get the init / inference kwargs and expected num_crops for this test."""
"""Ensure input processors can use processor kwargs."""
dummy_registry = InputRegistry()
# If we have a value for num_crops, pass the override value and make # If we have a value for num_crops, pass the override value and make
# sure we get that value as a return-value from out mock processor, # sure we get that value as a return-value from out mock processor,
# otherwise fall back to the default value # otherwise fall back to the default value
mm_processor_kwargs = None if num_crops is None else { init_kwargs = None if init_num_crops is None else {
"num_crops": num_crops "num_crops": init_num_crops
} }
expected_num_crops = DEFAULT_NUM_CROPS if num_crops is None else num_crops inference_kwargs = None if inference_num_crops is None else {
ctx = build_model_context(DUMMY_MODEL_ID, "num_crops": inference_num_crops
mm_processor_kwargs=mm_processor_kwargs) }
processor = dummy_registry.create_input_processor(ctx.model_config) if inference_num_crops is not None:
expected_seq_count = inference_num_crops
elif init_num_crops is not None:
expected_seq_count = init_num_crops
else:
expected_seq_count = DEFAULT_NUM_CROPS
return init_kwargs, inference_kwargs, expected_seq_count
num_crops_val = processor(LLMInputs(prompt_token_ids=[], prompt=""))
assert num_crops_val == expected_num_crops @pytest.mark.parametrize("init_num_crops,inference_num_crops", [
(None, None),
(NUM_CROPS_OVERRIDE, None),
(DEFAULT_NUM_CROPS, NUM_CROPS_OVERRIDE),
])
def test_input_processor_kwargs(use_processor_mock, init_num_crops,
inference_num_crops):
"""Ensure input processors can use processor kwargs."""
dummy_registry = InputRegistry()
init_kwargs, inference_kwargs, expected_seq_count = _get_num_crops_info(
init_num_crops, inference_num_crops)
ctx = build_model_context(DUMMY_MODEL_ID, mm_processor_kwargs=init_kwargs)
processor = dummy_registry.create_input_processor(ctx.model_config)
num_crops_val = processor(
LLMInputs(prompt_token_ids=[],
prompt="",
mm_processor_kwargs=inference_kwargs))
assert num_crops_val == expected_seq_count
@pytest.mark.parametrize( @pytest.mark.parametrize(
@ -124,11 +147,16 @@ def test_processor_with_sad_kwarg_overrides(use_processor_mock,
mm_processor_kwargs): mm_processor_kwargs):
"""Ensure that input processors filter out invalid mm_processor_kwargs""" """Ensure that input processors filter out invalid mm_processor_kwargs"""
dummy_registry = InputRegistry() dummy_registry = InputRegistry()
# Should filter out the init time kwargs
ctx = build_model_context(DUMMY_MODEL_ID, ctx = build_model_context(DUMMY_MODEL_ID,
mm_processor_kwargs=mm_processor_kwargs) mm_processor_kwargs=mm_processor_kwargs)
processor = dummy_registry.create_input_processor(ctx.model_config) processor = dummy_registry.create_input_processor(ctx.model_config)
num_crops_val = processor(LLMInputs(prompt_token_ids=[], prompt="")) # Should filter out the inference time kwargs
num_crops_val = processor(
LLMInputs(prompt_token_ids=[],
prompt="",
mm_processor_kwargs=mm_processor_kwargs))
assert num_crops_val == DEFAULT_NUM_CROPS assert num_crops_val == DEFAULT_NUM_CROPS
@ -271,32 +299,34 @@ def test_default_mapper_with_processer_kwargs(image_assets, num_crops):
assert mapped_inputs["pixel_values"].shape[1] == num_crops + 1 assert mapped_inputs["pixel_values"].shape[1] == num_crops + 1
@pytest.mark.parametrize("num_crops", [None, NUM_CROPS_OVERRIDE]) @pytest.mark.parametrize("init_num_crops,inference_num_crops", [
def test_custom_mapper_kwarg_overrides(image_assets, num_crops): (None, None),
(NUM_CROPS_OVERRIDE, None),
(DEFAULT_NUM_CROPS, NUM_CROPS_OVERRIDE),
])
def test_custom_mapper_kwarg_overrides(image_assets, init_num_crops,
inference_num_crops):
"""Ensure custom mappers can use processor kwargs.""" """Ensure custom mappers can use processor kwargs."""
mm_processor_kwargs = None if num_crops is None else { init_kwargs, inference_kwargs, expected_seq_count = _get_num_crops_info(
"num_crops": num_crops init_num_crops, inference_num_crops)
}
expected_seq_count = DEFAULT_NUM_CROPS if num_crops is None else num_crops
ctx = build_model_context(MULTIMODAL_MODEL_ID, ctx = build_model_context(MULTIMODAL_MODEL_ID,
trust_remote_code=True, trust_remote_code=True,
mm_processor_kwargs=mm_processor_kwargs, mm_processor_kwargs=init_kwargs,
limit_mm_per_prompt={"image": 1}) limit_mm_per_prompt={"image": 1})
mm_registry = MultiModalRegistry() mm_registry = MultiModalRegistry()
mm_registry.init_mm_limits_per_prompt(ctx.model_config) mm_registry.init_mm_limits_per_prompt(ctx.model_config)
# Patch the image registry for phi3v with our lambda that is compatible
# with overrides, then ensure that calling the method correctly echos
# our num_crops value back from the mm_processor_kwargs.
image = image_assets[0].pil_image image = image_assets[0].pil_image
mm_inputs = {"image": image} mm_inputs = {"image": image}
with patch.object( # Patch the image registry for phi3v with our lambda that is compatible
mm_registry._get_plugin("image"), # with overrides, then ensure that calling the method correctly echos
"_default_input_mapper", # our num_crops value back from the mm_processor_kwargs.
{mm_model_cls(): custom_mapper}, mm_registry._get_plugin("image").register_input_mapper(custom_mapper)(
): mm_model_cls())
mapped_inputs = mm_registry.map_input(ctx.model_config, mm_inputs) mapped_inputs = mm_registry.map_input(ctx.model_config, mm_inputs,
inference_kwargs)
assert mapped_inputs["pixel_values"].shape[1] == expected_seq_count + 1 assert mapped_inputs["pixel_values"].shape[1] == expected_seq_count + 1
@ -316,6 +346,7 @@ def test_custom_mapper_kwarg_overrides(image_assets, num_crops):
def test_custom_mapper_with_sad_kwarg_overrides(image_assets, def test_custom_mapper_with_sad_kwarg_overrides(image_assets,
mm_processor_kwargs): mm_processor_kwargs):
"""Ensure that custom mappers filters out invalid mm_processor_kwargs""" """Ensure that custom mappers filters out invalid mm_processor_kwargs"""
# Should filter out the init time kwargs
ctx = build_model_context(MULTIMODAL_MODEL_ID, ctx = build_model_context(MULTIMODAL_MODEL_ID,
trust_remote_code=True, trust_remote_code=True,
mm_processor_kwargs=mm_processor_kwargs, mm_processor_kwargs=mm_processor_kwargs,
@ -323,17 +354,16 @@ def test_custom_mapper_with_sad_kwarg_overrides(image_assets,
mm_registry = MultiModalRegistry() mm_registry = MultiModalRegistry()
mm_registry.init_mm_limits_per_prompt(ctx.model_config) mm_registry.init_mm_limits_per_prompt(ctx.model_config)
# Patch the image registry for phi3v with our lambda that is compatible
# with overrides, then ensure that calling the method correctly echos
# our num_crops value back from the mm_processor_kwargs.
image = image_assets[0].pil_image image = image_assets[0].pil_image
mm_inputs = {"image": image} mm_inputs = {"image": image}
with patch.object( # Patch the image registry for phi3v with our lambda that is compatible
mm_registry._get_plugin("image"), # with overrides, then ensure that calling the method correctly echos
"_default_input_mapper", # our num_crops value back from the mm_processor_kwargs.
{mm_model_cls(): custom_mapper}, mm_registry._get_plugin("image").register_input_mapper(custom_mapper)(
): mm_model_cls())
mapped_inputs = mm_registry.map_input(ctx.model_config, mm_inputs) # Should filter out the inference time kwargs
mapped_inputs = mm_registry.map_input(
ctx.model_config, mm_inputs, mm_processor_kwargs=mm_processor_kwargs)
assert mapped_inputs["pixel_values"].shape[1] == DEFAULT_NUM_CROPS + 1 assert mapped_inputs["pixel_values"].shape[1] == DEFAULT_NUM_CROPS + 1

View File

@ -2,6 +2,7 @@ from typing import List
import pytest import pytest
from vllm.inputs import zip_enc_dec_prompts
from vllm.inputs.parse import parse_and_batch_prompt from vllm.inputs.parse import parse_and_batch_prompt
STRING_INPUTS = [ STRING_INPUTS = [
@ -51,3 +52,28 @@ def test_parse_single_batch_token_consistent(token_input: List[int]):
def test_parse_single_batch_string_slice(inputs_slice: slice): def test_parse_single_batch_string_slice(inputs_slice: slice):
assert parse_and_batch_prompt(STRING_INPUTS)[inputs_slice] \ assert parse_and_batch_prompt(STRING_INPUTS)[inputs_slice] \
== parse_and_batch_prompt(STRING_INPUTS[inputs_slice]) == parse_and_batch_prompt(STRING_INPUTS[inputs_slice])
# yapf: disable
@pytest.mark.parametrize('mm_processor_kwargs,expected_mm_kwargs', [
(None, [{}, {}]),
({}, [{}, {}]),
({"foo": 100}, [{"foo": 100}, {"foo": 100}]),
([{"foo": 100}, {"bar": 200}], [{"foo": 100}, {"bar": 200}]),
])
# yapf: enable
def test_zip_enc_dec_prompts(mm_processor_kwargs, expected_mm_kwargs):
"""Test mm_processor_kwargs init for zipping enc/dec prompts."""
encoder_prompts = ['An encoder prompt', 'Another encoder prompt']
decoder_prompts = ['A decoder prompt', 'Another decoder prompt']
zipped_prompts = zip_enc_dec_prompts(encoder_prompts, decoder_prompts,
mm_processor_kwargs)
assert len(zipped_prompts) == len(encoder_prompts) == len(decoder_prompts)
for enc, dec, exp_kwargs, zipped in zip(encoder_prompts, decoder_prompts,
expected_mm_kwargs,
zipped_prompts):
assert isinstance(zipped, dict)
assert len(zipped.keys()) == 3
assert zipped['encoder_prompt'] == enc
assert zipped['decoder_prompt'] == dec
assert zipped['mm_processor_kwargs'] == exp_kwargs

View File

@ -7,7 +7,7 @@ from typing import AsyncIterator, Tuple
import pytest import pytest
from vllm.utils import (FlexibleArgumentParser, deprecate_kwargs, from vllm.utils import (FlexibleArgumentParser, deprecate_kwargs,
get_open_port, merge_async_iterators) get_open_port, merge_async_iterators, supports_kw)
from .utils import error_on_warning from .utils import error_on_warning
@ -236,3 +236,33 @@ def test_no_model_tag(parser_with_config):
with pytest.raises(ValueError): with pytest.raises(ValueError):
parser_with_config.parse_args( parser_with_config.parse_args(
['serve', '--config', './data/test_config.yaml']) ['serve', '--config', './data/test_config.yaml'])
# yapf: enable
@pytest.mark.parametrize(
"callable,kw_name,requires_kw_only,allow_var_kwargs,is_supported",
[
# Tests for positional argument support
(lambda foo: None, "foo", True, True, False),
(lambda foo: None, "foo", False, True, True),
# Tests for positional or keyword / keyword only
(lambda foo=100: None, "foo", True, True, False),
(lambda *, foo: None, "foo", False, True, True),
# Tests to make sure the names of variadic params are NOT supported
(lambda *args: None, "args", False, True, False),
(lambda **kwargs: None, "kwargs", False, True, False),
# Tests for if we allow var kwargs to add support
(lambda foo: None, "something_else", False, True, False),
(lambda foo, **kwargs: None, "something_else", False, True, True),
(lambda foo, **kwargs: None, "kwargs", True, True, False),
(lambda foo, **kwargs: None, "foo", True, True, False),
])
# yapf: disable
def test_supports_kw(callable,kw_name,requires_kw_only,
allow_var_kwargs,is_supported):
assert supports_kw(
callable=callable,
kw_name=kw_name,
requires_kw_only=requires_kw_only,
allow_var_kwargs=allow_var_kwargs
) == is_supported

View File

@ -1309,6 +1309,7 @@ class Scheduler:
# `multi_modal_data` will be None. # `multi_modal_data` will be None.
multi_modal_data=seq_group.multi_modal_data multi_modal_data=seq_group.multi_modal_data
if scheduler_outputs.num_prefill_groups > 0 else None, if scheduler_outputs.num_prefill_groups > 0 else None,
mm_processor_kwargs=seq_group.mm_processor_kwargs,
prompt_adapter_request=seq_group.prompt_adapter_request, prompt_adapter_request=seq_group.prompt_adapter_request,
) )
else: else:

View File

@ -811,6 +811,13 @@ class LLMEngine:
) )
processed_inputs = self.input_processor(preprocessed_inputs) processed_inputs = self.input_processor(preprocessed_inputs)
# This is a bit of a hack - copy the mm_processor_kwargs that were
# used in the input processor to the processed output, since these
# kwargs are presumed to be immutable and the values should be aligned
# between the input processor (here) and the input mapper.
processed_inputs["mm_processor_kwargs"] = preprocessed_inputs.get(
"mm_processor_kwargs")
self._add_processed_request( self._add_processed_request(
request_id=request_id, request_id=request_id,
processed_inputs=processed_inputs, processed_inputs=processed_inputs,

View File

@ -472,6 +472,7 @@ class LLM:
add_generation_prompt: bool = True, add_generation_prompt: bool = True,
continue_final_message: bool = False, continue_final_message: bool = False,
tools: Optional[List[Dict[str, Any]]] = None, tools: Optional[List[Dict[str, Any]]] = None,
mm_processor_kwargs: Optional[Dict[str, Any]] = None,
) -> List[RequestOutput]: ) -> List[RequestOutput]:
""" """
Generate responses for a chat conversation. Generate responses for a chat conversation.
@ -501,6 +502,8 @@ class LLM:
continue_final_message: If True, continues the final message in continue_final_message: If True, continues the final message in
the conversation instead of starting a new one. Cannot be `True` the conversation instead of starting a new one. Cannot be `True`
if `add_generation_prompt` is also `True`. if `add_generation_prompt` is also `True`.
mm_processor_kwargs: Multimodal processor kwarg overrides for this
chat request. Only used for offline requests.
Returns: Returns:
A list of ``RequestOutput`` objects containing the generated A list of ``RequestOutput`` objects containing the generated
@ -522,6 +525,9 @@ class LLM:
tokenizer = self.get_tokenizer() tokenizer = self.get_tokenizer()
model_config = self.llm_engine.get_model_config() model_config = self.llm_engine.get_model_config()
# NOTE: _parse_chat_message_content_parts() currently doesn't
# handle mm_processor_kwargs, since there is no implementation in
# the chat message parsing for it.
conversation, mm_data = parse_chat_messages( conversation, mm_data = parse_chat_messages(
msgs, model_config, tokenizer) msgs, model_config, tokenizer)
@ -554,6 +560,9 @@ class LLM:
if mm_data is not None: if mm_data is not None:
prompt["multi_modal_data"] = mm_data prompt["multi_modal_data"] = mm_data
if mm_processor_kwargs is not None:
prompt["mm_processor_kwargs"] = mm_processor_kwargs
prompts.append(prompt) prompts.append(prompt)
return self.generate( return self.generate(

View File

@ -1,5 +1,5 @@
from typing import (TYPE_CHECKING, Generic, Iterable, List, Optional, Tuple, from typing import (TYPE_CHECKING, Any, Dict, Generic, Iterable, List,
Union) Optional, Tuple, Union)
from typing_extensions import NotRequired, TypedDict, TypeVar from typing_extensions import NotRequired, TypedDict, TypeVar
@ -19,6 +19,14 @@ class TextPrompt(TypedDict):
if the model supports it. if the model supports it.
""" """
mm_processor_kwargs: NotRequired[Dict[str, Any]]
"""
Optional multi-modal processor kwargs to be forwarded to the
multimodal input mapper & processor. Note that if multiple modalities
have registered mappers etc for the model being considered, we attempt
to pass the mm_processor_kwargs to each of them.
"""
class TokensPrompt(TypedDict): class TokensPrompt(TypedDict):
"""Schema for a tokenized prompt.""" """Schema for a tokenized prompt."""
@ -32,6 +40,14 @@ class TokensPrompt(TypedDict):
if the model supports it. if the model supports it.
""" """
mm_processor_kwargs: NotRequired[Dict[str, Any]]
"""
Optional multi-modal processor kwargs to be forwarded to the
multimodal input mapper & processor. Note that if multiple modalities
have registered mappers etc for the model being considered, we attempt
to pass the mm_processor_kwargs to each of them.
"""
SingletonPrompt = Union[str, TextPrompt, TokensPrompt] SingletonPrompt = Union[str, TextPrompt, TokensPrompt]
""" """
@ -74,7 +90,9 @@ class ExplicitEncoderDecoderPrompt(TypedDict, Generic[_T1_co, _T2_co]):
according to any of the :class:`SingletonPrompt` schemas, according to any of the :class:`SingletonPrompt` schemas,
and are not required to have the same schema. and are not required to have the same schema.
Only the encoder prompt may have multi-modal data. Only the encoder prompt may have multi-modal data. mm_processor_kwargs
should be at the top-level, and should not be set in the encoder/decoder
prompts, since they are agnostic to the encoder/decoder.
Note that an :class:`ExplicitEncoderDecoderPrompt` may not Note that an :class:`ExplicitEncoderDecoderPrompt` may not
be used as an input to a decoder-only model, be used as an input to a decoder-only model,
@ -87,6 +105,8 @@ class ExplicitEncoderDecoderPrompt(TypedDict, Generic[_T1_co, _T2_co]):
decoder_prompt: Optional[_T2_co] decoder_prompt: Optional[_T2_co]
mm_processor_kwargs: NotRequired[Dict[str, Any]]
PromptType = Union[SingletonPrompt, ExplicitEncoderDecoderPrompt] PromptType = Union[SingletonPrompt, ExplicitEncoderDecoderPrompt]
""" """
@ -121,6 +141,14 @@ class LLMInputs(TypedDict):
if the model supports it. if the model supports it.
""" """
mm_processor_kwargs: NotRequired[Optional[Dict[str, Any]]]
"""
Optional multi-modal processor kwargs to be forwarded to the
multimodal input mapper & processor. Note that if multiple modalities
have registered mappers etc for the model being considered, we attempt
to pass the mm_processor_kwargs to each of them.
"""
class EncoderDecoderLLMInputs(LLMInputs): class EncoderDecoderLLMInputs(LLMInputs):
""" """
@ -152,22 +180,43 @@ _T2 = TypeVar("_T2", bound=SingletonPrompt, default=SingletonPrompt)
def build_explicit_enc_dec_prompt( def build_explicit_enc_dec_prompt(
encoder_prompt: _T1, encoder_prompt: _T1,
decoder_prompt: Optional[_T2], decoder_prompt: Optional[_T2],
mm_processor_kwargs: Optional[Dict[str, Any]] = None,
) -> ExplicitEncoderDecoderPrompt[_T1, _T2]: ) -> ExplicitEncoderDecoderPrompt[_T1, _T2]:
return ExplicitEncoderDecoderPrompt(encoder_prompt=encoder_prompt, if mm_processor_kwargs is None:
decoder_prompt=decoder_prompt) mm_processor_kwargs = {}
return ExplicitEncoderDecoderPrompt(
encoder_prompt=encoder_prompt,
decoder_prompt=decoder_prompt,
mm_processor_kwargs=mm_processor_kwargs)
def zip_enc_dec_prompts( def zip_enc_dec_prompts(
enc_prompts: Iterable[_T1], enc_prompts: Iterable[_T1],
dec_prompts: Iterable[Optional[_T2]], dec_prompts: Iterable[Optional[_T2]],
mm_processor_kwargs: Optional[Union[Iterable[Dict[str, Any]],
Dict[str, Any]]] = None,
) -> List[ExplicitEncoderDecoderPrompt[_T1, _T2]]: ) -> List[ExplicitEncoderDecoderPrompt[_T1, _T2]]:
""" """
Zip encoder and decoder prompts together into a list of Zip encoder and decoder prompts together into a list of
:class:`ExplicitEncoderDecoderPrompt` instances. :class:`ExplicitEncoderDecoderPrompt` instances. mm_processor_kwargs
may also be provided; if a dict is passed, the same dictionary will be
used for every encoder/decoder prompt. If an iterable is provided, it will
be zipped with the encoder/decoder prompts.
""" """
if mm_processor_kwargs is None:
mm_processor_kwargs = {}
if isinstance(mm_processor_kwargs, Dict):
return [
build_explicit_enc_dec_prompt(encoder_prompt, decoder_prompt,
mm_processor_kwargs)
for (encoder_prompt,
decoder_prompt) in zip(enc_prompts, dec_prompts)
]
return [ return [
build_explicit_enc_dec_prompt(encoder_prompt, decoder_prompt) build_explicit_enc_dec_prompt(encoder_prompt, decoder_prompt,
for (encoder_prompt, decoder_prompt) in zip(enc_prompts, dec_prompts) mm_proc_kwargs)
for (encoder_prompt, decoder_prompt, mm_proc_kwargs
) in zip(enc_prompts, dec_prompts, mm_processor_kwargs)
] ]

View File

@ -1,5 +1,5 @@
import asyncio import asyncio
from typing import TYPE_CHECKING, List, Optional, Tuple, Union from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union
from typing_extensions import assert_never from typing_extensions import assert_never
@ -20,9 +20,11 @@ if TYPE_CHECKING:
logger = init_logger(__name__) logger = init_logger(__name__)
PromptComponents = Tuple[Optional[str], List[int], PromptComponents = Tuple[Optional[str], List[int],
Optional["MultiModalDataDict"]] Optional["MultiModalDataDict"], Optional[Dict[str,
Any]]]
DecoderPromptComponents = Tuple[Optional[str], Optional[List[int]], DecoderPromptComponents = Tuple[Optional[str], Optional[List[int]],
Optional["MultiModalDataDict"]] Optional["MultiModalDataDict"],
Optional[Dict[str, Any]]]
class InputPreprocessor: class InputPreprocessor:
@ -227,6 +229,7 @@ class InputPreprocessor:
* prompt * prompt
* prompt_token_ids * prompt_token_ids
* multi_modal_data * multi_modal_data
* mm_processor_kwargs (request-level input processor/mapper overrides)
''' '''
parsed = parse_singleton_prompt(prompt) parsed = parse_singleton_prompt(prompt)
@ -239,10 +242,12 @@ class InputPreprocessor:
lora_request=lora_request, lora_request=lora_request,
) )
multi_modal_data = None multi_modal_data = None
mm_processor_kwargs = None
elif parsed["type"] == "tokens": elif parsed["type"] == "tokens":
prompt_text = None prompt_text = None
prompt_token_ids = parsed["content"]["prompt_token_ids"] prompt_token_ids = parsed["content"]["prompt_token_ids"]
multi_modal_data = parsed["content"].get("multi_modal_data") multi_modal_data = parsed["content"].get("multi_modal_data")
mm_processor_kwargs = parsed["content"].get("mm_processor_kwargs")
elif parsed["type"] == "text": elif parsed["type"] == "text":
prompt_text = parsed["content"]["prompt"] prompt_text = parsed["content"]["prompt"]
prompt_token_ids = self._tokenize_prompt( prompt_token_ids = self._tokenize_prompt(
@ -251,10 +256,12 @@ class InputPreprocessor:
lora_request=lora_request, lora_request=lora_request,
) )
multi_modal_data = parsed["content"].get("multi_modal_data") multi_modal_data = parsed["content"].get("multi_modal_data")
mm_processor_kwargs = parsed["content"].get("mm_processor_kwargs")
else: else:
assert_never(parsed) assert_never(parsed)
return prompt_text, prompt_token_ids, multi_modal_data return (prompt_text, prompt_token_ids, multi_modal_data,
mm_processor_kwargs)
async def _extract_prompt_components_async( async def _extract_prompt_components_async(
self, self,
@ -273,10 +280,12 @@ class InputPreprocessor:
lora_request=lora_request, lora_request=lora_request,
) )
multi_modal_data = None multi_modal_data = None
mm_processor_kwargs = None
elif parsed["type"] == "tokens": elif parsed["type"] == "tokens":
prompt_text = None prompt_text = None
prompt_token_ids = parsed["content"]["prompt_token_ids"] prompt_token_ids = parsed["content"]["prompt_token_ids"]
multi_modal_data = parsed["content"].get("multi_modal_data") multi_modal_data = parsed["content"].get("multi_modal_data")
mm_processor_kwargs = parsed["content"].get("mm_processor_kwargs")
elif parsed["type"] == "text": elif parsed["type"] == "text":
prompt_text = parsed["content"]["prompt"] prompt_text = parsed["content"]["prompt"]
prompt_token_ids = await self._tokenize_prompt_async( prompt_token_ids = await self._tokenize_prompt_async(
@ -285,18 +294,21 @@ class InputPreprocessor:
lora_request=lora_request, lora_request=lora_request,
) )
multi_modal_data = parsed["content"].get("multi_modal_data") multi_modal_data = parsed["content"].get("multi_modal_data")
mm_processor_kwargs = parsed["content"].get("mm_processor_kwargs")
else: else:
assert_never(parsed) assert_never(parsed)
return prompt_text, prompt_token_ids, multi_modal_data return (prompt_text, prompt_token_ids, multi_modal_data,
mm_processor_kwargs)
def _build_enc_dec_llm_inputs( def _build_enc_dec_llm_inputs(
self, self,
encoder_comps: PromptComponents, encoder_comps: PromptComponents,
decoder_comps: DecoderPromptComponents, decoder_comps: DecoderPromptComponents,
mm_processor_kwargs: Dict[str, Any],
) -> EncoderDecoderLLMInputs: ) -> EncoderDecoderLLMInputs:
encoder_prompt, encoder_prompt_ids, encoder_mm_data = encoder_comps encoder_prompt, encoder_prompt_ids, encoder_mm_data, _ = encoder_comps
decoder_prompt, decoder_prompt_ids, decoder_mm_data = decoder_comps decoder_prompt, decoder_prompt_ids, decoder_mm_data, _ = decoder_comps
if decoder_mm_data is not None: if decoder_mm_data is not None:
raise ValueError( raise ValueError(
@ -314,6 +326,7 @@ class InputPreprocessor:
prompt_token_ids=decoder_prompt_ids, prompt_token_ids=decoder_prompt_ids,
prompt=decoder_prompt, prompt=decoder_prompt,
multi_modal_data=decoder_mm_data, multi_modal_data=decoder_mm_data,
mm_processor_kwargs=mm_processor_kwargs,
encoder_prompt_token_ids=encoder_prompt_ids, encoder_prompt_token_ids=encoder_prompt_ids,
encoder_prompt=encoder_prompt, encoder_prompt=encoder_prompt,
encoder_multi_modal_data=encoder_mm_data, encoder_multi_modal_data=encoder_mm_data,
@ -367,21 +380,30 @@ class InputPreprocessor:
) )
if (decoder_input := prompt["decoder_prompt"]) is None: if (decoder_input := prompt["decoder_prompt"]) is None:
decoder_comps = None, None, None decoder_comps = None, None, None, None
else: else:
decoder_comps = self._extract_prompt_components( decoder_comps = self._extract_prompt_components(
decoder_input, decoder_input,
request_id=request_id, request_id=request_id,
) )
# Handle this carefully in case it was directly initialized by user
mm_processor_kwargs = prompt.get("mm_processor_kwargs", {})
else: else:
encoder_comps = self._extract_prompt_components( encoder_comps = self._extract_prompt_components(
prompt, prompt,
request_id=request_id, request_id=request_id,
) )
# If there are no decoder components, we assume the
# mm_processor_kwargs are in the encoder prompt
mm_processor_kwargs = encoder_comps[-1] if encoder_comps[
-1] is not None else {}
decoder_comps = None, None, None, None
decoder_comps = None, None, None return self._build_enc_dec_llm_inputs(
encoder_comps,
return self._build_enc_dec_llm_inputs(encoder_comps, decoder_comps) decoder_comps,
mm_processor_kwargs,
)
async def _process_encoder_decoder_prompt_async( async def _process_encoder_decoder_prompt_async(
self, self,
@ -400,7 +422,7 @@ class InputPreprocessor:
if (decoder_input := prompt["decoder_prompt"]) is None: if (decoder_input := prompt["decoder_prompt"]) is None:
encoder_comps = await encoder_task encoder_comps = await encoder_task
decoder_comps = None, None, None decoder_comps = None, None, None, None
else: else:
decoder_task = self._extract_prompt_components_async( decoder_task = self._extract_prompt_components_async(
decoder_input, decoder_input,
@ -409,29 +431,39 @@ class InputPreprocessor:
encoder_comps, decoder_comps = await asyncio.gather( encoder_comps, decoder_comps = await asyncio.gather(
encoder_task, decoder_task) encoder_task, decoder_task)
mm_processor_kwargs = prompt["mm_processor_kwargs"]
else: else:
encoder_comps = await self._extract_prompt_components_async( encoder_comps = await self._extract_prompt_components_async(
prompt, prompt,
request_id=request_id, request_id=request_id,
) )
# If there are no decoder components, we assume the
# mm_processor_kwargs are in the encoder prompt
mm_processor_kwargs = encoder_comps[-1] if encoder_comps[
-1] is not None else {}
decoder_comps = None, None, None, None
decoder_comps = None, None, None return self._build_enc_dec_llm_inputs(
encoder_comps,
return self._build_enc_dec_llm_inputs(encoder_comps, decoder_comps) decoder_comps,
mm_processor_kwargs,
)
def _build_decoder_only_llm_inputs( def _build_decoder_only_llm_inputs(
self, self,
prompt_comps: PromptComponents, prompt_comps: PromptComponents,
prompt_adapter_request: Optional[PromptAdapterRequest], prompt_adapter_request: Optional[PromptAdapterRequest],
) -> LLMInputs: ) -> LLMInputs:
prompt, prompt_token_ids, multi_modal_data = prompt_comps (prompt, prompt_token_ids, multi_modal_data,
mm_processor_kwargs) = prompt_comps
prompt_token_ids = self._apply_prompt_adapter( prompt_token_ids = self._apply_prompt_adapter(
prompt_token_ids, prompt_adapter_request=prompt_adapter_request) prompt_token_ids, prompt_adapter_request=prompt_adapter_request)
return LLMInputs(prompt_token_ids=prompt_token_ids, return LLMInputs(prompt_token_ids=prompt_token_ids,
prompt=prompt, prompt=prompt,
multi_modal_data=multi_modal_data) multi_modal_data=multi_modal_data,
mm_processor_kwargs=mm_processor_kwargs)
def _process_decoder_only_prompt( def _process_decoder_only_prompt(
self, self,

View File

@ -9,7 +9,8 @@ from transformers import PretrainedConfig
from typing_extensions import TypeVar from typing_extensions import TypeVar
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.utils import get_allowed_kwarg_only_overrides, print_warning_once from vllm.utils import (get_allowed_kwarg_only_overrides, print_warning_once,
resolve_mm_processor_kwargs)
from .data import LLMInputs from .data import LLMInputs
@ -293,8 +294,14 @@ class InputRegistry:
model_cls, _ = get_model_architecture(model_config) model_cls, _ = get_model_architecture(model_config)
processor = self._get_model_input_processor(model_cls) processor = self._get_model_input_processor(model_cls)
mm_processor_kwargs = get_allowed_kwarg_only_overrides( # Handle multimodal processor kwargs with priority:
processor, overrides=model_config.mm_processor_kwargs) # Inference kwargs -> Init kwargs -> {}
# If it's empty, it'll fall back to the default kwarg values
mm_processor_kwargs = resolve_mm_processor_kwargs(
model_config.mm_processor_kwargs,
inputs.get("mm_processor_kwargs"),
processor,
)
return processor(InputContext(model_config), inputs, return processor(InputContext(model_config), inputs,
**mm_processor_kwargs) **mm_processor_kwargs)

View File

@ -8,8 +8,8 @@ class AudioPlugin(MultiModalPlugin):
def get_data_key(self) -> str: def get_data_key(self) -> str:
return "audio" return "audio"
def _default_input_mapper(self, ctx: InputContext, def _default_input_mapper(self, ctx: InputContext, data: object,
data: object) -> MultiModalInputs: **mm_processor_kwargs) -> MultiModalInputs:
raise NotImplementedError("There is no default audio input mapper") raise NotImplementedError("There is no default audio input mapper")
def _default_max_multimodal_tokens(self, ctx: InputContext) -> int: def _default_max_multimodal_tokens(self, ctx: InputContext) -> int:

View File

@ -1,7 +1,7 @@
import sys import sys
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from collections import UserDict, defaultdict from collections import UserDict, defaultdict
from typing import (Callable, Dict, List, Mapping, Optional, Tuple, Type, from typing import (Any, Callable, Dict, List, Mapping, Optional, Tuple, Type,
TypedDict, TypeVar, Union, cast, final) TypedDict, TypeVar, Union, cast, final)
import numpy as np import numpy as np
@ -15,7 +15,7 @@ from vllm.config import ModelConfig
from vllm.inputs import InputContext from vllm.inputs import InputContext
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.utils import (JSONTree, get_allowed_kwarg_only_overrides, is_list_of, from vllm.utils import (JSONTree, get_allowed_kwarg_only_overrides, is_list_of,
json_map_leaves) json_map_leaves, resolve_mm_processor_kwargs)
logger = init_logger(__name__) logger = init_logger(__name__)
@ -200,6 +200,7 @@ class MultiModalPlugin(ABC):
self, self,
ctx: InputContext, ctx: InputContext,
data: MultiModalData[object], data: MultiModalData[object],
**mm_processor_kwargs,
) -> MultiModalInputs: ) -> MultiModalInputs:
""" """
Return a dictionary to be passed as keyword arguments to Return a dictionary to be passed as keyword arguments to
@ -243,7 +244,8 @@ class MultiModalPlugin(ABC):
return wrapper return wrapper
def map_input(self, model_config: ModelConfig, def map_input(self, model_config: ModelConfig,
data: MultiModalData[object]) -> MultiModalInputs: data: MultiModalData[object],
mm_processor_kwargs: Dict[str, Any]) -> MultiModalInputs:
""" """
Transform the data into a dictionary of model inputs using the Transform the data into a dictionary of model inputs using the
input mapper registered for that model. input mapper registered for that model.
@ -263,19 +265,26 @@ class MultiModalPlugin(ABC):
model_cls, _ = get_model_architecture(model_config) model_cls, _ = get_model_architecture(model_config)
mapper = self._input_mappers.get(model_cls) mapper = self._input_mappers.get(model_cls)
# Only get processor kwargs at mapping time if we are not using the
# input mapper; no overrides are used on the default here because they
# should be passed to the huggingface resource at initialization time.
if mapper is not None and mapper != self._default_input_mapper:
mm_processor_kwargs = get_allowed_kwarg_only_overrides(
mapper, overrides=model_config.mm_processor_kwargs)
else:
mm_processor_kwargs = {}
if mapper is None: if mapper is None:
raise KeyError(f"No input mapper in {self} is registered for " raise KeyError(f"No input mapper in {self} is registered for "
f"model class {model_cls.__name__}.") f"model class {model_cls.__name__}.")
# In the case of the default mapper, we have to get resource
# processor through its HuggingFace autoclass; since this goes
# through **kwargs, we can't inspect it the same way, so we allow
# drop mm_processor_kwargs based on signature inspection
# if we're using the default mapper.
#
# This should be safe in general due to the sanitation, since the
# transformers resource should filter unused kwargs anyway.
uses_default_mapper = mapper == self._default_input_mapper
mm_processor_kwargs = resolve_mm_processor_kwargs(
model_config.mm_processor_kwargs,
mm_processor_kwargs,
callable=mapper,
allow_var_kwargs=uses_default_mapper,
)
return mapper(InputContext(model_config), data, **mm_processor_kwargs) return mapper(InputContext(model_config), data, **mm_processor_kwargs)
@abstractmethod @abstractmethod

View File

@ -1,4 +1,5 @@
from functools import lru_cache from functools import lru_cache
from typing import Any, Dict, Optional
import torch import torch
from PIL import Image from PIL import Image
@ -23,11 +24,13 @@ class ImagePlugin(MultiModalPlugin):
def get_data_key(self) -> str: def get_data_key(self) -> str:
return "image" return "image"
def _get_hf_image_processor(self, model_config: ModelConfig): def _get_hf_image_processor(
mm_processor_kwargs = ({} if model_config.mm_processor_kwargs is None self,
else model_config.mm_processor_kwargs) model_config: ModelConfig,
# We don't explicitly check kwarg overrides to the HF class mm_processor_kwargs: Optional[Dict[str, Any]] = None,
# since the automodel just takes kwargs, so we can't inspect it ):
if mm_processor_kwargs is None:
mm_processor_kwargs = {}
return cached_get_image_processor( return cached_get_image_processor(
model_config.model, model_config.model,
trust_remote_code=model_config.trust_remote_code, trust_remote_code=model_config.trust_remote_code,
@ -37,6 +40,7 @@ class ImagePlugin(MultiModalPlugin):
self, self,
ctx: InputContext, ctx: InputContext,
data: MultiModalData[object], data: MultiModalData[object],
**mm_processor_kwargs,
) -> MultiModalInputs: ) -> MultiModalInputs:
model_config = ctx.model_config model_config = ctx.model_config
@ -46,12 +50,20 @@ class ImagePlugin(MultiModalPlugin):
# PIL image # PIL image
if isinstance(data, Image.Image) or is_list_of(data, Image.Image): if isinstance(data, Image.Image) or is_list_of(data, Image.Image):
image_processor = self._get_hf_image_processor(model_config) image_processor = self._get_hf_image_processor(
model_config,
mm_processor_kwargs,
)
if image_processor is None: if image_processor is None:
raise RuntimeError("No HuggingFace processor is available " raise RuntimeError("No HuggingFace processor is available "
"to process the image object") "to process the image object")
try: try:
# NOTE: It may make sense to forward the mm_processor_kwargs
# here too. For now, to keep it simple, we only allow it be
# used for the initialization call though, just in case the
# signatures of the preprocessor initializer don't match
# preprocess()
batch_data = image_processor \ batch_data = image_processor \
.preprocess(data, return_tensors="pt") \ .preprocess(data, return_tensors="pt") \
.data .data

View File

@ -1,6 +1,6 @@
import functools import functools
from collections import UserDict from collections import UserDict
from typing import Dict, Mapping, Optional, Sequence from typing import Any, Dict, Mapping, Optional, Sequence
from vllm.config import ModelConfig from vllm.config import ModelConfig
from vllm.logger import init_logger from vllm.logger import init_logger
@ -96,8 +96,12 @@ class MultiModalRegistry:
""" """
return self.register_input_mapper("image", mapper) return self.register_input_mapper("image", mapper)
def map_input(self, model_config: ModelConfig, def map_input(
data: MultiModalDataDict) -> MultiModalInputs: self,
model_config: ModelConfig,
data: MultiModalDataDict,
mm_processor_kwargs: Optional[Dict[str, Any]] = None,
) -> MultiModalInputs:
""" """
Apply an input mapper to the data passed to the model. Apply an input mapper to the data passed to the model.
@ -123,7 +127,8 @@ class MultiModalRegistry:
f"`--limit-mm-per-prompt`, but found {num_items} items " f"`--limit-mm-per-prompt`, but found {num_items} items "
"in the same prompt.") "in the same prompt.")
input_dict = plugin.map_input(model_config, data_value) input_dict = plugin.map_input(model_config, data_value,
mm_processor_kwargs)
for input_key, input_tensor in input_dict.items(): for input_key, input_tensor in input_dict.items():
if input_key in merged_dict: if input_key in merged_dict:
raise ValueError(f"The input mappers (keys={set(data)}) " raise ValueError(f"The input mappers (keys={set(data)}) "

View File

@ -1,5 +1,5 @@
from functools import lru_cache from functools import lru_cache
from typing import List, Union from typing import Any, Dict, List, Optional, Union
import numpy as np import numpy as np
@ -36,11 +36,13 @@ class VideoPlugin(ImagePlugin):
def get_data_key(self) -> str: def get_data_key(self) -> str:
return "video" return "video"
def _get_hf_video_processor(self, model_config: ModelConfig): def _get_hf_video_processor(
mm_processor_kwargs = ({} if model_config.mm_processor_kwargs is None self,
else model_config.mm_processor_kwargs) model_config: ModelConfig,
# We don't explicitly check kwarg overrides to the HF class mm_processor_kwargs: Optional[Dict[str, Any]] = None,
# since the automodel just takes kwargs, so we can't inspect it ):
if mm_processor_kwargs is None:
mm_processor_kwargs = {}
return cached_get_video_processor( return cached_get_video_processor(
model_config.model, model_config.model,
trust_remote_code=model_config.trust_remote_code, trust_remote_code=model_config.trust_remote_code,
@ -50,16 +52,24 @@ class VideoPlugin(ImagePlugin):
self, self,
ctx: InputContext, ctx: InputContext,
data: MultiModalData[object], data: MultiModalData[object],
**mm_processor_kwargs,
) -> MultiModalInputs: ) -> MultiModalInputs:
model_config = ctx.model_config model_config = ctx.model_config
# single video input as np.ndarray # single video input as np.ndarray
if isinstance(data, np.ndarray): if isinstance(data, np.ndarray):
video_processor = self._get_hf_video_processor(model_config) video_processor = self._get_hf_video_processor(
model_config,
mm_processor_kwargs,
)
if video_processor is None: if video_processor is None:
raise RuntimeError("No HuggingFace processor is available " raise RuntimeError("No HuggingFace processor is available "
"to process the image object") "to process the image object")
try: try:
# NOTE: Similar to image; it may be a good idea to filter and
# pass mm_processor_kwargs here too, but for now we don't to
# avoid extra complexity if the initializer and preprocess
# signatures of the processor don't align
batch_data = video_processor(data, return_tensors="pt").data batch_data = video_processor(data, return_tensors="pt").data
except Exception: except Exception:
logger.error("Failed to process image (%s)", data) logger.error("Failed to process image (%s)", data)

View File

@ -481,6 +481,10 @@ class Sequence:
EncoderDecoderLLMInputs, EncoderDecoderLLMInputs,
inputs).get("encoder_multi_modal_data")) or {} inputs).get("encoder_multi_modal_data")) or {}
@property
def mm_processor_kwargs(self) -> Dict[str, Any]:
return self.inputs.get("mm_processor_kwargs") or {}
@property @property
def lora_int_id(self) -> int: def lora_int_id(self) -> int:
return self.lora_request.lora_int_id if self.lora_request else 0 return self.lora_request.lora_int_id if self.lora_request else 0
@ -710,6 +714,14 @@ class SequenceGroup:
# We use the multi-modal data of an arbitrary sequence. # We use the multi-modal data of an arbitrary sequence.
return self.seqs[0].multi_modal_data return self.seqs[0].multi_modal_data
@property
def mm_processor_kwargs(self) -> Dict[str, Any]:
# As with multi-modal data, all sequences in the group should have the
# same processor kwargs (i.e., mm_processor_kwargs are optionally
# provided per request; note that are independent of whether the model
# decoder-only or an encoder-decoder).
return self.seqs[0].mm_processor_kwargs
@property @property
def lora_int_id(self) -> int: def lora_int_id(self) -> int:
return self.lora_request.lora_int_id if self.lora_request else 0 return self.lora_request.lora_int_id if self.lora_request else 0
@ -949,6 +961,7 @@ class SequenceGroupMetadata(
used in prefix caching. used in prefix caching.
state: Internal state tied to this sequence group. state: Internal state tied to this sequence group.
multi_modal_data: Multi modal data. multi_modal_data: Multi modal data.
mm_processor_kwargs: Multimodal input processor / mapper overrides.
encoder_seq_data: Optional sequence data for encoder prompt encoder_seq_data: Optional sequence data for encoder prompt
(SequenceGroup.encoder_seq). Should be None (SequenceGroup.encoder_seq). Should be None
unless you are working with an encoder/decoder unless you are working with an encoder/decoder
@ -975,6 +988,7 @@ class SequenceGroupMetadata(
# "MultiModalDataDict" types. We have to use Any due to msgspec # "MultiModalDataDict" types. We have to use Any due to msgspec
# doesn't allow to have union of 2 different dicts. # doesn't allow to have union of 2 different dicts.
multi_modal_data: Optional[Any] = None multi_modal_data: Optional[Any] = None
mm_processor_kwargs: Optional[Dict[str, Any]] = None
encoder_seq_data: Optional[SequenceData] = None encoder_seq_data: Optional[SequenceData] = None
cross_block_table: Optional[List[int]] = None cross_block_table: Optional[List[int]] = None
prompt_adapter_request: Optional[PromptAdapterRequest] = None prompt_adapter_request: Optional[PromptAdapterRequest] = None

View File

@ -1277,18 +1277,87 @@ async def _run_task_with_lock(task: Callable, lock: asyncio.Lock, *args,
return await task(*args, **kwargs) return await task(*args, **kwargs)
def supports_kw(callable: Callable[..., object], kw_name: str) -> bool: def supports_kw(
callable: Callable[..., object],
kw_name: str,
requires_kw_only: bool = False,
allow_var_kwargs: bool = True,
) -> bool:
"""Check if a keyword is a valid kwarg for a callable; if requires_kw_only
disallows kwargs names that can also be positional arguments.
"""
params = inspect.signature(callable).parameters params = inspect.signature(callable).parameters
if kw_name in params: if not params:
return True return False
return any(param.kind == inspect.Parameter.VAR_KEYWORD param_val = params.get(kw_name)
for param in params.values())
# Types where the it may be valid, i.e., explicitly defined & nonvariadic
passable_kw_types = set((inspect.Parameter.POSITIONAL_ONLY,
inspect.Parameter.POSITIONAL_OR_KEYWORD,
inspect.Parameter.KEYWORD_ONLY))
if param_val:
is_sig_param = param_val.kind in passable_kw_types
# We want kwargs only, but this is passable as a positional arg
if (requires_kw_only and is_sig_param
and param_val.kind != inspect.Parameter.KEYWORD_ONLY):
return False
if ((requires_kw_only
and param_val.kind == inspect.Parameter.KEYWORD_ONLY)
or (not requires_kw_only and is_sig_param)):
return True
# If we're okay with var-kwargs, it's supported as long as
# the kw_name isn't something like *args, **kwargs
if allow_var_kwargs:
# Get the last param; type is ignored here because params is a proxy
# mapping, but it wraps an ordered dict, and they appear in order.
# Ref: https://docs.python.org/3/library/inspect.html#inspect.Signature.parameters
last_param = params[next(reversed(params))] # type: ignore
return (last_param.kind == inspect.Parameter.VAR_KEYWORD
and last_param.name != kw_name)
return False
def resolve_mm_processor_kwargs(
init_kwargs: Optional[Dict[str, Any]],
inference_kwargs: Optional[Dict[str, Any]],
callable: Callable[..., object],
allow_var_kwargs: bool = False,
) -> Dict[str, Any]:
"""Applies filtering to eliminate invalid mm_processor_kwargs, i.e.,
those who are not explicit keywords to the given callable (of one is
given; otherwise no filtering is done), then merges the kwarg dicts,
giving priority to inference_kwargs if there are any collisions.
In the case that no kwarg overrides are provided, returns an empty
dict so that it can still be kwarg expanded into the callable later on.
If allow_var_kwargs=True, allows for things that can be expanded into
kwargs as long as they aren't naming collision for var_kwargs or potential
positional arguments.
"""
# Filter inference time multimodal processor kwargs provided
runtime_mm_kwargs = get_allowed_kwarg_only_overrides(
callable,
overrides=inference_kwargs,
allow_var_kwargs=allow_var_kwargs)
# Filter init time multimodal processor kwargs provided
init_mm_kwargs = get_allowed_kwarg_only_overrides(
callable, overrides=init_kwargs, allow_var_kwargs=allow_var_kwargs)
# Merge the final processor kwargs, prioritizing inference
# time values over the initialization time values.
mm_processor_kwargs = {**init_mm_kwargs, **runtime_mm_kwargs}
return mm_processor_kwargs
def get_allowed_kwarg_only_overrides( def get_allowed_kwarg_only_overrides(
callable: Callable[..., object], callable: Callable[..., object],
overrides: Optional[Dict[str, Any]], overrides: Optional[Dict[str, Any]],
allow_var_kwargs: bool = False,
) -> Dict[str, Any]: ) -> Dict[str, Any]:
""" """
Given a callable which has one or more keyword only params and a dict Given a callable which has one or more keyword only params and a dict
@ -1300,7 +1369,9 @@ def get_allowed_kwarg_only_overrides(
Args: Args:
callable: Callable which takes 0 or more keyword only arguments. callable: Callable which takes 0 or more keyword only arguments.
If None is provided, all overrides names are allowed.
overrides: Potential overrides to be used when invoking the callable. overrides: Potential overrides to be used when invoking the callable.
allow_var_kwargs: Allows overrides that are expandable for var kwargs.
Returns: Returns:
Dictionary containing the kwargs to be leveraged which may be used Dictionary containing the kwargs to be leveraged which may be used
@ -1310,17 +1381,15 @@ def get_allowed_kwarg_only_overrides(
if not overrides: if not overrides:
return {} return {}
allowed_override_names = [ # Drop any mm_processor_kwargs provided by the user that
name for name, param in inspect.signature(callable).parameters.items() # are not kwargs, unless it can fit it var_kwargs param
if param.kind == inspect.Parameter.KEYWORD_ONLY
]
# Drop any mm_processor_kwargs provided by the user that are
# not kwarg names accepted by the provided input processor.
filtered_overrides = { filtered_overrides = {
kwarg_name: val kwarg_name: val
for kwarg_name, val in overrides.items() for kwarg_name, val in overrides.items()
if kwarg_name in allowed_override_names if supports_kw(callable,
kwarg_name,
requires_kw_only=True,
allow_var_kwargs=allow_var_kwargs)
} }
# If anything is dropped, log a warning # If anything is dropped, log a warning

View File

@ -148,8 +148,9 @@ class ModelInputForCPUBuilder(ModelRunnerInputBuilderBase[ModelInputForCPU]):
) )
def _compute_multi_modal_input(self, seq_data: SequenceData, mm_data, def _compute_multi_modal_input(self, seq_data: SequenceData, mm_data,
computed_len: int): computed_len: int,
mm_kwargs = self.multi_modal_input_mapper(mm_data) mm_processor_kwargs: Dict[str, Any]):
mm_kwargs = self.multi_modal_input_mapper(mm_data, mm_processor_kwargs)
# special processing for mrope position deltas. # special processing for mrope position deltas.
mrope_positions = None mrope_positions = None
@ -210,7 +211,8 @@ class ModelInputForCPUBuilder(ModelRunnerInputBuilderBase[ModelInputForCPU]):
mrope_positions = None mrope_positions = None
if (mm_data := seq_group_metadata.multi_modal_data): if (mm_data := seq_group_metadata.multi_modal_data):
mm_kwargs, mrope_positions = self._compute_multi_modal_input( mm_kwargs, mrope_positions = self._compute_multi_modal_input(
seq_data, mm_data, computed_len) seq_data, mm_data, computed_len,
seq_group_metadata.mm_processor_kwargs)
multi_modal_inputs_list.append(mm_kwargs) multi_modal_inputs_list.append(mm_kwargs)
# Token position ids # Token position ids

View File

@ -640,7 +640,9 @@ class ModelInputForGPUBuilder(ModelRunnerInputBuilderBase[ModelInputForGPU]):
if not mm_data: if not mm_data:
return return
mm_kwargs = self.multi_modal_input_mapper(mm_data) mm_kwargs = self.multi_modal_input_mapper(
mm_data,
mm_processor_kwargs=seq_group_metadata.mm_processor_kwargs)
inter_data.multi_modal_inputs = mm_kwargs inter_data.multi_modal_inputs = mm_kwargs
# special processing for mrope position deltas. # special processing for mrope position deltas.

View File

@ -153,7 +153,10 @@ class NeuronModelRunner(ModelRunnerBase[ModelInputForNeuron]):
mm_data = seq_group_metadata.multi_modal_data mm_data = seq_group_metadata.multi_modal_data
if mm_data: if mm_data:
# Process multi-modal data # Process multi-modal data
mm_kwargs = self.multi_modal_input_mapper(mm_data) mm_kwargs = self.multi_modal_input_mapper(
mm_data,
mm_processor_kwargs=seq_group_metadata.mm_processor_kwargs,
)
multi_modal_inputs_list.append(mm_kwargs) multi_modal_inputs_list.append(mm_kwargs)
max_seq_len = max(seq_lens) max_seq_len = max(seq_lens)

View File

@ -172,7 +172,11 @@ class OpenVINOModelRunner:
mm_data = seq_group_metadata.multi_modal_data mm_data = seq_group_metadata.multi_modal_data
if mm_data: if mm_data:
mm_kwargs = self.multi_modal_input_mapper(mm_data) mm_kwargs = self.multi_modal_input_mapper(
mm_data,
mm_processor_kwargs=seq_group_metadata.
mm_processor_kwargs,
)
multi_modal_inputs_list.append(mm_kwargs) multi_modal_inputs_list.append(mm_kwargs)
block_table = seq_group_metadata.block_tables[seq_id] block_table = seq_group_metadata.block_tables[seq_id]