[Misc] Add Phi4-MM example (#14343)

Signed-off-by: Jee Jee Li <pandaleefree@gmail.com>
This commit is contained in:
Jee Jee Li 2025-03-08 01:28:52 +08:00 committed by GitHub
parent d0feea31c7
commit 952a074980
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 131 additions and 7 deletions

View File

@ -6,10 +6,14 @@ with the correct prompt format on audio language models.
For most models, the prompt format should follow corresponding examples
on HuggingFace model repository.
"""
import os
from huggingface_hub import snapshot_download
from transformers import AutoTokenizer
from vllm import LLM, SamplingParams
from vllm.assets.audio import AudioAsset
from vllm.lora.request import LoRARequest
from vllm.utils import FlexibleArgumentParser
audio_assets = [AudioAsset("mary_had_lamb"), AudioAsset("winning_call")]
@ -51,6 +55,39 @@ def run_minicpmo(question: str, audio_count: int):
return llm, prompt, stop_token_ids
# Phi-4-multimodal-instruct
def run_phi4mm(questions: str, audio_count: int):
"""
Phi-4-multimodal-instruct supports both image and audio inputs. Here, we
show how to process audio inputs.
"""
model_path = snapshot_download("microsoft/Phi-4-multimodal-instruct")
# Since the vision-lora and speech-lora co-exist with the base model,
# we have to manually specify the path of the lora weights.
speech_lora_path = os.path.join(model_path, "speech-lora")
placeholders = "".join([f"<|audio_{i+1}|>" for i in range(audio_count)])
prompts = f"<|user|>{placeholders}{questions}<|end|><|assistant|>"
llm = LLM(
model=model_path,
trust_remote_code=True,
max_model_len=4096,
max_num_seqs=2,
enable_lora=True,
max_lora_rank=320,
lora_extra_vocab_size=0,
)
lora_request = LoRARequest("speech", 1, speech_lora_path)
# To maintain code compatibility in this script, we add LoRA here.
llm.llm_engine.add_lora(lora_request=lora_request)
# You can also add LoRA using:
# llm.generate(prompts, lora_request=lora_request,...)
stop_token_ids = None
return llm, prompts, stop_token_ids
# Qwen2-Audio
def run_qwen2_audio(question: str, audio_count: int):
model_name = "Qwen/Qwen2-Audio-7B-Instruct"
@ -113,6 +150,7 @@ def run_whisper(question: str, audio_count: int):
model_example_map = {
"minicpmo": run_minicpmo,
"phi4_mm": run_phi4mm,
"qwen2_audio": run_qwen2_audio,
"ultravox": run_ultravox,
"whisper": run_whisper,

View File

@ -6,13 +6,16 @@ the correct prompt format on vision language models for text generation.
For most models, the prompt format should follow corresponding examples
on HuggingFace model repository.
"""
import os
import random
from huggingface_hub import snapshot_download
from transformers import AutoTokenizer
from vllm import LLM, SamplingParams
from vllm.assets.image import ImageAsset
from vllm.assets.video import VideoAsset
from vllm.lora.request import LoRARequest
from vllm.utils import FlexibleArgumentParser
# NOTE: The default `max_num_seqs` and `max_model_len` may result in OOM on
@ -519,6 +522,40 @@ def run_phi3v(questions: list[str], modality: str):
return llm, prompts, stop_token_ids
# Phi-4-multimodal-instruct
def run_phi4mm(questions: list[str], modality: str):
"""
Phi-4-multimodal-instruct supports both image and audio inputs. Here, we
show how to process image inputs.
"""
assert modality == "image"
model_path = snapshot_download("microsoft/Phi-4-multimodal-instruct")
# Since the vision-lora and speech-lora co-exist with the base model,
# we have to manually specify the path of the lora weights.
vision_lora_path = os.path.join(model_path, "vision-lora")
prompts = [
f"<|user|><|image_1|>{question}<|end|><|assistant|>"
for question in questions
]
llm = LLM(
model=model_path,
trust_remote_code=True,
max_model_len=4096,
max_num_seqs=2,
enable_lora=True,
max_lora_rank=320,
lora_extra_vocab_size=0,
)
lora_request = LoRARequest("vision", 1, vision_lora_path)
# To maintain code compatibility in this script, we add LoRA here.
llm.llm_engine.add_lora(lora_request=lora_request)
# You can also add LoRA using:
# llm.generate(prompts, lora_request=lora_request,...)
stop_token_ids = None
return llm, prompts, stop_token_ids
# Pixtral HF-format
def run_pixtral_hf(questions: list[str], modality: str):
assert modality == "image"
@ -644,6 +681,7 @@ model_example_map = {
"paligemma": run_paligemma,
"paligemma2": run_paligemma2,
"phi3_v": run_phi3v,
"phi4_mm": run_phi4mm,
"pixtral_hf": run_pixtral_hf,
"qwen_vl": run_qwen_vl,
"qwen2_vl": run_qwen2_vl,

View File

@ -4,13 +4,16 @@ This example shows how to use vLLM for running offline inference with
multi-image input on vision language models for text generation,
using the chat template defined by the model.
"""
import os
from argparse import Namespace
from typing import NamedTuple, Optional
from huggingface_hub import snapshot_download
from PIL.Image import Image
from transformers import AutoProcessor, AutoTokenizer
from vllm import LLM, SamplingParams
from vllm.lora.request import LoRARequest
from vllm.multimodal.utils import fetch_image
from vllm.utils import FlexibleArgumentParser
@ -294,6 +297,46 @@ def load_phi3v(question: str, image_urls: list[str]) -> ModelRequestData:
)
def load_phi4mm(question: str, image_urls: list[str]) -> ModelRequestData:
"""
Phi-4-multimodal-instruct supports both image and audio inputs. Here, we
show how to process multi images inputs.
"""
model_path = snapshot_download("microsoft/Phi-4-multimodal-instruct")
# Since the vision-lora and speech-lora co-exist with the base model,
# we have to manually specify the path of the lora weights.
vision_lora_path = os.path.join(model_path, "vision-lora")
llm = LLM(
model=model_path,
trust_remote_code=True,
max_model_len=10000,
max_num_seqs=2,
limit_mm_per_prompt={"image": len(image_urls)},
enable_lora=True,
max_lora_rank=320,
lora_extra_vocab_size=0,
)
lora_request = LoRARequest("vision", 1, vision_lora_path)
# To maintain code compatibility in this script, we add LoRA here.
llm.llm_engine.add_lora(lora_request=lora_request)
# You can also add LoRA using:
# llm.generate(prompts, lora_request=lora_request,...)
placeholders = "".join(f"<|image_{i}|>"
for i, _ in enumerate(image_urls, start=1))
prompt = f"<|user|>{placeholders}{question}<|end|><|assistant|>"
stop_token_ids = None
return ModelRequestData(
llm=llm,
prompt=prompt,
stop_token_ids=stop_token_ids,
image_data=[fetch_image(url) for url in image_urls],
chat_template=None,
)
def load_qwen_vl_chat(question: str,
image_urls: list[str]) -> ModelRequestData:
model_name = "Qwen/Qwen-VL-Chat"
@ -459,6 +502,7 @@ model_example_map = {
"mllama": load_mllama,
"NVLM_D": load_nvlm_d,
"phi3_v": load_phi3v,
"phi4_mm": load_phi4mm,
"pixtral_hf": load_pixtral_hf,
"qwen_vl_chat": load_qwen_vl_chat,
"qwen2_vl": load_qwen2_vl,

View File

@ -25,6 +25,7 @@ from vllm.model_executor.layers.sampler import Sampler, SamplerOutput
from vllm.model_executor.layers.vocab_parallel_embedding import (
DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead)
from vllm.model_executor.models.llama import LlamaModel
from vllm.model_executor.models.module_mapping import MultiModelKeys
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.inputs import MultiModalInputs, NestedTensors
@ -1421,7 +1422,6 @@ class Phi4MMForCausalLM(nn.Module, SupportsLoRA, SupportsMultiModal):
"""
Implements the Phi-4-multimodal-instruct model in VLLM.
"""
# LoRA specific attributes
packed_modules_mapping = {
"qkv_proj": [
"qkv_proj",
@ -1430,12 +1430,6 @@ class Phi4MMForCausalLM(nn.Module, SupportsLoRA, SupportsMultiModal):
"gate_up_proj",
],
}
supported_lora_modules = [
"qkv_proj", "o_proj", "gate_up_proj", "down_proj"
]
# Phi4MMForCausalLM does not apply LoRA to the embedding layer.
embedding_modules = {}
embedding_padding_modules = []
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__()
@ -1801,3 +1795,13 @@ class Phi4MMForCausalLM(nn.Module, SupportsLoRA, SupportsMultiModal):
) -> Optional[SamplerOutput]:
next_tokens = self.sampler(logits, sampling_metadata)
return next_tokens
def get_mm_mapping(self) -> MultiModelKeys:
"""
Get the module prefix in multimodal models
"""
return MultiModelKeys.from_string_field(
language_model="model.",
connector=["audio_projection_for_vision", "audio_projection"],
tower_model=["vision_encoder", "embed_tokens_extend"],
)