[Core][Frontend] Add Support for Inference Time mm_processor_kwargs (#9131)
Signed-off-by: Alex-Brooks <Alex.Brooks@ibm.com>
This commit is contained in:
parent
8c746226c9
commit
a3691b6b5e
@ -105,6 +105,7 @@ def run_phi3v(question: str, modality: str):
|
||||
trust_remote_code=True,
|
||||
max_model_len=4096,
|
||||
max_num_seqs=2,
|
||||
# Note - mm_processor_kwargs can also be passed to generate/chat calls
|
||||
mm_processor_kwargs={"num_crops": 16},
|
||||
)
|
||||
stop_token_ids = None
|
||||
|
@ -74,11 +74,11 @@ def mm_model_cls():
|
||||
# lambda whose signature matches max token calcs extra & mapper + extra kwargs
|
||||
get_num_crops = lambda ctx, *, num_crops=DEFAULT_NUM_CROPS: num_crops
|
||||
custom_mapper = lambda ctx, data, *, num_crops=DEFAULT_NUM_CROPS: {
|
||||
"num_pixels": torch.zeros(size=(1, num_crops + 1, 3, 336, 336))
|
||||
"pixel_values": torch.zeros(size=(1, num_crops + 1, 3, 336, 336))
|
||||
}
|
||||
|
||||
|
||||
### Test for default processor logic & mm_processor_kwargs wrapping
|
||||
### Tests for default processor logic & mm_processor_kwargs wrapping
|
||||
def test_default_processor_is_a_noop():
|
||||
"""Ensure that by default, there is no processor override."""
|
||||
dummy_registry = InputRegistry()
|
||||
@ -89,23 +89,46 @@ def test_default_processor_is_a_noop():
|
||||
assert proc_inputs is proc_outputs
|
||||
|
||||
|
||||
@pytest.mark.parametrize("num_crops", [None, NUM_CROPS_OVERRIDE])
|
||||
def test_processor_default_kwargs(use_processor_mock, num_crops):
|
||||
"""Ensure input processors can use processor kwargs."""
|
||||
dummy_registry = InputRegistry()
|
||||
def _get_num_crops_info(init_num_crops: int, inference_num_crops: int):
|
||||
"""Get the init / inference kwargs and expected num_crops for this test."""
|
||||
# If we have a value for num_crops, pass the override value and make
|
||||
# sure we get that value as a return-value from out mock processor,
|
||||
# otherwise fall back to the default value
|
||||
mm_processor_kwargs = None if num_crops is None else {
|
||||
"num_crops": num_crops
|
||||
init_kwargs = None if init_num_crops is None else {
|
||||
"num_crops": init_num_crops
|
||||
}
|
||||
expected_num_crops = DEFAULT_NUM_CROPS if num_crops is None else num_crops
|
||||
ctx = build_model_context(DUMMY_MODEL_ID,
|
||||
mm_processor_kwargs=mm_processor_kwargs)
|
||||
processor = dummy_registry.create_input_processor(ctx.model_config)
|
||||
inference_kwargs = None if inference_num_crops is None else {
|
||||
"num_crops": inference_num_crops
|
||||
}
|
||||
if inference_num_crops is not None:
|
||||
expected_seq_count = inference_num_crops
|
||||
elif init_num_crops is not None:
|
||||
expected_seq_count = init_num_crops
|
||||
else:
|
||||
expected_seq_count = DEFAULT_NUM_CROPS
|
||||
return init_kwargs, inference_kwargs, expected_seq_count
|
||||
|
||||
num_crops_val = processor(LLMInputs(prompt_token_ids=[], prompt=""))
|
||||
assert num_crops_val == expected_num_crops
|
||||
|
||||
@pytest.mark.parametrize("init_num_crops,inference_num_crops", [
|
||||
(None, None),
|
||||
(NUM_CROPS_OVERRIDE, None),
|
||||
(DEFAULT_NUM_CROPS, NUM_CROPS_OVERRIDE),
|
||||
])
|
||||
def test_input_processor_kwargs(use_processor_mock, init_num_crops,
|
||||
inference_num_crops):
|
||||
"""Ensure input processors can use processor kwargs."""
|
||||
dummy_registry = InputRegistry()
|
||||
|
||||
init_kwargs, inference_kwargs, expected_seq_count = _get_num_crops_info(
|
||||
init_num_crops, inference_num_crops)
|
||||
|
||||
ctx = build_model_context(DUMMY_MODEL_ID, mm_processor_kwargs=init_kwargs)
|
||||
processor = dummy_registry.create_input_processor(ctx.model_config)
|
||||
num_crops_val = processor(
|
||||
LLMInputs(prompt_token_ids=[],
|
||||
prompt="",
|
||||
mm_processor_kwargs=inference_kwargs))
|
||||
assert num_crops_val == expected_seq_count
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
@ -124,11 +147,16 @@ def test_processor_with_sad_kwarg_overrides(use_processor_mock,
|
||||
mm_processor_kwargs):
|
||||
"""Ensure that input processors filter out invalid mm_processor_kwargs"""
|
||||
dummy_registry = InputRegistry()
|
||||
# Should filter out the init time kwargs
|
||||
ctx = build_model_context(DUMMY_MODEL_ID,
|
||||
mm_processor_kwargs=mm_processor_kwargs)
|
||||
|
||||
processor = dummy_registry.create_input_processor(ctx.model_config)
|
||||
num_crops_val = processor(LLMInputs(prompt_token_ids=[], prompt=""))
|
||||
# Should filter out the inference time kwargs
|
||||
num_crops_val = processor(
|
||||
LLMInputs(prompt_token_ids=[],
|
||||
prompt="",
|
||||
mm_processor_kwargs=mm_processor_kwargs))
|
||||
assert num_crops_val == DEFAULT_NUM_CROPS
|
||||
|
||||
|
||||
@ -271,32 +299,34 @@ def test_default_mapper_with_processer_kwargs(image_assets, num_crops):
|
||||
assert mapped_inputs["pixel_values"].shape[1] == num_crops + 1
|
||||
|
||||
|
||||
@pytest.mark.parametrize("num_crops", [None, NUM_CROPS_OVERRIDE])
|
||||
def test_custom_mapper_kwarg_overrides(image_assets, num_crops):
|
||||
@pytest.mark.parametrize("init_num_crops,inference_num_crops", [
|
||||
(None, None),
|
||||
(NUM_CROPS_OVERRIDE, None),
|
||||
(DEFAULT_NUM_CROPS, NUM_CROPS_OVERRIDE),
|
||||
])
|
||||
def test_custom_mapper_kwarg_overrides(image_assets, init_num_crops,
|
||||
inference_num_crops):
|
||||
"""Ensure custom mappers can use processor kwargs."""
|
||||
mm_processor_kwargs = None if num_crops is None else {
|
||||
"num_crops": num_crops
|
||||
}
|
||||
expected_seq_count = DEFAULT_NUM_CROPS if num_crops is None else num_crops
|
||||
init_kwargs, inference_kwargs, expected_seq_count = _get_num_crops_info(
|
||||
init_num_crops, inference_num_crops)
|
||||
|
||||
ctx = build_model_context(MULTIMODAL_MODEL_ID,
|
||||
trust_remote_code=True,
|
||||
mm_processor_kwargs=mm_processor_kwargs,
|
||||
mm_processor_kwargs=init_kwargs,
|
||||
limit_mm_per_prompt={"image": 1})
|
||||
|
||||
mm_registry = MultiModalRegistry()
|
||||
mm_registry.init_mm_limits_per_prompt(ctx.model_config)
|
||||
# Patch the image registry for phi3v with our lambda that is compatible
|
||||
# with overrides, then ensure that calling the method correctly echos
|
||||
# our num_crops value back from the mm_processor_kwargs.
|
||||
image = image_assets[0].pil_image
|
||||
mm_inputs = {"image": image}
|
||||
|
||||
with patch.object(
|
||||
mm_registry._get_plugin("image"),
|
||||
"_default_input_mapper",
|
||||
{mm_model_cls(): custom_mapper},
|
||||
):
|
||||
mapped_inputs = mm_registry.map_input(ctx.model_config, mm_inputs)
|
||||
# Patch the image registry for phi3v with our lambda that is compatible
|
||||
# with overrides, then ensure that calling the method correctly echos
|
||||
# our num_crops value back from the mm_processor_kwargs.
|
||||
mm_registry._get_plugin("image").register_input_mapper(custom_mapper)(
|
||||
mm_model_cls())
|
||||
mapped_inputs = mm_registry.map_input(ctx.model_config, mm_inputs,
|
||||
inference_kwargs)
|
||||
|
||||
assert mapped_inputs["pixel_values"].shape[1] == expected_seq_count + 1
|
||||
|
||||
@ -316,6 +346,7 @@ def test_custom_mapper_kwarg_overrides(image_assets, num_crops):
|
||||
def test_custom_mapper_with_sad_kwarg_overrides(image_assets,
|
||||
mm_processor_kwargs):
|
||||
"""Ensure that custom mappers filters out invalid mm_processor_kwargs"""
|
||||
# Should filter out the init time kwargs
|
||||
ctx = build_model_context(MULTIMODAL_MODEL_ID,
|
||||
trust_remote_code=True,
|
||||
mm_processor_kwargs=mm_processor_kwargs,
|
||||
@ -323,17 +354,16 @@ def test_custom_mapper_with_sad_kwarg_overrides(image_assets,
|
||||
|
||||
mm_registry = MultiModalRegistry()
|
||||
mm_registry.init_mm_limits_per_prompt(ctx.model_config)
|
||||
# Patch the image registry for phi3v with our lambda that is compatible
|
||||
# with overrides, then ensure that calling the method correctly echos
|
||||
# our num_crops value back from the mm_processor_kwargs.
|
||||
image = image_assets[0].pil_image
|
||||
mm_inputs = {"image": image}
|
||||
|
||||
with patch.object(
|
||||
mm_registry._get_plugin("image"),
|
||||
"_default_input_mapper",
|
||||
{mm_model_cls(): custom_mapper},
|
||||
):
|
||||
mapped_inputs = mm_registry.map_input(ctx.model_config, mm_inputs)
|
||||
# Patch the image registry for phi3v with our lambda that is compatible
|
||||
# with overrides, then ensure that calling the method correctly echos
|
||||
# our num_crops value back from the mm_processor_kwargs.
|
||||
mm_registry._get_plugin("image").register_input_mapper(custom_mapper)(
|
||||
mm_model_cls())
|
||||
# Should filter out the inference time kwargs
|
||||
mapped_inputs = mm_registry.map_input(
|
||||
ctx.model_config, mm_inputs, mm_processor_kwargs=mm_processor_kwargs)
|
||||
|
||||
assert mapped_inputs["pixel_values"].shape[1] == DEFAULT_NUM_CROPS + 1
|
||||
|
@ -2,6 +2,7 @@ from typing import List
|
||||
|
||||
import pytest
|
||||
|
||||
from vllm.inputs import zip_enc_dec_prompts
|
||||
from vllm.inputs.parse import parse_and_batch_prompt
|
||||
|
||||
STRING_INPUTS = [
|
||||
@ -51,3 +52,28 @@ def test_parse_single_batch_token_consistent(token_input: List[int]):
|
||||
def test_parse_single_batch_string_slice(inputs_slice: slice):
|
||||
assert parse_and_batch_prompt(STRING_INPUTS)[inputs_slice] \
|
||||
== parse_and_batch_prompt(STRING_INPUTS[inputs_slice])
|
||||
|
||||
|
||||
# yapf: disable
|
||||
@pytest.mark.parametrize('mm_processor_kwargs,expected_mm_kwargs', [
|
||||
(None, [{}, {}]),
|
||||
({}, [{}, {}]),
|
||||
({"foo": 100}, [{"foo": 100}, {"foo": 100}]),
|
||||
([{"foo": 100}, {"bar": 200}], [{"foo": 100}, {"bar": 200}]),
|
||||
])
|
||||
# yapf: enable
|
||||
def test_zip_enc_dec_prompts(mm_processor_kwargs, expected_mm_kwargs):
|
||||
"""Test mm_processor_kwargs init for zipping enc/dec prompts."""
|
||||
encoder_prompts = ['An encoder prompt', 'Another encoder prompt']
|
||||
decoder_prompts = ['A decoder prompt', 'Another decoder prompt']
|
||||
zipped_prompts = zip_enc_dec_prompts(encoder_prompts, decoder_prompts,
|
||||
mm_processor_kwargs)
|
||||
assert len(zipped_prompts) == len(encoder_prompts) == len(decoder_prompts)
|
||||
for enc, dec, exp_kwargs, zipped in zip(encoder_prompts, decoder_prompts,
|
||||
expected_mm_kwargs,
|
||||
zipped_prompts):
|
||||
assert isinstance(zipped, dict)
|
||||
assert len(zipped.keys()) == 3
|
||||
assert zipped['encoder_prompt'] == enc
|
||||
assert zipped['decoder_prompt'] == dec
|
||||
assert zipped['mm_processor_kwargs'] == exp_kwargs
|
||||
|
@ -7,7 +7,7 @@ from typing import AsyncIterator, Tuple
|
||||
import pytest
|
||||
|
||||
from vllm.utils import (FlexibleArgumentParser, deprecate_kwargs,
|
||||
get_open_port, merge_async_iterators)
|
||||
get_open_port, merge_async_iterators, supports_kw)
|
||||
|
||||
from .utils import error_on_warning
|
||||
|
||||
@ -236,3 +236,33 @@ def test_no_model_tag(parser_with_config):
|
||||
with pytest.raises(ValueError):
|
||||
parser_with_config.parse_args(
|
||||
['serve', '--config', './data/test_config.yaml'])
|
||||
|
||||
|
||||
# yapf: enable
|
||||
@pytest.mark.parametrize(
|
||||
"callable,kw_name,requires_kw_only,allow_var_kwargs,is_supported",
|
||||
[
|
||||
# Tests for positional argument support
|
||||
(lambda foo: None, "foo", True, True, False),
|
||||
(lambda foo: None, "foo", False, True, True),
|
||||
# Tests for positional or keyword / keyword only
|
||||
(lambda foo=100: None, "foo", True, True, False),
|
||||
(lambda *, foo: None, "foo", False, True, True),
|
||||
# Tests to make sure the names of variadic params are NOT supported
|
||||
(lambda *args: None, "args", False, True, False),
|
||||
(lambda **kwargs: None, "kwargs", False, True, False),
|
||||
# Tests for if we allow var kwargs to add support
|
||||
(lambda foo: None, "something_else", False, True, False),
|
||||
(lambda foo, **kwargs: None, "something_else", False, True, True),
|
||||
(lambda foo, **kwargs: None, "kwargs", True, True, False),
|
||||
(lambda foo, **kwargs: None, "foo", True, True, False),
|
||||
])
|
||||
# yapf: disable
|
||||
def test_supports_kw(callable,kw_name,requires_kw_only,
|
||||
allow_var_kwargs,is_supported):
|
||||
assert supports_kw(
|
||||
callable=callable,
|
||||
kw_name=kw_name,
|
||||
requires_kw_only=requires_kw_only,
|
||||
allow_var_kwargs=allow_var_kwargs
|
||||
) == is_supported
|
||||
|
@ -1309,6 +1309,7 @@ class Scheduler:
|
||||
# `multi_modal_data` will be None.
|
||||
multi_modal_data=seq_group.multi_modal_data
|
||||
if scheduler_outputs.num_prefill_groups > 0 else None,
|
||||
mm_processor_kwargs=seq_group.mm_processor_kwargs,
|
||||
prompt_adapter_request=seq_group.prompt_adapter_request,
|
||||
)
|
||||
else:
|
||||
|
@ -811,6 +811,13 @@ class LLMEngine:
|
||||
)
|
||||
processed_inputs = self.input_processor(preprocessed_inputs)
|
||||
|
||||
# This is a bit of a hack - copy the mm_processor_kwargs that were
|
||||
# used in the input processor to the processed output, since these
|
||||
# kwargs are presumed to be immutable and the values should be aligned
|
||||
# between the input processor (here) and the input mapper.
|
||||
processed_inputs["mm_processor_kwargs"] = preprocessed_inputs.get(
|
||||
"mm_processor_kwargs")
|
||||
|
||||
self._add_processed_request(
|
||||
request_id=request_id,
|
||||
processed_inputs=processed_inputs,
|
||||
|
@ -472,6 +472,7 @@ class LLM:
|
||||
add_generation_prompt: bool = True,
|
||||
continue_final_message: bool = False,
|
||||
tools: Optional[List[Dict[str, Any]]] = None,
|
||||
mm_processor_kwargs: Optional[Dict[str, Any]] = None,
|
||||
) -> List[RequestOutput]:
|
||||
"""
|
||||
Generate responses for a chat conversation.
|
||||
@ -501,6 +502,8 @@ class LLM:
|
||||
continue_final_message: If True, continues the final message in
|
||||
the conversation instead of starting a new one. Cannot be `True`
|
||||
if `add_generation_prompt` is also `True`.
|
||||
mm_processor_kwargs: Multimodal processor kwarg overrides for this
|
||||
chat request. Only used for offline requests.
|
||||
|
||||
Returns:
|
||||
A list of ``RequestOutput`` objects containing the generated
|
||||
@ -522,6 +525,9 @@ class LLM:
|
||||
tokenizer = self.get_tokenizer()
|
||||
model_config = self.llm_engine.get_model_config()
|
||||
|
||||
# NOTE: _parse_chat_message_content_parts() currently doesn't
|
||||
# handle mm_processor_kwargs, since there is no implementation in
|
||||
# the chat message parsing for it.
|
||||
conversation, mm_data = parse_chat_messages(
|
||||
msgs, model_config, tokenizer)
|
||||
|
||||
@ -554,6 +560,9 @@ class LLM:
|
||||
if mm_data is not None:
|
||||
prompt["multi_modal_data"] = mm_data
|
||||
|
||||
if mm_processor_kwargs is not None:
|
||||
prompt["mm_processor_kwargs"] = mm_processor_kwargs
|
||||
|
||||
prompts.append(prompt)
|
||||
|
||||
return self.generate(
|
||||
|
@ -1,5 +1,5 @@
|
||||
from typing import (TYPE_CHECKING, Generic, Iterable, List, Optional, Tuple,
|
||||
Union)
|
||||
from typing import (TYPE_CHECKING, Any, Dict, Generic, Iterable, List,
|
||||
Optional, Tuple, Union)
|
||||
|
||||
from typing_extensions import NotRequired, TypedDict, TypeVar
|
||||
|
||||
@ -19,6 +19,14 @@ class TextPrompt(TypedDict):
|
||||
if the model supports it.
|
||||
"""
|
||||
|
||||
mm_processor_kwargs: NotRequired[Dict[str, Any]]
|
||||
"""
|
||||
Optional multi-modal processor kwargs to be forwarded to the
|
||||
multimodal input mapper & processor. Note that if multiple modalities
|
||||
have registered mappers etc for the model being considered, we attempt
|
||||
to pass the mm_processor_kwargs to each of them.
|
||||
"""
|
||||
|
||||
|
||||
class TokensPrompt(TypedDict):
|
||||
"""Schema for a tokenized prompt."""
|
||||
@ -32,6 +40,14 @@ class TokensPrompt(TypedDict):
|
||||
if the model supports it.
|
||||
"""
|
||||
|
||||
mm_processor_kwargs: NotRequired[Dict[str, Any]]
|
||||
"""
|
||||
Optional multi-modal processor kwargs to be forwarded to the
|
||||
multimodal input mapper & processor. Note that if multiple modalities
|
||||
have registered mappers etc for the model being considered, we attempt
|
||||
to pass the mm_processor_kwargs to each of them.
|
||||
"""
|
||||
|
||||
|
||||
SingletonPrompt = Union[str, TextPrompt, TokensPrompt]
|
||||
"""
|
||||
@ -74,7 +90,9 @@ class ExplicitEncoderDecoderPrompt(TypedDict, Generic[_T1_co, _T2_co]):
|
||||
according to any of the :class:`SingletonPrompt` schemas,
|
||||
and are not required to have the same schema.
|
||||
|
||||
Only the encoder prompt may have multi-modal data.
|
||||
Only the encoder prompt may have multi-modal data. mm_processor_kwargs
|
||||
should be at the top-level, and should not be set in the encoder/decoder
|
||||
prompts, since they are agnostic to the encoder/decoder.
|
||||
|
||||
Note that an :class:`ExplicitEncoderDecoderPrompt` may not
|
||||
be used as an input to a decoder-only model,
|
||||
@ -87,6 +105,8 @@ class ExplicitEncoderDecoderPrompt(TypedDict, Generic[_T1_co, _T2_co]):
|
||||
|
||||
decoder_prompt: Optional[_T2_co]
|
||||
|
||||
mm_processor_kwargs: NotRequired[Dict[str, Any]]
|
||||
|
||||
|
||||
PromptType = Union[SingletonPrompt, ExplicitEncoderDecoderPrompt]
|
||||
"""
|
||||
@ -121,6 +141,14 @@ class LLMInputs(TypedDict):
|
||||
if the model supports it.
|
||||
"""
|
||||
|
||||
mm_processor_kwargs: NotRequired[Optional[Dict[str, Any]]]
|
||||
"""
|
||||
Optional multi-modal processor kwargs to be forwarded to the
|
||||
multimodal input mapper & processor. Note that if multiple modalities
|
||||
have registered mappers etc for the model being considered, we attempt
|
||||
to pass the mm_processor_kwargs to each of them.
|
||||
"""
|
||||
|
||||
|
||||
class EncoderDecoderLLMInputs(LLMInputs):
|
||||
"""
|
||||
@ -152,22 +180,43 @@ _T2 = TypeVar("_T2", bound=SingletonPrompt, default=SingletonPrompt)
|
||||
def build_explicit_enc_dec_prompt(
|
||||
encoder_prompt: _T1,
|
||||
decoder_prompt: Optional[_T2],
|
||||
mm_processor_kwargs: Optional[Dict[str, Any]] = None,
|
||||
) -> ExplicitEncoderDecoderPrompt[_T1, _T2]:
|
||||
return ExplicitEncoderDecoderPrompt(encoder_prompt=encoder_prompt,
|
||||
decoder_prompt=decoder_prompt)
|
||||
if mm_processor_kwargs is None:
|
||||
mm_processor_kwargs = {}
|
||||
return ExplicitEncoderDecoderPrompt(
|
||||
encoder_prompt=encoder_prompt,
|
||||
decoder_prompt=decoder_prompt,
|
||||
mm_processor_kwargs=mm_processor_kwargs)
|
||||
|
||||
|
||||
def zip_enc_dec_prompts(
|
||||
enc_prompts: Iterable[_T1],
|
||||
dec_prompts: Iterable[Optional[_T2]],
|
||||
mm_processor_kwargs: Optional[Union[Iterable[Dict[str, Any]],
|
||||
Dict[str, Any]]] = None,
|
||||
) -> List[ExplicitEncoderDecoderPrompt[_T1, _T2]]:
|
||||
"""
|
||||
Zip encoder and decoder prompts together into a list of
|
||||
:class:`ExplicitEncoderDecoderPrompt` instances.
|
||||
:class:`ExplicitEncoderDecoderPrompt` instances. mm_processor_kwargs
|
||||
may also be provided; if a dict is passed, the same dictionary will be
|
||||
used for every encoder/decoder prompt. If an iterable is provided, it will
|
||||
be zipped with the encoder/decoder prompts.
|
||||
"""
|
||||
if mm_processor_kwargs is None:
|
||||
mm_processor_kwargs = {}
|
||||
if isinstance(mm_processor_kwargs, Dict):
|
||||
return [
|
||||
build_explicit_enc_dec_prompt(encoder_prompt, decoder_prompt,
|
||||
mm_processor_kwargs)
|
||||
for (encoder_prompt,
|
||||
decoder_prompt) in zip(enc_prompts, dec_prompts)
|
||||
]
|
||||
return [
|
||||
build_explicit_enc_dec_prompt(encoder_prompt, decoder_prompt)
|
||||
for (encoder_prompt, decoder_prompt) in zip(enc_prompts, dec_prompts)
|
||||
build_explicit_enc_dec_prompt(encoder_prompt, decoder_prompt,
|
||||
mm_proc_kwargs)
|
||||
for (encoder_prompt, decoder_prompt, mm_proc_kwargs
|
||||
) in zip(enc_prompts, dec_prompts, mm_processor_kwargs)
|
||||
]
|
||||
|
||||
|
||||
|
@ -1,5 +1,5 @@
|
||||
import asyncio
|
||||
from typing import TYPE_CHECKING, List, Optional, Tuple, Union
|
||||
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union
|
||||
|
||||
from typing_extensions import assert_never
|
||||
|
||||
@ -20,9 +20,11 @@ if TYPE_CHECKING:
|
||||
logger = init_logger(__name__)
|
||||
|
||||
PromptComponents = Tuple[Optional[str], List[int],
|
||||
Optional["MultiModalDataDict"]]
|
||||
Optional["MultiModalDataDict"], Optional[Dict[str,
|
||||
Any]]]
|
||||
DecoderPromptComponents = Tuple[Optional[str], Optional[List[int]],
|
||||
Optional["MultiModalDataDict"]]
|
||||
Optional["MultiModalDataDict"],
|
||||
Optional[Dict[str, Any]]]
|
||||
|
||||
|
||||
class InputPreprocessor:
|
||||
@ -227,6 +229,7 @@ class InputPreprocessor:
|
||||
* prompt
|
||||
* prompt_token_ids
|
||||
* multi_modal_data
|
||||
* mm_processor_kwargs (request-level input processor/mapper overrides)
|
||||
'''
|
||||
|
||||
parsed = parse_singleton_prompt(prompt)
|
||||
@ -239,10 +242,12 @@ class InputPreprocessor:
|
||||
lora_request=lora_request,
|
||||
)
|
||||
multi_modal_data = None
|
||||
mm_processor_kwargs = None
|
||||
elif parsed["type"] == "tokens":
|
||||
prompt_text = None
|
||||
prompt_token_ids = parsed["content"]["prompt_token_ids"]
|
||||
multi_modal_data = parsed["content"].get("multi_modal_data")
|
||||
mm_processor_kwargs = parsed["content"].get("mm_processor_kwargs")
|
||||
elif parsed["type"] == "text":
|
||||
prompt_text = parsed["content"]["prompt"]
|
||||
prompt_token_ids = self._tokenize_prompt(
|
||||
@ -251,10 +256,12 @@ class InputPreprocessor:
|
||||
lora_request=lora_request,
|
||||
)
|
||||
multi_modal_data = parsed["content"].get("multi_modal_data")
|
||||
mm_processor_kwargs = parsed["content"].get("mm_processor_kwargs")
|
||||
else:
|
||||
assert_never(parsed)
|
||||
|
||||
return prompt_text, prompt_token_ids, multi_modal_data
|
||||
return (prompt_text, prompt_token_ids, multi_modal_data,
|
||||
mm_processor_kwargs)
|
||||
|
||||
async def _extract_prompt_components_async(
|
||||
self,
|
||||
@ -273,10 +280,12 @@ class InputPreprocessor:
|
||||
lora_request=lora_request,
|
||||
)
|
||||
multi_modal_data = None
|
||||
mm_processor_kwargs = None
|
||||
elif parsed["type"] == "tokens":
|
||||
prompt_text = None
|
||||
prompt_token_ids = parsed["content"]["prompt_token_ids"]
|
||||
multi_modal_data = parsed["content"].get("multi_modal_data")
|
||||
mm_processor_kwargs = parsed["content"].get("mm_processor_kwargs")
|
||||
elif parsed["type"] == "text":
|
||||
prompt_text = parsed["content"]["prompt"]
|
||||
prompt_token_ids = await self._tokenize_prompt_async(
|
||||
@ -285,18 +294,21 @@ class InputPreprocessor:
|
||||
lora_request=lora_request,
|
||||
)
|
||||
multi_modal_data = parsed["content"].get("multi_modal_data")
|
||||
mm_processor_kwargs = parsed["content"].get("mm_processor_kwargs")
|
||||
else:
|
||||
assert_never(parsed)
|
||||
|
||||
return prompt_text, prompt_token_ids, multi_modal_data
|
||||
return (prompt_text, prompt_token_ids, multi_modal_data,
|
||||
mm_processor_kwargs)
|
||||
|
||||
def _build_enc_dec_llm_inputs(
|
||||
self,
|
||||
encoder_comps: PromptComponents,
|
||||
decoder_comps: DecoderPromptComponents,
|
||||
mm_processor_kwargs: Dict[str, Any],
|
||||
) -> EncoderDecoderLLMInputs:
|
||||
encoder_prompt, encoder_prompt_ids, encoder_mm_data = encoder_comps
|
||||
decoder_prompt, decoder_prompt_ids, decoder_mm_data = decoder_comps
|
||||
encoder_prompt, encoder_prompt_ids, encoder_mm_data, _ = encoder_comps
|
||||
decoder_prompt, decoder_prompt_ids, decoder_mm_data, _ = decoder_comps
|
||||
|
||||
if decoder_mm_data is not None:
|
||||
raise ValueError(
|
||||
@ -314,6 +326,7 @@ class InputPreprocessor:
|
||||
prompt_token_ids=decoder_prompt_ids,
|
||||
prompt=decoder_prompt,
|
||||
multi_modal_data=decoder_mm_data,
|
||||
mm_processor_kwargs=mm_processor_kwargs,
|
||||
encoder_prompt_token_ids=encoder_prompt_ids,
|
||||
encoder_prompt=encoder_prompt,
|
||||
encoder_multi_modal_data=encoder_mm_data,
|
||||
@ -367,21 +380,30 @@ class InputPreprocessor:
|
||||
)
|
||||
|
||||
if (decoder_input := prompt["decoder_prompt"]) is None:
|
||||
decoder_comps = None, None, None
|
||||
decoder_comps = None, None, None, None
|
||||
else:
|
||||
decoder_comps = self._extract_prompt_components(
|
||||
decoder_input,
|
||||
request_id=request_id,
|
||||
)
|
||||
# Handle this carefully in case it was directly initialized by user
|
||||
mm_processor_kwargs = prompt.get("mm_processor_kwargs", {})
|
||||
else:
|
||||
encoder_comps = self._extract_prompt_components(
|
||||
prompt,
|
||||
request_id=request_id,
|
||||
)
|
||||
# If there are no decoder components, we assume the
|
||||
# mm_processor_kwargs are in the encoder prompt
|
||||
mm_processor_kwargs = encoder_comps[-1] if encoder_comps[
|
||||
-1] is not None else {}
|
||||
decoder_comps = None, None, None, None
|
||||
|
||||
decoder_comps = None, None, None
|
||||
|
||||
return self._build_enc_dec_llm_inputs(encoder_comps, decoder_comps)
|
||||
return self._build_enc_dec_llm_inputs(
|
||||
encoder_comps,
|
||||
decoder_comps,
|
||||
mm_processor_kwargs,
|
||||
)
|
||||
|
||||
async def _process_encoder_decoder_prompt_async(
|
||||
self,
|
||||
@ -400,7 +422,7 @@ class InputPreprocessor:
|
||||
|
||||
if (decoder_input := prompt["decoder_prompt"]) is None:
|
||||
encoder_comps = await encoder_task
|
||||
decoder_comps = None, None, None
|
||||
decoder_comps = None, None, None, None
|
||||
else:
|
||||
decoder_task = self._extract_prompt_components_async(
|
||||
decoder_input,
|
||||
@ -409,29 +431,39 @@ class InputPreprocessor:
|
||||
|
||||
encoder_comps, decoder_comps = await asyncio.gather(
|
||||
encoder_task, decoder_task)
|
||||
mm_processor_kwargs = prompt["mm_processor_kwargs"]
|
||||
else:
|
||||
encoder_comps = await self._extract_prompt_components_async(
|
||||
prompt,
|
||||
request_id=request_id,
|
||||
)
|
||||
# If there are no decoder components, we assume the
|
||||
# mm_processor_kwargs are in the encoder prompt
|
||||
mm_processor_kwargs = encoder_comps[-1] if encoder_comps[
|
||||
-1] is not None else {}
|
||||
decoder_comps = None, None, None, None
|
||||
|
||||
decoder_comps = None, None, None
|
||||
|
||||
return self._build_enc_dec_llm_inputs(encoder_comps, decoder_comps)
|
||||
return self._build_enc_dec_llm_inputs(
|
||||
encoder_comps,
|
||||
decoder_comps,
|
||||
mm_processor_kwargs,
|
||||
)
|
||||
|
||||
def _build_decoder_only_llm_inputs(
|
||||
self,
|
||||
prompt_comps: PromptComponents,
|
||||
prompt_adapter_request: Optional[PromptAdapterRequest],
|
||||
) -> LLMInputs:
|
||||
prompt, prompt_token_ids, multi_modal_data = prompt_comps
|
||||
(prompt, prompt_token_ids, multi_modal_data,
|
||||
mm_processor_kwargs) = prompt_comps
|
||||
|
||||
prompt_token_ids = self._apply_prompt_adapter(
|
||||
prompt_token_ids, prompt_adapter_request=prompt_adapter_request)
|
||||
|
||||
return LLMInputs(prompt_token_ids=prompt_token_ids,
|
||||
prompt=prompt,
|
||||
multi_modal_data=multi_modal_data)
|
||||
multi_modal_data=multi_modal_data,
|
||||
mm_processor_kwargs=mm_processor_kwargs)
|
||||
|
||||
def _process_decoder_only_prompt(
|
||||
self,
|
||||
|
@ -9,7 +9,8 @@ from transformers import PretrainedConfig
|
||||
from typing_extensions import TypeVar
|
||||
|
||||
from vllm.logger import init_logger
|
||||
from vllm.utils import get_allowed_kwarg_only_overrides, print_warning_once
|
||||
from vllm.utils import (get_allowed_kwarg_only_overrides, print_warning_once,
|
||||
resolve_mm_processor_kwargs)
|
||||
|
||||
from .data import LLMInputs
|
||||
|
||||
@ -293,8 +294,14 @@ class InputRegistry:
|
||||
model_cls, _ = get_model_architecture(model_config)
|
||||
processor = self._get_model_input_processor(model_cls)
|
||||
|
||||
mm_processor_kwargs = get_allowed_kwarg_only_overrides(
|
||||
processor, overrides=model_config.mm_processor_kwargs)
|
||||
# Handle multimodal processor kwargs with priority:
|
||||
# Inference kwargs -> Init kwargs -> {}
|
||||
# If it's empty, it'll fall back to the default kwarg values
|
||||
mm_processor_kwargs = resolve_mm_processor_kwargs(
|
||||
model_config.mm_processor_kwargs,
|
||||
inputs.get("mm_processor_kwargs"),
|
||||
processor,
|
||||
)
|
||||
|
||||
return processor(InputContext(model_config), inputs,
|
||||
**mm_processor_kwargs)
|
||||
|
@ -8,8 +8,8 @@ class AudioPlugin(MultiModalPlugin):
|
||||
def get_data_key(self) -> str:
|
||||
return "audio"
|
||||
|
||||
def _default_input_mapper(self, ctx: InputContext,
|
||||
data: object) -> MultiModalInputs:
|
||||
def _default_input_mapper(self, ctx: InputContext, data: object,
|
||||
**mm_processor_kwargs) -> MultiModalInputs:
|
||||
raise NotImplementedError("There is no default audio input mapper")
|
||||
|
||||
def _default_max_multimodal_tokens(self, ctx: InputContext) -> int:
|
||||
|
@ -1,7 +1,7 @@
|
||||
import sys
|
||||
from abc import ABC, abstractmethod
|
||||
from collections import UserDict, defaultdict
|
||||
from typing import (Callable, Dict, List, Mapping, Optional, Tuple, Type,
|
||||
from typing import (Any, Callable, Dict, List, Mapping, Optional, Tuple, Type,
|
||||
TypedDict, TypeVar, Union, cast, final)
|
||||
|
||||
import numpy as np
|
||||
@ -15,7 +15,7 @@ from vllm.config import ModelConfig
|
||||
from vllm.inputs import InputContext
|
||||
from vllm.logger import init_logger
|
||||
from vllm.utils import (JSONTree, get_allowed_kwarg_only_overrides, is_list_of,
|
||||
json_map_leaves)
|
||||
json_map_leaves, resolve_mm_processor_kwargs)
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
@ -200,6 +200,7 @@ class MultiModalPlugin(ABC):
|
||||
self,
|
||||
ctx: InputContext,
|
||||
data: MultiModalData[object],
|
||||
**mm_processor_kwargs,
|
||||
) -> MultiModalInputs:
|
||||
"""
|
||||
Return a dictionary to be passed as keyword arguments to
|
||||
@ -243,7 +244,8 @@ class MultiModalPlugin(ABC):
|
||||
return wrapper
|
||||
|
||||
def map_input(self, model_config: ModelConfig,
|
||||
data: MultiModalData[object]) -> MultiModalInputs:
|
||||
data: MultiModalData[object],
|
||||
mm_processor_kwargs: Dict[str, Any]) -> MultiModalInputs:
|
||||
"""
|
||||
Transform the data into a dictionary of model inputs using the
|
||||
input mapper registered for that model.
|
||||
@ -263,19 +265,26 @@ class MultiModalPlugin(ABC):
|
||||
model_cls, _ = get_model_architecture(model_config)
|
||||
|
||||
mapper = self._input_mappers.get(model_cls)
|
||||
# Only get processor kwargs at mapping time if we are not using the
|
||||
# input mapper; no overrides are used on the default here because they
|
||||
# should be passed to the huggingface resource at initialization time.
|
||||
if mapper is not None and mapper != self._default_input_mapper:
|
||||
mm_processor_kwargs = get_allowed_kwarg_only_overrides(
|
||||
mapper, overrides=model_config.mm_processor_kwargs)
|
||||
else:
|
||||
mm_processor_kwargs = {}
|
||||
|
||||
if mapper is None:
|
||||
raise KeyError(f"No input mapper in {self} is registered for "
|
||||
f"model class {model_cls.__name__}.")
|
||||
|
||||
# In the case of the default mapper, we have to get resource
|
||||
# processor through its HuggingFace autoclass; since this goes
|
||||
# through **kwargs, we can't inspect it the same way, so we allow
|
||||
# drop mm_processor_kwargs based on signature inspection
|
||||
# if we're using the default mapper.
|
||||
#
|
||||
# This should be safe in general due to the sanitation, since the
|
||||
# transformers resource should filter unused kwargs anyway.
|
||||
uses_default_mapper = mapper == self._default_input_mapper
|
||||
mm_processor_kwargs = resolve_mm_processor_kwargs(
|
||||
model_config.mm_processor_kwargs,
|
||||
mm_processor_kwargs,
|
||||
callable=mapper,
|
||||
allow_var_kwargs=uses_default_mapper,
|
||||
)
|
||||
return mapper(InputContext(model_config), data, **mm_processor_kwargs)
|
||||
|
||||
@abstractmethod
|
||||
|
@ -1,4 +1,5 @@
|
||||
from functools import lru_cache
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
import torch
|
||||
from PIL import Image
|
||||
@ -23,11 +24,13 @@ class ImagePlugin(MultiModalPlugin):
|
||||
def get_data_key(self) -> str:
|
||||
return "image"
|
||||
|
||||
def _get_hf_image_processor(self, model_config: ModelConfig):
|
||||
mm_processor_kwargs = ({} if model_config.mm_processor_kwargs is None
|
||||
else model_config.mm_processor_kwargs)
|
||||
# We don't explicitly check kwarg overrides to the HF class
|
||||
# since the automodel just takes kwargs, so we can't inspect it
|
||||
def _get_hf_image_processor(
|
||||
self,
|
||||
model_config: ModelConfig,
|
||||
mm_processor_kwargs: Optional[Dict[str, Any]] = None,
|
||||
):
|
||||
if mm_processor_kwargs is None:
|
||||
mm_processor_kwargs = {}
|
||||
return cached_get_image_processor(
|
||||
model_config.model,
|
||||
trust_remote_code=model_config.trust_remote_code,
|
||||
@ -37,6 +40,7 @@ class ImagePlugin(MultiModalPlugin):
|
||||
self,
|
||||
ctx: InputContext,
|
||||
data: MultiModalData[object],
|
||||
**mm_processor_kwargs,
|
||||
) -> MultiModalInputs:
|
||||
model_config = ctx.model_config
|
||||
|
||||
@ -46,12 +50,20 @@ class ImagePlugin(MultiModalPlugin):
|
||||
|
||||
# PIL image
|
||||
if isinstance(data, Image.Image) or is_list_of(data, Image.Image):
|
||||
image_processor = self._get_hf_image_processor(model_config)
|
||||
image_processor = self._get_hf_image_processor(
|
||||
model_config,
|
||||
mm_processor_kwargs,
|
||||
)
|
||||
|
||||
if image_processor is None:
|
||||
raise RuntimeError("No HuggingFace processor is available "
|
||||
"to process the image object")
|
||||
try:
|
||||
# NOTE: It may make sense to forward the mm_processor_kwargs
|
||||
# here too. For now, to keep it simple, we only allow it be
|
||||
# used for the initialization call though, just in case the
|
||||
# signatures of the preprocessor initializer don't match
|
||||
# preprocess()
|
||||
batch_data = image_processor \
|
||||
.preprocess(data, return_tensors="pt") \
|
||||
.data
|
||||
|
@ -1,6 +1,6 @@
|
||||
import functools
|
||||
from collections import UserDict
|
||||
from typing import Dict, Mapping, Optional, Sequence
|
||||
from typing import Any, Dict, Mapping, Optional, Sequence
|
||||
|
||||
from vllm.config import ModelConfig
|
||||
from vllm.logger import init_logger
|
||||
@ -96,8 +96,12 @@ class MultiModalRegistry:
|
||||
"""
|
||||
return self.register_input_mapper("image", mapper)
|
||||
|
||||
def map_input(self, model_config: ModelConfig,
|
||||
data: MultiModalDataDict) -> MultiModalInputs:
|
||||
def map_input(
|
||||
self,
|
||||
model_config: ModelConfig,
|
||||
data: MultiModalDataDict,
|
||||
mm_processor_kwargs: Optional[Dict[str, Any]] = None,
|
||||
) -> MultiModalInputs:
|
||||
"""
|
||||
Apply an input mapper to the data passed to the model.
|
||||
|
||||
@ -123,7 +127,8 @@ class MultiModalRegistry:
|
||||
f"`--limit-mm-per-prompt`, but found {num_items} items "
|
||||
"in the same prompt.")
|
||||
|
||||
input_dict = plugin.map_input(model_config, data_value)
|
||||
input_dict = plugin.map_input(model_config, data_value,
|
||||
mm_processor_kwargs)
|
||||
for input_key, input_tensor in input_dict.items():
|
||||
if input_key in merged_dict:
|
||||
raise ValueError(f"The input mappers (keys={set(data)}) "
|
||||
|
@ -1,5 +1,5 @@
|
||||
from functools import lru_cache
|
||||
from typing import List, Union
|
||||
from typing import Any, Dict, List, Optional, Union
|
||||
|
||||
import numpy as np
|
||||
|
||||
@ -36,11 +36,13 @@ class VideoPlugin(ImagePlugin):
|
||||
def get_data_key(self) -> str:
|
||||
return "video"
|
||||
|
||||
def _get_hf_video_processor(self, model_config: ModelConfig):
|
||||
mm_processor_kwargs = ({} if model_config.mm_processor_kwargs is None
|
||||
else model_config.mm_processor_kwargs)
|
||||
# We don't explicitly check kwarg overrides to the HF class
|
||||
# since the automodel just takes kwargs, so we can't inspect it
|
||||
def _get_hf_video_processor(
|
||||
self,
|
||||
model_config: ModelConfig,
|
||||
mm_processor_kwargs: Optional[Dict[str, Any]] = None,
|
||||
):
|
||||
if mm_processor_kwargs is None:
|
||||
mm_processor_kwargs = {}
|
||||
return cached_get_video_processor(
|
||||
model_config.model,
|
||||
trust_remote_code=model_config.trust_remote_code,
|
||||
@ -50,16 +52,24 @@ class VideoPlugin(ImagePlugin):
|
||||
self,
|
||||
ctx: InputContext,
|
||||
data: MultiModalData[object],
|
||||
**mm_processor_kwargs,
|
||||
) -> MultiModalInputs:
|
||||
model_config = ctx.model_config
|
||||
|
||||
# single video input as np.ndarray
|
||||
if isinstance(data, np.ndarray):
|
||||
video_processor = self._get_hf_video_processor(model_config)
|
||||
video_processor = self._get_hf_video_processor(
|
||||
model_config,
|
||||
mm_processor_kwargs,
|
||||
)
|
||||
if video_processor is None:
|
||||
raise RuntimeError("No HuggingFace processor is available "
|
||||
"to process the image object")
|
||||
try:
|
||||
# NOTE: Similar to image; it may be a good idea to filter and
|
||||
# pass mm_processor_kwargs here too, but for now we don't to
|
||||
# avoid extra complexity if the initializer and preprocess
|
||||
# signatures of the processor don't align
|
||||
batch_data = video_processor(data, return_tensors="pt").data
|
||||
except Exception:
|
||||
logger.error("Failed to process image (%s)", data)
|
||||
|
@ -481,6 +481,10 @@ class Sequence:
|
||||
EncoderDecoderLLMInputs,
|
||||
inputs).get("encoder_multi_modal_data")) or {}
|
||||
|
||||
@property
|
||||
def mm_processor_kwargs(self) -> Dict[str, Any]:
|
||||
return self.inputs.get("mm_processor_kwargs") or {}
|
||||
|
||||
@property
|
||||
def lora_int_id(self) -> int:
|
||||
return self.lora_request.lora_int_id if self.lora_request else 0
|
||||
@ -710,6 +714,14 @@ class SequenceGroup:
|
||||
# We use the multi-modal data of an arbitrary sequence.
|
||||
return self.seqs[0].multi_modal_data
|
||||
|
||||
@property
|
||||
def mm_processor_kwargs(self) -> Dict[str, Any]:
|
||||
# As with multi-modal data, all sequences in the group should have the
|
||||
# same processor kwargs (i.e., mm_processor_kwargs are optionally
|
||||
# provided per request; note that are independent of whether the model
|
||||
# decoder-only or an encoder-decoder).
|
||||
return self.seqs[0].mm_processor_kwargs
|
||||
|
||||
@property
|
||||
def lora_int_id(self) -> int:
|
||||
return self.lora_request.lora_int_id if self.lora_request else 0
|
||||
@ -949,6 +961,7 @@ class SequenceGroupMetadata(
|
||||
used in prefix caching.
|
||||
state: Internal state tied to this sequence group.
|
||||
multi_modal_data: Multi modal data.
|
||||
mm_processor_kwargs: Multimodal input processor / mapper overrides.
|
||||
encoder_seq_data: Optional sequence data for encoder prompt
|
||||
(SequenceGroup.encoder_seq). Should be None
|
||||
unless you are working with an encoder/decoder
|
||||
@ -975,6 +988,7 @@ class SequenceGroupMetadata(
|
||||
# "MultiModalDataDict" types. We have to use Any due to msgspec
|
||||
# doesn't allow to have union of 2 different dicts.
|
||||
multi_modal_data: Optional[Any] = None
|
||||
mm_processor_kwargs: Optional[Dict[str, Any]] = None
|
||||
encoder_seq_data: Optional[SequenceData] = None
|
||||
cross_block_table: Optional[List[int]] = None
|
||||
prompt_adapter_request: Optional[PromptAdapterRequest] = None
|
||||
|
@ -1277,18 +1277,87 @@ async def _run_task_with_lock(task: Callable, lock: asyncio.Lock, *args,
|
||||
return await task(*args, **kwargs)
|
||||
|
||||
|
||||
def supports_kw(callable: Callable[..., object], kw_name: str) -> bool:
|
||||
def supports_kw(
|
||||
callable: Callable[..., object],
|
||||
kw_name: str,
|
||||
requires_kw_only: bool = False,
|
||||
allow_var_kwargs: bool = True,
|
||||
) -> bool:
|
||||
"""Check if a keyword is a valid kwarg for a callable; if requires_kw_only
|
||||
disallows kwargs names that can also be positional arguments.
|
||||
"""
|
||||
params = inspect.signature(callable).parameters
|
||||
if kw_name in params:
|
||||
return True
|
||||
if not params:
|
||||
return False
|
||||
|
||||
return any(param.kind == inspect.Parameter.VAR_KEYWORD
|
||||
for param in params.values())
|
||||
param_val = params.get(kw_name)
|
||||
|
||||
# Types where the it may be valid, i.e., explicitly defined & nonvariadic
|
||||
passable_kw_types = set((inspect.Parameter.POSITIONAL_ONLY,
|
||||
inspect.Parameter.POSITIONAL_OR_KEYWORD,
|
||||
inspect.Parameter.KEYWORD_ONLY))
|
||||
|
||||
if param_val:
|
||||
is_sig_param = param_val.kind in passable_kw_types
|
||||
# We want kwargs only, but this is passable as a positional arg
|
||||
if (requires_kw_only and is_sig_param
|
||||
and param_val.kind != inspect.Parameter.KEYWORD_ONLY):
|
||||
return False
|
||||
if ((requires_kw_only
|
||||
and param_val.kind == inspect.Parameter.KEYWORD_ONLY)
|
||||
or (not requires_kw_only and is_sig_param)):
|
||||
return True
|
||||
|
||||
# If we're okay with var-kwargs, it's supported as long as
|
||||
# the kw_name isn't something like *args, **kwargs
|
||||
if allow_var_kwargs:
|
||||
# Get the last param; type is ignored here because params is a proxy
|
||||
# mapping, but it wraps an ordered dict, and they appear in order.
|
||||
# Ref: https://docs.python.org/3/library/inspect.html#inspect.Signature.parameters
|
||||
last_param = params[next(reversed(params))] # type: ignore
|
||||
return (last_param.kind == inspect.Parameter.VAR_KEYWORD
|
||||
and last_param.name != kw_name)
|
||||
return False
|
||||
|
||||
|
||||
def resolve_mm_processor_kwargs(
|
||||
init_kwargs: Optional[Dict[str, Any]],
|
||||
inference_kwargs: Optional[Dict[str, Any]],
|
||||
callable: Callable[..., object],
|
||||
allow_var_kwargs: bool = False,
|
||||
) -> Dict[str, Any]:
|
||||
"""Applies filtering to eliminate invalid mm_processor_kwargs, i.e.,
|
||||
those who are not explicit keywords to the given callable (of one is
|
||||
given; otherwise no filtering is done), then merges the kwarg dicts,
|
||||
giving priority to inference_kwargs if there are any collisions.
|
||||
|
||||
In the case that no kwarg overrides are provided, returns an empty
|
||||
dict so that it can still be kwarg expanded into the callable later on.
|
||||
|
||||
If allow_var_kwargs=True, allows for things that can be expanded into
|
||||
kwargs as long as they aren't naming collision for var_kwargs or potential
|
||||
positional arguments.
|
||||
"""
|
||||
# Filter inference time multimodal processor kwargs provided
|
||||
runtime_mm_kwargs = get_allowed_kwarg_only_overrides(
|
||||
callable,
|
||||
overrides=inference_kwargs,
|
||||
allow_var_kwargs=allow_var_kwargs)
|
||||
|
||||
# Filter init time multimodal processor kwargs provided
|
||||
init_mm_kwargs = get_allowed_kwarg_only_overrides(
|
||||
callable, overrides=init_kwargs, allow_var_kwargs=allow_var_kwargs)
|
||||
|
||||
# Merge the final processor kwargs, prioritizing inference
|
||||
# time values over the initialization time values.
|
||||
mm_processor_kwargs = {**init_mm_kwargs, **runtime_mm_kwargs}
|
||||
return mm_processor_kwargs
|
||||
|
||||
|
||||
def get_allowed_kwarg_only_overrides(
|
||||
callable: Callable[..., object],
|
||||
overrides: Optional[Dict[str, Any]],
|
||||
allow_var_kwargs: bool = False,
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Given a callable which has one or more keyword only params and a dict
|
||||
@ -1300,7 +1369,9 @@ def get_allowed_kwarg_only_overrides(
|
||||
|
||||
Args:
|
||||
callable: Callable which takes 0 or more keyword only arguments.
|
||||
If None is provided, all overrides names are allowed.
|
||||
overrides: Potential overrides to be used when invoking the callable.
|
||||
allow_var_kwargs: Allows overrides that are expandable for var kwargs.
|
||||
|
||||
Returns:
|
||||
Dictionary containing the kwargs to be leveraged which may be used
|
||||
@ -1310,17 +1381,15 @@ def get_allowed_kwarg_only_overrides(
|
||||
if not overrides:
|
||||
return {}
|
||||
|
||||
allowed_override_names = [
|
||||
name for name, param in inspect.signature(callable).parameters.items()
|
||||
if param.kind == inspect.Parameter.KEYWORD_ONLY
|
||||
]
|
||||
|
||||
# Drop any mm_processor_kwargs provided by the user that are
|
||||
# not kwarg names accepted by the provided input processor.
|
||||
# Drop any mm_processor_kwargs provided by the user that
|
||||
# are not kwargs, unless it can fit it var_kwargs param
|
||||
filtered_overrides = {
|
||||
kwarg_name: val
|
||||
for kwarg_name, val in overrides.items()
|
||||
if kwarg_name in allowed_override_names
|
||||
if supports_kw(callable,
|
||||
kwarg_name,
|
||||
requires_kw_only=True,
|
||||
allow_var_kwargs=allow_var_kwargs)
|
||||
}
|
||||
|
||||
# If anything is dropped, log a warning
|
||||
|
@ -148,8 +148,9 @@ class ModelInputForCPUBuilder(ModelRunnerInputBuilderBase[ModelInputForCPU]):
|
||||
)
|
||||
|
||||
def _compute_multi_modal_input(self, seq_data: SequenceData, mm_data,
|
||||
computed_len: int):
|
||||
mm_kwargs = self.multi_modal_input_mapper(mm_data)
|
||||
computed_len: int,
|
||||
mm_processor_kwargs: Dict[str, Any]):
|
||||
mm_kwargs = self.multi_modal_input_mapper(mm_data, mm_processor_kwargs)
|
||||
|
||||
# special processing for mrope position deltas.
|
||||
mrope_positions = None
|
||||
@ -210,7 +211,8 @@ class ModelInputForCPUBuilder(ModelRunnerInputBuilderBase[ModelInputForCPU]):
|
||||
mrope_positions = None
|
||||
if (mm_data := seq_group_metadata.multi_modal_data):
|
||||
mm_kwargs, mrope_positions = self._compute_multi_modal_input(
|
||||
seq_data, mm_data, computed_len)
|
||||
seq_data, mm_data, computed_len,
|
||||
seq_group_metadata.mm_processor_kwargs)
|
||||
multi_modal_inputs_list.append(mm_kwargs)
|
||||
|
||||
# Token position ids
|
||||
|
@ -640,7 +640,9 @@ class ModelInputForGPUBuilder(ModelRunnerInputBuilderBase[ModelInputForGPU]):
|
||||
if not mm_data:
|
||||
return
|
||||
|
||||
mm_kwargs = self.multi_modal_input_mapper(mm_data)
|
||||
mm_kwargs = self.multi_modal_input_mapper(
|
||||
mm_data,
|
||||
mm_processor_kwargs=seq_group_metadata.mm_processor_kwargs)
|
||||
inter_data.multi_modal_inputs = mm_kwargs
|
||||
|
||||
# special processing for mrope position deltas.
|
||||
|
@ -153,7 +153,10 @@ class NeuronModelRunner(ModelRunnerBase[ModelInputForNeuron]):
|
||||
mm_data = seq_group_metadata.multi_modal_data
|
||||
if mm_data:
|
||||
# Process multi-modal data
|
||||
mm_kwargs = self.multi_modal_input_mapper(mm_data)
|
||||
mm_kwargs = self.multi_modal_input_mapper(
|
||||
mm_data,
|
||||
mm_processor_kwargs=seq_group_metadata.mm_processor_kwargs,
|
||||
)
|
||||
multi_modal_inputs_list.append(mm_kwargs)
|
||||
|
||||
max_seq_len = max(seq_lens)
|
||||
|
@ -172,7 +172,11 @@ class OpenVINOModelRunner:
|
||||
|
||||
mm_data = seq_group_metadata.multi_modal_data
|
||||
if mm_data:
|
||||
mm_kwargs = self.multi_modal_input_mapper(mm_data)
|
||||
mm_kwargs = self.multi_modal_input_mapper(
|
||||
mm_data,
|
||||
mm_processor_kwargs=seq_group_metadata.
|
||||
mm_processor_kwargs,
|
||||
)
|
||||
multi_modal_inputs_list.append(mm_kwargs)
|
||||
|
||||
block_table = seq_group_metadata.block_tables[seq_id]
|
||||
|
Loading…
x
Reference in New Issue
Block a user