[Model] Add UltravoxModel and UltravoxConfig (#7615)

This commit is contained in:
Peter Salas 2024-08-21 15:49:39 -07:00 committed by GitHub
parent dd53c4b023
commit 1ca0d4f86b
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
33 changed files with 1087 additions and 261 deletions

View File

@ -186,7 +186,7 @@ Multimodal Language Models
* - Architecture
- Models
- Supported Modality(ies)
- Supported Modalities
- Example HuggingFace Models
- :ref:`LoRA <lora>`
* - :code:`Blip2ForConditionalGeneration`
@ -234,6 +234,11 @@ Multimodal Language Models
- Image
- :code:`openbmb/MiniCPM-V-2` (see note), :code:`openbmb/MiniCPM-Llama3-V-2_5`, :code:`openbmb/MiniCPM-V-2_6`, etc.
-
* - :code: `UltravoxModel`
- Ultravox
- Audio
- :code: `fixie-ai/ultravox-v0_3`
-
.. note::
For :code:`openbmb/MiniCPM-V-2`, the official repo doesn't work yet, so we need to use a fork (:code:`HwwwH/MiniCPM-V-2`) for now.

View File

@ -0,0 +1,97 @@
"""
This example shows how to use vLLM for running offline inference
with the correct prompt format on vision language models.
For most models, the prompt format should follow corresponding examples
on HuggingFace model repository.
"""
from transformers import AutoTokenizer
from vllm import LLM, SamplingParams
from vllm.assets.audio import AudioAsset
from vllm.utils import FlexibleArgumentParser
# Input audio and question
audio_and_sample_rate = AudioAsset("mary_had_lamb").audio_and_sample_rate
question = "What is recited in the audio?"
# Ultravox 0.3
def run_ultravox(question):
model_name = "fixie-ai/ultravox-v0_3"
tokenizer = AutoTokenizer.from_pretrained(model_name)
messages = [{
'role': 'user',
'content': f"<|reserved_special_token_0|>\n{question}"
}]
prompt = tokenizer.apply_chat_template(messages,
tokenize=False,
add_generation_prompt=True)
llm = LLM(model=model_name)
stop_token_ids = None
return llm, prompt, stop_token_ids
model_example_map = {
"ultravox": run_ultravox,
}
def main(args):
model = args.model_type
if model not in model_example_map:
raise ValueError(f"Model type {model} is not supported.")
llm, prompt, stop_token_ids = model_example_map[model](question)
# We set temperature to 0.2 so that outputs can be different
# even when all prompts are identical when running batch inference.
sampling_params = SamplingParams(temperature=0.2,
max_tokens=64,
stop_token_ids=stop_token_ids)
assert args.num_prompts > 0
if args.num_prompts == 1:
# Single inference
inputs = {
"prompt": prompt,
"multi_modal_data": {
"audio": audio_and_sample_rate
},
}
else:
# Batch inference
inputs = [{
"prompt": prompt,
"multi_modal_data": {
"audio": audio_and_sample_rate
},
} for _ in range(args.num_prompts)]
outputs = llm.generate(inputs, sampling_params=sampling_params)
for o in outputs:
generated_text = o.outputs[0].text
print(generated_text)
if __name__ == "__main__":
parser = FlexibleArgumentParser(
description='Demo on using vLLM for offline inference with '
'audio language models')
parser.add_argument('--model-type',
'-m',
type=str,
default="ultravox",
choices=model_example_map.keys(),
help='Huggingface "model_type".')
parser.add_argument('--num-prompts',
type=int,
default=1,
help='Number of prompts to run.')
args = parser.parse_args()
main(args)

View File

@ -0,0 +1,90 @@
"""An example showing how to use vLLM to serve VLMs.
Launch the vLLM server with the following command:
vllm serve fixie-ai/ultravox-v0_3
"""
import base64
import requests
from openai import OpenAI
from vllm.assets.audio import AudioAsset
# Modify OpenAI's API key and API base to use vLLM's API server.
openai_api_key = "EMPTY"
openai_api_base = "http://localhost:8000/v1"
client = OpenAI(
# defaults to os.environ.get("OPENAI_API_KEY")
api_key=openai_api_key,
base_url=openai_api_base,
)
models = client.models.list()
model = models.data[0].id
# Any format supported by librosa is supported
audio_url = AudioAsset("winning_call").url
# Use audio url in the payload
chat_completion_from_url = client.chat.completions.create(
messages=[{
"role":
"user",
"content": [
{
"type": "text",
"text": "What's in this audio?"
},
{
"type": "audio_url",
"audio_url": {
"url": audio_url
},
},
],
}],
model=model,
max_tokens=64,
)
result = chat_completion_from_url.choices[0].message.content
print(f"Chat completion output:{result}")
# Use base64 encoded audio in the payload
def encode_audio_base64_from_url(audio_url: str) -> str:
"""Encode an audio retrieved from a remote url to base64 format."""
with requests.get(audio_url) as response:
response.raise_for_status()
result = base64.b64encode(response.content).decode('utf-8')
return result
audio_base64 = encode_audio_base64_from_url(audio_url=audio_url)
chat_completion_from_base64 = client.chat.completions.create(
messages=[{
"role":
"user",
"content": [
{
"type": "text",
"text": "What's in this audio?"
},
{
"type": "audio_url",
"audio_url": {
# Any format supported by librosa is supported
"url": f"data:audio/ogg;base64,{audio_base64}"
},
},
],
}],
model=model,
max_tokens=64,
)
result = chat_completion_from_base64.choices[0].message.content
print(f"Chat completion output:{result}")

View File

@ -9,14 +9,14 @@ from enum import Enum
from typing import (Any, Callable, Dict, List, Optional, Tuple, TypedDict,
TypeVar, Union)
import numpy as np
import pytest
import torch
import torch.nn as nn
import torch.nn.functional as F
from huggingface_hub import snapshot_download
from PIL import Image
from transformers import (AutoModelForCausalLM, AutoModelForSeq2SeqLM,
AutoModelForVision2Seq, AutoTokenizer, BatchEncoding,
from transformers import (AutoModelForCausalLM, AutoTokenizer, BatchEncoding,
BatchFeature)
from vllm import LLM, SamplingParams
@ -216,8 +216,7 @@ class HfRunner:
*,
model_kwargs: Optional[Dict[str, Any]] = None,
is_embedding_model: bool = False,
is_vision_model: bool = False,
is_encoder_decoder_model: bool = False,
auto_cls=AutoModelForCausalLM,
postprocess_inputs: Callable[[BatchEncoding],
BatchEncoding] = identity,
) -> None:
@ -234,13 +233,6 @@ class HfRunner:
device="cpu",
).to(dtype=torch_dtype))
else:
if is_vision_model:
auto_cls = AutoModelForVision2Seq
elif is_encoder_decoder_model:
auto_cls = AutoModelForSeq2SeqLM
else:
auto_cls = AutoModelForCausalLM
model_kwargs = model_kwargs if model_kwargs is not None else {}
self.model = self.wrap_device(
auto_cls.from_pretrained(
@ -432,6 +424,7 @@ class HfRunner:
max_tokens: int,
num_logprobs: int,
images: Optional[List[Image.Image]] = None,
audios: Optional[List[Tuple[np.ndarray, int]]] = None,
**kwargs: Any,
) -> List[Tuple[List[int], str, List[Dict[int, float]]]]:
all_logprobs: List[List[Dict[int, float]]] = []
@ -446,6 +439,11 @@ class HfRunner:
if images is not None and images[i] is not None:
processor_kwargs["images"] = images[i]
if audios is not None:
audio, sr = audios[i]
processor_kwargs["audio"] = audio
processor_kwargs["sampling_rate"] = sr
inputs = self.processor(**processor_kwargs)
inputs = self.postprocess_inputs(inputs)
@ -627,6 +625,8 @@ class VllmRunner:
sampling_params: SamplingParams,
images: Optional[Union[List[Image.Image],
List[List[Image.Image]]]] = None,
audios: Optional[Union[List[Tuple[np.ndarray, int]],
List[List[Tuple[np.ndarray, int]]]]] = None
) -> List[Tuple[List[int], str, Optional[SampleLogprobs]]]:
assert sampling_params.logprobs is not None
@ -638,6 +638,10 @@ class VllmRunner:
for i, image in enumerate(images):
inputs[i]["multi_modal_data"] = {"image": image}
if audios is not None:
for i, audio in enumerate(audios):
inputs[i]["multi_modal_data"] = {"audio": audio}
req_outputs = self.model.generate(inputs,
sampling_params=sampling_params)
return self._final_steps_generate_w_logprobs(req_outputs)
@ -674,6 +678,8 @@ class VllmRunner:
num_logprobs: int,
images: Optional[Union[List[Image.Image],
List[List[Image.Image]]]] = None,
audios: Optional[Union[List[Tuple[np.ndarray, int]],
List[List[Tuple[np.ndarray, int]]]]] = None,
stop_token_ids: Optional[List[int]] = None,
) -> List[Tuple[List[int], str, Optional[SampleLogprobs]]]:
greedy_logprobs_params = SamplingParams(temperature=0.0,
@ -682,7 +688,8 @@ class VllmRunner:
stop_token_ids=stop_token_ids)
outputs = self.generate_w_logprobs(prompts,
greedy_logprobs_params,
images=images)
images=images,
audios=audios)
return [(output_ids, output_str, output_logprobs)
for output_ids, output_str, output_logprobs in outputs]

View File

@ -10,6 +10,7 @@ pytest distributed/test_basic_distributed_correctness_enc_dec.py
"""
import pytest
from transformers import AutoModelForSeq2SeqLM
from vllm.utils import cuda_device_count_stateless
@ -85,7 +86,7 @@ def test_models(
}
with hf_runner(model, dtype=dtype,
is_encoder_decoder_model=True) as hf_model:
auto_cls=AutoModelForSeq2SeqLM) as hf_model:
hf_outputs = (hf_model.generate_encoder_decoder_greedy_logprobs_limit(
test_prompts,
max_tokens,

View File

@ -1,138 +1,36 @@
import math
import sys
import time
from typing import Dict, List, Optional, Tuple, Union, cast
from unittest.mock import patch
from typing import Dict, List
import librosa
import numpy as np
import openai
import pytest
import requests
import torch
from vllm import ModelRegistry
from vllm.config import MultiModalConfig
from vllm.inputs import INPUT_REGISTRY
from vllm.inputs.data import LLMInputs
from vllm.inputs.registry import InputContext
from vllm.model_executor.models.interfaces import SupportsMultiModal
from vllm.model_executor.models.opt import OPTForCausalLM
from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.base import MultiModalInputs
from vllm.multimodal.image import (cached_get_tokenizer,
repeat_and_pad_image_tokens)
from vllm.assets.audio import AudioAsset
from vllm.multimodal.utils import encode_audio_base64, fetch_audio
from vllm.utils import get_open_port
from ...utils import VLLM_PATH
from ...utils import RemoteOpenAIServer
chatml_jinja_path = VLLM_PATH / "examples/template_chatml.jinja"
assert chatml_jinja_path.exists()
MODEL_NAME = "facebook/opt-125m"
MODEL_NAME = "fixie-ai/ultravox-v0_3"
TEST_AUDIO_URLS = [
"https://upload.wikimedia.org/wikipedia/en/b/bf/Dave_Niehaus_Winning_Call_1995_AL_Division_Series.ogg",
AudioAsset("winning_call").url,
]
def server_function(port):
@pytest.fixture(scope="module")
def server():
args = [
"--dtype",
"bfloat16",
"--max-model-len",
"4096",
"--enforce-eager",
]
def fake_input_mapper(ctx: InputContext, data: object):
assert isinstance(data, tuple)
(audio, sr) = cast(Tuple[np.ndarray, Union[float, int]], data)
# Resample it to 1 sample per second
audio = librosa.resample(audio, orig_sr=sr, target_sr=1)
return MultiModalInputs({"processed_audio": torch.from_numpy(audio)})
def fake_input_processor(ctx: InputContext, llm_inputs: LLMInputs):
multi_modal_data = llm_inputs.get("multi_modal_data")
if multi_modal_data is None or "audio" not in multi_modal_data:
return llm_inputs
audio, sr = multi_modal_data.get("audio")
audio_duration = math.ceil(len(audio) / sr)
new_prompt, new_token_ids = repeat_and_pad_image_tokens(
cached_get_tokenizer(ctx.model_config.tokenizer),
llm_inputs.get("prompt"),
llm_inputs["prompt_token_ids"],
image_token_id=62, # "_"
repeat_count=audio_duration)
return LLMInputs(prompt_token_ids=new_token_ids,
prompt=new_prompt,
multi_modal_data=multi_modal_data)
@MULTIMODAL_REGISTRY.register_input_mapper("audio", fake_input_mapper)
@MULTIMODAL_REGISTRY.register_max_multimodal_tokens(
"audio", lambda *_, **__: 100)
@INPUT_REGISTRY.register_input_processor(fake_input_processor)
class FakeAudioModel(OPTForCausalLM, SupportsMultiModal):
def __init__(self, *args, multimodal_config: MultiModalConfig,
**kwargs):
assert multimodal_config is not None
super().__init__(*args, **kwargs)
def forward(
self,
*args,
processed_audio: Optional[torch.Tensor] = None,
**kwargs,
) -> torch.Tensor:
return super().forward(*args, **kwargs)
ModelRegistry.register_model("OPTForCausalLM", FakeAudioModel)
with patch(
"vllm.entrypoints.chat_utils._mm_token_str",
lambda *_, **__: "_"), patch(
"vllm.model_executor.models.ModelRegistry.is_multimodal_model"
) as mock:
mock.return_value = True
sys.argv = ["placeholder.py"] + \
(f"--model {MODEL_NAME} --gpu-memory-utilization 0.10 "
"--dtype bfloat16 --enforce-eager --api-key token-abc123 "
f"--port {port} --chat-template {chatml_jinja_path} "
"--disable-frontend-multiprocessing").split()
import runpy
runpy.run_module('vllm.entrypoints.openai.api_server',
run_name='__main__')
with RemoteOpenAIServer(MODEL_NAME, args) as remote_server:
yield remote_server
@pytest.fixture(scope="module")
def client():
port = get_open_port()
ctx = torch.multiprocessing.get_context("spawn")
server = ctx.Process(target=server_function, args=(port, ))
server.start()
MAX_SERVER_START_WAIT_S = 60
client = openai.AsyncOpenAI(
base_url=f"http://localhost:{port}/v1",
api_key="token-abc123",
)
# run health check
health_url = f"http://localhost:{port}/health"
start = time.time()
while True:
try:
if requests.get(health_url).status_code == 200:
break
except Exception as err:
result = server.exitcode
if result is not None:
raise RuntimeError("Server exited unexpectedly.") from err
time.sleep(0.5)
if time.time() - start > MAX_SERVER_START_WAIT_S:
raise RuntimeError("Server failed to start in time.") from err
try:
yield client
finally:
server.kill()
def client(server):
return server.get_async_client()
@pytest.fixture(scope="session")
@ -176,7 +74,7 @@ async def test_single_chat_session_audio(client: openai.AsyncOpenAI,
choice = chat_completion.choices[0]
assert choice.finish_reason == "length"
assert chat_completion.usage == openai.types.CompletionUsage(
completion_tokens=10, prompt_tokens=36, total_tokens=46)
completion_tokens=10, prompt_tokens=202, total_tokens=212)
message = choice.message
message = chat_completion.choices[0].message
@ -231,7 +129,7 @@ async def test_single_chat_session_audio_base64encoded(
choice = chat_completion.choices[0]
assert choice.finish_reason == "length"
assert chat_completion.usage == openai.types.CompletionUsage(
completion_tokens=10, prompt_tokens=36, total_tokens=46)
completion_tokens=10, prompt_tokens=202, total_tokens=212)
message = choice.message
message = chat_completion.choices[0].message

View File

@ -12,6 +12,7 @@ if not is_cpu():
# (xFormers, etc.)
import pytest
from transformers import AutoModelForSeq2SeqLM
from vllm.sequence import SampleLogprobs
@ -131,7 +132,7 @@ if not is_cpu():
}
with hf_runner(model, dtype=dtype,
is_encoder_decoder_model=True) as hf_model:
auto_cls=AutoModelForSeq2SeqLM) as hf_model:
hf_outputs = (
hf_model.generate_encoder_decoder_greedy_logprobs_limit(
test_case_prompts,

View File

@ -1,7 +1,7 @@
from typing import List, Optional, Tuple
import pytest
from transformers import AutoTokenizer
from transformers import AutoModelForVision2Seq, AutoTokenizer
from vllm.multimodal.utils import rescale_image_size
from vllm.sequence import SampleLogprobs
@ -80,7 +80,8 @@ def test_models(hf_runner, vllm_runner, image_assets, model, size_factors,
for prompts, images in inputs_per_image
]
with hf_runner(model, dtype=dtype, is_vision_model=True) as hf_model:
with hf_runner(model, dtype=dtype,
auto_cls=AutoModelForVision2Seq) as hf_model:
hf_outputs_per_image = [
hf_model.generate_greedy_logprobs_limit(prompts,
max_tokens,

View File

@ -1,7 +1,7 @@
from typing import List, Optional, Type
import pytest
from transformers import BatchEncoding
from transformers import AutoModelForVision2Seq, BatchEncoding
from vllm.multimodal.utils import rescale_image_size
from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE
@ -74,7 +74,7 @@ def run_test(
with hf_runner(model,
dtype=dtype,
postprocess_inputs=process,
is_vision_model=True) as hf_model:
auto_cls=AutoModelForVision2Seq) as hf_model:
hf_outputs_per_image = [
hf_model.generate_greedy_logprobs_limit(prompts,
max_tokens,

View File

@ -1,7 +1,8 @@
from typing import List, Optional, Tuple, Type
import pytest
from transformers import AutoConfig, AutoTokenizer, BatchEncoding
from transformers import (AutoConfig, AutoModelForVision2Seq, AutoTokenizer,
BatchEncoding)
from vllm.multimodal.utils import rescale_image_size
from vllm.sequence import SampleLogprobs
@ -124,7 +125,7 @@ def run_test(
with hf_runner(model,
dtype=dtype,
postprocess_inputs=process,
is_vision_model=True) as hf_model:
auto_cls=AutoModelForVision2Seq) as hf_model:
hf_outputs_per_image = [
hf_model.generate_greedy_logprobs_limit(prompts,
max_tokens,

View File

@ -1,7 +1,7 @@
from typing import List, Optional, Tuple, Type
import pytest
from transformers import AutoConfig, AutoTokenizer
from transformers import AutoConfig, AutoModelForVision2Seq, AutoTokenizer
from vllm.sequence import SampleLogprobs
@ -105,7 +105,8 @@ def run_test(
for prompts, images in vllm_inputs_per_image
]
with hf_runner(model, dtype=dtype, is_vision_model=True) as hf_model:
with hf_runner(model, dtype=dtype,
auto_cls=AutoModelForVision2Seq) as hf_model:
hf_outputs_per_image = [
hf_model.generate_greedy_logprobs_limit(prompts,
max_tokens,

View File

@ -1,7 +1,7 @@
from typing import List, Optional, Tuple, Type, overload
import pytest
from transformers import AutoConfig, AutoTokenizer
from transformers import AutoConfig, AutoModelForVision2Seq, AutoTokenizer
from vllm.multimodal.utils import rescale_image_size
from vllm.sequence import SampleLogprobs
@ -129,7 +129,8 @@ def run_test(
for prompts, images in inputs_per_image
]
with hf_runner(model, dtype=dtype, is_vision_model=True) as hf_model:
with hf_runner(model, dtype=dtype,
auto_cls=AutoModelForVision2Seq) as hf_model:
hf_outputs_per_image = [
hf_model.generate_greedy_logprobs_limit(prompts,
max_tokens,

View File

@ -2,7 +2,7 @@ import os
from typing import List, Optional, Tuple, Type
import pytest
from transformers import AutoConfig, AutoTokenizer
from transformers import AutoConfig, AutoModelForVision2Seq, AutoTokenizer
from vllm.multimodal.utils import rescale_image_size
from vllm.sequence import SampleLogprobs
@ -102,7 +102,8 @@ def run_test(
for prompts, images in inputs_per_image
]
with hf_runner(model, dtype=dtype, is_vision_model=True) as hf_model:
with hf_runner(model, dtype=dtype,
auto_cls=AutoModelForVision2Seq) as hf_model:
hf_outputs_per_image = [
hf_model.generate_greedy_logprobs_limit(prompts,
max_tokens,

View File

@ -26,7 +26,7 @@ def test_text_only_qwen_model(
# for qwen-vl is still unsupported in VLLM. In the near-future, the
# implementation and this test will be extended to consider
# visual inputs as well.
with hf_runner(model, dtype=dtype, is_vision_model=False) as hf_model:
with hf_runner(model, dtype=dtype) as hf_model:
hf_outputs = hf_model.generate_greedy_logprobs_limit(
example_prompts,
max_tokens,

View File

@ -0,0 +1,151 @@
from typing import List, Optional, Tuple, Type
import librosa
import numpy as np
import pytest
from transformers import AutoModel, AutoTokenizer, BatchEncoding
from vllm.assets.audio import AudioAsset
from vllm.sequence import SampleLogprobs
from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE
from ..conftest import HfRunner, VllmRunner
from .utils import check_logprobs_close
pytestmark = pytest.mark.vlm
MODEL_NAME = "fixie-ai/ultravox-v0_3"
AudioTuple = Tuple[np.ndarray, int]
@pytest.fixture(scope="session")
def audio_and_sample_rate():
return AudioAsset("mary_had_lamb").audio_and_sample_rate
@pytest.fixture
def prompts_and_audios(audio_and_sample_rate):
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
vllm_placeholder = "<|reserved_special_token_0|>"
hf_placeholder = "<|audio|>"
question = "What's in the audio?"
vllm_prompt = tokenizer.apply_chat_template(
[{
'role': 'user',
'content': f"{vllm_placeholder}\n{question}"
}],
tokenize=False,
add_generation_prompt=True)
hf_prompt = tokenizer.apply_chat_template(
[{
'role': 'user',
'content': f"{hf_placeholder}\n{question}"
}],
tokenize=False,
add_generation_prompt=True)
return [(vllm_prompt, hf_prompt, audio_and_sample_rate)]
def vllm_to_hf_output(vllm_output: Tuple[List[int], str,
Optional[SampleLogprobs]],
model: str):
"""Sanitize vllm output to be comparable with hf output."""
output_ids, output_str, out_logprobs = vllm_output
tokenizer = AutoTokenizer.from_pretrained(model)
eos_token_id = tokenizer.eos_token_id
hf_output_ids = output_ids[:]
hf_output_str = output_str
if hf_output_ids[-1] == eos_token_id:
hf_output_str = hf_output_str + tokenizer.decode(eos_token_id)
return hf_output_ids, hf_output_str, out_logprobs
def run_test(
hf_runner: Type[HfRunner],
vllm_runner: Type[VllmRunner],
prompts_and_audios: List[Tuple[str, str, AudioTuple]],
model: str,
*,
dtype: str,
max_tokens: int,
num_logprobs: int,
tensor_parallel_size: int,
distributed_executor_backend: Optional[str] = None,
):
"""Inference result should be the same between hf and vllm."""
torch_dtype = STR_DTYPE_TO_TORCH_DTYPE[dtype]
# NOTE: take care of the order. run vLLM first, and then run HF.
# vLLM needs a fresh new process without cuda initialization.
# if we run HF first, the cuda initialization will be done and it
# will hurt multiprocessing backend with fork method (the default method).
with vllm_runner(model,
dtype=dtype,
tensor_parallel_size=tensor_parallel_size,
distributed_executor_backend=distributed_executor_backend,
enforce_eager=True) as vllm_model:
vllm_outputs_per_audio = [
vllm_model.generate_greedy_logprobs([vllm_prompt],
max_tokens,
num_logprobs=num_logprobs,
audios=[audio])
for vllm_prompt, _, audio in prompts_and_audios
]
def process(hf_inputs: BatchEncoding):
hf_inputs["audio_values"] = hf_inputs["audio_values"] \
.to(torch_dtype) # type: ignore
return hf_inputs
with hf_runner(model,
dtype=dtype,
postprocess_inputs=process,
auto_cls=AutoModel) as hf_model:
hf_outputs_per_audio = [
hf_model.generate_greedy_logprobs_limit(
[hf_prompt],
max_tokens,
num_logprobs=num_logprobs,
audios=[(librosa.resample(audio[0],
orig_sr=audio[1],
target_sr=16000), 16000)])
for _, hf_prompt, audio in prompts_and_audios
]
for hf_outputs, vllm_outputs in zip(hf_outputs_per_audio,
vllm_outputs_per_audio):
check_logprobs_close(
outputs_0_lst=hf_outputs,
outputs_1_lst=[
vllm_to_hf_output(vllm_output, model)
for vllm_output in vllm_outputs
],
name_0="hf",
name_1="vllm",
)
@pytest.mark.parametrize("dtype", ["half"])
@pytest.mark.parametrize("max_tokens", [128])
@pytest.mark.parametrize("num_logprobs", [5])
def test_models(hf_runner, vllm_runner, prompts_and_audios, dtype: str,
max_tokens: int, num_logprobs: int) -> None:
run_test(
hf_runner,
vllm_runner,
prompts_and_audios,
MODEL_NAME,
dtype=dtype,
max_tokens=max_tokens,
num_logprobs=num_logprobs,
tensor_parallel_size=1,
)

26
vllm/assets/audio.py Normal file
View File

@ -0,0 +1,26 @@
from dataclasses import dataclass
from typing import Literal, Tuple
from urllib.parse import urljoin
import librosa
import numpy as np
from vllm.assets.base import get_vllm_public_assets, vLLM_S3_BUCKET_URL
ASSET_DIR = "multimodal_asset"
@dataclass(frozen=True)
class AudioAsset:
name: Literal["winning_call", "mary_had_lamb"]
@property
def audio_and_sample_rate(self) -> Tuple[np.ndarray, int]:
audio_path = get_vllm_public_assets(filename=f"{self.name}.ogg",
s3_prefix=ASSET_DIR)
return librosa.load(audio_path, sr=None)
@property
def url(self) -> str:
return urljoin(vLLM_S3_BUCKET_URL, f"{ASSET_DIR}/{self.name}.ogg")

View File

@ -117,8 +117,8 @@ def _mm_token_str(model_config: ModelConfig, tokenizer: AnyTokenizer,
modality: Literal["image", "audio"]) -> Optional[str]:
# TODO: Let user specify how to insert image tokens into prompt
# (similar to chat template)
if modality == "image":
model_type = model_config.hf_config.model_type
if modality == "image":
if model_type == "phi3_v":
# Workaround since this token is not defined in the tokenizer
return "<|image_1|>"
@ -134,7 +134,9 @@ def _mm_token_str(model_config: ModelConfig, tokenizer: AnyTokenizer,
raise TypeError(f"Unknown model type: {model_type}")
elif modality == "audio":
raise TypeError("No audio models are supported yet.")
if model_type == "ultravox":
return "<|reserved_special_token_0|>"
raise TypeError(f"Unknown model type: {model_type}")
else:
raise TypeError(f"Unknown modality: {modality}")

View File

@ -61,7 +61,7 @@ _GENERATION_MODELS = {
"Phi3SmallForCausalLM": ("phi3_small", "Phi3SmallForCausalLM"),
"MedusaModel": ("medusa", "Medusa"),
"MLPSpeculatorPreTrainedModel": ("mlp_speculator", "MLPSpeculator"),
"JambaForCausalLM": ("jamba", "JambaForCausalLM")
"JambaForCausalLM": ("jamba", "JambaForCausalLM"),
}
_EMBEDDING_MODELS = {
@ -83,6 +83,7 @@ _MULTIMODAL_MODELS = {
"PaliGemmaForConditionalGeneration": ("paligemma",
"PaliGemmaForConditionalGeneration"),
"Phi3VForCausalLM": ("phi3v", "Phi3VForCausalLM"),
"UltravoxModel": ("ultravox", "UltravoxModel"),
}
_CONDITIONAL_GENERATION_MODELS = {
"BartModel": ("bart", "BartForConditionalGeneration"),

View File

@ -15,8 +15,8 @@ from vllm.model_executor.layers.activation import get_act_fn
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
RowParallelLinear)
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.multimodal.image import (cached_get_tokenizer,
repeat_and_pad_image_tokens)
from vllm.multimodal.utils import (cached_get_tokenizer,
repeat_and_pad_placeholder_tokens)
from vllm.sequence import VLLM_TOKEN_ID_ARRAY_TYPE, SequenceData
@ -97,11 +97,11 @@ def input_processor_for_blip(
else:
image_feature_size = image_feature_size_override
new_prompt, new_token_ids = repeat_and_pad_image_tokens(
new_prompt, new_token_ids = repeat_and_pad_placeholder_tokens(
tokenizer,
llm_inputs.get("prompt"),
llm_inputs["prompt_token_ids"],
image_token_id=image_token_id,
placeholder_token_id=image_token_id,
repeat_count=image_feature_size,
)

View File

@ -30,8 +30,8 @@ from vllm.model_executor.model_loader.weight_utils import (
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.model_executor.utils import set_weight_attrs
from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.image import (cached_get_tokenizer,
repeat_and_pad_image_tokens)
from vllm.multimodal.utils import (cached_get_tokenizer,
repeat_and_pad_placeholder_tokens)
from vllm.sequence import (VLLM_TOKEN_ID_ARRAY_TYPE, IntermediateTensors,
SamplerOutput, SequenceData)
from vllm.utils import print_warning_once
@ -124,11 +124,11 @@ def input_processor_for_chameleon(ctx: InputContext, llm_inputs: LLMInputs):
model_config = ctx.model_config
tokenizer = cached_get_tokenizer(model_config.tokenizer)
new_prompt, new_token_ids = repeat_and_pad_image_tokens(
new_prompt, new_token_ids = repeat_and_pad_placeholder_tokens(
tokenizer,
llm_inputs.get("prompt"),
llm_inputs["prompt_token_ids"],
image_token_id=CHAMELEON_IMAGE_TOKEN_ID,
placeholder_token_id=CHAMELEON_IMAGE_TOKEN_ID,
repeat_count=CHAMELEON_IMAGE_SEQ_LENGTH,
pad_token_left=CHAMELEON_IMAGE_START_TOKEN_ID,
pad_token_right=CHAMELEON_IMAGE_END_TOKEN_ID,

View File

@ -16,8 +16,8 @@ from vllm.model_executor.layers.linear import (ColumnParallelLinear,
RowParallelLinear)
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.multimodal.image import (cached_get_tokenizer,
repeat_and_pad_image_tokens)
from vllm.multimodal.utils import (cached_get_tokenizer,
repeat_and_pad_placeholder_tokens)
from vllm.sequence import VLLM_TOKEN_ID_ARRAY_TYPE, SequenceData
@ -103,11 +103,11 @@ def input_processor_for_clip(
else:
image_feature_size = image_feature_size_override
new_prompt, new_token_ids = repeat_and_pad_image_tokens(
new_prompt, new_token_ids = repeat_and_pad_placeholder_tokens(
tokenizer,
llm_inputs.get("prompt"),
llm_inputs["prompt_token_ids"],
image_token_id=image_token_id,
placeholder_token_id=image_token_id,
repeat_count=image_feature_size,
)

View File

@ -36,8 +36,8 @@ from vllm.model_executor.models.persimmon import PersimmonForCausalLM
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.base import MultiModalInputs
from vllm.multimodal.image import (cached_get_image_processor,
cached_get_tokenizer)
from vllm.multimodal.image import cached_get_image_processor
from vllm.multimodal.utils import cached_get_tokenizer
from vllm.sequence import (VLLM_TOKEN_ID_ARRAY_TYPE, IntermediateTensors,
SamplerOutput, SequenceData)

View File

@ -23,7 +23,7 @@ from vllm.model_executor.models.intern_vit import InternVisionModel
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.base import MultiModalInputs
from vllm.multimodal.image import cached_get_tokenizer
from vllm.multimodal.utils import cached_get_tokenizer
from vllm.sequence import IntermediateTensors, SamplerOutput
from .clip import (dummy_image_for_clip, dummy_seq_data_for_clip,

View File

@ -54,8 +54,8 @@ from vllm.model_executor.models.minicpm import MiniCPMModel
from vllm.model_executor.models.qwen2 import Qwen2Model
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.image import (cached_get_image_processor,
cached_get_tokenizer)
from vllm.multimodal.image import cached_get_image_processor
from vllm.multimodal.utils import cached_get_tokenizer
from vllm.sequence import (VLLM_TOKEN_ID_ARRAY_TYPE, IntermediateTensors,
SamplerOutput, SequenceData)

View File

@ -16,7 +16,7 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.models.gemma import GemmaModel
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.image import cached_get_tokenizer
from vllm.multimodal.utils import cached_get_tokenizer
from vllm.sequence import IntermediateTensors, SamplerOutput
from .interfaces import SupportsMultiModal

View File

@ -37,7 +37,7 @@ from vllm.model_executor.models.clip import CLIPVisionModel
from vllm.model_executor.models.llama import LlamaModel
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.image import cached_get_tokenizer
from vllm.multimodal.utils import cached_get_tokenizer
from vllm.sequence import IntermediateTensors, SamplerOutput
from .clip import (dummy_image_for_clip, dummy_seq_data_for_clip,

View File

@ -24,8 +24,8 @@ from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.vocab_parallel_embedding import (
VocabParallelEmbedding)
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.multimodal.image import (cached_get_tokenizer,
repeat_and_pad_image_tokens)
from vllm.multimodal.utils import (cached_get_tokenizer,
repeat_and_pad_placeholder_tokens)
from vllm.sequence import VLLM_TOKEN_ID_ARRAY_TYPE, SequenceData
@ -112,11 +112,11 @@ def input_processor_for_siglip(
else:
image_feature_size = image_feature_size_override
new_prompt, new_token_ids = repeat_and_pad_image_tokens(
new_prompt, new_token_ids = repeat_and_pad_placeholder_tokens(
tokenizer,
llm_inputs.get("prompt"),
llm_inputs["prompt_token_ids"],
image_token_id=image_token_id,
placeholder_token_id=image_token_id,
repeat_count=image_feature_size,
)

View File

@ -0,0 +1,435 @@
# Adapted from https://github.com/fixie-ai/ultravox/blob/ecd58c4041030bae2ad15aa6bcf04ab43199ea02/ultravox/model/ultravox_model.py
"""PyTorch Ultravox model."""
import itertools
import math
from array import array
from functools import lru_cache
from typing import (Iterable, List, Literal, Mapping, Optional, Tuple,
TypedDict, Union, cast)
import librosa
import numpy as np
import torch
import torch.utils.checkpoint
from torch import nn
from torch.nn import functional as F
from transformers.models.whisper import WhisperFeatureExtractor
from transformers.models.whisper.modeling_whisper import WhisperEncoder
from vllm.attention import AttentionMetadata
from vllm.config import CacheConfig, MultiModalConfig
from vllm.inputs import INPUT_REGISTRY
from vllm.inputs.data import LLMInputs
from vllm.inputs.registry import InputContext
from vllm.logger import init_logger
from vllm.model_executor.layers.activation import SiluAndMul, get_act_fn
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig)
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.models.interfaces import SupportsMultiModal
from vllm.model_executor.models.utils import (filter_weights,
init_vllm_registered_model,
merge_multimodal_embeddings)
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.base import MultiModalInputs
from vllm.multimodal.utils import (cached_get_tokenizer,
repeat_and_pad_placeholder_tokens)
from vllm.sequence import VLLM_TOKEN_ID_ARRAY_TYPE, SamplerOutput, SequenceData
from vllm.transformers_utils.configs.ultravox import UltravoxConfig
_AUDIO_PLACEHOLDER_TOKEN = 128002
_AUDIO_TOKENS_PER_SECOND = 6.25
logger = init_logger(__name__)
class UltravoxAudioFeatureInputs(TypedDict):
type: Literal["audio_features"]
data: Union[torch.Tensor, List[torch.Tensor]]
"""Shape: `(batch_size, 80, M)"""
class UltravoxAudioEmbeddingInputs(TypedDict):
type: Literal["audio_embeds"]
data: torch.Tensor
UltravoxAudioInputs = Union[UltravoxAudioFeatureInputs,
UltravoxAudioEmbeddingInputs]
@lru_cache
def cached_feature_extractor(model_id: str) -> WhisperFeatureExtractor:
return WhisperFeatureExtractor.from_pretrained(model_id)
def whisper_feature_extractor(ctx: InputContext) -> WhisperFeatureExtractor:
return cached_feature_extractor(
ctx.get_hf_config(UltravoxConfig).audio_model_id)
def get_ultravox_max_audio_tokens(ctx: InputContext):
feature_extractor = whisper_feature_extractor(ctx)
return math.ceil(feature_extractor.chunk_length * _AUDIO_TOKENS_PER_SECOND)
def dummy_data_for_ultravox(
ctx: InputContext,
seq_len: int,
mm_counts: Mapping[str, int],
):
feature_extractor = whisper_feature_extractor(ctx)
audio_count = mm_counts["audio"]
audio_token_ids = array(VLLM_TOKEN_ID_ARRAY_TYPE, [
_AUDIO_PLACEHOLDER_TOKEN
]) * get_ultravox_max_audio_tokens(ctx) * audio_count
other_token_ids = array(VLLM_TOKEN_ID_ARRAY_TYPE,
[0]) * (seq_len - len(audio_token_ids))
audio_and_sr = (np.array([0.0] * feature_extractor.chunk_length), 1)
mm_dict = {
"audio":
audio_and_sr if audio_count == 1 else [audio_and_sr] * audio_count
}
return (SequenceData(audio_token_ids + other_token_ids), mm_dict)
def input_mapper_for_ultravox(ctx: InputContext, data: object):
if isinstance(data, tuple):
(audio, sr) = cast(Tuple[np.ndarray, Union[float, int]], data)
feature_extractor = whisper_feature_extractor(ctx)
if sr != feature_extractor.sampling_rate:
audio = librosa.resample(audio,
orig_sr=sr,
target_sr=feature_extractor.sampling_rate)
sr = feature_extractor.sampling_rate
minimum_audio_length = feature_extractor.n_fft // 2 + 1
if len(audio) < minimum_audio_length:
# Not enough audio; pad it.
audio = np.pad(audio, (0, minimum_audio_length - len(audio)))
return MultiModalInputs({
"audio_features":
feature_extractor(audio,
sampling_rate=sr,
padding="longest",
return_tensors="pt")["input_features"]
})
raise NotImplementedError(f"Unsupported data type: {type(data)}")
def input_processor_for_ultravox(ctx: InputContext, llm_inputs: LLMInputs):
multi_modal_data = llm_inputs.get("multi_modal_data")
if multi_modal_data is None or "audio" not in multi_modal_data:
return llm_inputs
feature_extractor = whisper_feature_extractor(ctx)
audio_data, sample_rate = multi_modal_data["audio"]
audio_length = audio_data.shape[0]
if sample_rate != feature_extractor.sampling_rate:
# Account for resampling.
adjustment = feature_extractor.sampling_rate / sample_rate
audio_length = math.ceil(adjustment * audio_length)
feature_extractor_output_length = math.ceil(
(audio_length -
(feature_extractor.hop_length - 1)) / feature_extractor.hop_length)
uv_config = ctx.get_hf_config(UltravoxConfig)
audio_num_tokens = min(
max(
1,
math.ceil(feature_extractor_output_length /
(uv_config.stack_factor * 2))),
get_ultravox_max_audio_tokens(ctx))
tokenizer = cached_get_tokenizer(ctx.model_config.tokenizer)
new_prompt, new_token_ids = repeat_and_pad_placeholder_tokens(
tokenizer,
llm_inputs.get("prompt"),
llm_inputs["prompt_token_ids"],
placeholder_token_id=_AUDIO_PLACEHOLDER_TOKEN,
repeat_count=audio_num_tokens,
)
# NOTE: Create a defensive copy of the original inputs
return LLMInputs(prompt_token_ids=new_token_ids,
prompt=new_prompt,
multi_modal_data=multi_modal_data)
class StackAudioFrames(nn.Module):
"""
Stack the audio embedding frames to reduce the sequence length by a factor
of `stack_factor`.
"""
def __init__(self, stack_factor: int = 8):
super().__init__()
self.stack_factor = stack_factor
def forward(self, audio_embeds: torch.Tensor) -> torch.Tensor:
B, T, C = audio_embeds.shape
T_pad = (T + self.stack_factor -
1) // self.stack_factor * self.stack_factor
audio_embeds = F.pad(audio_embeds, (0, 0, 0, T_pad - T))
B, T, C = audio_embeds.shape
audio_embeds = audio_embeds.view(B, T // self.stack_factor,
C * self.stack_factor)
return audio_embeds
class FlippedSiluAndMul(SiluAndMul):
"""Ultravox is trained with SwiGLU with flipped halves."""
def forward(self, x: torch.Tensor):
a, b = x.chunk(2, dim=-1)
flipped = torch.cat((b, a), dim=-1)
return super().forward(flipped)
class UltravoxProjector(nn.Module):
def __init__(self, config: UltravoxConfig):
super().__init__()
self.hidden_dim = config.hidden_size
self._pad_and_stack = StackAudioFrames(config.stack_factor)
dim = config.audio_config.hidden_size * config.stack_factor
self.ln_pre = RMSNorm(dim)
self.linear_1 = nn.Linear(dim, self.hidden_dim, bias=False)
dim = self.hidden_dim
if config.projector_act == "swiglu":
self.act = FlippedSiluAndMul()
dim = dim // 2
else:
self.act = get_act_fn(config.projector_act)
self.linear_2 = nn.Linear(dim,
config.text_config.hidden_size,
bias=False)
self.ln_post = RMSNorm(config.text_config.hidden_size)
def forward(self, audio_features: torch.Tensor) -> torch.Tensor:
audio_features = self._pad_and_stack(audio_features)
audio_features = self.ln_pre(audio_features)
hidden_states = self.linear_1(audio_features)
hidden_states = self.act(hidden_states)
hidden_states = self.linear_2(hidden_states)
hidden_states = self.ln_post(hidden_states)
return hidden_states
class ModifiedWhisperEncoder(WhisperEncoder):
"""
Encoder portion of OpenAI's Whisper model.
This implementation is a slightly modified version of HF Transformers'
Whisper Encoder, with only a few fixes:
1. base_model_prefix updated to allow for doing `.from_pretrained`
directly on the encoder
2. allow less than 30 second of audio padding to be passed in:
- relaxed ValueError check for `input_features` length to be less
than or equal to `expected_seq_length` instead of strictly equal
- embed_pos is now sliced to match the length of `inputs_embeds`
Original: https://github.com/huggingface/transformers/blob/main/src/transformers/models/whisper/modeling_whisper.py
See commentary: https://github.com/huggingface/transformers/issues/25744
"""
base_model_prefix = "model.encoder"
def forward(
self,
input_features,
):
expected_seq_length = (self.config.max_source_positions *
self.conv1.stride[0] * self.conv2.stride[0])
if input_features.shape[-1] > expected_seq_length:
raise ValueError(
f"Whisper expects the mel input features to be of length "
f"{expected_seq_length} or less, but found "
f"{input_features.shape[-1]}. Make sure to pad the input mel "
f"features to {expected_seq_length}.")
inputs_embeds = nn.functional.gelu(self.conv1(input_features))
inputs_embeds = nn.functional.gelu(self.conv2(inputs_embeds))
inputs_embeds = inputs_embeds.permute(0, 2, 1)
embed_pos = self.embed_positions.weight[:inputs_embeds.size(-2)]
hidden_states = inputs_embeds + embed_pos
hidden_states = nn.functional.dropout(hidden_states,
p=self.dropout,
training=self.training)
for encoder_layer in self.layers:
layer_outputs = encoder_layer(
hidden_states,
None,
layer_head_mask=None,
)
hidden_states = layer_outputs[0]
hidden_states = self.layer_norm(hidden_states)
return hidden_states
@MULTIMODAL_REGISTRY.register_input_mapper("audio", input_mapper_for_ultravox)
@MULTIMODAL_REGISTRY.register_max_multimodal_tokens(
"audio", get_ultravox_max_audio_tokens)
@INPUT_REGISTRY.register_dummy_data(dummy_data_for_ultravox)
@INPUT_REGISTRY.register_input_processor(input_processor_for_ultravox)
class UltravoxModel(nn.Module, SupportsMultiModal):
def __init__(self,
config: UltravoxConfig,
multimodal_config: MultiModalConfig,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional["QuantizationConfig"] = None):
super().__init__()
self.config = config
self.multi_modal_config = multimodal_config
assert self.multi_modal_config
if config.audio_model_id is not None:
self.audio_tower = ModifiedWhisperEncoder.from_pretrained(
config.audio_model_id)
else:
self.audio_tower = ModifiedWhisperEncoder(config.audio_config)
self.multi_modal_projector = UltravoxProjector(config)
self.language_model = init_vllm_registered_model(
config.text_config, cache_config, quant_config)
def _audio_features_to_embeddings(
self, input_features: torch.Tensor) -> torch.Tensor:
audio_input = input_features.to(self.audio_tower.dtype)
audio_features = self.audio_tower(audio_input)
audio_features = audio_features.to(self.audio_tower.dtype)
audio_embeddings = self.multi_modal_projector(audio_features)
return audio_embeddings
def _parse_and_validate_audio_input(
self, **kwargs: object) -> Optional[UltravoxAudioInputs]:
audio_features = kwargs.pop("audio_features", None)
audio_embeds = kwargs.pop("audio_embeds", None)
if audio_features is None and audio_embeds is None:
return None
if audio_features is not None:
if not isinstance(audio_features, (torch.Tensor, list)):
raise ValueError("Incorrect type of audio features. "
f"Got type: {type(audio_features)}")
return UltravoxAudioFeatureInputs(type="audio_features",
data=audio_features)
if audio_embeds is not None:
if not isinstance(audio_embeds, torch.Tensor):
raise ValueError("Incorrect type of audio embeds. "
f"Got type: {type(audio_embeds)}")
return UltravoxAudioEmbeddingInputs(type="audio_embeds",
data=audio_embeds)
raise AssertionError("This line should be unreachable.")
def _process_audio_input(
self, audio_input: UltravoxAudioInputs
) -> Union[torch.Tensor, List[torch.Tensor]]:
if audio_input["type"] == "audio_embeds":
return audio_input["data"]
audio_features = audio_input["data"]
if isinstance(audio_features, list):
# TODO: Batch these through the encoder/projector instead of
# serializing them.
return [
self._audio_features_to_embeddings(
features.unsqueeze(0)).squeeze(0)
for features in audio_features
]
else:
return self._audio_features_to_embeddings(audio_features)
def forward(self, input_ids: torch.Tensor, positions: torch.Tensor,
kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata,
intermediate_tensors: Optional[torch.Tensor],
**kwargs) -> SamplerOutput:
"""Run forward pass for Ultravox
One key thing to understand is the `input_ids` already accounts for the
positions of the to-be-inserted audio embeddings. The to-be-inserted
audio has a size that is essentially 6.25 tokens per second of audio.
This way, the `positions` and `attn_metadata` are consistent
with the `input_ids`.
Args:
input_features: A batch of audio inputs, [1, 80, M].
"""
audio_input = self._parse_and_validate_audio_input(**kwargs)
if audio_input is not None:
audio_embeddings = self._process_audio_input(audio_input)
inputs_embeds = self.language_model.model.get_input_embeddings(
input_ids)
inputs_embeds = merge_multimodal_embeddings(
input_ids, inputs_embeds, audio_embeddings,
_AUDIO_PLACEHOLDER_TOKEN)
input_ids = None
else:
inputs_embeds = None
hidden_states = self.language_model.model(
input_ids=input_ids,
positions=positions,
kv_caches=kv_caches,
attn_metadata=attn_metadata,
intermediate_tensors=intermediate_tensors,
inputs_embeds=inputs_embeds)
return hidden_states
def compute_logits(self, hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata) -> torch.Tensor:
return self.language_model.compute_logits(hidden_states,
sampling_metadata)
def sample(
self,
logits: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> Optional[SamplerOutput]:
return self.language_model.sample(logits, sampling_metadata)
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
# prepare weight iterators for components
projector_weights, llm_weights = itertools.tee(weights, 2)
# load projector weights
projector_weights = filter_weights(projector_weights,
"multi_modal_projector")
projector_params_dict = dict(
self.multi_modal_projector.named_parameters())
for name, loaded_weight in projector_weights:
param = projector_params_dict[name]
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
weight_loader(param, loaded_weight)
# load llm backbone
llm_weights = filter_weights(llm_weights, "language_model")
self.language_model.load_weights(llm_weights)

View File

@ -1,5 +1,4 @@
from functools import lru_cache
from typing import List, Optional, Tuple, TypeVar
import torch
from PIL import Image
@ -8,7 +7,6 @@ from vllm.config import ModelConfig
from vllm.inputs.registry import InputContext
from vllm.logger import init_logger
from vllm.transformers_utils.image_processor import get_image_processor
from vllm.transformers_utils.tokenizer import AnyTokenizer, get_tokenizer
from vllm.utils import is_list_of
from .base import MultiModalData, MultiModalInputs, MultiModalPlugin
@ -16,87 +14,6 @@ from .base import MultiModalData, MultiModalInputs, MultiModalPlugin
logger = init_logger(__name__)
cached_get_image_processor = lru_cache(get_image_processor)
cached_get_tokenizer = lru_cache(get_tokenizer)
# Utilities for image input processors
_T = TypeVar("_T", str, int)
def repeat_and_pad_token(
token: _T,
*,
repeat_count: int = 1,
pad_token_left: Optional[_T] = None,
pad_token_right: Optional[_T] = None,
) -> List[_T]:
replacement = [token] * repeat_count
if pad_token_left is not None:
replacement = [pad_token_left] + replacement
if pad_token_right is not None:
replacement = replacement + [pad_token_right]
return replacement
def repeat_and_pad_image_tokens(
tokenizer: AnyTokenizer,
prompt: Optional[str],
prompt_token_ids: List[int],
*,
image_token_id: int,
repeat_count: int = 1,
pad_token_left: Optional[int] = None,
pad_token_right: Optional[int] = None,
) -> Tuple[Optional[str], List[int]]:
if prompt is None:
new_prompt = None
else:
image_token_str = tokenizer.decode(image_token_id)
pad_token_str_left = (None if pad_token_left is None else
tokenizer.decode(pad_token_left))
pad_token_str_right = (None if pad_token_right is None else
tokenizer.decode(pad_token_right))
replacement_str = "".join(
repeat_and_pad_token(
image_token_str,
repeat_count=repeat_count,
pad_token_left=pad_token_str_left,
pad_token_right=pad_token_str_right,
))
image_token_count = prompt.count(image_token_str)
# This is an arbitrary number to distinguish between the two cases
if image_token_count > 16:
logger.warning(
"Please follow the prompt format that is "
"documented on HuggingFace which does not involve "
"repeating %s tokens.", image_token_str)
elif image_token_count > 1:
logger.warning("Multiple image input is not supported yet, "
"so any extra image tokens will be treated "
"as plain text.")
# The image tokens are removed to be consistent with HuggingFace
new_prompt = prompt.replace(image_token_str, replacement_str, 1)
new_token_ids: List[int] = []
for i, token in enumerate(prompt_token_ids):
if token == image_token_id:
replacement_ids = repeat_and_pad_token(
image_token_id,
repeat_count=repeat_count,
pad_token_left=pad_token_left,
pad_token_right=pad_token_right,
)
new_token_ids.extend(replacement_ids)
# No need to further scan the list since we only replace once
new_token_ids.extend(prompt_token_ids[i + 1:])
break
else:
new_token_ids.append(token)
return new_prompt, new_token_ids
class ImagePlugin(MultiModalPlugin):

View File

@ -1,6 +1,7 @@
import base64
from functools import lru_cache
from io import BytesIO
from typing import Tuple, Union
from typing import List, Optional, Tuple, TypeVar, Union
import librosa
import numpy as np
@ -9,7 +10,13 @@ from PIL import Image
from vllm.connections import global_http_connection
from vllm.envs import VLLM_AUDIO_FETCH_TIMEOUT, VLLM_IMAGE_FETCH_TIMEOUT
from vllm.logger import init_logger
from vllm.multimodal.base import MultiModalDataDict
from vllm.transformers_utils.tokenizer import AnyTokenizer, get_tokenizer
logger = init_logger(__name__)
cached_get_tokenizer = lru_cache(get_tokenizer)
def _load_image_from_bytes(b: bytes):
@ -154,3 +161,84 @@ def rescale_image_size(image: Image.Image,
if transpose >= 0:
image = image.transpose(Image.Transpose(transpose))
return image
# Utilities for input processors
_T = TypeVar("_T", str, int)
def repeat_and_pad_token(
token: _T,
*,
repeat_count: int = 1,
pad_token_left: Optional[_T] = None,
pad_token_right: Optional[_T] = None,
) -> List[_T]:
replacement = [token] * repeat_count
if pad_token_left is not None:
replacement = [pad_token_left] + replacement
if pad_token_right is not None:
replacement = replacement + [pad_token_right]
return replacement
def repeat_and_pad_placeholder_tokens(
tokenizer: AnyTokenizer,
prompt: Optional[str],
prompt_token_ids: List[int],
*,
placeholder_token_id: int,
repeat_count: int = 1,
pad_token_left: Optional[int] = None,
pad_token_right: Optional[int] = None,
) -> Tuple[Optional[str], List[int]]:
if prompt is None:
new_prompt = None
else:
placeholder_token_str = tokenizer.decode(placeholder_token_id)
pad_token_str_left = (None if pad_token_left is None else
tokenizer.decode(pad_token_left))
pad_token_str_right = (None if pad_token_right is None else
tokenizer.decode(pad_token_right))
replacement_str = "".join(
repeat_and_pad_token(
placeholder_token_str,
repeat_count=repeat_count,
pad_token_left=pad_token_str_left,
pad_token_right=pad_token_str_right,
))
placeholder_token_count = prompt.count(placeholder_token_str)
# This is an arbitrary number to distinguish between the two cases
if placeholder_token_count > 16:
logger.warning(
"Please follow the prompt format that is "
"documented on HuggingFace which does not involve "
"repeating %s tokens.", placeholder_token_str)
elif placeholder_token_count > 1:
logger.warning("Multiple multi-modal input is not supported yet, "
"so any extra placeholder tokens will be treated "
"as plain text.")
# The image tokens are removed to be consistent with HuggingFace
new_prompt = prompt.replace(placeholder_token_str, replacement_str, 1)
new_token_ids: List[int] = []
for i, token in enumerate(prompt_token_ids):
if token == placeholder_token_id:
replacement_ids = repeat_and_pad_token(
placeholder_token_id,
repeat_count=repeat_count,
pad_token_left=pad_token_left,
pad_token_right=pad_token_right,
)
new_token_ids.extend(replacement_ids)
# No need to further scan the list since we only replace once
new_token_ids.extend(prompt_token_ids[i + 1:])
break
else:
new_token_ids.append(token)
return new_prompt, new_token_ids

View File

@ -12,7 +12,7 @@ from vllm.transformers_utils.configs import (ChatGLMConfig, DbrxConfig,
InternVLChatConfig, JAISConfig,
MedusaConfig, MLPSpeculatorConfig,
MPTConfig, NemotronConfig,
RWConfig)
RWConfig, UltravoxConfig)
if VLLM_USE_MODELSCOPE:
from modelscope import AutoConfig
@ -32,6 +32,7 @@ _CONFIG_REGISTRY: Dict[str, Type[PretrainedConfig]] = {
"medusa": MedusaConfig,
"internvl_chat": InternVLChatConfig,
"nemotron": NemotronConfig,
"ultravox": UltravoxConfig,
}
for name, cls in _CONFIG_REGISTRY.items():

View File

@ -10,6 +10,7 @@ from vllm.transformers_utils.configs.medusa import MedusaConfig
from vllm.transformers_utils.configs.mlp_speculator import MLPSpeculatorConfig
from vllm.transformers_utils.configs.mpt import MPTConfig
from vllm.transformers_utils.configs.nemotron import NemotronConfig
from vllm.transformers_utils.configs.ultravox import UltravoxConfig
__all__ = [
"ChatGLMConfig",
@ -21,4 +22,5 @@ __all__ = [
"MedusaConfig",
"MLPSpeculatorConfig",
"NemotronConfig",
"UltravoxConfig",
]

View File

@ -0,0 +1,99 @@
# Adapted from https://github.com/fixie-ai/ultravox/blob/ecd58c4041030bae2ad15aa6bcf04ab43199ea02/ultravox/model/ultravox_config.py
from typing import Any, Dict, Optional
import transformers
class UltravoxConfig(transformers.PretrainedConfig):
r"""
This is the configuration class to store the configuration of a
[`UltravoxForConditionalGeneration`]. It is used to instantiate an
Ultravox model according to the specified arguments, defining the model
architecture.
Configuration objects inherit from [`PretrainedConfig`] and can be used to
control the model outputs. Read the documentation from [`PretrainedConfig`]
for more information.
Args:
audio_config (`Union[AutoConfig, dict]`, *optional*):
Custom audio config or dict
text_config (`Union[AutoConfig, dict]`, *optional*):
The config object of the text backbone. Can be any of `LlamaConfig`
or `MistralConfig`.
ignore_index (`int`, *optional*, defaults to -100):
The ignore index for the loss function.
audio_token_index (`int`, *optional*, defaults to 32000):
The audio token index to encode the audio prompt.
stack_factor (`int`, *optional*, defaults to 8):
Audio downsampling factor for the multimodal projector.
norm_init (`float`, *optional*, defaults to 0.4):
The initialization value for the layer normalization.
projector_act (`str`, *optional*, defaults to `"swiglu"`):
The activation function used by the multimodal projector.
text_model_lora_config (`LoraConfigSimplified`, *optional*):
The LoRA configuration for finetuning the text model.
audio_model_lora_config (`LoraConfigSimplified`, *optional*):
The LoRA configuration for finetuning the audio model.
"""
model_type = "ultravox"
is_composition = False
def __init__(
self,
audio_config: Optional[Dict[str, Any]] = None,
text_config: Optional[Dict[str, Any]] = None,
audio_model_id: Optional[str] = None,
text_model_id: Optional[str] = None,
ignore_index: int = -100,
audio_token_index: int = 32000,
hidden_size: int = 4096,
stack_factor: int = 8,
norm_init: float = 0.4,
projector_act: str = "swiglu",
text_model_lora_config: Optional[Dict[str, Any]] = None,
audio_model_lora_config: Optional[Dict[str, Any]] = None,
**kwargs,
):
self.ignore_index = ignore_index
self.audio_model_id = audio_model_id
self.text_model_id = text_model_id
self.audio_token_index = audio_token_index
self.hidden_size = hidden_size
self.stack_factor = stack_factor
self.norm_init = norm_init
self.projector_act = projector_act
if text_model_id is not None:
# Avoid circular import
from vllm.transformers_utils.config import get_config
self.text_config = get_config(text_model_id,
trust_remote_code=False)
else:
text_config = text_config or {}
self.text_config = transformers.CONFIG_MAPPING[text_config.get(
"model_type", "llama")](**text_config)
if audio_model_id is not None:
# Avoid circular import
from vllm.transformers_utils.config import get_config
self.audio_config = get_config(audio_model_id,
trust_remote_code=False)
else:
audio_config = audio_config or {}
self.audio_config = transformers.CONFIG_MAPPING[audio_config.get(
"model_type", "whisper")](**audio_config)
self.text_model_lora_config = text_model_lora_config or {}
self.audio_model_lora_config = audio_model_lora_config or {}
self.vocab_size = self.text_config.vocab_size
self.initializer_range = self.text_config.initializer_range
super().__init__(**kwargs)