[VLM] Merged multimodal processor for Qwen2-Audio (#11303)

Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
Cyrus Leung 2024-12-19 14:14:17 +08:00 committed by GitHub
parent c6b0a7d3ba
commit 6142ef0ada
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
11 changed files with 414 additions and 358 deletions

View File

@ -18,6 +18,10 @@ question_per_audio_count = {
2: "What sport and what nursery rhyme are referenced?" 2: "What sport and what nursery rhyme are referenced?"
} }
# NOTE: The default `max_num_seqs` and `max_model_len` may result in OOM on
# lower-end GPUs.
# Unless specified, these settings have been tested to work on a single L4.
# Ultravox 0.3 # Ultravox 0.3
def run_ultravox(question: str, audio_count: int): def run_ultravox(question: str, audio_count: int):
@ -33,6 +37,8 @@ def run_ultravox(question: str, audio_count: int):
add_generation_prompt=True) add_generation_prompt=True)
llm = LLM(model=model_name, llm = LLM(model=model_name,
max_model_len=4096,
max_num_seqs=5,
trust_remote_code=True, trust_remote_code=True,
limit_mm_per_prompt={"audio": audio_count}) limit_mm_per_prompt={"audio": audio_count})
stop_token_ids = None stop_token_ids = None

View File

@ -5,6 +5,7 @@ import pytest
import pytest_asyncio import pytest_asyncio
from transformers import AutoModel, AutoTokenizer, BatchEncoding from transformers import AutoModel, AutoTokenizer, BatchEncoding
from vllm.multimodal.audio import resample_audio
from vllm.sequence import SampleLogprobs from vllm.sequence import SampleLogprobs
from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE
@ -130,14 +131,12 @@ def run_test(
dtype=dtype, dtype=dtype,
postprocess_inputs=process, postprocess_inputs=process,
auto_cls=AutoModel) as hf_model: auto_cls=AutoModel) as hf_model:
import librosa
hf_outputs_per_audio = [ hf_outputs_per_audio = [
hf_model.generate_greedy_logprobs_limit( hf_model.generate_greedy_logprobs_limit(
[hf_prompt], [hf_prompt],
max_tokens, max_tokens,
num_logprobs=num_logprobs, num_logprobs=num_logprobs,
audios=[(librosa.resample(audio[0], audios=[(resample_audio(audio[0],
orig_sr=audio[1], orig_sr=audio[1],
target_sr=16000), 16000)]) target_sr=16000), 16000)])
for _, hf_prompt, audio in prompts_and_audios for _, hf_prompt, audio in prompts_and_audios

View File

@ -1,11 +1,11 @@
import functools import functools
from collections import UserDict from collections import UserDict
from dataclasses import dataclass from dataclasses import dataclass
from typing import (TYPE_CHECKING, Any, Callable, Dict, Mapping, NamedTuple, from typing import (TYPE_CHECKING, Any, Callable, Mapping, NamedTuple,
Optional, Protocol, Type) Optional, Protocol, Union)
from torch import nn from torch import nn
from transformers import PretrainedConfig, ProcessorMixin from transformers import BatchFeature, PretrainedConfig, ProcessorMixin
from typing_extensions import TypeVar, assert_never from typing_extensions import TypeVar, assert_never
from vllm.logger import init_logger from vllm.logger import init_logger
@ -26,6 +26,7 @@ if TYPE_CHECKING:
logger = init_logger(__name__) logger = init_logger(__name__)
C = TypeVar("C", bound=PretrainedConfig, default=PretrainedConfig) C = TypeVar("C", bound=PretrainedConfig, default=PretrainedConfig)
P = TypeVar("P", bound=ProcessorMixin, default=ProcessorMixin)
@dataclass(frozen=True) @dataclass(frozen=True)
@ -38,24 +39,28 @@ class InputContext:
model_config: "ModelConfig" model_config: "ModelConfig"
"""The configuration of the model.""" """The configuration of the model."""
def get_hf_config(self, hf_config_type: Type[C] = PretrainedConfig) -> C: def get_hf_config(
self,
typ: Union[type[C], tuple[type[C], ...]] = PretrainedConfig,
/,
) -> C:
""" """
Get the HuggingFace configuration Get the HuggingFace configuration
(:class:`transformers.PretrainedConfig`) of the model, (:class:`transformers.PretrainedConfig`) of the model,
additionally checking its type. additionally checking its type.
Raises: Raises:
TypeError: If the model is not of the specified type. TypeError: If the configuration is not of the specified type.
""" """
hf_config = self.model_config.hf_config hf_config = self.model_config.hf_config
if not isinstance(hf_config, hf_config_type): if not isinstance(hf_config, typ):
raise TypeError("Invalid type of HuggingFace config. " raise TypeError("Invalid type of HuggingFace config. "
f"Expected type: {hf_config_type}, but " f"Expected type: {typ}, but "
f"found type: {type(hf_config)}") f"found type: {type(hf_config)}")
return hf_config return hf_config
def get_hf_image_processor_config(self) -> Dict[str, Any]: def get_hf_image_processor_config(self) -> dict[str, Any]:
""" """
Get the HuggingFace image processor configuration of the model. Get the HuggingFace image processor configuration of the model.
""" """
@ -74,18 +79,37 @@ class InputContext:
return mm_config return mm_config
def get_hf_processor(self, **kwargs: object) -> ProcessorMixin: def get_hf_processor(
self,
typ: Union[type[P], tuple[type[P], ...]] = ProcessorMixin,
/,
**kwargs: object,
) -> P:
"""
Get the HuggingFace processor
(:class:`transformers.ProcessorMixin`) of the model,
additionally checking its type.
Raises:
TypeError: If the processor is not of the specified type.
"""
base_kwargs = self.model_config.mm_processor_kwargs base_kwargs = self.model_config.mm_processor_kwargs
if base_kwargs is None: if base_kwargs is None:
base_kwargs = {} base_kwargs = {}
merged_kwargs = {**base_kwargs, **kwargs} merged_kwargs = {**base_kwargs, **kwargs}
return cached_get_processor( hf_processor = cached_get_processor(
self.model_config.model, self.model_config.model,
trust_remote_code=self.model_config.trust_remote_code, trust_remote_code=self.model_config.trust_remote_code,
**merged_kwargs, **merged_kwargs,
) )
if not isinstance(hf_processor, typ):
raise TypeError("Invalid type of HuggingFace processor. "
f"Expected type: {typ}, but "
f"found type: {type(hf_processor)}")
return hf_processor
@dataclass(frozen=True) @dataclass(frozen=True)
@ -93,39 +117,55 @@ class InputProcessingContext(InputContext):
tokenizer: AnyTokenizer tokenizer: AnyTokenizer
"""The tokenizer used to tokenize the inputs.""" """The tokenizer used to tokenize the inputs."""
def get_hf_processor(self, **kwargs: object) -> ProcessorMixin: def get_hf_processor(
base_kwargs = self.model_config.mm_processor_kwargs self,
if base_kwargs is None: typ: Union[type[P], tuple[type[P], ...]] = ProcessorMixin,
base_kwargs = {} /,
**kwargs: object,
merged_kwargs = {**base_kwargs, **kwargs} ) -> P:
return super().get_hf_processor(
return cached_get_processor( typ,
self.model_config.model, tokenizer=self.tokenizer,
tokenizer=self.tokenizer, # Override the tokenizer with ours **kwargs,
trust_remote_code=self.model_config.trust_remote_code,
**merged_kwargs,
) )
def resolve_hf_processor_call_kwargs( def call_hf_processor(
self, self,
hf_processor: ProcessorMixin, hf_processor: ProcessorMixin,
prompt: str,
processor_data: Mapping[str, object],
inference_kwargs: Mapping[str, object], inference_kwargs: Mapping[str, object],
) -> Mapping[str, object]: ) -> BatchFeature:
assert callable(hf_processor) assert callable(hf_processor)
base_kwargs = self.model_config.mm_processor_kwargs base_kwargs = self.model_config.mm_processor_kwargs
if base_kwargs is None: if base_kwargs is None:
base_kwargs = {} base_kwargs = {}
return resolve_mm_processor_kwargs( merged_kwargs = resolve_mm_processor_kwargs(
base_kwargs, base_kwargs,
inference_kwargs, inference_kwargs,
hf_processor, hf_processor,
requires_kw_only=False,
allow_var_kwargs=True,
) )
try:
return hf_processor(
text=prompt,
**processor_data,
**merged_kwargs,
return_tensors="pt",
)
except Exception as exc:
data = dict(text=prompt, **processor_data)
msg = (f"Failed to apply {type(hf_processor).__name__} "
f"on data={data} with kwargs={merged_kwargs}")
N = TypeVar("N", bound=Type[nn.Module]) raise RuntimeError(msg) from exc
N = TypeVar("N", bound=type[nn.Module])
class DummyData(NamedTuple): class DummyData(NamedTuple):
@ -232,7 +272,7 @@ class InputRegistry:
return wrapper return wrapper
def _get_dummy_data_factory(self, model_cls: Type[nn.Module]): def _get_dummy_data_factory(self, model_cls: type[nn.Module]):
return self._dummy_factories_by_model_type \ return self._dummy_factories_by_model_type \
.get(model_cls, self._default_dummy_data_factory) .get(model_cls, self._default_dummy_data_factory)
@ -257,7 +297,7 @@ class InputRegistry:
return wrapper return wrapper
def _get_dummy_encoder_data_factory(self, model_cls: Type[nn.Module]): def _get_dummy_encoder_data_factory(self, model_cls: type[nn.Module]):
return self._dummy_encoder_factories_by_model_type \ return self._dummy_encoder_factories_by_model_type \
.get(model_cls, self._default_dummy_data_factory) .get(model_cls, self._default_dummy_data_factory)
@ -368,14 +408,14 @@ class InputRegistry:
return wrapper return wrapper
def _get_model_input_processor(self, model_cls: Type[nn.Module]): def _get_model_input_processor(self, model_cls: type[nn.Module]):
return self._input_processors_by_model_type \ return self._input_processors_by_model_type \
.get(model_cls, self._default_input_processor) .get(model_cls, self._default_input_processor)
def _ensure_mm_kwargs( def _ensure_mm_kwargs(
self, self,
inputs: SingletonInputs, inputs: SingletonInputs,
mm_processor_kwargs: Dict[str, Any], mm_processor_kwargs: dict[str, Any],
): ):
if inputs["type"] == "token": if inputs["type"] == "token":
# In case the input processor for that model fails to set it # In case the input processor for that model fails to set it

View File

@ -133,8 +133,8 @@ class LlavaMultiModalProcessor(BaseMultiModalProcessor):
hf_processor.__is_patched__ = True # type: ignore hf_processor.__is_patched__ = True # type: ignore
def _get_hf_processor(self) -> Union[LlavaProcessor, PixtralProcessor]: def _get_hf_processor(self) -> Union[LlavaProcessor, PixtralProcessor]:
hf_processor = self.ctx.get_hf_processor() hf_processor = self.ctx.get_hf_processor(
assert isinstance(hf_processor, (LlavaProcessor, PixtralProcessor)) (LlavaProcessor, PixtralProcessor))
if isinstance(hf_processor, PixtralProcessor): if isinstance(hf_processor, PixtralProcessor):
self._patch_pixtral_processor(hf_processor) self._patch_pixtral_processor(hf_processor)

View File

@ -34,7 +34,6 @@ from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.inputs import NestedTensors from vllm.multimodal.inputs import NestedTensors
from vllm.multimodal.processing import (BaseMultiModalProcessor, from vllm.multimodal.processing import (BaseMultiModalProcessor,
MultiModalDataDict,
MultiModalDataItems, ProcessorInputs, MultiModalDataItems, ProcessorInputs,
PromptReplacement) PromptReplacement)
from vllm.sequence import IntermediateTensors from vllm.sequence import IntermediateTensors
@ -330,20 +329,27 @@ class Phi3VMultiModalProcessor(BaseMultiModalProcessor):
return self.ctx.get_hf_processor(num_crops=num_crops) return self.ctx.get_hf_processor(num_crops=num_crops)
return self.ctx.get_hf_processor() return self.ctx.get_hf_processor()
def _apply_hf_processor( def _call_hf_processor(
self, self,
hf_processor: ProcessorMixin,
prompt: str, prompt: str,
mm_data: MultiModalDataDict, processor_data: Mapping[str, object],
mm_processor_kwargs: Mapping[str, object], mm_processor_kwargs: Mapping[str, object],
) -> BatchFeature: ) -> BatchFeature:
processed_outputs = super()._apply_hf_processor( processed_outputs = super()._call_hf_processor(
prompt, mm_data, mm_processor_kwargs) hf_processor,
prompt=prompt,
processor_data=processor_data,
mm_processor_kwargs=mm_processor_kwargs,
)
# Phi3v processor has inserted -1, -2 etc as placeholder in prompt_ids, # Phi3v processor has inserted -1, -2 etc as placeholder in prompt_ids,
# which will cause OverflowError when decoding the prompt_ids. # which will cause OverflowError when decoding the prompt_ids.
# Therefore, we need to do an early replacement here # Therefore, we need to do an early replacement here
token_ids = processed_outputs['input_ids'] token_ids = processed_outputs['input_ids']
token_ids[token_ids < 0] = _IMAGE_TOKEN_ID token_ids[token_ids < 0] = _IMAGE_TOKEN_ID
processed_outputs['input_ids'] = token_ids processed_outputs['input_ids'] = token_ids
return processed_outputs return processed_outputs
def _get_prompt_replacements( def _get_prompt_replacements(

View File

@ -19,45 +19,43 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
"""Inference-only Qwen2-Audio model compatible with HuggingFace weights.""" """Inference-only Qwen2-Audio model compatible with HuggingFace weights."""
from functools import cached_property, lru_cache from functools import cached_property
from typing import (Iterable, List, Mapping, Optional, Set, Tuple, TypedDict, from typing import (Any, Iterable, List, Mapping, Optional, Set, Tuple,
Union) TypedDict, Union)
import librosa
import numpy as np import numpy as np
import torch import torch
import torch.nn as nn import torch.nn as nn
from transformers import Qwen2AudioEncoder from transformers import BatchFeature, ProcessorMixin
from transformers.models.qwen2_audio import (Qwen2AudioConfig,
Qwen2AudioEncoder,
Qwen2AudioProcessor)
from transformers.models.whisper import WhisperFeatureExtractor
from vllm.attention import AttentionMetadata from vllm.attention import AttentionMetadata
from vllm.config import VllmConfig from vllm.config import VllmConfig
from vllm.inputs import (INPUT_REGISTRY, DecoderOnlyInputs, DummyData, from vllm.inputs import InputContext
InputContext, token_inputs)
from vllm.logger import init_logger
from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler
from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalKwargs from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.inputs import NestedTensors from vllm.multimodal.inputs import NestedTensors
from vllm.multimodal.utils import consecutive_placeholder_ranges from vllm.multimodal.processing import (BaseMultiModalProcessor,
from vllm.sequence import IntermediateTensors, SequenceData MultiModalDataItems, ProcessorInputs,
PromptReplacement)
from vllm.sequence import IntermediateTensors
from .interfaces import SupportsMultiModal, SupportsPP from .interfaces import SupportsMultiModal, SupportsPP
from .utils import (AutoWeightsLoader, init_vllm_registered_model, from .utils import (AutoWeightsLoader, init_vllm_registered_model,
maybe_prefix, merge_multimodal_embeddings) maybe_prefix, merge_multimodal_embeddings)
logger = init_logger(__name__)
# # === Audio Inputs === # # # === Audio Inputs === #
class Qwen2AudioInputs(TypedDict): class Qwen2AudioInputs(TypedDict):
input_features: torch.Tensor input_features: torch.Tensor
"""Shape: """Shape: `(num_audios, num_mel_bins, 3000)`"""
`(num_audios, num_mel_bins, 3000)`
"""
feature_attention_mask: torch.Tensor feature_attention_mask: torch.Tensor
"""Shape: `(num_audios, 3000)` """Shape: `(num_audios, 3000)`"""
"""
# === Audio Encoder === # # === Audio Encoder === #
@ -74,187 +72,114 @@ class Qwen2AudioMultiModalProjector(nn.Module):
return hidden_states return hidden_states
def dummy_data_for_qwen2_audio(ctx: InputContext, seq_len: int, # From Qwen2AudioEncoder._get_feat_extract_output_lengths
mm_counts: Mapping[str, int]):
num_audios = mm_counts["audio"]
max_tokens_per_audio = get_max_qwen2_audio_audio_tokens(ctx)
max_llm_audio_tokens = max_tokens_per_audio * num_audios
if seq_len - max_llm_audio_tokens - 2 < 0:
raise RuntimeError(
f"Qwen2-Audio cannot process {num_audios} audios in a prompt, "
"please increase max_model_len or reduce audio limit by "
"--limit-mm-per-prompt.")
audio_token_index = ctx.model_config.hf_config.audio_token_index
dummy_seqdata = SequenceData.from_prompt_token_counts(
(audio_token_index, max_llm_audio_tokens),
(0, seq_len - max_llm_audio_tokens),
)
dummy_audio = np.full((max_llm_audio_tokens * 2 * 2 * 160, ), 0.)
return DummyData(
dummy_seqdata, {"audio": [(dummy_audio, 16000)] * num_audios}, {
"audio":
consecutive_placeholder_ranges(num_items=num_audios,
item_size=max_tokens_per_audio)
})
def get_processor(
processor_name: str,
*args,
trust_remote_code: bool = False,
**kwargs,
):
"""Gets a processor for the given model name via HuggingFace.
Derived from `vllm.transformers_utils.image_processor.get_image_processor`.
"""
# don't put this import at the top level
# it will call torch.cuda.device_count()
from transformers import AutoProcessor
try:
processor = AutoProcessor.from_pretrained(
processor_name,
*args,
trust_remote_code=trust_remote_code,
**kwargs)
except ValueError as e:
# If the error pertains to the processor class not existing or not
# currently being imported, suggest using the --trust-remote-code flag.
# Unlike AutoTokenizer, AutoProcessor does not separate such errors
if not trust_remote_code:
err_msg = (
"Failed to load the processor. If the processor is "
"a custom processor not yet available in the HuggingFace "
"transformers library, consider setting "
"`trust_remote_code=True` in LLM or using the "
"`--trust-remote-code` flag in the CLI.")
raise RuntimeError(err_msg) from e
else:
raise e
return processor
cached_get_processor = lru_cache(get_processor)
def _get_feat_extract_output_lengths(input_lengths: torch.LongTensor): def _get_feat_extract_output_lengths(input_lengths: torch.LongTensor):
""" feat_lengths = (input_lengths - 1) // 2 + 1
Computes the output length of the convolutional layers output_lengths = (feat_lengths - 2) // 2 + 1
and the output length of the audio encoder return feat_lengths, output_lengths
"""
input_lengths = (input_lengths - 1) // 2 + 1
output_lengths = (input_lengths - 2) // 2 + 1
return input_lengths, output_lengths
def get_max_qwen2_audio_audio_tokens(ctx: InputContext) -> int: def get_max_qwen2_audio_audio_tokens(ctx: InputContext) -> int:
max_source_position = ( hf_config = ctx.get_hf_config(Qwen2AudioConfig)
ctx.model_config.hf_config.audio_config.max_source_positions) max_source_position = hf_config.audio_config.max_source_positions
output_lengths = (max_source_position - 2) // 2 + 1 output_lengths = (max_source_position - 2) // 2 + 1
return output_lengths return output_lengths
def input_processor_for_qwen2_audio( class Qwen2AudioMultiModalProcessor(BaseMultiModalProcessor):
ctx: InputContext, inputs: DecoderOnlyInputs) -> DecoderOnlyInputs:
multi_modal_data = inputs.get("multi_modal_data")
if multi_modal_data is None or "audio" not in multi_modal_data:
return inputs
audios = multi_modal_data["audio"] def _get_hf_processor(self) -> Qwen2AudioProcessor:
if not isinstance(audios, list): return self.ctx.get_hf_processor(Qwen2AudioProcessor)
audios = [audios]
if len(audios) == 0: def _get_feature_extractor(self) -> WhisperFeatureExtractor:
return inputs return self._get_hf_processor().feature_extractor # type: ignore
processor = cached_get_processor(ctx.model_config.model) def _get_processor_data(
resampled_audios = [ self,
librosa.resample(audio, mm_items: MultiModalDataItems,
orig_sr=sampling_rate, ) -> tuple[dict[str, Any], dict[str, Any]]:
target_sr=processor.feature_extractor.sampling_rate) # resample audio to the model's sampling rate
for audio, sampling_rate in audios feature_extractor = self._get_feature_extractor()
mm_items.resample_audios(feature_extractor.sampling_rate)
return super()._get_processor_data(mm_items)
def _call_hf_processor(
self,
hf_processor: ProcessorMixin,
prompt: str,
processor_data: Mapping[str, object],
mm_processor_kwargs: Mapping[str, object],
) -> BatchFeature:
processor_data = dict(processor_data)
audios = processor_data.pop("audios", [])
if audios:
processor_data["audios"] = audios
feature_extractor = self._get_feature_extractor()
mm_processor_kwargs = dict(
**mm_processor_kwargs,
sampling_rate=feature_extractor.sampling_rate,
)
else:
# NOTE: WhisperFeatureExtractor cannot handle empty list of audios
pass
return super()._call_hf_processor(
hf_processor,
prompt=prompt,
processor_data=processor_data,
mm_processor_kwargs=mm_processor_kwargs,
)
def _get_prompt_replacements(
self,
mm_items: MultiModalDataItems,
hf_inputs: BatchFeature,
mm_processor_kwargs: Mapping[str, object],
) -> list[PromptReplacement]:
hf_config = self.ctx.get_hf_config(Qwen2AudioConfig)
placeholder = hf_config.audio_token_index
feature_attention_mask = hf_inputs.get("feature_attention_mask")
if feature_attention_mask is None:
audio_output_lengths = []
else:
_, audio_output_lengths = _get_feat_extract_output_lengths(
feature_attention_mask.sum(-1))
def get_replacement_qwen2_audio(item_idx: int):
return [placeholder] * audio_output_lengths[item_idx]
return [
PromptReplacement(
modality="audio",
target=[placeholder],
replacement=get_replacement_qwen2_audio,
)
] ]
audio_input_lengths = np.array(
[min(3000, _.shape[0] // 160 + 1) for _ in resampled_audios])
audio_feat_lengths, audio_output_lengths = _get_feat_extract_output_lengths( def _get_dummy_mm_inputs(
audio_input_lengths) self,
mm_counts: Mapping[str, int],
) -> ProcessorInputs:
audio_len = get_max_qwen2_audio_audio_tokens(self.ctx)
audio_token_index = ctx.model_config.hf_config.audio_token_index audio_count = mm_counts["audio"]
audio = np.zeros(audio_len)
data = {"audio": [audio] * audio_count}
input_ids = inputs['prompt_token_ids'] return ProcessorInputs(
prompt_text="<|AUDIO|>" * audio_count,
new_input_ids = [] mm_data=data,
audio_num = input_ids.count(audio_token_index) mm_processor_kwargs={},
assert len(audio_input_lengths) == audio_num, \
(f'The text input contains {audio_num} audio tokens, '
f'but {len(audio_input_lengths)} audios provided')
start = 0
for audio_idx in range(audio_num):
end = input_ids.index(audio_token_index, start)
new_input_ids.extend(input_ids[start:end]) # text part
new_input_ids.extend([audio_token_index] *
audio_output_lengths[audio_idx])
start = end + 1
new_input_ids.extend(input_ids[start:])
return token_inputs(
prompt_token_ids=new_input_ids,
prompt=inputs.get("prompt"),
multi_modal_data=multi_modal_data,
) )
def input_mapper_for_qwen2_audio(
ctx: InputContext,
multi_modal_data: Union[np.ndarray, List[np.ndarray]],
) -> MultiModalKwargs:
"""Input mapper for Qwen2-Audio."""
if not isinstance(multi_modal_data, list):
multi_modal_data = [multi_modal_data]
if len(multi_modal_data) == 0:
return MultiModalKwargs()
processor = cached_get_processor(ctx.model_config.model)
audio_feature_extractor = processor.feature_extractor
if audio_feature_extractor is None:
raise RuntimeError(
"No HuggingFace audio_feature_extractor is available "
"to process the audio object")
try:
resampled_audios = [
librosa.resample(
audio,
orig_sr=sampling_rate,
target_sr=processor.feature_extractor.sampling_rate)
for audio, sampling_rate in multi_modal_data
]
batch_data = audio_feature_extractor(resampled_audios,
sampling_rate=16000,
return_attention_mask=True,
padding="max_length",
return_tensors="pt").data
batch_data["feature_attention_mask"] = batch_data.pop("attention_mask")
except Exception:
logger.error("Failed to process audio (%s)", multi_modal_data)
raise
return MultiModalKwargs(batch_data)
@INPUT_REGISTRY.register_dummy_data(dummy_data_for_qwen2_audio)
@INPUT_REGISTRY.register_input_processor(input_processor_for_qwen2_audio)
@MULTIMODAL_REGISTRY.register_input_mapper("audio",
input_mapper_for_qwen2_audio)
@MULTIMODAL_REGISTRY.register_max_multimodal_tokens( @MULTIMODAL_REGISTRY.register_max_multimodal_tokens(
"audio", get_max_qwen2_audio_audio_tokens) "audio", get_max_qwen2_audio_audio_tokens)
@MULTIMODAL_REGISTRY.register_processor(Qwen2AudioMultiModalProcessor)
class Qwen2AudioForConditionalGeneration(nn.Module, SupportsMultiModal, class Qwen2AudioForConditionalGeneration(nn.Module, SupportsMultiModal,
SupportsPP): SupportsPP):
@ -289,9 +214,7 @@ class Qwen2AudioForConditionalGeneration(nn.Module, SupportsMultiModal,
return get_sampler() return get_sampler()
def _validate_and_reshape_mm_tensor(self, def _validate_and_reshape_mm_tensor(self, mm_input: object,
mm_input: Union[torch.Tensor,
List[torch.Tensor]],
name: str) -> torch.Tensor: name: str) -> torch.Tensor:
if not isinstance(mm_input, (torch.Tensor, list)): if not isinstance(mm_input, (torch.Tensor, list)):
raise ValueError(f"Incorrect type of {name}. " raise ValueError(f"Incorrect type of {name}. "

View File

@ -3,7 +3,7 @@
import math import math
from functools import cached_property, lru_cache from functools import cached_property, lru_cache
from typing import (Any, Dict, Iterable, List, Literal, Mapping, Optional, Set, from typing import (Any, Iterable, List, Literal, Mapping, Optional, Set,
Tuple, TypedDict, Union) Tuple, TypedDict, Union)
import numpy as np import numpy as np
@ -11,7 +11,7 @@ import torch
import torch.utils.checkpoint import torch.utils.checkpoint
from torch import nn from torch import nn
from torch.nn import functional as F from torch.nn import functional as F
from transformers import BatchFeature from transformers import BatchFeature, ProcessorMixin
from transformers.models.whisper import WhisperFeatureExtractor from transformers.models.whisper import WhisperFeatureExtractor
from transformers.models.whisper.modeling_whisper import WhisperEncoder from transformers.models.whisper.modeling_whisper import WhisperEncoder
@ -25,11 +25,11 @@ from vllm.model_executor.model_loader.loader import DefaultModelLoader
from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.multimodal import MULTIMODAL_REGISTRY, NestedTensors from vllm.multimodal import MULTIMODAL_REGISTRY, NestedTensors
from vllm.multimodal.processing import (BaseMultiModalProcessor, from vllm.multimodal.processing import (BaseMultiModalProcessor,
MultiModalDataDict,
MultiModalDataItems, ProcessorInputs, MultiModalDataItems, ProcessorInputs,
PromptReplacement) PromptReplacement)
from vllm.sequence import IntermediateTensors from vllm.sequence import IntermediateTensors
from vllm.transformers_utils.configs.ultravox import UltravoxConfig from vllm.transformers_utils.configs.ultravox import UltravoxConfig
from vllm.utils import is_list_of
from .interfaces import SupportsMultiModal, SupportsPP from .interfaces import SupportsMultiModal, SupportsPP
from .utils import (AutoWeightsLoader, WeightsMapper, flatten_bn, from .utils import (AutoWeightsLoader, WeightsMapper, flatten_bn,
@ -61,8 +61,8 @@ def cached_feature_extractor(model_id: str) -> WhisperFeatureExtractor:
def whisper_feature_extractor(ctx: InputContext) -> WhisperFeatureExtractor: def whisper_feature_extractor(ctx: InputContext) -> WhisperFeatureExtractor:
return cached_feature_extractor( hf_config = ctx.get_hf_config(UltravoxConfig)
ctx.get_hf_config(UltravoxConfig).audio_model_id) return cached_feature_extractor(hf_config.audio_model_id)
def get_ultravox_max_audio_tokens(ctx: InputContext): def get_ultravox_max_audio_tokens(ctx: InputContext):
@ -73,72 +73,71 @@ def get_ultravox_max_audio_tokens(ctx: InputContext):
class UltravoxMultiModalProcessor(BaseMultiModalProcessor): class UltravoxMultiModalProcessor(BaseMultiModalProcessor):
def _get_feature_extractor(self) -> WhisperFeatureExtractor: def _get_feature_extractor(self) -> WhisperFeatureExtractor:
return self._get_hf_processor().audio_processor.feature_extractor hf_processor = self._get_hf_processor()
return hf_processor.audio_processor.feature_extractor # type: ignore
def _resample_audio(
self,
audio: np.ndarray,
sr: int,
) -> Dict[str, Union[np.ndarray, int]]:
# resample audio to the model's sampling rate
feature_extractor = self._get_feature_extractor()
if sr != feature_extractor.sampling_rate:
try:
import librosa
except ImportError as exc:
raise ImportError(
"Please install vllm[audio] for audio support.") from exc
audio = librosa.resample(audio,
orig_sr=sr,
target_sr=feature_extractor.sampling_rate)
sr = feature_extractor.sampling_rate
return {"audio": audio, "sampling_rate": sr}
def _apply_hf_processor(
self,
prompt: str,
mm_data: MultiModalDataDict,
mm_processor_kwargs: Mapping[str, object],
) -> BatchFeature:
if not mm_data or not mm_data.get("audio", None):
return super()._apply_hf_processor(prompt, mm_data,
mm_processor_kwargs)
audio_data = mm_data["audio"]
if not isinstance(audio_data, list):
audio_data = [audio_data]
# Ultravox processor doesn't support multiple inputs,
# therefore we need to input text and audio one by one
tokenizer = self._get_tokenizer()
audio_features, audio_token_len = [], []
processed_inputs = {}
for audio, sr in audio_data:
data = self._resample_audio(audio, sr)
processed_inputs = super()._apply_hf_processor(
prompt, data, mm_processor_kwargs)
prompt = tokenizer.decode(processed_inputs["input_ids"][0],
skip_special_tokens=False)
audio_features.append(
processed_inputs.pop("audio_values").squeeze(0))
audio_token_len.append(
processed_inputs.pop("audio_token_len").item())
return dict(
**processed_inputs,
audio_features=audio_features,
audio_token_len=audio_token_len,
)
def _get_processor_data( def _get_processor_data(
self, self,
mm_data: MultiModalDataDict, mm_items: MultiModalDataItems,
) -> Tuple[Dict[str, Any], Dict[str, Any]]: ) -> tuple[dict[str, Any], dict[str, Any]]:
# Ultravox uses "audio" instead of "audios" as calling keyword # resample audio to the model's sampling rate
processor_data, passthrough_data = super()._get_processor_data(mm_data) feature_extractor = self._get_feature_extractor()
if "audios" in processor_data: mm_items.resample_audios(feature_extractor.sampling_rate)
processor_data["audio"] = processor_data.pop("audios")
return processor_data, passthrough_data return super()._get_processor_data(mm_items)
def _call_hf_processor(
self,
hf_processor: ProcessorMixin,
prompt: str,
processor_data: Mapping[str, object],
mm_processor_kwargs: Mapping[str, object],
) -> BatchFeature:
processor_data = dict(processor_data)
audios = processor_data.pop("audios", [])
if not audios:
return super()._call_hf_processor(
hf_processor,
prompt=prompt,
processor_data=processor_data,
mm_processor_kwargs=mm_processor_kwargs,
)
feature_extractor = self._get_feature_extractor()
mm_processor_kwargs = dict(
**mm_processor_kwargs,
sampling_rate=feature_extractor.sampling_rate,
)
# Already resampled by _get_processor_data
assert is_list_of(audios, np.ndarray)
# Ultravox processor doesn't support multiple inputs,
# therefore we need to input text and audio one by one
audio_features, audio_token_len = [], []
shared_outputs = {}
for audio in audios:
# NOTE: Ultravox processor accepts "audio" instead of "audios"
item_processor_data = dict(**processor_data, audio=audio)
item_outputs = super()._call_hf_processor(
hf_processor,
prompt=prompt,
processor_data=item_processor_data,
mm_processor_kwargs=mm_processor_kwargs,
)
audio_features.append(item_outputs.pop("audio_values")[0])
audio_token_len.append(item_outputs.pop("audio_token_len").item())
shared_outputs = item_outputs
combined_outputs = dict(
**shared_outputs,
audio_features=audio_features,
audio_token_len=audio_token_len,
)
return BatchFeature(combined_outputs)
def _get_prompt_replacements( def _get_prompt_replacements(
self, self,
@ -147,7 +146,7 @@ class UltravoxMultiModalProcessor(BaseMultiModalProcessor):
mm_processor_kwargs: Mapping[str, object], mm_processor_kwargs: Mapping[str, object],
) -> list[PromptReplacement]: ) -> list[PromptReplacement]:
hf_processor = self._get_hf_processor() hf_processor = self._get_hf_processor()
placeholder = hf_processor.audio_token_replacement placeholder = hf_processor.audio_token_replacement # type: ignore
def get_replacement_ultravox(item_idx: int): def get_replacement_ultravox(item_idx: int):
audio_token_len = hf_inputs["audio_token_len"][item_idx] audio_token_len = hf_inputs["audio_token_len"][item_idx]
@ -171,7 +170,7 @@ class UltravoxMultiModalProcessor(BaseMultiModalProcessor):
audio_count = mm_counts["audio"] audio_count = mm_counts["audio"]
audio = np.zeros(audio_len) audio = np.zeros(audio_len)
data = {"audio": [(audio, sampling_rate)] * audio_count} data = {"audio": [audio] * audio_count}
return ProcessorInputs( return ProcessorInputs(
prompt_text="<|audio|>" * audio_count, prompt_text="<|audio|>" * audio_count,

View File

@ -1,3 +1,6 @@
import numpy as np
import numpy.typing as npt
from vllm.inputs.registry import InputContext from vllm.inputs.registry import InputContext
from .base import MultiModalPlugin from .base import MultiModalPlugin
@ -21,3 +24,18 @@ class AudioPlugin(MultiModalPlugin):
def _default_max_multimodal_tokens(self, ctx: InputContext) -> int: def _default_max_multimodal_tokens(self, ctx: InputContext) -> int:
raise NotImplementedError( raise NotImplementedError(
"There is no default maximum multimodal tokens") "There is no default maximum multimodal tokens")
def resample_audio(
audio: npt.NDArray[np.floating],
*,
orig_sr: float,
target_sr: float,
) -> npt.NDArray[np.floating]:
try:
import librosa
except ImportError as exc:
msg = "Please install vllm[audio] for audio support."
raise ImportError(msg) from exc
return librosa.resample(audio, orig_sr=orig_sr, target_sr=target_sr)

View File

@ -15,31 +15,32 @@ _T = TypeVar("_T")
# yapf: disable # yapf: disable
ImageItem: TypeAlias = Union[Image, np.ndarray, torch.Tensor] ImageItem: TypeAlias = Union[Image, np.ndarray, torch.Tensor]
""" """
A :class:`transformers.image_utils.ImageInput` representing a single image, A :class:`transformers.image_utils.ImageInput` representing a single image
which can be passed to a HuggingFace :code:`ImageProcessor`. item, which can be passed to a HuggingFace :code:`ImageProcessor`.
""" """
VideoItem: TypeAlias = Union[ VideoItem: TypeAlias = Union[
List[Image], list[Image],
np.ndarray, np.ndarray,
torch.Tensor, torch.Tensor,
List[np.ndarray], list[np.ndarray],
List[torch.Tensor], list[torch.Tensor],
] ]
""" """
A :class:`transformers.image_utils.VideoInput` representing a single video
A :class:`transformers.image_utils.VideoInput` representing a single video, item, which can be passed to a HuggingFace :code:`VideoProcessor`.
which can be passed to a HuggingFace :code:`VideoProcessor`.
""" """
AudioItem: TypeAlias = Union[ AudioItem: TypeAlias = Union[
np.ndarray, np.ndarray,
List[float], list[float],
Tuple[np.ndarray, float], # DEPRECATED: Use mm_processor_kwargs instead # `(audio, sampling_rate)`: If the audio's sampling rate is different
# from that expected by the model, we need to resample it.
tuple[np.ndarray, float],
] ]
""" """
Represents a single audio that can be inputted to a HuggingFace Represents a single audio
:code:`AudioProcessor`. item, which can be passed to a HuggingFace :code:`AudioProcessor`.
""" """
# yapf: enable # yapf: enable

View File

@ -17,6 +17,7 @@ from vllm.logger import init_logger
from vllm.transformers_utils.tokenizer import AnyTokenizer, MistralTokenizer from vllm.transformers_utils.tokenizer import AnyTokenizer, MistralTokenizer
from vllm.utils import flatten_2d_lists, full_groupby, is_list_of from vllm.utils import flatten_2d_lists, full_groupby, is_list_of
from .audio import resample_audio
from .inputs import (AudioItem, ImageItem, MultiModalDataDict, from .inputs import (AudioItem, ImageItem, MultiModalDataDict,
MultiModalInputsV2, MultiModalKwargs, PlaceholderRange, MultiModalInputsV2, MultiModalKwargs, PlaceholderRange,
VideoItem) VideoItem)
@ -30,7 +31,7 @@ _PromptSeq = Union[str, list[int]]
@dataclass @dataclass
class PromptReplacement: class PromptReplacement:
modality: str modality: str
"""The modality for which the replacement is made""" """The modality for which the replacement is made."""
target: _PromptSeq target: _PromptSeq
"""The text or token sequence to find and replace.""" """The text or token sequence to find and replace."""
@ -211,31 +212,8 @@ class MultiModalDataItems(UserDict[str, list[Any]]):
corresponds to a list. corresponds to a list.
""" """
@property @staticmethod
def image(self) -> list[ImageItem]: def from_dict(data: MultiModalDataDict) -> "MultiModalDataItems":
return self["image"]
@property
def video(self) -> list[VideoItem]:
return self["video"]
@property
def audio(self) -> list[AudioItem]:
return self["audio"]
def get_image_size(self, item_idx: int) -> ImageSize:
image = self.image[item_idx]
if isinstance(image, Image):
return ImageSize(*image.size)
if isinstance(image, (np.ndarray, torch.Tensor)):
_, h, w = image.shape
return ImageSize(w, h)
assert_never(image)
def to_multi_format(data: MultiModalDataDict) -> MultiModalDataItems:
""" """
Normalize :class:`MultiModalDataDict` to :class:`MultiModalDataItems`. Normalize :class:`MultiModalDataDict` to :class:`MultiModalDataItems`.
""" """
@ -245,15 +223,82 @@ def to_multi_format(data: MultiModalDataDict) -> MultiModalDataItems:
# yapf: disable # yapf: disable
if k == "video": if k == "video":
# Special case since even a single item can be a list # Special case since even a single item can be a list
multi_data[k] = v if is_list_of(v, list) else [v] # type: ignore[index] multi_data[k] = ( # type: ignore[index]
v if is_list_of(v, (list, torch.Tensor)) else [v]
)
elif k in ("image", "audio"): elif k in ("image", "audio"):
multi_data[k] = v if isinstance(v, list) else [v] # type: ignore[index] multi_data[k] = ( # type: ignore[index]
v if isinstance(v, (list, torch.Tensor)) else [v]
)
else: else:
multi_data[k] = v if isinstance(v, list) else [v] # type: ignore[index] multi_data[k] = v if isinstance(v, list) else [v] # type: ignore[index]
# yapf: enable # yapf: enable
return multi_data return multi_data
# NOTE: When a field (e.g. `images`) doesn't exist, directly appending to
# `self.images` doesn't update this dictionary, which may be confusing
# We annotate the getter methods as `Sequence` to prevent others from
# trying to update the list in this way
@property
def images(self) -> Sequence[ImageItem]:
return self.get("image", [])
@property
def videos(self) -> Sequence[VideoItem]:
return self.get("video", [])
@property
def audios(self) -> Sequence[AudioItem]:
return self.get("audio", [])
def get_image_size(self, item_idx: int) -> ImageSize:
image = self.images[item_idx]
if isinstance(image, Image):
return ImageSize(*image.size)
if isinstance(image, (np.ndarray, torch.Tensor)):
_, h, w = image.shape
return ImageSize(w, h)
assert_never(image)
def get_audio_with_sr(
self,
item_idx: int,
*,
default_sr: float,
) -> tuple[np.ndarray, float]:
audio = self.audios[item_idx]
if isinstance(audio, tuple):
return audio
if isinstance(audio, list):
return np.array(audio), default_sr
if isinstance(audio, np.ndarray):
return audio, default_sr
assert_never(audio)
def resample_audios(self, new_sr: float, *, drop_sr: bool = True) -> None:
"""
If :code:`drop_sr=True`, the audio items in this dictionary are updated
to be NumPy arrays which implicitly means that their sampling rate is
the same as the model's expected sampling rate; otherwise, they remain
as :code:`(audio, new_sr)` tuples.
"""
if not self.audios:
return
new_audios = []
for item_idx in range(len(self.audios)):
audio, sr = self.get_audio_with_sr(item_idx, default_sr=new_sr)
audio = resample_audio(audio, orig_sr=sr, target_sr=new_sr)
new_audios.append(audio if drop_sr else (audio, new_sr))
self["audio"] = new_audios
class _TokenMatch(NamedTuple): class _TokenMatch(NamedTuple):
start_idx: int start_idx: int
@ -596,18 +641,20 @@ class BaseMultiModalProcessor(ABC):
def _get_processor_data( def _get_processor_data(
self, self,
mm_data: MultiModalDataDict, mm_items: MultiModalDataItems,
) -> BatchFeature: ) -> tuple[dict[str, Any], dict[str, Any]]:
processor_data = dict[str, Any]() processor_data = dict[str, Any]()
passthrough_data = dict[str, Any]() passthrough_data = dict[str, Any]()
for k, v in mm_data.items():
for k, v in mm_items.items():
# TODO: Make a separate modality for embedding inputs # TODO: Make a separate modality for embedding inputs
# to avoid confusion # to avoid confusion
if k in ("image", "video", "audio"): if k in ("image", "video", "audio"):
if isinstance(v, torch.Tensor) and v.ndim == 3: if isinstance(v, torch.Tensor) and v.ndim == 3:
# Pass through embedding inputs (single) # Pass through embedding inputs (single)
passthrough_data[f"{k}_embeds"] = [v] passthrough_data[f"{k}_embeds"] = [v]
elif is_list_of(v, torch.Tensor) and v[0].ndim == 2: elif (is_list_of(v, torch.Tensor) and len(v) > 0
and v[0].ndim == 2):
# Pass through embedding inputs (multi) # Pass through embedding inputs (multi)
passthrough_data[f"{k}_embeds"] = v passthrough_data[f"{k}_embeds"] = v
else: else:
@ -615,40 +662,41 @@ class BaseMultiModalProcessor(ABC):
processor_data[f"{k}s"] = v processor_data[f"{k}s"] = v
else: else:
processor_data[k] = v processor_data[k] = v
return processor_data, passthrough_data return processor_data, passthrough_data
def _call_hf_processor(
self,
hf_processor: ProcessorMixin,
prompt: str,
processor_data: Mapping[str, object],
mm_processor_kwargs: Mapping[str, object],
) -> BatchFeature:
return self.ctx.call_hf_processor(
hf_processor,
prompt,
processor_data,
mm_processor_kwargs,
)
def _apply_hf_processor( def _apply_hf_processor(
self, self,
prompt: str, prompt: str,
mm_data: MultiModalDataDict, mm_items: MultiModalDataItems,
mm_processor_kwargs: Mapping[str, object], mm_processor_kwargs: Mapping[str, object],
) -> BatchFeature: ) -> BatchFeature:
# some mm_processor_kwargs may be used in processor initialization # some mm_processor_kwargs may be used in processor initialization
# instead of processor call # instead of processor call
hf_processor = self._get_hf_processor(**mm_processor_kwargs) hf_processor = self._get_hf_processor(**mm_processor_kwargs)
processor_data, passthrough_data = self._get_processor_data(mm_data) processor_data, passthrough_data = self._get_processor_data(mm_items)
assert callable(hf_processor) hf_inputs = self._call_hf_processor(
mm_processor_kwargs = self.ctx.resolve_hf_processor_call_kwargs(
hf_processor, hf_processor,
mm_processor_kwargs, prompt=prompt,
processor_data=processor_data,
mm_processor_kwargs=mm_processor_kwargs,
) )
try:
hf_inputs = hf_processor(
text=prompt, # type: ignore
**processor_data,
**mm_processor_kwargs,
return_tensors="pt",
)
except Exception as exc:
data = dict(text=prompt, **processor_data)
raise RuntimeError(
f"Failed to apply {type(hf_processor).__name__} "
f"on data={data} with kwargs={mm_processor_kwargs}") from exc
hf_inputs.update(passthrough_data) hf_inputs.update(passthrough_data)
return hf_inputs return hf_inputs
@ -730,14 +778,13 @@ class BaseMultiModalProcessor(ABC):
3. Extract information about the placeholder tokens from the 3. Extract information about the placeholder tokens from the
processed token IDs. processed token IDs.
""" """
tokenizer = self._get_tokenizer() mm_items = MultiModalDataItems.from_dict(mm_data)
hf_inputs = self._apply_hf_processor(prompt_text, mm_data, hf_inputs = self._apply_hf_processor(prompt_text, mm_items,
mm_processor_kwargs) mm_processor_kwargs)
prompt_ids, = hf_inputs.pop("input_ids").tolist() prompt_ids, = hf_inputs.pop("input_ids").tolist()
mm_kwargs = MultiModalKwargs(hf_inputs) mm_kwargs = MultiModalKwargs(hf_inputs)
mm_items = to_multi_format(mm_data)
prompt_repls = self._get_prompt_replacements(mm_items, hf_inputs, prompt_repls = self._get_prompt_replacements(mm_items, hf_inputs,
mm_processor_kwargs) mm_processor_kwargs)
all_prompt_repls = self._bind_prompt_replacements(prompt_repls) all_prompt_repls = self._bind_prompt_replacements(prompt_repls)
@ -749,6 +796,7 @@ class BaseMultiModalProcessor(ABC):
prompt_ids, mm_item_counts) prompt_ids, mm_item_counts)
if all_placeholders: if all_placeholders:
tokenizer = self._get_tokenizer()
prompt_text = _decode(tokenizer, prompt_ids) prompt_text = _decode(tokenizer, prompt_ids)
else: else:
( (

View File

@ -775,7 +775,7 @@ def get_dtype_size(dtype: torch.dtype) -> int:
# `collections` helpers # `collections` helpers
def is_list_of( def is_list_of(
value: object, value: object,
typ: Type[T], typ: Union[type[T], tuple[type[T], ...]],
*, *,
check: Literal["first", "all"] = "first", check: Literal["first", "all"] = "first",
) -> TypeIs[List[T]]: ) -> TypeIs[List[T]]:
@ -1282,6 +1282,7 @@ async def _run_task_with_lock(task: Callable, lock: asyncio.Lock, *args,
def supports_kw( def supports_kw(
callable: Callable[..., object], callable: Callable[..., object],
kw_name: str, kw_name: str,
*,
requires_kw_only: bool = False, requires_kw_only: bool = False,
allow_var_kwargs: bool = True, allow_var_kwargs: bool = True,
) -> bool: ) -> bool:
@ -1326,6 +1327,8 @@ def resolve_mm_processor_kwargs(
init_kwargs: Optional[Mapping[str, object]], init_kwargs: Optional[Mapping[str, object]],
inference_kwargs: Optional[Mapping[str, object]], inference_kwargs: Optional[Mapping[str, object]],
callable: Callable[..., object], callable: Callable[..., object],
*,
requires_kw_only: bool = True,
allow_var_kwargs: bool = False, allow_var_kwargs: bool = False,
) -> Dict[str, Any]: ) -> Dict[str, Any]:
"""Applies filtering to eliminate invalid mm_processor_kwargs, i.e., """Applies filtering to eliminate invalid mm_processor_kwargs, i.e.,
@ -1344,11 +1347,17 @@ def resolve_mm_processor_kwargs(
runtime_mm_kwargs = get_allowed_kwarg_only_overrides( runtime_mm_kwargs = get_allowed_kwarg_only_overrides(
callable, callable,
overrides=inference_kwargs, overrides=inference_kwargs,
allow_var_kwargs=allow_var_kwargs) requires_kw_only=requires_kw_only,
allow_var_kwargs=allow_var_kwargs,
)
# Filter init time multimodal processor kwargs provided # Filter init time multimodal processor kwargs provided
init_mm_kwargs = get_allowed_kwarg_only_overrides( init_mm_kwargs = get_allowed_kwarg_only_overrides(
callable, overrides=init_kwargs, allow_var_kwargs=allow_var_kwargs) callable,
overrides=init_kwargs,
requires_kw_only=requires_kw_only,
allow_var_kwargs=allow_var_kwargs,
)
# Merge the final processor kwargs, prioritizing inference # Merge the final processor kwargs, prioritizing inference
# time values over the initialization time values. # time values over the initialization time values.
@ -1359,6 +1368,8 @@ def resolve_mm_processor_kwargs(
def get_allowed_kwarg_only_overrides( def get_allowed_kwarg_only_overrides(
callable: Callable[..., object], callable: Callable[..., object],
overrides: Optional[Mapping[str, object]], overrides: Optional[Mapping[str, object]],
*,
requires_kw_only: bool = True,
allow_var_kwargs: bool = False, allow_var_kwargs: bool = False,
) -> Dict[str, Any]: ) -> Dict[str, Any]:
""" """
@ -1390,16 +1401,21 @@ def get_allowed_kwarg_only_overrides(
for kwarg_name, val in overrides.items() for kwarg_name, val in overrides.items()
if supports_kw(callable, if supports_kw(callable,
kwarg_name, kwarg_name,
requires_kw_only=True, requires_kw_only=requires_kw_only,
allow_var_kwargs=allow_var_kwargs) allow_var_kwargs=allow_var_kwargs)
} }
# If anything is dropped, log a warning # If anything is dropped, log a warning
dropped_keys = overrides.keys() - filtered_overrides.keys() dropped_keys = overrides.keys() - filtered_overrides.keys()
if dropped_keys: if dropped_keys:
if requires_kw_only:
logger.warning( logger.warning(
"The following intended overrides are not keyword-only args " "The following intended overrides are not keyword-only args "
"and and will be dropped: %s", dropped_keys) "and and will be dropped: %s", dropped_keys)
else:
logger.warning(
"The following intended overrides are not keyword args "
"and and will be dropped: %s", dropped_keys)
return filtered_overrides return filtered_overrides