[Misc] Add Phi4-MM example (#14343)
Signed-off-by: Jee Jee Li <pandaleefree@gmail.com>
This commit is contained in:
parent
d0feea31c7
commit
952a074980
@ -6,10 +6,14 @@ with the correct prompt format on audio language models.
|
||||
For most models, the prompt format should follow corresponding examples
|
||||
on HuggingFace model repository.
|
||||
"""
|
||||
import os
|
||||
|
||||
from huggingface_hub import snapshot_download
|
||||
from transformers import AutoTokenizer
|
||||
|
||||
from vllm import LLM, SamplingParams
|
||||
from vllm.assets.audio import AudioAsset
|
||||
from vllm.lora.request import LoRARequest
|
||||
from vllm.utils import FlexibleArgumentParser
|
||||
|
||||
audio_assets = [AudioAsset("mary_had_lamb"), AudioAsset("winning_call")]
|
||||
@ -51,6 +55,39 @@ def run_minicpmo(question: str, audio_count: int):
|
||||
return llm, prompt, stop_token_ids
|
||||
|
||||
|
||||
# Phi-4-multimodal-instruct
|
||||
def run_phi4mm(questions: str, audio_count: int):
|
||||
"""
|
||||
Phi-4-multimodal-instruct supports both image and audio inputs. Here, we
|
||||
show how to process audio inputs.
|
||||
"""
|
||||
model_path = snapshot_download("microsoft/Phi-4-multimodal-instruct")
|
||||
# Since the vision-lora and speech-lora co-exist with the base model,
|
||||
# we have to manually specify the path of the lora weights.
|
||||
speech_lora_path = os.path.join(model_path, "speech-lora")
|
||||
placeholders = "".join([f"<|audio_{i+1}|>" for i in range(audio_count)])
|
||||
|
||||
prompts = f"<|user|>{placeholders}{questions}<|end|><|assistant|>"
|
||||
|
||||
llm = LLM(
|
||||
model=model_path,
|
||||
trust_remote_code=True,
|
||||
max_model_len=4096,
|
||||
max_num_seqs=2,
|
||||
enable_lora=True,
|
||||
max_lora_rank=320,
|
||||
lora_extra_vocab_size=0,
|
||||
)
|
||||
lora_request = LoRARequest("speech", 1, speech_lora_path)
|
||||
# To maintain code compatibility in this script, we add LoRA here.
|
||||
llm.llm_engine.add_lora(lora_request=lora_request)
|
||||
# You can also add LoRA using:
|
||||
# llm.generate(prompts, lora_request=lora_request,...)
|
||||
|
||||
stop_token_ids = None
|
||||
return llm, prompts, stop_token_ids
|
||||
|
||||
|
||||
# Qwen2-Audio
|
||||
def run_qwen2_audio(question: str, audio_count: int):
|
||||
model_name = "Qwen/Qwen2-Audio-7B-Instruct"
|
||||
@ -113,6 +150,7 @@ def run_whisper(question: str, audio_count: int):
|
||||
|
||||
model_example_map = {
|
||||
"minicpmo": run_minicpmo,
|
||||
"phi4_mm": run_phi4mm,
|
||||
"qwen2_audio": run_qwen2_audio,
|
||||
"ultravox": run_ultravox,
|
||||
"whisper": run_whisper,
|
||||
|
@ -6,13 +6,16 @@ the correct prompt format on vision language models for text generation.
|
||||
For most models, the prompt format should follow corresponding examples
|
||||
on HuggingFace model repository.
|
||||
"""
|
||||
import os
|
||||
import random
|
||||
|
||||
from huggingface_hub import snapshot_download
|
||||
from transformers import AutoTokenizer
|
||||
|
||||
from vllm import LLM, SamplingParams
|
||||
from vllm.assets.image import ImageAsset
|
||||
from vllm.assets.video import VideoAsset
|
||||
from vllm.lora.request import LoRARequest
|
||||
from vllm.utils import FlexibleArgumentParser
|
||||
|
||||
# NOTE: The default `max_num_seqs` and `max_model_len` may result in OOM on
|
||||
@ -519,6 +522,40 @@ def run_phi3v(questions: list[str], modality: str):
|
||||
return llm, prompts, stop_token_ids
|
||||
|
||||
|
||||
# Phi-4-multimodal-instruct
|
||||
def run_phi4mm(questions: list[str], modality: str):
|
||||
"""
|
||||
Phi-4-multimodal-instruct supports both image and audio inputs. Here, we
|
||||
show how to process image inputs.
|
||||
"""
|
||||
assert modality == "image"
|
||||
model_path = snapshot_download("microsoft/Phi-4-multimodal-instruct")
|
||||
# Since the vision-lora and speech-lora co-exist with the base model,
|
||||
# we have to manually specify the path of the lora weights.
|
||||
vision_lora_path = os.path.join(model_path, "vision-lora")
|
||||
prompts = [
|
||||
f"<|user|><|image_1|>{question}<|end|><|assistant|>"
|
||||
for question in questions
|
||||
]
|
||||
llm = LLM(
|
||||
model=model_path,
|
||||
trust_remote_code=True,
|
||||
max_model_len=4096,
|
||||
max_num_seqs=2,
|
||||
enable_lora=True,
|
||||
max_lora_rank=320,
|
||||
lora_extra_vocab_size=0,
|
||||
)
|
||||
lora_request = LoRARequest("vision", 1, vision_lora_path)
|
||||
# To maintain code compatibility in this script, we add LoRA here.
|
||||
llm.llm_engine.add_lora(lora_request=lora_request)
|
||||
# You can also add LoRA using:
|
||||
# llm.generate(prompts, lora_request=lora_request,...)
|
||||
|
||||
stop_token_ids = None
|
||||
return llm, prompts, stop_token_ids
|
||||
|
||||
|
||||
# Pixtral HF-format
|
||||
def run_pixtral_hf(questions: list[str], modality: str):
|
||||
assert modality == "image"
|
||||
@ -644,6 +681,7 @@ model_example_map = {
|
||||
"paligemma": run_paligemma,
|
||||
"paligemma2": run_paligemma2,
|
||||
"phi3_v": run_phi3v,
|
||||
"phi4_mm": run_phi4mm,
|
||||
"pixtral_hf": run_pixtral_hf,
|
||||
"qwen_vl": run_qwen_vl,
|
||||
"qwen2_vl": run_qwen2_vl,
|
||||
|
@ -4,13 +4,16 @@ This example shows how to use vLLM for running offline inference with
|
||||
multi-image input on vision language models for text generation,
|
||||
using the chat template defined by the model.
|
||||
"""
|
||||
import os
|
||||
from argparse import Namespace
|
||||
from typing import NamedTuple, Optional
|
||||
|
||||
from huggingface_hub import snapshot_download
|
||||
from PIL.Image import Image
|
||||
from transformers import AutoProcessor, AutoTokenizer
|
||||
|
||||
from vllm import LLM, SamplingParams
|
||||
from vllm.lora.request import LoRARequest
|
||||
from vllm.multimodal.utils import fetch_image
|
||||
from vllm.utils import FlexibleArgumentParser
|
||||
|
||||
@ -294,6 +297,46 @@ def load_phi3v(question: str, image_urls: list[str]) -> ModelRequestData:
|
||||
)
|
||||
|
||||
|
||||
def load_phi4mm(question: str, image_urls: list[str]) -> ModelRequestData:
|
||||
"""
|
||||
Phi-4-multimodal-instruct supports both image and audio inputs. Here, we
|
||||
show how to process multi images inputs.
|
||||
"""
|
||||
|
||||
model_path = snapshot_download("microsoft/Phi-4-multimodal-instruct")
|
||||
# Since the vision-lora and speech-lora co-exist with the base model,
|
||||
# we have to manually specify the path of the lora weights.
|
||||
vision_lora_path = os.path.join(model_path, "vision-lora")
|
||||
llm = LLM(
|
||||
model=model_path,
|
||||
trust_remote_code=True,
|
||||
max_model_len=10000,
|
||||
max_num_seqs=2,
|
||||
limit_mm_per_prompt={"image": len(image_urls)},
|
||||
enable_lora=True,
|
||||
max_lora_rank=320,
|
||||
lora_extra_vocab_size=0,
|
||||
)
|
||||
lora_request = LoRARequest("vision", 1, vision_lora_path)
|
||||
# To maintain code compatibility in this script, we add LoRA here.
|
||||
llm.llm_engine.add_lora(lora_request=lora_request)
|
||||
# You can also add LoRA using:
|
||||
# llm.generate(prompts, lora_request=lora_request,...)
|
||||
|
||||
placeholders = "".join(f"<|image_{i}|>"
|
||||
for i, _ in enumerate(image_urls, start=1))
|
||||
prompt = f"<|user|>{placeholders}{question}<|end|><|assistant|>"
|
||||
stop_token_ids = None
|
||||
|
||||
return ModelRequestData(
|
||||
llm=llm,
|
||||
prompt=prompt,
|
||||
stop_token_ids=stop_token_ids,
|
||||
image_data=[fetch_image(url) for url in image_urls],
|
||||
chat_template=None,
|
||||
)
|
||||
|
||||
|
||||
def load_qwen_vl_chat(question: str,
|
||||
image_urls: list[str]) -> ModelRequestData:
|
||||
model_name = "Qwen/Qwen-VL-Chat"
|
||||
@ -459,6 +502,7 @@ model_example_map = {
|
||||
"mllama": load_mllama,
|
||||
"NVLM_D": load_nvlm_d,
|
||||
"phi3_v": load_phi3v,
|
||||
"phi4_mm": load_phi4mm,
|
||||
"pixtral_hf": load_pixtral_hf,
|
||||
"qwen_vl_chat": load_qwen_vl_chat,
|
||||
"qwen2_vl": load_qwen2_vl,
|
||||
|
@ -25,6 +25,7 @@ from vllm.model_executor.layers.sampler import Sampler, SamplerOutput
|
||||
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
||||
DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead)
|
||||
from vllm.model_executor.models.llama import LlamaModel
|
||||
from vllm.model_executor.models.module_mapping import MultiModelKeys
|
||||
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
||||
from vllm.multimodal import MULTIMODAL_REGISTRY
|
||||
from vllm.multimodal.inputs import MultiModalInputs, NestedTensors
|
||||
@ -1421,7 +1422,6 @@ class Phi4MMForCausalLM(nn.Module, SupportsLoRA, SupportsMultiModal):
|
||||
"""
|
||||
Implements the Phi-4-multimodal-instruct model in VLLM.
|
||||
"""
|
||||
# LoRA specific attributes
|
||||
packed_modules_mapping = {
|
||||
"qkv_proj": [
|
||||
"qkv_proj",
|
||||
@ -1430,12 +1430,6 @@ class Phi4MMForCausalLM(nn.Module, SupportsLoRA, SupportsMultiModal):
|
||||
"gate_up_proj",
|
||||
],
|
||||
}
|
||||
supported_lora_modules = [
|
||||
"qkv_proj", "o_proj", "gate_up_proj", "down_proj"
|
||||
]
|
||||
# Phi4MMForCausalLM does not apply LoRA to the embedding layer.
|
||||
embedding_modules = {}
|
||||
embedding_padding_modules = []
|
||||
|
||||
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
||||
super().__init__()
|
||||
@ -1801,3 +1795,13 @@ class Phi4MMForCausalLM(nn.Module, SupportsLoRA, SupportsMultiModal):
|
||||
) -> Optional[SamplerOutput]:
|
||||
next_tokens = self.sampler(logits, sampling_metadata)
|
||||
return next_tokens
|
||||
|
||||
def get_mm_mapping(self) -> MultiModelKeys:
|
||||
"""
|
||||
Get the module prefix in multimodal models
|
||||
"""
|
||||
return MultiModelKeys.from_string_field(
|
||||
language_model="model.",
|
||||
connector=["audio_projection_for_vision", "audio_projection"],
|
||||
tower_model=["vision_encoder", "embed_tokens_extend"],
|
||||
)
|
Loading…
x
Reference in New Issue
Block a user