[Model] Add support for H2OVL-Mississippi models (#9747)

Signed-off-by: Shanshan Wang <shanshan.wang@h2o.ai>
Signed-off-by: Roger Wang <ywang@roblox.com>
Co-authored-by: Roger Wang <ywang@roblox.com>
This commit is contained in:
shanshan wang 2024-11-03 18:15:36 -06:00 committed by GitHub
parent 1f1b6d6eda
commit 54597724f4
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
12 changed files with 698 additions and 4 deletions

View File

@ -440,6 +440,12 @@ Text Generation
- :code:`THUDM/glm-4v-9b` etc.
-
- ✅︎
* - :code:`H2OVLChatModel`
- H2OVL
- T + I\ :sup:`E+`
- :code:`h2oai/h2ovl-mississippi-800m`, :code:`h2oai/h2ovl-mississippi-2b`, etc.
-
- ✅︎
* - :code:`InternVLChatModel`
- InternVL2
- T + I\ :sup:`E+`

View File

@ -176,6 +176,31 @@ def run_minicpmv(question: str, modality: str):
return llm, prompt, stop_token_ids
# H2OVL-Mississippi
def run_h2ovl(question: str, modality: str):
assert modality == "image"
model_name = "h2oai/h2ovl-mississippi-2b"
llm = LLM(
model=model_name,
trust_remote_code=True,
max_model_len=8192,
)
tokenizer = AutoTokenizer.from_pretrained(model_name,
trust_remote_code=True)
messages = [{'role': 'user', 'content': f"<image>\n{question}"}]
prompt = tokenizer.apply_chat_template(messages,
tokenize=False,
add_generation_prompt=True)
# Stop tokens for H2OVL-Mississippi
# https://huggingface.co/h2oai/h2ovl-mississippi-2b
stop_token_ids = [tokenizer.eos_token_id]
return llm, prompt, stop_token_ids
# InternVL
def run_internvl(question: str, modality: str):
assert modality == "image"
@ -363,6 +388,7 @@ model_example_map = {
"chameleon": run_chameleon,
"minicpmv": run_minicpmv,
"blip-2": run_blip2,
"h2ovl_chat": run_h2ovl,
"internvl_chat": run_internvl,
"NVLM_D": run_nvlm_d,
"qwen_vl": run_qwen_vl,
@ -475,4 +501,4 @@ if __name__ == "__main__":
default=16,
help='Number of frames to extract from the video.')
args = parser.parse_args()
main(args)
main(args)

View File

@ -107,6 +107,40 @@ def load_phi3v(question: str, image_urls: List[str]) -> ModelRequestData:
)
def load_h2onvl(question: str, image_urls: List[str]) -> ModelRequestData:
model_name = "h2oai/h2ovl-mississippi-2b"
llm = LLM(
model=model_name,
trust_remote_code=True,
max_model_len=8192,
limit_mm_per_prompt={"image": len(image_urls)},
mm_processor_kwargs={"max_dynamic_patch": 4},
)
placeholders = "\n".join(f"Image-{i}: <image>\n"
for i, _ in enumerate(image_urls, start=1))
messages = [{'role': 'user', 'content': f"{placeholders}\n{question}"}]
tokenizer = AutoTokenizer.from_pretrained(model_name,
trust_remote_code=True)
prompt = tokenizer.apply_chat_template(messages,
tokenize=False,
add_generation_prompt=True)
# Stop tokens for H2OVL-Mississippi
# https://huggingface.co/h2oai/h2ovl-mississippi-2b
stop_token_ids = [tokenizer.eos_token_id]
return ModelRequestData(
llm=llm,
prompt=prompt,
stop_token_ids=stop_token_ids,
image_data=[fetch_image(url) for url in image_urls],
chat_template=None,
)
def load_internvl(question: str, image_urls: List[str]) -> ModelRequestData:
model_name = "OpenGVLab/InternVL2-2B"
@ -258,6 +292,7 @@ def load_mllama(question, image_urls: List[str]) -> ModelRequestData:
model_example_map = {
"phi3_v": load_phi3v,
"h2ovl_chat": load_h2onvl,
"internvl_chat": load_internvl,
"NVLM_D": load_nvlm_d,
"qwen2_vl": load_qwen2_vl,

View File

@ -0,0 +1,130 @@
from typing import Optional, Tuple
import pytest
import torch
from PIL.Image import Image
from transformers import AutoConfig
# Import the functions to test
from vllm.model_executor.models.h2ovl import (calculate_num_blocks,
image_to_pixel_values_wrapper)
from vllm.multimodal.utils import rescale_image_size
models = [
"h2oai/h2ovl-mississippi-800m", # Replace with your actual model names
"h2oai/h2ovl-mississippi-2b",
]
target_dtype = "bfloat16"
def run_preprocessing_test(
image: Image,
config,
max_dynamic_patch: Optional[int] = None,
) -> Tuple[torch.Tensor, int]:
"""Test the image preprocessing and calculate expected blocks."""
if max_dynamic_patch is None:
max_dynamic_patch = config.max_dynamic_patch
width, height = image.size
use_MSAC = config.use_msac
# Create the mapper function with the provided configuration
mapper = image_to_pixel_values_wrapper(config, max_dynamic_patch, use_MSAC)
pixel_values = mapper(image)
# Calculate the expected number of blocks
if use_MSAC:
# First pass
blocks1, _, _, aspect_ratio = calculate_num_blocks(
width,
height,
config.min_dynamic_patch,
max_dynamic_patch,
config.vision_config.image_size,
use_thumbnail=False, # Thumbnail is handled separately
prior_aspect_ratio=None,
)
# Second pass
blocks2, _, _, _ = calculate_num_blocks(
width,
height,
config.min_dynamic_patch,
max_dynamic_patch,
config.vision_config.image_size,
use_thumbnail=False,
prior_aspect_ratio=aspect_ratio,
)
# Add thumbnail if use_thumbnail is True and total_blocks > 1
if config.use_thumbnail:
blocks1 += 1 if blocks1 > 1 else 0
blocks2 += 1 if blocks2 > 1 else 0
# Total blocks is the sum of blocks from both passes minus overlapping
total_blocks = blocks1 + blocks2 - 1
expected_blocks = total_blocks
else:
blocks, _, _, _ = calculate_num_blocks(
width,
height,
config.min_dynamic_patch,
max_dynamic_patch,
config.vision_config.image_size,
use_thumbnail=False,
prior_aspect_ratio=None,
)
expected_blocks = blocks
if config.use_thumbnail and expected_blocks > 1:
expected_blocks += 1
return pixel_values, expected_blocks
@pytest.mark.parametrize("model_name", models)
@pytest.mark.parametrize(
"size_factors",
[
# Single-scale
[1.0],
# Single-scale, batched
[1.0, 1.0, 1.0],
# Multi-scale
[0.25, 0.5, 1.0],
],
)
@pytest.mark.parametrize("max_dynamic_patch", [None, 2, 4, 8])
def test_image_preprocessing(image_assets, model_name, size_factors,
max_dynamic_patch):
"""Test image preprocessing pipeline with different configurations."""
# Load the configuration from the model
config = AutoConfig.from_pretrained(model_name, trust_remote_code=True)
for asset in image_assets:
image = asset.pil_image
for factor in size_factors:
scaled_image = rescale_image_size(image, factor)
# Test preprocessing and get expected number of blocks
pixel_values, expected_blocks = run_preprocessing_test(
scaled_image, config, max_dynamic_patch)
# Verify output shapes and properties
actual_blocks = pixel_values.shape[0]
assert actual_blocks == expected_blocks, (
f"Expected {expected_blocks} blocks, got {actual_blocks}")
# Check image dimensions
expected_size = (
3, # Number of channels (C, H, W)
config.vision_config.image_size,
config.vision_config.image_size,
)
for img in pixel_values:
assert img.shape == expected_size, (
f"Expected image size {expected_size}, got {img.shape}")

View File

@ -187,6 +187,23 @@ VLM_TEST_SETTINGS = {
marks=[large_gpu_mark(min_gb=48)],
patch_hf_runner=model_utils.glm_patch_hf_runner,
),
"h2ovl": VLMTestInfo(
models = [
"h2oai/h2ovl-mississippi-800m",
"h2oai/h2ovl-mississippi-2b",
],
test_type=(VLMTestType.IMAGE, VLMTestType.MULTI_IMAGE),
prompt_formatter=lambda img_prompt: f"<|prompt|>{img_prompt}<|end|><|answer|>", # noqa: E501
single_image_prompts=IMAGE_ASSETS.prompts({
"stop_sign": "<image>\nWhat's the content in the center of the image?", # noqa: E501
"cherry_blossom": "<image>\nWhat is the season?",
}),
multi_image_prompt="Image-1: <image>\nImage-2: <image>\nDescribe the two images in short.", # noqa: E501
max_model_len=8192,
dtype="bfloat16",
use_tokenizer_eos=True,
patch_hf_runner=model_utils.h2ovl_patch_hf_runner,
),
"intern_vl": VLMTestInfo(
models=[
"OpenGVLab/InternVL2-1B",

View File

@ -259,6 +259,66 @@ def glm_patch_hf_runner(hf_model: HfRunner) -> HfRunner:
return hf_model
def h2ovl_patch_hf_runner(hf_model: HfRunner) -> HfRunner:
"""Patches and returns an instance of the HfRunner to use for H2OVL."""
class H2OVLProcessor:
"""A simple processor for H2OVL models."""
def __init__(self, hf_runner: HfRunner):
self.num_image_token = hf_runner.model.num_image_token
self.tokenizer = hf_runner.tokenizer
self.dtype = hf_runner.model.dtype
self.config = AutoConfig.from_pretrained(hf_runner.model_name,
trust_remote_code=True)
self.vision_config = self.config.vision_config
self.use_thumbnail = self.config.use_thumbnail
self.min_num = self.config.min_dynamic_patch
self.max_num = self.config.max_dynamic_patch
self.image_size = self.vision_config.image_size
def __call__(self, text: str, images: Union[Image, List[Image]],
**kwargs):
# yapf: disable
from vllm.model_executor.models.h2ovl import (
IMG_CONTEXT, IMG_END, IMG_START, image_to_pixel_values)
# yapf: enable
images = [images] if isinstance(images, Image) else images
pixel_values = [
image_to_pixel_values(image,
self.image_size,
self.min_num,
self.max_num,
self.use_thumbnail,
use_MSAC=self.config.use_msac).to(
self.dtype) for image in images
]
num_patches_list = [
pixel_value.shape[0] for pixel_value in pixel_values
]
pixel_values = torch.cat(pixel_values, dim=0)
for num_patches in num_patches_list:
context_tokens = IMG_CONTEXT * self.num_image_token \
* num_patches
image_tokens = IMG_START + context_tokens + IMG_END
text = text.replace('<image>', image_tokens, 1)
prompt = self.tokenizer(text, return_tensors="pt")
prompt.update({"pixel_values": pixel_values})
return prompt
img_context_token_id = hf_model.tokenizer.convert_tokens_to_ids(
"<IMG_CONTEXT>")
hf_model.model.img_context_token_id = img_context_token_id
hf_model.processor = H2OVLProcessor(hf_model)
hf_model.model.get_output_embeddings = lambda: \
hf_model.model.language_model.get_output_embeddings()
hf_model.model.generate = types.MethodType(_internvl_generate,
hf_model.model)
return hf_model
def internvl_patch_hf_runner(hf_model: HfRunner) -> HfRunner:
"""Patches and returns an instance of the HfRunner to use for InternVL."""

View File

@ -187,7 +187,8 @@ class BaseMultiModalItemTracker(ABC, Generic[_T]):
if model_type.startswith("llava"):
return self._cached_token_str(self._tokenizer,
hf_config.image_token_index)
if model_type in ("chameleon", "internvl_chat", "NVLM_D"):
if model_type in ("chameleon", "internvl_chat", "NVLM_D",
"h2ovl_chat"):
return "<image>"
if model_type == "mllama":
return "<|image|>"

View File

@ -0,0 +1,401 @@
# adapted from https://huggingface.co/h2oai/h2ovl-mississippi-2b/blob/main/modeling_h2ovl_chat.py
# https://huggingface.co/h2oai/h2ovl-mississippi-2b/blob/main/image_process.py
# --------------------------------------------------------
# H2OVL-Mississippi
# Copyright (c) 2024 H2O.AI
# Licensed under Apache 2.0 License [see LICENSE for details]
# --------------------------------------------------------
from functools import partial
from typing import List, Optional, Tuple
import torch
from PIL import Image
from transformers import PretrainedConfig
from vllm.inputs import (INPUT_REGISTRY, DecoderOnlyInputs, InputContext,
token_inputs)
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.base import MultiModalInputs
from vllm.multimodal.utils import cached_get_tokenizer
from vllm.utils import is_list_of
from .intern_vit import InternVisionModel
from .internvl import (IMG_CONTEXT, IMG_END, IMG_START, InternVLChatModel,
InternVLInputPipeline, build_transform,
find_closest_aspect_ratio, get_internvl_num_patches)
# modified to include blocks generated in second pass
def calculate_num_blocks(
orig_width: int,
orig_height: int,
min_num: int,
max_num: int,
image_size: int,
use_thumbnail: bool,
prior_aspect_ratio=None,
) -> Tuple[int, int, int, Tuple[int, int]]:
aspect_ratio = orig_width / orig_height
# calculate the existing image aspect ratio
target_ratios = set((i, j) for n in range(min_num, max_num + 1)
for i in range(1, n + 1) for j in range(1, n + 1)
if i * j <= max_num and i * j >= min_num)
target_ratios = sorted(target_ratios, key=lambda x: x[0] * x[1])
# if prior_aspect_ratio is provided, filter the target ratios
if prior_aspect_ratio is not None:
target_ratios = [
ratio for ratio in target_ratios if prior_aspect_ratio[0] %
ratio[0] != 0 and prior_aspect_ratio[1] % ratio[1] != 0
]
# find the closest aspect ratio to the target
target_aspect_ratio = find_closest_aspect_ratio(aspect_ratio,
target_ratios, orig_width,
orig_height, image_size)
# calculate the target width and height
target_width = image_size * target_aspect_ratio[0]
target_height = image_size * target_aspect_ratio[1]
blocks = target_aspect_ratio[0] * target_aspect_ratio[1]
# add thumbnail image if num_blocks > 1
if use_thumbnail and blocks > 1:
blocks += 1
return blocks, target_width, target_height, target_aspect_ratio
# adapted from https://huggingface.co/OpenGVLab/InternVL2-1B
# refactored to handle prior_aspect_ratio as optional
def dynamic_preprocess(
image: Image.Image,
min_num: int,
max_num: int,
image_size: int,
use_thumbnail: bool,
prior_aspect_ratio: Optional[Tuple[int, int]] = None,
) -> Tuple[List[Image.Image], Tuple[int, int]]:
orig_width, orig_height = image.size
# calculate the number of blocks based on prior aspect ratio if available
blocks, target_width, target_height, target_aspect_ratio = (
calculate_num_blocks(
orig_width,
orig_height,
min_num,
max_num,
image_size,
use_thumbnail=False,
prior_aspect_ratio=prior_aspect_ratio,
))
# resize the image
resized_img = image.resize((target_width, target_height))
processed_images = []
for i in range(blocks):
box = (
(i % (target_width // image_size)) * image_size,
(i // (target_width // image_size)) * image_size,
((i % (target_width // image_size)) + 1) * image_size,
((i // (target_width // image_size)) + 1) * image_size,
)
# split the image
split_img = resized_img.crop(box)
processed_images.append(split_img)
assert len(processed_images) == blocks
if use_thumbnail and len(processed_images) != 1:
thumbnail_img = image.resize((image_size, image_size))
processed_images.append(thumbnail_img)
return processed_images, target_aspect_ratio
def load_image(
image: Image.Image,
input_size=448,
min_num=1,
max_num=6,
use_thumbnail=True,
prior_aspect_ratio: Optional[Tuple[int, int]] = None,
) -> Tuple[torch.Tensor, Tuple[int, int]]:
transform = build_transform(input_size=input_size)
images, target_aspect_ratio = dynamic_preprocess(
image,
image_size=input_size,
use_thumbnail=use_thumbnail,
min_num=min_num,
max_num=max_num,
prior_aspect_ratio=prior_aspect_ratio,
)
pixel_values = [transform(image) for image in images]
pixel_values = torch.stack(pixel_values)
return pixel_values, target_aspect_ratio
# refactored to use the combined load_image function
def image_to_pixel_values(
image: Image.Image,
input_size: int,
min_num: int,
max_num: int,
use_thumbnail: bool,
use_MSAC: bool,
) -> torch.Tensor:
# when MSAC is turned on, we need to process the image twice
if use_MSAC:
# first pass
pixel_values, target_aspect_ratio = load_image(
image,
input_size=input_size,
min_num=min_num,
max_num=max_num,
use_thumbnail=True,
)
# second pass
pixel_values2, _ = load_image(
image,
input_size=input_size,
min_num=min_num,
max_num=max_num,
prior_aspect_ratio=target_aspect_ratio,
)
# combine pixel values
pixel_values = torch.cat(
[pixel_values2[:-1], pixel_values[:-1], pixel_values2[-1:]], 0)
else:
pixel_values, _ = load_image(
image,
input_size=input_size,
min_num=min_num,
max_num=max_num,
use_thumbnail=use_thumbnail,
)
return pixel_values
def image_to_pixel_values_wrapper(hf_config: PretrainedConfig,
max_dynamic_patch: Optional[int] = None,
use_MSAC: Optional[bool] = None):
image_size = hf_config.vision_config.image_size
min_num = hf_config.min_dynamic_patch
if max_dynamic_patch is None:
max_dynamic_patch = hf_config.max_dynamic_patch
if use_MSAC is None:
use_MSAC = hf_config.use_msac
use_thumbnail = hf_config.use_thumbnail
return partial(
image_to_pixel_values,
input_size=image_size,
min_num=min_num,
max_num=max_dynamic_patch,
use_thumbnail=use_thumbnail,
use_MSAC=use_MSAC,
)
def get_max_internvl_image_tokens(ctx: InputContext,
*,
max_dynamic_patch: Optional[int] = None):
"""
Calculate the maximum number of tokens with/without MSAC and thumbnail
"""
hf_config = ctx.get_hf_config()
use_thumbnail = hf_config.use_thumbnail
use_MSAC = hf_config.use_msac
if max_dynamic_patch is None:
max_dynamic_patch = hf_config.max_dynamic_patch
num_patches = get_internvl_num_patches(hf_config)
coefficient = 2 if use_MSAC else 1
num_blocks = coefficient * max_dynamic_patch + (1 if use_thumbnail else 0)
return num_blocks * num_patches
class H2OVLInputPipeline(InternVLInputPipeline):
"""
Input pipeline for processing image and text data for the H2OVL model.
"""
def input_processor(
self,
ctx: InputContext,
inputs: DecoderOnlyInputs,
*,
max_dynamic_patch: Optional[int] = None,
) -> DecoderOnlyInputs:
# get multi_modal_data
multi_modal_data = inputs.get("multi_modal_data")
if multi_modal_data is None or "image" not in multi_modal_data:
return inputs
model_config = ctx.model_config
hf_config = ctx.get_hf_config()
use_MSAC = hf_config.use_msac
image_data = multi_modal_data["image"]
num_patches = get_internvl_num_patches(hf_config)
image_pixel_values_mapper = image_to_pixel_values_wrapper(
hf_config, max_dynamic_patch=max_dynamic_patch)
# single image
if isinstance(image_data, Image.Image):
pixel_values = image_pixel_values_mapper(image_data,
use_MSAC=use_MSAC)
num_blocks = pixel_values.shape[0]
image_feature_sizes = [num_blocks * num_patches]
pixel_values = pixel_values.unsqueeze(0)
# multi images
elif is_list_of(image_data, Image.Image):
# Do not use MSAC for multi images
image_feature_sizes = []
pixel_values = [
image_pixel_values_mapper(image, use_MSAC=False)
for image in image_data
]
for pixel_value in pixel_values:
num_blocks = pixel_value.shape[0]
image_feature_sizes.append(num_blocks * num_patches)
# image embeddings as input
elif isinstance(image_data, torch.Tensor):
_, image_feature_size, _ = image_data.shape
image_feature_sizes = [image_feature_size]
pixel_values = None
# multi-image image embeddings
elif is_list_of(image_data, torch.Tensor):
image_feature_sizes = []
for image_embed in image_data:
_, image_feature_size, _ = image_embed.shape
image_feature_sizes.append(image_feature_size)
pixel_values = None
else:
raise TypeError(f"Invalid image type: {type(image_data)}")
tokenizer = cached_get_tokenizer(
model_config.tokenizer,
trust_remote_code=model_config.trust_remote_code,
)
prompt = inputs.get("prompt")
prompt_token_ids = inputs["prompt_token_ids"]
if prompt is None:
prompt = tokenizer.decode(prompt_token_ids)
new_prompt = self._expand_image_prompt(prompt, image_feature_sizes,
num_patches)
new_prompt_token_ids = tokenizer.encode(new_prompt)
# Wrap image processing in input_processor to avoid duplication
image_token_id = tokenizer.encode(
self.img_context_token,
add_special_tokens=False,
return_tensors="pt",
)[0]
# Update multi_modal_data to return
if pixel_values is not None:
multi_modal_data = {
"image": {
"pixel_values": pixel_values,
"image_token_id": image_token_id,
}
}
else:
multi_modal_data = {"image": {"image_embeds": image_data}}
return token_inputs(
prompt=prompt,
prompt_token_ids=new_prompt_token_ids,
multi_modal_data=multi_modal_data,
)
def input_mapper(
self,
ctx: InputContext,
data: object,
*,
max_dynamic_patch: Optional[int] = None,
) -> MultiModalInputs:
# NOTE: Preprocessing for the image data is done in the
# 'input_processor' function during actual inference.
if isinstance(data, dict):
return MultiModalInputs(data)
# The section below is only used with dummy data during
# memory profiling.
hf_config = ctx.get_hf_config()
image_pixel_values_mapper = image_to_pixel_values_wrapper(
hf_config, max_dynamic_patch)
if isinstance(data, Image.Image):
pixel_values = image_pixel_values_mapper(data)
pixel_values = pixel_values.unsqueeze(0)
elif is_list_of(data, Image.Image):
hf_config.use_msac = False
pixel_values = [image_pixel_values_mapper(img) for img in data]
else:
return MultiModalInputs({"image_embeds": data})
model_config = ctx.model_config
tokenizer = cached_get_tokenizer(
model_config.tokenizer,
trust_remote_code=model_config.trust_remote_code,
)
image_token_id = tokenizer.encode(
self.img_context_token,
add_special_tokens=False,
return_tensors="pt",
)[0]
return MultiModalInputs({
"pixel_values": pixel_values,
"image_token_id": image_token_id
})
input_pipeline = H2OVLInputPipeline(IMG_START, IMG_END, IMG_CONTEXT)
@MULTIMODAL_REGISTRY.register_image_input_mapper(input_pipeline.input_mapper)
@MULTIMODAL_REGISTRY.register_max_image_tokens(get_max_internvl_image_tokens)
@INPUT_REGISTRY.register_dummy_data(input_pipeline.dummy_data)
@INPUT_REGISTRY.register_input_processor(input_pipeline.input_processor)
class H2OVLChatModel(InternVLChatModel):
def _init_vision_model(
self,
config: PretrainedConfig,
quant_config: Optional[QuantizationConfig],
*,
is_mono: bool,
prefix: str,
):
if not is_mono:
vision_feature_layer = config.select_layer
if vision_feature_layer < 0:
num_hidden_layers = (config.vision_config.num_hidden_layers +
vision_feature_layer + 1)
else:
num_hidden_layers = vision_feature_layer + 1
return InternVisionModel(
config.vision_config,
quant_config=quant_config,
num_hidden_layers_override=num_hidden_layers,
prefix=prefix,
)
else:
msg = "Monolith mode is not applicable to H2OVL"
raise NotImplementedError(msg)

View File

@ -128,6 +128,7 @@ _MULTIMODAL_MODELS = {
"ChatGLMModel": ("chatglm", "ChatGLMForCausalLM"),
"ChatGLMForConditionalGeneration": ("chatglm", "ChatGLMForCausalLM"),
"FuyuForCausalLM": ("fuyu", "FuyuForCausalLM"),
"H2OVLChatModel": ("h2ovl", "H2OVLChatModel"),
"InternVLChatModel": ("internvl", "InternVLChatModel"),
"LlavaForConditionalGeneration": ("llava", "LlavaForConditionalGeneration"),
"LlavaNextForConditionalGeneration": ("llava_next", "LlavaNextForConditionalGeneration"), # noqa: E501
@ -482,4 +483,4 @@ def _run() -> None:
if __name__ == "__main__":
_run()
_run()

View File

@ -19,6 +19,7 @@ from vllm.logger import init_logger
# yapf: disable
from vllm.transformers_utils.configs import (ChatGLMConfig, DbrxConfig,
EAGLEConfig, ExaoneConfig,
H2OVLChatConfig,
InternVLChatConfig, JAISConfig,
MedusaConfig, MllamaConfig,
MLPSpeculatorConfig, MPTConfig,
@ -52,6 +53,7 @@ _CONFIG_REGISTRY: Dict[str, Type[PretrainedConfig]] = {
"medusa": MedusaConfig,
"eagle": EAGLEConfig,
"exaone": ExaoneConfig,
"h2ovl_chat": H2OVLChatConfig,
"internvl_chat": InternVLChatConfig,
"nemotron": NemotronConfig,
"NVLM_D": NVLM_D_Config,

View File

@ -6,6 +6,7 @@ from vllm.transformers_utils.configs.exaone import ExaoneConfig
# tiiuae/falcon-7b(-instruct) models. Newer Falcon models will use the
# `FalconConfig` class from the official HuggingFace transformers library.
from vllm.transformers_utils.configs.falcon import RWConfig
from vllm.transformers_utils.configs.h2ovl import H2OVLChatConfig
from vllm.transformers_utils.configs.internvl import InternVLChatConfig
from vllm.transformers_utils.configs.jais import JAISConfig
from vllm.transformers_utils.configs.medusa import MedusaConfig
@ -22,6 +23,7 @@ __all__ = [
"DbrxConfig",
"MPTConfig",
"RWConfig",
"H2OVLChatConfig",
"InternVLChatConfig",
"JAISConfig",
"MedusaConfig",
@ -33,4 +35,4 @@ __all__ = [
"NVLM_D_Config",
"SolarConfig",
"UltravoxConfig",
]
]

View File

@ -0,0 +1,13 @@
# Adapted from
# https://huggingface.co/h2oai/h2ovl-mississippi-2b/blob/main/configuration_h2ovl_chat.py
# --------------------------------------------------------
# H2OVL-Mississippi
# Copyright (c) 2024 H2O.AI
# Licensed under Apache 2.0 License [see LICENSE for details]
# --------------------------------------------------------
from .internvl import InternVLChatConfig
class H2OVLChatConfig(InternVLChatConfig):
model_type = "h2ovl_chat"