[Model] Add UltravoxModel and UltravoxConfig (#7615)
This commit is contained in:
parent
dd53c4b023
commit
1ca0d4f86b
@ -186,7 +186,7 @@ Multimodal Language Models
|
||||
|
||||
* - Architecture
|
||||
- Models
|
||||
- Supported Modality(ies)
|
||||
- Supported Modalities
|
||||
- Example HuggingFace Models
|
||||
- :ref:`LoRA <lora>`
|
||||
* - :code:`Blip2ForConditionalGeneration`
|
||||
@ -234,6 +234,11 @@ Multimodal Language Models
|
||||
- Image
|
||||
- :code:`openbmb/MiniCPM-V-2` (see note), :code:`openbmb/MiniCPM-Llama3-V-2_5`, :code:`openbmb/MiniCPM-V-2_6`, etc.
|
||||
-
|
||||
* - :code: `UltravoxModel`
|
||||
- Ultravox
|
||||
- Audio
|
||||
- :code: `fixie-ai/ultravox-v0_3`
|
||||
-
|
||||
|
||||
.. note::
|
||||
For :code:`openbmb/MiniCPM-V-2`, the official repo doesn't work yet, so we need to use a fork (:code:`HwwwH/MiniCPM-V-2`) for now.
|
||||
|
97
examples/offline_inference_audio_language.py
Normal file
97
examples/offline_inference_audio_language.py
Normal file
@ -0,0 +1,97 @@
|
||||
"""
|
||||
This example shows how to use vLLM for running offline inference
|
||||
with the correct prompt format on vision language models.
|
||||
|
||||
For most models, the prompt format should follow corresponding examples
|
||||
on HuggingFace model repository.
|
||||
"""
|
||||
from transformers import AutoTokenizer
|
||||
|
||||
from vllm import LLM, SamplingParams
|
||||
from vllm.assets.audio import AudioAsset
|
||||
from vllm.utils import FlexibleArgumentParser
|
||||
|
||||
# Input audio and question
|
||||
audio_and_sample_rate = AudioAsset("mary_had_lamb").audio_and_sample_rate
|
||||
question = "What is recited in the audio?"
|
||||
|
||||
|
||||
# Ultravox 0.3
|
||||
def run_ultravox(question):
|
||||
model_name = "fixie-ai/ultravox-v0_3"
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
||||
messages = [{
|
||||
'role': 'user',
|
||||
'content': f"<|reserved_special_token_0|>\n{question}"
|
||||
}]
|
||||
prompt = tokenizer.apply_chat_template(messages,
|
||||
tokenize=False,
|
||||
add_generation_prompt=True)
|
||||
|
||||
llm = LLM(model=model_name)
|
||||
stop_token_ids = None
|
||||
return llm, prompt, stop_token_ids
|
||||
|
||||
|
||||
model_example_map = {
|
||||
"ultravox": run_ultravox,
|
||||
}
|
||||
|
||||
|
||||
def main(args):
|
||||
model = args.model_type
|
||||
if model not in model_example_map:
|
||||
raise ValueError(f"Model type {model} is not supported.")
|
||||
|
||||
llm, prompt, stop_token_ids = model_example_map[model](question)
|
||||
|
||||
# We set temperature to 0.2 so that outputs can be different
|
||||
# even when all prompts are identical when running batch inference.
|
||||
sampling_params = SamplingParams(temperature=0.2,
|
||||
max_tokens=64,
|
||||
stop_token_ids=stop_token_ids)
|
||||
|
||||
assert args.num_prompts > 0
|
||||
if args.num_prompts == 1:
|
||||
# Single inference
|
||||
inputs = {
|
||||
"prompt": prompt,
|
||||
"multi_modal_data": {
|
||||
"audio": audio_and_sample_rate
|
||||
},
|
||||
}
|
||||
|
||||
else:
|
||||
# Batch inference
|
||||
inputs = [{
|
||||
"prompt": prompt,
|
||||
"multi_modal_data": {
|
||||
"audio": audio_and_sample_rate
|
||||
},
|
||||
} for _ in range(args.num_prompts)]
|
||||
|
||||
outputs = llm.generate(inputs, sampling_params=sampling_params)
|
||||
|
||||
for o in outputs:
|
||||
generated_text = o.outputs[0].text
|
||||
print(generated_text)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = FlexibleArgumentParser(
|
||||
description='Demo on using vLLM for offline inference with '
|
||||
'audio language models')
|
||||
parser.add_argument('--model-type',
|
||||
'-m',
|
||||
type=str,
|
||||
default="ultravox",
|
||||
choices=model_example_map.keys(),
|
||||
help='Huggingface "model_type".')
|
||||
parser.add_argument('--num-prompts',
|
||||
type=int,
|
||||
default=1,
|
||||
help='Number of prompts to run.')
|
||||
|
||||
args = parser.parse_args()
|
||||
main(args)
|
90
examples/openai_audio_api_client.py
Normal file
90
examples/openai_audio_api_client.py
Normal file
@ -0,0 +1,90 @@
|
||||
"""An example showing how to use vLLM to serve VLMs.
|
||||
|
||||
Launch the vLLM server with the following command:
|
||||
vllm serve fixie-ai/ultravox-v0_3
|
||||
"""
|
||||
import base64
|
||||
|
||||
import requests
|
||||
from openai import OpenAI
|
||||
|
||||
from vllm.assets.audio import AudioAsset
|
||||
|
||||
# Modify OpenAI's API key and API base to use vLLM's API server.
|
||||
openai_api_key = "EMPTY"
|
||||
openai_api_base = "http://localhost:8000/v1"
|
||||
|
||||
client = OpenAI(
|
||||
# defaults to os.environ.get("OPENAI_API_KEY")
|
||||
api_key=openai_api_key,
|
||||
base_url=openai_api_base,
|
||||
)
|
||||
|
||||
models = client.models.list()
|
||||
model = models.data[0].id
|
||||
|
||||
# Any format supported by librosa is supported
|
||||
audio_url = AudioAsset("winning_call").url
|
||||
|
||||
# Use audio url in the payload
|
||||
chat_completion_from_url = client.chat.completions.create(
|
||||
messages=[{
|
||||
"role":
|
||||
"user",
|
||||
"content": [
|
||||
{
|
||||
"type": "text",
|
||||
"text": "What's in this audio?"
|
||||
},
|
||||
{
|
||||
"type": "audio_url",
|
||||
"audio_url": {
|
||||
"url": audio_url
|
||||
},
|
||||
},
|
||||
],
|
||||
}],
|
||||
model=model,
|
||||
max_tokens=64,
|
||||
)
|
||||
|
||||
result = chat_completion_from_url.choices[0].message.content
|
||||
print(f"Chat completion output:{result}")
|
||||
|
||||
|
||||
# Use base64 encoded audio in the payload
|
||||
def encode_audio_base64_from_url(audio_url: str) -> str:
|
||||
"""Encode an audio retrieved from a remote url to base64 format."""
|
||||
|
||||
with requests.get(audio_url) as response:
|
||||
response.raise_for_status()
|
||||
result = base64.b64encode(response.content).decode('utf-8')
|
||||
|
||||
return result
|
||||
|
||||
|
||||
audio_base64 = encode_audio_base64_from_url(audio_url=audio_url)
|
||||
chat_completion_from_base64 = client.chat.completions.create(
|
||||
messages=[{
|
||||
"role":
|
||||
"user",
|
||||
"content": [
|
||||
{
|
||||
"type": "text",
|
||||
"text": "What's in this audio?"
|
||||
},
|
||||
{
|
||||
"type": "audio_url",
|
||||
"audio_url": {
|
||||
# Any format supported by librosa is supported
|
||||
"url": f"data:audio/ogg;base64,{audio_base64}"
|
||||
},
|
||||
},
|
||||
],
|
||||
}],
|
||||
model=model,
|
||||
max_tokens=64,
|
||||
)
|
||||
|
||||
result = chat_completion_from_base64.choices[0].message.content
|
||||
print(f"Chat completion output:{result}")
|
@ -9,14 +9,14 @@ from enum import Enum
|
||||
from typing import (Any, Callable, Dict, List, Optional, Tuple, TypedDict,
|
||||
TypeVar, Union)
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from huggingface_hub import snapshot_download
|
||||
from PIL import Image
|
||||
from transformers import (AutoModelForCausalLM, AutoModelForSeq2SeqLM,
|
||||
AutoModelForVision2Seq, AutoTokenizer, BatchEncoding,
|
||||
from transformers import (AutoModelForCausalLM, AutoTokenizer, BatchEncoding,
|
||||
BatchFeature)
|
||||
|
||||
from vllm import LLM, SamplingParams
|
||||
@ -216,8 +216,7 @@ class HfRunner:
|
||||
*,
|
||||
model_kwargs: Optional[Dict[str, Any]] = None,
|
||||
is_embedding_model: bool = False,
|
||||
is_vision_model: bool = False,
|
||||
is_encoder_decoder_model: bool = False,
|
||||
auto_cls=AutoModelForCausalLM,
|
||||
postprocess_inputs: Callable[[BatchEncoding],
|
||||
BatchEncoding] = identity,
|
||||
) -> None:
|
||||
@ -234,13 +233,6 @@ class HfRunner:
|
||||
device="cpu",
|
||||
).to(dtype=torch_dtype))
|
||||
else:
|
||||
if is_vision_model:
|
||||
auto_cls = AutoModelForVision2Seq
|
||||
elif is_encoder_decoder_model:
|
||||
auto_cls = AutoModelForSeq2SeqLM
|
||||
else:
|
||||
auto_cls = AutoModelForCausalLM
|
||||
|
||||
model_kwargs = model_kwargs if model_kwargs is not None else {}
|
||||
self.model = self.wrap_device(
|
||||
auto_cls.from_pretrained(
|
||||
@ -432,6 +424,7 @@ class HfRunner:
|
||||
max_tokens: int,
|
||||
num_logprobs: int,
|
||||
images: Optional[List[Image.Image]] = None,
|
||||
audios: Optional[List[Tuple[np.ndarray, int]]] = None,
|
||||
**kwargs: Any,
|
||||
) -> List[Tuple[List[int], str, List[Dict[int, float]]]]:
|
||||
all_logprobs: List[List[Dict[int, float]]] = []
|
||||
@ -446,6 +439,11 @@ class HfRunner:
|
||||
if images is not None and images[i] is not None:
|
||||
processor_kwargs["images"] = images[i]
|
||||
|
||||
if audios is not None:
|
||||
audio, sr = audios[i]
|
||||
processor_kwargs["audio"] = audio
|
||||
processor_kwargs["sampling_rate"] = sr
|
||||
|
||||
inputs = self.processor(**processor_kwargs)
|
||||
inputs = self.postprocess_inputs(inputs)
|
||||
|
||||
@ -627,6 +625,8 @@ class VllmRunner:
|
||||
sampling_params: SamplingParams,
|
||||
images: Optional[Union[List[Image.Image],
|
||||
List[List[Image.Image]]]] = None,
|
||||
audios: Optional[Union[List[Tuple[np.ndarray, int]],
|
||||
List[List[Tuple[np.ndarray, int]]]]] = None
|
||||
) -> List[Tuple[List[int], str, Optional[SampleLogprobs]]]:
|
||||
assert sampling_params.logprobs is not None
|
||||
|
||||
@ -638,6 +638,10 @@ class VllmRunner:
|
||||
for i, image in enumerate(images):
|
||||
inputs[i]["multi_modal_data"] = {"image": image}
|
||||
|
||||
if audios is not None:
|
||||
for i, audio in enumerate(audios):
|
||||
inputs[i]["multi_modal_data"] = {"audio": audio}
|
||||
|
||||
req_outputs = self.model.generate(inputs,
|
||||
sampling_params=sampling_params)
|
||||
return self._final_steps_generate_w_logprobs(req_outputs)
|
||||
@ -674,6 +678,8 @@ class VllmRunner:
|
||||
num_logprobs: int,
|
||||
images: Optional[Union[List[Image.Image],
|
||||
List[List[Image.Image]]]] = None,
|
||||
audios: Optional[Union[List[Tuple[np.ndarray, int]],
|
||||
List[List[Tuple[np.ndarray, int]]]]] = None,
|
||||
stop_token_ids: Optional[List[int]] = None,
|
||||
) -> List[Tuple[List[int], str, Optional[SampleLogprobs]]]:
|
||||
greedy_logprobs_params = SamplingParams(temperature=0.0,
|
||||
@ -682,7 +688,8 @@ class VllmRunner:
|
||||
stop_token_ids=stop_token_ids)
|
||||
outputs = self.generate_w_logprobs(prompts,
|
||||
greedy_logprobs_params,
|
||||
images=images)
|
||||
images=images,
|
||||
audios=audios)
|
||||
|
||||
return [(output_ids, output_str, output_logprobs)
|
||||
for output_ids, output_str, output_logprobs in outputs]
|
||||
|
@ -10,6 +10,7 @@ pytest distributed/test_basic_distributed_correctness_enc_dec.py
|
||||
"""
|
||||
|
||||
import pytest
|
||||
from transformers import AutoModelForSeq2SeqLM
|
||||
|
||||
from vllm.utils import cuda_device_count_stateless
|
||||
|
||||
@ -85,7 +86,7 @@ def test_models(
|
||||
}
|
||||
|
||||
with hf_runner(model, dtype=dtype,
|
||||
is_encoder_decoder_model=True) as hf_model:
|
||||
auto_cls=AutoModelForSeq2SeqLM) as hf_model:
|
||||
hf_outputs = (hf_model.generate_encoder_decoder_greedy_logprobs_limit(
|
||||
test_prompts,
|
||||
max_tokens,
|
||||
|
@ -1,138 +1,36 @@
|
||||
import math
|
||||
import sys
|
||||
import time
|
||||
from typing import Dict, List, Optional, Tuple, Union, cast
|
||||
from unittest.mock import patch
|
||||
from typing import Dict, List
|
||||
|
||||
import librosa
|
||||
import numpy as np
|
||||
import openai
|
||||
import pytest
|
||||
import requests
|
||||
import torch
|
||||
|
||||
from vllm import ModelRegistry
|
||||
from vllm.config import MultiModalConfig
|
||||
from vllm.inputs import INPUT_REGISTRY
|
||||
from vllm.inputs.data import LLMInputs
|
||||
from vllm.inputs.registry import InputContext
|
||||
from vllm.model_executor.models.interfaces import SupportsMultiModal
|
||||
from vllm.model_executor.models.opt import OPTForCausalLM
|
||||
from vllm.multimodal import MULTIMODAL_REGISTRY
|
||||
from vllm.multimodal.base import MultiModalInputs
|
||||
from vllm.multimodal.image import (cached_get_tokenizer,
|
||||
repeat_and_pad_image_tokens)
|
||||
from vllm.assets.audio import AudioAsset
|
||||
from vllm.multimodal.utils import encode_audio_base64, fetch_audio
|
||||
from vllm.utils import get_open_port
|
||||
|
||||
from ...utils import VLLM_PATH
|
||||
from ...utils import RemoteOpenAIServer
|
||||
|
||||
chatml_jinja_path = VLLM_PATH / "examples/template_chatml.jinja"
|
||||
assert chatml_jinja_path.exists()
|
||||
|
||||
MODEL_NAME = "facebook/opt-125m"
|
||||
MODEL_NAME = "fixie-ai/ultravox-v0_3"
|
||||
TEST_AUDIO_URLS = [
|
||||
"https://upload.wikimedia.org/wikipedia/en/b/bf/Dave_Niehaus_Winning_Call_1995_AL_Division_Series.ogg",
|
||||
AudioAsset("winning_call").url,
|
||||
]
|
||||
|
||||
|
||||
def server_function(port):
|
||||
@pytest.fixture(scope="module")
|
||||
def server():
|
||||
args = [
|
||||
"--dtype",
|
||||
"bfloat16",
|
||||
"--max-model-len",
|
||||
"4096",
|
||||
"--enforce-eager",
|
||||
]
|
||||
|
||||
def fake_input_mapper(ctx: InputContext, data: object):
|
||||
assert isinstance(data, tuple)
|
||||
(audio, sr) = cast(Tuple[np.ndarray, Union[float, int]], data)
|
||||
|
||||
# Resample it to 1 sample per second
|
||||
audio = librosa.resample(audio, orig_sr=sr, target_sr=1)
|
||||
return MultiModalInputs({"processed_audio": torch.from_numpy(audio)})
|
||||
|
||||
def fake_input_processor(ctx: InputContext, llm_inputs: LLMInputs):
|
||||
multi_modal_data = llm_inputs.get("multi_modal_data")
|
||||
if multi_modal_data is None or "audio" not in multi_modal_data:
|
||||
return llm_inputs
|
||||
|
||||
audio, sr = multi_modal_data.get("audio")
|
||||
audio_duration = math.ceil(len(audio) / sr)
|
||||
|
||||
new_prompt, new_token_ids = repeat_and_pad_image_tokens(
|
||||
cached_get_tokenizer(ctx.model_config.tokenizer),
|
||||
llm_inputs.get("prompt"),
|
||||
llm_inputs["prompt_token_ids"],
|
||||
image_token_id=62, # "_"
|
||||
repeat_count=audio_duration)
|
||||
|
||||
return LLMInputs(prompt_token_ids=new_token_ids,
|
||||
prompt=new_prompt,
|
||||
multi_modal_data=multi_modal_data)
|
||||
|
||||
@MULTIMODAL_REGISTRY.register_input_mapper("audio", fake_input_mapper)
|
||||
@MULTIMODAL_REGISTRY.register_max_multimodal_tokens(
|
||||
"audio", lambda *_, **__: 100)
|
||||
@INPUT_REGISTRY.register_input_processor(fake_input_processor)
|
||||
class FakeAudioModel(OPTForCausalLM, SupportsMultiModal):
|
||||
|
||||
def __init__(self, *args, multimodal_config: MultiModalConfig,
|
||||
**kwargs):
|
||||
assert multimodal_config is not None
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
*args,
|
||||
processed_audio: Optional[torch.Tensor] = None,
|
||||
**kwargs,
|
||||
) -> torch.Tensor:
|
||||
return super().forward(*args, **kwargs)
|
||||
|
||||
ModelRegistry.register_model("OPTForCausalLM", FakeAudioModel)
|
||||
|
||||
with patch(
|
||||
"vllm.entrypoints.chat_utils._mm_token_str",
|
||||
lambda *_, **__: "_"), patch(
|
||||
"vllm.model_executor.models.ModelRegistry.is_multimodal_model"
|
||||
) as mock:
|
||||
mock.return_value = True
|
||||
sys.argv = ["placeholder.py"] + \
|
||||
(f"--model {MODEL_NAME} --gpu-memory-utilization 0.10 "
|
||||
"--dtype bfloat16 --enforce-eager --api-key token-abc123 "
|
||||
f"--port {port} --chat-template {chatml_jinja_path} "
|
||||
"--disable-frontend-multiprocessing").split()
|
||||
import runpy
|
||||
runpy.run_module('vllm.entrypoints.openai.api_server',
|
||||
run_name='__main__')
|
||||
with RemoteOpenAIServer(MODEL_NAME, args) as remote_server:
|
||||
yield remote_server
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def client():
|
||||
port = get_open_port()
|
||||
ctx = torch.multiprocessing.get_context("spawn")
|
||||
server = ctx.Process(target=server_function, args=(port, ))
|
||||
server.start()
|
||||
MAX_SERVER_START_WAIT_S = 60
|
||||
client = openai.AsyncOpenAI(
|
||||
base_url=f"http://localhost:{port}/v1",
|
||||
api_key="token-abc123",
|
||||
)
|
||||
# run health check
|
||||
health_url = f"http://localhost:{port}/health"
|
||||
start = time.time()
|
||||
while True:
|
||||
try:
|
||||
if requests.get(health_url).status_code == 200:
|
||||
break
|
||||
except Exception as err:
|
||||
result = server.exitcode
|
||||
if result is not None:
|
||||
raise RuntimeError("Server exited unexpectedly.") from err
|
||||
|
||||
time.sleep(0.5)
|
||||
if time.time() - start > MAX_SERVER_START_WAIT_S:
|
||||
raise RuntimeError("Server failed to start in time.") from err
|
||||
|
||||
try:
|
||||
yield client
|
||||
finally:
|
||||
server.kill()
|
||||
def client(server):
|
||||
return server.get_async_client()
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
@ -176,7 +74,7 @@ async def test_single_chat_session_audio(client: openai.AsyncOpenAI,
|
||||
choice = chat_completion.choices[0]
|
||||
assert choice.finish_reason == "length"
|
||||
assert chat_completion.usage == openai.types.CompletionUsage(
|
||||
completion_tokens=10, prompt_tokens=36, total_tokens=46)
|
||||
completion_tokens=10, prompt_tokens=202, total_tokens=212)
|
||||
|
||||
message = choice.message
|
||||
message = chat_completion.choices[0].message
|
||||
@ -231,7 +129,7 @@ async def test_single_chat_session_audio_base64encoded(
|
||||
choice = chat_completion.choices[0]
|
||||
assert choice.finish_reason == "length"
|
||||
assert chat_completion.usage == openai.types.CompletionUsage(
|
||||
completion_tokens=10, prompt_tokens=36, total_tokens=46)
|
||||
completion_tokens=10, prompt_tokens=202, total_tokens=212)
|
||||
|
||||
message = choice.message
|
||||
message = chat_completion.choices[0].message
|
||||
|
@ -12,6 +12,7 @@ if not is_cpu():
|
||||
# (xFormers, etc.)
|
||||
|
||||
import pytest
|
||||
from transformers import AutoModelForSeq2SeqLM
|
||||
|
||||
from vllm.sequence import SampleLogprobs
|
||||
|
||||
@ -131,7 +132,7 @@ if not is_cpu():
|
||||
}
|
||||
|
||||
with hf_runner(model, dtype=dtype,
|
||||
is_encoder_decoder_model=True) as hf_model:
|
||||
auto_cls=AutoModelForSeq2SeqLM) as hf_model:
|
||||
hf_outputs = (
|
||||
hf_model.generate_encoder_decoder_greedy_logprobs_limit(
|
||||
test_case_prompts,
|
||||
|
@ -1,7 +1,7 @@
|
||||
from typing import List, Optional, Tuple
|
||||
|
||||
import pytest
|
||||
from transformers import AutoTokenizer
|
||||
from transformers import AutoModelForVision2Seq, AutoTokenizer
|
||||
|
||||
from vllm.multimodal.utils import rescale_image_size
|
||||
from vllm.sequence import SampleLogprobs
|
||||
@ -80,7 +80,8 @@ def test_models(hf_runner, vllm_runner, image_assets, model, size_factors,
|
||||
for prompts, images in inputs_per_image
|
||||
]
|
||||
|
||||
with hf_runner(model, dtype=dtype, is_vision_model=True) as hf_model:
|
||||
with hf_runner(model, dtype=dtype,
|
||||
auto_cls=AutoModelForVision2Seq) as hf_model:
|
||||
hf_outputs_per_image = [
|
||||
hf_model.generate_greedy_logprobs_limit(prompts,
|
||||
max_tokens,
|
||||
|
@ -1,7 +1,7 @@
|
||||
from typing import List, Optional, Type
|
||||
|
||||
import pytest
|
||||
from transformers import BatchEncoding
|
||||
from transformers import AutoModelForVision2Seq, BatchEncoding
|
||||
|
||||
from vllm.multimodal.utils import rescale_image_size
|
||||
from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE
|
||||
@ -74,7 +74,7 @@ def run_test(
|
||||
with hf_runner(model,
|
||||
dtype=dtype,
|
||||
postprocess_inputs=process,
|
||||
is_vision_model=True) as hf_model:
|
||||
auto_cls=AutoModelForVision2Seq) as hf_model:
|
||||
hf_outputs_per_image = [
|
||||
hf_model.generate_greedy_logprobs_limit(prompts,
|
||||
max_tokens,
|
||||
|
@ -1,7 +1,8 @@
|
||||
from typing import List, Optional, Tuple, Type
|
||||
|
||||
import pytest
|
||||
from transformers import AutoConfig, AutoTokenizer, BatchEncoding
|
||||
from transformers import (AutoConfig, AutoModelForVision2Seq, AutoTokenizer,
|
||||
BatchEncoding)
|
||||
|
||||
from vllm.multimodal.utils import rescale_image_size
|
||||
from vllm.sequence import SampleLogprobs
|
||||
@ -124,7 +125,7 @@ def run_test(
|
||||
with hf_runner(model,
|
||||
dtype=dtype,
|
||||
postprocess_inputs=process,
|
||||
is_vision_model=True) as hf_model:
|
||||
auto_cls=AutoModelForVision2Seq) as hf_model:
|
||||
hf_outputs_per_image = [
|
||||
hf_model.generate_greedy_logprobs_limit(prompts,
|
||||
max_tokens,
|
||||
|
@ -1,7 +1,7 @@
|
||||
from typing import List, Optional, Tuple, Type
|
||||
|
||||
import pytest
|
||||
from transformers import AutoConfig, AutoTokenizer
|
||||
from transformers import AutoConfig, AutoModelForVision2Seq, AutoTokenizer
|
||||
|
||||
from vllm.sequence import SampleLogprobs
|
||||
|
||||
@ -105,7 +105,8 @@ def run_test(
|
||||
for prompts, images in vllm_inputs_per_image
|
||||
]
|
||||
|
||||
with hf_runner(model, dtype=dtype, is_vision_model=True) as hf_model:
|
||||
with hf_runner(model, dtype=dtype,
|
||||
auto_cls=AutoModelForVision2Seq) as hf_model:
|
||||
hf_outputs_per_image = [
|
||||
hf_model.generate_greedy_logprobs_limit(prompts,
|
||||
max_tokens,
|
||||
|
@ -1,7 +1,7 @@
|
||||
from typing import List, Optional, Tuple, Type, overload
|
||||
|
||||
import pytest
|
||||
from transformers import AutoConfig, AutoTokenizer
|
||||
from transformers import AutoConfig, AutoModelForVision2Seq, AutoTokenizer
|
||||
|
||||
from vllm.multimodal.utils import rescale_image_size
|
||||
from vllm.sequence import SampleLogprobs
|
||||
@ -129,7 +129,8 @@ def run_test(
|
||||
for prompts, images in inputs_per_image
|
||||
]
|
||||
|
||||
with hf_runner(model, dtype=dtype, is_vision_model=True) as hf_model:
|
||||
with hf_runner(model, dtype=dtype,
|
||||
auto_cls=AutoModelForVision2Seq) as hf_model:
|
||||
hf_outputs_per_image = [
|
||||
hf_model.generate_greedy_logprobs_limit(prompts,
|
||||
max_tokens,
|
||||
|
@ -2,7 +2,7 @@ import os
|
||||
from typing import List, Optional, Tuple, Type
|
||||
|
||||
import pytest
|
||||
from transformers import AutoConfig, AutoTokenizer
|
||||
from transformers import AutoConfig, AutoModelForVision2Seq, AutoTokenizer
|
||||
|
||||
from vllm.multimodal.utils import rescale_image_size
|
||||
from vllm.sequence import SampleLogprobs
|
||||
@ -102,7 +102,8 @@ def run_test(
|
||||
for prompts, images in inputs_per_image
|
||||
]
|
||||
|
||||
with hf_runner(model, dtype=dtype, is_vision_model=True) as hf_model:
|
||||
with hf_runner(model, dtype=dtype,
|
||||
auto_cls=AutoModelForVision2Seq) as hf_model:
|
||||
hf_outputs_per_image = [
|
||||
hf_model.generate_greedy_logprobs_limit(prompts,
|
||||
max_tokens,
|
||||
|
@ -26,7 +26,7 @@ def test_text_only_qwen_model(
|
||||
# for qwen-vl is still unsupported in VLLM. In the near-future, the
|
||||
# implementation and this test will be extended to consider
|
||||
# visual inputs as well.
|
||||
with hf_runner(model, dtype=dtype, is_vision_model=False) as hf_model:
|
||||
with hf_runner(model, dtype=dtype) as hf_model:
|
||||
hf_outputs = hf_model.generate_greedy_logprobs_limit(
|
||||
example_prompts,
|
||||
max_tokens,
|
||||
|
151
tests/models/test_ultravox.py
Normal file
151
tests/models/test_ultravox.py
Normal file
@ -0,0 +1,151 @@
|
||||
from typing import List, Optional, Tuple, Type
|
||||
|
||||
import librosa
|
||||
import numpy as np
|
||||
import pytest
|
||||
from transformers import AutoModel, AutoTokenizer, BatchEncoding
|
||||
|
||||
from vllm.assets.audio import AudioAsset
|
||||
from vllm.sequence import SampleLogprobs
|
||||
from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE
|
||||
|
||||
from ..conftest import HfRunner, VllmRunner
|
||||
from .utils import check_logprobs_close
|
||||
|
||||
pytestmark = pytest.mark.vlm
|
||||
|
||||
MODEL_NAME = "fixie-ai/ultravox-v0_3"
|
||||
|
||||
AudioTuple = Tuple[np.ndarray, int]
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def audio_and_sample_rate():
|
||||
return AudioAsset("mary_had_lamb").audio_and_sample_rate
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def prompts_and_audios(audio_and_sample_rate):
|
||||
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
|
||||
|
||||
vllm_placeholder = "<|reserved_special_token_0|>"
|
||||
hf_placeholder = "<|audio|>"
|
||||
|
||||
question = "What's in the audio?"
|
||||
vllm_prompt = tokenizer.apply_chat_template(
|
||||
[{
|
||||
'role': 'user',
|
||||
'content': f"{vllm_placeholder}\n{question}"
|
||||
}],
|
||||
tokenize=False,
|
||||
add_generation_prompt=True)
|
||||
hf_prompt = tokenizer.apply_chat_template(
|
||||
[{
|
||||
'role': 'user',
|
||||
'content': f"{hf_placeholder}\n{question}"
|
||||
}],
|
||||
tokenize=False,
|
||||
add_generation_prompt=True)
|
||||
|
||||
return [(vllm_prompt, hf_prompt, audio_and_sample_rate)]
|
||||
|
||||
|
||||
def vllm_to_hf_output(vllm_output: Tuple[List[int], str,
|
||||
Optional[SampleLogprobs]],
|
||||
model: str):
|
||||
"""Sanitize vllm output to be comparable with hf output."""
|
||||
output_ids, output_str, out_logprobs = vllm_output
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained(model)
|
||||
eos_token_id = tokenizer.eos_token_id
|
||||
|
||||
hf_output_ids = output_ids[:]
|
||||
hf_output_str = output_str
|
||||
if hf_output_ids[-1] == eos_token_id:
|
||||
hf_output_str = hf_output_str + tokenizer.decode(eos_token_id)
|
||||
|
||||
return hf_output_ids, hf_output_str, out_logprobs
|
||||
|
||||
|
||||
def run_test(
|
||||
hf_runner: Type[HfRunner],
|
||||
vllm_runner: Type[VllmRunner],
|
||||
prompts_and_audios: List[Tuple[str, str, AudioTuple]],
|
||||
model: str,
|
||||
*,
|
||||
dtype: str,
|
||||
max_tokens: int,
|
||||
num_logprobs: int,
|
||||
tensor_parallel_size: int,
|
||||
distributed_executor_backend: Optional[str] = None,
|
||||
):
|
||||
"""Inference result should be the same between hf and vllm."""
|
||||
torch_dtype = STR_DTYPE_TO_TORCH_DTYPE[dtype]
|
||||
|
||||
# NOTE: take care of the order. run vLLM first, and then run HF.
|
||||
# vLLM needs a fresh new process without cuda initialization.
|
||||
# if we run HF first, the cuda initialization will be done and it
|
||||
# will hurt multiprocessing backend with fork method (the default method).
|
||||
|
||||
with vllm_runner(model,
|
||||
dtype=dtype,
|
||||
tensor_parallel_size=tensor_parallel_size,
|
||||
distributed_executor_backend=distributed_executor_backend,
|
||||
enforce_eager=True) as vllm_model:
|
||||
vllm_outputs_per_audio = [
|
||||
vllm_model.generate_greedy_logprobs([vllm_prompt],
|
||||
max_tokens,
|
||||
num_logprobs=num_logprobs,
|
||||
audios=[audio])
|
||||
for vllm_prompt, _, audio in prompts_and_audios
|
||||
]
|
||||
|
||||
def process(hf_inputs: BatchEncoding):
|
||||
hf_inputs["audio_values"] = hf_inputs["audio_values"] \
|
||||
.to(torch_dtype) # type: ignore
|
||||
return hf_inputs
|
||||
|
||||
with hf_runner(model,
|
||||
dtype=dtype,
|
||||
postprocess_inputs=process,
|
||||
auto_cls=AutoModel) as hf_model:
|
||||
|
||||
hf_outputs_per_audio = [
|
||||
hf_model.generate_greedy_logprobs_limit(
|
||||
[hf_prompt],
|
||||
max_tokens,
|
||||
num_logprobs=num_logprobs,
|
||||
audios=[(librosa.resample(audio[0],
|
||||
orig_sr=audio[1],
|
||||
target_sr=16000), 16000)])
|
||||
for _, hf_prompt, audio in prompts_and_audios
|
||||
]
|
||||
|
||||
for hf_outputs, vllm_outputs in zip(hf_outputs_per_audio,
|
||||
vllm_outputs_per_audio):
|
||||
check_logprobs_close(
|
||||
outputs_0_lst=hf_outputs,
|
||||
outputs_1_lst=[
|
||||
vllm_to_hf_output(vllm_output, model)
|
||||
for vllm_output in vllm_outputs
|
||||
],
|
||||
name_0="hf",
|
||||
name_1="vllm",
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("dtype", ["half"])
|
||||
@pytest.mark.parametrize("max_tokens", [128])
|
||||
@pytest.mark.parametrize("num_logprobs", [5])
|
||||
def test_models(hf_runner, vllm_runner, prompts_and_audios, dtype: str,
|
||||
max_tokens: int, num_logprobs: int) -> None:
|
||||
run_test(
|
||||
hf_runner,
|
||||
vllm_runner,
|
||||
prompts_and_audios,
|
||||
MODEL_NAME,
|
||||
dtype=dtype,
|
||||
max_tokens=max_tokens,
|
||||
num_logprobs=num_logprobs,
|
||||
tensor_parallel_size=1,
|
||||
)
|
26
vllm/assets/audio.py
Normal file
26
vllm/assets/audio.py
Normal file
@ -0,0 +1,26 @@
|
||||
from dataclasses import dataclass
|
||||
from typing import Literal, Tuple
|
||||
from urllib.parse import urljoin
|
||||
|
||||
import librosa
|
||||
import numpy as np
|
||||
|
||||
from vllm.assets.base import get_vllm_public_assets, vLLM_S3_BUCKET_URL
|
||||
|
||||
ASSET_DIR = "multimodal_asset"
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class AudioAsset:
|
||||
name: Literal["winning_call", "mary_had_lamb"]
|
||||
|
||||
@property
|
||||
def audio_and_sample_rate(self) -> Tuple[np.ndarray, int]:
|
||||
|
||||
audio_path = get_vllm_public_assets(filename=f"{self.name}.ogg",
|
||||
s3_prefix=ASSET_DIR)
|
||||
return librosa.load(audio_path, sr=None)
|
||||
|
||||
@property
|
||||
def url(self) -> str:
|
||||
return urljoin(vLLM_S3_BUCKET_URL, f"{ASSET_DIR}/{self.name}.ogg")
|
@ -117,8 +117,8 @@ def _mm_token_str(model_config: ModelConfig, tokenizer: AnyTokenizer,
|
||||
modality: Literal["image", "audio"]) -> Optional[str]:
|
||||
# TODO: Let user specify how to insert image tokens into prompt
|
||||
# (similar to chat template)
|
||||
if modality == "image":
|
||||
model_type = model_config.hf_config.model_type
|
||||
if modality == "image":
|
||||
if model_type == "phi3_v":
|
||||
# Workaround since this token is not defined in the tokenizer
|
||||
return "<|image_1|>"
|
||||
@ -134,7 +134,9 @@ def _mm_token_str(model_config: ModelConfig, tokenizer: AnyTokenizer,
|
||||
|
||||
raise TypeError(f"Unknown model type: {model_type}")
|
||||
elif modality == "audio":
|
||||
raise TypeError("No audio models are supported yet.")
|
||||
if model_type == "ultravox":
|
||||
return "<|reserved_special_token_0|>"
|
||||
raise TypeError(f"Unknown model type: {model_type}")
|
||||
else:
|
||||
raise TypeError(f"Unknown modality: {modality}")
|
||||
|
||||
|
@ -61,7 +61,7 @@ _GENERATION_MODELS = {
|
||||
"Phi3SmallForCausalLM": ("phi3_small", "Phi3SmallForCausalLM"),
|
||||
"MedusaModel": ("medusa", "Medusa"),
|
||||
"MLPSpeculatorPreTrainedModel": ("mlp_speculator", "MLPSpeculator"),
|
||||
"JambaForCausalLM": ("jamba", "JambaForCausalLM")
|
||||
"JambaForCausalLM": ("jamba", "JambaForCausalLM"),
|
||||
}
|
||||
|
||||
_EMBEDDING_MODELS = {
|
||||
@ -83,6 +83,7 @@ _MULTIMODAL_MODELS = {
|
||||
"PaliGemmaForConditionalGeneration": ("paligemma",
|
||||
"PaliGemmaForConditionalGeneration"),
|
||||
"Phi3VForCausalLM": ("phi3v", "Phi3VForCausalLM"),
|
||||
"UltravoxModel": ("ultravox", "UltravoxModel"),
|
||||
}
|
||||
_CONDITIONAL_GENERATION_MODELS = {
|
||||
"BartModel": ("bart", "BartForConditionalGeneration"),
|
||||
|
@ -15,8 +15,8 @@ from vllm.model_executor.layers.activation import get_act_fn
|
||||
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
|
||||
RowParallelLinear)
|
||||
from vllm.model_executor.layers.quantization import QuantizationConfig
|
||||
from vllm.multimodal.image import (cached_get_tokenizer,
|
||||
repeat_and_pad_image_tokens)
|
||||
from vllm.multimodal.utils import (cached_get_tokenizer,
|
||||
repeat_and_pad_placeholder_tokens)
|
||||
from vllm.sequence import VLLM_TOKEN_ID_ARRAY_TYPE, SequenceData
|
||||
|
||||
|
||||
@ -97,11 +97,11 @@ def input_processor_for_blip(
|
||||
else:
|
||||
image_feature_size = image_feature_size_override
|
||||
|
||||
new_prompt, new_token_ids = repeat_and_pad_image_tokens(
|
||||
new_prompt, new_token_ids = repeat_and_pad_placeholder_tokens(
|
||||
tokenizer,
|
||||
llm_inputs.get("prompt"),
|
||||
llm_inputs["prompt_token_ids"],
|
||||
image_token_id=image_token_id,
|
||||
placeholder_token_id=image_token_id,
|
||||
repeat_count=image_feature_size,
|
||||
)
|
||||
|
||||
|
@ -30,8 +30,8 @@ from vllm.model_executor.model_loader.weight_utils import (
|
||||
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
||||
from vllm.model_executor.utils import set_weight_attrs
|
||||
from vllm.multimodal import MULTIMODAL_REGISTRY
|
||||
from vllm.multimodal.image import (cached_get_tokenizer,
|
||||
repeat_and_pad_image_tokens)
|
||||
from vllm.multimodal.utils import (cached_get_tokenizer,
|
||||
repeat_and_pad_placeholder_tokens)
|
||||
from vllm.sequence import (VLLM_TOKEN_ID_ARRAY_TYPE, IntermediateTensors,
|
||||
SamplerOutput, SequenceData)
|
||||
from vllm.utils import print_warning_once
|
||||
@ -124,11 +124,11 @@ def input_processor_for_chameleon(ctx: InputContext, llm_inputs: LLMInputs):
|
||||
|
||||
model_config = ctx.model_config
|
||||
tokenizer = cached_get_tokenizer(model_config.tokenizer)
|
||||
new_prompt, new_token_ids = repeat_and_pad_image_tokens(
|
||||
new_prompt, new_token_ids = repeat_and_pad_placeholder_tokens(
|
||||
tokenizer,
|
||||
llm_inputs.get("prompt"),
|
||||
llm_inputs["prompt_token_ids"],
|
||||
image_token_id=CHAMELEON_IMAGE_TOKEN_ID,
|
||||
placeholder_token_id=CHAMELEON_IMAGE_TOKEN_ID,
|
||||
repeat_count=CHAMELEON_IMAGE_SEQ_LENGTH,
|
||||
pad_token_left=CHAMELEON_IMAGE_START_TOKEN_ID,
|
||||
pad_token_right=CHAMELEON_IMAGE_END_TOKEN_ID,
|
||||
|
@ -16,8 +16,8 @@ from vllm.model_executor.layers.linear import (ColumnParallelLinear,
|
||||
RowParallelLinear)
|
||||
from vllm.model_executor.layers.quantization import QuantizationConfig
|
||||
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
||||
from vllm.multimodal.image import (cached_get_tokenizer,
|
||||
repeat_and_pad_image_tokens)
|
||||
from vllm.multimodal.utils import (cached_get_tokenizer,
|
||||
repeat_and_pad_placeholder_tokens)
|
||||
from vllm.sequence import VLLM_TOKEN_ID_ARRAY_TYPE, SequenceData
|
||||
|
||||
|
||||
@ -103,11 +103,11 @@ def input_processor_for_clip(
|
||||
else:
|
||||
image_feature_size = image_feature_size_override
|
||||
|
||||
new_prompt, new_token_ids = repeat_and_pad_image_tokens(
|
||||
new_prompt, new_token_ids = repeat_and_pad_placeholder_tokens(
|
||||
tokenizer,
|
||||
llm_inputs.get("prompt"),
|
||||
llm_inputs["prompt_token_ids"],
|
||||
image_token_id=image_token_id,
|
||||
placeholder_token_id=image_token_id,
|
||||
repeat_count=image_feature_size,
|
||||
)
|
||||
|
||||
|
@ -36,8 +36,8 @@ from vllm.model_executor.models.persimmon import PersimmonForCausalLM
|
||||
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
||||
from vllm.multimodal import MULTIMODAL_REGISTRY
|
||||
from vllm.multimodal.base import MultiModalInputs
|
||||
from vllm.multimodal.image import (cached_get_image_processor,
|
||||
cached_get_tokenizer)
|
||||
from vllm.multimodal.image import cached_get_image_processor
|
||||
from vllm.multimodal.utils import cached_get_tokenizer
|
||||
from vllm.sequence import (VLLM_TOKEN_ID_ARRAY_TYPE, IntermediateTensors,
|
||||
SamplerOutput, SequenceData)
|
||||
|
||||
|
@ -23,7 +23,7 @@ from vllm.model_executor.models.intern_vit import InternVisionModel
|
||||
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
||||
from vllm.multimodal import MULTIMODAL_REGISTRY
|
||||
from vllm.multimodal.base import MultiModalInputs
|
||||
from vllm.multimodal.image import cached_get_tokenizer
|
||||
from vllm.multimodal.utils import cached_get_tokenizer
|
||||
from vllm.sequence import IntermediateTensors, SamplerOutput
|
||||
|
||||
from .clip import (dummy_image_for_clip, dummy_seq_data_for_clip,
|
||||
|
@ -54,8 +54,8 @@ from vllm.model_executor.models.minicpm import MiniCPMModel
|
||||
from vllm.model_executor.models.qwen2 import Qwen2Model
|
||||
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
||||
from vllm.multimodal import MULTIMODAL_REGISTRY
|
||||
from vllm.multimodal.image import (cached_get_image_processor,
|
||||
cached_get_tokenizer)
|
||||
from vllm.multimodal.image import cached_get_image_processor
|
||||
from vllm.multimodal.utils import cached_get_tokenizer
|
||||
from vllm.sequence import (VLLM_TOKEN_ID_ARRAY_TYPE, IntermediateTensors,
|
||||
SamplerOutput, SequenceData)
|
||||
|
||||
|
@ -16,7 +16,7 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
||||
from vllm.model_executor.models.gemma import GemmaModel
|
||||
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
||||
from vllm.multimodal import MULTIMODAL_REGISTRY
|
||||
from vllm.multimodal.image import cached_get_tokenizer
|
||||
from vllm.multimodal.utils import cached_get_tokenizer
|
||||
from vllm.sequence import IntermediateTensors, SamplerOutput
|
||||
|
||||
from .interfaces import SupportsMultiModal
|
||||
|
@ -37,7 +37,7 @@ from vllm.model_executor.models.clip import CLIPVisionModel
|
||||
from vllm.model_executor.models.llama import LlamaModel
|
||||
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
||||
from vllm.multimodal import MULTIMODAL_REGISTRY
|
||||
from vllm.multimodal.image import cached_get_tokenizer
|
||||
from vllm.multimodal.utils import cached_get_tokenizer
|
||||
from vllm.sequence import IntermediateTensors, SamplerOutput
|
||||
|
||||
from .clip import (dummy_image_for_clip, dummy_seq_data_for_clip,
|
||||
|
@ -24,8 +24,8 @@ from vllm.model_executor.layers.quantization import QuantizationConfig
|
||||
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
||||
VocabParallelEmbedding)
|
||||
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
||||
from vllm.multimodal.image import (cached_get_tokenizer,
|
||||
repeat_and_pad_image_tokens)
|
||||
from vllm.multimodal.utils import (cached_get_tokenizer,
|
||||
repeat_and_pad_placeholder_tokens)
|
||||
from vllm.sequence import VLLM_TOKEN_ID_ARRAY_TYPE, SequenceData
|
||||
|
||||
|
||||
@ -112,11 +112,11 @@ def input_processor_for_siglip(
|
||||
else:
|
||||
image_feature_size = image_feature_size_override
|
||||
|
||||
new_prompt, new_token_ids = repeat_and_pad_image_tokens(
|
||||
new_prompt, new_token_ids = repeat_and_pad_placeholder_tokens(
|
||||
tokenizer,
|
||||
llm_inputs.get("prompt"),
|
||||
llm_inputs["prompt_token_ids"],
|
||||
image_token_id=image_token_id,
|
||||
placeholder_token_id=image_token_id,
|
||||
repeat_count=image_feature_size,
|
||||
)
|
||||
|
||||
|
435
vllm/model_executor/models/ultravox.py
Normal file
435
vllm/model_executor/models/ultravox.py
Normal file
@ -0,0 +1,435 @@
|
||||
# Adapted from https://github.com/fixie-ai/ultravox/blob/ecd58c4041030bae2ad15aa6bcf04ab43199ea02/ultravox/model/ultravox_model.py
|
||||
"""PyTorch Ultravox model."""
|
||||
|
||||
import itertools
|
||||
import math
|
||||
from array import array
|
||||
from functools import lru_cache
|
||||
from typing import (Iterable, List, Literal, Mapping, Optional, Tuple,
|
||||
TypedDict, Union, cast)
|
||||
|
||||
import librosa
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.utils.checkpoint
|
||||
from torch import nn
|
||||
from torch.nn import functional as F
|
||||
from transformers.models.whisper import WhisperFeatureExtractor
|
||||
from transformers.models.whisper.modeling_whisper import WhisperEncoder
|
||||
|
||||
from vllm.attention import AttentionMetadata
|
||||
from vllm.config import CacheConfig, MultiModalConfig
|
||||
from vllm.inputs import INPUT_REGISTRY
|
||||
from vllm.inputs.data import LLMInputs
|
||||
from vllm.inputs.registry import InputContext
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.activation import SiluAndMul, get_act_fn
|
||||
from vllm.model_executor.layers.layernorm import RMSNorm
|
||||
from vllm.model_executor.layers.quantization.base_config import (
|
||||
QuantizationConfig)
|
||||
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
||||
from vllm.model_executor.models.interfaces import SupportsMultiModal
|
||||
from vllm.model_executor.models.utils import (filter_weights,
|
||||
init_vllm_registered_model,
|
||||
merge_multimodal_embeddings)
|
||||
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
||||
from vllm.multimodal import MULTIMODAL_REGISTRY
|
||||
from vllm.multimodal.base import MultiModalInputs
|
||||
from vllm.multimodal.utils import (cached_get_tokenizer,
|
||||
repeat_and_pad_placeholder_tokens)
|
||||
from vllm.sequence import VLLM_TOKEN_ID_ARRAY_TYPE, SamplerOutput, SequenceData
|
||||
from vllm.transformers_utils.configs.ultravox import UltravoxConfig
|
||||
|
||||
_AUDIO_PLACEHOLDER_TOKEN = 128002
|
||||
_AUDIO_TOKENS_PER_SECOND = 6.25
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
class UltravoxAudioFeatureInputs(TypedDict):
|
||||
type: Literal["audio_features"]
|
||||
data: Union[torch.Tensor, List[torch.Tensor]]
|
||||
"""Shape: `(batch_size, 80, M)"""
|
||||
|
||||
|
||||
class UltravoxAudioEmbeddingInputs(TypedDict):
|
||||
type: Literal["audio_embeds"]
|
||||
data: torch.Tensor
|
||||
|
||||
|
||||
UltravoxAudioInputs = Union[UltravoxAudioFeatureInputs,
|
||||
UltravoxAudioEmbeddingInputs]
|
||||
|
||||
|
||||
@lru_cache
|
||||
def cached_feature_extractor(model_id: str) -> WhisperFeatureExtractor:
|
||||
return WhisperFeatureExtractor.from_pretrained(model_id)
|
||||
|
||||
|
||||
def whisper_feature_extractor(ctx: InputContext) -> WhisperFeatureExtractor:
|
||||
return cached_feature_extractor(
|
||||
ctx.get_hf_config(UltravoxConfig).audio_model_id)
|
||||
|
||||
|
||||
def get_ultravox_max_audio_tokens(ctx: InputContext):
|
||||
feature_extractor = whisper_feature_extractor(ctx)
|
||||
return math.ceil(feature_extractor.chunk_length * _AUDIO_TOKENS_PER_SECOND)
|
||||
|
||||
|
||||
def dummy_data_for_ultravox(
|
||||
ctx: InputContext,
|
||||
seq_len: int,
|
||||
mm_counts: Mapping[str, int],
|
||||
):
|
||||
feature_extractor = whisper_feature_extractor(ctx)
|
||||
|
||||
audio_count = mm_counts["audio"]
|
||||
|
||||
audio_token_ids = array(VLLM_TOKEN_ID_ARRAY_TYPE, [
|
||||
_AUDIO_PLACEHOLDER_TOKEN
|
||||
]) * get_ultravox_max_audio_tokens(ctx) * audio_count
|
||||
other_token_ids = array(VLLM_TOKEN_ID_ARRAY_TYPE,
|
||||
[0]) * (seq_len - len(audio_token_ids))
|
||||
|
||||
audio_and_sr = (np.array([0.0] * feature_extractor.chunk_length), 1)
|
||||
mm_dict = {
|
||||
"audio":
|
||||
audio_and_sr if audio_count == 1 else [audio_and_sr] * audio_count
|
||||
}
|
||||
|
||||
return (SequenceData(audio_token_ids + other_token_ids), mm_dict)
|
||||
|
||||
|
||||
def input_mapper_for_ultravox(ctx: InputContext, data: object):
|
||||
if isinstance(data, tuple):
|
||||
(audio, sr) = cast(Tuple[np.ndarray, Union[float, int]], data)
|
||||
feature_extractor = whisper_feature_extractor(ctx)
|
||||
|
||||
if sr != feature_extractor.sampling_rate:
|
||||
audio = librosa.resample(audio,
|
||||
orig_sr=sr,
|
||||
target_sr=feature_extractor.sampling_rate)
|
||||
sr = feature_extractor.sampling_rate
|
||||
|
||||
minimum_audio_length = feature_extractor.n_fft // 2 + 1
|
||||
if len(audio) < minimum_audio_length:
|
||||
# Not enough audio; pad it.
|
||||
audio = np.pad(audio, (0, minimum_audio_length - len(audio)))
|
||||
|
||||
return MultiModalInputs({
|
||||
"audio_features":
|
||||
feature_extractor(audio,
|
||||
sampling_rate=sr,
|
||||
padding="longest",
|
||||
return_tensors="pt")["input_features"]
|
||||
})
|
||||
|
||||
raise NotImplementedError(f"Unsupported data type: {type(data)}")
|
||||
|
||||
|
||||
def input_processor_for_ultravox(ctx: InputContext, llm_inputs: LLMInputs):
|
||||
multi_modal_data = llm_inputs.get("multi_modal_data")
|
||||
if multi_modal_data is None or "audio" not in multi_modal_data:
|
||||
return llm_inputs
|
||||
|
||||
feature_extractor = whisper_feature_extractor(ctx)
|
||||
audio_data, sample_rate = multi_modal_data["audio"]
|
||||
|
||||
audio_length = audio_data.shape[0]
|
||||
if sample_rate != feature_extractor.sampling_rate:
|
||||
# Account for resampling.
|
||||
adjustment = feature_extractor.sampling_rate / sample_rate
|
||||
audio_length = math.ceil(adjustment * audio_length)
|
||||
|
||||
feature_extractor_output_length = math.ceil(
|
||||
(audio_length -
|
||||
(feature_extractor.hop_length - 1)) / feature_extractor.hop_length)
|
||||
|
||||
uv_config = ctx.get_hf_config(UltravoxConfig)
|
||||
audio_num_tokens = min(
|
||||
max(
|
||||
1,
|
||||
math.ceil(feature_extractor_output_length /
|
||||
(uv_config.stack_factor * 2))),
|
||||
get_ultravox_max_audio_tokens(ctx))
|
||||
tokenizer = cached_get_tokenizer(ctx.model_config.tokenizer)
|
||||
|
||||
new_prompt, new_token_ids = repeat_and_pad_placeholder_tokens(
|
||||
tokenizer,
|
||||
llm_inputs.get("prompt"),
|
||||
llm_inputs["prompt_token_ids"],
|
||||
placeholder_token_id=_AUDIO_PLACEHOLDER_TOKEN,
|
||||
repeat_count=audio_num_tokens,
|
||||
)
|
||||
|
||||
# NOTE: Create a defensive copy of the original inputs
|
||||
return LLMInputs(prompt_token_ids=new_token_ids,
|
||||
prompt=new_prompt,
|
||||
multi_modal_data=multi_modal_data)
|
||||
|
||||
|
||||
class StackAudioFrames(nn.Module):
|
||||
"""
|
||||
Stack the audio embedding frames to reduce the sequence length by a factor
|
||||
of `stack_factor`.
|
||||
"""
|
||||
|
||||
def __init__(self, stack_factor: int = 8):
|
||||
super().__init__()
|
||||
self.stack_factor = stack_factor
|
||||
|
||||
def forward(self, audio_embeds: torch.Tensor) -> torch.Tensor:
|
||||
B, T, C = audio_embeds.shape
|
||||
T_pad = (T + self.stack_factor -
|
||||
1) // self.stack_factor * self.stack_factor
|
||||
audio_embeds = F.pad(audio_embeds, (0, 0, 0, T_pad - T))
|
||||
B, T, C = audio_embeds.shape
|
||||
audio_embeds = audio_embeds.view(B, T // self.stack_factor,
|
||||
C * self.stack_factor)
|
||||
return audio_embeds
|
||||
|
||||
|
||||
class FlippedSiluAndMul(SiluAndMul):
|
||||
"""Ultravox is trained with SwiGLU with flipped halves."""
|
||||
|
||||
def forward(self, x: torch.Tensor):
|
||||
a, b = x.chunk(2, dim=-1)
|
||||
flipped = torch.cat((b, a), dim=-1)
|
||||
return super().forward(flipped)
|
||||
|
||||
|
||||
class UltravoxProjector(nn.Module):
|
||||
|
||||
def __init__(self, config: UltravoxConfig):
|
||||
super().__init__()
|
||||
self.hidden_dim = config.hidden_size
|
||||
self._pad_and_stack = StackAudioFrames(config.stack_factor)
|
||||
dim = config.audio_config.hidden_size * config.stack_factor
|
||||
self.ln_pre = RMSNorm(dim)
|
||||
self.linear_1 = nn.Linear(dim, self.hidden_dim, bias=False)
|
||||
dim = self.hidden_dim
|
||||
|
||||
if config.projector_act == "swiglu":
|
||||
self.act = FlippedSiluAndMul()
|
||||
dim = dim // 2
|
||||
else:
|
||||
self.act = get_act_fn(config.projector_act)
|
||||
|
||||
self.linear_2 = nn.Linear(dim,
|
||||
config.text_config.hidden_size,
|
||||
bias=False)
|
||||
self.ln_post = RMSNorm(config.text_config.hidden_size)
|
||||
|
||||
def forward(self, audio_features: torch.Tensor) -> torch.Tensor:
|
||||
audio_features = self._pad_and_stack(audio_features)
|
||||
audio_features = self.ln_pre(audio_features)
|
||||
hidden_states = self.linear_1(audio_features)
|
||||
hidden_states = self.act(hidden_states)
|
||||
hidden_states = self.linear_2(hidden_states)
|
||||
hidden_states = self.ln_post(hidden_states)
|
||||
return hidden_states
|
||||
|
||||
|
||||
class ModifiedWhisperEncoder(WhisperEncoder):
|
||||
"""
|
||||
Encoder portion of OpenAI's Whisper model.
|
||||
|
||||
This implementation is a slightly modified version of HF Transformers'
|
||||
Whisper Encoder, with only a few fixes:
|
||||
1. base_model_prefix updated to allow for doing `.from_pretrained`
|
||||
directly on the encoder
|
||||
2. allow less than 30 second of audio padding to be passed in:
|
||||
- relaxed ValueError check for `input_features` length to be less
|
||||
than or equal to `expected_seq_length` instead of strictly equal
|
||||
- embed_pos is now sliced to match the length of `inputs_embeds`
|
||||
|
||||
Original: https://github.com/huggingface/transformers/blob/main/src/transformers/models/whisper/modeling_whisper.py
|
||||
See commentary: https://github.com/huggingface/transformers/issues/25744
|
||||
"""
|
||||
|
||||
base_model_prefix = "model.encoder"
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_features,
|
||||
):
|
||||
expected_seq_length = (self.config.max_source_positions *
|
||||
self.conv1.stride[0] * self.conv2.stride[0])
|
||||
if input_features.shape[-1] > expected_seq_length:
|
||||
raise ValueError(
|
||||
f"Whisper expects the mel input features to be of length "
|
||||
f"{expected_seq_length} or less, but found "
|
||||
f"{input_features.shape[-1]}. Make sure to pad the input mel "
|
||||
f"features to {expected_seq_length}.")
|
||||
|
||||
inputs_embeds = nn.functional.gelu(self.conv1(input_features))
|
||||
inputs_embeds = nn.functional.gelu(self.conv2(inputs_embeds))
|
||||
|
||||
inputs_embeds = inputs_embeds.permute(0, 2, 1)
|
||||
embed_pos = self.embed_positions.weight[:inputs_embeds.size(-2)]
|
||||
|
||||
hidden_states = inputs_embeds + embed_pos
|
||||
hidden_states = nn.functional.dropout(hidden_states,
|
||||
p=self.dropout,
|
||||
training=self.training)
|
||||
|
||||
for encoder_layer in self.layers:
|
||||
layer_outputs = encoder_layer(
|
||||
hidden_states,
|
||||
None,
|
||||
layer_head_mask=None,
|
||||
)
|
||||
|
||||
hidden_states = layer_outputs[0]
|
||||
|
||||
hidden_states = self.layer_norm(hidden_states)
|
||||
return hidden_states
|
||||
|
||||
|
||||
@MULTIMODAL_REGISTRY.register_input_mapper("audio", input_mapper_for_ultravox)
|
||||
@MULTIMODAL_REGISTRY.register_max_multimodal_tokens(
|
||||
"audio", get_ultravox_max_audio_tokens)
|
||||
@INPUT_REGISTRY.register_dummy_data(dummy_data_for_ultravox)
|
||||
@INPUT_REGISTRY.register_input_processor(input_processor_for_ultravox)
|
||||
class UltravoxModel(nn.Module, SupportsMultiModal):
|
||||
|
||||
def __init__(self,
|
||||
config: UltravoxConfig,
|
||||
multimodal_config: MultiModalConfig,
|
||||
cache_config: Optional[CacheConfig] = None,
|
||||
quant_config: Optional["QuantizationConfig"] = None):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.multi_modal_config = multimodal_config
|
||||
assert self.multi_modal_config
|
||||
|
||||
if config.audio_model_id is not None:
|
||||
self.audio_tower = ModifiedWhisperEncoder.from_pretrained(
|
||||
config.audio_model_id)
|
||||
else:
|
||||
self.audio_tower = ModifiedWhisperEncoder(config.audio_config)
|
||||
self.multi_modal_projector = UltravoxProjector(config)
|
||||
self.language_model = init_vllm_registered_model(
|
||||
config.text_config, cache_config, quant_config)
|
||||
|
||||
def _audio_features_to_embeddings(
|
||||
self, input_features: torch.Tensor) -> torch.Tensor:
|
||||
audio_input = input_features.to(self.audio_tower.dtype)
|
||||
audio_features = self.audio_tower(audio_input)
|
||||
audio_features = audio_features.to(self.audio_tower.dtype)
|
||||
audio_embeddings = self.multi_modal_projector(audio_features)
|
||||
return audio_embeddings
|
||||
|
||||
def _parse_and_validate_audio_input(
|
||||
self, **kwargs: object) -> Optional[UltravoxAudioInputs]:
|
||||
audio_features = kwargs.pop("audio_features", None)
|
||||
audio_embeds = kwargs.pop("audio_embeds", None)
|
||||
|
||||
if audio_features is None and audio_embeds is None:
|
||||
return None
|
||||
|
||||
if audio_features is not None:
|
||||
if not isinstance(audio_features, (torch.Tensor, list)):
|
||||
raise ValueError("Incorrect type of audio features. "
|
||||
f"Got type: {type(audio_features)}")
|
||||
|
||||
return UltravoxAudioFeatureInputs(type="audio_features",
|
||||
data=audio_features)
|
||||
|
||||
if audio_embeds is not None:
|
||||
if not isinstance(audio_embeds, torch.Tensor):
|
||||
raise ValueError("Incorrect type of audio embeds. "
|
||||
f"Got type: {type(audio_embeds)}")
|
||||
|
||||
return UltravoxAudioEmbeddingInputs(type="audio_embeds",
|
||||
data=audio_embeds)
|
||||
|
||||
raise AssertionError("This line should be unreachable.")
|
||||
|
||||
def _process_audio_input(
|
||||
self, audio_input: UltravoxAudioInputs
|
||||
) -> Union[torch.Tensor, List[torch.Tensor]]:
|
||||
if audio_input["type"] == "audio_embeds":
|
||||
return audio_input["data"]
|
||||
|
||||
audio_features = audio_input["data"]
|
||||
if isinstance(audio_features, list):
|
||||
# TODO: Batch these through the encoder/projector instead of
|
||||
# serializing them.
|
||||
return [
|
||||
self._audio_features_to_embeddings(
|
||||
features.unsqueeze(0)).squeeze(0)
|
||||
for features in audio_features
|
||||
]
|
||||
else:
|
||||
return self._audio_features_to_embeddings(audio_features)
|
||||
|
||||
def forward(self, input_ids: torch.Tensor, positions: torch.Tensor,
|
||||
kv_caches: List[torch.Tensor],
|
||||
attn_metadata: AttentionMetadata,
|
||||
intermediate_tensors: Optional[torch.Tensor],
|
||||
**kwargs) -> SamplerOutput:
|
||||
"""Run forward pass for Ultravox
|
||||
|
||||
One key thing to understand is the `input_ids` already accounts for the
|
||||
positions of the to-be-inserted audio embeddings. The to-be-inserted
|
||||
audio has a size that is essentially 6.25 tokens per second of audio.
|
||||
|
||||
This way, the `positions` and `attn_metadata` are consistent
|
||||
with the `input_ids`.
|
||||
|
||||
Args:
|
||||
input_features: A batch of audio inputs, [1, 80, M].
|
||||
"""
|
||||
audio_input = self._parse_and_validate_audio_input(**kwargs)
|
||||
if audio_input is not None:
|
||||
audio_embeddings = self._process_audio_input(audio_input)
|
||||
inputs_embeds = self.language_model.model.get_input_embeddings(
|
||||
input_ids)
|
||||
|
||||
inputs_embeds = merge_multimodal_embeddings(
|
||||
input_ids, inputs_embeds, audio_embeddings,
|
||||
_AUDIO_PLACEHOLDER_TOKEN)
|
||||
input_ids = None
|
||||
else:
|
||||
inputs_embeds = None
|
||||
|
||||
hidden_states = self.language_model.model(
|
||||
input_ids=input_ids,
|
||||
positions=positions,
|
||||
kv_caches=kv_caches,
|
||||
attn_metadata=attn_metadata,
|
||||
intermediate_tensors=intermediate_tensors,
|
||||
inputs_embeds=inputs_embeds)
|
||||
return hidden_states
|
||||
|
||||
def compute_logits(self, hidden_states: torch.Tensor,
|
||||
sampling_metadata: SamplingMetadata) -> torch.Tensor:
|
||||
return self.language_model.compute_logits(hidden_states,
|
||||
sampling_metadata)
|
||||
|
||||
def sample(
|
||||
self,
|
||||
logits: torch.Tensor,
|
||||
sampling_metadata: SamplingMetadata,
|
||||
) -> Optional[SamplerOutput]:
|
||||
return self.language_model.sample(logits, sampling_metadata)
|
||||
|
||||
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
||||
# prepare weight iterators for components
|
||||
projector_weights, llm_weights = itertools.tee(weights, 2)
|
||||
|
||||
# load projector weights
|
||||
projector_weights = filter_weights(projector_weights,
|
||||
"multi_modal_projector")
|
||||
projector_params_dict = dict(
|
||||
self.multi_modal_projector.named_parameters())
|
||||
for name, loaded_weight in projector_weights:
|
||||
param = projector_params_dict[name]
|
||||
weight_loader = getattr(param, "weight_loader",
|
||||
default_weight_loader)
|
||||
weight_loader(param, loaded_weight)
|
||||
|
||||
# load llm backbone
|
||||
llm_weights = filter_weights(llm_weights, "language_model")
|
||||
self.language_model.load_weights(llm_weights)
|
@ -1,5 +1,4 @@
|
||||
from functools import lru_cache
|
||||
from typing import List, Optional, Tuple, TypeVar
|
||||
|
||||
import torch
|
||||
from PIL import Image
|
||||
@ -8,7 +7,6 @@ from vllm.config import ModelConfig
|
||||
from vllm.inputs.registry import InputContext
|
||||
from vllm.logger import init_logger
|
||||
from vllm.transformers_utils.image_processor import get_image_processor
|
||||
from vllm.transformers_utils.tokenizer import AnyTokenizer, get_tokenizer
|
||||
from vllm.utils import is_list_of
|
||||
|
||||
from .base import MultiModalData, MultiModalInputs, MultiModalPlugin
|
||||
@ -16,87 +14,6 @@ from .base import MultiModalData, MultiModalInputs, MultiModalPlugin
|
||||
logger = init_logger(__name__)
|
||||
|
||||
cached_get_image_processor = lru_cache(get_image_processor)
|
||||
cached_get_tokenizer = lru_cache(get_tokenizer)
|
||||
|
||||
# Utilities for image input processors
|
||||
_T = TypeVar("_T", str, int)
|
||||
|
||||
|
||||
def repeat_and_pad_token(
|
||||
token: _T,
|
||||
*,
|
||||
repeat_count: int = 1,
|
||||
pad_token_left: Optional[_T] = None,
|
||||
pad_token_right: Optional[_T] = None,
|
||||
) -> List[_T]:
|
||||
replacement = [token] * repeat_count
|
||||
if pad_token_left is not None:
|
||||
replacement = [pad_token_left] + replacement
|
||||
if pad_token_right is not None:
|
||||
replacement = replacement + [pad_token_right]
|
||||
|
||||
return replacement
|
||||
|
||||
|
||||
def repeat_and_pad_image_tokens(
|
||||
tokenizer: AnyTokenizer,
|
||||
prompt: Optional[str],
|
||||
prompt_token_ids: List[int],
|
||||
*,
|
||||
image_token_id: int,
|
||||
repeat_count: int = 1,
|
||||
pad_token_left: Optional[int] = None,
|
||||
pad_token_right: Optional[int] = None,
|
||||
) -> Tuple[Optional[str], List[int]]:
|
||||
if prompt is None:
|
||||
new_prompt = None
|
||||
else:
|
||||
image_token_str = tokenizer.decode(image_token_id)
|
||||
pad_token_str_left = (None if pad_token_left is None else
|
||||
tokenizer.decode(pad_token_left))
|
||||
pad_token_str_right = (None if pad_token_right is None else
|
||||
tokenizer.decode(pad_token_right))
|
||||
replacement_str = "".join(
|
||||
repeat_and_pad_token(
|
||||
image_token_str,
|
||||
repeat_count=repeat_count,
|
||||
pad_token_left=pad_token_str_left,
|
||||
pad_token_right=pad_token_str_right,
|
||||
))
|
||||
|
||||
image_token_count = prompt.count(image_token_str)
|
||||
# This is an arbitrary number to distinguish between the two cases
|
||||
if image_token_count > 16:
|
||||
logger.warning(
|
||||
"Please follow the prompt format that is "
|
||||
"documented on HuggingFace which does not involve "
|
||||
"repeating %s tokens.", image_token_str)
|
||||
elif image_token_count > 1:
|
||||
logger.warning("Multiple image input is not supported yet, "
|
||||
"so any extra image tokens will be treated "
|
||||
"as plain text.")
|
||||
|
||||
# The image tokens are removed to be consistent with HuggingFace
|
||||
new_prompt = prompt.replace(image_token_str, replacement_str, 1)
|
||||
|
||||
new_token_ids: List[int] = []
|
||||
for i, token in enumerate(prompt_token_ids):
|
||||
if token == image_token_id:
|
||||
replacement_ids = repeat_and_pad_token(
|
||||
image_token_id,
|
||||
repeat_count=repeat_count,
|
||||
pad_token_left=pad_token_left,
|
||||
pad_token_right=pad_token_right,
|
||||
)
|
||||
new_token_ids.extend(replacement_ids)
|
||||
|
||||
# No need to further scan the list since we only replace once
|
||||
new_token_ids.extend(prompt_token_ids[i + 1:])
|
||||
break
|
||||
else:
|
||||
new_token_ids.append(token)
|
||||
|
||||
return new_prompt, new_token_ids
|
||||
|
||||
|
||||
class ImagePlugin(MultiModalPlugin):
|
||||
|
@ -1,6 +1,7 @@
|
||||
import base64
|
||||
from functools import lru_cache
|
||||
from io import BytesIO
|
||||
from typing import Tuple, Union
|
||||
from typing import List, Optional, Tuple, TypeVar, Union
|
||||
|
||||
import librosa
|
||||
import numpy as np
|
||||
@ -9,7 +10,13 @@ from PIL import Image
|
||||
|
||||
from vllm.connections import global_http_connection
|
||||
from vllm.envs import VLLM_AUDIO_FETCH_TIMEOUT, VLLM_IMAGE_FETCH_TIMEOUT
|
||||
from vllm.logger import init_logger
|
||||
from vllm.multimodal.base import MultiModalDataDict
|
||||
from vllm.transformers_utils.tokenizer import AnyTokenizer, get_tokenizer
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
cached_get_tokenizer = lru_cache(get_tokenizer)
|
||||
|
||||
|
||||
def _load_image_from_bytes(b: bytes):
|
||||
@ -154,3 +161,84 @@ def rescale_image_size(image: Image.Image,
|
||||
if transpose >= 0:
|
||||
image = image.transpose(Image.Transpose(transpose))
|
||||
return image
|
||||
|
||||
|
||||
# Utilities for input processors
|
||||
_T = TypeVar("_T", str, int)
|
||||
|
||||
|
||||
def repeat_and_pad_token(
|
||||
token: _T,
|
||||
*,
|
||||
repeat_count: int = 1,
|
||||
pad_token_left: Optional[_T] = None,
|
||||
pad_token_right: Optional[_T] = None,
|
||||
) -> List[_T]:
|
||||
replacement = [token] * repeat_count
|
||||
if pad_token_left is not None:
|
||||
replacement = [pad_token_left] + replacement
|
||||
if pad_token_right is not None:
|
||||
replacement = replacement + [pad_token_right]
|
||||
|
||||
return replacement
|
||||
|
||||
|
||||
def repeat_and_pad_placeholder_tokens(
|
||||
tokenizer: AnyTokenizer,
|
||||
prompt: Optional[str],
|
||||
prompt_token_ids: List[int],
|
||||
*,
|
||||
placeholder_token_id: int,
|
||||
repeat_count: int = 1,
|
||||
pad_token_left: Optional[int] = None,
|
||||
pad_token_right: Optional[int] = None,
|
||||
) -> Tuple[Optional[str], List[int]]:
|
||||
if prompt is None:
|
||||
new_prompt = None
|
||||
else:
|
||||
placeholder_token_str = tokenizer.decode(placeholder_token_id)
|
||||
pad_token_str_left = (None if pad_token_left is None else
|
||||
tokenizer.decode(pad_token_left))
|
||||
pad_token_str_right = (None if pad_token_right is None else
|
||||
tokenizer.decode(pad_token_right))
|
||||
replacement_str = "".join(
|
||||
repeat_and_pad_token(
|
||||
placeholder_token_str,
|
||||
repeat_count=repeat_count,
|
||||
pad_token_left=pad_token_str_left,
|
||||
pad_token_right=pad_token_str_right,
|
||||
))
|
||||
|
||||
placeholder_token_count = prompt.count(placeholder_token_str)
|
||||
# This is an arbitrary number to distinguish between the two cases
|
||||
if placeholder_token_count > 16:
|
||||
logger.warning(
|
||||
"Please follow the prompt format that is "
|
||||
"documented on HuggingFace which does not involve "
|
||||
"repeating %s tokens.", placeholder_token_str)
|
||||
elif placeholder_token_count > 1:
|
||||
logger.warning("Multiple multi-modal input is not supported yet, "
|
||||
"so any extra placeholder tokens will be treated "
|
||||
"as plain text.")
|
||||
|
||||
# The image tokens are removed to be consistent with HuggingFace
|
||||
new_prompt = prompt.replace(placeholder_token_str, replacement_str, 1)
|
||||
|
||||
new_token_ids: List[int] = []
|
||||
for i, token in enumerate(prompt_token_ids):
|
||||
if token == placeholder_token_id:
|
||||
replacement_ids = repeat_and_pad_token(
|
||||
placeholder_token_id,
|
||||
repeat_count=repeat_count,
|
||||
pad_token_left=pad_token_left,
|
||||
pad_token_right=pad_token_right,
|
||||
)
|
||||
new_token_ids.extend(replacement_ids)
|
||||
|
||||
# No need to further scan the list since we only replace once
|
||||
new_token_ids.extend(prompt_token_ids[i + 1:])
|
||||
break
|
||||
else:
|
||||
new_token_ids.append(token)
|
||||
|
||||
return new_prompt, new_token_ids
|
||||
|
@ -12,7 +12,7 @@ from vllm.transformers_utils.configs import (ChatGLMConfig, DbrxConfig,
|
||||
InternVLChatConfig, JAISConfig,
|
||||
MedusaConfig, MLPSpeculatorConfig,
|
||||
MPTConfig, NemotronConfig,
|
||||
RWConfig)
|
||||
RWConfig, UltravoxConfig)
|
||||
|
||||
if VLLM_USE_MODELSCOPE:
|
||||
from modelscope import AutoConfig
|
||||
@ -32,6 +32,7 @@ _CONFIG_REGISTRY: Dict[str, Type[PretrainedConfig]] = {
|
||||
"medusa": MedusaConfig,
|
||||
"internvl_chat": InternVLChatConfig,
|
||||
"nemotron": NemotronConfig,
|
||||
"ultravox": UltravoxConfig,
|
||||
}
|
||||
|
||||
for name, cls in _CONFIG_REGISTRY.items():
|
||||
|
@ -10,6 +10,7 @@ from vllm.transformers_utils.configs.medusa import MedusaConfig
|
||||
from vllm.transformers_utils.configs.mlp_speculator import MLPSpeculatorConfig
|
||||
from vllm.transformers_utils.configs.mpt import MPTConfig
|
||||
from vllm.transformers_utils.configs.nemotron import NemotronConfig
|
||||
from vllm.transformers_utils.configs.ultravox import UltravoxConfig
|
||||
|
||||
__all__ = [
|
||||
"ChatGLMConfig",
|
||||
@ -21,4 +22,5 @@ __all__ = [
|
||||
"MedusaConfig",
|
||||
"MLPSpeculatorConfig",
|
||||
"NemotronConfig",
|
||||
"UltravoxConfig",
|
||||
]
|
||||
|
99
vllm/transformers_utils/configs/ultravox.py
Normal file
99
vllm/transformers_utils/configs/ultravox.py
Normal file
@ -0,0 +1,99 @@
|
||||
# Adapted from https://github.com/fixie-ai/ultravox/blob/ecd58c4041030bae2ad15aa6bcf04ab43199ea02/ultravox/model/ultravox_config.py
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
import transformers
|
||||
|
||||
|
||||
class UltravoxConfig(transformers.PretrainedConfig):
|
||||
r"""
|
||||
This is the configuration class to store the configuration of a
|
||||
[`UltravoxForConditionalGeneration`]. It is used to instantiate an
|
||||
Ultravox model according to the specified arguments, defining the model
|
||||
architecture.
|
||||
|
||||
Configuration objects inherit from [`PretrainedConfig`] and can be used to
|
||||
control the model outputs. Read the documentation from [`PretrainedConfig`]
|
||||
for more information.
|
||||
|
||||
Args:
|
||||
audio_config (`Union[AutoConfig, dict]`, *optional*):
|
||||
Custom audio config or dict
|
||||
text_config (`Union[AutoConfig, dict]`, *optional*):
|
||||
The config object of the text backbone. Can be any of `LlamaConfig`
|
||||
or `MistralConfig`.
|
||||
ignore_index (`int`, *optional*, defaults to -100):
|
||||
The ignore index for the loss function.
|
||||
audio_token_index (`int`, *optional*, defaults to 32000):
|
||||
The audio token index to encode the audio prompt.
|
||||
stack_factor (`int`, *optional*, defaults to 8):
|
||||
Audio downsampling factor for the multimodal projector.
|
||||
norm_init (`float`, *optional*, defaults to 0.4):
|
||||
The initialization value for the layer normalization.
|
||||
projector_act (`str`, *optional*, defaults to `"swiglu"`):
|
||||
The activation function used by the multimodal projector.
|
||||
text_model_lora_config (`LoraConfigSimplified`, *optional*):
|
||||
The LoRA configuration for finetuning the text model.
|
||||
audio_model_lora_config (`LoraConfigSimplified`, *optional*):
|
||||
The LoRA configuration for finetuning the audio model.
|
||||
"""
|
||||
|
||||
model_type = "ultravox"
|
||||
is_composition = False
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
audio_config: Optional[Dict[str, Any]] = None,
|
||||
text_config: Optional[Dict[str, Any]] = None,
|
||||
audio_model_id: Optional[str] = None,
|
||||
text_model_id: Optional[str] = None,
|
||||
ignore_index: int = -100,
|
||||
audio_token_index: int = 32000,
|
||||
hidden_size: int = 4096,
|
||||
stack_factor: int = 8,
|
||||
norm_init: float = 0.4,
|
||||
projector_act: str = "swiglu",
|
||||
text_model_lora_config: Optional[Dict[str, Any]] = None,
|
||||
audio_model_lora_config: Optional[Dict[str, Any]] = None,
|
||||
**kwargs,
|
||||
):
|
||||
self.ignore_index = ignore_index
|
||||
|
||||
self.audio_model_id = audio_model_id
|
||||
self.text_model_id = text_model_id
|
||||
self.audio_token_index = audio_token_index
|
||||
|
||||
self.hidden_size = hidden_size
|
||||
self.stack_factor = stack_factor
|
||||
self.norm_init = norm_init
|
||||
self.projector_act = projector_act
|
||||
|
||||
if text_model_id is not None:
|
||||
# Avoid circular import
|
||||
from vllm.transformers_utils.config import get_config
|
||||
|
||||
self.text_config = get_config(text_model_id,
|
||||
trust_remote_code=False)
|
||||
else:
|
||||
text_config = text_config or {}
|
||||
self.text_config = transformers.CONFIG_MAPPING[text_config.get(
|
||||
"model_type", "llama")](**text_config)
|
||||
|
||||
if audio_model_id is not None:
|
||||
# Avoid circular import
|
||||
from vllm.transformers_utils.config import get_config
|
||||
|
||||
self.audio_config = get_config(audio_model_id,
|
||||
trust_remote_code=False)
|
||||
else:
|
||||
audio_config = audio_config or {}
|
||||
self.audio_config = transformers.CONFIG_MAPPING[audio_config.get(
|
||||
"model_type", "whisper")](**audio_config)
|
||||
|
||||
self.text_model_lora_config = text_model_lora_config or {}
|
||||
self.audio_model_lora_config = audio_model_lora_config or {}
|
||||
|
||||
self.vocab_size = self.text_config.vocab_size
|
||||
|
||||
self.initializer_range = self.text_config.initializer_range
|
||||
|
||||
super().__init__(**kwargs)
|
Loading…
x
Reference in New Issue
Block a user