[Model] Add support for H2OVL-Mississippi models (#9747)
Signed-off-by: Shanshan Wang <shanshan.wang@h2o.ai> Signed-off-by: Roger Wang <ywang@roblox.com> Co-authored-by: Roger Wang <ywang@roblox.com>
This commit is contained in:
parent
1f1b6d6eda
commit
54597724f4
@ -440,6 +440,12 @@ Text Generation
|
||||
- :code:`THUDM/glm-4v-9b` etc.
|
||||
-
|
||||
- ✅︎
|
||||
* - :code:`H2OVLChatModel`
|
||||
- H2OVL
|
||||
- T + I\ :sup:`E+`
|
||||
- :code:`h2oai/h2ovl-mississippi-800m`, :code:`h2oai/h2ovl-mississippi-2b`, etc.
|
||||
-
|
||||
- ✅︎
|
||||
* - :code:`InternVLChatModel`
|
||||
- InternVL2
|
||||
- T + I\ :sup:`E+`
|
||||
|
@ -176,6 +176,31 @@ def run_minicpmv(question: str, modality: str):
|
||||
return llm, prompt, stop_token_ids
|
||||
|
||||
|
||||
# H2OVL-Mississippi
|
||||
def run_h2ovl(question: str, modality: str):
|
||||
assert modality == "image"
|
||||
|
||||
model_name = "h2oai/h2ovl-mississippi-2b"
|
||||
|
||||
llm = LLM(
|
||||
model=model_name,
|
||||
trust_remote_code=True,
|
||||
max_model_len=8192,
|
||||
)
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_name,
|
||||
trust_remote_code=True)
|
||||
messages = [{'role': 'user', 'content': f"<image>\n{question}"}]
|
||||
prompt = tokenizer.apply_chat_template(messages,
|
||||
tokenize=False,
|
||||
add_generation_prompt=True)
|
||||
|
||||
# Stop tokens for H2OVL-Mississippi
|
||||
# https://huggingface.co/h2oai/h2ovl-mississippi-2b
|
||||
stop_token_ids = [tokenizer.eos_token_id]
|
||||
return llm, prompt, stop_token_ids
|
||||
|
||||
|
||||
# InternVL
|
||||
def run_internvl(question: str, modality: str):
|
||||
assert modality == "image"
|
||||
@ -363,6 +388,7 @@ model_example_map = {
|
||||
"chameleon": run_chameleon,
|
||||
"minicpmv": run_minicpmv,
|
||||
"blip-2": run_blip2,
|
||||
"h2ovl_chat": run_h2ovl,
|
||||
"internvl_chat": run_internvl,
|
||||
"NVLM_D": run_nvlm_d,
|
||||
"qwen_vl": run_qwen_vl,
|
||||
@ -475,4 +501,4 @@ if __name__ == "__main__":
|
||||
default=16,
|
||||
help='Number of frames to extract from the video.')
|
||||
args = parser.parse_args()
|
||||
main(args)
|
||||
main(args)
|
@ -107,6 +107,40 @@ def load_phi3v(question: str, image_urls: List[str]) -> ModelRequestData:
|
||||
)
|
||||
|
||||
|
||||
def load_h2onvl(question: str, image_urls: List[str]) -> ModelRequestData:
|
||||
model_name = "h2oai/h2ovl-mississippi-2b"
|
||||
|
||||
llm = LLM(
|
||||
model=model_name,
|
||||
trust_remote_code=True,
|
||||
max_model_len=8192,
|
||||
limit_mm_per_prompt={"image": len(image_urls)},
|
||||
mm_processor_kwargs={"max_dynamic_patch": 4},
|
||||
)
|
||||
|
||||
placeholders = "\n".join(f"Image-{i}: <image>\n"
|
||||
for i, _ in enumerate(image_urls, start=1))
|
||||
messages = [{'role': 'user', 'content': f"{placeholders}\n{question}"}]
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_name,
|
||||
trust_remote_code=True)
|
||||
prompt = tokenizer.apply_chat_template(messages,
|
||||
tokenize=False,
|
||||
add_generation_prompt=True)
|
||||
|
||||
# Stop tokens for H2OVL-Mississippi
|
||||
# https://huggingface.co/h2oai/h2ovl-mississippi-2b
|
||||
stop_token_ids = [tokenizer.eos_token_id]
|
||||
|
||||
return ModelRequestData(
|
||||
llm=llm,
|
||||
prompt=prompt,
|
||||
stop_token_ids=stop_token_ids,
|
||||
image_data=[fetch_image(url) for url in image_urls],
|
||||
chat_template=None,
|
||||
)
|
||||
|
||||
|
||||
def load_internvl(question: str, image_urls: List[str]) -> ModelRequestData:
|
||||
model_name = "OpenGVLab/InternVL2-2B"
|
||||
|
||||
@ -258,6 +292,7 @@ def load_mllama(question, image_urls: List[str]) -> ModelRequestData:
|
||||
|
||||
model_example_map = {
|
||||
"phi3_v": load_phi3v,
|
||||
"h2ovl_chat": load_h2onvl,
|
||||
"internvl_chat": load_internvl,
|
||||
"NVLM_D": load_nvlm_d,
|
||||
"qwen2_vl": load_qwen2_vl,
|
||||
|
130
tests/models/decoder_only/vision_language/test_h2ovl.py
Normal file
130
tests/models/decoder_only/vision_language/test_h2ovl.py
Normal file
@ -0,0 +1,130 @@
|
||||
from typing import Optional, Tuple
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
from PIL.Image import Image
|
||||
from transformers import AutoConfig
|
||||
|
||||
# Import the functions to test
|
||||
from vllm.model_executor.models.h2ovl import (calculate_num_blocks,
|
||||
image_to_pixel_values_wrapper)
|
||||
from vllm.multimodal.utils import rescale_image_size
|
||||
|
||||
models = [
|
||||
"h2oai/h2ovl-mississippi-800m", # Replace with your actual model names
|
||||
"h2oai/h2ovl-mississippi-2b",
|
||||
]
|
||||
target_dtype = "bfloat16"
|
||||
|
||||
|
||||
def run_preprocessing_test(
|
||||
image: Image,
|
||||
config,
|
||||
max_dynamic_patch: Optional[int] = None,
|
||||
) -> Tuple[torch.Tensor, int]:
|
||||
"""Test the image preprocessing and calculate expected blocks."""
|
||||
|
||||
if max_dynamic_patch is None:
|
||||
max_dynamic_patch = config.max_dynamic_patch
|
||||
|
||||
width, height = image.size
|
||||
use_MSAC = config.use_msac
|
||||
|
||||
# Create the mapper function with the provided configuration
|
||||
mapper = image_to_pixel_values_wrapper(config, max_dynamic_patch, use_MSAC)
|
||||
pixel_values = mapper(image)
|
||||
|
||||
# Calculate the expected number of blocks
|
||||
if use_MSAC:
|
||||
# First pass
|
||||
blocks1, _, _, aspect_ratio = calculate_num_blocks(
|
||||
width,
|
||||
height,
|
||||
config.min_dynamic_patch,
|
||||
max_dynamic_patch,
|
||||
config.vision_config.image_size,
|
||||
use_thumbnail=False, # Thumbnail is handled separately
|
||||
prior_aspect_ratio=None,
|
||||
)
|
||||
|
||||
# Second pass
|
||||
blocks2, _, _, _ = calculate_num_blocks(
|
||||
width,
|
||||
height,
|
||||
config.min_dynamic_patch,
|
||||
max_dynamic_patch,
|
||||
config.vision_config.image_size,
|
||||
use_thumbnail=False,
|
||||
prior_aspect_ratio=aspect_ratio,
|
||||
)
|
||||
|
||||
# Add thumbnail if use_thumbnail is True and total_blocks > 1
|
||||
if config.use_thumbnail:
|
||||
blocks1 += 1 if blocks1 > 1 else 0
|
||||
blocks2 += 1 if blocks2 > 1 else 0
|
||||
|
||||
# Total blocks is the sum of blocks from both passes minus overlapping
|
||||
total_blocks = blocks1 + blocks2 - 1
|
||||
|
||||
expected_blocks = total_blocks
|
||||
|
||||
else:
|
||||
blocks, _, _, _ = calculate_num_blocks(
|
||||
width,
|
||||
height,
|
||||
config.min_dynamic_patch,
|
||||
max_dynamic_patch,
|
||||
config.vision_config.image_size,
|
||||
use_thumbnail=False,
|
||||
prior_aspect_ratio=None,
|
||||
)
|
||||
expected_blocks = blocks
|
||||
|
||||
if config.use_thumbnail and expected_blocks > 1:
|
||||
expected_blocks += 1
|
||||
|
||||
return pixel_values, expected_blocks
|
||||
|
||||
|
||||
@pytest.mark.parametrize("model_name", models)
|
||||
@pytest.mark.parametrize(
|
||||
"size_factors",
|
||||
[
|
||||
# Single-scale
|
||||
[1.0],
|
||||
# Single-scale, batched
|
||||
[1.0, 1.0, 1.0],
|
||||
# Multi-scale
|
||||
[0.25, 0.5, 1.0],
|
||||
],
|
||||
)
|
||||
@pytest.mark.parametrize("max_dynamic_patch", [None, 2, 4, 8])
|
||||
def test_image_preprocessing(image_assets, model_name, size_factors,
|
||||
max_dynamic_patch):
|
||||
"""Test image preprocessing pipeline with different configurations."""
|
||||
# Load the configuration from the model
|
||||
config = AutoConfig.from_pretrained(model_name, trust_remote_code=True)
|
||||
|
||||
for asset in image_assets:
|
||||
image = asset.pil_image
|
||||
for factor in size_factors:
|
||||
scaled_image = rescale_image_size(image, factor)
|
||||
|
||||
# Test preprocessing and get expected number of blocks
|
||||
pixel_values, expected_blocks = run_preprocessing_test(
|
||||
scaled_image, config, max_dynamic_patch)
|
||||
|
||||
# Verify output shapes and properties
|
||||
actual_blocks = pixel_values.shape[0]
|
||||
assert actual_blocks == expected_blocks, (
|
||||
f"Expected {expected_blocks} blocks, got {actual_blocks}")
|
||||
|
||||
# Check image dimensions
|
||||
expected_size = (
|
||||
3, # Number of channels (C, H, W)
|
||||
config.vision_config.image_size,
|
||||
config.vision_config.image_size,
|
||||
)
|
||||
for img in pixel_values:
|
||||
assert img.shape == expected_size, (
|
||||
f"Expected image size {expected_size}, got {img.shape}")
|
@ -187,6 +187,23 @@ VLM_TEST_SETTINGS = {
|
||||
marks=[large_gpu_mark(min_gb=48)],
|
||||
patch_hf_runner=model_utils.glm_patch_hf_runner,
|
||||
),
|
||||
"h2ovl": VLMTestInfo(
|
||||
models = [
|
||||
"h2oai/h2ovl-mississippi-800m",
|
||||
"h2oai/h2ovl-mississippi-2b",
|
||||
],
|
||||
test_type=(VLMTestType.IMAGE, VLMTestType.MULTI_IMAGE),
|
||||
prompt_formatter=lambda img_prompt: f"<|prompt|>{img_prompt}<|end|><|answer|>", # noqa: E501
|
||||
single_image_prompts=IMAGE_ASSETS.prompts({
|
||||
"stop_sign": "<image>\nWhat's the content in the center of the image?", # noqa: E501
|
||||
"cherry_blossom": "<image>\nWhat is the season?",
|
||||
}),
|
||||
multi_image_prompt="Image-1: <image>\nImage-2: <image>\nDescribe the two images in short.", # noqa: E501
|
||||
max_model_len=8192,
|
||||
dtype="bfloat16",
|
||||
use_tokenizer_eos=True,
|
||||
patch_hf_runner=model_utils.h2ovl_patch_hf_runner,
|
||||
),
|
||||
"intern_vl": VLMTestInfo(
|
||||
models=[
|
||||
"OpenGVLab/InternVL2-1B",
|
||||
|
@ -259,6 +259,66 @@ def glm_patch_hf_runner(hf_model: HfRunner) -> HfRunner:
|
||||
return hf_model
|
||||
|
||||
|
||||
def h2ovl_patch_hf_runner(hf_model: HfRunner) -> HfRunner:
|
||||
"""Patches and returns an instance of the HfRunner to use for H2OVL."""
|
||||
|
||||
class H2OVLProcessor:
|
||||
"""A simple processor for H2OVL models."""
|
||||
|
||||
def __init__(self, hf_runner: HfRunner):
|
||||
self.num_image_token = hf_runner.model.num_image_token
|
||||
self.tokenizer = hf_runner.tokenizer
|
||||
self.dtype = hf_runner.model.dtype
|
||||
|
||||
self.config = AutoConfig.from_pretrained(hf_runner.model_name,
|
||||
trust_remote_code=True)
|
||||
self.vision_config = self.config.vision_config
|
||||
self.use_thumbnail = self.config.use_thumbnail
|
||||
self.min_num = self.config.min_dynamic_patch
|
||||
self.max_num = self.config.max_dynamic_patch
|
||||
self.image_size = self.vision_config.image_size
|
||||
|
||||
def __call__(self, text: str, images: Union[Image, List[Image]],
|
||||
**kwargs):
|
||||
# yapf: disable
|
||||
from vllm.model_executor.models.h2ovl import (
|
||||
IMG_CONTEXT, IMG_END, IMG_START, image_to_pixel_values)
|
||||
|
||||
# yapf: enable
|
||||
images = [images] if isinstance(images, Image) else images
|
||||
pixel_values = [
|
||||
image_to_pixel_values(image,
|
||||
self.image_size,
|
||||
self.min_num,
|
||||
self.max_num,
|
||||
self.use_thumbnail,
|
||||
use_MSAC=self.config.use_msac).to(
|
||||
self.dtype) for image in images
|
||||
]
|
||||
num_patches_list = [
|
||||
pixel_value.shape[0] for pixel_value in pixel_values
|
||||
]
|
||||
pixel_values = torch.cat(pixel_values, dim=0)
|
||||
for num_patches in num_patches_list:
|
||||
context_tokens = IMG_CONTEXT * self.num_image_token \
|
||||
* num_patches
|
||||
image_tokens = IMG_START + context_tokens + IMG_END
|
||||
text = text.replace('<image>', image_tokens, 1)
|
||||
prompt = self.tokenizer(text, return_tensors="pt")
|
||||
prompt.update({"pixel_values": pixel_values})
|
||||
return prompt
|
||||
|
||||
img_context_token_id = hf_model.tokenizer.convert_tokens_to_ids(
|
||||
"<IMG_CONTEXT>")
|
||||
hf_model.model.img_context_token_id = img_context_token_id
|
||||
hf_model.processor = H2OVLProcessor(hf_model)
|
||||
hf_model.model.get_output_embeddings = lambda: \
|
||||
hf_model.model.language_model.get_output_embeddings()
|
||||
hf_model.model.generate = types.MethodType(_internvl_generate,
|
||||
hf_model.model)
|
||||
return hf_model
|
||||
|
||||
|
||||
def internvl_patch_hf_runner(hf_model: HfRunner) -> HfRunner:
|
||||
"""Patches and returns an instance of the HfRunner to use for InternVL."""
|
||||
|
||||
|
@ -187,7 +187,8 @@ class BaseMultiModalItemTracker(ABC, Generic[_T]):
|
||||
if model_type.startswith("llava"):
|
||||
return self._cached_token_str(self._tokenizer,
|
||||
hf_config.image_token_index)
|
||||
if model_type in ("chameleon", "internvl_chat", "NVLM_D"):
|
||||
if model_type in ("chameleon", "internvl_chat", "NVLM_D",
|
||||
"h2ovl_chat"):
|
||||
return "<image>"
|
||||
if model_type == "mllama":
|
||||
return "<|image|>"
|
||||
|
401
vllm/model_executor/models/h2ovl.py
Normal file
401
vllm/model_executor/models/h2ovl.py
Normal file
@ -0,0 +1,401 @@
|
||||
# adapted from https://huggingface.co/h2oai/h2ovl-mississippi-2b/blob/main/modeling_h2ovl_chat.py
|
||||
# https://huggingface.co/h2oai/h2ovl-mississippi-2b/blob/main/image_process.py
|
||||
# --------------------------------------------------------
|
||||
# H2OVL-Mississippi
|
||||
# Copyright (c) 2024 H2O.AI
|
||||
# Licensed under Apache 2.0 License [see LICENSE for details]
|
||||
# --------------------------------------------------------
|
||||
from functools import partial
|
||||
from typing import List, Optional, Tuple
|
||||
|
||||
import torch
|
||||
from PIL import Image
|
||||
from transformers import PretrainedConfig
|
||||
|
||||
from vllm.inputs import (INPUT_REGISTRY, DecoderOnlyInputs, InputContext,
|
||||
token_inputs)
|
||||
from vllm.model_executor.layers.quantization import QuantizationConfig
|
||||
from vllm.multimodal import MULTIMODAL_REGISTRY
|
||||
from vllm.multimodal.base import MultiModalInputs
|
||||
from vllm.multimodal.utils import cached_get_tokenizer
|
||||
from vllm.utils import is_list_of
|
||||
|
||||
from .intern_vit import InternVisionModel
|
||||
from .internvl import (IMG_CONTEXT, IMG_END, IMG_START, InternVLChatModel,
|
||||
InternVLInputPipeline, build_transform,
|
||||
find_closest_aspect_ratio, get_internvl_num_patches)
|
||||
|
||||
|
||||
# modified to include blocks generated in second pass
|
||||
def calculate_num_blocks(
|
||||
orig_width: int,
|
||||
orig_height: int,
|
||||
min_num: int,
|
||||
max_num: int,
|
||||
image_size: int,
|
||||
use_thumbnail: bool,
|
||||
prior_aspect_ratio=None,
|
||||
) -> Tuple[int, int, int, Tuple[int, int]]:
|
||||
aspect_ratio = orig_width / orig_height
|
||||
|
||||
# calculate the existing image aspect ratio
|
||||
target_ratios = set((i, j) for n in range(min_num, max_num + 1)
|
||||
for i in range(1, n + 1) for j in range(1, n + 1)
|
||||
if i * j <= max_num and i * j >= min_num)
|
||||
target_ratios = sorted(target_ratios, key=lambda x: x[0] * x[1])
|
||||
|
||||
# if prior_aspect_ratio is provided, filter the target ratios
|
||||
if prior_aspect_ratio is not None:
|
||||
target_ratios = [
|
||||
ratio for ratio in target_ratios if prior_aspect_ratio[0] %
|
||||
ratio[0] != 0 and prior_aspect_ratio[1] % ratio[1] != 0
|
||||
]
|
||||
|
||||
# find the closest aspect ratio to the target
|
||||
target_aspect_ratio = find_closest_aspect_ratio(aspect_ratio,
|
||||
target_ratios, orig_width,
|
||||
orig_height, image_size)
|
||||
|
||||
# calculate the target width and height
|
||||
target_width = image_size * target_aspect_ratio[0]
|
||||
target_height = image_size * target_aspect_ratio[1]
|
||||
blocks = target_aspect_ratio[0] * target_aspect_ratio[1]
|
||||
# add thumbnail image if num_blocks > 1
|
||||
if use_thumbnail and blocks > 1:
|
||||
blocks += 1
|
||||
return blocks, target_width, target_height, target_aspect_ratio
|
||||
|
||||
|
||||
# adapted from https://huggingface.co/OpenGVLab/InternVL2-1B
|
||||
# refactored to handle prior_aspect_ratio as optional
|
||||
def dynamic_preprocess(
|
||||
image: Image.Image,
|
||||
min_num: int,
|
||||
max_num: int,
|
||||
image_size: int,
|
||||
use_thumbnail: bool,
|
||||
prior_aspect_ratio: Optional[Tuple[int, int]] = None,
|
||||
) -> Tuple[List[Image.Image], Tuple[int, int]]:
|
||||
orig_width, orig_height = image.size
|
||||
|
||||
# calculate the number of blocks based on prior aspect ratio if available
|
||||
blocks, target_width, target_height, target_aspect_ratio = (
|
||||
calculate_num_blocks(
|
||||
orig_width,
|
||||
orig_height,
|
||||
min_num,
|
||||
max_num,
|
||||
image_size,
|
||||
use_thumbnail=False,
|
||||
prior_aspect_ratio=prior_aspect_ratio,
|
||||
))
|
||||
# resize the image
|
||||
resized_img = image.resize((target_width, target_height))
|
||||
processed_images = []
|
||||
for i in range(blocks):
|
||||
box = (
|
||||
(i % (target_width // image_size)) * image_size,
|
||||
(i // (target_width // image_size)) * image_size,
|
||||
((i % (target_width // image_size)) + 1) * image_size,
|
||||
((i // (target_width // image_size)) + 1) * image_size,
|
||||
)
|
||||
# split the image
|
||||
split_img = resized_img.crop(box)
|
||||
processed_images.append(split_img)
|
||||
assert len(processed_images) == blocks
|
||||
if use_thumbnail and len(processed_images) != 1:
|
||||
thumbnail_img = image.resize((image_size, image_size))
|
||||
processed_images.append(thumbnail_img)
|
||||
return processed_images, target_aspect_ratio
|
||||
|
||||
|
||||
def load_image(
|
||||
image: Image.Image,
|
||||
input_size=448,
|
||||
min_num=1,
|
||||
max_num=6,
|
||||
use_thumbnail=True,
|
||||
prior_aspect_ratio: Optional[Tuple[int, int]] = None,
|
||||
) -> Tuple[torch.Tensor, Tuple[int, int]]:
|
||||
transform = build_transform(input_size=input_size)
|
||||
images, target_aspect_ratio = dynamic_preprocess(
|
||||
image,
|
||||
image_size=input_size,
|
||||
use_thumbnail=use_thumbnail,
|
||||
min_num=min_num,
|
||||
max_num=max_num,
|
||||
prior_aspect_ratio=prior_aspect_ratio,
|
||||
)
|
||||
pixel_values = [transform(image) for image in images]
|
||||
pixel_values = torch.stack(pixel_values)
|
||||
return pixel_values, target_aspect_ratio
|
||||
|
||||
|
||||
# refactored to use the combined load_image function
|
||||
def image_to_pixel_values(
|
||||
image: Image.Image,
|
||||
input_size: int,
|
||||
min_num: int,
|
||||
max_num: int,
|
||||
use_thumbnail: bool,
|
||||
use_MSAC: bool,
|
||||
) -> torch.Tensor:
|
||||
# when MSAC is turned on, we need to process the image twice
|
||||
if use_MSAC:
|
||||
# first pass
|
||||
pixel_values, target_aspect_ratio = load_image(
|
||||
image,
|
||||
input_size=input_size,
|
||||
min_num=min_num,
|
||||
max_num=max_num,
|
||||
use_thumbnail=True,
|
||||
)
|
||||
# second pass
|
||||
pixel_values2, _ = load_image(
|
||||
image,
|
||||
input_size=input_size,
|
||||
min_num=min_num,
|
||||
max_num=max_num,
|
||||
prior_aspect_ratio=target_aspect_ratio,
|
||||
)
|
||||
# combine pixel values
|
||||
pixel_values = torch.cat(
|
||||
[pixel_values2[:-1], pixel_values[:-1], pixel_values2[-1:]], 0)
|
||||
|
||||
else:
|
||||
pixel_values, _ = load_image(
|
||||
image,
|
||||
input_size=input_size,
|
||||
min_num=min_num,
|
||||
max_num=max_num,
|
||||
use_thumbnail=use_thumbnail,
|
||||
)
|
||||
|
||||
return pixel_values
|
||||
|
||||
|
||||
def image_to_pixel_values_wrapper(hf_config: PretrainedConfig,
|
||||
max_dynamic_patch: Optional[int] = None,
|
||||
use_MSAC: Optional[bool] = None):
|
||||
image_size = hf_config.vision_config.image_size
|
||||
min_num = hf_config.min_dynamic_patch
|
||||
if max_dynamic_patch is None:
|
||||
max_dynamic_patch = hf_config.max_dynamic_patch
|
||||
if use_MSAC is None:
|
||||
use_MSAC = hf_config.use_msac
|
||||
use_thumbnail = hf_config.use_thumbnail
|
||||
return partial(
|
||||
image_to_pixel_values,
|
||||
input_size=image_size,
|
||||
min_num=min_num,
|
||||
max_num=max_dynamic_patch,
|
||||
use_thumbnail=use_thumbnail,
|
||||
use_MSAC=use_MSAC,
|
||||
)
|
||||
|
||||
|
||||
def get_max_internvl_image_tokens(ctx: InputContext,
|
||||
*,
|
||||
max_dynamic_patch: Optional[int] = None):
|
||||
"""
|
||||
Calculate the maximum number of tokens with/without MSAC and thumbnail
|
||||
"""
|
||||
hf_config = ctx.get_hf_config()
|
||||
use_thumbnail = hf_config.use_thumbnail
|
||||
use_MSAC = hf_config.use_msac
|
||||
|
||||
if max_dynamic_patch is None:
|
||||
max_dynamic_patch = hf_config.max_dynamic_patch
|
||||
|
||||
num_patches = get_internvl_num_patches(hf_config)
|
||||
|
||||
coefficient = 2 if use_MSAC else 1
|
||||
num_blocks = coefficient * max_dynamic_patch + (1 if use_thumbnail else 0)
|
||||
|
||||
return num_blocks * num_patches
|
||||
|
||||
|
||||
class H2OVLInputPipeline(InternVLInputPipeline):
|
||||
"""
|
||||
Input pipeline for processing image and text data for the H2OVL model.
|
||||
"""
|
||||
|
||||
def input_processor(
|
||||
self,
|
||||
ctx: InputContext,
|
||||
inputs: DecoderOnlyInputs,
|
||||
*,
|
||||
max_dynamic_patch: Optional[int] = None,
|
||||
) -> DecoderOnlyInputs:
|
||||
# get multi_modal_data
|
||||
multi_modal_data = inputs.get("multi_modal_data")
|
||||
if multi_modal_data is None or "image" not in multi_modal_data:
|
||||
return inputs
|
||||
|
||||
model_config = ctx.model_config
|
||||
hf_config = ctx.get_hf_config()
|
||||
use_MSAC = hf_config.use_msac
|
||||
|
||||
image_data = multi_modal_data["image"]
|
||||
num_patches = get_internvl_num_patches(hf_config)
|
||||
|
||||
image_pixel_values_mapper = image_to_pixel_values_wrapper(
|
||||
hf_config, max_dynamic_patch=max_dynamic_patch)
|
||||
|
||||
# single image
|
||||
if isinstance(image_data, Image.Image):
|
||||
pixel_values = image_pixel_values_mapper(image_data,
|
||||
use_MSAC=use_MSAC)
|
||||
num_blocks = pixel_values.shape[0]
|
||||
image_feature_sizes = [num_blocks * num_patches]
|
||||
pixel_values = pixel_values.unsqueeze(0)
|
||||
|
||||
# multi images
|
||||
elif is_list_of(image_data, Image.Image):
|
||||
# Do not use MSAC for multi images
|
||||
image_feature_sizes = []
|
||||
pixel_values = [
|
||||
image_pixel_values_mapper(image, use_MSAC=False)
|
||||
for image in image_data
|
||||
]
|
||||
for pixel_value in pixel_values:
|
||||
num_blocks = pixel_value.shape[0]
|
||||
image_feature_sizes.append(num_blocks * num_patches)
|
||||
|
||||
# image embeddings as input
|
||||
elif isinstance(image_data, torch.Tensor):
|
||||
_, image_feature_size, _ = image_data.shape
|
||||
image_feature_sizes = [image_feature_size]
|
||||
pixel_values = None
|
||||
|
||||
# multi-image image embeddings
|
||||
elif is_list_of(image_data, torch.Tensor):
|
||||
|
||||
image_feature_sizes = []
|
||||
for image_embed in image_data:
|
||||
_, image_feature_size, _ = image_embed.shape
|
||||
image_feature_sizes.append(image_feature_size)
|
||||
pixel_values = None
|
||||
|
||||
else:
|
||||
raise TypeError(f"Invalid image type: {type(image_data)}")
|
||||
|
||||
tokenizer = cached_get_tokenizer(
|
||||
model_config.tokenizer,
|
||||
trust_remote_code=model_config.trust_remote_code,
|
||||
)
|
||||
|
||||
prompt = inputs.get("prompt")
|
||||
prompt_token_ids = inputs["prompt_token_ids"]
|
||||
if prompt is None:
|
||||
prompt = tokenizer.decode(prompt_token_ids)
|
||||
|
||||
new_prompt = self._expand_image_prompt(prompt, image_feature_sizes,
|
||||
num_patches)
|
||||
new_prompt_token_ids = tokenizer.encode(new_prompt)
|
||||
|
||||
# Wrap image processing in input_processor to avoid duplication
|
||||
image_token_id = tokenizer.encode(
|
||||
self.img_context_token,
|
||||
add_special_tokens=False,
|
||||
return_tensors="pt",
|
||||
)[0]
|
||||
|
||||
# Update multi_modal_data to return
|
||||
if pixel_values is not None:
|
||||
multi_modal_data = {
|
||||
"image": {
|
||||
"pixel_values": pixel_values,
|
||||
"image_token_id": image_token_id,
|
||||
}
|
||||
}
|
||||
else:
|
||||
multi_modal_data = {"image": {"image_embeds": image_data}}
|
||||
|
||||
return token_inputs(
|
||||
prompt=prompt,
|
||||
prompt_token_ids=new_prompt_token_ids,
|
||||
multi_modal_data=multi_modal_data,
|
||||
)
|
||||
|
||||
def input_mapper(
|
||||
self,
|
||||
ctx: InputContext,
|
||||
data: object,
|
||||
*,
|
||||
max_dynamic_patch: Optional[int] = None,
|
||||
) -> MultiModalInputs:
|
||||
|
||||
# NOTE: Preprocessing for the image data is done in the
|
||||
# 'input_processor' function during actual inference.
|
||||
if isinstance(data, dict):
|
||||
return MultiModalInputs(data)
|
||||
|
||||
# The section below is only used with dummy data during
|
||||
# memory profiling.
|
||||
hf_config = ctx.get_hf_config()
|
||||
|
||||
image_pixel_values_mapper = image_to_pixel_values_wrapper(
|
||||
hf_config, max_dynamic_patch)
|
||||
|
||||
if isinstance(data, Image.Image):
|
||||
pixel_values = image_pixel_values_mapper(data)
|
||||
pixel_values = pixel_values.unsqueeze(0)
|
||||
|
||||
elif is_list_of(data, Image.Image):
|
||||
hf_config.use_msac = False
|
||||
pixel_values = [image_pixel_values_mapper(img) for img in data]
|
||||
|
||||
else:
|
||||
return MultiModalInputs({"image_embeds": data})
|
||||
model_config = ctx.model_config
|
||||
tokenizer = cached_get_tokenizer(
|
||||
model_config.tokenizer,
|
||||
trust_remote_code=model_config.trust_remote_code,
|
||||
)
|
||||
image_token_id = tokenizer.encode(
|
||||
self.img_context_token,
|
||||
add_special_tokens=False,
|
||||
return_tensors="pt",
|
||||
)[0]
|
||||
|
||||
return MultiModalInputs({
|
||||
"pixel_values": pixel_values,
|
||||
"image_token_id": image_token_id
|
||||
})
|
||||
|
||||
|
||||
input_pipeline = H2OVLInputPipeline(IMG_START, IMG_END, IMG_CONTEXT)
|
||||
|
||||
|
||||
@MULTIMODAL_REGISTRY.register_image_input_mapper(input_pipeline.input_mapper)
|
||||
@MULTIMODAL_REGISTRY.register_max_image_tokens(get_max_internvl_image_tokens)
|
||||
@INPUT_REGISTRY.register_dummy_data(input_pipeline.dummy_data)
|
||||
@INPUT_REGISTRY.register_input_processor(input_pipeline.input_processor)
|
||||
class H2OVLChatModel(InternVLChatModel):
|
||||
|
||||
def _init_vision_model(
|
||||
self,
|
||||
config: PretrainedConfig,
|
||||
quant_config: Optional[QuantizationConfig],
|
||||
*,
|
||||
is_mono: bool,
|
||||
prefix: str,
|
||||
):
|
||||
if not is_mono:
|
||||
vision_feature_layer = config.select_layer
|
||||
if vision_feature_layer < 0:
|
||||
num_hidden_layers = (config.vision_config.num_hidden_layers +
|
||||
vision_feature_layer + 1)
|
||||
else:
|
||||
num_hidden_layers = vision_feature_layer + 1
|
||||
|
||||
return InternVisionModel(
|
||||
config.vision_config,
|
||||
quant_config=quant_config,
|
||||
num_hidden_layers_override=num_hidden_layers,
|
||||
prefix=prefix,
|
||||
)
|
||||
else:
|
||||
msg = "Monolith mode is not applicable to H2OVL"
|
||||
raise NotImplementedError(msg)
|
@ -128,6 +128,7 @@ _MULTIMODAL_MODELS = {
|
||||
"ChatGLMModel": ("chatglm", "ChatGLMForCausalLM"),
|
||||
"ChatGLMForConditionalGeneration": ("chatglm", "ChatGLMForCausalLM"),
|
||||
"FuyuForCausalLM": ("fuyu", "FuyuForCausalLM"),
|
||||
"H2OVLChatModel": ("h2ovl", "H2OVLChatModel"),
|
||||
"InternVLChatModel": ("internvl", "InternVLChatModel"),
|
||||
"LlavaForConditionalGeneration": ("llava", "LlavaForConditionalGeneration"),
|
||||
"LlavaNextForConditionalGeneration": ("llava_next", "LlavaNextForConditionalGeneration"), # noqa: E501
|
||||
@ -482,4 +483,4 @@ def _run() -> None:
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
_run()
|
||||
_run()
|
@ -19,6 +19,7 @@ from vllm.logger import init_logger
|
||||
# yapf: disable
|
||||
from vllm.transformers_utils.configs import (ChatGLMConfig, DbrxConfig,
|
||||
EAGLEConfig, ExaoneConfig,
|
||||
H2OVLChatConfig,
|
||||
InternVLChatConfig, JAISConfig,
|
||||
MedusaConfig, MllamaConfig,
|
||||
MLPSpeculatorConfig, MPTConfig,
|
||||
@ -52,6 +53,7 @@ _CONFIG_REGISTRY: Dict[str, Type[PretrainedConfig]] = {
|
||||
"medusa": MedusaConfig,
|
||||
"eagle": EAGLEConfig,
|
||||
"exaone": ExaoneConfig,
|
||||
"h2ovl_chat": H2OVLChatConfig,
|
||||
"internvl_chat": InternVLChatConfig,
|
||||
"nemotron": NemotronConfig,
|
||||
"NVLM_D": NVLM_D_Config,
|
||||
|
@ -6,6 +6,7 @@ from vllm.transformers_utils.configs.exaone import ExaoneConfig
|
||||
# tiiuae/falcon-7b(-instruct) models. Newer Falcon models will use the
|
||||
# `FalconConfig` class from the official HuggingFace transformers library.
|
||||
from vllm.transformers_utils.configs.falcon import RWConfig
|
||||
from vllm.transformers_utils.configs.h2ovl import H2OVLChatConfig
|
||||
from vllm.transformers_utils.configs.internvl import InternVLChatConfig
|
||||
from vllm.transformers_utils.configs.jais import JAISConfig
|
||||
from vllm.transformers_utils.configs.medusa import MedusaConfig
|
||||
@ -22,6 +23,7 @@ __all__ = [
|
||||
"DbrxConfig",
|
||||
"MPTConfig",
|
||||
"RWConfig",
|
||||
"H2OVLChatConfig",
|
||||
"InternVLChatConfig",
|
||||
"JAISConfig",
|
||||
"MedusaConfig",
|
||||
@ -33,4 +35,4 @@ __all__ = [
|
||||
"NVLM_D_Config",
|
||||
"SolarConfig",
|
||||
"UltravoxConfig",
|
||||
]
|
||||
]
|
13
vllm/transformers_utils/configs/h2ovl.py
Normal file
13
vllm/transformers_utils/configs/h2ovl.py
Normal file
@ -0,0 +1,13 @@
|
||||
# Adapted from
|
||||
# https://huggingface.co/h2oai/h2ovl-mississippi-2b/blob/main/configuration_h2ovl_chat.py
|
||||
# --------------------------------------------------------
|
||||
# H2OVL-Mississippi
|
||||
# Copyright (c) 2024 H2O.AI
|
||||
# Licensed under Apache 2.0 License [see LICENSE for details]
|
||||
# --------------------------------------------------------
|
||||
|
||||
from .internvl import InternVLChatConfig
|
||||
|
||||
|
||||
class H2OVLChatConfig(InternVLChatConfig):
|
||||
model_type = "h2ovl_chat"
|
Loading…
x
Reference in New Issue
Block a user