[Model][VLM] Add LLaVA-Onevision model support (#8486)

Co-authored-by: litianjian <litianjian@bytedance.com>
Co-authored-by: Cyrus Leung <cyrus.tl.leung@gmail.com>
Co-authored-by: Roger Wang <ywang@roblox.com>
Co-authored-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
litianjian 2024-09-23 01:51:44 +08:00 committed by GitHub
parent ca2b628b3c
commit 5b59532760
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
10 changed files with 1330 additions and 21 deletions

View File

@ -244,6 +244,11 @@ Multimodal Language Models
- Video - Video
- :code:`llava-hf/LLaVA-NeXT-Video-7B-hf`, etc. (see note) - :code:`llava-hf/LLaVA-NeXT-Video-7B-hf`, etc. (see note)
- -
* - :code:`LlavaOnevisionForConditionalGeneration`
- LLaVA-Onevision
- Image\ :sup:`+` / Video
- :code:`llava-hf/llava-onevision-qwen2-7b-ov-hf`, :code:`llava-hf/llava-onevision-qwen2-0.5b-ov-hf`, etc. (see note)
-
* - :code:`MiniCPMV` * - :code:`MiniCPMV`
- MiniCPM-V - MiniCPM-V
- Image\ :sup:`+` - Image\ :sup:`+`
@ -288,7 +293,7 @@ Multimodal Language Models
For more details, please see: https://github.com/vllm-project/vllm/pull/4087#issuecomment-2250397630 For more details, please see: https://github.com/vllm-project/vllm/pull/4087#issuecomment-2250397630
.. note:: .. note::
For :code:`LLaVA-NeXT-Video` and :code:`Qwen2-VL`, the latest release of :code:`huggingface/transformers` doesn't work yet, so we need to use a developer version (:code:`21fac7abba2a37fae86106f87fcf9974fd1e3830`) for now. For :code:`LLaVA-NeXT-Video`, :code:`LLaVA-Onevision` and :code:`Qwen2-VL`, the latest release of :code:`huggingface/transformers` doesn't work yet, so we need to use a developer version (:code:`21fac7abba2a37fae86106f87fcf9974fd1e3830`) for now.
This can be installed by running the following command: This can be installed by running the following command:
.. code-block:: bash .. code-block:: bash

View File

@ -14,7 +14,8 @@ from vllm.utils import FlexibleArgumentParser
# LLaVA-1.5 # LLaVA-1.5
def run_llava(question): def run_llava(question, modality):
assert modality == "image"
prompt = f"USER: <image>\n{question}\nASSISTANT:" prompt = f"USER: <image>\n{question}\nASSISTANT:"
@ -24,7 +25,8 @@ def run_llava(question):
# LLaVA-1.6/LLaVA-NeXT # LLaVA-1.6/LLaVA-NeXT
def run_llava_next(question): def run_llava_next(question, modality):
assert modality == "image"
prompt = f"[INST] <image>\n{question} [/INST]" prompt = f"[INST] <image>\n{question} [/INST]"
llm = LLM(model="llava-hf/llava-v1.6-mistral-7b-hf", max_model_len=8192) llm = LLM(model="llava-hf/llava-v1.6-mistral-7b-hf", max_model_len=8192)
@ -34,15 +36,35 @@ def run_llava_next(question):
# LlaVA-NeXT-Video # LlaVA-NeXT-Video
# Currently only support for video input # Currently only support for video input
def run_llava_next_video(question): def run_llava_next_video(question, modality):
assert modality == "video"
prompt = f"USER: <video>\n{question} ASSISTANT:" prompt = f"USER: <video>\n{question} ASSISTANT:"
llm = LLM(model="llava-hf/LLaVA-NeXT-Video-7B-hf", max_model_len=8192) llm = LLM(model="llava-hf/LLaVA-NeXT-Video-7B-hf", max_model_len=8192)
stop_token_ids = None stop_token_ids = None
return llm, prompt, stop_token_ids return llm, prompt, stop_token_ids
# LLaVA-OneVision
def run_llava_onevision(question, modality):
if modality == "video":
prompt = f"<|im_start|>user <video>\n{question}<|im_end|> \
<|im_start|>assistant\n"
elif modality == "image":
prompt = f"<|im_start|>user <image>\n{question}<|im_end|> \
<|im_start|>assistant\n"
llm = LLM(model="llava-hf/llava-onevision-qwen2-7b-ov-hf",
max_model_len=32768)
stop_token_ids = None
return llm, prompt, stop_token_ids
# Fuyu # Fuyu
def run_fuyu(question): def run_fuyu(question, modality):
assert modality == "image"
prompt = f"{question}\n" prompt = f"{question}\n"
llm = LLM(model="adept/fuyu-8b") llm = LLM(model="adept/fuyu-8b")
@ -51,7 +73,8 @@ def run_fuyu(question):
# Phi-3-Vision # Phi-3-Vision
def run_phi3v(question): def run_phi3v(question, modality):
assert modality == "image"
prompt = f"<|user|>\n<|image_1|>\n{question}<|end|>\n<|assistant|>\n" # noqa: E501 prompt = f"<|user|>\n<|image_1|>\n{question}<|end|>\n<|assistant|>\n" # noqa: E501
# Note: The default setting of max_num_seqs (256) and # Note: The default setting of max_num_seqs (256) and
@ -70,7 +93,8 @@ def run_phi3v(question):
# PaliGemma # PaliGemma
def run_paligemma(question): def run_paligemma(question, modality):
assert modality == "image"
# PaliGemma has special prompt format for VQA # PaliGemma has special prompt format for VQA
prompt = "caption en" prompt = "caption en"
@ -80,7 +104,8 @@ def run_paligemma(question):
# Chameleon # Chameleon
def run_chameleon(question): def run_chameleon(question, modality):
assert modality == "image"
prompt = f"{question}<image>" prompt = f"{question}<image>"
llm = LLM(model="facebook/chameleon-7b") llm = LLM(model="facebook/chameleon-7b")
@ -89,7 +114,8 @@ def run_chameleon(question):
# MiniCPM-V # MiniCPM-V
def run_minicpmv(question): def run_minicpmv(question, modality):
assert modality == "image"
# 2.0 # 2.0
# The official repo doesn't work yet, so we need to use a fork for now # The official repo doesn't work yet, so we need to use a fork for now
@ -129,7 +155,9 @@ def run_minicpmv(question):
# InternVL # InternVL
def run_internvl(question): def run_internvl(question, modality):
assert modality == "image"
model_name = "OpenGVLab/InternVL2-2B" model_name = "OpenGVLab/InternVL2-2B"
llm = LLM( llm = LLM(
@ -155,7 +183,8 @@ def run_internvl(question):
# BLIP-2 # BLIP-2
def run_blip2(question): def run_blip2(question, modality):
assert modality == "image"
# BLIP-2 prompt format is inaccurate on HuggingFace model repository. # BLIP-2 prompt format is inaccurate on HuggingFace model repository.
# See https://huggingface.co/Salesforce/blip2-opt-2.7b/discussions/15#64ff02f3f8cf9e4f5b038262 #noqa # See https://huggingface.co/Salesforce/blip2-opt-2.7b/discussions/15#64ff02f3f8cf9e4f5b038262 #noqa
@ -166,7 +195,8 @@ def run_blip2(question):
# Qwen # Qwen
def run_qwen_vl(question): def run_qwen_vl(question, modality):
assert modality == "image"
llm = LLM( llm = LLM(
model="Qwen/Qwen-VL", model="Qwen/Qwen-VL",
@ -180,7 +210,9 @@ def run_qwen_vl(question):
# Qwen2-VL # Qwen2-VL
def run_qwen2_vl(question): def run_qwen2_vl(question, modality):
assert modality == "image"
model_name = "Qwen/Qwen2-VL-7B-Instruct" model_name = "Qwen/Qwen2-VL-7B-Instruct"
llm = LLM( llm = LLM(
@ -200,6 +232,7 @@ model_example_map = {
"llava": run_llava, "llava": run_llava,
"llava-next": run_llava_next, "llava-next": run_llava_next,
"llava-next-video": run_llava_next_video, "llava-next-video": run_llava_next_video,
"llava-onevision": run_llava_onevision,
"fuyu": run_fuyu, "fuyu": run_fuyu,
"phi3_v": run_phi3v, "phi3_v": run_phi3v,
"paligemma": run_paligemma, "paligemma": run_paligemma,
@ -255,7 +288,7 @@ def main(args):
data = mm_input["data"] data = mm_input["data"]
question = mm_input["question"] question = mm_input["question"]
llm, prompt, stop_token_ids = model_example_map[model](question) llm, prompt, stop_token_ids = model_example_map[model](question, modality)
# We set temperature to 0.2 so that outputs can be different # We set temperature to 0.2 so that outputs can be different
# even when all prompts are identical when running batch inference. # even when all prompts are identical when running batch inference.
@ -306,6 +339,7 @@ if __name__ == "__main__":
parser.add_argument('--modality', parser.add_argument('--modality',
type=str, type=str,
default="image", default="image",
choices=['image', 'video'],
help='Modality of the input.') help='Modality of the input.')
parser.add_argument('--num-frames', parser.add_argument('--num-frames',
type=int, type=int,

View File

@ -105,9 +105,6 @@ def run_test(
for asset in video_assets for asset in video_assets
] ]
for video in videos:
print(video.shape)
if size_factors is not None: if size_factors is not None:
inputs_per_video = [( inputs_per_video = [(
[prompt for _ in size_factors], [prompt for _ in size_factors],

View File

@ -0,0 +1,356 @@
from typing import List, Optional, Tuple, Type, overload
import pytest
import transformers
from transformers import (AutoConfig, AutoModelForVision2Seq, AutoTokenizer,
BatchEncoding)
from vllm.multimodal.utils import (rescale_image_size, rescale_video_size,
resize_video, sample_frames_from_video)
from vllm.sequence import SampleLogprobs
from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE
from ....conftest import (VIDEO_ASSETS, HfRunner, PromptImageInput, VllmRunner,
_VideoAssets)
from ...utils import check_logprobs_close
# Video test
HF_VIDEO_PROMPTS = VIDEO_ASSETS.prompts({
"sample_demo_1":
"<|im_start|>user <video>\nwhy is this video funny? \
<|im_end|><|im_start|>assistant\n"
})
models = ["llava-hf/llava-onevision-qwen2-7b-ov-hf"]
def vllm_to_hf_output(vllm_output: Tuple[List[int], str,
Optional[SampleLogprobs]],
model: str):
"""Sanitize vllm output to be comparable with hf output."""
output_ids, output_str, out_logprobs = vllm_output
config = AutoConfig.from_pretrained(model)
video_token_id = config.video_token_index
tokenizer = AutoTokenizer.from_pretrained(model)
eos_token_id = tokenizer.eos_token_id
hf_output_ids = [
token_id for idx, token_id in enumerate(output_ids)
if token_id != video_token_id or output_ids[idx - 1] != video_token_id
]
hf_output_str = output_str
if hf_output_ids[-1] == eos_token_id:
hf_output_str = hf_output_str + tokenizer.decode(eos_token_id)
return hf_output_ids, hf_output_str, out_logprobs
@overload
def run_video_test(
hf_runner: Type[HfRunner],
vllm_runner: Type[VllmRunner],
video_assets: _VideoAssets,
model: str,
*,
size_factors: List[float],
dtype: str,
max_tokens: int,
num_logprobs: int,
num_frames: int,
tensor_parallel_size: int,
distributed_executor_backend: Optional[str] = None,
):
...
@overload
def run_video_test(
hf_runner: Type[HfRunner],
vllm_runner: Type[VllmRunner],
video_assets: _VideoAssets,
model: str,
*,
sizes: List[Tuple[int, int]],
dtype: str,
max_tokens: int,
num_logprobs: int,
num_frames: int,
tensor_parallel_size: int,
distributed_executor_backend: Optional[str] = None,
):
...
def run_video_test(
hf_runner: Type[HfRunner],
vllm_runner: Type[VllmRunner],
video_assets: _VideoAssets,
model: str,
*,
size_factors: Optional[List[float]] = None,
sizes: Optional[List[Tuple[int, int]]] = None,
dtype: str,
max_tokens: int,
num_logprobs: int,
num_frames: int,
tensor_parallel_size: int,
distributed_executor_backend: Optional[str] = None,
):
torch_dtype = STR_DTYPE_TO_TORCH_DTYPE[dtype]
videos = [
sample_frames_from_video(asset.np_ndarrays, num_frames)
for asset in video_assets
]
if size_factors is not None:
inputs_per_video = [(
[prompt for _ in size_factors],
[rescale_video_size(video, factor) for factor in size_factors],
) for video, prompt in zip(videos, HF_VIDEO_PROMPTS)]
elif sizes is not None:
inputs_per_video = [(
[prompt for _ in sizes],
[resize_video(video, size) for size in sizes],
) for video, prompt in zip(videos, HF_VIDEO_PROMPTS)]
else:
raise ValueError("You must provide either `size_factors` or `sizes`")
# max_model_len should be greater than image_feature_size
with vllm_runner(model,
dtype=dtype,
max_model_len=4096,
tensor_parallel_size=tensor_parallel_size,
distributed_executor_backend=distributed_executor_backend,
enforce_eager=True) as vllm_model:
vllm_outputs_per_video = [
vllm_model.generate_greedy_logprobs(prompts,
max_tokens,
num_logprobs=num_logprobs,
videos=videos)
for prompts, videos in inputs_per_video
]
def process(hf_inputs: BatchEncoding):
hf_inputs["pixel_values_videos"] = hf_inputs["pixel_values_videos"] \
.to(torch_dtype) # type: ignore
return hf_inputs
with hf_runner(model,
dtype=dtype,
postprocess_inputs=process,
auto_cls=AutoModelForVision2Seq) as hf_model:
hf_outputs_per_video = [
hf_model.generate_greedy_logprobs_limit(prompts,
max_tokens,
num_logprobs=num_logprobs,
videos=videos)
for prompts, videos in inputs_per_video
]
for hf_outputs, vllm_outputs in zip(hf_outputs_per_video,
vllm_outputs_per_video):
# TODO: Check whether using original CLIPVisionModel can improve
# consistency against HF
check_logprobs_close(
outputs_0_lst=hf_outputs,
outputs_1_lst=[
vllm_to_hf_output(vllm_output, model)
for vllm_output in vllm_outputs
],
name_0="hf",
name_1="vllm",
)
@pytest.mark.skipif(transformers.__version__ < "4.45",
reason="Waiting for next transformers release")
@pytest.mark.parametrize("model", models)
@pytest.mark.parametrize(
"size_factors",
[
# No video
[],
# Single-scale
[1.0],
# Single-scale, batched
[1.0, 1.0, 1.0],
# Multi-scale
[0.25, 0.5, 1.0],
],
)
@pytest.mark.parametrize("dtype", ["half"])
@pytest.mark.parametrize("max_tokens", [128])
@pytest.mark.parametrize("num_logprobs", [5])
@pytest.mark.parametrize("num_frames", [16])
def test_models(hf_runner, vllm_runner, video_assets, model, size_factors,
dtype, max_tokens, num_logprobs, num_frames) -> None:
"""Inference result should be the same between hf and vllm.
All the image fixtures for the test is under tests/videos.
For huggingface runner, we provide the np.ndarray as input.
For vllm runner, we provide MultiModalDataDict objects
and corresponding MultiModalConfig as input.
Note, the text input is also adjusted to abide by vllm contract.
The text output is sanitized to be able to compare with hf.
"""
run_video_test(
hf_runner,
vllm_runner,
video_assets,
model,
size_factors=size_factors,
dtype=dtype,
max_tokens=max_tokens,
num_logprobs=num_logprobs,
num_frames=num_frames,
tensor_parallel_size=1,
)
@pytest.mark.skipif(transformers.__version__ < "4.45",
reason="Waiting for next transformers release")
@pytest.mark.parametrize("model", models)
@pytest.mark.parametrize(
"sizes",
[[(1669, 2560), (2560, 1669), (183, 488), (488, 183)]],
)
@pytest.mark.parametrize("dtype", ["half"])
@pytest.mark.parametrize("max_tokens", [128])
@pytest.mark.parametrize("num_logprobs", [5])
@pytest.mark.parametrize("num_frames", [16])
def test_models_fixed_sizes(hf_runner, vllm_runner, video_assets, model, sizes,
dtype, max_tokens, num_logprobs,
num_frames) -> None:
run_video_test(
hf_runner,
vllm_runner,
video_assets,
model,
sizes=sizes,
dtype=dtype,
max_tokens=max_tokens,
num_logprobs=num_logprobs,
num_frames=num_frames,
tensor_parallel_size=1,
)
# Image test
_LIMIT_IMAGE_PER_PROMPT = 4
def run_image_test(
hf_runner: Type[HfRunner],
vllm_runner: Type[VllmRunner],
inputs: List[Tuple[List[str], PromptImageInput]],
model: str,
dtype: str,
max_tokens: int,
num_logprobs: int,
tensor_parallel_size: int,
distributed_executor_backend: Optional[str] = None,
):
torch_dtype = STR_DTYPE_TO_TORCH_DTYPE[dtype]
# max_model_len should be greater than image_feature_size
with vllm_runner(model,
dtype=dtype,
max_model_len=32768,
tensor_parallel_size=tensor_parallel_size,
distributed_executor_backend=distributed_executor_backend,
enforce_eager=True,
limit_mm_per_prompt={"image": _LIMIT_IMAGE_PER_PROMPT
}) as vllm_model:
vllm_outputs_per_image = [
vllm_model.generate_greedy_logprobs(prompts,
max_tokens,
num_logprobs=num_logprobs,
images=images)
for prompts, images in inputs
]
def process(hf_inputs: BatchEncoding):
hf_inputs["pixel_values"] = hf_inputs["pixel_values"] \
.to(torch_dtype) # type: ignore
return hf_inputs
with hf_runner(model,
dtype=dtype,
postprocess_inputs=process,
auto_cls=AutoModelForVision2Seq) as hf_model:
hf_outputs_per_image = [
hf_model.generate_greedy_logprobs_limit(prompts,
max_tokens,
num_logprobs=num_logprobs,
images=images)
for prompts, images in inputs
]
for hf_outputs, vllm_outputs in zip(hf_outputs_per_image,
vllm_outputs_per_image):
# TODO: Check whether using original CLIPVisionModel can improve
# consistency against HF
check_logprobs_close(
outputs_0_lst=hf_outputs,
outputs_1_lst=[
vllm_to_hf_output(vllm_output, model)
for vllm_output in vllm_outputs
],
name_0="hf",
name_1="vllm",
)
@pytest.mark.skipif(transformers.__version__ < "4.45",
reason="Waiting for next transformers release")
@pytest.mark.parametrize("model", models)
@pytest.mark.parametrize("dtype", ["half"])
@pytest.mark.parametrize("max_tokens", [128])
@pytest.mark.parametrize("num_logprobs", [5])
def test_models_multiple_image_inputs(hf_runner, vllm_runner, image_assets,
model, dtype, max_tokens,
num_logprobs) -> None:
stop_sign = image_assets[0].pil_image
cherry_blossom = image_assets[1].pil_image
inputs = [(
[
"<|im_start|>user <image><image>\nDescribe 2 images. \
<|im_end|><|im_start|>assistant\n",
"<|im_start|>user <image><image>\nDescribe 2 images. \
<|im_end|><|im_start|>assistant\n",
"<|im_start|>user <image><image><image><image>\nDescribe 4 images. \
<|im_end|><|im_start|>assistant\n",
"<|im_start|>user <image>\nWhat is the season? \
<|im_end|><|im_start|>assistant\n",
],
[
[stop_sign, cherry_blossom],
# Images with different sizes and aspect-ratios
[
rescale_image_size(stop_sign, 0.1),
stop_sign,
],
[
stop_sign,
rescale_image_size(stop_sign, 0.25),
cherry_blossom.resize((183, 488)),
cherry_blossom.resize((488, 183))
],
cherry_blossom,
])]
run_image_test(
hf_runner,
vllm_runner,
inputs,
model,
dtype=dtype,
max_tokens=max_tokens,
num_logprobs=num_logprobs,
tensor_parallel_size=1,
)

View File

@ -6,7 +6,8 @@ from vllm.model_executor.models import _MODELS, ModelRegistry
@pytest.mark.parametrize("model_cls", _MODELS) @pytest.mark.parametrize("model_cls", _MODELS)
def test_registry_imports(model_cls): def test_registry_imports(model_cls):
if (model_cls == "Qwen2VLForConditionalGeneration" if (model_cls in ("LlavaOnevisionForConditionalGeneration",
"Qwen2VLForConditionalGeneration")
and transformers.__version__ < "4.45"): and transformers.__version__ < "4.45"):
pytest.skip("Waiting for next transformers release") pytest.skip("Waiting for next transformers release")

View File

@ -79,7 +79,7 @@ class VideoAsset:
return ret return ret
@property @property
def np_ndarrays(self) -> List[npt.NDArray]: def np_ndarrays(self) -> npt.NDArray:
video_path = download_video_asset(self.name) video_path = download_video_asset(self.name)
ret = video_to_ndarrays(video_path, self.num_frames) ret = video_to_ndarrays(video_path, self.num_frames)
return ret return ret

View File

@ -83,12 +83,14 @@ _MULTIMODAL_MODELS = {
("chameleon", "ChameleonForConditionalGeneration"), ("chameleon", "ChameleonForConditionalGeneration"),
"FuyuForCausalLM": ("fuyu", "FuyuForCausalLM"), "FuyuForCausalLM": ("fuyu", "FuyuForCausalLM"),
"InternVLChatModel": ("internvl", "InternVLChatModel"), "InternVLChatModel": ("internvl", "InternVLChatModel"),
"LlavaForConditionalGeneration": "LlavaForConditionalGeneration": ("llava",
("llava", "LlavaForConditionalGeneration"), "LlavaForConditionalGeneration"),
"LlavaNextForConditionalGeneration": ("llava_next", "LlavaNextForConditionalGeneration": ("llava_next",
"LlavaNextForConditionalGeneration"), "LlavaNextForConditionalGeneration"),
"LlavaNextVideoForConditionalGeneration": "LlavaNextVideoForConditionalGeneration":
("llava_next_video", "LlavaNextVideoForConditionalGeneration"), ("llava_next_video", "LlavaNextVideoForConditionalGeneration"),
"LlavaOnevisionForConditionalGeneration":
("llava_onevision", "LlavaOnevisionForConditionalGeneration"),
"MiniCPMV": ("minicpmv", "MiniCPMV"), "MiniCPMV": ("minicpmv", "MiniCPMV"),
"PaliGemmaForConditionalGeneration": ("paligemma", "PaliGemmaForConditionalGeneration": ("paligemma",
"PaliGemmaForConditionalGeneration"), "PaliGemmaForConditionalGeneration"),

View File

@ -2,6 +2,7 @@
within a vision language model.""" within a vision language model."""
from typing import Iterable, List, Optional, Tuple, Union from typing import Iterable, List, Optional, Tuple, Union
import numpy as np
import torch import torch
import torch.nn as nn import torch.nn as nn
from PIL import Image from PIL import Image
@ -84,6 +85,24 @@ def dummy_image_for_clip(
return {"image": image if num_images == 1 else [image] * num_images} return {"image": image if num_images == 1 else [image] * num_images}
def dummy_video_for_clip(
hf_config: CLIPVisionConfig,
num_frames: int,
*,
image_width_override: Optional[int] = None,
image_height_override: Optional[int] = None,
):
pil_frame = dummy_image_for_clip(
hf_config,
num_images=1,
image_width_override=image_width_override,
image_height_override=image_height_override)
np_frame = np.array(pil_frame["image"])
mm_data_per_video = np.repeat([np_frame], num_frames, axis=0)
mm_data = {"video": mm_data_per_video}
return mm_data
def input_processor_for_clip( def input_processor_for_clip(
model_config: ModelConfig, model_config: ModelConfig,
hf_config: CLIPVisionConfig, hf_config: CLIPVisionConfig,

View File

@ -0,0 +1,876 @@
import math
from typing import (Iterable, List, Literal, Mapping, Optional, Tuple,
TypedDict, Union)
import numpy as np
import torch
import torch.nn as nn
from PIL import Image
from transformers import (CLIPVisionConfig, LlavaOnevisionConfig,
SiglipVisionConfig)
from transformers.models.llava_onevision.modeling_llava_onevision import (
get_anyres_image_grid_shape, unpad_image)
from typing_extensions import NotRequired
from vllm.attention import AttentionMetadata
from vllm.config import CacheConfig, MultiModalConfig
from vllm.inputs import INPUT_REGISTRY, InputContext, LLMInputs
from vllm.logger import init_logger
from vllm.model_executor.layers.activation import get_act_fn
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig)
from vllm.model_executor.layers.sampler import SamplerOutput
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.utils import (cached_get_tokenizer,
repeat_and_pad_placeholder_tokens)
from vllm.sequence import IntermediateTensors
from vllm.utils import is_list_of
from .clip import (CLIPVisionModel, dummy_seq_data_for_clip,
dummy_video_for_clip, get_clip_image_feature_size,
get_clip_patch_grid_length, input_processor_for_clip)
from .interfaces import SupportsMultiModal
from .siglip import (SiglipVisionModel, dummy_seq_data_for_siglip,
dummy_video_for_siglip, get_siglip_image_feature_size,
get_siglip_patch_grid_length, input_processor_for_siglip)
from .utils import (flatten_bn, group_weights_with_prefix,
init_vllm_registered_model, merge_multimodal_embeddings)
logger = init_logger(__name__)
# Result in the max possible feature size (2x2 grid of 336x336px tiles)
MAX_IMAGE_FEATURE_SIZE_HEIGHT = MAX_IMAGE_FEATURE_SIZE_WIDTH = 448
# For profile run
_MAX_FRAMES_PER_VIDEO = 16
_MAX_NUM_VIDEOS = 1
class LlavaOnevisionVideoPixelInputs(TypedDict):
type: Literal["pixel_values_videos"]
data: Union[torch.Tensor, List[torch.Tensor]]
"""
Shape: `(batch_size, num_frames, num_channels, height, width)`
Note that `num_frames` may be different for each batch, in which case
the data is passed as a list instead of a batched tensor.
Note that it only supports one video input for one batch.
"""
class LlavaOnevisionImagePixelInputs(TypedDict):
type: Literal["pixel_values"]
data: Union[torch.Tensor, List[torch.Tensor]]
"""
Shape:
`(batch_size * num_images, 1 + num_patches, num_channels, height, width)`
Note that `num_patches` may be different per batch and image,
in which case the data is passed as a list instead of a batched tensor.
"""
image_sizes: NotRequired[torch.Tensor]
"""
Shape: `(batch_size * num_images, 2)`
This should be in `(height, width)` format.
"""
class LlavaOnevisionImageEmbeddingInputs(TypedDict):
type: Literal["image_embeds"]
data: torch.Tensor
"""Shape: `(batch_size * num_images, image_feature_size, hidden_size)`
`hidden_size` must match the hidden size of language model backbone.
"""
LlavaOnevisionImageInputs = Union[LlavaOnevisionImagePixelInputs,
LlavaOnevisionImageEmbeddingInputs]
LlavaOnevisionMultiInputs = Union[LlavaOnevisionImageInputs,
LlavaOnevisionVideoPixelInputs]
def _get_llava_onevision_image_unppaded_feature_size(height, width, patches,
scale_height,
scale_width):
current_height = patches * scale_height
current_width = patches * scale_width
original_aspect_ratio = width / height
current_aspect_ratio = current_width / current_height
if original_aspect_ratio > current_aspect_ratio:
new_height = int(height * (current_width / width))
padding = (current_height - new_height) // 2
current_height -= padding * 2
else:
new_width = int(width * (current_height / height))
padding = (current_width - new_width) // 2
current_width -= padding * 2
unpadded_features = current_height * current_width
newline_features = current_height
ratio = math.sqrt(current_height * current_width / (9 * patches**2))
if ratio > 1.1:
unpadded_features = int(current_height // ratio) * int(
current_width // ratio)
newline_features = int(current_height // ratio)
return (unpadded_features, newline_features)
def get_llava_onevision_image_feature_size(
hf_config: LlavaOnevisionConfig,
*,
input_height: int,
input_width: int,
) -> int:
vision_config = hf_config.vision_config
if isinstance(vision_config, CLIPVisionConfig):
num_patches = get_clip_patch_grid_length(
image_size=vision_config.image_size,
patch_size=vision_config.patch_size,
)
base_feature_size = get_clip_image_feature_size(vision_config)
elif isinstance(vision_config, SiglipVisionConfig):
num_patches = get_siglip_patch_grid_length(
image_size=vision_config.image_size,
patch_size=vision_config.patch_size,
)
base_feature_size = get_siglip_image_feature_size(vision_config)
else:
msg = f"Unsupported vision config: {type(vision_config)}"
raise NotImplementedError(msg)
strategy = hf_config.vision_feature_select_strategy
if strategy == "default":
base_feature_size -= 1
elif strategy == "full":
pass
else:
raise ValueError(f"Unexpected select feature strategy: {strategy}")
num_patch_height, num_patch_width = get_anyres_image_grid_shape(
image_size=(input_height, input_width),
grid_pinpoints=hf_config.image_grid_pinpoints,
patch_size=vision_config.image_size,
)
(
unpadded_feature_size,
newline_feature_size,
) = _get_llava_onevision_image_unppaded_feature_size(
input_height, input_width, num_patches, num_patch_height,
num_patch_width)
return unpadded_feature_size + newline_feature_size + base_feature_size
def get_max_llava_onevision_image_tokens(ctx: InputContext):
return get_llava_onevision_image_feature_size(
ctx.get_hf_config(LlavaOnevisionConfig),
input_height=MAX_IMAGE_FEATURE_SIZE_HEIGHT,
input_width=MAX_IMAGE_FEATURE_SIZE_WIDTH,
)
def get_llava_onevision_video_frame_feature_size(
hf_config: LlavaOnevisionConfig) -> int:
# Support both CLIPVisionConfig and SiglipVisionConfig
image_size = hf_config.vision_config.image_size
patch_size = hf_config.vision_config.patch_size
spatial_pool_stride = hf_config.spatial_pool_stride if hasattr(
hf_config, "spatial_pool_stride") else 2
height = width = image_size // patch_size
return math.ceil(height / spatial_pool_stride) * math.ceil(
width / spatial_pool_stride)
def get_llava_onevision_video_tokens(ctx: InputContext,
num_frames: int) -> int:
hf_config = ctx.get_hf_config(LlavaOnevisionConfig)
# TODO: support configuring (not supported by HF right now)
num_token_image_newline = 1
tokens_per_frame = get_llava_onevision_video_frame_feature_size(hf_config)
video_feature_size = num_frames * tokens_per_frame + num_token_image_newline
return video_feature_size
def get_max_llava_onevision_video_tokens(ctx: InputContext) -> int:
return get_llava_onevision_video_tokens(ctx, _MAX_FRAMES_PER_VIDEO)
def dummy_data_for_llava_onevision(ctx: InputContext, seq_len: int,
mm_counts: Mapping[str, int]):
hf_config = ctx.get_hf_config(LlavaOnevisionConfig)
vision_config = hf_config.vision_config
# TODO: support multiple videos
num_videos = mm_counts["video"]
if num_videos > _MAX_NUM_VIDEOS:
raise NotImplementedError(
f"Only {_MAX_NUM_VIDEOS} videos are supported")
# TODO: support configuring the number of frames
num_frames = _MAX_FRAMES_PER_VIDEO
video_feature_size = get_llava_onevision_video_tokens(ctx, num_frames)
if isinstance(vision_config, CLIPVisionConfig):
seq_data = dummy_seq_data_for_clip(
vision_config,
seq_len,
num_videos,
image_token_id=hf_config.video_token_index,
image_feature_size_override=video_feature_size,
)
mm_data = dummy_video_for_clip(vision_config, num_frames=num_frames)
return seq_data, mm_data
elif isinstance(vision_config, SiglipVisionConfig):
seq_data = dummy_seq_data_for_siglip(
vision_config,
seq_len,
num_videos,
image_token_id=hf_config.video_token_index,
image_feature_size_override=video_feature_size,
)
mm_data = dummy_video_for_siglip(vision_config, num_frames=num_frames)
return seq_data, mm_data
msg = f"Unsupported vision config: {type(vision_config)}"
raise NotImplementedError(msg)
def input_processor_when_multimodal_input_image(ctx: InputContext,
llm_inputs: LLMInputs):
multi_modal_data = llm_inputs.get("multi_modal_data")
if multi_modal_data is None or "image" not in multi_modal_data:
return llm_inputs
model_config = ctx.model_config
hf_config = ctx.get_hf_config(LlavaOnevisionConfig)
vision_config = hf_config.vision_config
image_data = multi_modal_data["image"]
if isinstance(image_data, Image.Image):
width, height = image_data.size
image_feature_size = get_llava_onevision_image_feature_size(
hf_config,
input_height=height,
input_width=width,
)
elif is_list_of(image_data, Image.Image):
image_feature_size = [
get_llava_onevision_image_feature_size(hf_config,
input_height=img.height,
input_width=img.width)
for img in image_data
]
elif isinstance(image_data, torch.Tensor):
num_images, image_feature_size, hidden_size = image_data.shape
elif is_list_of(image_data, torch.Tensor):
image_feature_size = [item.shape[1] for item in image_data]
else:
raise TypeError(f"Invalid image type: {type(image_data)}")
vision_config = hf_config.vision_config
if isinstance(vision_config, CLIPVisionConfig):
return input_processor_for_clip(
model_config,
vision_config,
llm_inputs,
image_token_id=hf_config.image_token_index,
image_feature_size_override=image_feature_size,
)
elif isinstance(vision_config, SiglipVisionConfig):
return input_processor_for_siglip(
model_config,
vision_config,
llm_inputs,
image_token_id=hf_config.image_token_index,
image_feature_size_override=image_feature_size,
)
msg = f"Unsupported vision config: {type(vision_config)}"
raise NotImplementedError(msg)
def input_processor_when_multimodal_input_video(ctx: InputContext,
llm_inputs: LLMInputs):
multi_modal_data = llm_inputs.get("multi_modal_data")
if multi_modal_data is None or "video" not in multi_modal_data:
return llm_inputs
video_data = multi_modal_data["video"]
model_config = ctx.model_config
hf_config = ctx.get_hf_config(LlavaOnevisionConfig)
vision_config = hf_config.vision_config
if isinstance(video_data, np.ndarray):
# Supports both CLIP and Siglip
num_frames = video_data.shape[0]
video_feature_size = get_llava_onevision_video_tokens(ctx, num_frames)
tokenizer = cached_get_tokenizer(model_config.tokenizer)
new_prompt, new_token_ids = repeat_and_pad_placeholder_tokens(
tokenizer,
llm_inputs.get("prompt"),
llm_inputs["prompt_token_ids"],
placeholder_token_id=hf_config.video_token_index,
repeat_count=video_feature_size,
)
return LLMInputs(prompt_token_ids=new_token_ids,
prompt=new_prompt,
multi_modal_data=multi_modal_data)
elif is_list_of(video_data, np.ndarray):
raise NotImplementedError(
"Processing multiple videos is not supported")
msg = f"Unsupported vision config: {type(vision_config)}"
raise NotImplementedError(msg)
def input_processor_for_llava_onevision(ctx: InputContext,
llm_inputs: LLMInputs):
multi_modal_data = llm_inputs.get("multi_modal_data")
if multi_modal_data is None or ("video" not in multi_modal_data
and "image" not in multi_modal_data):
return llm_inputs
if "image" in multi_modal_data:
return input_processor_when_multimodal_input_image(ctx, llm_inputs)
if "video" in multi_modal_data:
return input_processor_when_multimodal_input_video(ctx, llm_inputs)
msg = "Unsupported multi data type"
raise NotImplementedError(msg)
def _init_vision_tower(hf_config: LlavaOnevisionConfig):
vision_config = hf_config.vision_config
# Initialize the vision tower only up to the required feature layer
vision_feature_layer = hf_config.vision_feature_layer
if vision_feature_layer < 0:
num_hidden_layers = hf_config.vision_config.num_hidden_layers \
+ vision_feature_layer + 1
else:
num_hidden_layers = vision_feature_layer + 1
if isinstance(vision_config, CLIPVisionConfig):
return CLIPVisionModel(
vision_config,
num_hidden_layers_override=num_hidden_layers,
)
elif isinstance(vision_config, SiglipVisionConfig):
return SiglipVisionModel(
vision_config,
num_hidden_layers_override=num_hidden_layers,
)
msg = f"Unsupported vision config: {type(vision_config)}"
raise NotImplementedError(msg)
class LlavaOnevisionMultiModalProjector(nn.Module):
def __init__(self, config: LlavaOnevisionConfig):
super().__init__()
self.linear_1 = nn.Linear(config.vision_config.hidden_size,
config.text_config.hidden_size,
bias=True)
self.act = get_act_fn(config.projector_hidden_act)
self.linear_2 = nn.Linear(config.text_config.hidden_size,
config.text_config.hidden_size,
bias=True)
def forward(self, image_features: torch.Tensor) -> torch.Tensor:
hidden_states = self.linear_1(image_features)
hidden_states = self.act(hidden_states)
hidden_states = self.linear_2(hidden_states)
return hidden_states
@MULTIMODAL_REGISTRY.register_image_input_mapper()
@MULTIMODAL_REGISTRY.register_input_mapper("video")
@MULTIMODAL_REGISTRY.register_max_multimodal_tokens(
"image", get_max_llava_onevision_image_tokens)
@MULTIMODAL_REGISTRY.register_max_multimodal_tokens(
"video", get_max_llava_onevision_video_tokens)
@INPUT_REGISTRY.register_dummy_data(dummy_data_for_llava_onevision)
@INPUT_REGISTRY.register_input_processor(input_processor_for_llava_onevision)
class LlavaOnevisionForConditionalGeneration(nn.Module, SupportsMultiModal):
def __init__(self,
config: LlavaOnevisionConfig,
multimodal_config: MultiModalConfig,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None) -> None:
super().__init__()
self.config = config
self.multimodal_config = multimodal_config
# Initialize the vision tower only up to the required feature layer
self.vision_tower = _init_vision_tower(config)
self.multi_modal_projector = LlavaOnevisionMultiModalProjector(config)
self.language_model = init_vllm_registered_model(
config.text_config, cache_config, quant_config)
self.image_newline = nn.Parameter(
torch.empty(config.text_config.hidden_size))
def _validate_image_sizes(self, data: torch.Tensor) -> torch.Tensor:
expected_dims = (2, )
def _validate_shape(d: torch.Tensor):
actual_dims = tuple(d.shape)
if actual_dims != expected_dims:
expected_expr = str(expected_dims)
raise ValueError(
f"The expected shape of image sizes per image per batch "
f"is {expected_expr}. You supplied {tuple(d.shape)}.")
for d in data:
_validate_shape(d)
return data
def _validate_image_pixel_values(
self, data: Union[torch.Tensor, List[torch.Tensor]]
) -> Union[torch.Tensor, List[torch.Tensor]]:
h = w = self.config.vision_config.image_size
expected_dims = (3, h, w)
def _validate_shape(d: torch.Tensor):
actual_dims = tuple(d.shape[1:])
if actual_dims != expected_dims:
expected_expr = ("num_patches", *map(str, expected_dims))
raise ValueError(
"The expected shape of pixel values per image per batch "
f"is {expected_expr}. You supplied {tuple(d.shape)}.")
for d in data:
_validate_shape(d)
return data
def _parse_and_validate_image_input(
self, **kwargs: object) -> Optional[LlavaOnevisionImageInputs]:
pixel_values = kwargs.pop("pixel_values", None)
image_sizes = kwargs.pop("image_sizes", None)
image_embeds = kwargs.pop("image_embeds", None)
if pixel_values is None and image_embeds is None:
return None
if pixel_values is not None:
if not isinstance(pixel_values, (torch.Tensor, list)):
raise ValueError("Incorrect type of pixel values. "
f"Got type: {type(pixel_values)}")
if not isinstance(image_sizes, (torch.Tensor, list)):
raise ValueError("Incorrect type of image sizes. "
f"Got type: {type(image_sizes)}")
return LlavaOnevisionImagePixelInputs(
type="pixel_values",
data=self._validate_image_pixel_values(
flatten_bn(pixel_values)),
image_sizes=self._validate_image_sizes(
flatten_bn(image_sizes, concat=True)),
)
if image_embeds is not None:
if not isinstance(image_embeds, torch.Tensor):
raise ValueError("Incorrect type of image embeds. "
f"Got type: {type(image_embeds)}")
return LlavaOnevisionImageEmbeddingInputs(
type="image_embeds",
data=flatten_bn(image_embeds),
)
raise AssertionError("This line should be unreachable.")
def _validate_video_pixel_values(
self, data: Union[torch.Tensor, List[torch.Tensor]]
) -> Union[torch.Tensor, List[torch.Tensor]]:
h = w = self.config.vision_config.image_size
expected_dims = (3, h, w)
def _validate_shape(d: torch.Tensor):
actual_dims = tuple(d.shape[2:])
if actual_dims != expected_dims:
expected_expr = ("num_frames", *map(str, expected_dims))
raise ValueError(
"The expected shape of pixel values in each video frame "
f"is {expected_expr}. You supplied {tuple(d.shape)}.")
for d in data:
_validate_shape(d)
return data
def _parse_and_validate_video_input(
self,
**kwargs: object) -> Optional[LlavaOnevisionVideoPixelInputs]:
"""
A legal video input should have the following dimensions:
{
"pixel_values_videos" :
List[b, Tensor(nb_frames, nb_channels, height, width)]
}
"""
pixel_values = kwargs.pop("pixel_values_videos", None)
if pixel_values is None:
return None
if not (is_list_of(pixel_values,
(torch.Tensor)) # different shape videos
or isinstance(pixel_values,
torch.Tensor)): # same shape videos
raise ValueError("Incorrect type of pixel values. "
f"Got type: {type(pixel_values)}")
return LlavaOnevisionVideoPixelInputs(
type="pixel_values_videos",
data=pixel_values,
)
def _parse_and_validate_multimodal_inputs(self, **kwargs: object) -> dict:
modalities = {}
if "pixel_values" in kwargs:
modalities["images"] = self._parse_and_validate_image_input(
**kwargs)
if "pixel_values_videos" in kwargs:
modalities["videos"] = self._parse_and_validate_video_input(
**kwargs)
return modalities
def _select_image_features(self, image_features: torch.Tensor, *,
strategy: str) -> torch.Tensor:
if strategy == "default":
return image_features[:, 1:]
elif strategy == "full":
return image_features
raise ValueError(f"Unexpected select feature strategy: {strategy}")
def _image_pixels_to_features(
self,
vision_tower: Union[CLIPVisionModel, SiglipVisionModel],
pixel_values: torch.Tensor,
) -> torch.Tensor:
# NOTE: we skip the step to select the vision feature layer since
# this is already done inside the vision tower
image_features = vision_tower(pixel_values)
return self._select_image_features(
image_features,
strategy=self.config.vision_feature_select_strategy,
)
# Based on: https://github.com/haotian-liu/LLaVA/blob/main/llava/model/llava_arch.py
def _merge_image_patch_embeddings(self,
image_size: torch.Tensor,
patch_embeddings: torch.Tensor,
*,
image_newline=None,
vision_aspect_ratio="anyres_max_9",
strategy: str) -> torch.Tensor:
if strategy == "flat":
return patch_embeddings.flatten(0, 1)
if strategy.startswith("spatial"):
height = width = self.config.vision_config.image_size \
// self.config.vision_config.patch_size
base_patch_embeds = patch_embeddings[0]
if height * width != base_patch_embeds.shape[0]:
raise ValueError(
"The number of patches is not consistent with the "
"image size.")
if patch_embeddings.shape[0] > 1:
other_patch_embeds = patch_embeddings[1:]
# Move to CPU to avoid floating-point errors
orig_height, orig_width = image_size.tolist()
# image_aspect_ratio == "anyres"
num_patch_height, num_patch_width = get_anyres_image_grid_shape(
(orig_height, orig_width),
self.config.image_grid_pinpoints,
self.config.vision_config.image_size,
)
num_patches = num_patch_height * num_patch_width
# Image patches might be padded for batch processing
other_patch_embeds = other_patch_embeds[:num_patches] \
.view(num_patch_height, num_patch_width, height, width, -1)
if "unpad" in strategy:
other_patch_embeds = other_patch_embeds \
.permute(4, 0, 2, 1, 3).contiguous() \
.flatten(1, 2).flatten(2, 3)
other_patch_embeds = unpad_image(other_patch_embeds,
(orig_height, orig_width))
max_num_patches = int(
vision_aspect_ratio.removeprefix("anyres_max_"))
channels, curr_height, curr_width = other_patch_embeds.shape
ratio = math.sqrt(curr_height * curr_width /
(max_num_patches * height**2))
if ratio > 1.1:
other_patch_embeds = other_patch_embeds[None]
other_patch_embeds = nn.functional.interpolate(
other_patch_embeds, [
int(curr_height // ratio),
int(curr_width // ratio)
],
mode="bilinear")[0]
if image_newline is not None:
other_patch_embeds = torch.cat(
(
other_patch_embeds,
image_newline[:, None, None] \
.expand(*other_patch_embeds.shape[:-1], 1) \
.to(other_patch_embeds.device),
),
dim=-1)
other_patch_embeds = other_patch_embeds \
.flatten(1, 2).transpose(0, 1)
else:
other_patch_embeds = other_patch_embeds \
.permute(0, 2, 1, 3, 4).contiguous() \
.flatten(0, 3)
merged_patch_embeddings = torch.cat(
(base_patch_embeds, other_patch_embeds), dim=0)
else:
if "unpad" in strategy:
merged_patch_embeddings = torch.cat(
(base_patch_embeds,
self.image_newline[None] \
.to(base_patch_embeds.device)
), dim=0)
else:
merged_patch_embeddings = base_patch_embeds
return merged_patch_embeddings
raise ValueError(f"Unexpected patch merge strategy: {strategy}")
def _process_image_pixels(
self,
inputs: LlavaOnevisionImagePixelInputs,
) -> Union[torch.Tensor, List[torch.Tensor]]:
assert self.vision_tower is not None
pixel_values = inputs["data"]
if isinstance(pixel_values, torch.Tensor):
b, num_patches, c, h, w = pixel_values.shape
stacked_pixel_values = pixel_values.view(b * num_patches, c, h, w)
stacked_image_features = self._image_pixels_to_features(
self.vision_tower, stacked_pixel_values)
stacked_patch_embeddings = self.multi_modal_projector(
stacked_image_features)
return stacked_patch_embeddings.view(
b, num_patches, *stacked_patch_embeddings.shape[1:])
num_patches_per_batch = [v.shape[0] for v in pixel_values]
stacked_pixel_values = torch.cat(pixel_values)
stacked_image_features = self._image_pixels_to_features(
self.vision_tower, stacked_pixel_values)
return [
self.multi_modal_projector(image_features) for image_features in
torch.split(stacked_image_features, num_patches_per_batch)
]
def _process_image_input(
self,
image_input: LlavaOnevisionImageInputs,
) -> Union[torch.Tensor, List[torch.Tensor]]:
if image_input["type"] == "image_embeds":
return [image_input["data"]]
patch_embeddings = self._process_image_pixels(image_input)
image_sizes = image_input.get("image_sizes")
if image_sizes is None:
batch_size = len(image_input["data"])
vision_config = self.config.vision_config
default_height = default_width = vision_config.image_size
image_sizes = torch.as_tensor([[default_height, default_width]
for _ in range(batch_size)])
return [
self._merge_image_patch_embeddings(
image_sizes[i],
patch_features_batch,
image_newline=self.image_newline,
strategy="spatial_unpad")
for i, patch_features_batch in enumerate(patch_embeddings)
]
def _video_pixels_to_features(
self,
vision_tower: Union[CLIPVisionModel, SiglipVisionModel],
pixel_values: torch.Tensor,
) -> torch.Tensor:
# NOTE: we skip the step to select the vision feature layer since
# this is already done inside the vision tower
b, num_videos, frames, c, h, w = pixel_values.shape
assert (num_videos == _MAX_NUM_VIDEOS)
pixel_values = pixel_values.reshape(b * num_videos * frames, c, h, w)
video_features = vision_tower(pixel_values)
video_features = self._select_image_features(
video_features,
strategy=self.config.vision_feature_select_strategy,
)
video_features = self.multi_modal_projector(video_features)
video_features = self.apply_pooling(video_features)
video_features = video_features.reshape(
b, frames * video_features.shape[1], -1)
image_newline = self.image_newline[None, None, :].repeat(b, 1, 1).to(
video_features.device)
video_features = torch.cat((video_features, image_newline), dim=1)
video_features = video_features.flatten(0, 1)
return video_features
def _process_video_pixels(self, inputs: LlavaOnevisionVideoPixelInputs):
assert self.vision_tower is not None
video_pixels = inputs["data"]
# TODO: support multiple videos per input
if isinstance(video_pixels, torch.Tensor):
stacked_embeddings = self._video_pixels_to_features(
self.vision_tower, video_pixels)
return stacked_embeddings
else:
raise ValueError(
f"Unsupported type of video input {type(video_pixels)}")
def apply_pooling(self, image_features, stride=2):
vision_config = self.config.vision_config
height = width = vision_config.image_size // vision_config.patch_size
batch_frames, _, dim = image_features.shape
image_features = image_features.view(batch_frames, height, width, -1)
image_features = image_features.permute(0, 3, 1, 2)
# TODO support other pooling types config
height, width = image_features.shape[2:]
scaled_shape = [math.ceil(height / stride), math.ceil(width / stride)]
image_feature = nn.functional.interpolate(image_features,
size=scaled_shape,
mode='bilinear')
image_feature = image_feature.permute(0, 2, 3, 1)
image_feature = image_feature.view(batch_frames, -1, dim)
return image_feature
def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata,
intermediate_tensors: Optional[IntermediateTensors] = None,
**kwargs: object,
) -> SamplerOutput:
"""Run forward pass for LlaVA-Onevision.
Args:
input_ids: Flattened (concatenated) input_ids corresponding to a
batch.
pixel_values_videos: Pixels in each frames for each input videos.
"""
modalities = self._parse_and_validate_multimodal_inputs(**kwargs)
# merge video embeddings into input embeddings
if modalities:
inputs_embeds = self.language_model.model.get_input_embeddings(
input_ids)
if "images" in modalities:
image_input = modalities["images"]
vision_embeddings = self._process_image_input(image_input)
inputs_embeds = merge_multimodal_embeddings(
input_ids, inputs_embeds, vision_embeddings,
self.config.image_token_index)
if "videos" in modalities:
video_input = modalities["videos"]
video_embeddings = self._process_video_pixels(video_input)
inputs_embeds = merge_multimodal_embeddings(
input_ids, inputs_embeds, video_embeddings,
self.config.video_token_index)
input_ids = None
else:
inputs_embeds = None
hidden_states = self.language_model.model(input_ids,
positions,
kv_caches,
attn_metadata,
None,
inputs_embeds=inputs_embeds)
return hidden_states
def compute_logits(
self,
hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> Optional[torch.Tensor]:
return self.language_model.compute_logits(hidden_states,
sampling_metadata)
def sample(
self,
logits: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> Optional[SamplerOutput]:
return self.language_model.sample(logits, sampling_metadata)
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
# prepare weight iterators for components
weights_group = group_weights_with_prefix(weights)
# load vision encoder
self.vision_tower.load_weights(weights_group["vision_tower"])
# load mlp projector
mlp_params_dict = dict(self.multi_modal_projector.named_parameters())
for name, loaded_weight in weights_group["multi_modal_projector"]:
param = mlp_params_dict[name]
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
weight_loader(param, loaded_weight)
# load llm backbone
self.language_model.load_weights(weights_group["language_model"])

View File

@ -4,6 +4,7 @@ within a vision language model."""
import math import math
from typing import Iterable, List, Optional, Tuple, Union from typing import Iterable, List, Optional, Tuple, Union
import numpy as np
import torch import torch
from PIL import Image from PIL import Image
from torch import nn from torch import nn
@ -89,6 +90,24 @@ def dummy_image_for_siglip(
return {"image": image if num_images == 1 else [image] * num_images} return {"image": image if num_images == 1 else [image] * num_images}
def dummy_video_for_siglip(
hf_config: SiglipVisionConfig,
num_frames: int,
*,
image_width_override: Optional[int] = None,
image_height_override: Optional[int] = None,
):
pil_frame = dummy_image_for_siglip(
hf_config,
num_images=1,
image_width_override=image_width_override,
image_height_override=image_height_override)
np_frame = np.array(pil_frame["image"])
mm_data_per_video = np.repeat([np_frame], num_frames, axis=0)
mm_data = {"video": mm_data_per_video}
return mm_data
def input_processor_for_siglip( def input_processor_for_siglip(
model_config: ModelConfig, model_config: ModelConfig,
hf_config: SiglipVisionConfig, hf_config: SiglipVisionConfig,