[Model] LoRA Support for Ultravox model (#11253)
This commit is contained in:
parent
9cdea30b4f
commit
d88506dda4
@ -857,7 +857,7 @@ See [this page](#generative-models) for more information on how to use generativ
|
|||||||
* Ultravox
|
* Ultravox
|
||||||
* T + A<sup>E+</sup>
|
* T + A<sup>E+</sup>
|
||||||
* `fixie-ai/ultravox-v0_3`
|
* `fixie-ai/ultravox-v0_3`
|
||||||
*
|
* ✅︎
|
||||||
* ✅︎
|
* ✅︎
|
||||||
* ✅︎
|
* ✅︎
|
||||||
:::
|
:::
|
||||||
|
@ -737,6 +737,7 @@ class VllmRunner:
|
|||||||
images: Optional[PromptImageInput] = None,
|
images: Optional[PromptImageInput] = None,
|
||||||
videos: Optional[PromptVideoInput] = None,
|
videos: Optional[PromptVideoInput] = None,
|
||||||
audios: Optional[PromptAudioInput] = None,
|
audios: Optional[PromptAudioInput] = None,
|
||||||
|
**kwargs: Any,
|
||||||
) -> List[Tuple[List[List[int]], List[str]]]:
|
) -> List[Tuple[List[List[int]], List[str]]]:
|
||||||
inputs = self.get_inputs(prompts,
|
inputs = self.get_inputs(prompts,
|
||||||
images=images,
|
images=images,
|
||||||
@ -744,7 +745,8 @@ class VllmRunner:
|
|||||||
audios=audios)
|
audios=audios)
|
||||||
|
|
||||||
req_outputs = self.model.generate(inputs,
|
req_outputs = self.model.generate(inputs,
|
||||||
sampling_params=sampling_params)
|
sampling_params=sampling_params,
|
||||||
|
**kwargs)
|
||||||
|
|
||||||
outputs: List[Tuple[List[List[int]], List[str]]] = []
|
outputs: List[Tuple[List[List[int]], List[str]]] = []
|
||||||
for req_output in req_outputs:
|
for req_output in req_outputs:
|
||||||
@ -782,6 +784,7 @@ class VllmRunner:
|
|||||||
images: Optional[PromptImageInput] = None,
|
images: Optional[PromptImageInput] = None,
|
||||||
audios: Optional[PromptAudioInput] = None,
|
audios: Optional[PromptAudioInput] = None,
|
||||||
videos: Optional[PromptVideoInput] = None,
|
videos: Optional[PromptVideoInput] = None,
|
||||||
|
**kwargs: Any,
|
||||||
) -> Union[List[TokensTextLogprobs],
|
) -> Union[List[TokensTextLogprobs],
|
||||||
List[TokensTextLogprobsPromptLogprobs]]:
|
List[TokensTextLogprobsPromptLogprobs]]:
|
||||||
inputs = self.get_inputs(prompts,
|
inputs = self.get_inputs(prompts,
|
||||||
@ -790,7 +793,8 @@ class VllmRunner:
|
|||||||
audios=audios)
|
audios=audios)
|
||||||
|
|
||||||
req_outputs = self.model.generate(inputs,
|
req_outputs = self.model.generate(inputs,
|
||||||
sampling_params=sampling_params)
|
sampling_params=sampling_params,
|
||||||
|
**kwargs)
|
||||||
|
|
||||||
toks_str_logsprobs_prompt_logprobs = (
|
toks_str_logsprobs_prompt_logprobs = (
|
||||||
self._final_steps_generate_w_logprobs(req_outputs))
|
self._final_steps_generate_w_logprobs(req_outputs))
|
||||||
@ -826,13 +830,15 @@ class VllmRunner:
|
|||||||
images: Optional[PromptImageInput] = None,
|
images: Optional[PromptImageInput] = None,
|
||||||
videos: Optional[PromptVideoInput] = None,
|
videos: Optional[PromptVideoInput] = None,
|
||||||
audios: Optional[PromptAudioInput] = None,
|
audios: Optional[PromptAudioInput] = None,
|
||||||
|
**kwargs: Any,
|
||||||
) -> List[Tuple[List[int], str]]:
|
) -> List[Tuple[List[int], str]]:
|
||||||
greedy_params = SamplingParams(temperature=0.0, max_tokens=max_tokens)
|
greedy_params = SamplingParams(temperature=0.0, max_tokens=max_tokens)
|
||||||
outputs = self.generate(prompts,
|
outputs = self.generate(prompts,
|
||||||
greedy_params,
|
greedy_params,
|
||||||
images=images,
|
images=images,
|
||||||
videos=videos,
|
videos=videos,
|
||||||
audios=audios)
|
audios=audios,
|
||||||
|
**kwargs)
|
||||||
return [(output_ids[0], output_str[0])
|
return [(output_ids[0], output_str[0])
|
||||||
for output_ids, output_str in outputs]
|
for output_ids, output_str in outputs]
|
||||||
|
|
||||||
@ -847,6 +853,7 @@ class VllmRunner:
|
|||||||
videos: Optional[PromptVideoInput] = None,
|
videos: Optional[PromptVideoInput] = None,
|
||||||
stop_token_ids: Optional[List[int]] = None,
|
stop_token_ids: Optional[List[int]] = None,
|
||||||
stop: Optional[List[str]] = None,
|
stop: Optional[List[str]] = None,
|
||||||
|
**kwargs: Any,
|
||||||
) -> Union[List[TokensTextLogprobs],
|
) -> Union[List[TokensTextLogprobs],
|
||||||
List[TokensTextLogprobsPromptLogprobs]]:
|
List[TokensTextLogprobsPromptLogprobs]]:
|
||||||
greedy_logprobs_params = SamplingParams(
|
greedy_logprobs_params = SamplingParams(
|
||||||
@ -861,7 +868,8 @@ class VllmRunner:
|
|||||||
greedy_logprobs_params,
|
greedy_logprobs_params,
|
||||||
images=images,
|
images=images,
|
||||||
audios=audios,
|
audios=audios,
|
||||||
videos=videos)
|
videos=videos,
|
||||||
|
**kwargs)
|
||||||
|
|
||||||
def generate_encoder_decoder_greedy_logprobs(
|
def generate_encoder_decoder_greedy_logprobs(
|
||||||
self,
|
self,
|
||||||
|
121
tests/lora/test_ultravox.py
Normal file
121
tests/lora/test_ultravox.py
Normal file
@ -0,0 +1,121 @@
|
|||||||
|
import shutil
|
||||||
|
from os import path
|
||||||
|
from tempfile import TemporaryDirectory
|
||||||
|
from typing import List, Tuple
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from huggingface_hub import snapshot_download
|
||||||
|
from safetensors.torch import load_file, save_file
|
||||||
|
from transformers import AutoTokenizer
|
||||||
|
|
||||||
|
from vllm.lora.request import LoRARequest
|
||||||
|
|
||||||
|
from ..models.utils import check_outputs_equal
|
||||||
|
|
||||||
|
ULTRAVOX_MODEL_NAME = "fixie-ai/ultravox-v0_3"
|
||||||
|
LLMA_MODEL_NAME = "meta-llama/Llama-3.1-8B-Instruct"
|
||||||
|
|
||||||
|
VLLM_PLACEHOLDER = "<|reserved_special_token_0|>"
|
||||||
|
|
||||||
|
PROMPT = "Tell me about a Fool's mate move in 20 words. Provide the moves!"
|
||||||
|
|
||||||
|
|
||||||
|
def llama3_1_8b_chess_lora_path():
|
||||||
|
return snapshot_download(
|
||||||
|
repo_id="mkopecki/chess-lora-adapter-llama-3.1-8b")
|
||||||
|
|
||||||
|
|
||||||
|
# can't use llama lora adapter without module name transformation
|
||||||
|
# because ultravox nest language model
|
||||||
|
def transform_module_names_for_ultravox(state_dict):
|
||||||
|
transformed_state_dict = {}
|
||||||
|
for key, value in state_dict.items():
|
||||||
|
new_key = key.replace("base_model.model",
|
||||||
|
"base_model.model.language_model")
|
||||||
|
transformed_state_dict[new_key] = value
|
||||||
|
return transformed_state_dict
|
||||||
|
|
||||||
|
|
||||||
|
def mk_llama3_1_8b_ultravox_chess_lora(source_repo, target_path):
|
||||||
|
tensor_file = "adapter_model.safetensors"
|
||||||
|
state_dict = load_file(path.join(source_repo, tensor_file))
|
||||||
|
transformed_state_dict = transform_module_names_for_ultravox(state_dict)
|
||||||
|
|
||||||
|
save_file(transformed_state_dict, path.join(target_path, tensor_file))
|
||||||
|
|
||||||
|
config_file = "adapter_config.json"
|
||||||
|
shutil.copyfile(path.join(source_repo, config_file),
|
||||||
|
path.join(target_path, config_file))
|
||||||
|
return target_path
|
||||||
|
|
||||||
|
|
||||||
|
def _get_prompt(audio_count, question, placeholder, model_name) -> str:
|
||||||
|
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
||||||
|
placeholder = f"{placeholder}\n" * audio_count
|
||||||
|
|
||||||
|
return tokenizer.apply_chat_template([{
|
||||||
|
'role': 'user',
|
||||||
|
'content': f"{placeholder}{question}"
|
||||||
|
}],
|
||||||
|
tokenize=False,
|
||||||
|
add_generation_prompt=True)
|
||||||
|
|
||||||
|
|
||||||
|
def test_ultravox_lora(vllm_runner):
|
||||||
|
"""
|
||||||
|
TODO: Train an Ultravox LoRA instead of using a Llama LoRA.
|
||||||
|
"""
|
||||||
|
# Workaround to prevent device mismatch in Whisper.
|
||||||
|
# Can be removed when it is fixed upstream in transformer
|
||||||
|
# https://github.com/huggingface/transformers/pull/35866
|
||||||
|
torch.set_default_device("cpu")
|
||||||
|
|
||||||
|
llama3_1_8b_chess_lora = llama3_1_8b_chess_lora_path()
|
||||||
|
with TemporaryDirectory() as temp_ultravox_lora_dir:
|
||||||
|
llama3_1_8b_ultravox_chess_lora = mk_llama3_1_8b_ultravox_chess_lora(
|
||||||
|
llama3_1_8b_chess_lora, temp_ultravox_lora_dir)
|
||||||
|
with vllm_runner(
|
||||||
|
ULTRAVOX_MODEL_NAME,
|
||||||
|
enforce_eager=True,
|
||||||
|
max_num_seqs=2,
|
||||||
|
enable_lora=True,
|
||||||
|
max_loras=1,
|
||||||
|
max_lora_rank=128,
|
||||||
|
dtype="bfloat16",
|
||||||
|
max_model_len=1024,
|
||||||
|
) as vllm_model:
|
||||||
|
ultravox_outputs: List[Tuple[
|
||||||
|
List[int], str]] = vllm_model.generate_greedy(
|
||||||
|
[
|
||||||
|
_get_prompt(0, PROMPT, VLLM_PLACEHOLDER,
|
||||||
|
ULTRAVOX_MODEL_NAME)
|
||||||
|
],
|
||||||
|
256,
|
||||||
|
lora_request=LoRARequest(str(1), 1,
|
||||||
|
llama3_1_8b_ultravox_chess_lora),
|
||||||
|
)
|
||||||
|
|
||||||
|
# run llama with and without lora to compare outputs with above
|
||||||
|
with vllm_runner(
|
||||||
|
LLMA_MODEL_NAME,
|
||||||
|
enforce_eager=True,
|
||||||
|
max_num_seqs=2,
|
||||||
|
enable_lora=True,
|
||||||
|
max_loras=1,
|
||||||
|
max_lora_rank=128,
|
||||||
|
dtype="bfloat16",
|
||||||
|
max_model_len=1024,
|
||||||
|
) as vllm_model:
|
||||||
|
llama_outputs: List[Tuple[List[int], str]] = (
|
||||||
|
vllm_model.generate_greedy(
|
||||||
|
[_get_prompt(0, PROMPT, VLLM_PLACEHOLDER, LLMA_MODEL_NAME)],
|
||||||
|
256,
|
||||||
|
lora_request=LoRARequest(str(1), 1, llama3_1_8b_chess_lora),
|
||||||
|
))
|
||||||
|
|
||||||
|
check_outputs_equal(
|
||||||
|
outputs_0_lst=ultravox_outputs,
|
||||||
|
outputs_1_lst=llama_outputs,
|
||||||
|
name_0="ultravox",
|
||||||
|
name_1="llama",
|
||||||
|
)
|
@ -22,6 +22,7 @@ from vllm.model_executor.layers.activation import MulAndSilu, get_act_fn
|
|||||||
from vllm.model_executor.layers.layernorm import RMSNorm
|
from vllm.model_executor.layers.layernorm import RMSNorm
|
||||||
from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler
|
from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler
|
||||||
from vllm.model_executor.model_loader.loader import DefaultModelLoader
|
from vllm.model_executor.model_loader.loader import DefaultModelLoader
|
||||||
|
from vllm.model_executor.models.module_mapping import MultiModelKeys
|
||||||
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
||||||
from vllm.multimodal import MULTIMODAL_REGISTRY
|
from vllm.multimodal import MULTIMODAL_REGISTRY
|
||||||
from vllm.multimodal.inputs import (MultiModalFieldConfig, MultiModalKwargs,
|
from vllm.multimodal.inputs import (MultiModalFieldConfig, MultiModalKwargs,
|
||||||
@ -33,7 +34,7 @@ from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs
|
|||||||
from vllm.sequence import IntermediateTensors
|
from vllm.sequence import IntermediateTensors
|
||||||
from vllm.transformers_utils.configs.ultravox import UltravoxConfig
|
from vllm.transformers_utils.configs.ultravox import UltravoxConfig
|
||||||
|
|
||||||
from .interfaces import SupportsMultiModal, SupportsPP
|
from .interfaces import SupportsLoRA, SupportsMultiModal, SupportsPP
|
||||||
from .utils import (AutoWeightsLoader, WeightsMapper, flatten_bn,
|
from .utils import (AutoWeightsLoader, WeightsMapper, flatten_bn,
|
||||||
init_vllm_registered_model, maybe_prefix,
|
init_vllm_registered_model, maybe_prefix,
|
||||||
merge_multimodal_embeddings,
|
merge_multimodal_embeddings,
|
||||||
@ -343,7 +344,20 @@ class ModifiedWhisperEncoder(WhisperEncoder):
|
|||||||
UltravoxMultiModalProcessor,
|
UltravoxMultiModalProcessor,
|
||||||
info=UltravoxProcessingInfo,
|
info=UltravoxProcessingInfo,
|
||||||
dummy_inputs=UltravoxDummyInputsBuilder)
|
dummy_inputs=UltravoxDummyInputsBuilder)
|
||||||
class UltravoxModel(nn.Module, SupportsMultiModal, SupportsPP):
|
class UltravoxModel(nn.Module, SupportsMultiModal, SupportsPP, SupportsLoRA):
|
||||||
|
|
||||||
|
packed_modules_mapping = {
|
||||||
|
"qkv_proj": ["q_proj", "k_proj", "v_proj"],
|
||||||
|
"gate_up_proj": ["gate_proj", "up_proj"]
|
||||||
|
}
|
||||||
|
|
||||||
|
# LoRA specific attributes
|
||||||
|
# TODO : Add LoRA to the audio tower and projector.
|
||||||
|
supported_lora_modules = [
|
||||||
|
"qkv_proj", "o_proj", "gate_up_proj", "down_proj"
|
||||||
|
]
|
||||||
|
embedding_modules = {}
|
||||||
|
embedding_padding_modules = []
|
||||||
|
|
||||||
hf_to_vllm_mapper = WeightsMapper(
|
hf_to_vllm_mapper = WeightsMapper(
|
||||||
orig_to_new_prefix={"audio_tower.model.encoder.": "audio_tower."})
|
orig_to_new_prefix={"audio_tower.model.encoder.": "audio_tower."})
|
||||||
@ -391,6 +405,16 @@ class UltravoxModel(nn.Module, SupportsMultiModal, SupportsPP):
|
|||||||
|
|
||||||
return get_sampler()
|
return get_sampler()
|
||||||
|
|
||||||
|
def get_mm_mapping(self) -> MultiModelKeys:
|
||||||
|
"""
|
||||||
|
Get the module prefix in multimodal models
|
||||||
|
"""
|
||||||
|
return MultiModelKeys.from_string_field(
|
||||||
|
language_model="language_model.",
|
||||||
|
connector="multi_modal_projector.",
|
||||||
|
tower_model="audio_tower.",
|
||||||
|
)
|
||||||
|
|
||||||
def _audio_features_to_embeddings(
|
def _audio_features_to_embeddings(
|
||||||
self, input_features: torch.Tensor) -> torch.Tensor:
|
self, input_features: torch.Tensor) -> torch.Tensor:
|
||||||
audio_input = input_features.to(self.audio_tower.dtype)
|
audio_input = input_features.to(self.audio_tower.dtype)
|
||||||
|
Loading…
x
Reference in New Issue
Block a user