[Model] Refactor Ultravox to use merged input processor (#11198)

Signed-off-by: Isotr0py <2037008807@qq.com>
Co-authored-by: Cyrus Leung <cyrus.tl.leung@gmail.com>
This commit is contained in:
Isotr0py 2024-12-16 18:09:53 +08:00 committed by GitHub
parent bddbbcb132
commit d927dbcd88
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
7 changed files with 121 additions and 146 deletions

View File

@ -25,16 +25,16 @@ def run_ultravox(question: str, audio_count: int):
tokenizer = AutoTokenizer.from_pretrained(model_name)
messages = [{
'role':
'user',
'content':
"<|reserved_special_token_0|>\n" * audio_count + question
'role': 'user',
'content': "<|audio|>\n" * audio_count + question
}]
prompt = tokenizer.apply_chat_template(messages,
tokenize=False,
add_generation_prompt=True)
llm = LLM(model=model_name, limit_mm_per_prompt={"audio": audio_count})
llm = LLM(model=model_name,
trust_remote_code=True,
limit_mm_per_prompt={"audio": audio_count})
stop_token_ids = None
return llm, prompt, stop_token_ids

View File

@ -214,7 +214,7 @@ MULTIMODAL_MODELS = {
"Qwen/Qwen-VL-Chat": PPTestSettings.fast(trust_remote_code=True),
"Qwen/Qwen2-Audio-7B-Instruct": PPTestSettings.fast(),
"Qwen/Qwen2-VL-2B-Instruct": PPTestSettings.fast(),
"fixie-ai/ultravox-v0_3": PPTestSettings.fast(),
"fixie-ai/ultravox-v0_3": PPTestSettings.fast(trust_remote_code=True),
# [Encoder-decoder]
# TODO: Implement PP
# "meta-llama/Llama-3.2-11B-Vision-Instruct": PPTestSettings.fast(),

View File

@ -25,6 +25,7 @@ def server():
"--max-num-seqs",
"5",
"--enforce-eager",
"--trust-remote-code",
]
with RemoteOpenAIServer(MODEL_NAME, args) as remote_server:

View File

@ -16,7 +16,7 @@ MODEL_NAME = "fixie-ai/ultravox-v0_3"
AudioTuple = Tuple[np.ndarray, int]
VLLM_PLACEHOLDER = "<|reserved_special_token_0|>"
VLLM_PLACEHOLDER = "<|audio|>"
HF_PLACEHOLDER = "<|audio|>"
CHUNKED_PREFILL_KWARGS = {
@ -46,7 +46,8 @@ def audio(request):
def server(request, audio_assets):
args = [
"--dtype=bfloat16", "--max-model-len=4096", "--enforce-eager",
f"--limit-mm-per-prompt=audio={len(audio_assets)}"
f"--limit-mm-per-prompt=audio={len(audio_assets)}",
"--trust-remote-code"
] + [
f"--{key.replace('_','-')}={value}"
for key, value in request.param.items()

View File

@ -418,7 +418,7 @@ class BaseMultiModalItemTracker(ABC, Generic[_T]):
raise TypeError(f"Unknown {modality} model type: {model_type}")
elif modality == "audio":
if model_type == "ultravox":
return "<|reserved_special_token_0|>"
return "<|audio|>"
if model_type == "qwen2_audio":
return (f"Audio {current_count}: "
f"<|audio_bos|><|AUDIO|><|audio_eos|>")

View File

@ -3,41 +3,39 @@
import math
from functools import cached_property, lru_cache
from typing import (Iterable, List, Literal, Mapping, Optional, Set, Tuple,
TypedDict, Union, cast)
from typing import (Any, Dict, Iterable, List, Literal, Mapping, Optional, Set,
Tuple, TypedDict, Union)
import numpy as np
import torch
import torch.utils.checkpoint
from torch import nn
from torch.nn import functional as F
from transformers import BatchFeature
from transformers.models.whisper import WhisperFeatureExtractor
from transformers.models.whisper.modeling_whisper import WhisperEncoder
from vllm.attention import AttentionMetadata
from vllm.config import VllmConfig
from vllm.inputs import (INPUT_REGISTRY, DecoderOnlyInputs, DummyData,
InputContext, token_inputs)
from vllm.inputs import InputContext
from vllm.model_executor.layers.activation import SiluAndMul, get_act_fn
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler
from vllm.model_executor.model_loader.loader import DefaultModelLoader
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.multimodal import (MULTIMODAL_REGISTRY, MultiModalKwargs,
NestedTensors)
from vllm.multimodal.utils import (cached_get_tokenizer,
consecutive_placeholder_ranges,
repeat_and_pad_placeholder_tokens)
from vllm.sequence import IntermediateTensors, SequenceData
from vllm.multimodal import MULTIMODAL_REGISTRY, NestedTensors
from vllm.multimodal.processing import (BaseMultiModalProcessor,
MultiModalDataDict,
MultiModalDataItems, ProcessorInputs,
PromptReplacement)
from vllm.sequence import IntermediateTensors
from vllm.transformers_utils.configs.ultravox import UltravoxConfig
from vllm.utils import is_list_of
from .interfaces import SupportsMultiModal, SupportsPP
from .utils import (AutoWeightsLoader, WeightsMapper, flatten_bn,
init_vllm_registered_model, maybe_prefix,
merge_multimodal_embeddings_from_map)
_AUDIO_PLACEHOLDER_TOKEN = 128002
_AUDIO_TOKENS_PER_SECOND = 6.25
@ -72,64 +70,18 @@ def get_ultravox_max_audio_tokens(ctx: InputContext):
return math.ceil(feature_extractor.chunk_length * _AUDIO_TOKENS_PER_SECOND)
def dummy_seq_data_for_ultravox(
ctx: InputContext,
seq_len: int,
audio_count: int,
):
audio_length = min(get_ultravox_max_audio_tokens(ctx),
seq_len // audio_count)
class UltravoxMultiModalProcessor(BaseMultiModalProcessor):
return SequenceData.from_prompt_token_counts(
(_AUDIO_PLACEHOLDER_TOKEN, audio_length * audio_count),
(0, seq_len - audio_length * audio_count)), {
"audio":
consecutive_placeholder_ranges(num_items=audio_count,
item_size=audio_length)
}
def dummy_audio_for_ultravox(
ctx: InputContext,
audio_count: int,
):
feature_extractor = whisper_feature_extractor(ctx)
audio_and_sr = (np.array([0.0] * feature_extractor.chunk_length), 1)
return {"audio": [audio_and_sr] * audio_count}
def dummy_data_for_ultravox(
ctx: InputContext,
seq_len: int,
mm_counts: Mapping[str, int],
):
audio_count = mm_counts["audio"]
seq_data, ranges = dummy_seq_data_for_ultravox(ctx, seq_len, audio_count)
mm_dict = dummy_audio_for_ultravox(ctx, audio_count)
return DummyData(seq_data, mm_dict, ranges)
def input_mapper_for_ultravox(ctx: InputContext, data: object):
if not isinstance(data, list):
data = [data]
if len(data) == 0:
return MultiModalKwargs()
# If the audio inputs are embeddings, no need for preprocessing
if is_list_of(data, torch.Tensor, check="all"):
return MultiModalKwargs({"audio_embeds": data})
audio_features = []
for audio_input in data:
if not isinstance(audio_input, tuple):
raise NotImplementedError(
f"Unsupported data type: {type(audio_input)}")
(audio, sr) = cast(Tuple[np.ndarray, Union[float, int]], audio_input)
feature_extractor = whisper_feature_extractor(ctx)
def _get_feature_extractor(self) -> WhisperFeatureExtractor:
return self._get_hf_processor().audio_processor.feature_extractor
def _resample_audio(
self,
audio: np.ndarray,
sr: int,
) -> Dict[str, Union[np.ndarray, int]]:
# resample audio to the model's sampling rate
feature_extractor = self._get_feature_extractor()
if sr != feature_extractor.sampling_rate:
try:
import librosa
@ -140,78 +92,92 @@ def input_mapper_for_ultravox(ctx: InputContext, data: object):
orig_sr=sr,
target_sr=feature_extractor.sampling_rate)
sr = feature_extractor.sampling_rate
return {"audio": audio, "sampling_rate": sr}
minimum_audio_length = feature_extractor.n_fft // 2 + 1
if len(audio) < minimum_audio_length:
# Not enough audio; pad it.
audio = np.pad(audio, (0, minimum_audio_length - len(audio)))
def _apply_hf_processor(
self,
prompt: str,
mm_data: MultiModalDataDict,
mm_processor_kwargs: Mapping[str, object],
) -> BatchFeature:
if not mm_data or not mm_data.get("audio", None):
return super()._apply_hf_processor(prompt, mm_data,
mm_processor_kwargs)
single_audio_features = feature_extractor(
audio, sampling_rate=sr, padding="longest",
return_tensors="pt")["input_features"]
audio_data = mm_data["audio"]
if not isinstance(audio_data, list):
audio_data = [audio_data]
# Remove the batch dimension because we're wrapping it in a list.
audio_features.append(single_audio_features.squeeze(0))
# Ultravox processor doesn't support multiple inputs,
# therefore we need to input text and audio one by one
tokenizer = self._get_tokenizer()
audio_features, audio_token_len = [], []
processed_inputs = {}
for audio, sr in audio_data:
data = self._resample_audio(audio, sr)
processed_inputs = super()._apply_hf_processor(
prompt, data, mm_processor_kwargs)
prompt = tokenizer.decode(processed_inputs["input_ids"][0],
skip_special_tokens=False)
audio_features.append(
processed_inputs.pop("audio_values").squeeze(0))
audio_token_len.append(
processed_inputs.pop("audio_token_len").item())
return MultiModalKwargs({"audio_features": audio_features})
return dict(
**processed_inputs,
audio_features=audio_features,
audio_token_len=audio_token_len,
)
def _get_processor_data(
self,
mm_data: MultiModalDataDict,
) -> Tuple[Dict[str, Any], Dict[str, Any]]:
# Ultravox uses "audio" instead of "audios" as calling keyword
processor_data, passthrough_data = super()._get_processor_data(mm_data)
if "audios" in processor_data:
processor_data["audio"] = processor_data.pop("audios")
return processor_data, passthrough_data
def input_processor_for_ultravox(ctx: InputContext, inputs: DecoderOnlyInputs):
multi_modal_data = inputs.get("multi_modal_data")
if multi_modal_data is None or "audio" not in multi_modal_data:
return inputs
def _get_prompt_replacements(
self,
mm_items: MultiModalDataItems,
hf_inputs: BatchFeature,
mm_processor_kwargs: Mapping[str, object],
) -> list[PromptReplacement]:
hf_processor = self._get_hf_processor()
placeholder = hf_processor.audio_token_replacement
if "multi_modal_placeholders" in inputs and "audio" in inputs[
"multi_modal_placeholders"]:
# The inputs already have placeholders.
return inputs
def get_replacement_ultravox(item_idx: int):
audio_token_len = hf_inputs["audio_token_len"][item_idx]
return placeholder * audio_token_len
feature_extractor = whisper_feature_extractor(ctx)
audios = multi_modal_data["audio"]
if not isinstance(audios, list):
audios = [audios]
return [
PromptReplacement(
modality="audio",
target="<|audio|>",
replacement=get_replacement_ultravox,
)
]
audio_token_counts = []
for audio in audios:
if isinstance(audio, torch.Tensor):
audio_num_tokens = audio.shape[1]
audio_token_counts.append(audio_num_tokens)
else:
audio_data, sample_rate = audio
audio_length = audio_data.shape[0]
if sample_rate != feature_extractor.sampling_rate:
# Account for resampling.
adjustment = feature_extractor.sampling_rate / sample_rate
audio_length = math.ceil(adjustment * audio_length)
def _get_dummy_mm_inputs(
self,
mm_counts: Mapping[str, int],
) -> ProcessorInputs:
feature_extractor = self._get_feature_extractor()
sampling_rate = feature_extractor.sampling_rate
audio_len = feature_extractor.chunk_length * sampling_rate
feature_extractor_output_length = math.ceil(
(audio_length - (feature_extractor.hop_length - 1)) /
feature_extractor.hop_length)
audio_count = mm_counts["audio"]
audio = np.zeros(audio_len)
data = {"audio": [(audio, sampling_rate)] * audio_count}
uv_config = ctx.get_hf_config(UltravoxConfig)
audio_num_tokens = min(
max(
1,
math.ceil(feature_extractor_output_length /
(uv_config.stack_factor * 2))),
get_ultravox_max_audio_tokens(ctx))
audio_token_counts.append(audio_num_tokens)
tokenizer = cached_get_tokenizer(ctx.model_config.tokenizer)
new_prompt, new_token_ids, ranges = repeat_and_pad_placeholder_tokens(
tokenizer,
inputs.get("prompt"),
inputs["prompt_token_ids"],
placeholder_token_id=_AUDIO_PLACEHOLDER_TOKEN,
repeat_count=audio_token_counts,
)
# NOTE: Create a defensive copy of the original inputs
return token_inputs(prompt_token_ids=new_token_ids,
prompt=new_prompt,
multi_modal_data=multi_modal_data,
multi_modal_placeholders={"audio": ranges})
return ProcessorInputs(
prompt_text="<|audio|>" * audio_count,
mm_data=data,
mm_processor_kwargs={},
)
class StackAudioFrames(nn.Module):
@ -332,11 +298,9 @@ class ModifiedWhisperEncoder(WhisperEncoder):
return hidden_states
@MULTIMODAL_REGISTRY.register_input_mapper("audio", input_mapper_for_ultravox)
@MULTIMODAL_REGISTRY.register_max_multimodal_tokens(
"audio", get_ultravox_max_audio_tokens)
@INPUT_REGISTRY.register_dummy_data(dummy_data_for_ultravox)
@INPUT_REGISTRY.register_input_processor(input_processor_for_ultravox)
@MULTIMODAL_REGISTRY.register_processor(UltravoxMultiModalProcessor)
class UltravoxModel(nn.Module, SupportsMultiModal, SupportsPP):
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):

View File

@ -594,14 +594,10 @@ class BaseMultiModalProcessor(ABC):
return list(
iter_placeholders(all_prompt_repls, new_token_ids, mm_item_counts))
def _apply_hf_processor(
def _get_processor_data(
self,
prompt: str,
mm_data: MultiModalDataDict,
mm_processor_kwargs: Mapping[str, object],
) -> BatchFeature:
hf_processor = self._get_hf_processor(**mm_processor_kwargs)
processor_data = dict[str, Any]()
passthrough_data = dict[str, Any]()
for k, v in mm_data.items():
@ -619,6 +615,19 @@ class BaseMultiModalProcessor(ABC):
processor_data[f"{k}s"] = v
else:
processor_data[k] = v
return processor_data, passthrough_data
def _apply_hf_processor(
self,
prompt: str,
mm_data: MultiModalDataDict,
mm_processor_kwargs: Mapping[str, object],
) -> BatchFeature:
# some mm_processor_kwargs may be used in processor initialization
# instead of processor call
hf_processor = self._get_hf_processor(**mm_processor_kwargs)
processor_data, passthrough_data = self._get_processor_data(mm_data)
assert callable(hf_processor)
mm_processor_kwargs = self.ctx.resolve_hf_processor_call_kwargs(