[VLM] Merged multimodal processor for Qwen2-Audio (#11303)

Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
Cyrus Leung 2024-12-19 14:14:17 +08:00 committed by GitHub
parent c6b0a7d3ba
commit 6142ef0ada
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
11 changed files with 414 additions and 358 deletions

View File

@ -18,6 +18,10 @@ question_per_audio_count = {
2: "What sport and what nursery rhyme are referenced?"
}
# NOTE: The default `max_num_seqs` and `max_model_len` may result in OOM on
# lower-end GPUs.
# Unless specified, these settings have been tested to work on a single L4.
# Ultravox 0.3
def run_ultravox(question: str, audio_count: int):
@ -33,6 +37,8 @@ def run_ultravox(question: str, audio_count: int):
add_generation_prompt=True)
llm = LLM(model=model_name,
max_model_len=4096,
max_num_seqs=5,
trust_remote_code=True,
limit_mm_per_prompt={"audio": audio_count})
stop_token_ids = None

View File

@ -5,6 +5,7 @@ import pytest
import pytest_asyncio
from transformers import AutoModel, AutoTokenizer, BatchEncoding
from vllm.multimodal.audio import resample_audio
from vllm.sequence import SampleLogprobs
from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE
@ -130,16 +131,14 @@ def run_test(
dtype=dtype,
postprocess_inputs=process,
auto_cls=AutoModel) as hf_model:
import librosa
hf_outputs_per_audio = [
hf_model.generate_greedy_logprobs_limit(
[hf_prompt],
max_tokens,
num_logprobs=num_logprobs,
audios=[(librosa.resample(audio[0],
orig_sr=audio[1],
target_sr=16000), 16000)])
audios=[(resample_audio(audio[0],
orig_sr=audio[1],
target_sr=16000), 16000)])
for _, hf_prompt, audio in prompts_and_audios
]

View File

@ -1,11 +1,11 @@
import functools
from collections import UserDict
from dataclasses import dataclass
from typing import (TYPE_CHECKING, Any, Callable, Dict, Mapping, NamedTuple,
Optional, Protocol, Type)
from typing import (TYPE_CHECKING, Any, Callable, Mapping, NamedTuple,
Optional, Protocol, Union)
from torch import nn
from transformers import PretrainedConfig, ProcessorMixin
from transformers import BatchFeature, PretrainedConfig, ProcessorMixin
from typing_extensions import TypeVar, assert_never
from vllm.logger import init_logger
@ -26,6 +26,7 @@ if TYPE_CHECKING:
logger = init_logger(__name__)
C = TypeVar("C", bound=PretrainedConfig, default=PretrainedConfig)
P = TypeVar("P", bound=ProcessorMixin, default=ProcessorMixin)
@dataclass(frozen=True)
@ -38,24 +39,28 @@ class InputContext:
model_config: "ModelConfig"
"""The configuration of the model."""
def get_hf_config(self, hf_config_type: Type[C] = PretrainedConfig) -> C:
def get_hf_config(
self,
typ: Union[type[C], tuple[type[C], ...]] = PretrainedConfig,
/,
) -> C:
"""
Get the HuggingFace configuration
(:class:`transformers.PretrainedConfig`) of the model,
additionally checking its type.
Raises:
TypeError: If the model is not of the specified type.
TypeError: If the configuration is not of the specified type.
"""
hf_config = self.model_config.hf_config
if not isinstance(hf_config, hf_config_type):
if not isinstance(hf_config, typ):
raise TypeError("Invalid type of HuggingFace config. "
f"Expected type: {hf_config_type}, but "
f"Expected type: {typ}, but "
f"found type: {type(hf_config)}")
return hf_config
def get_hf_image_processor_config(self) -> Dict[str, Any]:
def get_hf_image_processor_config(self) -> dict[str, Any]:
"""
Get the HuggingFace image processor configuration of the model.
"""
@ -74,18 +79,37 @@ class InputContext:
return mm_config
def get_hf_processor(self, **kwargs: object) -> ProcessorMixin:
def get_hf_processor(
self,
typ: Union[type[P], tuple[type[P], ...]] = ProcessorMixin,
/,
**kwargs: object,
) -> P:
"""
Get the HuggingFace processor
(:class:`transformers.ProcessorMixin`) of the model,
additionally checking its type.
Raises:
TypeError: If the processor is not of the specified type.
"""
base_kwargs = self.model_config.mm_processor_kwargs
if base_kwargs is None:
base_kwargs = {}
merged_kwargs = {**base_kwargs, **kwargs}
return cached_get_processor(
hf_processor = cached_get_processor(
self.model_config.model,
trust_remote_code=self.model_config.trust_remote_code,
**merged_kwargs,
)
if not isinstance(hf_processor, typ):
raise TypeError("Invalid type of HuggingFace processor. "
f"Expected type: {typ}, but "
f"found type: {type(hf_processor)}")
return hf_processor
@dataclass(frozen=True)
@ -93,39 +117,55 @@ class InputProcessingContext(InputContext):
tokenizer: AnyTokenizer
"""The tokenizer used to tokenize the inputs."""
def get_hf_processor(self, **kwargs: object) -> ProcessorMixin:
base_kwargs = self.model_config.mm_processor_kwargs
if base_kwargs is None:
base_kwargs = {}
merged_kwargs = {**base_kwargs, **kwargs}
return cached_get_processor(
self.model_config.model,
tokenizer=self.tokenizer, # Override the tokenizer with ours
trust_remote_code=self.model_config.trust_remote_code,
**merged_kwargs,
def get_hf_processor(
self,
typ: Union[type[P], tuple[type[P], ...]] = ProcessorMixin,
/,
**kwargs: object,
) -> P:
return super().get_hf_processor(
typ,
tokenizer=self.tokenizer,
**kwargs,
)
def resolve_hf_processor_call_kwargs(
def call_hf_processor(
self,
hf_processor: ProcessorMixin,
prompt: str,
processor_data: Mapping[str, object],
inference_kwargs: Mapping[str, object],
) -> Mapping[str, object]:
) -> BatchFeature:
assert callable(hf_processor)
base_kwargs = self.model_config.mm_processor_kwargs
if base_kwargs is None:
base_kwargs = {}
return resolve_mm_processor_kwargs(
merged_kwargs = resolve_mm_processor_kwargs(
base_kwargs,
inference_kwargs,
hf_processor,
requires_kw_only=False,
allow_var_kwargs=True,
)
try:
return hf_processor(
text=prompt,
**processor_data,
**merged_kwargs,
return_tensors="pt",
)
except Exception as exc:
data = dict(text=prompt, **processor_data)
msg = (f"Failed to apply {type(hf_processor).__name__} "
f"on data={data} with kwargs={merged_kwargs}")
N = TypeVar("N", bound=Type[nn.Module])
raise RuntimeError(msg) from exc
N = TypeVar("N", bound=type[nn.Module])
class DummyData(NamedTuple):
@ -232,7 +272,7 @@ class InputRegistry:
return wrapper
def _get_dummy_data_factory(self, model_cls: Type[nn.Module]):
def _get_dummy_data_factory(self, model_cls: type[nn.Module]):
return self._dummy_factories_by_model_type \
.get(model_cls, self._default_dummy_data_factory)
@ -257,7 +297,7 @@ class InputRegistry:
return wrapper
def _get_dummy_encoder_data_factory(self, model_cls: Type[nn.Module]):
def _get_dummy_encoder_data_factory(self, model_cls: type[nn.Module]):
return self._dummy_encoder_factories_by_model_type \
.get(model_cls, self._default_dummy_data_factory)
@ -368,14 +408,14 @@ class InputRegistry:
return wrapper
def _get_model_input_processor(self, model_cls: Type[nn.Module]):
def _get_model_input_processor(self, model_cls: type[nn.Module]):
return self._input_processors_by_model_type \
.get(model_cls, self._default_input_processor)
def _ensure_mm_kwargs(
self,
inputs: SingletonInputs,
mm_processor_kwargs: Dict[str, Any],
mm_processor_kwargs: dict[str, Any],
):
if inputs["type"] == "token":
# In case the input processor for that model fails to set it

View File

@ -133,8 +133,8 @@ class LlavaMultiModalProcessor(BaseMultiModalProcessor):
hf_processor.__is_patched__ = True # type: ignore
def _get_hf_processor(self) -> Union[LlavaProcessor, PixtralProcessor]:
hf_processor = self.ctx.get_hf_processor()
assert isinstance(hf_processor, (LlavaProcessor, PixtralProcessor))
hf_processor = self.ctx.get_hf_processor(
(LlavaProcessor, PixtralProcessor))
if isinstance(hf_processor, PixtralProcessor):
self._patch_pixtral_processor(hf_processor)

View File

@ -34,7 +34,6 @@ from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.inputs import NestedTensors
from vllm.multimodal.processing import (BaseMultiModalProcessor,
MultiModalDataDict,
MultiModalDataItems, ProcessorInputs,
PromptReplacement)
from vllm.sequence import IntermediateTensors
@ -330,20 +329,27 @@ class Phi3VMultiModalProcessor(BaseMultiModalProcessor):
return self.ctx.get_hf_processor(num_crops=num_crops)
return self.ctx.get_hf_processor()
def _apply_hf_processor(
def _call_hf_processor(
self,
hf_processor: ProcessorMixin,
prompt: str,
mm_data: MultiModalDataDict,
processor_data: Mapping[str, object],
mm_processor_kwargs: Mapping[str, object],
) -> BatchFeature:
processed_outputs = super()._apply_hf_processor(
prompt, mm_data, mm_processor_kwargs)
processed_outputs = super()._call_hf_processor(
hf_processor,
prompt=prompt,
processor_data=processor_data,
mm_processor_kwargs=mm_processor_kwargs,
)
# Phi3v processor has inserted -1, -2 etc as placeholder in prompt_ids,
# which will cause OverflowError when decoding the prompt_ids.
# Therefore, we need to do an early replacement here
token_ids = processed_outputs['input_ids']
token_ids[token_ids < 0] = _IMAGE_TOKEN_ID
processed_outputs['input_ids'] = token_ids
return processed_outputs
def _get_prompt_replacements(

View File

@ -19,45 +19,43 @@
# See the License for the specific language governing permissions and
# limitations under the License.
"""Inference-only Qwen2-Audio model compatible with HuggingFace weights."""
from functools import cached_property, lru_cache
from typing import (Iterable, List, Mapping, Optional, Set, Tuple, TypedDict,
Union)
from functools import cached_property
from typing import (Any, Iterable, List, Mapping, Optional, Set, Tuple,
TypedDict, Union)
import librosa
import numpy as np
import torch
import torch.nn as nn
from transformers import Qwen2AudioEncoder
from transformers import BatchFeature, ProcessorMixin
from transformers.models.qwen2_audio import (Qwen2AudioConfig,
Qwen2AudioEncoder,
Qwen2AudioProcessor)
from transformers.models.whisper import WhisperFeatureExtractor
from vllm.attention import AttentionMetadata
from vllm.config import VllmConfig
from vllm.inputs import (INPUT_REGISTRY, DecoderOnlyInputs, DummyData,
InputContext, token_inputs)
from vllm.logger import init_logger
from vllm.inputs import InputContext
from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalKwargs
from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.inputs import NestedTensors
from vllm.multimodal.utils import consecutive_placeholder_ranges
from vllm.sequence import IntermediateTensors, SequenceData
from vllm.multimodal.processing import (BaseMultiModalProcessor,
MultiModalDataItems, ProcessorInputs,
PromptReplacement)
from vllm.sequence import IntermediateTensors
from .interfaces import SupportsMultiModal, SupportsPP
from .utils import (AutoWeightsLoader, init_vllm_registered_model,
maybe_prefix, merge_multimodal_embeddings)
logger = init_logger(__name__)
# # === Audio Inputs === #
class Qwen2AudioInputs(TypedDict):
input_features: torch.Tensor
"""Shape:
`(num_audios, num_mel_bins, 3000)`
"""
"""Shape: `(num_audios, num_mel_bins, 3000)`"""
feature_attention_mask: torch.Tensor
"""Shape: `(num_audios, 3000)`
"""
"""Shape: `(num_audios, 3000)`"""
# === Audio Encoder === #
@ -74,187 +72,114 @@ class Qwen2AudioMultiModalProjector(nn.Module):
return hidden_states
def dummy_data_for_qwen2_audio(ctx: InputContext, seq_len: int,
mm_counts: Mapping[str, int]):
num_audios = mm_counts["audio"]
max_tokens_per_audio = get_max_qwen2_audio_audio_tokens(ctx)
max_llm_audio_tokens = max_tokens_per_audio * num_audios
if seq_len - max_llm_audio_tokens - 2 < 0:
raise RuntimeError(
f"Qwen2-Audio cannot process {num_audios} audios in a prompt, "
"please increase max_model_len or reduce audio limit by "
"--limit-mm-per-prompt.")
audio_token_index = ctx.model_config.hf_config.audio_token_index
dummy_seqdata = SequenceData.from_prompt_token_counts(
(audio_token_index, max_llm_audio_tokens),
(0, seq_len - max_llm_audio_tokens),
)
dummy_audio = np.full((max_llm_audio_tokens * 2 * 2 * 160, ), 0.)
return DummyData(
dummy_seqdata, {"audio": [(dummy_audio, 16000)] * num_audios}, {
"audio":
consecutive_placeholder_ranges(num_items=num_audios,
item_size=max_tokens_per_audio)
})
def get_processor(
processor_name: str,
*args,
trust_remote_code: bool = False,
**kwargs,
):
"""Gets a processor for the given model name via HuggingFace.
Derived from `vllm.transformers_utils.image_processor.get_image_processor`.
"""
# don't put this import at the top level
# it will call torch.cuda.device_count()
from transformers import AutoProcessor
try:
processor = AutoProcessor.from_pretrained(
processor_name,
*args,
trust_remote_code=trust_remote_code,
**kwargs)
except ValueError as e:
# If the error pertains to the processor class not existing or not
# currently being imported, suggest using the --trust-remote-code flag.
# Unlike AutoTokenizer, AutoProcessor does not separate such errors
if not trust_remote_code:
err_msg = (
"Failed to load the processor. If the processor is "
"a custom processor not yet available in the HuggingFace "
"transformers library, consider setting "
"`trust_remote_code=True` in LLM or using the "
"`--trust-remote-code` flag in the CLI.")
raise RuntimeError(err_msg) from e
else:
raise e
return processor
cached_get_processor = lru_cache(get_processor)
# From Qwen2AudioEncoder._get_feat_extract_output_lengths
def _get_feat_extract_output_lengths(input_lengths: torch.LongTensor):
"""
Computes the output length of the convolutional layers
and the output length of the audio encoder
"""
input_lengths = (input_lengths - 1) // 2 + 1
output_lengths = (input_lengths - 2) // 2 + 1
return input_lengths, output_lengths
feat_lengths = (input_lengths - 1) // 2 + 1
output_lengths = (feat_lengths - 2) // 2 + 1
return feat_lengths, output_lengths
def get_max_qwen2_audio_audio_tokens(ctx: InputContext) -> int:
max_source_position = (
ctx.model_config.hf_config.audio_config.max_source_positions)
hf_config = ctx.get_hf_config(Qwen2AudioConfig)
max_source_position = hf_config.audio_config.max_source_positions
output_lengths = (max_source_position - 2) // 2 + 1
return output_lengths
def input_processor_for_qwen2_audio(
ctx: InputContext, inputs: DecoderOnlyInputs) -> DecoderOnlyInputs:
multi_modal_data = inputs.get("multi_modal_data")
if multi_modal_data is None or "audio" not in multi_modal_data:
return inputs
class Qwen2AudioMultiModalProcessor(BaseMultiModalProcessor):
audios = multi_modal_data["audio"]
if not isinstance(audios, list):
audios = [audios]
def _get_hf_processor(self) -> Qwen2AudioProcessor:
return self.ctx.get_hf_processor(Qwen2AudioProcessor)
if len(audios) == 0:
return inputs
def _get_feature_extractor(self) -> WhisperFeatureExtractor:
return self._get_hf_processor().feature_extractor # type: ignore
processor = cached_get_processor(ctx.model_config.model)
resampled_audios = [
librosa.resample(audio,
orig_sr=sampling_rate,
target_sr=processor.feature_extractor.sampling_rate)
for audio, sampling_rate in audios
]
audio_input_lengths = np.array(
[min(3000, _.shape[0] // 160 + 1) for _ in resampled_audios])
def _get_processor_data(
self,
mm_items: MultiModalDataItems,
) -> tuple[dict[str, Any], dict[str, Any]]:
# resample audio to the model's sampling rate
feature_extractor = self._get_feature_extractor()
mm_items.resample_audios(feature_extractor.sampling_rate)
audio_feat_lengths, audio_output_lengths = _get_feat_extract_output_lengths(
audio_input_lengths)
return super()._get_processor_data(mm_items)
audio_token_index = ctx.model_config.hf_config.audio_token_index
def _call_hf_processor(
self,
hf_processor: ProcessorMixin,
prompt: str,
processor_data: Mapping[str, object],
mm_processor_kwargs: Mapping[str, object],
) -> BatchFeature:
processor_data = dict(processor_data)
audios = processor_data.pop("audios", [])
input_ids = inputs['prompt_token_ids']
if audios:
processor_data["audios"] = audios
new_input_ids = []
audio_num = input_ids.count(audio_token_index)
assert len(audio_input_lengths) == audio_num, \
(f'The text input contains {audio_num} audio tokens, '
f'but {len(audio_input_lengths)} audios provided')
start = 0
for audio_idx in range(audio_num):
end = input_ids.index(audio_token_index, start)
new_input_ids.extend(input_ids[start:end]) # text part
feature_extractor = self._get_feature_extractor()
mm_processor_kwargs = dict(
**mm_processor_kwargs,
sampling_rate=feature_extractor.sampling_rate,
)
else:
# NOTE: WhisperFeatureExtractor cannot handle empty list of audios
pass
new_input_ids.extend([audio_token_index] *
audio_output_lengths[audio_idx])
start = end + 1
new_input_ids.extend(input_ids[start:])
return super()._call_hf_processor(
hf_processor,
prompt=prompt,
processor_data=processor_data,
mm_processor_kwargs=mm_processor_kwargs,
)
return token_inputs(
prompt_token_ids=new_input_ids,
prompt=inputs.get("prompt"),
multi_modal_data=multi_modal_data,
)
def _get_prompt_replacements(
self,
mm_items: MultiModalDataItems,
hf_inputs: BatchFeature,
mm_processor_kwargs: Mapping[str, object],
) -> list[PromptReplacement]:
hf_config = self.ctx.get_hf_config(Qwen2AudioConfig)
placeholder = hf_config.audio_token_index
feature_attention_mask = hf_inputs.get("feature_attention_mask")
if feature_attention_mask is None:
audio_output_lengths = []
else:
_, audio_output_lengths = _get_feat_extract_output_lengths(
feature_attention_mask.sum(-1))
def input_mapper_for_qwen2_audio(
ctx: InputContext,
multi_modal_data: Union[np.ndarray, List[np.ndarray]],
) -> MultiModalKwargs:
"""Input mapper for Qwen2-Audio."""
if not isinstance(multi_modal_data, list):
multi_modal_data = [multi_modal_data]
def get_replacement_qwen2_audio(item_idx: int):
return [placeholder] * audio_output_lengths[item_idx]
if len(multi_modal_data) == 0:
return MultiModalKwargs()
processor = cached_get_processor(ctx.model_config.model)
audio_feature_extractor = processor.feature_extractor
if audio_feature_extractor is None:
raise RuntimeError(
"No HuggingFace audio_feature_extractor is available "
"to process the audio object")
try:
resampled_audios = [
librosa.resample(
audio,
orig_sr=sampling_rate,
target_sr=processor.feature_extractor.sampling_rate)
for audio, sampling_rate in multi_modal_data
return [
PromptReplacement(
modality="audio",
target=[placeholder],
replacement=get_replacement_qwen2_audio,
)
]
batch_data = audio_feature_extractor(resampled_audios,
sampling_rate=16000,
return_attention_mask=True,
padding="max_length",
return_tensors="pt").data
batch_data["feature_attention_mask"] = batch_data.pop("attention_mask")
except Exception:
logger.error("Failed to process audio (%s)", multi_modal_data)
raise
return MultiModalKwargs(batch_data)
def _get_dummy_mm_inputs(
self,
mm_counts: Mapping[str, int],
) -> ProcessorInputs:
audio_len = get_max_qwen2_audio_audio_tokens(self.ctx)
audio_count = mm_counts["audio"]
audio = np.zeros(audio_len)
data = {"audio": [audio] * audio_count}
return ProcessorInputs(
prompt_text="<|AUDIO|>" * audio_count,
mm_data=data,
mm_processor_kwargs={},
)
@INPUT_REGISTRY.register_dummy_data(dummy_data_for_qwen2_audio)
@INPUT_REGISTRY.register_input_processor(input_processor_for_qwen2_audio)
@MULTIMODAL_REGISTRY.register_input_mapper("audio",
input_mapper_for_qwen2_audio)
@MULTIMODAL_REGISTRY.register_max_multimodal_tokens(
"audio", get_max_qwen2_audio_audio_tokens)
@MULTIMODAL_REGISTRY.register_processor(Qwen2AudioMultiModalProcessor)
class Qwen2AudioForConditionalGeneration(nn.Module, SupportsMultiModal,
SupportsPP):
@ -289,9 +214,7 @@ class Qwen2AudioForConditionalGeneration(nn.Module, SupportsMultiModal,
return get_sampler()
def _validate_and_reshape_mm_tensor(self,
mm_input: Union[torch.Tensor,
List[torch.Tensor]],
def _validate_and_reshape_mm_tensor(self, mm_input: object,
name: str) -> torch.Tensor:
if not isinstance(mm_input, (torch.Tensor, list)):
raise ValueError(f"Incorrect type of {name}. "

View File

@ -3,7 +3,7 @@
import math
from functools import cached_property, lru_cache
from typing import (Any, Dict, Iterable, List, Literal, Mapping, Optional, Set,
from typing import (Any, Iterable, List, Literal, Mapping, Optional, Set,
Tuple, TypedDict, Union)
import numpy as np
@ -11,7 +11,7 @@ import torch
import torch.utils.checkpoint
from torch import nn
from torch.nn import functional as F
from transformers import BatchFeature
from transformers import BatchFeature, ProcessorMixin
from transformers.models.whisper import WhisperFeatureExtractor
from transformers.models.whisper.modeling_whisper import WhisperEncoder
@ -25,11 +25,11 @@ from vllm.model_executor.model_loader.loader import DefaultModelLoader
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.multimodal import MULTIMODAL_REGISTRY, NestedTensors
from vllm.multimodal.processing import (BaseMultiModalProcessor,
MultiModalDataDict,
MultiModalDataItems, ProcessorInputs,
PromptReplacement)
from vllm.sequence import IntermediateTensors
from vllm.transformers_utils.configs.ultravox import UltravoxConfig
from vllm.utils import is_list_of
from .interfaces import SupportsMultiModal, SupportsPP
from .utils import (AutoWeightsLoader, WeightsMapper, flatten_bn,
@ -61,8 +61,8 @@ def cached_feature_extractor(model_id: str) -> WhisperFeatureExtractor:
def whisper_feature_extractor(ctx: InputContext) -> WhisperFeatureExtractor:
return cached_feature_extractor(
ctx.get_hf_config(UltravoxConfig).audio_model_id)
hf_config = ctx.get_hf_config(UltravoxConfig)
return cached_feature_extractor(hf_config.audio_model_id)
def get_ultravox_max_audio_tokens(ctx: InputContext):
@ -73,72 +73,71 @@ def get_ultravox_max_audio_tokens(ctx: InputContext):
class UltravoxMultiModalProcessor(BaseMultiModalProcessor):
def _get_feature_extractor(self) -> WhisperFeatureExtractor:
return self._get_hf_processor().audio_processor.feature_extractor
def _resample_audio(
self,
audio: np.ndarray,
sr: int,
) -> Dict[str, Union[np.ndarray, int]]:
# resample audio to the model's sampling rate
feature_extractor = self._get_feature_extractor()
if sr != feature_extractor.sampling_rate:
try:
import librosa
except ImportError as exc:
raise ImportError(
"Please install vllm[audio] for audio support.") from exc
audio = librosa.resample(audio,
orig_sr=sr,
target_sr=feature_extractor.sampling_rate)
sr = feature_extractor.sampling_rate
return {"audio": audio, "sampling_rate": sr}
def _apply_hf_processor(
self,
prompt: str,
mm_data: MultiModalDataDict,
mm_processor_kwargs: Mapping[str, object],
) -> BatchFeature:
if not mm_data or not mm_data.get("audio", None):
return super()._apply_hf_processor(prompt, mm_data,
mm_processor_kwargs)
audio_data = mm_data["audio"]
if not isinstance(audio_data, list):
audio_data = [audio_data]
# Ultravox processor doesn't support multiple inputs,
# therefore we need to input text and audio one by one
tokenizer = self._get_tokenizer()
audio_features, audio_token_len = [], []
processed_inputs = {}
for audio, sr in audio_data:
data = self._resample_audio(audio, sr)
processed_inputs = super()._apply_hf_processor(
prompt, data, mm_processor_kwargs)
prompt = tokenizer.decode(processed_inputs["input_ids"][0],
skip_special_tokens=False)
audio_features.append(
processed_inputs.pop("audio_values").squeeze(0))
audio_token_len.append(
processed_inputs.pop("audio_token_len").item())
return dict(
**processed_inputs,
audio_features=audio_features,
audio_token_len=audio_token_len,
)
hf_processor = self._get_hf_processor()
return hf_processor.audio_processor.feature_extractor # type: ignore
def _get_processor_data(
self,
mm_data: MultiModalDataDict,
) -> Tuple[Dict[str, Any], Dict[str, Any]]:
# Ultravox uses "audio" instead of "audios" as calling keyword
processor_data, passthrough_data = super()._get_processor_data(mm_data)
if "audios" in processor_data:
processor_data["audio"] = processor_data.pop("audios")
return processor_data, passthrough_data
mm_items: MultiModalDataItems,
) -> tuple[dict[str, Any], dict[str, Any]]:
# resample audio to the model's sampling rate
feature_extractor = self._get_feature_extractor()
mm_items.resample_audios(feature_extractor.sampling_rate)
return super()._get_processor_data(mm_items)
def _call_hf_processor(
self,
hf_processor: ProcessorMixin,
prompt: str,
processor_data: Mapping[str, object],
mm_processor_kwargs: Mapping[str, object],
) -> BatchFeature:
processor_data = dict(processor_data)
audios = processor_data.pop("audios", [])
if not audios:
return super()._call_hf_processor(
hf_processor,
prompt=prompt,
processor_data=processor_data,
mm_processor_kwargs=mm_processor_kwargs,
)
feature_extractor = self._get_feature_extractor()
mm_processor_kwargs = dict(
**mm_processor_kwargs,
sampling_rate=feature_extractor.sampling_rate,
)
# Already resampled by _get_processor_data
assert is_list_of(audios, np.ndarray)
# Ultravox processor doesn't support multiple inputs,
# therefore we need to input text and audio one by one
audio_features, audio_token_len = [], []
shared_outputs = {}
for audio in audios:
# NOTE: Ultravox processor accepts "audio" instead of "audios"
item_processor_data = dict(**processor_data, audio=audio)
item_outputs = super()._call_hf_processor(
hf_processor,
prompt=prompt,
processor_data=item_processor_data,
mm_processor_kwargs=mm_processor_kwargs,
)
audio_features.append(item_outputs.pop("audio_values")[0])
audio_token_len.append(item_outputs.pop("audio_token_len").item())
shared_outputs = item_outputs
combined_outputs = dict(
**shared_outputs,
audio_features=audio_features,
audio_token_len=audio_token_len,
)
return BatchFeature(combined_outputs)
def _get_prompt_replacements(
self,
@ -147,7 +146,7 @@ class UltravoxMultiModalProcessor(BaseMultiModalProcessor):
mm_processor_kwargs: Mapping[str, object],
) -> list[PromptReplacement]:
hf_processor = self._get_hf_processor()
placeholder = hf_processor.audio_token_replacement
placeholder = hf_processor.audio_token_replacement # type: ignore
def get_replacement_ultravox(item_idx: int):
audio_token_len = hf_inputs["audio_token_len"][item_idx]
@ -171,7 +170,7 @@ class UltravoxMultiModalProcessor(BaseMultiModalProcessor):
audio_count = mm_counts["audio"]
audio = np.zeros(audio_len)
data = {"audio": [(audio, sampling_rate)] * audio_count}
data = {"audio": [audio] * audio_count}
return ProcessorInputs(
prompt_text="<|audio|>" * audio_count,

View File

@ -1,3 +1,6 @@
import numpy as np
import numpy.typing as npt
from vllm.inputs.registry import InputContext
from .base import MultiModalPlugin
@ -21,3 +24,18 @@ class AudioPlugin(MultiModalPlugin):
def _default_max_multimodal_tokens(self, ctx: InputContext) -> int:
raise NotImplementedError(
"There is no default maximum multimodal tokens")
def resample_audio(
audio: npt.NDArray[np.floating],
*,
orig_sr: float,
target_sr: float,
) -> npt.NDArray[np.floating]:
try:
import librosa
except ImportError as exc:
msg = "Please install vllm[audio] for audio support."
raise ImportError(msg) from exc
return librosa.resample(audio, orig_sr=orig_sr, target_sr=target_sr)

View File

@ -15,31 +15,32 @@ _T = TypeVar("_T")
# yapf: disable
ImageItem: TypeAlias = Union[Image, np.ndarray, torch.Tensor]
"""
A :class:`transformers.image_utils.ImageInput` representing a single image,
which can be passed to a HuggingFace :code:`ImageProcessor`.
A :class:`transformers.image_utils.ImageInput` representing a single image
item, which can be passed to a HuggingFace :code:`ImageProcessor`.
"""
VideoItem: TypeAlias = Union[
List[Image],
list[Image],
np.ndarray,
torch.Tensor,
List[np.ndarray],
List[torch.Tensor],
list[np.ndarray],
list[torch.Tensor],
]
"""
A :class:`transformers.image_utils.VideoInput` representing a single video,
which can be passed to a HuggingFace :code:`VideoProcessor`.
A :class:`transformers.image_utils.VideoInput` representing a single video
item, which can be passed to a HuggingFace :code:`VideoProcessor`.
"""
AudioItem: TypeAlias = Union[
np.ndarray,
List[float],
Tuple[np.ndarray, float], # DEPRECATED: Use mm_processor_kwargs instead
list[float],
# `(audio, sampling_rate)`: If the audio's sampling rate is different
# from that expected by the model, we need to resample it.
tuple[np.ndarray, float],
]
"""
Represents a single audio that can be inputted to a HuggingFace
:code:`AudioProcessor`.
Represents a single audio
item, which can be passed to a HuggingFace :code:`AudioProcessor`.
"""
# yapf: enable

View File

@ -17,6 +17,7 @@ from vllm.logger import init_logger
from vllm.transformers_utils.tokenizer import AnyTokenizer, MistralTokenizer
from vllm.utils import flatten_2d_lists, full_groupby, is_list_of
from .audio import resample_audio
from .inputs import (AudioItem, ImageItem, MultiModalDataDict,
MultiModalInputsV2, MultiModalKwargs, PlaceholderRange,
VideoItem)
@ -30,7 +31,7 @@ _PromptSeq = Union[str, list[int]]
@dataclass
class PromptReplacement:
modality: str
"""The modality for which the replacement is made"""
"""The modality for which the replacement is made."""
target: _PromptSeq
"""The text or token sequence to find and replace."""
@ -211,20 +212,48 @@ class MultiModalDataItems(UserDict[str, list[Any]]):
corresponds to a list.
"""
@staticmethod
def from_dict(data: MultiModalDataDict) -> "MultiModalDataItems":
"""
Normalize :class:`MultiModalDataDict` to :class:`MultiModalDataItems`.
"""
multi_data = MultiModalDataItems()
for k, v in data.items():
# yapf: disable
if k == "video":
# Special case since even a single item can be a list
multi_data[k] = ( # type: ignore[index]
v if is_list_of(v, (list, torch.Tensor)) else [v]
)
elif k in ("image", "audio"):
multi_data[k] = ( # type: ignore[index]
v if isinstance(v, (list, torch.Tensor)) else [v]
)
else:
multi_data[k] = v if isinstance(v, list) else [v] # type: ignore[index]
# yapf: enable
return multi_data
# NOTE: When a field (e.g. `images`) doesn't exist, directly appending to
# `self.images` doesn't update this dictionary, which may be confusing
# We annotate the getter methods as `Sequence` to prevent others from
# trying to update the list in this way
@property
def image(self) -> list[ImageItem]:
return self["image"]
def images(self) -> Sequence[ImageItem]:
return self.get("image", [])
@property
def video(self) -> list[VideoItem]:
return self["video"]
def videos(self) -> Sequence[VideoItem]:
return self.get("video", [])
@property
def audio(self) -> list[AudioItem]:
return self["audio"]
def audios(self) -> Sequence[AudioItem]:
return self.get("audio", [])
def get_image_size(self, item_idx: int) -> ImageSize:
image = self.image[item_idx]
image = self.images[item_idx]
if isinstance(image, Image):
return ImageSize(*image.size)
@ -234,25 +263,41 @@ class MultiModalDataItems(UserDict[str, list[Any]]):
assert_never(image)
def get_audio_with_sr(
self,
item_idx: int,
*,
default_sr: float,
) -> tuple[np.ndarray, float]:
audio = self.audios[item_idx]
def to_multi_format(data: MultiModalDataDict) -> MultiModalDataItems:
"""
Normalize :class:`MultiModalDataDict` to :class:`MultiModalDataItems`.
"""
multi_data = MultiModalDataItems()
if isinstance(audio, tuple):
return audio
if isinstance(audio, list):
return np.array(audio), default_sr
if isinstance(audio, np.ndarray):
return audio, default_sr
for k, v in data.items():
# yapf: disable
if k == "video":
# Special case since even a single item can be a list
multi_data[k] = v if is_list_of(v, list) else [v] # type: ignore[index]
elif k in ("image", "audio"):
multi_data[k] = v if isinstance(v, list) else [v] # type: ignore[index]
else:
multi_data[k] = v if isinstance(v, list) else [v] # type: ignore[index]
# yapf: enable
assert_never(audio)
return multi_data
def resample_audios(self, new_sr: float, *, drop_sr: bool = True) -> None:
"""
If :code:`drop_sr=True`, the audio items in this dictionary are updated
to be NumPy arrays which implicitly means that their sampling rate is
the same as the model's expected sampling rate; otherwise, they remain
as :code:`(audio, new_sr)` tuples.
"""
if not self.audios:
return
new_audios = []
for item_idx in range(len(self.audios)):
audio, sr = self.get_audio_with_sr(item_idx, default_sr=new_sr)
audio = resample_audio(audio, orig_sr=sr, target_sr=new_sr)
new_audios.append(audio if drop_sr else (audio, new_sr))
self["audio"] = new_audios
class _TokenMatch(NamedTuple):
@ -596,18 +641,20 @@ class BaseMultiModalProcessor(ABC):
def _get_processor_data(
self,
mm_data: MultiModalDataDict,
) -> BatchFeature:
mm_items: MultiModalDataItems,
) -> tuple[dict[str, Any], dict[str, Any]]:
processor_data = dict[str, Any]()
passthrough_data = dict[str, Any]()
for k, v in mm_data.items():
for k, v in mm_items.items():
# TODO: Make a separate modality for embedding inputs
# to avoid confusion
if k in ("image", "video", "audio"):
if isinstance(v, torch.Tensor) and v.ndim == 3:
# Pass through embedding inputs (single)
passthrough_data[f"{k}_embeds"] = [v]
elif is_list_of(v, torch.Tensor) and v[0].ndim == 2:
elif (is_list_of(v, torch.Tensor) and len(v) > 0
and v[0].ndim == 2):
# Pass through embedding inputs (multi)
passthrough_data[f"{k}_embeds"] = v
else:
@ -615,40 +662,41 @@ class BaseMultiModalProcessor(ABC):
processor_data[f"{k}s"] = v
else:
processor_data[k] = v
return processor_data, passthrough_data
def _call_hf_processor(
self,
hf_processor: ProcessorMixin,
prompt: str,
processor_data: Mapping[str, object],
mm_processor_kwargs: Mapping[str, object],
) -> BatchFeature:
return self.ctx.call_hf_processor(
hf_processor,
prompt,
processor_data,
mm_processor_kwargs,
)
def _apply_hf_processor(
self,
prompt: str,
mm_data: MultiModalDataDict,
mm_items: MultiModalDataItems,
mm_processor_kwargs: Mapping[str, object],
) -> BatchFeature:
# some mm_processor_kwargs may be used in processor initialization
# instead of processor call
hf_processor = self._get_hf_processor(**mm_processor_kwargs)
processor_data, passthrough_data = self._get_processor_data(mm_data)
processor_data, passthrough_data = self._get_processor_data(mm_items)
assert callable(hf_processor)
mm_processor_kwargs = self.ctx.resolve_hf_processor_call_kwargs(
hf_inputs = self._call_hf_processor(
hf_processor,
mm_processor_kwargs,
prompt=prompt,
processor_data=processor_data,
mm_processor_kwargs=mm_processor_kwargs,
)
try:
hf_inputs = hf_processor(
text=prompt, # type: ignore
**processor_data,
**mm_processor_kwargs,
return_tensors="pt",
)
except Exception as exc:
data = dict(text=prompt, **processor_data)
raise RuntimeError(
f"Failed to apply {type(hf_processor).__name__} "
f"on data={data} with kwargs={mm_processor_kwargs}") from exc
hf_inputs.update(passthrough_data)
return hf_inputs
@ -730,14 +778,13 @@ class BaseMultiModalProcessor(ABC):
3. Extract information about the placeholder tokens from the
processed token IDs.
"""
tokenizer = self._get_tokenizer()
mm_items = MultiModalDataItems.from_dict(mm_data)
hf_inputs = self._apply_hf_processor(prompt_text, mm_data,
hf_inputs = self._apply_hf_processor(prompt_text, mm_items,
mm_processor_kwargs)
prompt_ids, = hf_inputs.pop("input_ids").tolist()
mm_kwargs = MultiModalKwargs(hf_inputs)
mm_items = to_multi_format(mm_data)
prompt_repls = self._get_prompt_replacements(mm_items, hf_inputs,
mm_processor_kwargs)
all_prompt_repls = self._bind_prompt_replacements(prompt_repls)
@ -749,6 +796,7 @@ class BaseMultiModalProcessor(ABC):
prompt_ids, mm_item_counts)
if all_placeholders:
tokenizer = self._get_tokenizer()
prompt_text = _decode(tokenizer, prompt_ids)
else:
(

View File

@ -775,7 +775,7 @@ def get_dtype_size(dtype: torch.dtype) -> int:
# `collections` helpers
def is_list_of(
value: object,
typ: Type[T],
typ: Union[type[T], tuple[type[T], ...]],
*,
check: Literal["first", "all"] = "first",
) -> TypeIs[List[T]]:
@ -1282,6 +1282,7 @@ async def _run_task_with_lock(task: Callable, lock: asyncio.Lock, *args,
def supports_kw(
callable: Callable[..., object],
kw_name: str,
*,
requires_kw_only: bool = False,
allow_var_kwargs: bool = True,
) -> bool:
@ -1326,6 +1327,8 @@ def resolve_mm_processor_kwargs(
init_kwargs: Optional[Mapping[str, object]],
inference_kwargs: Optional[Mapping[str, object]],
callable: Callable[..., object],
*,
requires_kw_only: bool = True,
allow_var_kwargs: bool = False,
) -> Dict[str, Any]:
"""Applies filtering to eliminate invalid mm_processor_kwargs, i.e.,
@ -1344,11 +1347,17 @@ def resolve_mm_processor_kwargs(
runtime_mm_kwargs = get_allowed_kwarg_only_overrides(
callable,
overrides=inference_kwargs,
allow_var_kwargs=allow_var_kwargs)
requires_kw_only=requires_kw_only,
allow_var_kwargs=allow_var_kwargs,
)
# Filter init time multimodal processor kwargs provided
init_mm_kwargs = get_allowed_kwarg_only_overrides(
callable, overrides=init_kwargs, allow_var_kwargs=allow_var_kwargs)
callable,
overrides=init_kwargs,
requires_kw_only=requires_kw_only,
allow_var_kwargs=allow_var_kwargs,
)
# Merge the final processor kwargs, prioritizing inference
# time values over the initialization time values.
@ -1359,6 +1368,8 @@ def resolve_mm_processor_kwargs(
def get_allowed_kwarg_only_overrides(
callable: Callable[..., object],
overrides: Optional[Mapping[str, object]],
*,
requires_kw_only: bool = True,
allow_var_kwargs: bool = False,
) -> Dict[str, Any]:
"""
@ -1390,16 +1401,21 @@ def get_allowed_kwarg_only_overrides(
for kwarg_name, val in overrides.items()
if supports_kw(callable,
kwarg_name,
requires_kw_only=True,
requires_kw_only=requires_kw_only,
allow_var_kwargs=allow_var_kwargs)
}
# If anything is dropped, log a warning
dropped_keys = overrides.keys() - filtered_overrides.keys()
if dropped_keys:
logger.warning(
"The following intended overrides are not keyword-only args "
"and and will be dropped: %s", dropped_keys)
if requires_kw_only:
logger.warning(
"The following intended overrides are not keyword-only args "
"and and will be dropped: %s", dropped_keys)
else:
logger.warning(
"The following intended overrides are not keyword args "
"and and will be dropped: %s", dropped_keys)
return filtered_overrides