[Model] Refactor Ultravox to use merged input processor (#11198)

Signed-off-by: Isotr0py <2037008807@qq.com>
Co-authored-by: Cyrus Leung <cyrus.tl.leung@gmail.com>
This commit is contained in:
Isotr0py 2024-12-16 18:09:53 +08:00 committed by GitHub
parent bddbbcb132
commit d927dbcd88
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
7 changed files with 121 additions and 146 deletions

View File

@ -25,16 +25,16 @@ def run_ultravox(question: str, audio_count: int):
tokenizer = AutoTokenizer.from_pretrained(model_name) tokenizer = AutoTokenizer.from_pretrained(model_name)
messages = [{ messages = [{
'role': 'role': 'user',
'user', 'content': "<|audio|>\n" * audio_count + question
'content':
"<|reserved_special_token_0|>\n" * audio_count + question
}] }]
prompt = tokenizer.apply_chat_template(messages, prompt = tokenizer.apply_chat_template(messages,
tokenize=False, tokenize=False,
add_generation_prompt=True) add_generation_prompt=True)
llm = LLM(model=model_name, limit_mm_per_prompt={"audio": audio_count}) llm = LLM(model=model_name,
trust_remote_code=True,
limit_mm_per_prompt={"audio": audio_count})
stop_token_ids = None stop_token_ids = None
return llm, prompt, stop_token_ids return llm, prompt, stop_token_ids

View File

@ -214,7 +214,7 @@ MULTIMODAL_MODELS = {
"Qwen/Qwen-VL-Chat": PPTestSettings.fast(trust_remote_code=True), "Qwen/Qwen-VL-Chat": PPTestSettings.fast(trust_remote_code=True),
"Qwen/Qwen2-Audio-7B-Instruct": PPTestSettings.fast(), "Qwen/Qwen2-Audio-7B-Instruct": PPTestSettings.fast(),
"Qwen/Qwen2-VL-2B-Instruct": PPTestSettings.fast(), "Qwen/Qwen2-VL-2B-Instruct": PPTestSettings.fast(),
"fixie-ai/ultravox-v0_3": PPTestSettings.fast(), "fixie-ai/ultravox-v0_3": PPTestSettings.fast(trust_remote_code=True),
# [Encoder-decoder] # [Encoder-decoder]
# TODO: Implement PP # TODO: Implement PP
# "meta-llama/Llama-3.2-11B-Vision-Instruct": PPTestSettings.fast(), # "meta-llama/Llama-3.2-11B-Vision-Instruct": PPTestSettings.fast(),

View File

@ -25,6 +25,7 @@ def server():
"--max-num-seqs", "--max-num-seqs",
"5", "5",
"--enforce-eager", "--enforce-eager",
"--trust-remote-code",
] ]
with RemoteOpenAIServer(MODEL_NAME, args) as remote_server: with RemoteOpenAIServer(MODEL_NAME, args) as remote_server:

View File

@ -16,7 +16,7 @@ MODEL_NAME = "fixie-ai/ultravox-v0_3"
AudioTuple = Tuple[np.ndarray, int] AudioTuple = Tuple[np.ndarray, int]
VLLM_PLACEHOLDER = "<|reserved_special_token_0|>" VLLM_PLACEHOLDER = "<|audio|>"
HF_PLACEHOLDER = "<|audio|>" HF_PLACEHOLDER = "<|audio|>"
CHUNKED_PREFILL_KWARGS = { CHUNKED_PREFILL_KWARGS = {
@ -46,7 +46,8 @@ def audio(request):
def server(request, audio_assets): def server(request, audio_assets):
args = [ args = [
"--dtype=bfloat16", "--max-model-len=4096", "--enforce-eager", "--dtype=bfloat16", "--max-model-len=4096", "--enforce-eager",
f"--limit-mm-per-prompt=audio={len(audio_assets)}" f"--limit-mm-per-prompt=audio={len(audio_assets)}",
"--trust-remote-code"
] + [ ] + [
f"--{key.replace('_','-')}={value}" f"--{key.replace('_','-')}={value}"
for key, value in request.param.items() for key, value in request.param.items()

View File

@ -418,7 +418,7 @@ class BaseMultiModalItemTracker(ABC, Generic[_T]):
raise TypeError(f"Unknown {modality} model type: {model_type}") raise TypeError(f"Unknown {modality} model type: {model_type}")
elif modality == "audio": elif modality == "audio":
if model_type == "ultravox": if model_type == "ultravox":
return "<|reserved_special_token_0|>" return "<|audio|>"
if model_type == "qwen2_audio": if model_type == "qwen2_audio":
return (f"Audio {current_count}: " return (f"Audio {current_count}: "
f"<|audio_bos|><|AUDIO|><|audio_eos|>") f"<|audio_bos|><|AUDIO|><|audio_eos|>")

View File

@ -3,41 +3,39 @@
import math import math
from functools import cached_property, lru_cache from functools import cached_property, lru_cache
from typing import (Iterable, List, Literal, Mapping, Optional, Set, Tuple, from typing import (Any, Dict, Iterable, List, Literal, Mapping, Optional, Set,
TypedDict, Union, cast) Tuple, TypedDict, Union)
import numpy as np import numpy as np
import torch import torch
import torch.utils.checkpoint import torch.utils.checkpoint
from torch import nn from torch import nn
from torch.nn import functional as F from torch.nn import functional as F
from transformers import BatchFeature
from transformers.models.whisper import WhisperFeatureExtractor from transformers.models.whisper import WhisperFeatureExtractor
from transformers.models.whisper.modeling_whisper import WhisperEncoder from transformers.models.whisper.modeling_whisper import WhisperEncoder
from vllm.attention import AttentionMetadata from vllm.attention import AttentionMetadata
from vllm.config import VllmConfig from vllm.config import VllmConfig
from vllm.inputs import (INPUT_REGISTRY, DecoderOnlyInputs, DummyData, from vllm.inputs import InputContext
InputContext, token_inputs)
from vllm.model_executor.layers.activation import SiluAndMul, get_act_fn from vllm.model_executor.layers.activation import SiluAndMul, get_act_fn
from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler
from vllm.model_executor.model_loader.loader import DefaultModelLoader from vllm.model_executor.model_loader.loader import DefaultModelLoader
from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.multimodal import (MULTIMODAL_REGISTRY, MultiModalKwargs, from vllm.multimodal import MULTIMODAL_REGISTRY, NestedTensors
NestedTensors) from vllm.multimodal.processing import (BaseMultiModalProcessor,
from vllm.multimodal.utils import (cached_get_tokenizer, MultiModalDataDict,
consecutive_placeholder_ranges, MultiModalDataItems, ProcessorInputs,
repeat_and_pad_placeholder_tokens) PromptReplacement)
from vllm.sequence import IntermediateTensors, SequenceData from vllm.sequence import IntermediateTensors
from vllm.transformers_utils.configs.ultravox import UltravoxConfig from vllm.transformers_utils.configs.ultravox import UltravoxConfig
from vllm.utils import is_list_of
from .interfaces import SupportsMultiModal, SupportsPP from .interfaces import SupportsMultiModal, SupportsPP
from .utils import (AutoWeightsLoader, WeightsMapper, flatten_bn, from .utils import (AutoWeightsLoader, WeightsMapper, flatten_bn,
init_vllm_registered_model, maybe_prefix, init_vllm_registered_model, maybe_prefix,
merge_multimodal_embeddings_from_map) merge_multimodal_embeddings_from_map)
_AUDIO_PLACEHOLDER_TOKEN = 128002
_AUDIO_TOKENS_PER_SECOND = 6.25 _AUDIO_TOKENS_PER_SECOND = 6.25
@ -72,64 +70,18 @@ def get_ultravox_max_audio_tokens(ctx: InputContext):
return math.ceil(feature_extractor.chunk_length * _AUDIO_TOKENS_PER_SECOND) return math.ceil(feature_extractor.chunk_length * _AUDIO_TOKENS_PER_SECOND)
def dummy_seq_data_for_ultravox( class UltravoxMultiModalProcessor(BaseMultiModalProcessor):
ctx: InputContext,
seq_len: int,
audio_count: int,
):
audio_length = min(get_ultravox_max_audio_tokens(ctx),
seq_len // audio_count)
return SequenceData.from_prompt_token_counts( def _get_feature_extractor(self) -> WhisperFeatureExtractor:
(_AUDIO_PLACEHOLDER_TOKEN, audio_length * audio_count), return self._get_hf_processor().audio_processor.feature_extractor
(0, seq_len - audio_length * audio_count)), {
"audio":
consecutive_placeholder_ranges(num_items=audio_count,
item_size=audio_length)
}
def dummy_audio_for_ultravox(
ctx: InputContext,
audio_count: int,
):
feature_extractor = whisper_feature_extractor(ctx)
audio_and_sr = (np.array([0.0] * feature_extractor.chunk_length), 1)
return {"audio": [audio_and_sr] * audio_count}
def dummy_data_for_ultravox(
ctx: InputContext,
seq_len: int,
mm_counts: Mapping[str, int],
):
audio_count = mm_counts["audio"]
seq_data, ranges = dummy_seq_data_for_ultravox(ctx, seq_len, audio_count)
mm_dict = dummy_audio_for_ultravox(ctx, audio_count)
return DummyData(seq_data, mm_dict, ranges)
def input_mapper_for_ultravox(ctx: InputContext, data: object):
if not isinstance(data, list):
data = [data]
if len(data) == 0:
return MultiModalKwargs()
# If the audio inputs are embeddings, no need for preprocessing
if is_list_of(data, torch.Tensor, check="all"):
return MultiModalKwargs({"audio_embeds": data})
audio_features = []
for audio_input in data:
if not isinstance(audio_input, tuple):
raise NotImplementedError(
f"Unsupported data type: {type(audio_input)}")
(audio, sr) = cast(Tuple[np.ndarray, Union[float, int]], audio_input)
feature_extractor = whisper_feature_extractor(ctx)
def _resample_audio(
self,
audio: np.ndarray,
sr: int,
) -> Dict[str, Union[np.ndarray, int]]:
# resample audio to the model's sampling rate
feature_extractor = self._get_feature_extractor()
if sr != feature_extractor.sampling_rate: if sr != feature_extractor.sampling_rate:
try: try:
import librosa import librosa
@ -140,78 +92,92 @@ def input_mapper_for_ultravox(ctx: InputContext, data: object):
orig_sr=sr, orig_sr=sr,
target_sr=feature_extractor.sampling_rate) target_sr=feature_extractor.sampling_rate)
sr = feature_extractor.sampling_rate sr = feature_extractor.sampling_rate
return {"audio": audio, "sampling_rate": sr}
minimum_audio_length = feature_extractor.n_fft // 2 + 1 def _apply_hf_processor(
if len(audio) < minimum_audio_length: self,
# Not enough audio; pad it. prompt: str,
audio = np.pad(audio, (0, minimum_audio_length - len(audio))) mm_data: MultiModalDataDict,
mm_processor_kwargs: Mapping[str, object],
) -> BatchFeature:
if not mm_data or not mm_data.get("audio", None):
return super()._apply_hf_processor(prompt, mm_data,
mm_processor_kwargs)
single_audio_features = feature_extractor( audio_data = mm_data["audio"]
audio, sampling_rate=sr, padding="longest", if not isinstance(audio_data, list):
return_tensors="pt")["input_features"] audio_data = [audio_data]
# Remove the batch dimension because we're wrapping it in a list. # Ultravox processor doesn't support multiple inputs,
audio_features.append(single_audio_features.squeeze(0)) # therefore we need to input text and audio one by one
tokenizer = self._get_tokenizer()
audio_features, audio_token_len = [], []
processed_inputs = {}
for audio, sr in audio_data:
data = self._resample_audio(audio, sr)
processed_inputs = super()._apply_hf_processor(
prompt, data, mm_processor_kwargs)
prompt = tokenizer.decode(processed_inputs["input_ids"][0],
skip_special_tokens=False)
audio_features.append(
processed_inputs.pop("audio_values").squeeze(0))
audio_token_len.append(
processed_inputs.pop("audio_token_len").item())
return MultiModalKwargs({"audio_features": audio_features}) return dict(
**processed_inputs,
audio_features=audio_features,
def input_processor_for_ultravox(ctx: InputContext, inputs: DecoderOnlyInputs): audio_token_len=audio_token_len,
multi_modal_data = inputs.get("multi_modal_data")
if multi_modal_data is None or "audio" not in multi_modal_data:
return inputs
if "multi_modal_placeholders" in inputs and "audio" in inputs[
"multi_modal_placeholders"]:
# The inputs already have placeholders.
return inputs
feature_extractor = whisper_feature_extractor(ctx)
audios = multi_modal_data["audio"]
if not isinstance(audios, list):
audios = [audios]
audio_token_counts = []
for audio in audios:
if isinstance(audio, torch.Tensor):
audio_num_tokens = audio.shape[1]
audio_token_counts.append(audio_num_tokens)
else:
audio_data, sample_rate = audio
audio_length = audio_data.shape[0]
if sample_rate != feature_extractor.sampling_rate:
# Account for resampling.
adjustment = feature_extractor.sampling_rate / sample_rate
audio_length = math.ceil(adjustment * audio_length)
feature_extractor_output_length = math.ceil(
(audio_length - (feature_extractor.hop_length - 1)) /
feature_extractor.hop_length)
uv_config = ctx.get_hf_config(UltravoxConfig)
audio_num_tokens = min(
max(
1,
math.ceil(feature_extractor_output_length /
(uv_config.stack_factor * 2))),
get_ultravox_max_audio_tokens(ctx))
audio_token_counts.append(audio_num_tokens)
tokenizer = cached_get_tokenizer(ctx.model_config.tokenizer)
new_prompt, new_token_ids, ranges = repeat_and_pad_placeholder_tokens(
tokenizer,
inputs.get("prompt"),
inputs["prompt_token_ids"],
placeholder_token_id=_AUDIO_PLACEHOLDER_TOKEN,
repeat_count=audio_token_counts,
) )
# NOTE: Create a defensive copy of the original inputs def _get_processor_data(
return token_inputs(prompt_token_ids=new_token_ids, self,
prompt=new_prompt, mm_data: MultiModalDataDict,
multi_modal_data=multi_modal_data, ) -> Tuple[Dict[str, Any], Dict[str, Any]]:
multi_modal_placeholders={"audio": ranges}) # Ultravox uses "audio" instead of "audios" as calling keyword
processor_data, passthrough_data = super()._get_processor_data(mm_data)
if "audios" in processor_data:
processor_data["audio"] = processor_data.pop("audios")
return processor_data, passthrough_data
def _get_prompt_replacements(
self,
mm_items: MultiModalDataItems,
hf_inputs: BatchFeature,
mm_processor_kwargs: Mapping[str, object],
) -> list[PromptReplacement]:
hf_processor = self._get_hf_processor()
placeholder = hf_processor.audio_token_replacement
def get_replacement_ultravox(item_idx: int):
audio_token_len = hf_inputs["audio_token_len"][item_idx]
return placeholder * audio_token_len
return [
PromptReplacement(
modality="audio",
target="<|audio|>",
replacement=get_replacement_ultravox,
)
]
def _get_dummy_mm_inputs(
self,
mm_counts: Mapping[str, int],
) -> ProcessorInputs:
feature_extractor = self._get_feature_extractor()
sampling_rate = feature_extractor.sampling_rate
audio_len = feature_extractor.chunk_length * sampling_rate
audio_count = mm_counts["audio"]
audio = np.zeros(audio_len)
data = {"audio": [(audio, sampling_rate)] * audio_count}
return ProcessorInputs(
prompt_text="<|audio|>" * audio_count,
mm_data=data,
mm_processor_kwargs={},
)
class StackAudioFrames(nn.Module): class StackAudioFrames(nn.Module):
@ -332,11 +298,9 @@ class ModifiedWhisperEncoder(WhisperEncoder):
return hidden_states return hidden_states
@MULTIMODAL_REGISTRY.register_input_mapper("audio", input_mapper_for_ultravox)
@MULTIMODAL_REGISTRY.register_max_multimodal_tokens( @MULTIMODAL_REGISTRY.register_max_multimodal_tokens(
"audio", get_ultravox_max_audio_tokens) "audio", get_ultravox_max_audio_tokens)
@INPUT_REGISTRY.register_dummy_data(dummy_data_for_ultravox) @MULTIMODAL_REGISTRY.register_processor(UltravoxMultiModalProcessor)
@INPUT_REGISTRY.register_input_processor(input_processor_for_ultravox)
class UltravoxModel(nn.Module, SupportsMultiModal, SupportsPP): class UltravoxModel(nn.Module, SupportsMultiModal, SupportsPP):
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):

View File

@ -594,14 +594,10 @@ class BaseMultiModalProcessor(ABC):
return list( return list(
iter_placeholders(all_prompt_repls, new_token_ids, mm_item_counts)) iter_placeholders(all_prompt_repls, new_token_ids, mm_item_counts))
def _apply_hf_processor( def _get_processor_data(
self, self,
prompt: str,
mm_data: MultiModalDataDict, mm_data: MultiModalDataDict,
mm_processor_kwargs: Mapping[str, object],
) -> BatchFeature: ) -> BatchFeature:
hf_processor = self._get_hf_processor(**mm_processor_kwargs)
processor_data = dict[str, Any]() processor_data = dict[str, Any]()
passthrough_data = dict[str, Any]() passthrough_data = dict[str, Any]()
for k, v in mm_data.items(): for k, v in mm_data.items():
@ -619,6 +615,19 @@ class BaseMultiModalProcessor(ABC):
processor_data[f"{k}s"] = v processor_data[f"{k}s"] = v
else: else:
processor_data[k] = v processor_data[k] = v
return processor_data, passthrough_data
def _apply_hf_processor(
self,
prompt: str,
mm_data: MultiModalDataDict,
mm_processor_kwargs: Mapping[str, object],
) -> BatchFeature:
# some mm_processor_kwargs may be used in processor initialization
# instead of processor call
hf_processor = self._get_hf_processor(**mm_processor_kwargs)
processor_data, passthrough_data = self._get_processor_data(mm_data)
assert callable(hf_processor) assert callable(hf_processor)
mm_processor_kwargs = self.ctx.resolve_hf_processor_call_kwargs( mm_processor_kwargs = self.ctx.resolve_hf_processor_call_kwargs(