[Model] Refactor Ultravox to use merged input processor (#11198)
Signed-off-by: Isotr0py <2037008807@qq.com> Co-authored-by: Cyrus Leung <cyrus.tl.leung@gmail.com>
This commit is contained in:
parent
bddbbcb132
commit
d927dbcd88
@ -25,16 +25,16 @@ def run_ultravox(question: str, audio_count: int):
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
||||
messages = [{
|
||||
'role':
|
||||
'user',
|
||||
'content':
|
||||
"<|reserved_special_token_0|>\n" * audio_count + question
|
||||
'role': 'user',
|
||||
'content': "<|audio|>\n" * audio_count + question
|
||||
}]
|
||||
prompt = tokenizer.apply_chat_template(messages,
|
||||
tokenize=False,
|
||||
add_generation_prompt=True)
|
||||
|
||||
llm = LLM(model=model_name, limit_mm_per_prompt={"audio": audio_count})
|
||||
llm = LLM(model=model_name,
|
||||
trust_remote_code=True,
|
||||
limit_mm_per_prompt={"audio": audio_count})
|
||||
stop_token_ids = None
|
||||
return llm, prompt, stop_token_ids
|
||||
|
||||
|
@ -214,7 +214,7 @@ MULTIMODAL_MODELS = {
|
||||
"Qwen/Qwen-VL-Chat": PPTestSettings.fast(trust_remote_code=True),
|
||||
"Qwen/Qwen2-Audio-7B-Instruct": PPTestSettings.fast(),
|
||||
"Qwen/Qwen2-VL-2B-Instruct": PPTestSettings.fast(),
|
||||
"fixie-ai/ultravox-v0_3": PPTestSettings.fast(),
|
||||
"fixie-ai/ultravox-v0_3": PPTestSettings.fast(trust_remote_code=True),
|
||||
# [Encoder-decoder]
|
||||
# TODO: Implement PP
|
||||
# "meta-llama/Llama-3.2-11B-Vision-Instruct": PPTestSettings.fast(),
|
||||
|
@ -25,6 +25,7 @@ def server():
|
||||
"--max-num-seqs",
|
||||
"5",
|
||||
"--enforce-eager",
|
||||
"--trust-remote-code",
|
||||
]
|
||||
|
||||
with RemoteOpenAIServer(MODEL_NAME, args) as remote_server:
|
||||
|
@ -16,7 +16,7 @@ MODEL_NAME = "fixie-ai/ultravox-v0_3"
|
||||
|
||||
AudioTuple = Tuple[np.ndarray, int]
|
||||
|
||||
VLLM_PLACEHOLDER = "<|reserved_special_token_0|>"
|
||||
VLLM_PLACEHOLDER = "<|audio|>"
|
||||
HF_PLACEHOLDER = "<|audio|>"
|
||||
|
||||
CHUNKED_PREFILL_KWARGS = {
|
||||
@ -46,7 +46,8 @@ def audio(request):
|
||||
def server(request, audio_assets):
|
||||
args = [
|
||||
"--dtype=bfloat16", "--max-model-len=4096", "--enforce-eager",
|
||||
f"--limit-mm-per-prompt=audio={len(audio_assets)}"
|
||||
f"--limit-mm-per-prompt=audio={len(audio_assets)}",
|
||||
"--trust-remote-code"
|
||||
] + [
|
||||
f"--{key.replace('_','-')}={value}"
|
||||
for key, value in request.param.items()
|
||||
|
@ -418,7 +418,7 @@ class BaseMultiModalItemTracker(ABC, Generic[_T]):
|
||||
raise TypeError(f"Unknown {modality} model type: {model_type}")
|
||||
elif modality == "audio":
|
||||
if model_type == "ultravox":
|
||||
return "<|reserved_special_token_0|>"
|
||||
return "<|audio|>"
|
||||
if model_type == "qwen2_audio":
|
||||
return (f"Audio {current_count}: "
|
||||
f"<|audio_bos|><|AUDIO|><|audio_eos|>")
|
||||
|
@ -3,41 +3,39 @@
|
||||
|
||||
import math
|
||||
from functools import cached_property, lru_cache
|
||||
from typing import (Iterable, List, Literal, Mapping, Optional, Set, Tuple,
|
||||
TypedDict, Union, cast)
|
||||
from typing import (Any, Dict, Iterable, List, Literal, Mapping, Optional, Set,
|
||||
Tuple, TypedDict, Union)
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.utils.checkpoint
|
||||
from torch import nn
|
||||
from torch.nn import functional as F
|
||||
from transformers import BatchFeature
|
||||
from transformers.models.whisper import WhisperFeatureExtractor
|
||||
from transformers.models.whisper.modeling_whisper import WhisperEncoder
|
||||
|
||||
from vllm.attention import AttentionMetadata
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.inputs import (INPUT_REGISTRY, DecoderOnlyInputs, DummyData,
|
||||
InputContext, token_inputs)
|
||||
from vllm.inputs import InputContext
|
||||
from vllm.model_executor.layers.activation import SiluAndMul, get_act_fn
|
||||
from vllm.model_executor.layers.layernorm import RMSNorm
|
||||
from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler
|
||||
from vllm.model_executor.model_loader.loader import DefaultModelLoader
|
||||
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
||||
from vllm.multimodal import (MULTIMODAL_REGISTRY, MultiModalKwargs,
|
||||
NestedTensors)
|
||||
from vllm.multimodal.utils import (cached_get_tokenizer,
|
||||
consecutive_placeholder_ranges,
|
||||
repeat_and_pad_placeholder_tokens)
|
||||
from vllm.sequence import IntermediateTensors, SequenceData
|
||||
from vllm.multimodal import MULTIMODAL_REGISTRY, NestedTensors
|
||||
from vllm.multimodal.processing import (BaseMultiModalProcessor,
|
||||
MultiModalDataDict,
|
||||
MultiModalDataItems, ProcessorInputs,
|
||||
PromptReplacement)
|
||||
from vllm.sequence import IntermediateTensors
|
||||
from vllm.transformers_utils.configs.ultravox import UltravoxConfig
|
||||
from vllm.utils import is_list_of
|
||||
|
||||
from .interfaces import SupportsMultiModal, SupportsPP
|
||||
from .utils import (AutoWeightsLoader, WeightsMapper, flatten_bn,
|
||||
init_vllm_registered_model, maybe_prefix,
|
||||
merge_multimodal_embeddings_from_map)
|
||||
|
||||
_AUDIO_PLACEHOLDER_TOKEN = 128002
|
||||
_AUDIO_TOKENS_PER_SECOND = 6.25
|
||||
|
||||
|
||||
@ -72,64 +70,18 @@ def get_ultravox_max_audio_tokens(ctx: InputContext):
|
||||
return math.ceil(feature_extractor.chunk_length * _AUDIO_TOKENS_PER_SECOND)
|
||||
|
||||
|
||||
def dummy_seq_data_for_ultravox(
|
||||
ctx: InputContext,
|
||||
seq_len: int,
|
||||
audio_count: int,
|
||||
):
|
||||
audio_length = min(get_ultravox_max_audio_tokens(ctx),
|
||||
seq_len // audio_count)
|
||||
class UltravoxMultiModalProcessor(BaseMultiModalProcessor):
|
||||
|
||||
return SequenceData.from_prompt_token_counts(
|
||||
(_AUDIO_PLACEHOLDER_TOKEN, audio_length * audio_count),
|
||||
(0, seq_len - audio_length * audio_count)), {
|
||||
"audio":
|
||||
consecutive_placeholder_ranges(num_items=audio_count,
|
||||
item_size=audio_length)
|
||||
}
|
||||
|
||||
|
||||
def dummy_audio_for_ultravox(
|
||||
ctx: InputContext,
|
||||
audio_count: int,
|
||||
):
|
||||
feature_extractor = whisper_feature_extractor(ctx)
|
||||
audio_and_sr = (np.array([0.0] * feature_extractor.chunk_length), 1)
|
||||
return {"audio": [audio_and_sr] * audio_count}
|
||||
|
||||
|
||||
def dummy_data_for_ultravox(
|
||||
ctx: InputContext,
|
||||
seq_len: int,
|
||||
mm_counts: Mapping[str, int],
|
||||
):
|
||||
audio_count = mm_counts["audio"]
|
||||
seq_data, ranges = dummy_seq_data_for_ultravox(ctx, seq_len, audio_count)
|
||||
mm_dict = dummy_audio_for_ultravox(ctx, audio_count)
|
||||
|
||||
return DummyData(seq_data, mm_dict, ranges)
|
||||
|
||||
|
||||
def input_mapper_for_ultravox(ctx: InputContext, data: object):
|
||||
if not isinstance(data, list):
|
||||
data = [data]
|
||||
|
||||
if len(data) == 0:
|
||||
return MultiModalKwargs()
|
||||
|
||||
# If the audio inputs are embeddings, no need for preprocessing
|
||||
if is_list_of(data, torch.Tensor, check="all"):
|
||||
return MultiModalKwargs({"audio_embeds": data})
|
||||
|
||||
audio_features = []
|
||||
for audio_input in data:
|
||||
if not isinstance(audio_input, tuple):
|
||||
raise NotImplementedError(
|
||||
f"Unsupported data type: {type(audio_input)}")
|
||||
|
||||
(audio, sr) = cast(Tuple[np.ndarray, Union[float, int]], audio_input)
|
||||
feature_extractor = whisper_feature_extractor(ctx)
|
||||
def _get_feature_extractor(self) -> WhisperFeatureExtractor:
|
||||
return self._get_hf_processor().audio_processor.feature_extractor
|
||||
|
||||
def _resample_audio(
|
||||
self,
|
||||
audio: np.ndarray,
|
||||
sr: int,
|
||||
) -> Dict[str, Union[np.ndarray, int]]:
|
||||
# resample audio to the model's sampling rate
|
||||
feature_extractor = self._get_feature_extractor()
|
||||
if sr != feature_extractor.sampling_rate:
|
||||
try:
|
||||
import librosa
|
||||
@ -140,78 +92,92 @@ def input_mapper_for_ultravox(ctx: InputContext, data: object):
|
||||
orig_sr=sr,
|
||||
target_sr=feature_extractor.sampling_rate)
|
||||
sr = feature_extractor.sampling_rate
|
||||
return {"audio": audio, "sampling_rate": sr}
|
||||
|
||||
minimum_audio_length = feature_extractor.n_fft // 2 + 1
|
||||
if len(audio) < minimum_audio_length:
|
||||
# Not enough audio; pad it.
|
||||
audio = np.pad(audio, (0, minimum_audio_length - len(audio)))
|
||||
def _apply_hf_processor(
|
||||
self,
|
||||
prompt: str,
|
||||
mm_data: MultiModalDataDict,
|
||||
mm_processor_kwargs: Mapping[str, object],
|
||||
) -> BatchFeature:
|
||||
if not mm_data or not mm_data.get("audio", None):
|
||||
return super()._apply_hf_processor(prompt, mm_data,
|
||||
mm_processor_kwargs)
|
||||
|
||||
single_audio_features = feature_extractor(
|
||||
audio, sampling_rate=sr, padding="longest",
|
||||
return_tensors="pt")["input_features"]
|
||||
audio_data = mm_data["audio"]
|
||||
if not isinstance(audio_data, list):
|
||||
audio_data = [audio_data]
|
||||
|
||||
# Remove the batch dimension because we're wrapping it in a list.
|
||||
audio_features.append(single_audio_features.squeeze(0))
|
||||
# Ultravox processor doesn't support multiple inputs,
|
||||
# therefore we need to input text and audio one by one
|
||||
tokenizer = self._get_tokenizer()
|
||||
audio_features, audio_token_len = [], []
|
||||
processed_inputs = {}
|
||||
for audio, sr in audio_data:
|
||||
data = self._resample_audio(audio, sr)
|
||||
processed_inputs = super()._apply_hf_processor(
|
||||
prompt, data, mm_processor_kwargs)
|
||||
prompt = tokenizer.decode(processed_inputs["input_ids"][0],
|
||||
skip_special_tokens=False)
|
||||
audio_features.append(
|
||||
processed_inputs.pop("audio_values").squeeze(0))
|
||||
audio_token_len.append(
|
||||
processed_inputs.pop("audio_token_len").item())
|
||||
|
||||
return MultiModalKwargs({"audio_features": audio_features})
|
||||
return dict(
|
||||
**processed_inputs,
|
||||
audio_features=audio_features,
|
||||
audio_token_len=audio_token_len,
|
||||
)
|
||||
|
||||
def _get_processor_data(
|
||||
self,
|
||||
mm_data: MultiModalDataDict,
|
||||
) -> Tuple[Dict[str, Any], Dict[str, Any]]:
|
||||
# Ultravox uses "audio" instead of "audios" as calling keyword
|
||||
processor_data, passthrough_data = super()._get_processor_data(mm_data)
|
||||
if "audios" in processor_data:
|
||||
processor_data["audio"] = processor_data.pop("audios")
|
||||
return processor_data, passthrough_data
|
||||
|
||||
def input_processor_for_ultravox(ctx: InputContext, inputs: DecoderOnlyInputs):
|
||||
multi_modal_data = inputs.get("multi_modal_data")
|
||||
if multi_modal_data is None or "audio" not in multi_modal_data:
|
||||
return inputs
|
||||
def _get_prompt_replacements(
|
||||
self,
|
||||
mm_items: MultiModalDataItems,
|
||||
hf_inputs: BatchFeature,
|
||||
mm_processor_kwargs: Mapping[str, object],
|
||||
) -> list[PromptReplacement]:
|
||||
hf_processor = self._get_hf_processor()
|
||||
placeholder = hf_processor.audio_token_replacement
|
||||
|
||||
if "multi_modal_placeholders" in inputs and "audio" in inputs[
|
||||
"multi_modal_placeholders"]:
|
||||
# The inputs already have placeholders.
|
||||
return inputs
|
||||
def get_replacement_ultravox(item_idx: int):
|
||||
audio_token_len = hf_inputs["audio_token_len"][item_idx]
|
||||
return placeholder * audio_token_len
|
||||
|
||||
feature_extractor = whisper_feature_extractor(ctx)
|
||||
audios = multi_modal_data["audio"]
|
||||
if not isinstance(audios, list):
|
||||
audios = [audios]
|
||||
return [
|
||||
PromptReplacement(
|
||||
modality="audio",
|
||||
target="<|audio|>",
|
||||
replacement=get_replacement_ultravox,
|
||||
)
|
||||
]
|
||||
|
||||
audio_token_counts = []
|
||||
for audio in audios:
|
||||
if isinstance(audio, torch.Tensor):
|
||||
audio_num_tokens = audio.shape[1]
|
||||
audio_token_counts.append(audio_num_tokens)
|
||||
else:
|
||||
audio_data, sample_rate = audio
|
||||
audio_length = audio_data.shape[0]
|
||||
if sample_rate != feature_extractor.sampling_rate:
|
||||
# Account for resampling.
|
||||
adjustment = feature_extractor.sampling_rate / sample_rate
|
||||
audio_length = math.ceil(adjustment * audio_length)
|
||||
def _get_dummy_mm_inputs(
|
||||
self,
|
||||
mm_counts: Mapping[str, int],
|
||||
) -> ProcessorInputs:
|
||||
feature_extractor = self._get_feature_extractor()
|
||||
sampling_rate = feature_extractor.sampling_rate
|
||||
audio_len = feature_extractor.chunk_length * sampling_rate
|
||||
|
||||
feature_extractor_output_length = math.ceil(
|
||||
(audio_length - (feature_extractor.hop_length - 1)) /
|
||||
feature_extractor.hop_length)
|
||||
audio_count = mm_counts["audio"]
|
||||
audio = np.zeros(audio_len)
|
||||
data = {"audio": [(audio, sampling_rate)] * audio_count}
|
||||
|
||||
uv_config = ctx.get_hf_config(UltravoxConfig)
|
||||
audio_num_tokens = min(
|
||||
max(
|
||||
1,
|
||||
math.ceil(feature_extractor_output_length /
|
||||
(uv_config.stack_factor * 2))),
|
||||
get_ultravox_max_audio_tokens(ctx))
|
||||
audio_token_counts.append(audio_num_tokens)
|
||||
|
||||
tokenizer = cached_get_tokenizer(ctx.model_config.tokenizer)
|
||||
|
||||
new_prompt, new_token_ids, ranges = repeat_and_pad_placeholder_tokens(
|
||||
tokenizer,
|
||||
inputs.get("prompt"),
|
||||
inputs["prompt_token_ids"],
|
||||
placeholder_token_id=_AUDIO_PLACEHOLDER_TOKEN,
|
||||
repeat_count=audio_token_counts,
|
||||
)
|
||||
|
||||
# NOTE: Create a defensive copy of the original inputs
|
||||
return token_inputs(prompt_token_ids=new_token_ids,
|
||||
prompt=new_prompt,
|
||||
multi_modal_data=multi_modal_data,
|
||||
multi_modal_placeholders={"audio": ranges})
|
||||
return ProcessorInputs(
|
||||
prompt_text="<|audio|>" * audio_count,
|
||||
mm_data=data,
|
||||
mm_processor_kwargs={},
|
||||
)
|
||||
|
||||
|
||||
class StackAudioFrames(nn.Module):
|
||||
@ -332,11 +298,9 @@ class ModifiedWhisperEncoder(WhisperEncoder):
|
||||
return hidden_states
|
||||
|
||||
|
||||
@MULTIMODAL_REGISTRY.register_input_mapper("audio", input_mapper_for_ultravox)
|
||||
@MULTIMODAL_REGISTRY.register_max_multimodal_tokens(
|
||||
"audio", get_ultravox_max_audio_tokens)
|
||||
@INPUT_REGISTRY.register_dummy_data(dummy_data_for_ultravox)
|
||||
@INPUT_REGISTRY.register_input_processor(input_processor_for_ultravox)
|
||||
@MULTIMODAL_REGISTRY.register_processor(UltravoxMultiModalProcessor)
|
||||
class UltravoxModel(nn.Module, SupportsMultiModal, SupportsPP):
|
||||
|
||||
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
||||
|
@ -594,14 +594,10 @@ class BaseMultiModalProcessor(ABC):
|
||||
return list(
|
||||
iter_placeholders(all_prompt_repls, new_token_ids, mm_item_counts))
|
||||
|
||||
def _apply_hf_processor(
|
||||
def _get_processor_data(
|
||||
self,
|
||||
prompt: str,
|
||||
mm_data: MultiModalDataDict,
|
||||
mm_processor_kwargs: Mapping[str, object],
|
||||
) -> BatchFeature:
|
||||
hf_processor = self._get_hf_processor(**mm_processor_kwargs)
|
||||
|
||||
processor_data = dict[str, Any]()
|
||||
passthrough_data = dict[str, Any]()
|
||||
for k, v in mm_data.items():
|
||||
@ -619,6 +615,19 @@ class BaseMultiModalProcessor(ABC):
|
||||
processor_data[f"{k}s"] = v
|
||||
else:
|
||||
processor_data[k] = v
|
||||
return processor_data, passthrough_data
|
||||
|
||||
def _apply_hf_processor(
|
||||
self,
|
||||
prompt: str,
|
||||
mm_data: MultiModalDataDict,
|
||||
mm_processor_kwargs: Mapping[str, object],
|
||||
) -> BatchFeature:
|
||||
# some mm_processor_kwargs may be used in processor initialization
|
||||
# instead of processor call
|
||||
hf_processor = self._get_hf_processor(**mm_processor_kwargs)
|
||||
|
||||
processor_data, passthrough_data = self._get_processor_data(mm_data)
|
||||
|
||||
assert callable(hf_processor)
|
||||
mm_processor_kwargs = self.ctx.resolve_hf_processor_call_kwargs(
|
||||
|
Loading…
x
Reference in New Issue
Block a user