[VLM] Merged multimodal processor for Qwen2-Audio (#11303)
Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
parent
c6b0a7d3ba
commit
6142ef0ada
@ -18,6 +18,10 @@ question_per_audio_count = {
|
||||
2: "What sport and what nursery rhyme are referenced?"
|
||||
}
|
||||
|
||||
# NOTE: The default `max_num_seqs` and `max_model_len` may result in OOM on
|
||||
# lower-end GPUs.
|
||||
# Unless specified, these settings have been tested to work on a single L4.
|
||||
|
||||
|
||||
# Ultravox 0.3
|
||||
def run_ultravox(question: str, audio_count: int):
|
||||
@ -33,6 +37,8 @@ def run_ultravox(question: str, audio_count: int):
|
||||
add_generation_prompt=True)
|
||||
|
||||
llm = LLM(model=model_name,
|
||||
max_model_len=4096,
|
||||
max_num_seqs=5,
|
||||
trust_remote_code=True,
|
||||
limit_mm_per_prompt={"audio": audio_count})
|
||||
stop_token_ids = None
|
||||
|
@ -5,6 +5,7 @@ import pytest
|
||||
import pytest_asyncio
|
||||
from transformers import AutoModel, AutoTokenizer, BatchEncoding
|
||||
|
||||
from vllm.multimodal.audio import resample_audio
|
||||
from vllm.sequence import SampleLogprobs
|
||||
from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE
|
||||
|
||||
@ -130,16 +131,14 @@ def run_test(
|
||||
dtype=dtype,
|
||||
postprocess_inputs=process,
|
||||
auto_cls=AutoModel) as hf_model:
|
||||
import librosa
|
||||
|
||||
hf_outputs_per_audio = [
|
||||
hf_model.generate_greedy_logprobs_limit(
|
||||
[hf_prompt],
|
||||
max_tokens,
|
||||
num_logprobs=num_logprobs,
|
||||
audios=[(librosa.resample(audio[0],
|
||||
orig_sr=audio[1],
|
||||
target_sr=16000), 16000)])
|
||||
audios=[(resample_audio(audio[0],
|
||||
orig_sr=audio[1],
|
||||
target_sr=16000), 16000)])
|
||||
for _, hf_prompt, audio in prompts_and_audios
|
||||
]
|
||||
|
||||
|
@ -1,11 +1,11 @@
|
||||
import functools
|
||||
from collections import UserDict
|
||||
from dataclasses import dataclass
|
||||
from typing import (TYPE_CHECKING, Any, Callable, Dict, Mapping, NamedTuple,
|
||||
Optional, Protocol, Type)
|
||||
from typing import (TYPE_CHECKING, Any, Callable, Mapping, NamedTuple,
|
||||
Optional, Protocol, Union)
|
||||
|
||||
from torch import nn
|
||||
from transformers import PretrainedConfig, ProcessorMixin
|
||||
from transformers import BatchFeature, PretrainedConfig, ProcessorMixin
|
||||
from typing_extensions import TypeVar, assert_never
|
||||
|
||||
from vllm.logger import init_logger
|
||||
@ -26,6 +26,7 @@ if TYPE_CHECKING:
|
||||
logger = init_logger(__name__)
|
||||
|
||||
C = TypeVar("C", bound=PretrainedConfig, default=PretrainedConfig)
|
||||
P = TypeVar("P", bound=ProcessorMixin, default=ProcessorMixin)
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
@ -38,24 +39,28 @@ class InputContext:
|
||||
model_config: "ModelConfig"
|
||||
"""The configuration of the model."""
|
||||
|
||||
def get_hf_config(self, hf_config_type: Type[C] = PretrainedConfig) -> C:
|
||||
def get_hf_config(
|
||||
self,
|
||||
typ: Union[type[C], tuple[type[C], ...]] = PretrainedConfig,
|
||||
/,
|
||||
) -> C:
|
||||
"""
|
||||
Get the HuggingFace configuration
|
||||
(:class:`transformers.PretrainedConfig`) of the model,
|
||||
additionally checking its type.
|
||||
|
||||
Raises:
|
||||
TypeError: If the model is not of the specified type.
|
||||
TypeError: If the configuration is not of the specified type.
|
||||
"""
|
||||
hf_config = self.model_config.hf_config
|
||||
if not isinstance(hf_config, hf_config_type):
|
||||
if not isinstance(hf_config, typ):
|
||||
raise TypeError("Invalid type of HuggingFace config. "
|
||||
f"Expected type: {hf_config_type}, but "
|
||||
f"Expected type: {typ}, but "
|
||||
f"found type: {type(hf_config)}")
|
||||
|
||||
return hf_config
|
||||
|
||||
def get_hf_image_processor_config(self) -> Dict[str, Any]:
|
||||
def get_hf_image_processor_config(self) -> dict[str, Any]:
|
||||
"""
|
||||
Get the HuggingFace image processor configuration of the model.
|
||||
"""
|
||||
@ -74,18 +79,37 @@ class InputContext:
|
||||
|
||||
return mm_config
|
||||
|
||||
def get_hf_processor(self, **kwargs: object) -> ProcessorMixin:
|
||||
def get_hf_processor(
|
||||
self,
|
||||
typ: Union[type[P], tuple[type[P], ...]] = ProcessorMixin,
|
||||
/,
|
||||
**kwargs: object,
|
||||
) -> P:
|
||||
"""
|
||||
Get the HuggingFace processor
|
||||
(:class:`transformers.ProcessorMixin`) of the model,
|
||||
additionally checking its type.
|
||||
|
||||
Raises:
|
||||
TypeError: If the processor is not of the specified type.
|
||||
"""
|
||||
base_kwargs = self.model_config.mm_processor_kwargs
|
||||
if base_kwargs is None:
|
||||
base_kwargs = {}
|
||||
|
||||
merged_kwargs = {**base_kwargs, **kwargs}
|
||||
|
||||
return cached_get_processor(
|
||||
hf_processor = cached_get_processor(
|
||||
self.model_config.model,
|
||||
trust_remote_code=self.model_config.trust_remote_code,
|
||||
**merged_kwargs,
|
||||
)
|
||||
if not isinstance(hf_processor, typ):
|
||||
raise TypeError("Invalid type of HuggingFace processor. "
|
||||
f"Expected type: {typ}, but "
|
||||
f"found type: {type(hf_processor)}")
|
||||
|
||||
return hf_processor
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
@ -93,39 +117,55 @@ class InputProcessingContext(InputContext):
|
||||
tokenizer: AnyTokenizer
|
||||
"""The tokenizer used to tokenize the inputs."""
|
||||
|
||||
def get_hf_processor(self, **kwargs: object) -> ProcessorMixin:
|
||||
base_kwargs = self.model_config.mm_processor_kwargs
|
||||
if base_kwargs is None:
|
||||
base_kwargs = {}
|
||||
|
||||
merged_kwargs = {**base_kwargs, **kwargs}
|
||||
|
||||
return cached_get_processor(
|
||||
self.model_config.model,
|
||||
tokenizer=self.tokenizer, # Override the tokenizer with ours
|
||||
trust_remote_code=self.model_config.trust_remote_code,
|
||||
**merged_kwargs,
|
||||
def get_hf_processor(
|
||||
self,
|
||||
typ: Union[type[P], tuple[type[P], ...]] = ProcessorMixin,
|
||||
/,
|
||||
**kwargs: object,
|
||||
) -> P:
|
||||
return super().get_hf_processor(
|
||||
typ,
|
||||
tokenizer=self.tokenizer,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
def resolve_hf_processor_call_kwargs(
|
||||
def call_hf_processor(
|
||||
self,
|
||||
hf_processor: ProcessorMixin,
|
||||
prompt: str,
|
||||
processor_data: Mapping[str, object],
|
||||
inference_kwargs: Mapping[str, object],
|
||||
) -> Mapping[str, object]:
|
||||
) -> BatchFeature:
|
||||
assert callable(hf_processor)
|
||||
|
||||
base_kwargs = self.model_config.mm_processor_kwargs
|
||||
if base_kwargs is None:
|
||||
base_kwargs = {}
|
||||
|
||||
return resolve_mm_processor_kwargs(
|
||||
merged_kwargs = resolve_mm_processor_kwargs(
|
||||
base_kwargs,
|
||||
inference_kwargs,
|
||||
hf_processor,
|
||||
requires_kw_only=False,
|
||||
allow_var_kwargs=True,
|
||||
)
|
||||
|
||||
try:
|
||||
return hf_processor(
|
||||
text=prompt,
|
||||
**processor_data,
|
||||
**merged_kwargs,
|
||||
return_tensors="pt",
|
||||
)
|
||||
except Exception as exc:
|
||||
data = dict(text=prompt, **processor_data)
|
||||
msg = (f"Failed to apply {type(hf_processor).__name__} "
|
||||
f"on data={data} with kwargs={merged_kwargs}")
|
||||
|
||||
N = TypeVar("N", bound=Type[nn.Module])
|
||||
raise RuntimeError(msg) from exc
|
||||
|
||||
|
||||
N = TypeVar("N", bound=type[nn.Module])
|
||||
|
||||
|
||||
class DummyData(NamedTuple):
|
||||
@ -232,7 +272,7 @@ class InputRegistry:
|
||||
|
||||
return wrapper
|
||||
|
||||
def _get_dummy_data_factory(self, model_cls: Type[nn.Module]):
|
||||
def _get_dummy_data_factory(self, model_cls: type[nn.Module]):
|
||||
return self._dummy_factories_by_model_type \
|
||||
.get(model_cls, self._default_dummy_data_factory)
|
||||
|
||||
@ -257,7 +297,7 @@ class InputRegistry:
|
||||
|
||||
return wrapper
|
||||
|
||||
def _get_dummy_encoder_data_factory(self, model_cls: Type[nn.Module]):
|
||||
def _get_dummy_encoder_data_factory(self, model_cls: type[nn.Module]):
|
||||
return self._dummy_encoder_factories_by_model_type \
|
||||
.get(model_cls, self._default_dummy_data_factory)
|
||||
|
||||
@ -368,14 +408,14 @@ class InputRegistry:
|
||||
|
||||
return wrapper
|
||||
|
||||
def _get_model_input_processor(self, model_cls: Type[nn.Module]):
|
||||
def _get_model_input_processor(self, model_cls: type[nn.Module]):
|
||||
return self._input_processors_by_model_type \
|
||||
.get(model_cls, self._default_input_processor)
|
||||
|
||||
def _ensure_mm_kwargs(
|
||||
self,
|
||||
inputs: SingletonInputs,
|
||||
mm_processor_kwargs: Dict[str, Any],
|
||||
mm_processor_kwargs: dict[str, Any],
|
||||
):
|
||||
if inputs["type"] == "token":
|
||||
# In case the input processor for that model fails to set it
|
||||
|
@ -133,8 +133,8 @@ class LlavaMultiModalProcessor(BaseMultiModalProcessor):
|
||||
hf_processor.__is_patched__ = True # type: ignore
|
||||
|
||||
def _get_hf_processor(self) -> Union[LlavaProcessor, PixtralProcessor]:
|
||||
hf_processor = self.ctx.get_hf_processor()
|
||||
assert isinstance(hf_processor, (LlavaProcessor, PixtralProcessor))
|
||||
hf_processor = self.ctx.get_hf_processor(
|
||||
(LlavaProcessor, PixtralProcessor))
|
||||
|
||||
if isinstance(hf_processor, PixtralProcessor):
|
||||
self._patch_pixtral_processor(hf_processor)
|
||||
|
@ -34,7 +34,6 @@ from vllm.model_executor.sampling_metadata import SamplingMetadata
|
||||
from vllm.multimodal import MULTIMODAL_REGISTRY
|
||||
from vllm.multimodal.inputs import NestedTensors
|
||||
from vllm.multimodal.processing import (BaseMultiModalProcessor,
|
||||
MultiModalDataDict,
|
||||
MultiModalDataItems, ProcessorInputs,
|
||||
PromptReplacement)
|
||||
from vllm.sequence import IntermediateTensors
|
||||
@ -330,20 +329,27 @@ class Phi3VMultiModalProcessor(BaseMultiModalProcessor):
|
||||
return self.ctx.get_hf_processor(num_crops=num_crops)
|
||||
return self.ctx.get_hf_processor()
|
||||
|
||||
def _apply_hf_processor(
|
||||
def _call_hf_processor(
|
||||
self,
|
||||
hf_processor: ProcessorMixin,
|
||||
prompt: str,
|
||||
mm_data: MultiModalDataDict,
|
||||
processor_data: Mapping[str, object],
|
||||
mm_processor_kwargs: Mapping[str, object],
|
||||
) -> BatchFeature:
|
||||
processed_outputs = super()._apply_hf_processor(
|
||||
prompt, mm_data, mm_processor_kwargs)
|
||||
processed_outputs = super()._call_hf_processor(
|
||||
hf_processor,
|
||||
prompt=prompt,
|
||||
processor_data=processor_data,
|
||||
mm_processor_kwargs=mm_processor_kwargs,
|
||||
)
|
||||
|
||||
# Phi3v processor has inserted -1, -2 etc as placeholder in prompt_ids,
|
||||
# which will cause OverflowError when decoding the prompt_ids.
|
||||
# Therefore, we need to do an early replacement here
|
||||
token_ids = processed_outputs['input_ids']
|
||||
token_ids[token_ids < 0] = _IMAGE_TOKEN_ID
|
||||
processed_outputs['input_ids'] = token_ids
|
||||
|
||||
return processed_outputs
|
||||
|
||||
def _get_prompt_replacements(
|
||||
|
@ -19,45 +19,43 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Inference-only Qwen2-Audio model compatible with HuggingFace weights."""
|
||||
from functools import cached_property, lru_cache
|
||||
from typing import (Iterable, List, Mapping, Optional, Set, Tuple, TypedDict,
|
||||
Union)
|
||||
from functools import cached_property
|
||||
from typing import (Any, Iterable, List, Mapping, Optional, Set, Tuple,
|
||||
TypedDict, Union)
|
||||
|
||||
import librosa
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from transformers import Qwen2AudioEncoder
|
||||
from transformers import BatchFeature, ProcessorMixin
|
||||
from transformers.models.qwen2_audio import (Qwen2AudioConfig,
|
||||
Qwen2AudioEncoder,
|
||||
Qwen2AudioProcessor)
|
||||
from transformers.models.whisper import WhisperFeatureExtractor
|
||||
|
||||
from vllm.attention import AttentionMetadata
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.inputs import (INPUT_REGISTRY, DecoderOnlyInputs, DummyData,
|
||||
InputContext, token_inputs)
|
||||
from vllm.logger import init_logger
|
||||
from vllm.inputs import InputContext
|
||||
from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler
|
||||
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
||||
from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalKwargs
|
||||
from vllm.multimodal import MULTIMODAL_REGISTRY
|
||||
from vllm.multimodal.inputs import NestedTensors
|
||||
from vllm.multimodal.utils import consecutive_placeholder_ranges
|
||||
from vllm.sequence import IntermediateTensors, SequenceData
|
||||
from vllm.multimodal.processing import (BaseMultiModalProcessor,
|
||||
MultiModalDataItems, ProcessorInputs,
|
||||
PromptReplacement)
|
||||
from vllm.sequence import IntermediateTensors
|
||||
|
||||
from .interfaces import SupportsMultiModal, SupportsPP
|
||||
from .utils import (AutoWeightsLoader, init_vllm_registered_model,
|
||||
maybe_prefix, merge_multimodal_embeddings)
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
# # === Audio Inputs === #
|
||||
class Qwen2AudioInputs(TypedDict):
|
||||
input_features: torch.Tensor
|
||||
"""Shape:
|
||||
`(num_audios, num_mel_bins, 3000)`
|
||||
"""
|
||||
"""Shape: `(num_audios, num_mel_bins, 3000)`"""
|
||||
|
||||
feature_attention_mask: torch.Tensor
|
||||
"""Shape: `(num_audios, 3000)`
|
||||
"""
|
||||
"""Shape: `(num_audios, 3000)`"""
|
||||
|
||||
|
||||
# === Audio Encoder === #
|
||||
@ -74,187 +72,114 @@ class Qwen2AudioMultiModalProjector(nn.Module):
|
||||
return hidden_states
|
||||
|
||||
|
||||
def dummy_data_for_qwen2_audio(ctx: InputContext, seq_len: int,
|
||||
mm_counts: Mapping[str, int]):
|
||||
num_audios = mm_counts["audio"]
|
||||
max_tokens_per_audio = get_max_qwen2_audio_audio_tokens(ctx)
|
||||
max_llm_audio_tokens = max_tokens_per_audio * num_audios
|
||||
if seq_len - max_llm_audio_tokens - 2 < 0:
|
||||
raise RuntimeError(
|
||||
f"Qwen2-Audio cannot process {num_audios} audios in a prompt, "
|
||||
"please increase max_model_len or reduce audio limit by "
|
||||
"--limit-mm-per-prompt.")
|
||||
|
||||
audio_token_index = ctx.model_config.hf_config.audio_token_index
|
||||
|
||||
dummy_seqdata = SequenceData.from_prompt_token_counts(
|
||||
(audio_token_index, max_llm_audio_tokens),
|
||||
(0, seq_len - max_llm_audio_tokens),
|
||||
)
|
||||
dummy_audio = np.full((max_llm_audio_tokens * 2 * 2 * 160, ), 0.)
|
||||
return DummyData(
|
||||
dummy_seqdata, {"audio": [(dummy_audio, 16000)] * num_audios}, {
|
||||
"audio":
|
||||
consecutive_placeholder_ranges(num_items=num_audios,
|
||||
item_size=max_tokens_per_audio)
|
||||
})
|
||||
|
||||
|
||||
def get_processor(
|
||||
processor_name: str,
|
||||
*args,
|
||||
trust_remote_code: bool = False,
|
||||
**kwargs,
|
||||
):
|
||||
"""Gets a processor for the given model name via HuggingFace.
|
||||
|
||||
Derived from `vllm.transformers_utils.image_processor.get_image_processor`.
|
||||
"""
|
||||
# don't put this import at the top level
|
||||
# it will call torch.cuda.device_count()
|
||||
from transformers import AutoProcessor
|
||||
|
||||
try:
|
||||
processor = AutoProcessor.from_pretrained(
|
||||
processor_name,
|
||||
*args,
|
||||
trust_remote_code=trust_remote_code,
|
||||
**kwargs)
|
||||
except ValueError as e:
|
||||
# If the error pertains to the processor class not existing or not
|
||||
# currently being imported, suggest using the --trust-remote-code flag.
|
||||
# Unlike AutoTokenizer, AutoProcessor does not separate such errors
|
||||
if not trust_remote_code:
|
||||
err_msg = (
|
||||
"Failed to load the processor. If the processor is "
|
||||
"a custom processor not yet available in the HuggingFace "
|
||||
"transformers library, consider setting "
|
||||
"`trust_remote_code=True` in LLM or using the "
|
||||
"`--trust-remote-code` flag in the CLI.")
|
||||
raise RuntimeError(err_msg) from e
|
||||
else:
|
||||
raise e
|
||||
|
||||
return processor
|
||||
|
||||
|
||||
cached_get_processor = lru_cache(get_processor)
|
||||
|
||||
|
||||
# From Qwen2AudioEncoder._get_feat_extract_output_lengths
|
||||
def _get_feat_extract_output_lengths(input_lengths: torch.LongTensor):
|
||||
"""
|
||||
Computes the output length of the convolutional layers
|
||||
and the output length of the audio encoder
|
||||
"""
|
||||
input_lengths = (input_lengths - 1) // 2 + 1
|
||||
output_lengths = (input_lengths - 2) // 2 + 1
|
||||
return input_lengths, output_lengths
|
||||
feat_lengths = (input_lengths - 1) // 2 + 1
|
||||
output_lengths = (feat_lengths - 2) // 2 + 1
|
||||
return feat_lengths, output_lengths
|
||||
|
||||
|
||||
def get_max_qwen2_audio_audio_tokens(ctx: InputContext) -> int:
|
||||
max_source_position = (
|
||||
ctx.model_config.hf_config.audio_config.max_source_positions)
|
||||
hf_config = ctx.get_hf_config(Qwen2AudioConfig)
|
||||
max_source_position = hf_config.audio_config.max_source_positions
|
||||
output_lengths = (max_source_position - 2) // 2 + 1
|
||||
return output_lengths
|
||||
|
||||
|
||||
def input_processor_for_qwen2_audio(
|
||||
ctx: InputContext, inputs: DecoderOnlyInputs) -> DecoderOnlyInputs:
|
||||
multi_modal_data = inputs.get("multi_modal_data")
|
||||
if multi_modal_data is None or "audio" not in multi_modal_data:
|
||||
return inputs
|
||||
class Qwen2AudioMultiModalProcessor(BaseMultiModalProcessor):
|
||||
|
||||
audios = multi_modal_data["audio"]
|
||||
if not isinstance(audios, list):
|
||||
audios = [audios]
|
||||
def _get_hf_processor(self) -> Qwen2AudioProcessor:
|
||||
return self.ctx.get_hf_processor(Qwen2AudioProcessor)
|
||||
|
||||
if len(audios) == 0:
|
||||
return inputs
|
||||
def _get_feature_extractor(self) -> WhisperFeatureExtractor:
|
||||
return self._get_hf_processor().feature_extractor # type: ignore
|
||||
|
||||
processor = cached_get_processor(ctx.model_config.model)
|
||||
resampled_audios = [
|
||||
librosa.resample(audio,
|
||||
orig_sr=sampling_rate,
|
||||
target_sr=processor.feature_extractor.sampling_rate)
|
||||
for audio, sampling_rate in audios
|
||||
]
|
||||
audio_input_lengths = np.array(
|
||||
[min(3000, _.shape[0] // 160 + 1) for _ in resampled_audios])
|
||||
def _get_processor_data(
|
||||
self,
|
||||
mm_items: MultiModalDataItems,
|
||||
) -> tuple[dict[str, Any], dict[str, Any]]:
|
||||
# resample audio to the model's sampling rate
|
||||
feature_extractor = self._get_feature_extractor()
|
||||
mm_items.resample_audios(feature_extractor.sampling_rate)
|
||||
|
||||
audio_feat_lengths, audio_output_lengths = _get_feat_extract_output_lengths(
|
||||
audio_input_lengths)
|
||||
return super()._get_processor_data(mm_items)
|
||||
|
||||
audio_token_index = ctx.model_config.hf_config.audio_token_index
|
||||
def _call_hf_processor(
|
||||
self,
|
||||
hf_processor: ProcessorMixin,
|
||||
prompt: str,
|
||||
processor_data: Mapping[str, object],
|
||||
mm_processor_kwargs: Mapping[str, object],
|
||||
) -> BatchFeature:
|
||||
processor_data = dict(processor_data)
|
||||
audios = processor_data.pop("audios", [])
|
||||
|
||||
input_ids = inputs['prompt_token_ids']
|
||||
if audios:
|
||||
processor_data["audios"] = audios
|
||||
|
||||
new_input_ids = []
|
||||
audio_num = input_ids.count(audio_token_index)
|
||||
assert len(audio_input_lengths) == audio_num, \
|
||||
(f'The text input contains {audio_num} audio tokens, '
|
||||
f'but {len(audio_input_lengths)} audios provided')
|
||||
start = 0
|
||||
for audio_idx in range(audio_num):
|
||||
end = input_ids.index(audio_token_index, start)
|
||||
new_input_ids.extend(input_ids[start:end]) # text part
|
||||
feature_extractor = self._get_feature_extractor()
|
||||
mm_processor_kwargs = dict(
|
||||
**mm_processor_kwargs,
|
||||
sampling_rate=feature_extractor.sampling_rate,
|
||||
)
|
||||
else:
|
||||
# NOTE: WhisperFeatureExtractor cannot handle empty list of audios
|
||||
pass
|
||||
|
||||
new_input_ids.extend([audio_token_index] *
|
||||
audio_output_lengths[audio_idx])
|
||||
start = end + 1
|
||||
new_input_ids.extend(input_ids[start:])
|
||||
return super()._call_hf_processor(
|
||||
hf_processor,
|
||||
prompt=prompt,
|
||||
processor_data=processor_data,
|
||||
mm_processor_kwargs=mm_processor_kwargs,
|
||||
)
|
||||
|
||||
return token_inputs(
|
||||
prompt_token_ids=new_input_ids,
|
||||
prompt=inputs.get("prompt"),
|
||||
multi_modal_data=multi_modal_data,
|
||||
)
|
||||
def _get_prompt_replacements(
|
||||
self,
|
||||
mm_items: MultiModalDataItems,
|
||||
hf_inputs: BatchFeature,
|
||||
mm_processor_kwargs: Mapping[str, object],
|
||||
) -> list[PromptReplacement]:
|
||||
hf_config = self.ctx.get_hf_config(Qwen2AudioConfig)
|
||||
placeholder = hf_config.audio_token_index
|
||||
|
||||
feature_attention_mask = hf_inputs.get("feature_attention_mask")
|
||||
if feature_attention_mask is None:
|
||||
audio_output_lengths = []
|
||||
else:
|
||||
_, audio_output_lengths = _get_feat_extract_output_lengths(
|
||||
feature_attention_mask.sum(-1))
|
||||
|
||||
def input_mapper_for_qwen2_audio(
|
||||
ctx: InputContext,
|
||||
multi_modal_data: Union[np.ndarray, List[np.ndarray]],
|
||||
) -> MultiModalKwargs:
|
||||
"""Input mapper for Qwen2-Audio."""
|
||||
if not isinstance(multi_modal_data, list):
|
||||
multi_modal_data = [multi_modal_data]
|
||||
def get_replacement_qwen2_audio(item_idx: int):
|
||||
return [placeholder] * audio_output_lengths[item_idx]
|
||||
|
||||
if len(multi_modal_data) == 0:
|
||||
return MultiModalKwargs()
|
||||
|
||||
processor = cached_get_processor(ctx.model_config.model)
|
||||
audio_feature_extractor = processor.feature_extractor
|
||||
if audio_feature_extractor is None:
|
||||
raise RuntimeError(
|
||||
"No HuggingFace audio_feature_extractor is available "
|
||||
"to process the audio object")
|
||||
|
||||
try:
|
||||
resampled_audios = [
|
||||
librosa.resample(
|
||||
audio,
|
||||
orig_sr=sampling_rate,
|
||||
target_sr=processor.feature_extractor.sampling_rate)
|
||||
for audio, sampling_rate in multi_modal_data
|
||||
return [
|
||||
PromptReplacement(
|
||||
modality="audio",
|
||||
target=[placeholder],
|
||||
replacement=get_replacement_qwen2_audio,
|
||||
)
|
||||
]
|
||||
batch_data = audio_feature_extractor(resampled_audios,
|
||||
sampling_rate=16000,
|
||||
return_attention_mask=True,
|
||||
padding="max_length",
|
||||
return_tensors="pt").data
|
||||
batch_data["feature_attention_mask"] = batch_data.pop("attention_mask")
|
||||
except Exception:
|
||||
logger.error("Failed to process audio (%s)", multi_modal_data)
|
||||
raise
|
||||
|
||||
return MultiModalKwargs(batch_data)
|
||||
def _get_dummy_mm_inputs(
|
||||
self,
|
||||
mm_counts: Mapping[str, int],
|
||||
) -> ProcessorInputs:
|
||||
audio_len = get_max_qwen2_audio_audio_tokens(self.ctx)
|
||||
|
||||
audio_count = mm_counts["audio"]
|
||||
audio = np.zeros(audio_len)
|
||||
data = {"audio": [audio] * audio_count}
|
||||
|
||||
return ProcessorInputs(
|
||||
prompt_text="<|AUDIO|>" * audio_count,
|
||||
mm_data=data,
|
||||
mm_processor_kwargs={},
|
||||
)
|
||||
|
||||
|
||||
@INPUT_REGISTRY.register_dummy_data(dummy_data_for_qwen2_audio)
|
||||
@INPUT_REGISTRY.register_input_processor(input_processor_for_qwen2_audio)
|
||||
@MULTIMODAL_REGISTRY.register_input_mapper("audio",
|
||||
input_mapper_for_qwen2_audio)
|
||||
@MULTIMODAL_REGISTRY.register_max_multimodal_tokens(
|
||||
"audio", get_max_qwen2_audio_audio_tokens)
|
||||
@MULTIMODAL_REGISTRY.register_processor(Qwen2AudioMultiModalProcessor)
|
||||
class Qwen2AudioForConditionalGeneration(nn.Module, SupportsMultiModal,
|
||||
SupportsPP):
|
||||
|
||||
@ -289,9 +214,7 @@ class Qwen2AudioForConditionalGeneration(nn.Module, SupportsMultiModal,
|
||||
|
||||
return get_sampler()
|
||||
|
||||
def _validate_and_reshape_mm_tensor(self,
|
||||
mm_input: Union[torch.Tensor,
|
||||
List[torch.Tensor]],
|
||||
def _validate_and_reshape_mm_tensor(self, mm_input: object,
|
||||
name: str) -> torch.Tensor:
|
||||
if not isinstance(mm_input, (torch.Tensor, list)):
|
||||
raise ValueError(f"Incorrect type of {name}. "
|
||||
|
@ -3,7 +3,7 @@
|
||||
|
||||
import math
|
||||
from functools import cached_property, lru_cache
|
||||
from typing import (Any, Dict, Iterable, List, Literal, Mapping, Optional, Set,
|
||||
from typing import (Any, Iterable, List, Literal, Mapping, Optional, Set,
|
||||
Tuple, TypedDict, Union)
|
||||
|
||||
import numpy as np
|
||||
@ -11,7 +11,7 @@ import torch
|
||||
import torch.utils.checkpoint
|
||||
from torch import nn
|
||||
from torch.nn import functional as F
|
||||
from transformers import BatchFeature
|
||||
from transformers import BatchFeature, ProcessorMixin
|
||||
from transformers.models.whisper import WhisperFeatureExtractor
|
||||
from transformers.models.whisper.modeling_whisper import WhisperEncoder
|
||||
|
||||
@ -25,11 +25,11 @@ from vllm.model_executor.model_loader.loader import DefaultModelLoader
|
||||
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
||||
from vllm.multimodal import MULTIMODAL_REGISTRY, NestedTensors
|
||||
from vllm.multimodal.processing import (BaseMultiModalProcessor,
|
||||
MultiModalDataDict,
|
||||
MultiModalDataItems, ProcessorInputs,
|
||||
PromptReplacement)
|
||||
from vllm.sequence import IntermediateTensors
|
||||
from vllm.transformers_utils.configs.ultravox import UltravoxConfig
|
||||
from vllm.utils import is_list_of
|
||||
|
||||
from .interfaces import SupportsMultiModal, SupportsPP
|
||||
from .utils import (AutoWeightsLoader, WeightsMapper, flatten_bn,
|
||||
@ -61,8 +61,8 @@ def cached_feature_extractor(model_id: str) -> WhisperFeatureExtractor:
|
||||
|
||||
|
||||
def whisper_feature_extractor(ctx: InputContext) -> WhisperFeatureExtractor:
|
||||
return cached_feature_extractor(
|
||||
ctx.get_hf_config(UltravoxConfig).audio_model_id)
|
||||
hf_config = ctx.get_hf_config(UltravoxConfig)
|
||||
return cached_feature_extractor(hf_config.audio_model_id)
|
||||
|
||||
|
||||
def get_ultravox_max_audio_tokens(ctx: InputContext):
|
||||
@ -73,72 +73,71 @@ def get_ultravox_max_audio_tokens(ctx: InputContext):
|
||||
class UltravoxMultiModalProcessor(BaseMultiModalProcessor):
|
||||
|
||||
def _get_feature_extractor(self) -> WhisperFeatureExtractor:
|
||||
return self._get_hf_processor().audio_processor.feature_extractor
|
||||
|
||||
def _resample_audio(
|
||||
self,
|
||||
audio: np.ndarray,
|
||||
sr: int,
|
||||
) -> Dict[str, Union[np.ndarray, int]]:
|
||||
# resample audio to the model's sampling rate
|
||||
feature_extractor = self._get_feature_extractor()
|
||||
if sr != feature_extractor.sampling_rate:
|
||||
try:
|
||||
import librosa
|
||||
except ImportError as exc:
|
||||
raise ImportError(
|
||||
"Please install vllm[audio] for audio support.") from exc
|
||||
audio = librosa.resample(audio,
|
||||
orig_sr=sr,
|
||||
target_sr=feature_extractor.sampling_rate)
|
||||
sr = feature_extractor.sampling_rate
|
||||
return {"audio": audio, "sampling_rate": sr}
|
||||
|
||||
def _apply_hf_processor(
|
||||
self,
|
||||
prompt: str,
|
||||
mm_data: MultiModalDataDict,
|
||||
mm_processor_kwargs: Mapping[str, object],
|
||||
) -> BatchFeature:
|
||||
if not mm_data or not mm_data.get("audio", None):
|
||||
return super()._apply_hf_processor(prompt, mm_data,
|
||||
mm_processor_kwargs)
|
||||
|
||||
audio_data = mm_data["audio"]
|
||||
if not isinstance(audio_data, list):
|
||||
audio_data = [audio_data]
|
||||
|
||||
# Ultravox processor doesn't support multiple inputs,
|
||||
# therefore we need to input text and audio one by one
|
||||
tokenizer = self._get_tokenizer()
|
||||
audio_features, audio_token_len = [], []
|
||||
processed_inputs = {}
|
||||
for audio, sr in audio_data:
|
||||
data = self._resample_audio(audio, sr)
|
||||
processed_inputs = super()._apply_hf_processor(
|
||||
prompt, data, mm_processor_kwargs)
|
||||
prompt = tokenizer.decode(processed_inputs["input_ids"][0],
|
||||
skip_special_tokens=False)
|
||||
audio_features.append(
|
||||
processed_inputs.pop("audio_values").squeeze(0))
|
||||
audio_token_len.append(
|
||||
processed_inputs.pop("audio_token_len").item())
|
||||
|
||||
return dict(
|
||||
**processed_inputs,
|
||||
audio_features=audio_features,
|
||||
audio_token_len=audio_token_len,
|
||||
)
|
||||
hf_processor = self._get_hf_processor()
|
||||
return hf_processor.audio_processor.feature_extractor # type: ignore
|
||||
|
||||
def _get_processor_data(
|
||||
self,
|
||||
mm_data: MultiModalDataDict,
|
||||
) -> Tuple[Dict[str, Any], Dict[str, Any]]:
|
||||
# Ultravox uses "audio" instead of "audios" as calling keyword
|
||||
processor_data, passthrough_data = super()._get_processor_data(mm_data)
|
||||
if "audios" in processor_data:
|
||||
processor_data["audio"] = processor_data.pop("audios")
|
||||
return processor_data, passthrough_data
|
||||
mm_items: MultiModalDataItems,
|
||||
) -> tuple[dict[str, Any], dict[str, Any]]:
|
||||
# resample audio to the model's sampling rate
|
||||
feature_extractor = self._get_feature_extractor()
|
||||
mm_items.resample_audios(feature_extractor.sampling_rate)
|
||||
|
||||
return super()._get_processor_data(mm_items)
|
||||
|
||||
def _call_hf_processor(
|
||||
self,
|
||||
hf_processor: ProcessorMixin,
|
||||
prompt: str,
|
||||
processor_data: Mapping[str, object],
|
||||
mm_processor_kwargs: Mapping[str, object],
|
||||
) -> BatchFeature:
|
||||
processor_data = dict(processor_data)
|
||||
audios = processor_data.pop("audios", [])
|
||||
|
||||
if not audios:
|
||||
return super()._call_hf_processor(
|
||||
hf_processor,
|
||||
prompt=prompt,
|
||||
processor_data=processor_data,
|
||||
mm_processor_kwargs=mm_processor_kwargs,
|
||||
)
|
||||
|
||||
feature_extractor = self._get_feature_extractor()
|
||||
mm_processor_kwargs = dict(
|
||||
**mm_processor_kwargs,
|
||||
sampling_rate=feature_extractor.sampling_rate,
|
||||
)
|
||||
|
||||
# Already resampled by _get_processor_data
|
||||
assert is_list_of(audios, np.ndarray)
|
||||
|
||||
# Ultravox processor doesn't support multiple inputs,
|
||||
# therefore we need to input text and audio one by one
|
||||
audio_features, audio_token_len = [], []
|
||||
shared_outputs = {}
|
||||
for audio in audios:
|
||||
# NOTE: Ultravox processor accepts "audio" instead of "audios"
|
||||
item_processor_data = dict(**processor_data, audio=audio)
|
||||
|
||||
item_outputs = super()._call_hf_processor(
|
||||
hf_processor,
|
||||
prompt=prompt,
|
||||
processor_data=item_processor_data,
|
||||
mm_processor_kwargs=mm_processor_kwargs,
|
||||
)
|
||||
|
||||
audio_features.append(item_outputs.pop("audio_values")[0])
|
||||
audio_token_len.append(item_outputs.pop("audio_token_len").item())
|
||||
shared_outputs = item_outputs
|
||||
|
||||
combined_outputs = dict(
|
||||
**shared_outputs,
|
||||
audio_features=audio_features,
|
||||
audio_token_len=audio_token_len,
|
||||
)
|
||||
return BatchFeature(combined_outputs)
|
||||
|
||||
def _get_prompt_replacements(
|
||||
self,
|
||||
@ -147,7 +146,7 @@ class UltravoxMultiModalProcessor(BaseMultiModalProcessor):
|
||||
mm_processor_kwargs: Mapping[str, object],
|
||||
) -> list[PromptReplacement]:
|
||||
hf_processor = self._get_hf_processor()
|
||||
placeholder = hf_processor.audio_token_replacement
|
||||
placeholder = hf_processor.audio_token_replacement # type: ignore
|
||||
|
||||
def get_replacement_ultravox(item_idx: int):
|
||||
audio_token_len = hf_inputs["audio_token_len"][item_idx]
|
||||
@ -171,7 +170,7 @@ class UltravoxMultiModalProcessor(BaseMultiModalProcessor):
|
||||
|
||||
audio_count = mm_counts["audio"]
|
||||
audio = np.zeros(audio_len)
|
||||
data = {"audio": [(audio, sampling_rate)] * audio_count}
|
||||
data = {"audio": [audio] * audio_count}
|
||||
|
||||
return ProcessorInputs(
|
||||
prompt_text="<|audio|>" * audio_count,
|
||||
|
@ -1,3 +1,6 @@
|
||||
import numpy as np
|
||||
import numpy.typing as npt
|
||||
|
||||
from vllm.inputs.registry import InputContext
|
||||
|
||||
from .base import MultiModalPlugin
|
||||
@ -21,3 +24,18 @@ class AudioPlugin(MultiModalPlugin):
|
||||
def _default_max_multimodal_tokens(self, ctx: InputContext) -> int:
|
||||
raise NotImplementedError(
|
||||
"There is no default maximum multimodal tokens")
|
||||
|
||||
|
||||
def resample_audio(
|
||||
audio: npt.NDArray[np.floating],
|
||||
*,
|
||||
orig_sr: float,
|
||||
target_sr: float,
|
||||
) -> npt.NDArray[np.floating]:
|
||||
try:
|
||||
import librosa
|
||||
except ImportError as exc:
|
||||
msg = "Please install vllm[audio] for audio support."
|
||||
raise ImportError(msg) from exc
|
||||
|
||||
return librosa.resample(audio, orig_sr=orig_sr, target_sr=target_sr)
|
||||
|
@ -15,31 +15,32 @@ _T = TypeVar("_T")
|
||||
# yapf: disable
|
||||
ImageItem: TypeAlias = Union[Image, np.ndarray, torch.Tensor]
|
||||
"""
|
||||
A :class:`transformers.image_utils.ImageInput` representing a single image,
|
||||
which can be passed to a HuggingFace :code:`ImageProcessor`.
|
||||
A :class:`transformers.image_utils.ImageInput` representing a single image
|
||||
item, which can be passed to a HuggingFace :code:`ImageProcessor`.
|
||||
"""
|
||||
|
||||
VideoItem: TypeAlias = Union[
|
||||
List[Image],
|
||||
list[Image],
|
||||
np.ndarray,
|
||||
torch.Tensor,
|
||||
List[np.ndarray],
|
||||
List[torch.Tensor],
|
||||
list[np.ndarray],
|
||||
list[torch.Tensor],
|
||||
]
|
||||
"""
|
||||
|
||||
A :class:`transformers.image_utils.VideoInput` representing a single video,
|
||||
which can be passed to a HuggingFace :code:`VideoProcessor`.
|
||||
A :class:`transformers.image_utils.VideoInput` representing a single video
|
||||
item, which can be passed to a HuggingFace :code:`VideoProcessor`.
|
||||
"""
|
||||
|
||||
AudioItem: TypeAlias = Union[
|
||||
np.ndarray,
|
||||
List[float],
|
||||
Tuple[np.ndarray, float], # DEPRECATED: Use mm_processor_kwargs instead
|
||||
list[float],
|
||||
# `(audio, sampling_rate)`: If the audio's sampling rate is different
|
||||
# from that expected by the model, we need to resample it.
|
||||
tuple[np.ndarray, float],
|
||||
]
|
||||
"""
|
||||
Represents a single audio that can be inputted to a HuggingFace
|
||||
:code:`AudioProcessor`.
|
||||
Represents a single audio
|
||||
item, which can be passed to a HuggingFace :code:`AudioProcessor`.
|
||||
"""
|
||||
# yapf: enable
|
||||
|
||||
|
@ -17,6 +17,7 @@ from vllm.logger import init_logger
|
||||
from vllm.transformers_utils.tokenizer import AnyTokenizer, MistralTokenizer
|
||||
from vllm.utils import flatten_2d_lists, full_groupby, is_list_of
|
||||
|
||||
from .audio import resample_audio
|
||||
from .inputs import (AudioItem, ImageItem, MultiModalDataDict,
|
||||
MultiModalInputsV2, MultiModalKwargs, PlaceholderRange,
|
||||
VideoItem)
|
||||
@ -30,7 +31,7 @@ _PromptSeq = Union[str, list[int]]
|
||||
@dataclass
|
||||
class PromptReplacement:
|
||||
modality: str
|
||||
"""The modality for which the replacement is made"""
|
||||
"""The modality for which the replacement is made."""
|
||||
|
||||
target: _PromptSeq
|
||||
"""The text or token sequence to find and replace."""
|
||||
@ -211,20 +212,48 @@ class MultiModalDataItems(UserDict[str, list[Any]]):
|
||||
corresponds to a list.
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def from_dict(data: MultiModalDataDict) -> "MultiModalDataItems":
|
||||
"""
|
||||
Normalize :class:`MultiModalDataDict` to :class:`MultiModalDataItems`.
|
||||
"""
|
||||
multi_data = MultiModalDataItems()
|
||||
|
||||
for k, v in data.items():
|
||||
# yapf: disable
|
||||
if k == "video":
|
||||
# Special case since even a single item can be a list
|
||||
multi_data[k] = ( # type: ignore[index]
|
||||
v if is_list_of(v, (list, torch.Tensor)) else [v]
|
||||
)
|
||||
elif k in ("image", "audio"):
|
||||
multi_data[k] = ( # type: ignore[index]
|
||||
v if isinstance(v, (list, torch.Tensor)) else [v]
|
||||
)
|
||||
else:
|
||||
multi_data[k] = v if isinstance(v, list) else [v] # type: ignore[index]
|
||||
# yapf: enable
|
||||
|
||||
return multi_data
|
||||
|
||||
# NOTE: When a field (e.g. `images`) doesn't exist, directly appending to
|
||||
# `self.images` doesn't update this dictionary, which may be confusing
|
||||
# We annotate the getter methods as `Sequence` to prevent others from
|
||||
# trying to update the list in this way
|
||||
@property
|
||||
def image(self) -> list[ImageItem]:
|
||||
return self["image"]
|
||||
def images(self) -> Sequence[ImageItem]:
|
||||
return self.get("image", [])
|
||||
|
||||
@property
|
||||
def video(self) -> list[VideoItem]:
|
||||
return self["video"]
|
||||
def videos(self) -> Sequence[VideoItem]:
|
||||
return self.get("video", [])
|
||||
|
||||
@property
|
||||
def audio(self) -> list[AudioItem]:
|
||||
return self["audio"]
|
||||
def audios(self) -> Sequence[AudioItem]:
|
||||
return self.get("audio", [])
|
||||
|
||||
def get_image_size(self, item_idx: int) -> ImageSize:
|
||||
image = self.image[item_idx]
|
||||
image = self.images[item_idx]
|
||||
|
||||
if isinstance(image, Image):
|
||||
return ImageSize(*image.size)
|
||||
@ -234,25 +263,41 @@ class MultiModalDataItems(UserDict[str, list[Any]]):
|
||||
|
||||
assert_never(image)
|
||||
|
||||
def get_audio_with_sr(
|
||||
self,
|
||||
item_idx: int,
|
||||
*,
|
||||
default_sr: float,
|
||||
) -> tuple[np.ndarray, float]:
|
||||
audio = self.audios[item_idx]
|
||||
|
||||
def to_multi_format(data: MultiModalDataDict) -> MultiModalDataItems:
|
||||
"""
|
||||
Normalize :class:`MultiModalDataDict` to :class:`MultiModalDataItems`.
|
||||
"""
|
||||
multi_data = MultiModalDataItems()
|
||||
if isinstance(audio, tuple):
|
||||
return audio
|
||||
if isinstance(audio, list):
|
||||
return np.array(audio), default_sr
|
||||
if isinstance(audio, np.ndarray):
|
||||
return audio, default_sr
|
||||
|
||||
for k, v in data.items():
|
||||
# yapf: disable
|
||||
if k == "video":
|
||||
# Special case since even a single item can be a list
|
||||
multi_data[k] = v if is_list_of(v, list) else [v] # type: ignore[index]
|
||||
elif k in ("image", "audio"):
|
||||
multi_data[k] = v if isinstance(v, list) else [v] # type: ignore[index]
|
||||
else:
|
||||
multi_data[k] = v if isinstance(v, list) else [v] # type: ignore[index]
|
||||
# yapf: enable
|
||||
assert_never(audio)
|
||||
|
||||
return multi_data
|
||||
def resample_audios(self, new_sr: float, *, drop_sr: bool = True) -> None:
|
||||
"""
|
||||
If :code:`drop_sr=True`, the audio items in this dictionary are updated
|
||||
to be NumPy arrays which implicitly means that their sampling rate is
|
||||
the same as the model's expected sampling rate; otherwise, they remain
|
||||
as :code:`(audio, new_sr)` tuples.
|
||||
"""
|
||||
if not self.audios:
|
||||
return
|
||||
|
||||
new_audios = []
|
||||
for item_idx in range(len(self.audios)):
|
||||
audio, sr = self.get_audio_with_sr(item_idx, default_sr=new_sr)
|
||||
audio = resample_audio(audio, orig_sr=sr, target_sr=new_sr)
|
||||
|
||||
new_audios.append(audio if drop_sr else (audio, new_sr))
|
||||
|
||||
self["audio"] = new_audios
|
||||
|
||||
|
||||
class _TokenMatch(NamedTuple):
|
||||
@ -596,18 +641,20 @@ class BaseMultiModalProcessor(ABC):
|
||||
|
||||
def _get_processor_data(
|
||||
self,
|
||||
mm_data: MultiModalDataDict,
|
||||
) -> BatchFeature:
|
||||
mm_items: MultiModalDataItems,
|
||||
) -> tuple[dict[str, Any], dict[str, Any]]:
|
||||
processor_data = dict[str, Any]()
|
||||
passthrough_data = dict[str, Any]()
|
||||
for k, v in mm_data.items():
|
||||
|
||||
for k, v in mm_items.items():
|
||||
# TODO: Make a separate modality for embedding inputs
|
||||
# to avoid confusion
|
||||
if k in ("image", "video", "audio"):
|
||||
if isinstance(v, torch.Tensor) and v.ndim == 3:
|
||||
# Pass through embedding inputs (single)
|
||||
passthrough_data[f"{k}_embeds"] = [v]
|
||||
elif is_list_of(v, torch.Tensor) and v[0].ndim == 2:
|
||||
elif (is_list_of(v, torch.Tensor) and len(v) > 0
|
||||
and v[0].ndim == 2):
|
||||
# Pass through embedding inputs (multi)
|
||||
passthrough_data[f"{k}_embeds"] = v
|
||||
else:
|
||||
@ -615,40 +662,41 @@ class BaseMultiModalProcessor(ABC):
|
||||
processor_data[f"{k}s"] = v
|
||||
else:
|
||||
processor_data[k] = v
|
||||
|
||||
return processor_data, passthrough_data
|
||||
|
||||
def _call_hf_processor(
|
||||
self,
|
||||
hf_processor: ProcessorMixin,
|
||||
prompt: str,
|
||||
processor_data: Mapping[str, object],
|
||||
mm_processor_kwargs: Mapping[str, object],
|
||||
) -> BatchFeature:
|
||||
return self.ctx.call_hf_processor(
|
||||
hf_processor,
|
||||
prompt,
|
||||
processor_data,
|
||||
mm_processor_kwargs,
|
||||
)
|
||||
|
||||
def _apply_hf_processor(
|
||||
self,
|
||||
prompt: str,
|
||||
mm_data: MultiModalDataDict,
|
||||
mm_items: MultiModalDataItems,
|
||||
mm_processor_kwargs: Mapping[str, object],
|
||||
) -> BatchFeature:
|
||||
# some mm_processor_kwargs may be used in processor initialization
|
||||
# instead of processor call
|
||||
hf_processor = self._get_hf_processor(**mm_processor_kwargs)
|
||||
|
||||
processor_data, passthrough_data = self._get_processor_data(mm_data)
|
||||
processor_data, passthrough_data = self._get_processor_data(mm_items)
|
||||
|
||||
assert callable(hf_processor)
|
||||
mm_processor_kwargs = self.ctx.resolve_hf_processor_call_kwargs(
|
||||
hf_inputs = self._call_hf_processor(
|
||||
hf_processor,
|
||||
mm_processor_kwargs,
|
||||
prompt=prompt,
|
||||
processor_data=processor_data,
|
||||
mm_processor_kwargs=mm_processor_kwargs,
|
||||
)
|
||||
|
||||
try:
|
||||
hf_inputs = hf_processor(
|
||||
text=prompt, # type: ignore
|
||||
**processor_data,
|
||||
**mm_processor_kwargs,
|
||||
return_tensors="pt",
|
||||
)
|
||||
except Exception as exc:
|
||||
data = dict(text=prompt, **processor_data)
|
||||
|
||||
raise RuntimeError(
|
||||
f"Failed to apply {type(hf_processor).__name__} "
|
||||
f"on data={data} with kwargs={mm_processor_kwargs}") from exc
|
||||
|
||||
hf_inputs.update(passthrough_data)
|
||||
|
||||
return hf_inputs
|
||||
@ -730,14 +778,13 @@ class BaseMultiModalProcessor(ABC):
|
||||
3. Extract information about the placeholder tokens from the
|
||||
processed token IDs.
|
||||
"""
|
||||
tokenizer = self._get_tokenizer()
|
||||
mm_items = MultiModalDataItems.from_dict(mm_data)
|
||||
|
||||
hf_inputs = self._apply_hf_processor(prompt_text, mm_data,
|
||||
hf_inputs = self._apply_hf_processor(prompt_text, mm_items,
|
||||
mm_processor_kwargs)
|
||||
prompt_ids, = hf_inputs.pop("input_ids").tolist()
|
||||
mm_kwargs = MultiModalKwargs(hf_inputs)
|
||||
|
||||
mm_items = to_multi_format(mm_data)
|
||||
prompt_repls = self._get_prompt_replacements(mm_items, hf_inputs,
|
||||
mm_processor_kwargs)
|
||||
all_prompt_repls = self._bind_prompt_replacements(prompt_repls)
|
||||
@ -749,6 +796,7 @@ class BaseMultiModalProcessor(ABC):
|
||||
prompt_ids, mm_item_counts)
|
||||
|
||||
if all_placeholders:
|
||||
tokenizer = self._get_tokenizer()
|
||||
prompt_text = _decode(tokenizer, prompt_ids)
|
||||
else:
|
||||
(
|
||||
|
@ -775,7 +775,7 @@ def get_dtype_size(dtype: torch.dtype) -> int:
|
||||
# `collections` helpers
|
||||
def is_list_of(
|
||||
value: object,
|
||||
typ: Type[T],
|
||||
typ: Union[type[T], tuple[type[T], ...]],
|
||||
*,
|
||||
check: Literal["first", "all"] = "first",
|
||||
) -> TypeIs[List[T]]:
|
||||
@ -1282,6 +1282,7 @@ async def _run_task_with_lock(task: Callable, lock: asyncio.Lock, *args,
|
||||
def supports_kw(
|
||||
callable: Callable[..., object],
|
||||
kw_name: str,
|
||||
*,
|
||||
requires_kw_only: bool = False,
|
||||
allow_var_kwargs: bool = True,
|
||||
) -> bool:
|
||||
@ -1326,6 +1327,8 @@ def resolve_mm_processor_kwargs(
|
||||
init_kwargs: Optional[Mapping[str, object]],
|
||||
inference_kwargs: Optional[Mapping[str, object]],
|
||||
callable: Callable[..., object],
|
||||
*,
|
||||
requires_kw_only: bool = True,
|
||||
allow_var_kwargs: bool = False,
|
||||
) -> Dict[str, Any]:
|
||||
"""Applies filtering to eliminate invalid mm_processor_kwargs, i.e.,
|
||||
@ -1344,11 +1347,17 @@ def resolve_mm_processor_kwargs(
|
||||
runtime_mm_kwargs = get_allowed_kwarg_only_overrides(
|
||||
callable,
|
||||
overrides=inference_kwargs,
|
||||
allow_var_kwargs=allow_var_kwargs)
|
||||
requires_kw_only=requires_kw_only,
|
||||
allow_var_kwargs=allow_var_kwargs,
|
||||
)
|
||||
|
||||
# Filter init time multimodal processor kwargs provided
|
||||
init_mm_kwargs = get_allowed_kwarg_only_overrides(
|
||||
callable, overrides=init_kwargs, allow_var_kwargs=allow_var_kwargs)
|
||||
callable,
|
||||
overrides=init_kwargs,
|
||||
requires_kw_only=requires_kw_only,
|
||||
allow_var_kwargs=allow_var_kwargs,
|
||||
)
|
||||
|
||||
# Merge the final processor kwargs, prioritizing inference
|
||||
# time values over the initialization time values.
|
||||
@ -1359,6 +1368,8 @@ def resolve_mm_processor_kwargs(
|
||||
def get_allowed_kwarg_only_overrides(
|
||||
callable: Callable[..., object],
|
||||
overrides: Optional[Mapping[str, object]],
|
||||
*,
|
||||
requires_kw_only: bool = True,
|
||||
allow_var_kwargs: bool = False,
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
@ -1390,16 +1401,21 @@ def get_allowed_kwarg_only_overrides(
|
||||
for kwarg_name, val in overrides.items()
|
||||
if supports_kw(callable,
|
||||
kwarg_name,
|
||||
requires_kw_only=True,
|
||||
requires_kw_only=requires_kw_only,
|
||||
allow_var_kwargs=allow_var_kwargs)
|
||||
}
|
||||
|
||||
# If anything is dropped, log a warning
|
||||
dropped_keys = overrides.keys() - filtered_overrides.keys()
|
||||
if dropped_keys:
|
||||
logger.warning(
|
||||
"The following intended overrides are not keyword-only args "
|
||||
"and and will be dropped: %s", dropped_keys)
|
||||
if requires_kw_only:
|
||||
logger.warning(
|
||||
"The following intended overrides are not keyword-only args "
|
||||
"and and will be dropped: %s", dropped_keys)
|
||||
else:
|
||||
logger.warning(
|
||||
"The following intended overrides are not keyword args "
|
||||
"and and will be dropped: %s", dropped_keys)
|
||||
|
||||
return filtered_overrides
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user