[Core][Frontend] Add Support for Inference Time mm_processor_kwargs (#9131)
Signed-off-by: Alex-Brooks <Alex.Brooks@ibm.com>
This commit is contained in:
parent
8c746226c9
commit
a3691b6b5e
@ -105,6 +105,7 @@ def run_phi3v(question: str, modality: str):
|
|||||||
trust_remote_code=True,
|
trust_remote_code=True,
|
||||||
max_model_len=4096,
|
max_model_len=4096,
|
||||||
max_num_seqs=2,
|
max_num_seqs=2,
|
||||||
|
# Note - mm_processor_kwargs can also be passed to generate/chat calls
|
||||||
mm_processor_kwargs={"num_crops": 16},
|
mm_processor_kwargs={"num_crops": 16},
|
||||||
)
|
)
|
||||||
stop_token_ids = None
|
stop_token_ids = None
|
||||||
|
@ -74,11 +74,11 @@ def mm_model_cls():
|
|||||||
# lambda whose signature matches max token calcs extra & mapper + extra kwargs
|
# lambda whose signature matches max token calcs extra & mapper + extra kwargs
|
||||||
get_num_crops = lambda ctx, *, num_crops=DEFAULT_NUM_CROPS: num_crops
|
get_num_crops = lambda ctx, *, num_crops=DEFAULT_NUM_CROPS: num_crops
|
||||||
custom_mapper = lambda ctx, data, *, num_crops=DEFAULT_NUM_CROPS: {
|
custom_mapper = lambda ctx, data, *, num_crops=DEFAULT_NUM_CROPS: {
|
||||||
"num_pixels": torch.zeros(size=(1, num_crops + 1, 3, 336, 336))
|
"pixel_values": torch.zeros(size=(1, num_crops + 1, 3, 336, 336))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
### Test for default processor logic & mm_processor_kwargs wrapping
|
### Tests for default processor logic & mm_processor_kwargs wrapping
|
||||||
def test_default_processor_is_a_noop():
|
def test_default_processor_is_a_noop():
|
||||||
"""Ensure that by default, there is no processor override."""
|
"""Ensure that by default, there is no processor override."""
|
||||||
dummy_registry = InputRegistry()
|
dummy_registry = InputRegistry()
|
||||||
@ -89,23 +89,46 @@ def test_default_processor_is_a_noop():
|
|||||||
assert proc_inputs is proc_outputs
|
assert proc_inputs is proc_outputs
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("num_crops", [None, NUM_CROPS_OVERRIDE])
|
def _get_num_crops_info(init_num_crops: int, inference_num_crops: int):
|
||||||
def test_processor_default_kwargs(use_processor_mock, num_crops):
|
"""Get the init / inference kwargs and expected num_crops for this test."""
|
||||||
"""Ensure input processors can use processor kwargs."""
|
|
||||||
dummy_registry = InputRegistry()
|
|
||||||
# If we have a value for num_crops, pass the override value and make
|
# If we have a value for num_crops, pass the override value and make
|
||||||
# sure we get that value as a return-value from out mock processor,
|
# sure we get that value as a return-value from out mock processor,
|
||||||
# otherwise fall back to the default value
|
# otherwise fall back to the default value
|
||||||
mm_processor_kwargs = None if num_crops is None else {
|
init_kwargs = None if init_num_crops is None else {
|
||||||
"num_crops": num_crops
|
"num_crops": init_num_crops
|
||||||
}
|
}
|
||||||
expected_num_crops = DEFAULT_NUM_CROPS if num_crops is None else num_crops
|
inference_kwargs = None if inference_num_crops is None else {
|
||||||
ctx = build_model_context(DUMMY_MODEL_ID,
|
"num_crops": inference_num_crops
|
||||||
mm_processor_kwargs=mm_processor_kwargs)
|
}
|
||||||
processor = dummy_registry.create_input_processor(ctx.model_config)
|
if inference_num_crops is not None:
|
||||||
|
expected_seq_count = inference_num_crops
|
||||||
|
elif init_num_crops is not None:
|
||||||
|
expected_seq_count = init_num_crops
|
||||||
|
else:
|
||||||
|
expected_seq_count = DEFAULT_NUM_CROPS
|
||||||
|
return init_kwargs, inference_kwargs, expected_seq_count
|
||||||
|
|
||||||
num_crops_val = processor(LLMInputs(prompt_token_ids=[], prompt=""))
|
|
||||||
assert num_crops_val == expected_num_crops
|
@pytest.mark.parametrize("init_num_crops,inference_num_crops", [
|
||||||
|
(None, None),
|
||||||
|
(NUM_CROPS_OVERRIDE, None),
|
||||||
|
(DEFAULT_NUM_CROPS, NUM_CROPS_OVERRIDE),
|
||||||
|
])
|
||||||
|
def test_input_processor_kwargs(use_processor_mock, init_num_crops,
|
||||||
|
inference_num_crops):
|
||||||
|
"""Ensure input processors can use processor kwargs."""
|
||||||
|
dummy_registry = InputRegistry()
|
||||||
|
|
||||||
|
init_kwargs, inference_kwargs, expected_seq_count = _get_num_crops_info(
|
||||||
|
init_num_crops, inference_num_crops)
|
||||||
|
|
||||||
|
ctx = build_model_context(DUMMY_MODEL_ID, mm_processor_kwargs=init_kwargs)
|
||||||
|
processor = dummy_registry.create_input_processor(ctx.model_config)
|
||||||
|
num_crops_val = processor(
|
||||||
|
LLMInputs(prompt_token_ids=[],
|
||||||
|
prompt="",
|
||||||
|
mm_processor_kwargs=inference_kwargs))
|
||||||
|
assert num_crops_val == expected_seq_count
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
@ -124,11 +147,16 @@ def test_processor_with_sad_kwarg_overrides(use_processor_mock,
|
|||||||
mm_processor_kwargs):
|
mm_processor_kwargs):
|
||||||
"""Ensure that input processors filter out invalid mm_processor_kwargs"""
|
"""Ensure that input processors filter out invalid mm_processor_kwargs"""
|
||||||
dummy_registry = InputRegistry()
|
dummy_registry = InputRegistry()
|
||||||
|
# Should filter out the init time kwargs
|
||||||
ctx = build_model_context(DUMMY_MODEL_ID,
|
ctx = build_model_context(DUMMY_MODEL_ID,
|
||||||
mm_processor_kwargs=mm_processor_kwargs)
|
mm_processor_kwargs=mm_processor_kwargs)
|
||||||
|
|
||||||
processor = dummy_registry.create_input_processor(ctx.model_config)
|
processor = dummy_registry.create_input_processor(ctx.model_config)
|
||||||
num_crops_val = processor(LLMInputs(prompt_token_ids=[], prompt=""))
|
# Should filter out the inference time kwargs
|
||||||
|
num_crops_val = processor(
|
||||||
|
LLMInputs(prompt_token_ids=[],
|
||||||
|
prompt="",
|
||||||
|
mm_processor_kwargs=mm_processor_kwargs))
|
||||||
assert num_crops_val == DEFAULT_NUM_CROPS
|
assert num_crops_val == DEFAULT_NUM_CROPS
|
||||||
|
|
||||||
|
|
||||||
@ -271,32 +299,34 @@ def test_default_mapper_with_processer_kwargs(image_assets, num_crops):
|
|||||||
assert mapped_inputs["pixel_values"].shape[1] == num_crops + 1
|
assert mapped_inputs["pixel_values"].shape[1] == num_crops + 1
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("num_crops", [None, NUM_CROPS_OVERRIDE])
|
@pytest.mark.parametrize("init_num_crops,inference_num_crops", [
|
||||||
def test_custom_mapper_kwarg_overrides(image_assets, num_crops):
|
(None, None),
|
||||||
|
(NUM_CROPS_OVERRIDE, None),
|
||||||
|
(DEFAULT_NUM_CROPS, NUM_CROPS_OVERRIDE),
|
||||||
|
])
|
||||||
|
def test_custom_mapper_kwarg_overrides(image_assets, init_num_crops,
|
||||||
|
inference_num_crops):
|
||||||
"""Ensure custom mappers can use processor kwargs."""
|
"""Ensure custom mappers can use processor kwargs."""
|
||||||
mm_processor_kwargs = None if num_crops is None else {
|
init_kwargs, inference_kwargs, expected_seq_count = _get_num_crops_info(
|
||||||
"num_crops": num_crops
|
init_num_crops, inference_num_crops)
|
||||||
}
|
|
||||||
expected_seq_count = DEFAULT_NUM_CROPS if num_crops is None else num_crops
|
|
||||||
ctx = build_model_context(MULTIMODAL_MODEL_ID,
|
ctx = build_model_context(MULTIMODAL_MODEL_ID,
|
||||||
trust_remote_code=True,
|
trust_remote_code=True,
|
||||||
mm_processor_kwargs=mm_processor_kwargs,
|
mm_processor_kwargs=init_kwargs,
|
||||||
limit_mm_per_prompt={"image": 1})
|
limit_mm_per_prompt={"image": 1})
|
||||||
|
|
||||||
mm_registry = MultiModalRegistry()
|
mm_registry = MultiModalRegistry()
|
||||||
mm_registry.init_mm_limits_per_prompt(ctx.model_config)
|
mm_registry.init_mm_limits_per_prompt(ctx.model_config)
|
||||||
# Patch the image registry for phi3v with our lambda that is compatible
|
|
||||||
# with overrides, then ensure that calling the method correctly echos
|
|
||||||
# our num_crops value back from the mm_processor_kwargs.
|
|
||||||
image = image_assets[0].pil_image
|
image = image_assets[0].pil_image
|
||||||
mm_inputs = {"image": image}
|
mm_inputs = {"image": image}
|
||||||
|
|
||||||
with patch.object(
|
# Patch the image registry for phi3v with our lambda that is compatible
|
||||||
mm_registry._get_plugin("image"),
|
# with overrides, then ensure that calling the method correctly echos
|
||||||
"_default_input_mapper",
|
# our num_crops value back from the mm_processor_kwargs.
|
||||||
{mm_model_cls(): custom_mapper},
|
mm_registry._get_plugin("image").register_input_mapper(custom_mapper)(
|
||||||
):
|
mm_model_cls())
|
||||||
mapped_inputs = mm_registry.map_input(ctx.model_config, mm_inputs)
|
mapped_inputs = mm_registry.map_input(ctx.model_config, mm_inputs,
|
||||||
|
inference_kwargs)
|
||||||
|
|
||||||
assert mapped_inputs["pixel_values"].shape[1] == expected_seq_count + 1
|
assert mapped_inputs["pixel_values"].shape[1] == expected_seq_count + 1
|
||||||
|
|
||||||
@ -316,6 +346,7 @@ def test_custom_mapper_kwarg_overrides(image_assets, num_crops):
|
|||||||
def test_custom_mapper_with_sad_kwarg_overrides(image_assets,
|
def test_custom_mapper_with_sad_kwarg_overrides(image_assets,
|
||||||
mm_processor_kwargs):
|
mm_processor_kwargs):
|
||||||
"""Ensure that custom mappers filters out invalid mm_processor_kwargs"""
|
"""Ensure that custom mappers filters out invalid mm_processor_kwargs"""
|
||||||
|
# Should filter out the init time kwargs
|
||||||
ctx = build_model_context(MULTIMODAL_MODEL_ID,
|
ctx = build_model_context(MULTIMODAL_MODEL_ID,
|
||||||
trust_remote_code=True,
|
trust_remote_code=True,
|
||||||
mm_processor_kwargs=mm_processor_kwargs,
|
mm_processor_kwargs=mm_processor_kwargs,
|
||||||
@ -323,17 +354,16 @@ def test_custom_mapper_with_sad_kwarg_overrides(image_assets,
|
|||||||
|
|
||||||
mm_registry = MultiModalRegistry()
|
mm_registry = MultiModalRegistry()
|
||||||
mm_registry.init_mm_limits_per_prompt(ctx.model_config)
|
mm_registry.init_mm_limits_per_prompt(ctx.model_config)
|
||||||
# Patch the image registry for phi3v with our lambda that is compatible
|
|
||||||
# with overrides, then ensure that calling the method correctly echos
|
|
||||||
# our num_crops value back from the mm_processor_kwargs.
|
|
||||||
image = image_assets[0].pil_image
|
image = image_assets[0].pil_image
|
||||||
mm_inputs = {"image": image}
|
mm_inputs = {"image": image}
|
||||||
|
|
||||||
with patch.object(
|
# Patch the image registry for phi3v with our lambda that is compatible
|
||||||
mm_registry._get_plugin("image"),
|
# with overrides, then ensure that calling the method correctly echos
|
||||||
"_default_input_mapper",
|
# our num_crops value back from the mm_processor_kwargs.
|
||||||
{mm_model_cls(): custom_mapper},
|
mm_registry._get_plugin("image").register_input_mapper(custom_mapper)(
|
||||||
):
|
mm_model_cls())
|
||||||
mapped_inputs = mm_registry.map_input(ctx.model_config, mm_inputs)
|
# Should filter out the inference time kwargs
|
||||||
|
mapped_inputs = mm_registry.map_input(
|
||||||
|
ctx.model_config, mm_inputs, mm_processor_kwargs=mm_processor_kwargs)
|
||||||
|
|
||||||
assert mapped_inputs["pixel_values"].shape[1] == DEFAULT_NUM_CROPS + 1
|
assert mapped_inputs["pixel_values"].shape[1] == DEFAULT_NUM_CROPS + 1
|
||||||
|
@ -2,6 +2,7 @@ from typing import List
|
|||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
|
from vllm.inputs import zip_enc_dec_prompts
|
||||||
from vllm.inputs.parse import parse_and_batch_prompt
|
from vllm.inputs.parse import parse_and_batch_prompt
|
||||||
|
|
||||||
STRING_INPUTS = [
|
STRING_INPUTS = [
|
||||||
@ -51,3 +52,28 @@ def test_parse_single_batch_token_consistent(token_input: List[int]):
|
|||||||
def test_parse_single_batch_string_slice(inputs_slice: slice):
|
def test_parse_single_batch_string_slice(inputs_slice: slice):
|
||||||
assert parse_and_batch_prompt(STRING_INPUTS)[inputs_slice] \
|
assert parse_and_batch_prompt(STRING_INPUTS)[inputs_slice] \
|
||||||
== parse_and_batch_prompt(STRING_INPUTS[inputs_slice])
|
== parse_and_batch_prompt(STRING_INPUTS[inputs_slice])
|
||||||
|
|
||||||
|
|
||||||
|
# yapf: disable
|
||||||
|
@pytest.mark.parametrize('mm_processor_kwargs,expected_mm_kwargs', [
|
||||||
|
(None, [{}, {}]),
|
||||||
|
({}, [{}, {}]),
|
||||||
|
({"foo": 100}, [{"foo": 100}, {"foo": 100}]),
|
||||||
|
([{"foo": 100}, {"bar": 200}], [{"foo": 100}, {"bar": 200}]),
|
||||||
|
])
|
||||||
|
# yapf: enable
|
||||||
|
def test_zip_enc_dec_prompts(mm_processor_kwargs, expected_mm_kwargs):
|
||||||
|
"""Test mm_processor_kwargs init for zipping enc/dec prompts."""
|
||||||
|
encoder_prompts = ['An encoder prompt', 'Another encoder prompt']
|
||||||
|
decoder_prompts = ['A decoder prompt', 'Another decoder prompt']
|
||||||
|
zipped_prompts = zip_enc_dec_prompts(encoder_prompts, decoder_prompts,
|
||||||
|
mm_processor_kwargs)
|
||||||
|
assert len(zipped_prompts) == len(encoder_prompts) == len(decoder_prompts)
|
||||||
|
for enc, dec, exp_kwargs, zipped in zip(encoder_prompts, decoder_prompts,
|
||||||
|
expected_mm_kwargs,
|
||||||
|
zipped_prompts):
|
||||||
|
assert isinstance(zipped, dict)
|
||||||
|
assert len(zipped.keys()) == 3
|
||||||
|
assert zipped['encoder_prompt'] == enc
|
||||||
|
assert zipped['decoder_prompt'] == dec
|
||||||
|
assert zipped['mm_processor_kwargs'] == exp_kwargs
|
||||||
|
@ -7,7 +7,7 @@ from typing import AsyncIterator, Tuple
|
|||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
from vllm.utils import (FlexibleArgumentParser, deprecate_kwargs,
|
from vllm.utils import (FlexibleArgumentParser, deprecate_kwargs,
|
||||||
get_open_port, merge_async_iterators)
|
get_open_port, merge_async_iterators, supports_kw)
|
||||||
|
|
||||||
from .utils import error_on_warning
|
from .utils import error_on_warning
|
||||||
|
|
||||||
@ -236,3 +236,33 @@ def test_no_model_tag(parser_with_config):
|
|||||||
with pytest.raises(ValueError):
|
with pytest.raises(ValueError):
|
||||||
parser_with_config.parse_args(
|
parser_with_config.parse_args(
|
||||||
['serve', '--config', './data/test_config.yaml'])
|
['serve', '--config', './data/test_config.yaml'])
|
||||||
|
|
||||||
|
|
||||||
|
# yapf: enable
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"callable,kw_name,requires_kw_only,allow_var_kwargs,is_supported",
|
||||||
|
[
|
||||||
|
# Tests for positional argument support
|
||||||
|
(lambda foo: None, "foo", True, True, False),
|
||||||
|
(lambda foo: None, "foo", False, True, True),
|
||||||
|
# Tests for positional or keyword / keyword only
|
||||||
|
(lambda foo=100: None, "foo", True, True, False),
|
||||||
|
(lambda *, foo: None, "foo", False, True, True),
|
||||||
|
# Tests to make sure the names of variadic params are NOT supported
|
||||||
|
(lambda *args: None, "args", False, True, False),
|
||||||
|
(lambda **kwargs: None, "kwargs", False, True, False),
|
||||||
|
# Tests for if we allow var kwargs to add support
|
||||||
|
(lambda foo: None, "something_else", False, True, False),
|
||||||
|
(lambda foo, **kwargs: None, "something_else", False, True, True),
|
||||||
|
(lambda foo, **kwargs: None, "kwargs", True, True, False),
|
||||||
|
(lambda foo, **kwargs: None, "foo", True, True, False),
|
||||||
|
])
|
||||||
|
# yapf: disable
|
||||||
|
def test_supports_kw(callable,kw_name,requires_kw_only,
|
||||||
|
allow_var_kwargs,is_supported):
|
||||||
|
assert supports_kw(
|
||||||
|
callable=callable,
|
||||||
|
kw_name=kw_name,
|
||||||
|
requires_kw_only=requires_kw_only,
|
||||||
|
allow_var_kwargs=allow_var_kwargs
|
||||||
|
) == is_supported
|
||||||
|
@ -1309,6 +1309,7 @@ class Scheduler:
|
|||||||
# `multi_modal_data` will be None.
|
# `multi_modal_data` will be None.
|
||||||
multi_modal_data=seq_group.multi_modal_data
|
multi_modal_data=seq_group.multi_modal_data
|
||||||
if scheduler_outputs.num_prefill_groups > 0 else None,
|
if scheduler_outputs.num_prefill_groups > 0 else None,
|
||||||
|
mm_processor_kwargs=seq_group.mm_processor_kwargs,
|
||||||
prompt_adapter_request=seq_group.prompt_adapter_request,
|
prompt_adapter_request=seq_group.prompt_adapter_request,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
|
@ -811,6 +811,13 @@ class LLMEngine:
|
|||||||
)
|
)
|
||||||
processed_inputs = self.input_processor(preprocessed_inputs)
|
processed_inputs = self.input_processor(preprocessed_inputs)
|
||||||
|
|
||||||
|
# This is a bit of a hack - copy the mm_processor_kwargs that were
|
||||||
|
# used in the input processor to the processed output, since these
|
||||||
|
# kwargs are presumed to be immutable and the values should be aligned
|
||||||
|
# between the input processor (here) and the input mapper.
|
||||||
|
processed_inputs["mm_processor_kwargs"] = preprocessed_inputs.get(
|
||||||
|
"mm_processor_kwargs")
|
||||||
|
|
||||||
self._add_processed_request(
|
self._add_processed_request(
|
||||||
request_id=request_id,
|
request_id=request_id,
|
||||||
processed_inputs=processed_inputs,
|
processed_inputs=processed_inputs,
|
||||||
|
@ -472,6 +472,7 @@ class LLM:
|
|||||||
add_generation_prompt: bool = True,
|
add_generation_prompt: bool = True,
|
||||||
continue_final_message: bool = False,
|
continue_final_message: bool = False,
|
||||||
tools: Optional[List[Dict[str, Any]]] = None,
|
tools: Optional[List[Dict[str, Any]]] = None,
|
||||||
|
mm_processor_kwargs: Optional[Dict[str, Any]] = None,
|
||||||
) -> List[RequestOutput]:
|
) -> List[RequestOutput]:
|
||||||
"""
|
"""
|
||||||
Generate responses for a chat conversation.
|
Generate responses for a chat conversation.
|
||||||
@ -501,6 +502,8 @@ class LLM:
|
|||||||
continue_final_message: If True, continues the final message in
|
continue_final_message: If True, continues the final message in
|
||||||
the conversation instead of starting a new one. Cannot be `True`
|
the conversation instead of starting a new one. Cannot be `True`
|
||||||
if `add_generation_prompt` is also `True`.
|
if `add_generation_prompt` is also `True`.
|
||||||
|
mm_processor_kwargs: Multimodal processor kwarg overrides for this
|
||||||
|
chat request. Only used for offline requests.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
A list of ``RequestOutput`` objects containing the generated
|
A list of ``RequestOutput`` objects containing the generated
|
||||||
@ -522,6 +525,9 @@ class LLM:
|
|||||||
tokenizer = self.get_tokenizer()
|
tokenizer = self.get_tokenizer()
|
||||||
model_config = self.llm_engine.get_model_config()
|
model_config = self.llm_engine.get_model_config()
|
||||||
|
|
||||||
|
# NOTE: _parse_chat_message_content_parts() currently doesn't
|
||||||
|
# handle mm_processor_kwargs, since there is no implementation in
|
||||||
|
# the chat message parsing for it.
|
||||||
conversation, mm_data = parse_chat_messages(
|
conversation, mm_data = parse_chat_messages(
|
||||||
msgs, model_config, tokenizer)
|
msgs, model_config, tokenizer)
|
||||||
|
|
||||||
@ -554,6 +560,9 @@ class LLM:
|
|||||||
if mm_data is not None:
|
if mm_data is not None:
|
||||||
prompt["multi_modal_data"] = mm_data
|
prompt["multi_modal_data"] = mm_data
|
||||||
|
|
||||||
|
if mm_processor_kwargs is not None:
|
||||||
|
prompt["mm_processor_kwargs"] = mm_processor_kwargs
|
||||||
|
|
||||||
prompts.append(prompt)
|
prompts.append(prompt)
|
||||||
|
|
||||||
return self.generate(
|
return self.generate(
|
||||||
|
@ -1,5 +1,5 @@
|
|||||||
from typing import (TYPE_CHECKING, Generic, Iterable, List, Optional, Tuple,
|
from typing import (TYPE_CHECKING, Any, Dict, Generic, Iterable, List,
|
||||||
Union)
|
Optional, Tuple, Union)
|
||||||
|
|
||||||
from typing_extensions import NotRequired, TypedDict, TypeVar
|
from typing_extensions import NotRequired, TypedDict, TypeVar
|
||||||
|
|
||||||
@ -19,6 +19,14 @@ class TextPrompt(TypedDict):
|
|||||||
if the model supports it.
|
if the model supports it.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
mm_processor_kwargs: NotRequired[Dict[str, Any]]
|
||||||
|
"""
|
||||||
|
Optional multi-modal processor kwargs to be forwarded to the
|
||||||
|
multimodal input mapper & processor. Note that if multiple modalities
|
||||||
|
have registered mappers etc for the model being considered, we attempt
|
||||||
|
to pass the mm_processor_kwargs to each of them.
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
class TokensPrompt(TypedDict):
|
class TokensPrompt(TypedDict):
|
||||||
"""Schema for a tokenized prompt."""
|
"""Schema for a tokenized prompt."""
|
||||||
@ -32,6 +40,14 @@ class TokensPrompt(TypedDict):
|
|||||||
if the model supports it.
|
if the model supports it.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
mm_processor_kwargs: NotRequired[Dict[str, Any]]
|
||||||
|
"""
|
||||||
|
Optional multi-modal processor kwargs to be forwarded to the
|
||||||
|
multimodal input mapper & processor. Note that if multiple modalities
|
||||||
|
have registered mappers etc for the model being considered, we attempt
|
||||||
|
to pass the mm_processor_kwargs to each of them.
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
SingletonPrompt = Union[str, TextPrompt, TokensPrompt]
|
SingletonPrompt = Union[str, TextPrompt, TokensPrompt]
|
||||||
"""
|
"""
|
||||||
@ -74,7 +90,9 @@ class ExplicitEncoderDecoderPrompt(TypedDict, Generic[_T1_co, _T2_co]):
|
|||||||
according to any of the :class:`SingletonPrompt` schemas,
|
according to any of the :class:`SingletonPrompt` schemas,
|
||||||
and are not required to have the same schema.
|
and are not required to have the same schema.
|
||||||
|
|
||||||
Only the encoder prompt may have multi-modal data.
|
Only the encoder prompt may have multi-modal data. mm_processor_kwargs
|
||||||
|
should be at the top-level, and should not be set in the encoder/decoder
|
||||||
|
prompts, since they are agnostic to the encoder/decoder.
|
||||||
|
|
||||||
Note that an :class:`ExplicitEncoderDecoderPrompt` may not
|
Note that an :class:`ExplicitEncoderDecoderPrompt` may not
|
||||||
be used as an input to a decoder-only model,
|
be used as an input to a decoder-only model,
|
||||||
@ -87,6 +105,8 @@ class ExplicitEncoderDecoderPrompt(TypedDict, Generic[_T1_co, _T2_co]):
|
|||||||
|
|
||||||
decoder_prompt: Optional[_T2_co]
|
decoder_prompt: Optional[_T2_co]
|
||||||
|
|
||||||
|
mm_processor_kwargs: NotRequired[Dict[str, Any]]
|
||||||
|
|
||||||
|
|
||||||
PromptType = Union[SingletonPrompt, ExplicitEncoderDecoderPrompt]
|
PromptType = Union[SingletonPrompt, ExplicitEncoderDecoderPrompt]
|
||||||
"""
|
"""
|
||||||
@ -121,6 +141,14 @@ class LLMInputs(TypedDict):
|
|||||||
if the model supports it.
|
if the model supports it.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
mm_processor_kwargs: NotRequired[Optional[Dict[str, Any]]]
|
||||||
|
"""
|
||||||
|
Optional multi-modal processor kwargs to be forwarded to the
|
||||||
|
multimodal input mapper & processor. Note that if multiple modalities
|
||||||
|
have registered mappers etc for the model being considered, we attempt
|
||||||
|
to pass the mm_processor_kwargs to each of them.
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
class EncoderDecoderLLMInputs(LLMInputs):
|
class EncoderDecoderLLMInputs(LLMInputs):
|
||||||
"""
|
"""
|
||||||
@ -152,22 +180,43 @@ _T2 = TypeVar("_T2", bound=SingletonPrompt, default=SingletonPrompt)
|
|||||||
def build_explicit_enc_dec_prompt(
|
def build_explicit_enc_dec_prompt(
|
||||||
encoder_prompt: _T1,
|
encoder_prompt: _T1,
|
||||||
decoder_prompt: Optional[_T2],
|
decoder_prompt: Optional[_T2],
|
||||||
|
mm_processor_kwargs: Optional[Dict[str, Any]] = None,
|
||||||
) -> ExplicitEncoderDecoderPrompt[_T1, _T2]:
|
) -> ExplicitEncoderDecoderPrompt[_T1, _T2]:
|
||||||
return ExplicitEncoderDecoderPrompt(encoder_prompt=encoder_prompt,
|
if mm_processor_kwargs is None:
|
||||||
decoder_prompt=decoder_prompt)
|
mm_processor_kwargs = {}
|
||||||
|
return ExplicitEncoderDecoderPrompt(
|
||||||
|
encoder_prompt=encoder_prompt,
|
||||||
|
decoder_prompt=decoder_prompt,
|
||||||
|
mm_processor_kwargs=mm_processor_kwargs)
|
||||||
|
|
||||||
|
|
||||||
def zip_enc_dec_prompts(
|
def zip_enc_dec_prompts(
|
||||||
enc_prompts: Iterable[_T1],
|
enc_prompts: Iterable[_T1],
|
||||||
dec_prompts: Iterable[Optional[_T2]],
|
dec_prompts: Iterable[Optional[_T2]],
|
||||||
|
mm_processor_kwargs: Optional[Union[Iterable[Dict[str, Any]],
|
||||||
|
Dict[str, Any]]] = None,
|
||||||
) -> List[ExplicitEncoderDecoderPrompt[_T1, _T2]]:
|
) -> List[ExplicitEncoderDecoderPrompt[_T1, _T2]]:
|
||||||
"""
|
"""
|
||||||
Zip encoder and decoder prompts together into a list of
|
Zip encoder and decoder prompts together into a list of
|
||||||
:class:`ExplicitEncoderDecoderPrompt` instances.
|
:class:`ExplicitEncoderDecoderPrompt` instances. mm_processor_kwargs
|
||||||
|
may also be provided; if a dict is passed, the same dictionary will be
|
||||||
|
used for every encoder/decoder prompt. If an iterable is provided, it will
|
||||||
|
be zipped with the encoder/decoder prompts.
|
||||||
"""
|
"""
|
||||||
|
if mm_processor_kwargs is None:
|
||||||
|
mm_processor_kwargs = {}
|
||||||
|
if isinstance(mm_processor_kwargs, Dict):
|
||||||
|
return [
|
||||||
|
build_explicit_enc_dec_prompt(encoder_prompt, decoder_prompt,
|
||||||
|
mm_processor_kwargs)
|
||||||
|
for (encoder_prompt,
|
||||||
|
decoder_prompt) in zip(enc_prompts, dec_prompts)
|
||||||
|
]
|
||||||
return [
|
return [
|
||||||
build_explicit_enc_dec_prompt(encoder_prompt, decoder_prompt)
|
build_explicit_enc_dec_prompt(encoder_prompt, decoder_prompt,
|
||||||
for (encoder_prompt, decoder_prompt) in zip(enc_prompts, dec_prompts)
|
mm_proc_kwargs)
|
||||||
|
for (encoder_prompt, decoder_prompt, mm_proc_kwargs
|
||||||
|
) in zip(enc_prompts, dec_prompts, mm_processor_kwargs)
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
|
@ -1,5 +1,5 @@
|
|||||||
import asyncio
|
import asyncio
|
||||||
from typing import TYPE_CHECKING, List, Optional, Tuple, Union
|
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union
|
||||||
|
|
||||||
from typing_extensions import assert_never
|
from typing_extensions import assert_never
|
||||||
|
|
||||||
@ -20,9 +20,11 @@ if TYPE_CHECKING:
|
|||||||
logger = init_logger(__name__)
|
logger = init_logger(__name__)
|
||||||
|
|
||||||
PromptComponents = Tuple[Optional[str], List[int],
|
PromptComponents = Tuple[Optional[str], List[int],
|
||||||
Optional["MultiModalDataDict"]]
|
Optional["MultiModalDataDict"], Optional[Dict[str,
|
||||||
|
Any]]]
|
||||||
DecoderPromptComponents = Tuple[Optional[str], Optional[List[int]],
|
DecoderPromptComponents = Tuple[Optional[str], Optional[List[int]],
|
||||||
Optional["MultiModalDataDict"]]
|
Optional["MultiModalDataDict"],
|
||||||
|
Optional[Dict[str, Any]]]
|
||||||
|
|
||||||
|
|
||||||
class InputPreprocessor:
|
class InputPreprocessor:
|
||||||
@ -227,6 +229,7 @@ class InputPreprocessor:
|
|||||||
* prompt
|
* prompt
|
||||||
* prompt_token_ids
|
* prompt_token_ids
|
||||||
* multi_modal_data
|
* multi_modal_data
|
||||||
|
* mm_processor_kwargs (request-level input processor/mapper overrides)
|
||||||
'''
|
'''
|
||||||
|
|
||||||
parsed = parse_singleton_prompt(prompt)
|
parsed = parse_singleton_prompt(prompt)
|
||||||
@ -239,10 +242,12 @@ class InputPreprocessor:
|
|||||||
lora_request=lora_request,
|
lora_request=lora_request,
|
||||||
)
|
)
|
||||||
multi_modal_data = None
|
multi_modal_data = None
|
||||||
|
mm_processor_kwargs = None
|
||||||
elif parsed["type"] == "tokens":
|
elif parsed["type"] == "tokens":
|
||||||
prompt_text = None
|
prompt_text = None
|
||||||
prompt_token_ids = parsed["content"]["prompt_token_ids"]
|
prompt_token_ids = parsed["content"]["prompt_token_ids"]
|
||||||
multi_modal_data = parsed["content"].get("multi_modal_data")
|
multi_modal_data = parsed["content"].get("multi_modal_data")
|
||||||
|
mm_processor_kwargs = parsed["content"].get("mm_processor_kwargs")
|
||||||
elif parsed["type"] == "text":
|
elif parsed["type"] == "text":
|
||||||
prompt_text = parsed["content"]["prompt"]
|
prompt_text = parsed["content"]["prompt"]
|
||||||
prompt_token_ids = self._tokenize_prompt(
|
prompt_token_ids = self._tokenize_prompt(
|
||||||
@ -251,10 +256,12 @@ class InputPreprocessor:
|
|||||||
lora_request=lora_request,
|
lora_request=lora_request,
|
||||||
)
|
)
|
||||||
multi_modal_data = parsed["content"].get("multi_modal_data")
|
multi_modal_data = parsed["content"].get("multi_modal_data")
|
||||||
|
mm_processor_kwargs = parsed["content"].get("mm_processor_kwargs")
|
||||||
else:
|
else:
|
||||||
assert_never(parsed)
|
assert_never(parsed)
|
||||||
|
|
||||||
return prompt_text, prompt_token_ids, multi_modal_data
|
return (prompt_text, prompt_token_ids, multi_modal_data,
|
||||||
|
mm_processor_kwargs)
|
||||||
|
|
||||||
async def _extract_prompt_components_async(
|
async def _extract_prompt_components_async(
|
||||||
self,
|
self,
|
||||||
@ -273,10 +280,12 @@ class InputPreprocessor:
|
|||||||
lora_request=lora_request,
|
lora_request=lora_request,
|
||||||
)
|
)
|
||||||
multi_modal_data = None
|
multi_modal_data = None
|
||||||
|
mm_processor_kwargs = None
|
||||||
elif parsed["type"] == "tokens":
|
elif parsed["type"] == "tokens":
|
||||||
prompt_text = None
|
prompt_text = None
|
||||||
prompt_token_ids = parsed["content"]["prompt_token_ids"]
|
prompt_token_ids = parsed["content"]["prompt_token_ids"]
|
||||||
multi_modal_data = parsed["content"].get("multi_modal_data")
|
multi_modal_data = parsed["content"].get("multi_modal_data")
|
||||||
|
mm_processor_kwargs = parsed["content"].get("mm_processor_kwargs")
|
||||||
elif parsed["type"] == "text":
|
elif parsed["type"] == "text":
|
||||||
prompt_text = parsed["content"]["prompt"]
|
prompt_text = parsed["content"]["prompt"]
|
||||||
prompt_token_ids = await self._tokenize_prompt_async(
|
prompt_token_ids = await self._tokenize_prompt_async(
|
||||||
@ -285,18 +294,21 @@ class InputPreprocessor:
|
|||||||
lora_request=lora_request,
|
lora_request=lora_request,
|
||||||
)
|
)
|
||||||
multi_modal_data = parsed["content"].get("multi_modal_data")
|
multi_modal_data = parsed["content"].get("multi_modal_data")
|
||||||
|
mm_processor_kwargs = parsed["content"].get("mm_processor_kwargs")
|
||||||
else:
|
else:
|
||||||
assert_never(parsed)
|
assert_never(parsed)
|
||||||
|
|
||||||
return prompt_text, prompt_token_ids, multi_modal_data
|
return (prompt_text, prompt_token_ids, multi_modal_data,
|
||||||
|
mm_processor_kwargs)
|
||||||
|
|
||||||
def _build_enc_dec_llm_inputs(
|
def _build_enc_dec_llm_inputs(
|
||||||
self,
|
self,
|
||||||
encoder_comps: PromptComponents,
|
encoder_comps: PromptComponents,
|
||||||
decoder_comps: DecoderPromptComponents,
|
decoder_comps: DecoderPromptComponents,
|
||||||
|
mm_processor_kwargs: Dict[str, Any],
|
||||||
) -> EncoderDecoderLLMInputs:
|
) -> EncoderDecoderLLMInputs:
|
||||||
encoder_prompt, encoder_prompt_ids, encoder_mm_data = encoder_comps
|
encoder_prompt, encoder_prompt_ids, encoder_mm_data, _ = encoder_comps
|
||||||
decoder_prompt, decoder_prompt_ids, decoder_mm_data = decoder_comps
|
decoder_prompt, decoder_prompt_ids, decoder_mm_data, _ = decoder_comps
|
||||||
|
|
||||||
if decoder_mm_data is not None:
|
if decoder_mm_data is not None:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
@ -314,6 +326,7 @@ class InputPreprocessor:
|
|||||||
prompt_token_ids=decoder_prompt_ids,
|
prompt_token_ids=decoder_prompt_ids,
|
||||||
prompt=decoder_prompt,
|
prompt=decoder_prompt,
|
||||||
multi_modal_data=decoder_mm_data,
|
multi_modal_data=decoder_mm_data,
|
||||||
|
mm_processor_kwargs=mm_processor_kwargs,
|
||||||
encoder_prompt_token_ids=encoder_prompt_ids,
|
encoder_prompt_token_ids=encoder_prompt_ids,
|
||||||
encoder_prompt=encoder_prompt,
|
encoder_prompt=encoder_prompt,
|
||||||
encoder_multi_modal_data=encoder_mm_data,
|
encoder_multi_modal_data=encoder_mm_data,
|
||||||
@ -367,21 +380,30 @@ class InputPreprocessor:
|
|||||||
)
|
)
|
||||||
|
|
||||||
if (decoder_input := prompt["decoder_prompt"]) is None:
|
if (decoder_input := prompt["decoder_prompt"]) is None:
|
||||||
decoder_comps = None, None, None
|
decoder_comps = None, None, None, None
|
||||||
else:
|
else:
|
||||||
decoder_comps = self._extract_prompt_components(
|
decoder_comps = self._extract_prompt_components(
|
||||||
decoder_input,
|
decoder_input,
|
||||||
request_id=request_id,
|
request_id=request_id,
|
||||||
)
|
)
|
||||||
|
# Handle this carefully in case it was directly initialized by user
|
||||||
|
mm_processor_kwargs = prompt.get("mm_processor_kwargs", {})
|
||||||
else:
|
else:
|
||||||
encoder_comps = self._extract_prompt_components(
|
encoder_comps = self._extract_prompt_components(
|
||||||
prompt,
|
prompt,
|
||||||
request_id=request_id,
|
request_id=request_id,
|
||||||
)
|
)
|
||||||
|
# If there are no decoder components, we assume the
|
||||||
|
# mm_processor_kwargs are in the encoder prompt
|
||||||
|
mm_processor_kwargs = encoder_comps[-1] if encoder_comps[
|
||||||
|
-1] is not None else {}
|
||||||
|
decoder_comps = None, None, None, None
|
||||||
|
|
||||||
decoder_comps = None, None, None
|
return self._build_enc_dec_llm_inputs(
|
||||||
|
encoder_comps,
|
||||||
return self._build_enc_dec_llm_inputs(encoder_comps, decoder_comps)
|
decoder_comps,
|
||||||
|
mm_processor_kwargs,
|
||||||
|
)
|
||||||
|
|
||||||
async def _process_encoder_decoder_prompt_async(
|
async def _process_encoder_decoder_prompt_async(
|
||||||
self,
|
self,
|
||||||
@ -400,7 +422,7 @@ class InputPreprocessor:
|
|||||||
|
|
||||||
if (decoder_input := prompt["decoder_prompt"]) is None:
|
if (decoder_input := prompt["decoder_prompt"]) is None:
|
||||||
encoder_comps = await encoder_task
|
encoder_comps = await encoder_task
|
||||||
decoder_comps = None, None, None
|
decoder_comps = None, None, None, None
|
||||||
else:
|
else:
|
||||||
decoder_task = self._extract_prompt_components_async(
|
decoder_task = self._extract_prompt_components_async(
|
||||||
decoder_input,
|
decoder_input,
|
||||||
@ -409,29 +431,39 @@ class InputPreprocessor:
|
|||||||
|
|
||||||
encoder_comps, decoder_comps = await asyncio.gather(
|
encoder_comps, decoder_comps = await asyncio.gather(
|
||||||
encoder_task, decoder_task)
|
encoder_task, decoder_task)
|
||||||
|
mm_processor_kwargs = prompt["mm_processor_kwargs"]
|
||||||
else:
|
else:
|
||||||
encoder_comps = await self._extract_prompt_components_async(
|
encoder_comps = await self._extract_prompt_components_async(
|
||||||
prompt,
|
prompt,
|
||||||
request_id=request_id,
|
request_id=request_id,
|
||||||
)
|
)
|
||||||
|
# If there are no decoder components, we assume the
|
||||||
|
# mm_processor_kwargs are in the encoder prompt
|
||||||
|
mm_processor_kwargs = encoder_comps[-1] if encoder_comps[
|
||||||
|
-1] is not None else {}
|
||||||
|
decoder_comps = None, None, None, None
|
||||||
|
|
||||||
decoder_comps = None, None, None
|
return self._build_enc_dec_llm_inputs(
|
||||||
|
encoder_comps,
|
||||||
return self._build_enc_dec_llm_inputs(encoder_comps, decoder_comps)
|
decoder_comps,
|
||||||
|
mm_processor_kwargs,
|
||||||
|
)
|
||||||
|
|
||||||
def _build_decoder_only_llm_inputs(
|
def _build_decoder_only_llm_inputs(
|
||||||
self,
|
self,
|
||||||
prompt_comps: PromptComponents,
|
prompt_comps: PromptComponents,
|
||||||
prompt_adapter_request: Optional[PromptAdapterRequest],
|
prompt_adapter_request: Optional[PromptAdapterRequest],
|
||||||
) -> LLMInputs:
|
) -> LLMInputs:
|
||||||
prompt, prompt_token_ids, multi_modal_data = prompt_comps
|
(prompt, prompt_token_ids, multi_modal_data,
|
||||||
|
mm_processor_kwargs) = prompt_comps
|
||||||
|
|
||||||
prompt_token_ids = self._apply_prompt_adapter(
|
prompt_token_ids = self._apply_prompt_adapter(
|
||||||
prompt_token_ids, prompt_adapter_request=prompt_adapter_request)
|
prompt_token_ids, prompt_adapter_request=prompt_adapter_request)
|
||||||
|
|
||||||
return LLMInputs(prompt_token_ids=prompt_token_ids,
|
return LLMInputs(prompt_token_ids=prompt_token_ids,
|
||||||
prompt=prompt,
|
prompt=prompt,
|
||||||
multi_modal_data=multi_modal_data)
|
multi_modal_data=multi_modal_data,
|
||||||
|
mm_processor_kwargs=mm_processor_kwargs)
|
||||||
|
|
||||||
def _process_decoder_only_prompt(
|
def _process_decoder_only_prompt(
|
||||||
self,
|
self,
|
||||||
|
@ -9,7 +9,8 @@ from transformers import PretrainedConfig
|
|||||||
from typing_extensions import TypeVar
|
from typing_extensions import TypeVar
|
||||||
|
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
from vllm.utils import get_allowed_kwarg_only_overrides, print_warning_once
|
from vllm.utils import (get_allowed_kwarg_only_overrides, print_warning_once,
|
||||||
|
resolve_mm_processor_kwargs)
|
||||||
|
|
||||||
from .data import LLMInputs
|
from .data import LLMInputs
|
||||||
|
|
||||||
@ -293,8 +294,14 @@ class InputRegistry:
|
|||||||
model_cls, _ = get_model_architecture(model_config)
|
model_cls, _ = get_model_architecture(model_config)
|
||||||
processor = self._get_model_input_processor(model_cls)
|
processor = self._get_model_input_processor(model_cls)
|
||||||
|
|
||||||
mm_processor_kwargs = get_allowed_kwarg_only_overrides(
|
# Handle multimodal processor kwargs with priority:
|
||||||
processor, overrides=model_config.mm_processor_kwargs)
|
# Inference kwargs -> Init kwargs -> {}
|
||||||
|
# If it's empty, it'll fall back to the default kwarg values
|
||||||
|
mm_processor_kwargs = resolve_mm_processor_kwargs(
|
||||||
|
model_config.mm_processor_kwargs,
|
||||||
|
inputs.get("mm_processor_kwargs"),
|
||||||
|
processor,
|
||||||
|
)
|
||||||
|
|
||||||
return processor(InputContext(model_config), inputs,
|
return processor(InputContext(model_config), inputs,
|
||||||
**mm_processor_kwargs)
|
**mm_processor_kwargs)
|
||||||
|
@ -8,8 +8,8 @@ class AudioPlugin(MultiModalPlugin):
|
|||||||
def get_data_key(self) -> str:
|
def get_data_key(self) -> str:
|
||||||
return "audio"
|
return "audio"
|
||||||
|
|
||||||
def _default_input_mapper(self, ctx: InputContext,
|
def _default_input_mapper(self, ctx: InputContext, data: object,
|
||||||
data: object) -> MultiModalInputs:
|
**mm_processor_kwargs) -> MultiModalInputs:
|
||||||
raise NotImplementedError("There is no default audio input mapper")
|
raise NotImplementedError("There is no default audio input mapper")
|
||||||
|
|
||||||
def _default_max_multimodal_tokens(self, ctx: InputContext) -> int:
|
def _default_max_multimodal_tokens(self, ctx: InputContext) -> int:
|
||||||
|
@ -1,7 +1,7 @@
|
|||||||
import sys
|
import sys
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from collections import UserDict, defaultdict
|
from collections import UserDict, defaultdict
|
||||||
from typing import (Callable, Dict, List, Mapping, Optional, Tuple, Type,
|
from typing import (Any, Callable, Dict, List, Mapping, Optional, Tuple, Type,
|
||||||
TypedDict, TypeVar, Union, cast, final)
|
TypedDict, TypeVar, Union, cast, final)
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
@ -15,7 +15,7 @@ from vllm.config import ModelConfig
|
|||||||
from vllm.inputs import InputContext
|
from vllm.inputs import InputContext
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
from vllm.utils import (JSONTree, get_allowed_kwarg_only_overrides, is_list_of,
|
from vllm.utils import (JSONTree, get_allowed_kwarg_only_overrides, is_list_of,
|
||||||
json_map_leaves)
|
json_map_leaves, resolve_mm_processor_kwargs)
|
||||||
|
|
||||||
logger = init_logger(__name__)
|
logger = init_logger(__name__)
|
||||||
|
|
||||||
@ -200,6 +200,7 @@ class MultiModalPlugin(ABC):
|
|||||||
self,
|
self,
|
||||||
ctx: InputContext,
|
ctx: InputContext,
|
||||||
data: MultiModalData[object],
|
data: MultiModalData[object],
|
||||||
|
**mm_processor_kwargs,
|
||||||
) -> MultiModalInputs:
|
) -> MultiModalInputs:
|
||||||
"""
|
"""
|
||||||
Return a dictionary to be passed as keyword arguments to
|
Return a dictionary to be passed as keyword arguments to
|
||||||
@ -243,7 +244,8 @@ class MultiModalPlugin(ABC):
|
|||||||
return wrapper
|
return wrapper
|
||||||
|
|
||||||
def map_input(self, model_config: ModelConfig,
|
def map_input(self, model_config: ModelConfig,
|
||||||
data: MultiModalData[object]) -> MultiModalInputs:
|
data: MultiModalData[object],
|
||||||
|
mm_processor_kwargs: Dict[str, Any]) -> MultiModalInputs:
|
||||||
"""
|
"""
|
||||||
Transform the data into a dictionary of model inputs using the
|
Transform the data into a dictionary of model inputs using the
|
||||||
input mapper registered for that model.
|
input mapper registered for that model.
|
||||||
@ -263,19 +265,26 @@ class MultiModalPlugin(ABC):
|
|||||||
model_cls, _ = get_model_architecture(model_config)
|
model_cls, _ = get_model_architecture(model_config)
|
||||||
|
|
||||||
mapper = self._input_mappers.get(model_cls)
|
mapper = self._input_mappers.get(model_cls)
|
||||||
# Only get processor kwargs at mapping time if we are not using the
|
|
||||||
# input mapper; no overrides are used on the default here because they
|
|
||||||
# should be passed to the huggingface resource at initialization time.
|
|
||||||
if mapper is not None and mapper != self._default_input_mapper:
|
|
||||||
mm_processor_kwargs = get_allowed_kwarg_only_overrides(
|
|
||||||
mapper, overrides=model_config.mm_processor_kwargs)
|
|
||||||
else:
|
|
||||||
mm_processor_kwargs = {}
|
|
||||||
|
|
||||||
if mapper is None:
|
if mapper is None:
|
||||||
raise KeyError(f"No input mapper in {self} is registered for "
|
raise KeyError(f"No input mapper in {self} is registered for "
|
||||||
f"model class {model_cls.__name__}.")
|
f"model class {model_cls.__name__}.")
|
||||||
|
|
||||||
|
# In the case of the default mapper, we have to get resource
|
||||||
|
# processor through its HuggingFace autoclass; since this goes
|
||||||
|
# through **kwargs, we can't inspect it the same way, so we allow
|
||||||
|
# drop mm_processor_kwargs based on signature inspection
|
||||||
|
# if we're using the default mapper.
|
||||||
|
#
|
||||||
|
# This should be safe in general due to the sanitation, since the
|
||||||
|
# transformers resource should filter unused kwargs anyway.
|
||||||
|
uses_default_mapper = mapper == self._default_input_mapper
|
||||||
|
mm_processor_kwargs = resolve_mm_processor_kwargs(
|
||||||
|
model_config.mm_processor_kwargs,
|
||||||
|
mm_processor_kwargs,
|
||||||
|
callable=mapper,
|
||||||
|
allow_var_kwargs=uses_default_mapper,
|
||||||
|
)
|
||||||
return mapper(InputContext(model_config), data, **mm_processor_kwargs)
|
return mapper(InputContext(model_config), data, **mm_processor_kwargs)
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
|
@ -1,4 +1,5 @@
|
|||||||
from functools import lru_cache
|
from functools import lru_cache
|
||||||
|
from typing import Any, Dict, Optional
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
@ -23,11 +24,13 @@ class ImagePlugin(MultiModalPlugin):
|
|||||||
def get_data_key(self) -> str:
|
def get_data_key(self) -> str:
|
||||||
return "image"
|
return "image"
|
||||||
|
|
||||||
def _get_hf_image_processor(self, model_config: ModelConfig):
|
def _get_hf_image_processor(
|
||||||
mm_processor_kwargs = ({} if model_config.mm_processor_kwargs is None
|
self,
|
||||||
else model_config.mm_processor_kwargs)
|
model_config: ModelConfig,
|
||||||
# We don't explicitly check kwarg overrides to the HF class
|
mm_processor_kwargs: Optional[Dict[str, Any]] = None,
|
||||||
# since the automodel just takes kwargs, so we can't inspect it
|
):
|
||||||
|
if mm_processor_kwargs is None:
|
||||||
|
mm_processor_kwargs = {}
|
||||||
return cached_get_image_processor(
|
return cached_get_image_processor(
|
||||||
model_config.model,
|
model_config.model,
|
||||||
trust_remote_code=model_config.trust_remote_code,
|
trust_remote_code=model_config.trust_remote_code,
|
||||||
@ -37,6 +40,7 @@ class ImagePlugin(MultiModalPlugin):
|
|||||||
self,
|
self,
|
||||||
ctx: InputContext,
|
ctx: InputContext,
|
||||||
data: MultiModalData[object],
|
data: MultiModalData[object],
|
||||||
|
**mm_processor_kwargs,
|
||||||
) -> MultiModalInputs:
|
) -> MultiModalInputs:
|
||||||
model_config = ctx.model_config
|
model_config = ctx.model_config
|
||||||
|
|
||||||
@ -46,12 +50,20 @@ class ImagePlugin(MultiModalPlugin):
|
|||||||
|
|
||||||
# PIL image
|
# PIL image
|
||||||
if isinstance(data, Image.Image) or is_list_of(data, Image.Image):
|
if isinstance(data, Image.Image) or is_list_of(data, Image.Image):
|
||||||
image_processor = self._get_hf_image_processor(model_config)
|
image_processor = self._get_hf_image_processor(
|
||||||
|
model_config,
|
||||||
|
mm_processor_kwargs,
|
||||||
|
)
|
||||||
|
|
||||||
if image_processor is None:
|
if image_processor is None:
|
||||||
raise RuntimeError("No HuggingFace processor is available "
|
raise RuntimeError("No HuggingFace processor is available "
|
||||||
"to process the image object")
|
"to process the image object")
|
||||||
try:
|
try:
|
||||||
|
# NOTE: It may make sense to forward the mm_processor_kwargs
|
||||||
|
# here too. For now, to keep it simple, we only allow it be
|
||||||
|
# used for the initialization call though, just in case the
|
||||||
|
# signatures of the preprocessor initializer don't match
|
||||||
|
# preprocess()
|
||||||
batch_data = image_processor \
|
batch_data = image_processor \
|
||||||
.preprocess(data, return_tensors="pt") \
|
.preprocess(data, return_tensors="pt") \
|
||||||
.data
|
.data
|
||||||
|
@ -1,6 +1,6 @@
|
|||||||
import functools
|
import functools
|
||||||
from collections import UserDict
|
from collections import UserDict
|
||||||
from typing import Dict, Mapping, Optional, Sequence
|
from typing import Any, Dict, Mapping, Optional, Sequence
|
||||||
|
|
||||||
from vllm.config import ModelConfig
|
from vllm.config import ModelConfig
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
@ -96,8 +96,12 @@ class MultiModalRegistry:
|
|||||||
"""
|
"""
|
||||||
return self.register_input_mapper("image", mapper)
|
return self.register_input_mapper("image", mapper)
|
||||||
|
|
||||||
def map_input(self, model_config: ModelConfig,
|
def map_input(
|
||||||
data: MultiModalDataDict) -> MultiModalInputs:
|
self,
|
||||||
|
model_config: ModelConfig,
|
||||||
|
data: MultiModalDataDict,
|
||||||
|
mm_processor_kwargs: Optional[Dict[str, Any]] = None,
|
||||||
|
) -> MultiModalInputs:
|
||||||
"""
|
"""
|
||||||
Apply an input mapper to the data passed to the model.
|
Apply an input mapper to the data passed to the model.
|
||||||
|
|
||||||
@ -123,7 +127,8 @@ class MultiModalRegistry:
|
|||||||
f"`--limit-mm-per-prompt`, but found {num_items} items "
|
f"`--limit-mm-per-prompt`, but found {num_items} items "
|
||||||
"in the same prompt.")
|
"in the same prompt.")
|
||||||
|
|
||||||
input_dict = plugin.map_input(model_config, data_value)
|
input_dict = plugin.map_input(model_config, data_value,
|
||||||
|
mm_processor_kwargs)
|
||||||
for input_key, input_tensor in input_dict.items():
|
for input_key, input_tensor in input_dict.items():
|
||||||
if input_key in merged_dict:
|
if input_key in merged_dict:
|
||||||
raise ValueError(f"The input mappers (keys={set(data)}) "
|
raise ValueError(f"The input mappers (keys={set(data)}) "
|
||||||
|
@ -1,5 +1,5 @@
|
|||||||
from functools import lru_cache
|
from functools import lru_cache
|
||||||
from typing import List, Union
|
from typing import Any, Dict, List, Optional, Union
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
@ -36,11 +36,13 @@ class VideoPlugin(ImagePlugin):
|
|||||||
def get_data_key(self) -> str:
|
def get_data_key(self) -> str:
|
||||||
return "video"
|
return "video"
|
||||||
|
|
||||||
def _get_hf_video_processor(self, model_config: ModelConfig):
|
def _get_hf_video_processor(
|
||||||
mm_processor_kwargs = ({} if model_config.mm_processor_kwargs is None
|
self,
|
||||||
else model_config.mm_processor_kwargs)
|
model_config: ModelConfig,
|
||||||
# We don't explicitly check kwarg overrides to the HF class
|
mm_processor_kwargs: Optional[Dict[str, Any]] = None,
|
||||||
# since the automodel just takes kwargs, so we can't inspect it
|
):
|
||||||
|
if mm_processor_kwargs is None:
|
||||||
|
mm_processor_kwargs = {}
|
||||||
return cached_get_video_processor(
|
return cached_get_video_processor(
|
||||||
model_config.model,
|
model_config.model,
|
||||||
trust_remote_code=model_config.trust_remote_code,
|
trust_remote_code=model_config.trust_remote_code,
|
||||||
@ -50,16 +52,24 @@ class VideoPlugin(ImagePlugin):
|
|||||||
self,
|
self,
|
||||||
ctx: InputContext,
|
ctx: InputContext,
|
||||||
data: MultiModalData[object],
|
data: MultiModalData[object],
|
||||||
|
**mm_processor_kwargs,
|
||||||
) -> MultiModalInputs:
|
) -> MultiModalInputs:
|
||||||
model_config = ctx.model_config
|
model_config = ctx.model_config
|
||||||
|
|
||||||
# single video input as np.ndarray
|
# single video input as np.ndarray
|
||||||
if isinstance(data, np.ndarray):
|
if isinstance(data, np.ndarray):
|
||||||
video_processor = self._get_hf_video_processor(model_config)
|
video_processor = self._get_hf_video_processor(
|
||||||
|
model_config,
|
||||||
|
mm_processor_kwargs,
|
||||||
|
)
|
||||||
if video_processor is None:
|
if video_processor is None:
|
||||||
raise RuntimeError("No HuggingFace processor is available "
|
raise RuntimeError("No HuggingFace processor is available "
|
||||||
"to process the image object")
|
"to process the image object")
|
||||||
try:
|
try:
|
||||||
|
# NOTE: Similar to image; it may be a good idea to filter and
|
||||||
|
# pass mm_processor_kwargs here too, but for now we don't to
|
||||||
|
# avoid extra complexity if the initializer and preprocess
|
||||||
|
# signatures of the processor don't align
|
||||||
batch_data = video_processor(data, return_tensors="pt").data
|
batch_data = video_processor(data, return_tensors="pt").data
|
||||||
except Exception:
|
except Exception:
|
||||||
logger.error("Failed to process image (%s)", data)
|
logger.error("Failed to process image (%s)", data)
|
||||||
|
@ -481,6 +481,10 @@ class Sequence:
|
|||||||
EncoderDecoderLLMInputs,
|
EncoderDecoderLLMInputs,
|
||||||
inputs).get("encoder_multi_modal_data")) or {}
|
inputs).get("encoder_multi_modal_data")) or {}
|
||||||
|
|
||||||
|
@property
|
||||||
|
def mm_processor_kwargs(self) -> Dict[str, Any]:
|
||||||
|
return self.inputs.get("mm_processor_kwargs") or {}
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def lora_int_id(self) -> int:
|
def lora_int_id(self) -> int:
|
||||||
return self.lora_request.lora_int_id if self.lora_request else 0
|
return self.lora_request.lora_int_id if self.lora_request else 0
|
||||||
@ -710,6 +714,14 @@ class SequenceGroup:
|
|||||||
# We use the multi-modal data of an arbitrary sequence.
|
# We use the multi-modal data of an arbitrary sequence.
|
||||||
return self.seqs[0].multi_modal_data
|
return self.seqs[0].multi_modal_data
|
||||||
|
|
||||||
|
@property
|
||||||
|
def mm_processor_kwargs(self) -> Dict[str, Any]:
|
||||||
|
# As with multi-modal data, all sequences in the group should have the
|
||||||
|
# same processor kwargs (i.e., mm_processor_kwargs are optionally
|
||||||
|
# provided per request; note that are independent of whether the model
|
||||||
|
# decoder-only or an encoder-decoder).
|
||||||
|
return self.seqs[0].mm_processor_kwargs
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def lora_int_id(self) -> int:
|
def lora_int_id(self) -> int:
|
||||||
return self.lora_request.lora_int_id if self.lora_request else 0
|
return self.lora_request.lora_int_id if self.lora_request else 0
|
||||||
@ -949,6 +961,7 @@ class SequenceGroupMetadata(
|
|||||||
used in prefix caching.
|
used in prefix caching.
|
||||||
state: Internal state tied to this sequence group.
|
state: Internal state tied to this sequence group.
|
||||||
multi_modal_data: Multi modal data.
|
multi_modal_data: Multi modal data.
|
||||||
|
mm_processor_kwargs: Multimodal input processor / mapper overrides.
|
||||||
encoder_seq_data: Optional sequence data for encoder prompt
|
encoder_seq_data: Optional sequence data for encoder prompt
|
||||||
(SequenceGroup.encoder_seq). Should be None
|
(SequenceGroup.encoder_seq). Should be None
|
||||||
unless you are working with an encoder/decoder
|
unless you are working with an encoder/decoder
|
||||||
@ -975,6 +988,7 @@ class SequenceGroupMetadata(
|
|||||||
# "MultiModalDataDict" types. We have to use Any due to msgspec
|
# "MultiModalDataDict" types. We have to use Any due to msgspec
|
||||||
# doesn't allow to have union of 2 different dicts.
|
# doesn't allow to have union of 2 different dicts.
|
||||||
multi_modal_data: Optional[Any] = None
|
multi_modal_data: Optional[Any] = None
|
||||||
|
mm_processor_kwargs: Optional[Dict[str, Any]] = None
|
||||||
encoder_seq_data: Optional[SequenceData] = None
|
encoder_seq_data: Optional[SequenceData] = None
|
||||||
cross_block_table: Optional[List[int]] = None
|
cross_block_table: Optional[List[int]] = None
|
||||||
prompt_adapter_request: Optional[PromptAdapterRequest] = None
|
prompt_adapter_request: Optional[PromptAdapterRequest] = None
|
||||||
|
@ -1277,18 +1277,87 @@ async def _run_task_with_lock(task: Callable, lock: asyncio.Lock, *args,
|
|||||||
return await task(*args, **kwargs)
|
return await task(*args, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
def supports_kw(callable: Callable[..., object], kw_name: str) -> bool:
|
def supports_kw(
|
||||||
|
callable: Callable[..., object],
|
||||||
|
kw_name: str,
|
||||||
|
requires_kw_only: bool = False,
|
||||||
|
allow_var_kwargs: bool = True,
|
||||||
|
) -> bool:
|
||||||
|
"""Check if a keyword is a valid kwarg for a callable; if requires_kw_only
|
||||||
|
disallows kwargs names that can also be positional arguments.
|
||||||
|
"""
|
||||||
params = inspect.signature(callable).parameters
|
params = inspect.signature(callable).parameters
|
||||||
if kw_name in params:
|
if not params:
|
||||||
return True
|
return False
|
||||||
|
|
||||||
return any(param.kind == inspect.Parameter.VAR_KEYWORD
|
param_val = params.get(kw_name)
|
||||||
for param in params.values())
|
|
||||||
|
# Types where the it may be valid, i.e., explicitly defined & nonvariadic
|
||||||
|
passable_kw_types = set((inspect.Parameter.POSITIONAL_ONLY,
|
||||||
|
inspect.Parameter.POSITIONAL_OR_KEYWORD,
|
||||||
|
inspect.Parameter.KEYWORD_ONLY))
|
||||||
|
|
||||||
|
if param_val:
|
||||||
|
is_sig_param = param_val.kind in passable_kw_types
|
||||||
|
# We want kwargs only, but this is passable as a positional arg
|
||||||
|
if (requires_kw_only and is_sig_param
|
||||||
|
and param_val.kind != inspect.Parameter.KEYWORD_ONLY):
|
||||||
|
return False
|
||||||
|
if ((requires_kw_only
|
||||||
|
and param_val.kind == inspect.Parameter.KEYWORD_ONLY)
|
||||||
|
or (not requires_kw_only and is_sig_param)):
|
||||||
|
return True
|
||||||
|
|
||||||
|
# If we're okay with var-kwargs, it's supported as long as
|
||||||
|
# the kw_name isn't something like *args, **kwargs
|
||||||
|
if allow_var_kwargs:
|
||||||
|
# Get the last param; type is ignored here because params is a proxy
|
||||||
|
# mapping, but it wraps an ordered dict, and they appear in order.
|
||||||
|
# Ref: https://docs.python.org/3/library/inspect.html#inspect.Signature.parameters
|
||||||
|
last_param = params[next(reversed(params))] # type: ignore
|
||||||
|
return (last_param.kind == inspect.Parameter.VAR_KEYWORD
|
||||||
|
and last_param.name != kw_name)
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
def resolve_mm_processor_kwargs(
|
||||||
|
init_kwargs: Optional[Dict[str, Any]],
|
||||||
|
inference_kwargs: Optional[Dict[str, Any]],
|
||||||
|
callable: Callable[..., object],
|
||||||
|
allow_var_kwargs: bool = False,
|
||||||
|
) -> Dict[str, Any]:
|
||||||
|
"""Applies filtering to eliminate invalid mm_processor_kwargs, i.e.,
|
||||||
|
those who are not explicit keywords to the given callable (of one is
|
||||||
|
given; otherwise no filtering is done), then merges the kwarg dicts,
|
||||||
|
giving priority to inference_kwargs if there are any collisions.
|
||||||
|
|
||||||
|
In the case that no kwarg overrides are provided, returns an empty
|
||||||
|
dict so that it can still be kwarg expanded into the callable later on.
|
||||||
|
|
||||||
|
If allow_var_kwargs=True, allows for things that can be expanded into
|
||||||
|
kwargs as long as they aren't naming collision for var_kwargs or potential
|
||||||
|
positional arguments.
|
||||||
|
"""
|
||||||
|
# Filter inference time multimodal processor kwargs provided
|
||||||
|
runtime_mm_kwargs = get_allowed_kwarg_only_overrides(
|
||||||
|
callable,
|
||||||
|
overrides=inference_kwargs,
|
||||||
|
allow_var_kwargs=allow_var_kwargs)
|
||||||
|
|
||||||
|
# Filter init time multimodal processor kwargs provided
|
||||||
|
init_mm_kwargs = get_allowed_kwarg_only_overrides(
|
||||||
|
callable, overrides=init_kwargs, allow_var_kwargs=allow_var_kwargs)
|
||||||
|
|
||||||
|
# Merge the final processor kwargs, prioritizing inference
|
||||||
|
# time values over the initialization time values.
|
||||||
|
mm_processor_kwargs = {**init_mm_kwargs, **runtime_mm_kwargs}
|
||||||
|
return mm_processor_kwargs
|
||||||
|
|
||||||
|
|
||||||
def get_allowed_kwarg_only_overrides(
|
def get_allowed_kwarg_only_overrides(
|
||||||
callable: Callable[..., object],
|
callable: Callable[..., object],
|
||||||
overrides: Optional[Dict[str, Any]],
|
overrides: Optional[Dict[str, Any]],
|
||||||
|
allow_var_kwargs: bool = False,
|
||||||
) -> Dict[str, Any]:
|
) -> Dict[str, Any]:
|
||||||
"""
|
"""
|
||||||
Given a callable which has one or more keyword only params and a dict
|
Given a callable which has one or more keyword only params and a dict
|
||||||
@ -1300,7 +1369,9 @@ def get_allowed_kwarg_only_overrides(
|
|||||||
|
|
||||||
Args:
|
Args:
|
||||||
callable: Callable which takes 0 or more keyword only arguments.
|
callable: Callable which takes 0 or more keyword only arguments.
|
||||||
|
If None is provided, all overrides names are allowed.
|
||||||
overrides: Potential overrides to be used when invoking the callable.
|
overrides: Potential overrides to be used when invoking the callable.
|
||||||
|
allow_var_kwargs: Allows overrides that are expandable for var kwargs.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Dictionary containing the kwargs to be leveraged which may be used
|
Dictionary containing the kwargs to be leveraged which may be used
|
||||||
@ -1310,17 +1381,15 @@ def get_allowed_kwarg_only_overrides(
|
|||||||
if not overrides:
|
if not overrides:
|
||||||
return {}
|
return {}
|
||||||
|
|
||||||
allowed_override_names = [
|
# Drop any mm_processor_kwargs provided by the user that
|
||||||
name for name, param in inspect.signature(callable).parameters.items()
|
# are not kwargs, unless it can fit it var_kwargs param
|
||||||
if param.kind == inspect.Parameter.KEYWORD_ONLY
|
|
||||||
]
|
|
||||||
|
|
||||||
# Drop any mm_processor_kwargs provided by the user that are
|
|
||||||
# not kwarg names accepted by the provided input processor.
|
|
||||||
filtered_overrides = {
|
filtered_overrides = {
|
||||||
kwarg_name: val
|
kwarg_name: val
|
||||||
for kwarg_name, val in overrides.items()
|
for kwarg_name, val in overrides.items()
|
||||||
if kwarg_name in allowed_override_names
|
if supports_kw(callable,
|
||||||
|
kwarg_name,
|
||||||
|
requires_kw_only=True,
|
||||||
|
allow_var_kwargs=allow_var_kwargs)
|
||||||
}
|
}
|
||||||
|
|
||||||
# If anything is dropped, log a warning
|
# If anything is dropped, log a warning
|
||||||
|
@ -148,8 +148,9 @@ class ModelInputForCPUBuilder(ModelRunnerInputBuilderBase[ModelInputForCPU]):
|
|||||||
)
|
)
|
||||||
|
|
||||||
def _compute_multi_modal_input(self, seq_data: SequenceData, mm_data,
|
def _compute_multi_modal_input(self, seq_data: SequenceData, mm_data,
|
||||||
computed_len: int):
|
computed_len: int,
|
||||||
mm_kwargs = self.multi_modal_input_mapper(mm_data)
|
mm_processor_kwargs: Dict[str, Any]):
|
||||||
|
mm_kwargs = self.multi_modal_input_mapper(mm_data, mm_processor_kwargs)
|
||||||
|
|
||||||
# special processing for mrope position deltas.
|
# special processing for mrope position deltas.
|
||||||
mrope_positions = None
|
mrope_positions = None
|
||||||
@ -210,7 +211,8 @@ class ModelInputForCPUBuilder(ModelRunnerInputBuilderBase[ModelInputForCPU]):
|
|||||||
mrope_positions = None
|
mrope_positions = None
|
||||||
if (mm_data := seq_group_metadata.multi_modal_data):
|
if (mm_data := seq_group_metadata.multi_modal_data):
|
||||||
mm_kwargs, mrope_positions = self._compute_multi_modal_input(
|
mm_kwargs, mrope_positions = self._compute_multi_modal_input(
|
||||||
seq_data, mm_data, computed_len)
|
seq_data, mm_data, computed_len,
|
||||||
|
seq_group_metadata.mm_processor_kwargs)
|
||||||
multi_modal_inputs_list.append(mm_kwargs)
|
multi_modal_inputs_list.append(mm_kwargs)
|
||||||
|
|
||||||
# Token position ids
|
# Token position ids
|
||||||
|
@ -640,7 +640,9 @@ class ModelInputForGPUBuilder(ModelRunnerInputBuilderBase[ModelInputForGPU]):
|
|||||||
if not mm_data:
|
if not mm_data:
|
||||||
return
|
return
|
||||||
|
|
||||||
mm_kwargs = self.multi_modal_input_mapper(mm_data)
|
mm_kwargs = self.multi_modal_input_mapper(
|
||||||
|
mm_data,
|
||||||
|
mm_processor_kwargs=seq_group_metadata.mm_processor_kwargs)
|
||||||
inter_data.multi_modal_inputs = mm_kwargs
|
inter_data.multi_modal_inputs = mm_kwargs
|
||||||
|
|
||||||
# special processing for mrope position deltas.
|
# special processing for mrope position deltas.
|
||||||
|
@ -153,7 +153,10 @@ class NeuronModelRunner(ModelRunnerBase[ModelInputForNeuron]):
|
|||||||
mm_data = seq_group_metadata.multi_modal_data
|
mm_data = seq_group_metadata.multi_modal_data
|
||||||
if mm_data:
|
if mm_data:
|
||||||
# Process multi-modal data
|
# Process multi-modal data
|
||||||
mm_kwargs = self.multi_modal_input_mapper(mm_data)
|
mm_kwargs = self.multi_modal_input_mapper(
|
||||||
|
mm_data,
|
||||||
|
mm_processor_kwargs=seq_group_metadata.mm_processor_kwargs,
|
||||||
|
)
|
||||||
multi_modal_inputs_list.append(mm_kwargs)
|
multi_modal_inputs_list.append(mm_kwargs)
|
||||||
|
|
||||||
max_seq_len = max(seq_lens)
|
max_seq_len = max(seq_lens)
|
||||||
|
@ -172,7 +172,11 @@ class OpenVINOModelRunner:
|
|||||||
|
|
||||||
mm_data = seq_group_metadata.multi_modal_data
|
mm_data = seq_group_metadata.multi_modal_data
|
||||||
if mm_data:
|
if mm_data:
|
||||||
mm_kwargs = self.multi_modal_input_mapper(mm_data)
|
mm_kwargs = self.multi_modal_input_mapper(
|
||||||
|
mm_data,
|
||||||
|
mm_processor_kwargs=seq_group_metadata.
|
||||||
|
mm_processor_kwargs,
|
||||||
|
)
|
||||||
multi_modal_inputs_list.append(mm_kwargs)
|
multi_modal_inputs_list.append(mm_kwargs)
|
||||||
|
|
||||||
block_table = seq_group_metadata.block_tables[seq_id]
|
block_table = seq_group_metadata.block_tables[seq_id]
|
||||||
|
Loading…
x
Reference in New Issue
Block a user